/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.api.ops.impl.shape;

import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.descriptors.properties.PropertyMapping;
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef;

public class Gather
extends DynamicCustomOp {
    protected int[] indices;
    protected int jaxis = 0;

    public Gather(SameDiff sameDiff, SDVariable df, SDVariable indices, int axis) {
        this(sameDiff, df, indices, axis, false);
    }

    public Gather(SameDiff sameDiff, SDVariable df, int[] indices, int axis) {
        this(sameDiff, df, indices, axis, false);
    }

    public Gather(SameDiff sameDiff, SDVariable input, int[] indices, int axis, boolean inPlace) {
        super(null, sameDiff, new SDVariable[]{input}, inPlace);
        this.addIArgument(axis);
        this.addIArgument(indices);
        this.jaxis = axis;
        this.indices = indices;
    }

    public Gather(SameDiff sameDiff, SDVariable input, SDVariable indices, int axis, boolean inPlace) {
        super(null, sameDiff, new SDVariable[]{input, indices}, inPlace);
        this.addIArgument(axis);
        this.jaxis = axis;
    }

    public Gather(INDArray df, int[] indexes, int axis) {
        this.addInputArgument(df);
        this.addIArgument(axis);
        this.addIArgument(indexes);
        this.jaxis = axis;
        this.indices = this.indices;
    }

    public Gather(INDArray df, INDArray indexes, int axis) {
        this.addInputArgument(df, indexes);
        this.addIArgument(axis);
        this.jaxis = axis;
        this.indices = this.indices;
    }

    @Override
    public String onnxName() {
        return "Gather";
    }

    @Override
    public String[] tensorflowNames() {
        return new String[]{"Gather", "GatherV2"};
    }

    @Override
    public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
        TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
    }

    @Override
    public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
    }

    @Override
    public Map<String, Map<String, PropertyMapping>> mappingsForFunction() {
        HashMap<String, Map<String, PropertyMapping>> ret = new HashMap<String, Map<String, PropertyMapping>>();
        HashMap<String, PropertyMapping> map = new HashMap<String, PropertyMapping>();
        PropertyMapping broadcast = PropertyMapping.builder().onnxAttrName("indices").tfInputPosition(1).propertyNames(new String[]{"indices"}).build();
        map.put("indices", broadcast);
        ret.put(this.tensorflowNames()[0], map);
        ret.put(this.onnxName(), map);
        HashMap<String, PropertyMapping> map2 = new HashMap<String, PropertyMapping>();
        PropertyMapping broadcast2 = PropertyMapping.builder().tfInputPosition(1).propertyNames(new String[]{"indices"}).build();
        map2.put("indices", broadcast2);
        PropertyMapping axis2 = PropertyMapping.builder().tfInputPosition(2).propertyNames(new String[]{"axis"}).build();
        map2.put("axis", axis2);
        ret.put("GatherV2", map2);
        return ret;
    }

    @Override
    public String opName() {
        return "gather";
    }

    @Override
    public List<SDVariable> doDiff(List<SDVariable> i_v) {
        SDVariable axis;
        SDVariable indicesGrad = this.sameDiff.zerosLike(this.arg(1));
        SDVariable inputGrad = this.sameDiff.zerosLike(this.arg(0));
        SDVariable[] inputs = this.args();
        SDVariable rank = inputs[0].rank();
        if (inputs.length == 2) {
            axis = this.sameDiff.constant(this.jaxis);
            if (this.jaxis < 0) {
                axis = axis.add(rank);
            }
        } else {
            axis = inputs[2];
        }
        SDVariable dimsExAxis = this.sameDiff.range(null, this.sameDiff.constant(0), rank, this.sameDiff.constant(1), DataType.INT);
        SDVariable axisRank1 = axis.reshape(1);
        dimsExAxis = this.sameDiff.math().listDiff(dimsExAxis, axisRank1)[0];
        SDVariable permuteDims = this.sameDiff.concat(0, axisRank1, dimsExAxis);
        SDVariable invertDims = this.sameDiff.invertPermutation(permuteDims);
        SDVariable gradAtOut = i_v.get(0);
        SDVariable permuteGrad = gradAtOut.permute(permuteDims);
        SDVariable inputGradPermute = inputGrad.permute(permuteDims);
        inputGrad = this.sameDiff.scatterAdd(inputGradPermute, this.arg(1), permuteGrad);
        inputGrad = inputGrad.permute(invertDims);
        return Arrays.asList(inputGrad, indicesGrad);
    }

    @Override
    public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes) {
        return Collections.singletonList(dataTypes.get(0));
    }

    public Gather() {
    }
}

