/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.api.ops.impl.broadcast;

import java.util.Collections;
import java.util.List;
import java.util.Map;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.broadcast.BiasAddGrad;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef;

public class BiasAdd
extends DynamicCustomOp {
    protected boolean nchw = true;

    public BiasAdd(SameDiff sameDiff, SDVariable input, SDVariable bias, boolean nchw) {
        super(null, sameDiff, new SDVariable[]{input, bias}, false);
        this.bArguments.clear();
        this.bArguments.add(nchw);
        this.nchw = nchw;
    }

    public BiasAdd(@NonNull INDArray input, @NonNull INDArray bias, boolean nchw) {
        this(input, bias, null, nchw);
        if (input == null) {
            throw new NullPointerException("input is marked non-null but is null");
        }
        if (bias == null) {
            throw new NullPointerException("bias is marked non-null but is null");
        }
    }

    public BiasAdd(@NonNull INDArray input, @NonNull INDArray bias, INDArray output, boolean nchw) {
        super(new INDArray[]{input, bias}, BiasAdd.wrapOrNull(output));
        if (input == null) {
            throw new NullPointerException("input is marked non-null but is null");
        }
        if (bias == null) {
            throw new NullPointerException("bias is marked non-null but is null");
        }
        this.bArguments.clear();
        this.bArguments.add(nchw);
        this.nchw = nchw;
    }

    @Override
    public String opName() {
        return "biasadd";
    }

    @Override
    public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
        super.initFromTensorFlow(nodeDef, initWith, attributesForNode, graph);
        this.nchw = attributesForNode.containsKey("data_format") ? "NCHW".equalsIgnoreCase(attributesForNode.get("data_format").getS().toStringUtf8()) : false;
        this.bArguments.clear();
        this.bArguments.add(this.nchw);
    }

    @Override
    public String onnxName() {
        return "BiasAdd";
    }

    @Override
    public String[] tensorflowNames() {
        return new String[]{"BiasAdd", "BiasAddV1"};
    }

    @Override
    public List<SDVariable> doDiff(List<SDVariable> gradient) {
        return new BiasAddGrad(this.sameDiff, this.arg(0), this.arg(1), gradient.get(0), this.nchw).outputs();
    }

    @Override
    public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes) {
        Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 2, "Expected 2 input data types for %s, got %s", this.getClass(), inputDataTypes);
        return Collections.singletonList(inputDataTypes.get(0));
    }

    public BiasAdd() {
    }
}

