/*
 * Decompiled with CFR 0.152.
 */
package com.aliasi.test.unit.stats;

import com.aliasi.corpus.ObjectHandler;
import com.aliasi.io.Reporter;
import com.aliasi.matrix.DenseVector;
import com.aliasi.matrix.SparseFloatVector;
import com.aliasi.matrix.Vector;
import com.aliasi.stats.AnnealingSchedule;
import com.aliasi.stats.LogisticRegression;
import com.aliasi.stats.RegressionPrior;
import com.aliasi.util.AbstractExternalizable;
import com.aliasi.util.ObjectToCounterMap;
import com.aliasi.util.ObjectToDoubleMap;
import com.aliasi.util.Pair;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import junit.framework.Assert;
import org.junit.Test;

public class LogisticRegressionTest {
    static final int[] WALLET_OUTCOME_VECTOR = new int[]{1, 1, 2, 2, 0, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 1, 0, 1, 1, 2, 2, 2, 2, 1, 1, 0, 2, 2, 2, 2, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2, 2, 0, 2, 2, 0, 2, 1, 0, 0, 2, 2, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 0, 0, 1, 0, 1, 0, 1, 0, 2, 2, 1, 2, 0, 2, 1, 2, 2, 1, 2, 2, 0, 1, 1, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 2, 1, 2, 1, 2, 2, 0, 2, 2, 2, 2, 1, 2, 1, 2, 1, 2, 2, 2, 2, 1, 2, 2, 1, 2, 2, 1, 2, 1, 2, 0, 2, 1, 0, 1, 2, 1, 2, 1, 1, 0, 1, 1, 0, 1, 1, 2, 2, 1, 0, 1, 2, 1, 2, 0, 1, 2, 1, 2, 2, 2, 2, 2, 1};
    static final double[][] WALLET_DATA_MATRIX = new double[][]{{1.0, 0.0, 0.0, 2.0, 0.0}, {1.0, 0.0, 0.0, 2.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 2.0, 0.0}, {1.0, 1.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 1.0, 0.0, 1.0, 1.0}, {1.0, 1.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 2.0, 1.0}, {1.0, 0.0, 1.0, 1.0, 1.0}, {1.0, 1.0, 1.0, 1.0, 1.0}, {1.0, 1.0, 0.0, 1.0, 1.0}, {1.0, 1.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 0.0}, {1.0, 0.0, 0.0, 2.0, 1.0}, {1.0, 0.0, 0.0, 3.0, 0.0}, {1.0, 1.0, 1.0, 3.0, 0.0}, {1.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 1.0, 0.0, 2.0, 0.0}, {1.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 1.0, 0.0, 1.0, 1.0}, {1.0, 1.0, 1.0, 1.0, 0.0}, {1.0, 1.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 0.0}, {1.0, 1.0, 0.0, 3.0, 0.0}, {1.0, 1.0, 0.0, 2.0, 0.0}, {1.0, 1.0, 0.0, 2.0, 0.0}, {1.0, 1.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 1.0, 0.0, 2.0, 1.0}, {1.0, 1.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 2.0, 0.0}, {1.0, 1.0, 0.0, 1.0, 0.0}, {1.0, 1.0, 1.0, 2.0, 1.0}, {1.0, 0.0, 0.0, 2.0, 0.0}, {1.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 1.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 2.0, 0.0}, {1.0, 1.0, 1.0, 1.0, 1.0}, {1.0, 1.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 0.0}, {1.0, 0.0, 1.0, 3.0, 0.0}, {1.0, 1.0, 0.0, 2.0, 0.0}, {1.0, 0.0, 0.0, 2.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 1.0, 0.0, 1.0, 1.0}, {1.0, 1.0, 0.0, 1.0, 1.0}, {1.0, 1.0, 0.0, 1.0, 0.0}, {1.0, 1.0, 1.0, 1.0, 0.0}, {1.0, 1.0, 0.0, 1.0, 0.0}, {1.0, 1.0, 1.0, 3.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 3.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 0.0}, {1.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 1.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 0.0}, {1.0, 1.0, 0.0, 1.0, 1.0}, {1.0, 1.0, 0.0, 1.0, 0.0}, {1.0, 1.0, 0.0, 3.0, 1.0}, {1.0, 1.0, 0.0, 3.0, 1.0}, {1.0, 1.0, 1.0, 2.0, 1.0}, {1.0, 1.0, 0.0, 2.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 3.0, 0.0}, {1.0, 1.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 1.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 0.0}, {1.0, 1.0, 1.0, 1.0, 1.0}, {1.0, 1.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 1.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 2.0, 1.0}, {1.0, 1.0, 1.0, 1.0, 0.0}, {1.0, 1.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 1.0, 1.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 1.0, 1.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 1.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 2.0, 1.0}, {1.0, 1.0, 1.0, 1.0, 1.0}, {1.0, 1.0, 0.0, 2.0, 0.0}, {1.0, 1.0, 0.0, 1.0, 1.0}, {1.0, 1.0, 1.0, 1.0, 1.0}, {1.0, 1.0, 0.0, 1.0, 1.0}, {1.0, 1.0, 1.0, 3.0, 0.0}, {1.0, 1.0, 1.0, 1.0, 1.0}, {1.0, 1.0, 1.0, 3.0, 1.0}, {1.0, 0.0, 0.0, 3.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 1.0, 3.0, 0.0}, {1.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 0.0}, {1.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 1.0, 0.0, 1.0, 1.0}, {1.0, 1.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 1.0, 0.0, 3.0, 0.0}, {1.0, 1.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 1.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 3.0, 0.0}, {1.0, 0.0, 1.0, 2.0, 0.0}, {1.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 1.0, 1.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 0.0}, {1.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 1.0, 0.0, 1.0, 1.0}, {1.0, 1.0, 1.0, 1.0, 1.0}, {1.0, 1.0, 0.0, 1.0, 1.0}, {1.0, 1.0, 1.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 0.0}, {1.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 1.0, 1.0, 2.0, 0.0}, {1.0, 1.0, 0.0, 1.0, 0.0}, {1.0, 1.0, 0.0, 1.0, 0.0}, {1.0, 0.0, 0.0, 2.0, 1.0}, {1.0, 1.0, 1.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 3.0, 0.0}, {1.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 1.0, 1.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 0.0}, {1.0, 0.0, 1.0, 1.0, 0.0}, {1.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 1.0, 0.0, 1.0, 1.0}, {1.0, 1.0, 0.0, 1.0, 0.0}, {1.0, 0.0, 0.0, 1.0, 0.0}, {1.0, 1.0, 0.0, 1.0, 1.0}, {1.0, 1.0, 0.0, 1.0, 1.0}, {1.0, 1.0, 0.0, 2.0, 0.0}, {1.0, 1.0, 1.0, 2.0, 1.0}, {1.0, 1.0, 0.0, 1.0, 1.0}, {1.0, 1.0, 0.0, 1.0, 0.0}, {1.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 1.0, 0.0, 1.0, 1.0}, {1.0, 1.0, 1.0, 1.0, 0.0}, {1.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 1.0, 0.0, 1.0, 1.0}, {1.0, 1.0, 0.0, 1.0, 0.0}, {1.0, 0.0, 1.0, 2.0, 1.0}, {1.0, 1.0, 1.0, 2.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 1.0, 0.0, 1.0, 1.0}, {1.0, 1.0, 1.0, 3.0, 1.0}, {1.0, 1.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 0.0}, {1.0, 1.0, 0.0, 3.0, 0.0}, {1.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 1.0, 1.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 2.0, 0.0}, {1.0, 1.0, 0.0, 1.0, 1.0}, {1.0, 1.0, 1.0, 2.0, 0.0}, {1.0, 1.0, 0.0, 2.0, 0.0}, {1.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 1.0, 3.0, 0.0}, {1.0, 1.0, 0.0, 1.0, 1.0}, {1.0, 1.0, 0.0, 1.0, 1.0}, {1.0, 1.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 1.0, 0.0, 1.0, 1.0}, {1.0, 1.0, 0.0, 2.0, 0.0}, {1.0, 0.0, 1.0, 2.0, 1.0}, {1.0, 0.0, 0.0, 2.0, 0.0}, {1.0, 1.0, 1.0, 1.0, 1.0}, {1.0, 0.0, 1.0, 2.0, 1.0}, {1.0, 0.0, 0.0, 3.0, 0.0}, {1.0, 1.0, 1.0, 1.0, 0.0}, {1.0, 0.0, 0.0, 3.0, 1.0}, {1.0, 1.0, 0.0, 2.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 1.0, 0.0, 3.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 1.0, 1.0, 1.0, 1.0}, {1.0, 1.0, 0.0, 1.0, 1.0}, {1.0, 1.0, 0.0, 1.0, 1.0}};
    static final double[][] WALLET_EXPECTED_FEATURES = new double[][]{{-3.4712, 1.2673, 1.1804, 1.0817, -1.6006}, {-1.2917, 1.1699, 0.4179, 0.1957, -0.804}, {0.0, 0.0, 0.0, 0.0, 0.0}};

