/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.autodiff.samediff.ops;

import lombok.NonNull;
import org.nd4j.autodiff.loss.LossReduce;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.ops.SDOps;
import org.nd4j.autodiff.samediff.ops.SDValidation;
import org.nd4j.base.Preconditions;

public class SDLoss
extends SDOps {
    public SDLoss(SameDiff sameDiff) {
        super(sameDiff);
    }

    public SDVariable absoluteDifference(String name, @NonNull SDVariable label, @NonNull SDVariable predictions) {
        if (label == null) {
            throw new NullPointerException("label is marked @NonNull but is null");
        }
        if (predictions == null) {
            throw new NullPointerException("predictions is marked @NonNull but is null");
        }
        return this.absoluteDifference(name, label, predictions, null, LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT);
    }

    public SDVariable absoluteDifference(String name, @NonNull SDVariable label, @NonNull SDVariable predictions, SDVariable weights, @NonNull LossReduce lossReduce) {
        if (label == null) {
            throw new NullPointerException("label is marked @NonNull but is null");
        }
        if (predictions == null) {
            throw new NullPointerException("predictions is marked @NonNull but is null");
        }
        if (lossReduce == null) {
            throw new NullPointerException("lossReduce is marked @NonNull but is null");
        }
        SDValidation.validateFloatingPoint("absolute difference loss", "predictions", predictions);
        SDValidation.validateNumerical("absolute difference loss", "labels", label);
        if (weights == null) {
            weights = this.sd.scalar(null, predictions.dataType(), 1.0);
        }
        SDVariable result = this.f().lossAbsoluteDifference(label, predictions, weights, lossReduce);
        result = this.updateVariableNameAndReference(result, name);
        result.markAsLoss();
        return result;
    }

    public SDVariable absoluteDifference(String name, @NonNull SDVariable label, @NonNull SDVariable predictions, @NonNull LossReduce lossReduce) {
        if (label == null) {
            throw new NullPointerException("label is marked @NonNull but is null");
        }
        if (predictions == null) {
            throw new NullPointerException("predictions is marked @NonNull but is null");
        }
        if (lossReduce == null) {
            throw new NullPointerException("lossReduce is marked @NonNull but is null");
        }
        return this.absoluteDifference(name, label, predictions, null, lossReduce);
    }

    public SDVariable cosineDistance(String name, @NonNull SDVariable label, @NonNull SDVariable predictions, int dimension) {
        if (label == null) {
            throw new NullPointerException("label is marked @NonNull but is null");
        }
        if (predictions == null) {
            throw new NullPointerException("predictions is marked @NonNull but is null");
        }
        return this.cosineDistance(name, label, predictions, null, LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, dimension);
    }

    public SDVariable cosineDistance(String name, @NonNull SDVariable label, @NonNull SDVariable predictions, SDVariable weights, @NonNull LossReduce lossReduce, int dimension) {
        if (label == null) {
            throw new NullPointerException("label is marked @NonNull but is null");
        }
        if (predictions == null) {
            throw new NullPointerException("predictions is marked @NonNull but is null");
        }
        if (lossReduce == null) {
            throw new NullPointerException("lossReduce is marked @NonNull but is null");
        }
        SDValidation.validateFloatingPoint("cosine distance loss", "predictions", predictions);
        SDValidation.validateNumerical("cosine distance loss", "labels", label);
        if (weights == null) {
            weights = this.sd.scalar(null, predictions.dataType(), 1.0);
        }
        SDVariable result = this.f().lossCosineDistance(label, predictions, weights, lossReduce, dimension);
        result = this.updateVariableNameAndReference(result, name);
        result.markAsLoss();
        return result;
    }

    public SDVariable cosineDistance(String name, @NonNull SDVariable label, @NonNull SDVariable predictions, @NonNull LossReduce lossReduce, int dimension) {
        if (label == null) {
            throw new NullPointerException("label is marked @NonNull but is null");
        }
        if (predictions == null) {
            throw new NullPointerException("predictions is marked @NonNull but is null");
        }
        if (lossReduce == null) {
            throw new NullPointerException("lossReduce is marked @NonNull but is null");
        }
        return this.cosineDistance(name, label, predictions, null, lossReduce, dimension);
    }

    public SDVariable hingeLoss(String name, @NonNull SDVariable label, @NonNull SDVariable predictions) {
        if (label == null) {
            throw new NullPointerException("label is marked @NonNull but is null");
        }
        if (predictions == null) {
            throw new NullPointerException("predictions is marked @NonNull but is null");
        }
        return this.hingeLoss(name, label, predictions, null, LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT);
    }

    public SDVariable hingeLoss(String name, @NonNull SDVariable label, @NonNull SDVariable predictions, SDVariable weights, @NonNull LossReduce lossReduce) {
        if (label == null) {
            throw new NullPointerException("label is marked @NonNull but is null");
        }
        if (predictions == null) {
            throw new NullPointerException("predictions is marked @NonNull but is null");
        }
        if (lossReduce == null) {
            throw new NullPointerException("lossReduce is marked @NonNull but is null");
        }
        SDValidation.validateFloatingPoint("hinge loss", "predictions", predictions);
        SDValidation.validateNumerical("hinge loss", "labels", label);
        if (weights == null) {
            weights = this.sd.scalar(null, predictions.dataType(), 1.0);
        }
        SDVariable result = this.f().lossHinge(label, predictions, weights, lossReduce);
        result = this.updateVariableNameAndReference(result, name);
        result.markAsLoss();
        return result;
    }

    public SDVariable hingeLoss(String name, @NonNull SDVariable label, @NonNull SDVariable predictions, @NonNull LossReduce lossReduce) {
        if (label == null) {
            throw new NullPointerException("label is marked @NonNull but is null");
        }
        if (predictions == null) {
            throw new NullPointerException("predictions is marked @NonNull but is null");
        }
        if (lossReduce == null) {
            throw new NullPointerException("lossReduce is marked @NonNull but is null");
        }
        return this.hingeLoss(name, label, predictions, null, lossReduce);
    }

    public SDVariable huberLoss(String name, @NonNull SDVariable label, @NonNull SDVariable predictions, double delta) {
        if (label == null) {
            throw new NullPointerException("label is marked @NonNull but is null");
        }
        if (predictions == null) {
            throw new NullPointerException("predictions is marked @NonNull but is null");
        }
        return this.huberLoss(name, label, predictions, null, LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, delta);
    }

    public SDVariable huberLoss(String name, @NonNull SDVariable label, @NonNull SDVariable predictions, SDVariable weights, @NonNull LossReduce lossReduce, double delta) {
        if (label == null) {
            throw new NullPointerException("label is marked @NonNull but is null");
        }
        if (predictions == null) {
            throw new NullPointerException("predictions is marked @NonNull but is null");
        }
        if (lossReduce == null) {
            throw new NullPointerException("lossReduce is marked @NonNull but is null");
        }
        SDValidation.validateFloatingPoint("huber loss", "predictions", predictions);
        SDValidation.validateNumerical("huber loss", "labels", label);
        if (weights == null) {
            weights = this.sd.scalar(null, predictions.dataType(), 1.0);
        }
        SDVariable result = this.f().lossHuber(label, predictions, weights, lossReduce, delta);
        result = this.updateVariableNameAndReference(result, name);
        result.markAsLoss();
        return result;
    }

    public SDVariable huberLoss(String name, @NonNull SDVariable label, @NonNull SDVariable predictions, @NonNull LossReduce lossReduce, double delta) {
        if (label == null) {
            throw new NullPointerException("label is marked @NonNull but is null");
        }
        if (predictions == null) {
            throw new NullPointerException("predictions is marked @NonNull but is null");
        }
        if (lossReduce == null) {
            throw new NullPointerException("lossReduce is marked @NonNull but is null");
        }
        return this.huberLoss(name, label, predictions, null, lossReduce, delta);
    }

    public SDVariable l2Loss(@NonNull SDVariable var) {
        if (var == null) {
            throw new NullPointerException("var is marked @NonNull but is null");
        }
        return this.l2Loss(null, var);
    }

    public SDVariable l2Loss(String name, @NonNull SDVariable var) {
        if (var == null) {
            throw new NullPointerException("var is marked @NonNull but is null");
        }
        SDValidation.validateNumerical("l2 loss", var);
        SDVariable result = this.f().lossL2(var);
        result = this.updateVariableNameAndReference(result, name);
        result.markAsLoss();
        return result;
    }

    public SDVariable logLoss(String name, @NonNull SDVariable label, @NonNull SDVariable predictions) {
        if (label == null) {
            throw new NullPointerException("label is marked @NonNull but is null");
        }
        if (predictions == null) {
            throw new NullPointerException("predictions is marked @NonNull but is null");
        }
        return this.logLoss(name, label, predictions, null, LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, 1.0E-7);
    }

    public SDVariable logLoss(String name, @NonNull SDVariable label, @NonNull SDVariable predictions, SDVariable weights, @NonNull LossReduce lossReduce, double epsilon) {
        if (label == null) {
            throw new NullPointerException("label is marked @NonNull but is null");
        }
        if (predictions == null) {
            throw new NullPointerException("predictions is marked @NonNull but is null");
        }
        if (lossReduce == null) {
            throw new NullPointerException("lossReduce is marked @NonNull but is null");
        }
        SDValidation.validateFloatingPoint("log loss", "predictions", predictions);
        SDValidation.validateNumerical("log loss", "labels", label);
        if (weights == null) {
            weights = this.sd.scalar(null, predictions.dataType(), 1.0);
        }
        SDVariable result = this.f().lossLog(label, predictions, weights, lossReduce, epsilon);
        result = this.updateVariableNameAndReference(result, name);
        result.markAsLoss();
        return result;
    }

    public SDVariable logLoss(String name, @NonNull SDVariable label, @NonNull SDVariable predictions, @NonNull LossReduce lossReduce) {
        if (label == null) {
            throw new NullPointerException("label is marked @NonNull but is null");
        }
        if (predictions == null) {
            throw new NullPointerException("predictions is marked @NonNull but is null");
        }
        if (lossReduce == null) {
            throw new NullPointerException("lossReduce is marked @NonNull but is null");
        }
        return this.logLoss(name, label, predictions, null, lossReduce, 1.0E-7);
    }

    public SDVariable logPoisson(String name, @NonNull SDVariable label, @NonNull SDVariable predictions) {
        if (label == null) {
            throw new NullPointerException("label is marked @NonNull but is null");
        }
        if (predictions == null) {
            throw new NullPointerException("predictions is marked @NonNull but is null");
        }
        return this.logPoisson(name, label, predictions, null, LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT);
    }

    public SDVariable logPoisson(String name, @NonNull SDVariable label, @NonNull SDVariable predictions, SDVariable weights, @NonNull LossReduce lossReduce) {
        if (label == null) {
            throw new NullPointerException("label is marked @NonNull but is null");
        }
        if (predictions == null) {
            throw new NullPointerException("predictions is marked @NonNull but is null");
        }
        if (lossReduce == null) {
            throw new NullPointerException("lossReduce is marked @NonNull but is null");
        }
        SDValidation.validateFloatingPoint("log poisson loss", "predictions", predictions);
        SDValidation.validateNumerical("log poisson loss", "labels", label);
        if (weights == null) {
            weights = this.sd.scalar(null, predictions.dataType(), 1.0);
        }
        SDVariable result = this.f().lossLogPoisson(label, predictions, weights, lossReduce);
        result = this.updateVariableNameAndReference(result, name);
        result.markAsLoss();
        return result;
    }

    public SDVariable logPoisson(String name, @NonNull SDVariable label, @NonNull SDVariable predictions, @NonNull LossReduce lossReduce) {
        if (label == null) {
            throw new NullPointerException("label is marked @NonNull but is null");
        }
        if (predictions == null) {
            throw new NullPointerException("predictions is marked @NonNull but is null");
        }
        if (lossReduce == null) {
            throw new NullPointerException("lossReduce is marked @NonNull but is null");
        }
        return this.logPoisson(name, label, predictions, null, lossReduce);
    }

    public SDVariable logPoissonFull(String name, @NonNull SDVariable label, @NonNull SDVariable predictions) {
        if (label == null) {
            throw new NullPointerException("label is marked @NonNull but is null");
        }
        if (predictions == null) {
            throw new NullPointerException("predictions is marked @NonNull but is null");
        }
        return this.logPoissonFull(name, label, predictions, null, LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT);
    }

    public SDVariable logPoissonFull(String name, @NonNull SDVariable label, @NonNull SDVariable predictions, SDVariable weights, @NonNull LossReduce lossReduce) {
        if (label == null) {
            throw new NullPointerException("label is marked @NonNull but is null");
        }
        if (predictions == null) {
            throw new NullPointerException("predictions is marked @NonNull but is null");
        }
        if (lossReduce == null) {
            throw new NullPointerException("lossReduce is marked @NonNull but is null");
        }
        SDValidation.validateFloatingPoint("log poisson (full) loss", "predictions", predictions);
        SDValidation.validateNumerical("log poisson (full) loss", "labels", label);
        if (weights == null) {
            weights = this.sd.scalar(null, predictions.dataType(), 1.0);
        }
        SDVariable result = this.f().lossLogPoissonFull(label, predictions, weights, lossReduce);
        result = this.updateVariableNameAndReference(result, name);
        result.markAsLoss();
        return result;
    }

    public SDVariable logPoissonFull(String name, @NonNull SDVariable label, @NonNull SDVariable predictions, @NonNull LossReduce lossReduce) {
        if (label == null) {
            throw new NullPointerException("label is marked @NonNull but is null");
        }
        if (predictions == null) {
            throw new NullPointerException("predictions is marked @NonNull but is null");
        }
        if (lossReduce == null) {
            throw new NullPointerException("lossReduce is marked @NonNull but is null");
        }
        return this.logPoissonFull(name, label, predictions, null, lossReduce);
    }

    public SDVariable meanPairwiseSquaredError(String name, @NonNull SDVariable label, @NonNull SDVariable predictions, @NonNull LossReduce lossReduce) {
        if (label == null) {
            throw new NullPointerException("label is marked @NonNull but is null");
        }
        if (predictions == null) {
            throw new NullPointerException("predictions is marked @NonNull but is null");
        }
        if (lossReduce == null) {
            throw new NullPointerException("lossReduce is marked @NonNull but is null");
        }
        return this.meanPairwiseSquaredError(name, label, predictions, null, lossReduce);
    }

    public SDVariable meanPairwiseSquaredError(String name, @NonNull SDVariable label, @NonNull SDVariable predictions, SDVariable weights, @NonNull LossReduce lossReduce) {
        if (label == null) {
            throw new NullPointerException("label is marked @NonNull but is null");
        }
        if (predictions == null) {
            throw new NullPointerException("predictions is marked @NonNull but is null");
        }
        if (lossReduce == null) {
            throw new NullPointerException("lossReduce is marked @NonNull but is null");
        }
        SDValidation.validateFloatingPoint("main pairwise squared error loss", "predictions", predictions);
        SDValidation.validateNumerical("mean pairwise squared error loss", "labels", label);
        if (weights == null) {
            weights = this.sd.scalar(null, predictions.dataType(), 1.0);
        }
        SDVariable result = this.f().lossMeanPairwiseSquaredError(label, predictions, weights, lossReduce);
        result = this.updateVariableNameAndReference(result, name);
        result.markAsLoss();
        return result;
    }

    public SDVariable meanSquaredError(String name, @NonNull SDVariable label, @NonNull SDVariable predictions) {
        if (label == null) {
            throw new NullPointerException("label is marked @NonNull but is null");
        }
        if (predictions == null) {
            throw new NullPointerException("predictions is marked @NonNull but is null");
        }
        return this.meanSquaredError(name, label, predictions, null, LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT);
    }

    public SDVariable meanSquaredError(String name, @NonNull SDVariable label, @NonNull SDVariable predictions, SDVariable weights, @NonNull LossReduce lossReduce) {
        if (label == null) {
            throw new NullPointerException("label is marked @NonNull but is null");
        }
        if (predictions == null) {
            throw new NullPointerException("predictions is marked @NonNull but is null");
        }
        if (lossReduce == null) {
            throw new NullPointerException("lossReduce is marked @NonNull but is null");
        }
        SDValidation.validateFloatingPoint("mean squared error loss", "predictions", predictions);
        SDValidation.validateNumerical("mean squared error loss", "labels", label);
        if (weights == null) {
            weights = this.sd.scalar(null, predictions.dataType(), 1.0);
        }
        SDVariable result = this.f().lossMeanSquaredError(label, predictions, weights, lossReduce);
        result = this.updateVariableNameAndReference(result, name);
        result.markAsLoss();
        return result;
    }

    public SDVariable meanSquaredError(String name, @NonNull SDVariable label, @NonNull SDVariable predictions, @NonNull LossReduce lossReduce) {
        if (label == null) {
            throw new NullPointerException("label is marked @NonNull but is null");
        }
        if (predictions == null) {
            throw new NullPointerException("predictions is marked @NonNull but is null");
        }
        if (lossReduce == null) {
            throw new NullPointerException("lossReduce is marked @NonNull but is null");
        }
        return this.meanSquaredError(name, label, predictions, null, lossReduce);
    }

    public SDVariable sigmoidCrossEntropy(String name, @NonNull SDVariable label, @NonNull SDVariable predictions) {
        if (label == null) {
            throw new NullPointerException("label is marked @NonNull but is null");
        }
        if (predictions == null) {
            throw new NullPointerException("predictions is marked @NonNull but is null");
        }
        return this.sigmoidCrossEntropy(name, label, predictions, null, LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, 0.0);
    }

    public SDVariable sigmoidCrossEntropy(String name, @NonNull SDVariable label, @NonNull SDVariable predictionLogits, SDVariable weights, @NonNull LossReduce lossReduce, double labelSmoothing) {
        if (label == null) {
            throw new NullPointerException("label is marked @NonNull but is null");
        }
        if (predictionLogits == null) {
            throw new NullPointerException("predictionLogits is marked @NonNull but is null");
        }
        if (lossReduce == null) {
            throw new NullPointerException("lossReduce is marked @NonNull but is null");
        }
        SDValidation.validateFloatingPoint("sigmoid cross entropy loss", "predictions", predictionLogits);
        SDValidation.validateNumerical("sigmoid cross entropy loss", "labels", label);
        if (weights == null) {
            weights = this.sd.scalar(null, predictionLogits.dataType(), 1.0);
        }
        SDVariable result = this.f().lossSigmoidCrossEntropy(label, predictionLogits, weights, lossReduce, labelSmoothing);
        result = this.updateVariableNameAndReference(result, name);
        result.markAsLoss();
        return result;
    }

    public SDVariable sigmoidCrossEntropy(String name, @NonNull SDVariable label, @NonNull SDVariable predictions, @NonNull LossReduce lossReduce) {
        if (label == null) {
            throw new NullPointerException("label is marked @NonNull but is null");
        }
        if (predictions == null) {
            throw new NullPointerException("predictions is marked @NonNull but is null");
        }
        if (lossReduce == null) {
            throw new NullPointerException("lossReduce is marked @NonNull but is null");
        }
        return this.sigmoidCrossEntropy(name, label, predictions, null, lossReduce, 0.0);
    }

    public SDVariable softmaxCrossEntropy(String name, @NonNull SDVariable label, @NonNull SDVariable predictions) {
        if (label == null) {
            throw new NullPointerException("label is marked @NonNull but is null");
        }
        if (predictions == null) {
            throw new NullPointerException("predictions is marked @NonNull but is null");
        }
        return this.softmaxCrossEntropy(name, label, predictions, null, LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, 0.0);
    }

    public SDVariable softmaxCrossEntropy(String name, @NonNull SDVariable oneHotLabels, @NonNull SDVariable logitPredictions, SDVariable weights, @NonNull LossReduce lossReduce, double labelSmoothing) {
        if (oneHotLabels == null) {
            throw new NullPointerException("oneHotLabels is marked @NonNull but is null");
        }
        if (logitPredictions == null) {
            throw new NullPointerException("logitPredictions is marked @NonNull but is null");
        }
        if (lossReduce == null) {
            throw new NullPointerException("lossReduce is marked @NonNull but is null");
        }
        SDValidation.validateFloatingPoint("softmax cross entropy loss", "predictions", logitPredictions);
        SDValidation.validateNumerical("softmax cross entropy loss", "oneHotLabels", oneHotLabels);
        if (weights == null) {
            weights = this.sd.scalar(null, logitPredictions.dataType(), 1.0);
        }
        SDVariable result = this.f().lossSoftmaxCrossEntropy(oneHotLabels, logitPredictions, weights, lossReduce, labelSmoothing);
        result = this.updateVariableNameAndReference(result, name);
        result.markAsLoss();
        return result;
    }

    public SDVariable softmaxCrossEntropy(String name, @NonNull SDVariable label, @NonNull SDVariable predictions, @NonNull LossReduce lossReduce) {
        if (label == null) {
            throw new NullPointerException("label is marked @NonNull but is null");
        }
        if (predictions == null) {
            throw new NullPointerException("predictions is marked @NonNull but is null");
        }
        if (lossReduce == null) {
            throw new NullPointerException("lossReduce is marked @NonNull but is null");
        }
        return this.softmaxCrossEntropy(name, label, predictions, null, lossReduce, 0.0);
    }

    public SDVariable sparseSoftmaxCrossEntropy(@NonNull SDVariable logits, @NonNull SDVariable labels) {
        if (logits == null) {
            throw new NullPointerException("logits is marked @NonNull but is null");
        }
        if (labels == null) {
            throw new NullPointerException("labels is marked @NonNull but is null");
        }
        return this.sparseSoftmaxCrossEntropy(null, logits, labels);
    }

    public SDVariable sparseSoftmaxCrossEntropy(String name, @NonNull SDVariable logits, @NonNull SDVariable labels) {
        if (logits == null) {
            throw new NullPointerException("logits is marked @NonNull but is null");
        }
        if (labels == null) {
            throw new NullPointerException("labels is marked @NonNull but is null");
        }
        SDValidation.validateFloatingPoint("sparse softmax cross entropy", "logits (predictions)", logits);
        SDValidation.validateInteger("sparse softmax cross entropy", "labels", labels);
        Preconditions.checkState((boolean)labels.dataType().isIntType(), (String)"Labels variable must be an integer type: got %s", (Object)logits);
        SDVariable result = this.f().lossSparseSoftmaxCrossEntropy(logits, labels);
        result = this.updateVariableNameAndReference(result, name);
        result.markAsLoss();
        return result;
    }

    public SDVariable weightedCrossEntropyWithLogits(SDVariable targets, SDVariable inputs, SDVariable weights) {
        return this.weightedCrossEntropyWithLogits(null, targets, inputs, weights);
    }

    public SDVariable weightedCrossEntropyWithLogits(String name, SDVariable targets, SDVariable inputs, SDVariable weights) {
        SDValidation.validateFloatingPoint("weighted cross entropy with logits", "inputs", inputs);
        SDValidation.validateNumerical("weighted cross entropy with logits", "targets", targets);
        SDVariable result = this.f().weightedCrossEntropyWithLogits(targets, inputs, weights);
        result = this.updateVariableNameAndReference(result, name);
        result.markAsLoss();
        return result;
    }
}

