/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.translate;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDArrays;
import ai.djl.ndarray.NDList;
import ai.djl.translate.Batchifier;
import java.util.Arrays;
import java.util.stream.IntStream;

public class StackBatchifier
implements Batchifier {
    @Override
    public NDList batchify(NDList[] inputs) {
        int batchSize = inputs.length;
        int numInputKinds = inputs[0].size();
        if (numInputKinds == 0) {
            return new NDList();
        }
        NDList result = new NDList(numInputKinds);
        for (int i = 0; i < numInputKinds; ++i) {
            NDList inputsOfKind = new NDList(batchSize);
            for (NDList input : inputs) {
                inputsOfKind.add(input.get(i));
            }
            NDArray stacked = NDArrays.stack(new NDList(inputsOfKind));
            result.add(stacked);
        }
        return result;
    }

    @Override
    public NDList[] unbatchify(NDList inputs) {
        int numInputKinds = inputs.size();
        if (numInputKinds == 0) {
            return new NDList[0];
        }
        int batchSize = Math.toIntExact(inputs.head().size(0));
        if (batchSize == 0) {
            return new NDList[0];
        }
        NDList[] dataList = new NDList[batchSize];
        for (int i = 0; i < batchSize; ++i) {
            dataList[i] = new NDList();
        }
        for (NDArray input : inputs) {
            NDList splitList = input.split(batchSize);
            for (int i = 0; i < batchSize; ++i) {
                NDArray array = ((NDArray)splitList.get(i)).squeeze(0);
                array.setName(input.getName());
                dataList[i].add(array);
            }
        }
        return dataList;
    }

    @Override
    public NDList[] split(NDList list, int numOfSlices, boolean evenSplit) {
        int batchSize = Math.toIntExact(list.head().size(0));
        numOfSlices = Math.min(numOfSlices, batchSize);
        NDList[] splitted = new NDList[numOfSlices];
        Arrays.setAll(splitted, i -> new NDList());
        for (NDArray nd : list) {
            String name = nd.getName();
            NDList rows = this.split(nd, numOfSlices, evenSplit);
            for (int i2 = 0; i2 < numOfSlices; ++i2) {
                NDArray array = (NDArray)rows.get(i2);
                array.setName(name);
                splitted[i2].add(array);
            }
        }
        return splitted;
    }

    private NDList split(NDArray array, int numOfSlices, boolean evenSplit) {
        int batchSize = Math.toIntExact(array.size(0));
        if (batchSize < numOfSlices) {
            throw new IllegalArgumentException("Batch size(" + batchSize + ") is less then slice number(" + numOfSlices + ").");
        }
        if (evenSplit && batchSize % numOfSlices != 0) {
            throw new IllegalArgumentException("data with shape " + batchSize + " cannot be evenly split into " + numOfSlices + ". Use a batch size that's multiple of " + numOfSlices + " or set even_split=true to allow uneven partitioning of data.");
        }
        if (evenSplit) {
            return array.split(numOfSlices);
        }
        int step = (int)Math.ceil((double)batchSize / (double)numOfSlices);
        int[] indices = IntStream.range(1, numOfSlices).map(i -> i * step).toArray();
        return array.split(indices);
    }
}

