/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.api.ops.impl.indexaccum.custom;

import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.common.base.Preconditions;
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef;

public class ArgMax
extends DynamicCustomOp {
    protected boolean keepDims;
    private int[] dimensions;
    protected DataType outputType;

    public ArgMax(SameDiff sameDiff, SDVariable i_v, boolean keepDims, int[] dimensions) {
        super(sameDiff, i_v);
        this.keepDims = false;
        this.outputType = DataType.INT64;
        this.keepDims = keepDims;
        this.dimensions = dimensions;
        if (dimensions != null && dimensions.length > 0) {
            this.addIArgument(dimensions);
        }
        this.addBArgument(keepDims);
        this.addDArgument(this.outputType);
    }

    public ArgMax() {
        this.keepDims = false;
        this.outputType = DataType.INT64;
    }

    public ArgMax(INDArray x, INDArray z, boolean keepDims, int ... dimensions) {
        INDArray[] iNDArrayArray;
        INDArray[] iNDArrayArray2 = new INDArray[]{x};
        if (z != null) {
            INDArray[] iNDArrayArray3 = new INDArray[1];
            iNDArrayArray = iNDArrayArray3;
            iNDArrayArray3[0] = z;
        } else {
            iNDArrayArray = new INDArray[]{};
        }
        super(iNDArrayArray2, iNDArrayArray);
        this.keepDims = false;
        this.outputType = DataType.INT64;
        this.keepDims = keepDims;
        this.dimensions = dimensions;
        if (dimensions != null && dimensions.length > 0) {
            this.addIArgument(dimensions);
        }
        this.addBArgument(keepDims);
        this.addDArgument(this.outputType);
    }

    public ArgMax(INDArray x, INDArray z, int ... dimensions) {
        this(x, z, false, dimensions);
    }

    public ArgMax(INDArray x, int ... dimensions) {
        this(x, null, dimensions);
    }

    public ArgMax(INDArray x, boolean keepDims, int ... dimensions) {
        this(x, null, keepDims, dimensions);
    }

    @Override
    public String opName() {
        return "argmax";
    }

    @Override
    public String tensorflowName() {
        return "ArgMax";
    }

    @Override
    public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
        this.outputType = attributesForNode.containsKey("output_type") ? TFGraphMapper.convertType(attributesForNode.get("output_type").getType()) : DataType.LONG;
    }

    @Override
    public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes) {
        Preconditions.checkState((inputDataTypes != null && (inputDataTypes.size() == 1 || inputDataTypes.size() == 2) ? 1 : 0) != 0, (String)"Expected 1 or 2 input datatype to argmax, got %s", inputDataTypes);
        return Collections.singletonList(this.outputType == null ? DataType.LONG : this.outputType);
    }

    public boolean isKeepDims() {
        return this.keepDims;
    }

    @Override
    public int[] getDimensions() {
        return this.dimensions;
    }

    public DataType getOutputType() {
        return this.outputType;
    }

    public void setKeepDims(boolean keepDims) {
        this.keepDims = keepDims;
    }

    @Override
    public void setDimensions(int[] dimensions) {
        this.dimensions = dimensions;
    }

    public void setOutputType(DataType outputType) {
        this.outputType = outputType;
    }

    @Override
    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof ArgMax)) {
            return false;
        }
        ArgMax other = (ArgMax)o;
        if (!other.canEqual(this)) {
            return false;
        }
        if (this.isKeepDims() != other.isKeepDims()) {
            return false;
        }
        if (!Arrays.equals(this.getDimensions(), other.getDimensions())) {
            return false;
        }
        DataType this$outputType = this.getOutputType();
        DataType other$outputType = other.getOutputType();
        return !(this$outputType == null ? other$outputType != null : !((Object)((Object)this$outputType)).equals((Object)other$outputType));
    }

    protected boolean canEqual(Object other) {
        return other instanceof ArgMax;
    }

    @Override
    public int hashCode() {
        int PRIME = 59;
        int result = 1;
        result = result * 59 + (this.isKeepDims() ? 79 : 97);
        result = result * 59 + Arrays.hashCode(this.getDimensions());
        DataType $outputType = this.getOutputType();
        result = result * 59 + ($outputType == null ? 43 : ((Object)((Object)$outputType)).hashCode());
        return result;
    }

    @Override
    public String toString() {
        return "ArgMax(keepDims=" + this.isKeepDims() + ", dimensions=" + Arrays.toString(this.getDimensions()) + ", outputType=" + (Object)((Object)this.getOutputType()) + ")";
    }
}

