/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.evaluation.regression;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.nd4j.common.primitives.Triple;
import org.nd4j.evaluation.BaseEvaluation;
import org.nd4j.evaluation.IEvaluation;
import org.nd4j.evaluation.IMetric;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.reduce.same.ASum;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.serde.jackson.shaded.NDArrayTextDeSerializer;
import org.nd4j.serde.jackson.shaded.NDArrayTextSerializer;
import org.nd4j.shade.jackson.databind.annotation.JsonDeserialize;
import org.nd4j.shade.jackson.databind.annotation.JsonSerialize;

public class RegressionEvaluation
extends BaseEvaluation<RegressionEvaluation> {
    public static final int DEFAULT_PRECISION = 5;
    protected int axis = 1;
    private boolean initialized;
    private List<String> columnNames;
    private long precision;
    @JsonSerialize(using=NDArrayTextSerializer.class)
    @JsonDeserialize(using=NDArrayTextDeSerializer.class)
    private INDArray exampleCountPerColumn;
    @JsonSerialize(using=NDArrayTextSerializer.class)
    @JsonDeserialize(using=NDArrayTextDeSerializer.class)
    private INDArray labelsSumPerColumn;
    @JsonSerialize(using=NDArrayTextSerializer.class)
    @JsonDeserialize(using=NDArrayTextDeSerializer.class)
    private INDArray sumSquaredErrorsPerColumn;
    @JsonSerialize(using=NDArrayTextSerializer.class)
    @JsonDeserialize(using=NDArrayTextDeSerializer.class)
    private INDArray sumAbsErrorsPerColumn;
    @JsonSerialize(using=NDArrayTextSerializer.class)
    @JsonDeserialize(using=NDArrayTextDeSerializer.class)
    private INDArray currentMean;
    @JsonSerialize(using=NDArrayTextSerializer.class)
    @JsonDeserialize(using=NDArrayTextDeSerializer.class)
    private INDArray currentPredictionMean;
    @JsonSerialize(using=NDArrayTextSerializer.class)
    @JsonDeserialize(using=NDArrayTextDeSerializer.class)
    private INDArray sumOfProducts;
    @JsonSerialize(using=NDArrayTextSerializer.class)
    @JsonDeserialize(using=NDArrayTextDeSerializer.class)
    private INDArray sumSquaredLabels;
    @JsonSerialize(using=NDArrayTextSerializer.class)
    @JsonDeserialize(using=NDArrayTextDeSerializer.class)
    private INDArray sumSquaredPredicted;
    @JsonSerialize(using=NDArrayTextSerializer.class)
    @JsonDeserialize(using=NDArrayTextDeSerializer.class)
    private INDArray sumLabels;

    protected RegressionEvaluation(int axis, List<String> columnNames, long precision) {
        this.axis = axis;
        this.columnNames = columnNames;
        this.precision = precision;
    }

    public RegressionEvaluation() {
        this(null, 5L);
    }

    public RegressionEvaluation(long nColumns) {
        this(RegressionEvaluation.createDefaultColumnNames(nColumns), 5L);
    }

    public RegressionEvaluation(long nColumns, long precision) {
        this(RegressionEvaluation.createDefaultColumnNames(nColumns), precision);
    }

    public RegressionEvaluation(String ... columnNames) {
        this(columnNames == null || columnNames.length == 0 ? null : Arrays.asList(columnNames), 5L);
    }

    public RegressionEvaluation(List<String> columnNames) {
        this(columnNames, 5L);
    }

    public RegressionEvaluation(List<String> columnNames, long precision) {
        this.precision = precision;
        if (columnNames == null || columnNames.isEmpty()) {
            this.initialized = false;
        } else {
            this.columnNames = columnNames;
            this.initialize(columnNames.size());
        }
    }

    public void setAxis(int axis) {
        this.axis = axis;
    }

    public int getAxis() {
        return this.axis;
    }

    @Override
    public void reset() {
        this.initialized = false;
    }

    private void initialize(int n) {
        if (this.columnNames == null || this.columnNames.size() != n) {
            this.columnNames = RegressionEvaluation.createDefaultColumnNames(n);
        }
        this.exampleCountPerColumn = Nd4j.zeros(DataType.DOUBLE, n);
        this.labelsSumPerColumn = Nd4j.zeros(DataType.DOUBLE, n);
        this.sumSquaredErrorsPerColumn = Nd4j.zeros(DataType.DOUBLE, n);
        this.sumAbsErrorsPerColumn = Nd4j.zeros(DataType.DOUBLE, n);
        this.currentMean = Nd4j.zeros(DataType.DOUBLE, n);
        this.currentPredictionMean = Nd4j.zeros(DataType.DOUBLE, n);
        this.sumOfProducts = Nd4j.zeros(DataType.DOUBLE, n);
        this.sumSquaredLabels = Nd4j.zeros(DataType.DOUBLE, n);
        this.sumSquaredPredicted = Nd4j.zeros(DataType.DOUBLE, n);
        this.sumLabels = Nd4j.zeros(DataType.DOUBLE, n);
        this.initialized = true;
    }

    private static List<String> createDefaultColumnNames(long nColumns) {
        ArrayList<String> list = new ArrayList<String>((int)nColumns);
        int i = 0;
        while ((long)i < nColumns) {
            list.add("col_" + i);
            ++i;
        }
        return list;
    }

    @Override
    public void eval(INDArray labels, INDArray predictions) {
        this.eval(labels, predictions, (INDArray)null);
    }

    @Override
    public void eval(INDArray labels, INDArray networkPredictions, INDArray maskArray, List<? extends Serializable> recordMetaData) {
        this.eval(labels, networkPredictions, maskArray);
    }

    @Override
    public void eval(INDArray labelsArr, INDArray predictionsArr, INDArray maskArr) {
        Triple<INDArray, INDArray, INDArray> p = BaseEvaluation.reshapeAndExtractNotMasked(labelsArr, predictionsArr, maskArr, this.axis);
        INDArray labels = p.getFirst();
        INDArray predictions = p.getSecond();
        INDArray maskArray = p.getThird();
        if (labels.dataType() != predictions.dataType()) {
            labels = labels.castTo(predictions.dataType());
        }
        if (!this.initialized) {
            this.initialize((int)labels.size(1));
        }
        if ((long)this.columnNames.size() != labels.size(1) || (long)this.columnNames.size() != predictions.size(1)) {
            throw new IllegalArgumentException("Number of the columns of labels and predictions must match specification (" + this.columnNames.size() + "). Got " + labels.size(1) + " and " + predictions.size(1));
        }
        if (maskArray != null) {
            labels = labels.mul(maskArray);
            predictions = predictions.mul(maskArray);
        }
        this.labelsSumPerColumn.addi(labels.sum(0).castTo(this.labelsSumPerColumn.dataType()));
        INDArray error = predictions.sub(labels);
        INDArray absErrorSum = Nd4j.getExecutioner().exec(new ASum(error, 0));
        INDArray squaredErrorSum = error.mul(error).sum(0);
        this.sumAbsErrorsPerColumn.addi(absErrorSum.castTo(this.labelsSumPerColumn.dataType()));
        this.sumSquaredErrorsPerColumn.addi(squaredErrorSum.castTo(this.labelsSumPerColumn.dataType()));
        this.sumOfProducts.addi(labels.mul(predictions).sum(0).castTo(this.labelsSumPerColumn.dataType()));
        this.sumSquaredLabels.addi(labels.mul(labels).sum(0).castTo(this.labelsSumPerColumn.dataType()));
        this.sumSquaredPredicted.addi(predictions.mul(predictions).sum(0).castTo(this.labelsSumPerColumn.dataType()));
        long nRows = labels.size(0);
        INDArray newExampleCountPerColumn = maskArray == null ? this.exampleCountPerColumn.add(nRows) : this.exampleCountPerColumn.add(maskArray.sum(0).castTo(this.labelsSumPerColumn.dataType()));
        this.currentMean.muliRowVector(this.exampleCountPerColumn).addi(labels.sum(0).castTo(this.labelsSumPerColumn.dataType())).diviRowVector(newExampleCountPerColumn);
        this.currentPredictionMean.muliRowVector(this.exampleCountPerColumn).addi(predictions.sum(0).castTo(this.labelsSumPerColumn.dataType())).divi(newExampleCountPerColumn);
        this.exampleCountPerColumn = newExampleCountPerColumn;
        this.sumLabels.addi(labels.sum(0).castTo(this.labelsSumPerColumn.dataType()));
    }

    @Override
    public void merge(RegressionEvaluation other) {
        if (other.labelsSumPerColumn == null) {
            return;
        }
        if (this.labelsSumPerColumn == null) {
            this.columnNames = other.columnNames;
            this.precision = other.precision;
            this.exampleCountPerColumn = other.exampleCountPerColumn;
            this.labelsSumPerColumn = other.labelsSumPerColumn.dup();
            this.sumSquaredErrorsPerColumn = other.sumSquaredErrorsPerColumn.dup();
            this.sumAbsErrorsPerColumn = other.sumAbsErrorsPerColumn.dup();
            this.currentMean = other.currentMean.dup();
            this.currentPredictionMean = other.currentPredictionMean.dup();
            this.sumOfProducts = other.sumOfProducts.dup();
            this.sumSquaredLabels = other.sumSquaredLabels.dup();
            this.sumSquaredPredicted = other.sumSquaredPredicted.dup();
            return;
        }
        this.labelsSumPerColumn.addi(other.labelsSumPerColumn);
        this.sumSquaredErrorsPerColumn.addi(other.sumSquaredErrorsPerColumn);
        this.sumAbsErrorsPerColumn.addi(other.sumAbsErrorsPerColumn);
        this.currentMean.muliRowVector(this.exampleCountPerColumn).addi(other.currentMean.mulRowVector(other.exampleCountPerColumn)).diviRowVector(this.exampleCountPerColumn.add(other.exampleCountPerColumn));
        this.currentPredictionMean.muliRowVector(this.exampleCountPerColumn).addi(other.currentPredictionMean.mulRowVector(other.exampleCountPerColumn)).diviRowVector(this.exampleCountPerColumn.add(other.exampleCountPerColumn));
        this.sumOfProducts.addi(other.sumOfProducts);
        this.sumSquaredLabels.addi(other.sumSquaredLabels);
        this.sumSquaredPredicted.addi(other.sumSquaredPredicted);
        this.exampleCountPerColumn.addi(other.exampleCountPerColumn);
    }

    @Override
    public String stats() {
        if (!this.initialized) {
            return "RegressionEvaluation: No Data";
        }
        if (this.columnNames == null) {
            this.columnNames = RegressionEvaluation.createDefaultColumnNames(this.numColumns());
        }
        int maxLabelLength = 0;
        for (String s : this.columnNames) {
            maxLabelLength = Math.max(maxLabelLength, s.length());
        }
        int labelWidth = maxLabelLength + 5;
        long columnWidth = this.precision + 10L;
        String resultFormat = "%-" + labelWidth + "s%-" + columnWidth + "." + this.precision + "e%-" + columnWidth + "." + this.precision + "e%-" + columnWidth + "." + this.precision + "e%-" + columnWidth + "." + this.precision + "e%-" + columnWidth + "." + this.precision + "e%-" + columnWidth + "." + this.precision + "e";
        StringBuilder sb = new StringBuilder();
        String headerFormat = "%-" + labelWidth + "s%-" + columnWidth + "s%-" + columnWidth + "s%-" + columnWidth + "s%-" + columnWidth + "s%-" + columnWidth + "s%-" + columnWidth + "s";
        sb.append(String.format(headerFormat, "Column", "MSE", "MAE", "RMSE", "RSE", "PC", "R^2"));
        sb.append("\n");
        for (int i = 0; i < this.columnNames.size(); ++i) {
            String name = this.columnNames.get(i);
            double mse = this.meanSquaredError(i);
            double mae = this.meanAbsoluteError(i);
            double rmse = this.rootMeanSquaredError(i);
            double rse = this.relativeSquaredError(i);
            double corr = this.pearsonCorrelation(i);
            double r2 = this.rSquared(i);
            sb.append(String.format(resultFormat, name, mse, mae, rmse, rse, corr, r2));
            sb.append("\n");
        }
        return sb.toString();
    }

    public int numColumns() {
        if (this.columnNames == null) {
            if (this.exampleCountPerColumn == null) {
                return 0;
            }
            return (int)this.exampleCountPerColumn.size(1);
        }
        return this.columnNames.size();
    }

    public double meanSquaredError(int column) {
        return this.sumSquaredErrorsPerColumn.getDouble((long)column) / this.exampleCountPerColumn.getDouble((long)column);
    }

    public double meanAbsoluteError(int column) {
        return this.sumAbsErrorsPerColumn.getDouble((long)column) / this.exampleCountPerColumn.getDouble((long)column);
    }

    public double rootMeanSquaredError(int column) {
        return Math.sqrt(this.sumSquaredErrorsPerColumn.getDouble((long)column) / this.exampleCountPerColumn.getDouble((long)column));
    }

    @Deprecated
    public double correlationR2(int column) {
        return this.pearsonCorrelation(column);
    }

    public double pearsonCorrelation(int column) {
        double sumxiyi = this.sumOfProducts.getDouble((long)column);
        double predictionMean = this.currentPredictionMean.getDouble((long)column);
        double labelMean = this.currentMean.getDouble((long)column);
        double sumSquaredLabels = this.sumSquaredLabels.getDouble((long)column);
        double sumSquaredPredicted = this.sumSquaredPredicted.getDouble((long)column);
        double exampleCount = this.exampleCountPerColumn.getDouble((long)column);
        double r = sumxiyi - exampleCount * predictionMean * labelMean;
        return r /= Math.sqrt(sumSquaredLabels - exampleCount * labelMean * labelMean) * Math.sqrt(sumSquaredPredicted - exampleCount * predictionMean * predictionMean);
    }

    public double rSquared(int column) {
        double sumLabelSquared = this.sumSquaredLabels.getDouble((long)column);
        double meanLabel = this.currentMean.getDouble((long)column);
        double sumLabel = this.sumLabels.getDouble((long)column);
        double n = this.exampleCountPerColumn.getDouble((long)column);
        double sstot = sumLabelSquared + meanLabel * (n * meanLabel - 2.0 * sumLabel);
        double ssres = this.sumSquaredErrorsPerColumn.getDouble((long)column);
        return (sstot - ssres) / sstot;
    }

    public double relativeSquaredError(int column) {
        double numerator = this.sumSquaredPredicted.getDouble((long)column) - 2.0 * this.sumOfProducts.getDouble((long)column) + this.sumSquaredLabels.getDouble((long)column);
        double denominator = this.sumSquaredLabels.getDouble((long)column) - this.exampleCountPerColumn.getDouble((long)column) * this.currentMean.getDouble((long)column) * this.currentMean.getDouble((long)column);
        if (Math.abs(denominator) > Nd4j.EPS_THRESHOLD) {
            return numerator / denominator;
        }
        return Double.POSITIVE_INFINITY;
    }

    public double averageMeanSquaredError() {
        double ret = 0.0;
        for (int i = 0; i < this.numColumns(); ++i) {
            ret += this.meanSquaredError(i);
        }
        return ret / (double)this.numColumns();
    }

    public double averageMeanAbsoluteError() {
        double ret = 0.0;
        for (int i = 0; i < this.numColumns(); ++i) {
            ret += this.meanAbsoluteError(i);
        }
        return ret / (double)this.numColumns();
    }

    public double averagerootMeanSquaredError() {
        double ret = 0.0;
        for (int i = 0; i < this.numColumns(); ++i) {
            ret += this.rootMeanSquaredError(i);
        }
        return ret / (double)this.numColumns();
    }

    public double averagerelativeSquaredError() {
        double ret = 0.0;
        for (int i = 0; i < this.numColumns(); ++i) {
            ret += this.relativeSquaredError(i);
        }
        return ret / (double)this.numColumns();
    }

    @Deprecated
    public double averagecorrelationR2() {
        return this.averagePearsonCorrelation();
    }

    public double averagePearsonCorrelation() {
        double ret = 0.0;
        for (int i = 0; i < this.numColumns(); ++i) {
            ret += this.pearsonCorrelation(i);
        }
        return ret / (double)this.numColumns();
    }

    public double averageRSquared() {
        double ret = 0.0;
        for (int i = 0; i < this.numColumns(); ++i) {
            ret += this.rSquared(i);
        }
        return ret / (double)this.numColumns();
    }

    @Override
    public double getValue(IMetric metric) {
        if (metric instanceof Metric) {
            return this.scoreForMetric((Metric)metric);
        }
        throw new IllegalStateException("Can't get value for non-regression Metric " + metric);
    }

    public double scoreForMetric(Metric metric) {
        switch (metric) {
            case MSE: {
                return this.averageMeanSquaredError();
            }
            case MAE: {
                return this.averageMeanAbsoluteError();
            }
            case RMSE: {
                return this.averagerootMeanSquaredError();
            }
            case RSE: {
                return this.averagerelativeSquaredError();
            }
            case PC: {
                return this.averagePearsonCorrelation();
            }
            case R2: {
                return this.averageRSquared();
            }
        }
        throw new IllegalStateException("Unknown metric: " + metric);
    }

    public static RegressionEvaluation fromJson(String json) {
        return RegressionEvaluation.fromJson(json, RegressionEvaluation.class);
    }

    @Override
    public RegressionEvaluation newInstance() {
        return new RegressionEvaluation(this.axis, this.columnNames, this.precision);
    }

    public boolean isInitialized() {
        return this.initialized;
    }

    public List<String> getColumnNames() {
        return this.columnNames;
    }

    public long getPrecision() {
        return this.precision;
    }

    public INDArray getExampleCountPerColumn() {
        return this.exampleCountPerColumn;
    }

    public INDArray getLabelsSumPerColumn() {
        return this.labelsSumPerColumn;
    }

    public INDArray getSumSquaredErrorsPerColumn() {
        return this.sumSquaredErrorsPerColumn;
    }

    public INDArray getSumAbsErrorsPerColumn() {
        return this.sumAbsErrorsPerColumn;
    }

    public INDArray getCurrentMean() {
        return this.currentMean;
    }

    public INDArray getCurrentPredictionMean() {
        return this.currentPredictionMean;
    }

    public INDArray getSumOfProducts() {
        return this.sumOfProducts;
    }

    public INDArray getSumSquaredLabels() {
        return this.sumSquaredLabels;
    }

    public INDArray getSumSquaredPredicted() {
        return this.sumSquaredPredicted;
    }

    public INDArray getSumLabels() {
        return this.sumLabels;
    }

    public void setInitialized(boolean initialized) {
        this.initialized = initialized;
    }

    public void setColumnNames(List<String> columnNames) {
        this.columnNames = columnNames;
    }

    public void setPrecision(long precision) {
        this.precision = precision;
    }

    public void setExampleCountPerColumn(INDArray exampleCountPerColumn) {
        this.exampleCountPerColumn = exampleCountPerColumn;
    }

    public void setLabelsSumPerColumn(INDArray labelsSumPerColumn) {
        this.labelsSumPerColumn = labelsSumPerColumn;
    }

    public void setSumSquaredErrorsPerColumn(INDArray sumSquaredErrorsPerColumn) {
        this.sumSquaredErrorsPerColumn = sumSquaredErrorsPerColumn;
    }

    public void setSumAbsErrorsPerColumn(INDArray sumAbsErrorsPerColumn) {
        this.sumAbsErrorsPerColumn = sumAbsErrorsPerColumn;
    }

    public void setCurrentMean(INDArray currentMean) {
        this.currentMean = currentMean;
    }

    public void setCurrentPredictionMean(INDArray currentPredictionMean) {
        this.currentPredictionMean = currentPredictionMean;
    }

    public void setSumOfProducts(INDArray sumOfProducts) {
        this.sumOfProducts = sumOfProducts;
    }

    public void setSumSquaredLabels(INDArray sumSquaredLabels) {
        this.sumSquaredLabels = sumSquaredLabels;
    }

    public void setSumSquaredPredicted(INDArray sumSquaredPredicted) {
        this.sumSquaredPredicted = sumSquaredPredicted;
    }

    public void setSumLabels(INDArray sumLabels) {
        this.sumLabels = sumLabels;
    }

    @Override
    public String toString() {
        return "RegressionEvaluation(axis=" + this.getAxis() + ", initialized=" + this.isInitialized() + ", columnNames=" + this.getColumnNames() + ", precision=" + this.getPrecision() + ", exampleCountPerColumn=" + this.getExampleCountPerColumn() + ", labelsSumPerColumn=" + this.getLabelsSumPerColumn() + ", sumSquaredErrorsPerColumn=" + this.getSumSquaredErrorsPerColumn() + ", sumAbsErrorsPerColumn=" + this.getSumAbsErrorsPerColumn() + ", currentMean=" + this.getCurrentMean() + ", currentPredictionMean=" + this.getCurrentPredictionMean() + ", sumOfProducts=" + this.getSumOfProducts() + ", sumSquaredLabels=" + this.getSumSquaredLabels() + ", sumSquaredPredicted=" + this.getSumSquaredPredicted() + ", sumLabels=" + this.getSumLabels() + ")";
    }

    @Override
    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof RegressionEvaluation)) {
            return false;
        }
        RegressionEvaluation other = (RegressionEvaluation)o;
        if (!other.canEqual(this)) {
            return false;
        }
        if (!super.equals(o)) {
            return false;
        }
        if (this.isInitialized() != other.isInitialized()) {
            return false;
        }
        List<String> this$columnNames = this.getColumnNames();
        List<String> other$columnNames = other.getColumnNames();
        if (this$columnNames == null ? other$columnNames != null : !((Object)this$columnNames).equals(other$columnNames)) {
            return false;
        }
        if (this.getPrecision() != other.getPrecision()) {
            return false;
        }
        INDArray this$exampleCountPerColumn = this.getExampleCountPerColumn();
        INDArray other$exampleCountPerColumn = other.getExampleCountPerColumn();
        if (this$exampleCountPerColumn == null ? other$exampleCountPerColumn != null : !this$exampleCountPerColumn.equals(other$exampleCountPerColumn)) {
            return false;
        }
        INDArray this$labelsSumPerColumn = this.getLabelsSumPerColumn();
        INDArray other$labelsSumPerColumn = other.getLabelsSumPerColumn();
        if (this$labelsSumPerColumn == null ? other$labelsSumPerColumn != null : !this$labelsSumPerColumn.equals(other$labelsSumPerColumn)) {
            return false;
        }
        INDArray this$sumSquaredErrorsPerColumn = this.getSumSquaredErrorsPerColumn();
        INDArray other$sumSquaredErrorsPerColumn = other.getSumSquaredErrorsPerColumn();
        if (this$sumSquaredErrorsPerColumn == null ? other$sumSquaredErrorsPerColumn != null : !this$sumSquaredErrorsPerColumn.equals(other$sumSquaredErrorsPerColumn)) {
            return false;
        }
        INDArray this$sumAbsErrorsPerColumn = this.getSumAbsErrorsPerColumn();
        INDArray other$sumAbsErrorsPerColumn = other.getSumAbsErrorsPerColumn();
        if (this$sumAbsErrorsPerColumn == null ? other$sumAbsErrorsPerColumn != null : !this$sumAbsErrorsPerColumn.equals(other$sumAbsErrorsPerColumn)) {
            return false;
        }
        INDArray this$currentMean = this.getCurrentMean();
        INDArray other$currentMean = other.getCurrentMean();
        if (this$currentMean == null ? other$currentMean != null : !this$currentMean.equals(other$currentMean)) {
            return false;
        }
        INDArray this$currentPredictionMean = this.getCurrentPredictionMean();
        INDArray other$currentPredictionMean = other.getCurrentPredictionMean();
        if (this$currentPredictionMean == null ? other$currentPredictionMean != null : !this$currentPredictionMean.equals(other$currentPredictionMean)) {
            return false;
        }
        INDArray this$sumOfProducts = this.getSumOfProducts();
        INDArray other$sumOfProducts = other.getSumOfProducts();
        if (this$sumOfProducts == null ? other$sumOfProducts != null : !this$sumOfProducts.equals(other$sumOfProducts)) {
            return false;
        }
        INDArray this$sumSquaredLabels = this.getSumSquaredLabels();
        INDArray other$sumSquaredLabels = other.getSumSquaredLabels();
        if (this$sumSquaredLabels == null ? other$sumSquaredLabels != null : !this$sumSquaredLabels.equals(other$sumSquaredLabels)) {
            return false;
        }
        INDArray this$sumSquaredPredicted = this.getSumSquaredPredicted();
        INDArray other$sumSquaredPredicted = other.getSumSquaredPredicted();
        if (this$sumSquaredPredicted == null ? other$sumSquaredPredicted != null : !this$sumSquaredPredicted.equals(other$sumSquaredPredicted)) {
            return false;
        }
        INDArray this$sumLabels = this.getSumLabels();
        INDArray other$sumLabels = other.getSumLabels();
        return !(this$sumLabels == null ? other$sumLabels != null : !this$sumLabels.equals(other$sumLabels));
    }

    @Override
    protected boolean canEqual(Object other) {
        return other instanceof RegressionEvaluation;
    }

    @Override
    public int hashCode() {
        int PRIME = 59;
        int result = super.hashCode();
        result = result * 59 + (this.isInitialized() ? 79 : 97);
        List<String> $columnNames = this.getColumnNames();
        result = result * 59 + ($columnNames == null ? 43 : ((Object)$columnNames).hashCode());
        long $precision = this.getPrecision();
        result = result * 59 + (int)($precision >>> 32 ^ $precision);
        INDArray $exampleCountPerColumn = this.getExampleCountPerColumn();
        result = result * 59 + ($exampleCountPerColumn == null ? 43 : $exampleCountPerColumn.hashCode());
        INDArray $labelsSumPerColumn = this.getLabelsSumPerColumn();
        result = result * 59 + ($labelsSumPerColumn == null ? 43 : $labelsSumPerColumn.hashCode());
        INDArray $sumSquaredErrorsPerColumn = this.getSumSquaredErrorsPerColumn();
        result = result * 59 + ($sumSquaredErrorsPerColumn == null ? 43 : $sumSquaredErrorsPerColumn.hashCode());
        INDArray $sumAbsErrorsPerColumn = this.getSumAbsErrorsPerColumn();
        result = result * 59 + ($sumAbsErrorsPerColumn == null ? 43 : $sumAbsErrorsPerColumn.hashCode());
        INDArray $currentMean = this.getCurrentMean();
        result = result * 59 + ($currentMean == null ? 43 : $currentMean.hashCode());
        INDArray $currentPredictionMean = this.getCurrentPredictionMean();
        result = result * 59 + ($currentPredictionMean == null ? 43 : $currentPredictionMean.hashCode());
        INDArray $sumOfProducts = this.getSumOfProducts();
        result = result * 59 + ($sumOfProducts == null ? 43 : $sumOfProducts.hashCode());
        INDArray $sumSquaredLabels = this.getSumSquaredLabels();
        result = result * 59 + ($sumSquaredLabels == null ? 43 : $sumSquaredLabels.hashCode());
        INDArray $sumSquaredPredicted = this.getSumSquaredPredicted();
        result = result * 59 + ($sumSquaredPredicted == null ? 43 : $sumSquaredPredicted.hashCode());
        INDArray $sumLabels = this.getSumLabels();
        result = result * 59 + ($sumLabels == null ? 43 : $sumLabels.hashCode());
        return result;
    }

    public static enum Metric implements IMetric
    {
        MSE,
        MAE,
        RMSE,
        RSE,
        PC,
        R2;


        @Override
        public Class<? extends IEvaluation> getEvaluationClass() {
            return RegressionEvaluation.class;
        }

        @Override
        public boolean minimize() {
            return this != R2 && this != PC;
        }
    }
}

