/*
 * Decompiled with CFR 0.152.
 */
package mulan.classifier.neural;

import java.io.Serializable;
import java.util.Arrays;
import weka.core.Utils;
import weka.core.matrix.Matrix;

public class ThresholdFunction
implements Serializable {
    private static final long serialVersionUID = 5347411552628371402L;
    private double[] parameters;

    public ThresholdFunction(double[][] idealLabels, double[][] modelOutLabels) {
        this.build(idealLabels, modelOutLabels);
    }

    public double computeThreshold(double[] labelsConfidences) {
        int expectedDim = this.parameters.length - 1;
        if (labelsConfidences.length != expectedDim) {
            throw new IllegalArgumentException("The array of label confidences has wrong dimension.The function expect parameters of length : " + expectedDim);
        }
        double threshold = 0.0;
        for (int index = 0; index < expectedDim; ++index) {
            threshold += labelsConfidences[index] * this.parameters[index];
        }
        return threshold += this.parameters[expectedDim];
    }

    public void build(double[][] idealLabels, double[][] modelOutLabels) {
        if (idealLabels == null || modelOutLabels == null) {
            throw new IllegalArgumentException("Non of the input parameters can be null.");
        }
        int numExamples = idealLabels.length;
        int numLabels = idealLabels[0].length;
        if (modelOutLabels.length != numExamples || modelOutLabels[0].length != numLabels) {
            throw new IllegalArgumentException("Matrix dimensions of input parameters does not agree.");
        }
        double[] thresholds = new double[numExamples];
        double[] isLabelModelOuts = new double[numLabels];
        double[] isNotLabelModelOuts = new double[numLabels];
        for (int example = 0; example < numExamples; ++example) {
            double isNotLabelMax;
            Arrays.fill(isLabelModelOuts, Double.MAX_VALUE);
            Arrays.fill(isNotLabelModelOuts, -1.7976931348623157E308);
            for (int label = 0; label < numLabels; ++label) {
                if (idealLabels[example][label] == 1.0) {
                    isLabelModelOuts[label] = modelOutLabels[example][label];
                    continue;
                }
                isNotLabelModelOuts[label] = modelOutLabels[example][label];
            }
            double isLabelMin = isLabelModelOuts[Utils.minIndex((double[])isLabelModelOuts)];
            if (isLabelMin != (isNotLabelMax = isNotLabelModelOuts[Utils.maxIndex((double[])isNotLabelModelOuts)])) {
                if (isLabelMin == Double.MAX_VALUE) {
                    thresholds[example] = isNotLabelMax + 0.1;
                    continue;
                }
                if (isNotLabelMax == -1.7976931348623157E308) {
                    thresholds[example] = isLabelMin - 0.1;
                    continue;
                }
                thresholds[example] = (isLabelMin + isNotLabelMax) / 2.0;
                continue;
            }
            thresholds[example] = isLabelMin;
        }
        Matrix modelMatrix = new Matrix(numExamples, numLabels + 1, 1.0);
        modelMatrix.setMatrix(0, numExamples - 1, 0, numLabels - 1, new Matrix(modelOutLabels));
        Matrix weights = modelMatrix.solve(new Matrix(thresholds, thresholds.length));
        double[][] weightsArray = weights.transpose().getArray();
        this.parameters = Arrays.copyOf(weightsArray[0], weightsArray[0].length);
    }

    protected double[] getFunctionParameters() {
        return Arrays.copyOf(this.parameters, this.parameters.length);
    }
}

