/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.factory.ops;

import org.nd4j.autodiff.loss.LossReduce;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.loss.AbsoluteDifferenceLoss;
import org.nd4j.linalg.api.ops.impl.loss.CosineDistanceLoss;
import org.nd4j.linalg.api.ops.impl.loss.CtcLoss;
import org.nd4j.linalg.api.ops.impl.loss.HingeLoss;
import org.nd4j.linalg.api.ops.impl.loss.HuberLoss;
import org.nd4j.linalg.api.ops.impl.loss.L2Loss;
import org.nd4j.linalg.api.ops.impl.loss.LogLoss;
import org.nd4j.linalg.api.ops.impl.loss.LogPoissonLoss;
import org.nd4j.linalg.api.ops.impl.loss.MeanPairwiseSquaredErrorLoss;
import org.nd4j.linalg.api.ops.impl.loss.MeanSquaredErrorLoss;
import org.nd4j.linalg.api.ops.impl.loss.SigmoidCrossEntropyLoss;
import org.nd4j.linalg.api.ops.impl.loss.SoftmaxCrossEntropyLoss;
import org.nd4j.linalg.api.ops.impl.loss.SparseSoftmaxCrossEntropyLossWithLogits;
import org.nd4j.linalg.api.ops.impl.loss.WeightedCrossEntropyLoss;
import org.nd4j.linalg.factory.NDValidation;
import org.nd4j.linalg.factory.Nd4j;

public class NDLoss {
    public INDArray absoluteDifference(INDArray label, INDArray predictions, INDArray weights, LossReduce lossReduce) {
        NDValidation.validateNumerical("absoluteDifference", "label", label);
        NDValidation.validateNumerical("absoluteDifference", "predictions", predictions);
        NDValidation.validateNumerical("absoluteDifference", "weights", weights);
        return Nd4j.exec(new AbsoluteDifferenceLoss(label, predictions, weights, lossReduce))[0];
    }

    public INDArray absoluteDifference(INDArray label, INDArray predictions, INDArray weights) {
        NDValidation.validateNumerical("absoluteDifference", "label", label);
        NDValidation.validateNumerical("absoluteDifference", "predictions", predictions);
        NDValidation.validateNumerical("absoluteDifference", "weights", weights);
        return Nd4j.exec(new AbsoluteDifferenceLoss(label, predictions, weights, LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT))[0];
    }

    public INDArray cosineDistance(INDArray label, INDArray predictions, INDArray weights, LossReduce lossReduce, int dimension) {
        NDValidation.validateNumerical("cosineDistance", "label", label);
        NDValidation.validateNumerical("cosineDistance", "predictions", predictions);
        NDValidation.validateNumerical("cosineDistance", "weights", weights);
        return Nd4j.exec(new CosineDistanceLoss(label, predictions, weights, lossReduce, dimension))[0];
    }

    public INDArray cosineDistance(INDArray label, INDArray predictions, INDArray weights, int dimension) {
        NDValidation.validateNumerical("cosineDistance", "label", label);
        NDValidation.validateNumerical("cosineDistance", "predictions", predictions);
        NDValidation.validateNumerical("cosineDistance", "weights", weights);
        return Nd4j.exec(new CosineDistanceLoss(label, predictions, weights, LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, dimension))[0];
    }

    public INDArray ctcLoss(INDArray targetLabels, INDArray logitInput, INDArray targetLabelLengths, INDArray logitInputLengths) {
        NDValidation.validateNumerical("ctcLoss", "targetLabels", targetLabels);
        NDValidation.validateNumerical("ctcLoss", "logitInput", logitInput);
        NDValidation.validateNumerical("ctcLoss", "targetLabelLengths", targetLabelLengths);
        NDValidation.validateNumerical("ctcLoss", "logitInputLengths", logitInputLengths);
        return Nd4j.exec(new CtcLoss(targetLabels, logitInput, targetLabelLengths, logitInputLengths))[0];
    }

    public INDArray hingeLoss(INDArray label, INDArray predictions, INDArray weights, LossReduce lossReduce) {
        NDValidation.validateNumerical("hingeLoss", "label", label);
        NDValidation.validateNumerical("hingeLoss", "predictions", predictions);
        NDValidation.validateNumerical("hingeLoss", "weights", weights);
        return Nd4j.exec(new HingeLoss(label, predictions, weights, lossReduce))[0];
    }

