/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.api.ops.custom;

import java.util.ArrayList;
import java.util.List;
import lombok.NonNull;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.CustomOp;
import org.nd4j.linalg.api.ops.CustomOpDescriptor;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;

public class ScatterUpdate
implements CustomOp {
    protected CustomOp op;

    public ScatterUpdate(@NonNull INDArray original, @NonNull INDArray updates, @NonNull int[] indices, int[] dimension, @NonNull UpdateOp op) {
        this(original, updates, null, indices, dimension, op);
        if (original == null) {
            throw new NullPointerException("original is marked @NonNull but is null");
        }
        if (updates == null) {
            throw new NullPointerException("updates is marked @NonNull but is null");
        }
        if (indices == null) {
            throw new NullPointerException("indices is marked @NonNull but is null");
        }
        if (op == null) {
            throw new NullPointerException("op is marked @NonNull but is null");
        }
    }

    public ScatterUpdate(@NonNull INDArray original, @NonNull INDArray updates, INDArray result, @NonNull int[] indices, int[] dimension, @NonNull UpdateOp op) {
        if (original == null) {
            throw new NullPointerException("original is marked @NonNull but is null");
        }
        if (updates == null) {
            throw new NullPointerException("updates is marked @NonNull but is null");
        }
        if (indices == null) {
            throw new NullPointerException("indices is marked @NonNull but is null");
        }
        if (op == null) {
            throw new NullPointerException("op is marked @NonNull but is null");
        }
        ArrayList<Integer> iargs = new ArrayList<Integer>();
        iargs.add(op.ordinal());
        iargs.add(dimension.length);
        for (int v : dimension) {
            iargs.add(v);
        }
        iargs.add(indices.length);
        for (int v : indices) {
            iargs.add(v);
        }
        if (updates.tensorAlongDimension(0, dimension).lengthLong() != original.tensorAlongDimension(0, dimension).lengthLong()) {
            throw new ND4JIllegalStateException("ScatterUpdate requires equal shaped tensors for operation along given dimension(s)");
        }
        long numTensors = original.tensorssAlongDimension(dimension);
        for (int idx : indices) {
            if ((long)idx < numTensors) continue;
            throw new ND4JIllegalStateException("Can't update index higher then num tensors");
        }
        this.op = DynamicCustomOp.builder("scatter_update").addInputs(original, updates).callInplace(true).addIntegerArguments(iargs).build();
    }

    @Override
    public String opName() {
        return this.op.opName();
    }

    @Override
    public long opHash() {
        return this.op.opHash();
    }

    @Override
    public boolean isInplaceCall() {
        return this.op.isInplaceCall();
    }

    @Override
    public INDArray[] outputArguments() {
        return this.op.outputArguments();
    }

    @Override
    public INDArray[] inputArguments() {
        return this.op.inputArguments();
    }

    @Override
    public long[] iArgs() {
        return this.op.iArgs();
    }

    @Override
    public double[] tArgs() {
        return this.op.tArgs();
    }

    @Override
    public void addIArgument(int ... arg) {
        this.op.addIArgument(arg);
    }

    @Override
    public void addIArgument(long ... arg) {
        this.op.addIArgument(arg);
    }

    @Override
    public void removeIArgument(Integer arg) {
        this.op.removeIArgument(arg);
    }

    @Override
    public Long getIArgument(int index) {
        return this.op.getIArgument(index);
    }

    @Override
    public int numIArguments() {
        return this.op.numIArguments();
    }

    @Override
    public void addTArgument(double ... arg) {
        this.op.addTArgument(arg);
    }

    @Override
    public void removeTArgument(Double arg) {
        this.op.removeTArgument(arg);
    }

    @Override
    public Double getTArgument(int index) {
        return this.op.getTArgument(index);
    }

    @Override
    public int numTArguments() {
        return this.op.numTArguments();
    }

    @Override
    public void addInputArgument(INDArray ... arg) {
        this.op.addInputArgument(arg);
    }

    @Override
    public void removeInputArgument(INDArray arg) {
        this.op.removeInputArgument(arg);
    }

    @Override
    public INDArray getInputArgument(int index) {
        return this.op.getInputArgument(index);
    }

    @Override
    public int numInputArguments() {
        return this.op.numInputArguments();
    }

    @Override
    public void addOutputArgument(INDArray ... arg) {
        this.op.addOutputArgument(arg);
    }

    @Override
    public void removeOutputArgument(INDArray arg) {
    }

    @Override
    public INDArray getOutputArgument(int index) {
        return this.op.getOutputArgument(index);
    }

    @Override
    public int numOutputArguments() {
        return this.op.numOutputArguments();
    }

    @Override
    public List<long[]> calculateOutputShape() {
        return Nd4j.getExecutioner().calculateOutputShape(this);
    }

    @Override
    public CustomOpDescriptor getDescriptor() {
        return this.op.getDescriptor();
    }

    @Override
    public void assertValidForExecution() {
    }

    @Override
    public void populateInputsAndOutputsFromSameDiff() {
    }

    public static enum UpdateOp {
        ADD,
        SUBTRACT,
        MILTIPLY,
        DIVIDE,
        RSUBTRACT,
        RDIVIDE,
        ASSIGN;

    }
}

