/*
 * Decompiled with CFR 0.152.
 */
package dev.zarr.zarrjava.v3.codec.core;

import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import dev.zarr.zarrjava.ZarrException;
import dev.zarr.zarrjava.store.StoreHandle;
import dev.zarr.zarrjava.utils.IndexingUtils;
import dev.zarr.zarrjava.utils.MultiArrayUtils;
import dev.zarr.zarrjava.utils.Utils;
import dev.zarr.zarrjava.v3.ArrayMetadata;
import dev.zarr.zarrjava.v3.DataType;
import dev.zarr.zarrjava.v3.codec.ArrayBytesCodec;
import dev.zarr.zarrjava.v3.codec.Codec;
import dev.zarr.zarrjava.v3.codec.CodecPipeline;
import java.nio.Buffer;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Stream;
import javax.annotation.Nonnull;
import ucar.ma2.Array;
import ucar.ma2.InvalidRangeException;

public class ShardingIndexedCodec
extends ArrayBytesCodec.WithPartialDecode {
    public final String name = "sharding_indexed";
    @Nonnull
    public final Configuration configuration;
    CodecPipeline codecPipeline;
    CodecPipeline indexCodecPipeline;

    @JsonCreator(mode=JsonCreator.Mode.PROPERTIES)
    public ShardingIndexedCodec(@Nonnull @JsonProperty(value="configuration", required=true) Configuration configuration) throws ZarrException {
        this.configuration = configuration;
    }

    @Override
    public void setCoreArrayMetadata(ArrayMetadata.CoreArrayMetadata arrayMetadata) throws ZarrException {
        super.setCoreArrayMetadata(arrayMetadata);
        ArrayMetadata.CoreArrayMetadata shardMetadata = new ArrayMetadata.CoreArrayMetadata(Utils.toLongArray(arrayMetadata.chunkShape), this.configuration.chunkShape, arrayMetadata.dataType, arrayMetadata.parsedFillValue);
        this.codecPipeline = new CodecPipeline(this.configuration.codecs, shardMetadata);
        this.indexCodecPipeline = new CodecPipeline(this.configuration.indexCodecs, this.getShardIndexArrayMetadata(this.getChunksPerShard(arrayMetadata)));
    }

    ArrayMetadata.CoreArrayMetadata getShardIndexArrayMetadata(int[] chunksPerShard) {
        int[] indexShape = this.extendArrayBy1(chunksPerShard, 2);
        return new ArrayMetadata.CoreArrayMetadata(Utils.toLongArray(indexShape), indexShape, DataType.UINT64, -1);
    }

    public int[] getChunksPerShard(ArrayMetadata.CoreArrayMetadata arrayMetadata) {
        int ndim = arrayMetadata.ndim();
        int[] chunksPerShard = new int[ndim];
        for (int dimIdx = 0; dimIdx < ndim; ++dimIdx) {
            chunksPerShard[dimIdx] = arrayMetadata.chunkShape[dimIdx] / this.configuration.chunkShape[dimIdx];
        }
        return chunksPerShard;
    }

    int[] extendArrayBy1(int[] array, int value) {
        int[] out = new int[array.length + 1];
        System.arraycopy(array, 0, out, 0, array.length);
        out[out.length - 1] = value;
        return out;
    }

    long[] extendArrayBy1(long[] array, long value) {
        long[] out = new long[array.length + 1];
        System.arraycopy(array, 0, out, 0, array.length);
        out[out.length - 1] = value;
        return out;
    }

    long getValueFromShardIndexArray(Array shardIndexArray, long[] chunkCoords, int idx) {
        return shardIndexArray.getLong(shardIndexArray.getIndex().set(Utils.toIntArray(this.extendArrayBy1(chunkCoords, (long)idx))));
    }

    void setValueFromShardIndexArray(Array shardIndexArray, long[] chunkCoords, int idx, long value) {
        shardIndexArray.setLong(shardIndexArray.getIndex().set(Utils.toIntArray(this.extendArrayBy1(chunkCoords, (long)idx))), value);
    }

    @Override
    public Array decode(ByteBuffer shardBytes) throws ZarrException {
        return this.decodeInternal(new ByteBufferDataProvider(shardBytes), new long[this.arrayMetadata.ndim()], this.arrayMetadata.chunkShape, this.arrayMetadata);
    }

    @Override
    public ByteBuffer encode(Array shardArray) throws ZarrException {
        ArrayMetadata.CoreArrayMetadata shardMetadata = this.codecPipeline.arrayMetadata;
        int[] chunksPerShard = this.getChunksPerShard(this.arrayMetadata);
        int chunkCount = Arrays.stream(chunksPerShard).reduce(1, (r, a) -> r * a);
        Array shardIndexArray = Array.factory((ucar.ma2.DataType)ucar.ma2.DataType.ULONG, (int[])this.extendArrayBy1(chunksPerShard, 2));
        ArrayList chunkBytesList = new ArrayList(chunkCount);
        ((Stream)Arrays.stream(IndexingUtils.computeChunkCoords(shardMetadata.shape, shardMetadata.chunkShape)).parallel()).forEach(chunkCoords -> {
            block7: {
                try {
                    int i = (int)IndexingUtils.cOrderIndex(chunkCoords, Utils.toLongArray(chunksPerShard));
                    IndexingUtils.ChunkProjection chunkProjection = IndexingUtils.computeProjection(chunkCoords, shardMetadata.shape, shardMetadata.chunkShape);
                    Array chunkArray = shardArray.sectionNoReduce(chunkProjection.outOffset, chunkProjection.shape, null);
                    if (MultiArrayUtils.allValuesEqual(chunkArray, shardMetadata.parsedFillValue)) {
                        this.setValueFromShardIndexArray(shardIndexArray, (long[])chunkCoords, 0, -1L);
                        this.setValueFromShardIndexArray(shardIndexArray, (long[])chunkCoords, 1, -1L);
                        break block7;
                    }
                    ByteBuffer chunkBytes = this.codecPipeline.encode(chunkArray);
                    List list = chunkBytesList;
                    synchronized (list) {
                        int chunkByteOffset = chunkBytesList.stream().mapToInt(Buffer::capacity).sum();
                        if (this.configuration.indexLocation.equals("start")) {
                            chunkByteOffset += (int)this.getShardIndexSize(this.arrayMetadata);
                        }
                        this.setValueFromShardIndexArray(shardIndexArray, (long[])chunkCoords, 0, chunkByteOffset);
                        this.setValueFromShardIndexArray(shardIndexArray, (long[])chunkCoords, 1, chunkBytes.capacity());
                        chunkBytesList.add(chunkBytes);
                    }
                }
                catch (ZarrException | InvalidRangeException e) {
                    throw new RuntimeException(e);
                }
            }
        });
        int shardBytesLength = chunkBytesList.stream().mapToInt(Buffer::capacity).sum() + (int)this.getShardIndexSize(this.arrayMetadata);
        ByteBuffer shardBytes = ByteBuffer.allocate(shardBytesLength);
        if (this.configuration.indexLocation.equals("start")) {
            shardBytes.put(this.indexCodecPipeline.encode(shardIndexArray));
        }
        for (ByteBuffer chunkBytes : chunkBytesList) {
            shardBytes.put(chunkBytes);
        }
        if (this.configuration.indexLocation.equals("end")) {
            shardBytes.put(this.indexCodecPipeline.encode(shardIndexArray));
        }
        shardBytes.rewind();
        return shardBytes;
    }

    @Override
    public long computeEncodedSize(long inputByteLength, ArrayMetadata.CoreArrayMetadata arrayMetadata) throws ZarrException {
        return inputByteLength + this.getShardIndexSize(arrayMetadata);
    }

    private long getShardIndexSize(ArrayMetadata.CoreArrayMetadata arrayMetadata) throws ZarrException {
        return this.indexCodecPipeline.computeEncodedSize(16L * (long)Arrays.stream(this.getChunksPerShard(arrayMetadata)).reduce(1, (r, a) -> r * a), arrayMetadata);
    }

    private Array decodeInternal(DataProvider dataProvider, long[] offset, int[] shape, ArrayMetadata.CoreArrayMetadata arrayMetadata) throws ZarrException {
        ByteBuffer shardIndexBytes;
        ArrayMetadata.CoreArrayMetadata shardMetadata = this.codecPipeline.arrayMetadata;
        Array outputArray = Array.factory((ucar.ma2.DataType)arrayMetadata.dataType.getMA2DataType(), (int[])shape);
        int shardIndexByteLength = (int)this.getShardIndexSize(arrayMetadata);
        if (this.configuration.indexLocation.equals("start")) {
            shardIndexBytes = dataProvider.readPrefix(shardIndexByteLength);
        } else if (this.configuration.indexLocation.equals("end")) {
            shardIndexBytes = dataProvider.readSuffix(shardIndexByteLength);
        } else {
            throw new ZarrException("Only index_location \"start\" or \"end\" are supported.");
        }
        if (shardIndexBytes == null) {
            throw new ZarrException("Could not read shard index.");
        }
        Array shardIndexArray = this.indexCodecPipeline.decode(shardIndexBytes);
        long[][] allChunkCoords = IndexingUtils.computeChunkCoords(shardMetadata.shape, shardMetadata.chunkShape, offset, shape);
        Arrays.stream(allChunkCoords).forEach(chunkCoords -> {
            try {
                long chunkByteOffset = this.getValueFromShardIndexArray(shardIndexArray, (long[])chunkCoords, 0);
                long chunkByteLength = this.getValueFromShardIndexArray(shardIndexArray, (long[])chunkCoords, 1);
                Array chunkArray = null;
                IndexingUtils.ChunkProjection chunkProjection = IndexingUtils.computeProjection(chunkCoords, shardMetadata.shape, shardMetadata.chunkShape, offset, shape);
                if (chunkByteOffset != -1L && chunkByteLength != -1L) {
                    ByteBuffer chunkBytes = dataProvider.read(chunkByteOffset, chunkByteLength);
                    if (chunkBytes == null) {
                        throw new ZarrException(String.format("Could not load byte data for chunk %s", Arrays.toString(chunkCoords)));
                    }
                    chunkArray = this.codecPipeline.decode(chunkBytes);
                }
                if (chunkArray == null) {
                    chunkArray = shardMetadata.allocateFillValueChunk();
                }
                MultiArrayUtils.copyRegion(chunkArray, chunkProjection.chunkOffset, outputArray, chunkProjection.outOffset, chunkProjection.shape);
            }
            catch (ZarrException e) {
                throw new RuntimeException(e);
            }
        });
        return outputArray;
    }

    @Override
    public Array decodePartial(StoreHandle chunkHandle, long[] offset, int[] shape) throws ZarrException {
        if (Arrays.equals(shape, this.arrayMetadata.chunkShape)) {
            ByteBuffer chunkBytes = chunkHandle.read();
            if (chunkBytes == null) {
                return this.arrayMetadata.allocateFillValueChunk();
            }
            return this.decodeInternal(new ByteBufferDataProvider(chunkHandle.read()), offset, shape, this.arrayMetadata);
        }
        return this.decodeInternal(new StoreHandleDataProvider(chunkHandle), offset, shape, this.arrayMetadata);
    }

    static class StoreHandleDataProvider
    implements DataProvider {
        @Nonnull
        final StoreHandle storeHandle;

        StoreHandleDataProvider(@Nonnull StoreHandle storeHandle) {
            this.storeHandle = storeHandle;
        }

        @Override
        public ByteBuffer readSuffix(long suffixLength) {
            return this.storeHandle.read(-suffixLength);
        }

        @Override
        public ByteBuffer readPrefix(long prefixLength) {
            return this.storeHandle.read(0L, prefixLength);
        }

        @Override
        public ByteBuffer read(long start, long length) {
            return this.storeHandle.read(start, start + length);
        }
    }

    static class ByteBufferDataProvider
    implements DataProvider {
        @Nonnull
        final ByteBuffer buffer;

        ByteBufferDataProvider(@Nonnull ByteBuffer buffer) {
            this.buffer = buffer;
        }

        @Override
        public ByteBuffer readSuffix(long suffixLength) {
            ByteBuffer bufferSlice = this.buffer.slice();
            bufferSlice.position((int)((long)bufferSlice.capacity() - suffixLength));
            return bufferSlice.slice();
        }

        @Override
        public ByteBuffer readPrefix(long prefixLength) {
            ByteBuffer bufferSlice = this.buffer.slice();
            bufferSlice.limit((int)prefixLength);
            return bufferSlice.slice();
        }

        @Override
        public ByteBuffer read(long start, long length) {
            ByteBuffer bufferSlice = this.buffer.slice();
            bufferSlice.position((int)start);
            bufferSlice.limit((int)(start + length));
            return bufferSlice.slice();
        }
    }

    public static final class Configuration {
        @JsonProperty(value="chunk_shape")
        public final int[] chunkShape;
        @Nonnull
        @JsonProperty(value="codecs")
        public final Codec[] codecs;
        @Nonnull
        @JsonProperty(value="index_codecs")
        public final Codec[] indexCodecs;
        @Nonnull
        @JsonProperty(value="index_location")
        public String indexLocation;

        @JsonCreator(mode=JsonCreator.Mode.PROPERTIES)
        public Configuration(@JsonProperty(value="chunk_shape", required=true) int[] chunkShape, @Nonnull @JsonProperty(value="codecs") Codec[] codecs, @Nonnull @JsonProperty(value="index_codecs") Codec[] indexCodecs, @JsonProperty(value="index_location", defaultValue="end") String indexLocation) throws ZarrException {
            if (indexLocation == null) {
                indexLocation = "end";
            }
            if (!indexLocation.equals("start") && !indexLocation.equals("end")) {
                throw new ZarrException("Only index_location \"start\" or \"end\" are supported.");
            }
            this.chunkShape = chunkShape;
            this.codecs = codecs;
            this.indexCodecs = indexCodecs;
            this.indexLocation = indexLocation;
        }
    }

    static interface DataProvider {
        public ByteBuffer read(long var1, long var3);

        public ByteBuffer readSuffix(long var1);

        public ByteBuffer readPrefix(long var1);
    }
}

