/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.api.ops.impl.transforms.gradient;

import java.util.List;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BaseGradientOp;
import org.nd4j.linalg.api.ops.impl.transforms.OldSoftMax;
import org.nd4j.linalg.factory.Nd4j;

public class SoftMaxDerivative
extends BaseGradientOp {
    public SoftMaxDerivative(INDArray x, INDArray z) {
        super(x, z);
    }

    public SoftMaxDerivative() {
    }

    public SoftMaxDerivative(INDArray x, INDArray z, long n) {
        super(x, z, n);
    }

    public SoftMaxDerivative(INDArray x, INDArray y, INDArray z) {
        super(x, y, z, z.lengthLong());
    }

    public SoftMaxDerivative(INDArray x) {
        super(x);
    }

    @Override
    public int opNum() {
        return 0;
    }

    @Override
    public String opName() {
        return "softmaxderivative";
    }

    @Override
    public String onnxName() {
        throw new NoOpNameFoundException("No onnx op opName found for " + this.opName());
    }

    @Override
    public String tensorflowName() {
        throw new NoOpNameFoundException("No tensorflow op opName found for " + this.opName());
    }

    @Override
    public void exec() {
        INDArray softmaxed = Nd4j.getExecutioner().execAndReturn(new OldSoftMax(this.x));
        INDArray mulled = softmaxed.muli(this.y);
        INDArray summed = mulled.sum(-1);
        softmaxed.muliColumnVector(summed);
        mulled.subi(softmaxed);
    }

    @Override
    public void exec(int ... dimensions) {
        super.exec(dimensions);
    }

    @Override
    public List<SDVariable> doDiff(List<SDVariable> i_v) {
        throw new UnsupportedOperationException();
    }
}