    @Test
    public void testClass() {
        Vector[] weightVectors = new Vector[]{new DenseVector(new double[]{1.0, 2.0, 3.0}), new DenseVector(new double[]{-2.0, 1.0, -1.0})};
        LogisticRegression regression = new LogisticRegression(weightVectors);
        DenseVector testCase = new DenseVector(new double[]{1.0, -1.0, 2.0});
        double prod1 = 5.0;
        double prod2 = -5.0;
        double prod3 = 0.0;
        double prop1 = Math.exp(prod1);
        double prop2 = Math.exp(prod2);
        double prop3 = Math.exp(prod3);
        Assert.assertEquals((double)1.0, (double)prop3, (double)1.0E-4);
        double p1 = prop1 / (prop1 + prop2 + prop3);
        double p2 = prop2 / (prop1 + prop2 + prop3);
        double p3 = prop3 / (prop1 + prop2 + prop3);
        double[] expected = new double[]{p1, p2, p3};
        double[] estimated = regression.classify(testCase);
        Assert.assertEquals((int)expected.length, (int)estimated.length);
        for (int i = 0; i < expected.length; ++i) {
            Assert.assertEquals((double)expected[i], (double)estimated[i], (double)1.0E-7);
        }
    }

    static Vector[] sparseCopy(Vector[] matrix) {
        Vector[] result = new Vector[matrix.length];
        for (int i = 0; i < matrix.length; ++i) {
            result[i] = LogisticRegressionTest.sparseCopy(matrix[i]);
        }
        return result;
    }