    public INDArray hingeLoss(INDArray label, INDArray predictions, INDArray weights) {
        NDValidation.validateNumerical("hingeLoss", "label", label);
        NDValidation.validateNumerical("hingeLoss", "predictions", predictions);
        NDValidation.validateNumerical("hingeLoss", "weights", weights);
        return Nd4j.exec(new HingeLoss(label, predictions, weights, LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT))[0];
    }

    public INDArray huberLoss(INDArray label, INDArray predictions, INDArray weights, LossReduce lossReduce, double delta) {
        NDValidation.validateNumerical("huberLoss", "label", label);
        NDValidation.validateNumerical("huberLoss", "predictions", predictions);
        NDValidation.validateNumerical("huberLoss", "weights", weights);
        return Nd4j.exec(new HuberLoss(label, predictions, weights, lossReduce, delta))[0];
    }

    public INDArray huberLoss(INDArray label, INDArray predictions, INDArray weights, double delta) {
        NDValidation.validateNumerical("huberLoss", "label", label);
        NDValidation.validateNumerical("huberLoss", "predictions", predictions);
        NDValidation.validateNumerical("huberLoss", "weights", weights);
        return Nd4j.exec(new HuberLoss(label, predictions, weights, LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, delta))[0];
    }

    public INDArray l2Loss(INDArray var) {
        NDValidation.validateNumerical("l2Loss", "var", var);
        return Nd4j.exec(new L2Loss(var))[0];
    }

    public INDArray logLoss(INDArray label, INDArray predictions, INDArray weights, LossReduce lossReduce, double epsilon) {
        NDValidation.validateNumerical("logLoss", "label", label);
        NDValidation.validateNumerical("logLoss", "predictions", predictions);
        NDValidation.validateNumerical("logLoss", "weights", weights);
        return Nd4j.exec(new LogLoss(label, predictions, weights, lossReduce, epsilon))[0];
    }

    public INDArray logLoss(INDArray label, INDArray predictions) {
        NDValidation.validateNumerical("logLoss", "label", label);
        NDValidation.validateNumerical("logLoss", "predictions", predictions);
        return Nd4j.exec(new LogLoss(label, predictions, null, LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, 0.0))[0];
    }

    public INDArray logPoisson(INDArray label, INDArray predictions, INDArray weights, LossReduce lossReduce, boolean full) {
        NDValidation.validateNumerical("logPoisson", "label", label);
        NDValidation.validateNumerical("logPoisson", "predictions", predictions);
        NDValidation.validateNumerical("logPoisson", "weights", weights);
        return Nd4j.exec(new LogPoissonLoss(label, predictions, weights, lossReduce, full))[0];
    }

    public INDArray logPoisson(INDArray label, INDArray predictions, INDArray weights, boolean full) {
        NDValidation.validateNumerical("logPoisson", "label", label);
        NDValidation.validateNumerical("logPoisson", "predictions", predictions);
        NDValidation.validateNumerical("logPoisson", "weights", weights);
        return Nd4j.exec(new LogPoissonLoss(label, predictions, weights, LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, full))[0];
    }

    public INDArray meanPairwiseSquaredError(INDArray label, INDArray predictions, INDArray weights, LossReduce lossReduce) {
        NDValidation.validateNumerical("meanPairwiseSquaredError", "label", label);
        NDValidation.validateNumerical("meanPairwiseSquaredError", "predictions", predictions);
        NDValidation.validateNumerical("meanPairwiseSquaredError", "weights", weights);
        return Nd4j.exec(new MeanPairwiseSquaredErrorLoss(label, predictions, weights, lossReduce))[0];
    }

    public INDArray meanPairwiseSquaredError(INDArray label, INDArray predictions, INDArray weights) {
        NDValidation.validateNumerical("meanPairwiseSquaredError", "label", label);
        NDValidation.validateNumerical("meanPairwiseSquaredError", "predictions", predictions);
        NDValidation.validateNumerical("meanPairwiseSquaredError", "weights", weights);
        return Nd4j.exec(new MeanPairwiseSquaredErrorLoss(label, predictions, weights, LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT))[0];
    }

