/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.api.ops.aggregates;

import java.util.ArrayList;
import java.util.List;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.aggregates.Aggregate;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.shade.guava.collect.Lists;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class Batch<T extends Aggregate> {
    private static final Logger log = LoggerFactory.getLogger(Batch.class);
    private DataBuffer paramsSurface;
    private static final int batchLimit = 512;
    private List<T> aggregates;
    private T sample;
    private int numAggregates;

    public Batch(List<T> aggregates) {
        this.aggregates = aggregates;
        this.numAggregates = aggregates.size();
        this.sample = (Aggregate)aggregates.get(0);
    }

    public int opNum() {
        return this.sample.opNum();
    }

    public boolean append(T aggregate) {
        if (!this.isFull()) {
            this.aggregates.add(aggregate);
            return true;
        }
        return false;
    }

    public boolean isFull() {
        return 512 == this.numAggregates;
    }

    public static <U extends Aggregate> List<Batch<U>> getBatches(List<U> list) {
        return Batch.getBatches(list, 512);
    }

    public static <U extends Aggregate> List<Batch<U>> getBatches(List<U> list, int partitionSize) {
        DataType c = null;
        for (Aggregate u : list) {
            for (INDArray iNDArray : u.getArguments()) {
                if (c == null && iNDArray != null) {
                    c = iNDArray.dataType();
                }
                if (iNDArray == null || c == null) continue;
                Preconditions.checkArgument(c == iNDArray.dataType(), "All arguments must have same data type");
            }
        }
        if (c == null) {
            throw new ND4JIllegalStateException("Can't infer data type from arguments");
        }
        List<List<U>> partitions = Lists.partition(list, partitionSize);
        ArrayList<Batch<U>> split = new ArrayList<Batch<U>>();
        for (List list2 : partitions) {
            split.add(new Batch(list2));
        }
        return split;
    }

    public DataBuffer getParamsSurface() {
        return this.paramsSurface;
    }

    public void setParamsSurface(DataBuffer paramsSurface) {
        this.paramsSurface = paramsSurface;
    }

    public static int getBatchLimit() {
        return 512;
    }

    public List<T> getAggregates() {
        return this.aggregates;
    }

    public T getSample() {
        return this.sample;
    }

    public int getNumAggregates() {
        return this.numAggregates;
    }
}