    static Vector sparseCopy(Vector v) {
        int[] dims = new int[v.numDimensions()];
        float[] vals = new float[v.numDimensions()];
        for (int i = 0; i < dims.length; ++i) {
            dims[i] = i;
            vals[i] = (float)v.value(i);
        }
        return new SparseFloatVector(dims, vals, v.numDimensions());
    }

    @Test
    public void testEstimation() throws IOException, ClassNotFoundException {
        Vector[] data_matrix = new Vector[WALLET_DATA_MATRIX.length];
        for (int i = 0; i < data_matrix.length; ++i) {
            data_matrix[i] = new DenseVector(WALLET_DATA_MATRIX[i]);
        }
        Vector[] sparse_data_matrix = LogisticRegressionTest.sparseCopy(data_matrix);
        this.assertCorrectRegression(data_matrix, null);
        this.assertCorrectRegression(sparse_data_matrix, null);
        Pair<Vector[], Vector[]> convertedData2 = this.convertDataToWeightedProbs(data_matrix, WALLET_OUTCOME_VECTOR, 23, false);
        this.assertCorrectRegression(convertedData2.a(), convertedData2.b());
        convertedData2 = this.convertDataToWeightedProbs(sparse_data_matrix, WALLET_OUTCOME_VECTOR, 23, false);
        this.assertCorrectRegression(convertedData2.a(), convertedData2.b());
        convertedData2 = this.convertDataToWeightedProbs(data_matrix, WALLET_OUTCOME_VECTOR, 23, true);
        this.assertCorrectRegression(convertedData2.a(), convertedData2.b());
        convertedData2 = this.convertDataToWeightedProbs(sparse_data_matrix, WALLET_OUTCOME_VECTOR, 23, true);
        this.assertCorrectRegression(convertedData2.a(), convertedData2.b());
    }