    public INDArray meanSquaredError(INDArray label, INDArray predictions, INDArray weights, LossReduce lossReduce) {
        NDValidation.validateNumerical("meanSquaredError", "label", label);
        NDValidation.validateNumerical("meanSquaredError", "predictions", predictions);
        NDValidation.validateNumerical("meanSquaredError", "weights", weights);
        return Nd4j.exec(new MeanSquaredErrorLoss(label, predictions, weights, lossReduce))[0];
    }

    public INDArray meanSquaredError(INDArray label, INDArray predictions, INDArray weights) {
        NDValidation.validateNumerical("meanSquaredError", "label", label);
        NDValidation.validateNumerical("meanSquaredError", "predictions", predictions);
        NDValidation.validateNumerical("meanSquaredError", "weights", weights);
        return Nd4j.exec(new MeanSquaredErrorLoss(label, predictions, weights, LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT))[0];
    }

    public INDArray sigmoidCrossEntropy(INDArray label, INDArray predictionLogits, INDArray weights, LossReduce lossReduce, double labelSmoothing) {
        NDValidation.validateNumerical("sigmoidCrossEntropy", "label", label);
        NDValidation.validateNumerical("sigmoidCrossEntropy", "predictionLogits", predictionLogits);
        NDValidation.validateNumerical("sigmoidCrossEntropy", "weights", weights);
        return Nd4j.exec(new SigmoidCrossEntropyLoss(label, predictionLogits, weights, lossReduce, labelSmoothing))[0];
    }

    public INDArray sigmoidCrossEntropy(INDArray label, INDArray predictionLogits, INDArray weights) {
        NDValidation.validateNumerical("sigmoidCrossEntropy", "label", label);
        NDValidation.validateNumerical("sigmoidCrossEntropy", "predictionLogits", predictionLogits);
        NDValidation.validateNumerical("sigmoidCrossEntropy", "weights", weights);
        return Nd4j.exec(new SigmoidCrossEntropyLoss(label, predictionLogits, weights, LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, 0.0))[0];
    }

    public INDArray softmaxCrossEntropy(INDArray oneHotLabels, INDArray logitPredictions, INDArray weights, LossReduce lossReduce, double labelSmoothing) {
        NDValidation.validateNumerical("softmaxCrossEntropy", "oneHotLabels", oneHotLabels);
        NDValidation.validateNumerical("softmaxCrossEntropy", "logitPredictions", logitPredictions);
        NDValidation.validateNumerical("softmaxCrossEntropy", "weights", weights);
        return Nd4j.exec(new SoftmaxCrossEntropyLoss(oneHotLabels, logitPredictions, weights, lossReduce, labelSmoothing))[0];
    }

    public INDArray softmaxCrossEntropy(INDArray oneHotLabels, INDArray logitPredictions, INDArray weights) {
        NDValidation.validateNumerical("softmaxCrossEntropy", "oneHotLabels", oneHotLabels);
        NDValidation.validateNumerical("softmaxCrossEntropy", "logitPredictions", logitPredictions);
        NDValidation.validateNumerical("softmaxCrossEntropy", "weights", weights);
        return Nd4j.exec(new SoftmaxCrossEntropyLoss(oneHotLabels, logitPredictions, weights, LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, 0.0))[0];
    }

    public INDArray sparseSoftmaxCrossEntropy(INDArray logits, INDArray labels) {
        NDValidation.validateNumerical("sparseSoftmaxCrossEntropy", "logits", logits);
        NDValidation.validateInteger("sparseSoftmaxCrossEntropy", "labels", labels);
        return Nd4j.exec(new SparseSoftmaxCrossEntropyLossWithLogits(logits, labels))[0];
    }

    public INDArray weightedCrossEntropyWithLogits(INDArray targets, INDArray inputs, INDArray weights) {
        NDValidation.validateNumerical("weightedCrossEntropyWithLogits", "targets", targets);
        NDValidation.validateNumerical("weightedCrossEntropyWithLogits", "inputs", inputs);
        NDValidation.validateNumerical("weightedCrossEntropyWithLogits", "weights", weights);
        return Nd4j.exec(new WeightedCrossEntropyLoss(targets, inputs, weights))[0];
    }
}

