/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.training.dataset;

import ai.djl.Device;
import ai.djl.ndarray.NDManager;
import ai.djl.training.dataset.Batch;
import ai.djl.training.dataset.BatchSampler;
import ai.djl.training.dataset.DataIterable;
import ai.djl.training.dataset.Dataset;
import ai.djl.training.dataset.RandomSampler;
import ai.djl.training.dataset.Record;
import ai.djl.training.dataset.Sampler;
import ai.djl.training.dataset.SequenceSampler;
import ai.djl.translate.Batchifier;
import ai.djl.translate.Pipeline;
import ai.djl.util.RandomUtils;
import java.io.IOException;
import java.util.Arrays;
import java.util.RandomAccess;
import java.util.concurrent.ExecutorService;
import java.util.stream.IntStream;

public abstract class RandomAccessDataset
implements Dataset,
RandomAccess {
    protected Sampler sampler;
    protected Batchifier batchifier;
    protected Pipeline pipeline;
    protected Pipeline targetPipeline;
    protected ExecutorService executor;
    protected int prefetchNumber;
    protected long maxIteration;
    protected Device device;

    RandomAccessDataset() {
    }

    public RandomAccessDataset(BaseBuilder<?> builder) {
        this.sampler = builder.getSampler();
        this.batchifier = builder.batchifier;
        this.pipeline = builder.pipeline;
        this.targetPipeline = builder.targetPipeline;
        this.executor = builder.executor;
        this.prefetchNumber = builder.prefetchNumber;
        this.maxIteration = builder.maxIteration;
        this.device = builder.device;
    }

    public abstract Record get(NDManager var1, long var2) throws IOException;

    @Override
    public Iterable<Batch> getData(NDManager manager) {
        return new DataIterable(this, manager, this.sampler, this.batchifier, this.pipeline, this.targetPipeline, this.executor, this.prefetchNumber, this.maxIteration, this.device);
    }

    public abstract long size();

    public long getNumIterations() {
        int batchSize = this.sampler.getBatchSize();
        if (batchSize == -1) {
            return -1L;
        }
        long iteration = this.size() / (long)batchSize;
        return Math.min(this.maxIteration, iteration);
    }

    public RandomAccessDataset[] randomSplit(int ... ratio) {
        if (ratio.length < 2) {
            throw new IllegalArgumentException("Requires at least two split portion.");
        }
        int size = Math.toIntExact(this.size());
        int[] indices = IntStream.range(0, size).toArray();
        for (int i = 0; i < size; ++i) {
            RandomAccessDataset.swap(indices, i, RandomUtils.nextInt(size));
        }
        RandomAccessDataset[] ret = new RandomAccessDataset[ratio.length];
        double sum = Arrays.stream(ratio).sum();
        int from = 0;
        for (int i = 0; i < ratio.length - 1; ++i) {
            int to = from + (int)((double)ratio[i] / sum * (double)size);
            ret[i] = new SubDataset(this, indices, from, to);
            from += to;
        }
        ret[ratio.length - 1] = new SubDataset(this, indices, from, size);
        return ret;
    }

    private static void swap(int[] arr, int i, int j) {
        int tmp = arr[i];
        arr[i] = arr[j];
        arr[j] = tmp;
    }

    private static final class SubDataset
    extends RandomAccessDataset {
        private RandomAccessDataset dataset;
        private int[] indices;
        private int from;
        private int to;

        public SubDataset(RandomAccessDataset dataset, int[] indices, int from, int to) {
            this.dataset = dataset;
            this.indices = indices;
            this.from = from;
            this.to = to;
        }

        @Override
        public Record get(NDManager manager, long index) throws IOException {
            if (index >= this.size()) {
                throw new IndexOutOfBoundsException("index(" + index + ") > size(" + this.size() + ").");
            }
            return this.dataset.get(manager, this.indices[Math.toIntExact(index) + this.from]);
        }

        @Override
        public long size() {
            return this.to - this.from;
        }

        @Override
        public Iterable<Batch> getData(NDManager manager) {
            return this.dataset.getData(manager);
        }

        @Override
        public long getNumIterations() {
            return this.dataset.getNumIterations();
        }
    }

    public static abstract class BaseBuilder<T extends BaseBuilder> {
        protected Sampler sampler;
        protected Batchifier batchifier = Batchifier.STACK;
        protected Pipeline pipeline;
        protected Pipeline targetPipeline;
        protected ExecutorService executor;
        protected int prefetchNumber;
        protected long maxIteration = Long.MAX_VALUE;
        protected Device device;

        public Sampler getSampler() {
            if (this.sampler == null) {
                throw new IllegalArgumentException("The sampler must be set");
            }
            return this.sampler;
        }

        public T setSampling(int batchSize, boolean random) {
            return this.setSampling(batchSize, random, false);
        }

        public T setSampling(int batchSize, boolean random, boolean dropLast) {
            this.sampler = random ? new BatchSampler(new RandomSampler(), batchSize, dropLast) : new BatchSampler(new SequenceSampler(), batchSize, dropLast);
            return this.self();
        }

        public T setSampling(Sampler sampler) {
            this.sampler = sampler;
            return this.self();
        }

        public T optBatchier(Batchifier batchier) {
            this.batchifier = batchier;
            return this.self();
        }

        public T optPipeline(Pipeline pipeline) {
            this.pipeline = pipeline;
            return this.self();
        }

        public T optTargetPipeline(Pipeline targetPipeline) {
            this.targetPipeline = targetPipeline;
            return this.self();
        }

        public T optExcutor(ExecutorService executor, int prefetchNumber) {
            this.executor = executor;
            this.prefetchNumber = prefetchNumber;
            return this.self();
        }

        public T optDevice(Device device) {
            this.device = device;
            return this.self();
        }

        public T optMaxIteration(long maxIteration) {
            this.maxIteration = maxIteration;
            return this.self();
        }

        protected abstract T self();
    }
}