    Pair<Vector[], Vector[]> convertDataToWeightedProbs(Vector[] input_data_matrix, int[] outcomes, int uniqueInputs, boolean consolidateToUniqueInputs) {
        HashMap<Vector, Object> inputsToOutcomeCounts = new HashMap<Vector, Object>();
        int numOutcomes = 0;
        for (int i = 0; i < input_data_matrix.length; ++i) {
            Vector input = input_data_matrix[i];
            ObjectToCounterMap<Integer> outcomeCount = (ObjectToCounterMap<Integer>)inputsToOutcomeCounts.get(input);
            if (outcomeCount == null) {
                outcomeCount = new ObjectToCounterMap<Integer>();
                inputsToOutcomeCounts.put(input, outcomeCount);
            }
            outcomeCount.increment(outcomes[i]);
            numOutcomes = Math.max(numOutcomes, outcomes[i] + 1);
        }
        Assert.assertEquals((int)inputsToOutcomeCounts.keySet().size(), (int)uniqueInputs);
        ArrayList<Vector> convertedInputs = new ArrayList<Vector>();
        ArrayList<SparseFloatVector> convertedOutputs = new ArrayList<SparseFloatVector>();
        for (Vector input : inputsToOutcomeCounts.keySet()) {
            ObjectToCounterMap outcomeCounts = (ObjectToCounterMap)inputsToOutcomeCounts.get(input);
            ObjectToDoubleMap<Integer> convertedOutput = new ObjectToDoubleMap<Integer>();
            int repeats = 0;
            for (Integer outcome : outcomeCounts.keySet()) {
                int count = outcomeCounts.getCount(outcome);
                convertedOutput.set(outcome, count);
                repeats += count;
            }
            for (Integer outcome : outcomeCounts.keySet()) {
                double newValue = convertedOutput.getValue(outcome);
                if (!consolidateToUniqueInputs) {
                    newValue /= (double)repeats;
                }
                convertedOutput.set(outcome, newValue);
            }
            int rLim = consolidateToUniqueInputs ? 1 : repeats;
            for (int r = 0; r < rLim; ++r) {
                convertedOutputs.add(new SparseFloatVector(convertedOutput, numOutcomes));
                convertedInputs.add(input);
            }
        }
        Pair<Vector[], Vector[]> inOutPair = new Pair<Vector[], Vector[]>(convertedInputs.toArray(new Vector[0]), convertedOutputs.toArray(new Vector[0]));
        if (consolidateToUniqueInputs) {
            Assert.assertEquals((int)inOutPair.a().length, (int)uniqueInputs);
        } else {
            Assert.assertEquals((int)inOutPair.a().length, (int)input_data_matrix.length);
        }
        return inOutPair;
    }

    void assertCorrectRegression(Vector[] data_matrix, Vector[] outcomes) throws IOException, ClassNotFoundException {
        Reporter reporter = null;
        LogisticRegression hotStart = null;
        ObjectHandler<LogisticRegression> handler = null;
        int priorBlockSize = 5;
        LogisticRegression regression = null;
        regression = outcomes == null ? LogisticRegression.estimate(data_matrix, WALLET_OUTCOME_VECTOR, RegressionPrior.noninformative(), priorBlockSize, hotStart, AnnealingSchedule.inverse(0.05, 100.0), 1.0E-5, 5, 10, 500000, handler, reporter) : LogisticRegression.estimate(data_matrix, outcomes, RegressionPrior.noninformative(), priorBlockSize, hotStart, AnnealingSchedule.inverse(0.05, 100.0), 1.0E-5, 5, 10, 500000, handler, reporter);
        double ALLOWABLE_ERROR = 0.12;
        Vector[] vs = regression.weightVectors();
        for (int i = 0; i < vs.length; ++i) {
            for (int j = 0; j < vs[i].numDimensions(); ++j) {
                Assert.assertEquals((double)WALLET_EXPECTED_FEATURES[i][j], (double)vs[i].value(j), (double)ALLOWABLE_ERROR);
            }
        }
        LogisticRegression regression2 = (LogisticRegression)AbstractExternalizable.compile(regression);
        Assert.assertEquals((int)regression.numOutcomes(), (int)regression2.numOutcomes());
        Assert.assertEquals((int)regression.numInputDimensions(), (int)regression.numInputDimensions());
        Vector[] vs1 = regression.weightVectors();
        Vector[] vs2 = regression2.weightVectors();
        Assert.assertEquals((int)vs1.length, (int)vs2.length);
        Assert.assertEquals((int)vs1.length, (int)vs2.length);
        for (int i = 0; i < vs1.length; ++i) {
            Assert.assertEquals((Object)vs1[i], (Object)vs2[i]);
        }
        hotStart = regression;
        priorBlockSize = 6;
        LogisticRegression regression3 = null;
        regression3 = outcomes == null ? LogisticRegression.estimate(data_matrix, WALLET_OUTCOME_VECTOR, RegressionPrior.noninformative(), priorBlockSize, hotStart, AnnealingSchedule.inverse(0.05, 100.0), 1.0E-7, 5, 10, 500000, handler, reporter) : LogisticRegression.estimate(data_matrix, outcomes, RegressionPrior.noninformative(), priorBlockSize, hotStart, AnnealingSchedule.inverse(0.05, 100.0), 1.0E-7, 5, 10, 500000, handler, reporter);
        vs = regression3.weightVectors();
        for (int i = 0; i < vs.length; ++i) {
            for (int j = 0; j < vs[i].numDimensions(); ++j) {
                Assert.assertEquals((double)WALLET_EXPECTED_FEATURES[i][j], (double)vs[i].value(j), (double)ALLOWABLE_ERROR);
            }
        }
    }
}

