/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.lossfunctions.impl;

import org.nd4j.common.primitives.Pair;
import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.lossfunctions.ILossFunction;
import org.nd4j.linalg.lossfunctions.LossUtil;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.nd4j.shade.jackson.annotation.JsonInclude;

@JsonInclude(value=JsonInclude.Include.NON_NULL)
public class LossMultiLabel
implements ILossFunction {
    private void calculate(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask, INDArray scoreOutput, INDArray gradientOutput) {
        if (scoreOutput == null && gradientOutput == null) {
            throw new IllegalArgumentException("You have to provide at least one of scoreOutput or gradientOutput!");
        }
        if (labels.size(1) != preOutput.size(1)) {
            throw new IllegalArgumentException("Labels array numColumns (size(1) = " + labels.size(1) + ") does not match output layer number of outputs (nOut = " + preOutput.size(1) + ") ");
        }
        labels = labels.castTo(preOutput.dataType());
        INDArray postOutput = activationFn.getActivation(preOutput.dup(), true);
        INDArray positive = labels;
        INDArray negative = labels.eq(0.0).castTo(Nd4j.defaultFloatingPointType());
        INDArray normFactor = negative.sum(true, 1).castTo(Nd4j.defaultFloatingPointType()).muli(positive.sum(true, 1));
        long examples = positive.size(0);
        int i = 0;
        while ((long)i < examples) {
            INDArray locCfn = postOutput.getRow(i, true);
            long[] shape = locCfn.shape();
            INDArray locPositive = positive.getRow(i, true);
            INDArray locNegative = negative.getRow(i, true);
            Double locNormFactor = normFactor.getDouble((long)i);
            int outSetSize = locNegative.sumNumber().intValue();
            if (outSetSize == 0 || outSetSize == locNegative.columns()) {
                if (scoreOutput != null) {
                    scoreOutput.getRow(i, true).assign(0);
                }
                if (gradientOutput != null) {
                    gradientOutput.getRow(i, true).assign(0);
                }
            } else {
                INDArray operandA = Nd4j.ones(shape[1], shape[0]).mmul(locCfn);
                INDArray operandB = operandA.transpose();
                INDArray pairwiseSub = Transforms.exp(operandA.sub(operandB));
                INDArray selection = locPositive.transpose().mmul(locNegative);
                INDArray classificationDifferences = pairwiseSub.muli(selection).divi(locNormFactor);
                if (scoreOutput != null) {
                    if (mask != null) {
                        INDArray perLabel = classificationDifferences.sum(0);
                        LossUtil.applyMask(perLabel, mask.getRow(i, true));
                        perLabel.sum(scoreOutput.getRow(i, true), 0);
                    } else {
                        classificationDifferences.sum(scoreOutput.getRow(i, true), 0, 1);
                    }
                }
                if (gradientOutput != null) {
                    gradientOutput.getRow(i, true).assign(classificationDifferences.sum(true, 0).addi(classificationDifferences.sum(true, 1).transposei().negi()));
                }
            }
            ++i;
        }
        if (gradientOutput != null) {
            gradientOutput.assign(activationFn.backprop(preOutput.dup(), gradientOutput).getFirst());
            if (mask != null) {
                LossUtil.applyMask(gradientOutput, mask);
            }
        }
    }

    public INDArray scoreArray(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask) {
        INDArray scoreArr = Nd4j.create(labels.size(0), 1L);
        this.calculate(labels, preOutput, activationFn, mask, scoreArr, null);
        return scoreArr;
    }

    @Override
    public double computeScore(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask, boolean average) {
        INDArray scoreArr = this.scoreArray(labels, preOutput, activationFn, mask);
        double score = scoreArr.sumNumber().doubleValue();
        if (average) {
            score /= (double)scoreArr.size(0);
        }
        return score;
    }

    @Override
    public INDArray computeScoreArray(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask) {
        INDArray scoreArr = this.scoreArray(labels, preOutput, activationFn, mask);
        return scoreArr.sum(true, 1);
    }

    @Override
    public INDArray computeGradient(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask) {
        if ((labels = labels.castTo(preOutput.dataType())).size(1) != preOutput.size(1)) {
            throw new IllegalArgumentException("Labels array numColumns (size(1) = " + labels.size(1) + ") does not match output layer number of outputs (nOut = " + preOutput.size(1) + ") ");
        }
        INDArray grad = Nd4j.ones(labels.shape());
        this.calculate(labels, preOutput, activationFn, mask, null, grad);
        return grad;
    }

    @Override
    public Pair<Double, INDArray> computeGradientAndScore(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask, boolean average) {
        INDArray scoreArr = Nd4j.create(labels.size(0), 1L);
        INDArray grad = Nd4j.ones(labels.shape());
        this.calculate(labels, preOutput, activationFn, mask, scoreArr, grad);
        double score = scoreArr.sumNumber().doubleValue();
        if (average) {
            score /= (double)scoreArr.size(0);
        }
        return new Pair<Double, INDArray>(score, grad);
    }

    @Override
    public String name() {
        return this.toString();
    }

    public String toString() {
        return "LossMultiLabel";
    }

    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof LossMultiLabel)) {
            return false;
        }
        LossMultiLabel other = (LossMultiLabel)o;
        return other.canEqual(this);
    }

    protected boolean canEqual(Object other) {
        return other instanceof LossMultiLabel;
    }

    public int hashCode() {
        boolean result = true;
        return 1;
    }
}

