/*
 * Decompiled with CFR 0.152.
 */
package com.aliasi.test.unit.dca;

import com.aliasi.dca.DiscreteChooser;
import com.aliasi.io.Reporter;
import com.aliasi.matrix.DenseVector;
import com.aliasi.matrix.Vector;
import com.aliasi.stats.AnnealingSchedule;
import com.aliasi.stats.RegressionPrior;
import com.aliasi.test.unit.Asserts;
import com.aliasi.util.AbstractExternalizable;
import com.aliasi.util.Math;
import java.io.IOException;
import java.util.Random;
import junit.framework.Assert;
import org.junit.Test;

public class DiscreteChooserTest {
    @Test
    public void testSim() throws IOException {
        int numSamples = 1000;
        double[] simCoeffs = new double[]{0.0, 3.0, -2.0, 1.0};
        int numDims = simCoeffs.length;
        DenseVector simCoeffVector = new DenseVector(simCoeffs);
        DiscreteChooser simChooser = new DiscreteChooser(simCoeffVector);
        Random random = new Random(42L);
        Vector[][] alternativess = new Vector[numSamples][];
        int[] choices = new int[numSamples];
        block0: for (int i = 0; i < numSamples; ++i) {
            int numChoices = 1 + random.nextInt(8);
            alternativess[i] = new Vector[numChoices];
            for (int k = 0; k < numChoices; ++k) {
                double[] xs = new double[numDims];
                xs[0] = 1.0;
                for (int d = 1; d < numDims; ++d) {
                    xs[d] = 2.0 * random.nextGaussian();
                }
                alternativess[i][k] = new DenseVector(xs);
            }
            double[] choiceProbs = simChooser.choiceProbs(alternativess[i]);
            double choiceProb = random.nextDouble();
            double cumProb = 0.0;
            for (int k = 0; k < numChoices; ++k) {
                if (!(choiceProb < (cumProb += choiceProbs[k])) && k != numChoices - 1) continue;
                choices[i] = k;
                continue block0;
            }
        }
        double priorVariance = 5.0;
        boolean nonInformativeIntercept = true;
        RegressionPrior prior = RegressionPrior.gaussian(priorVariance, nonInformativeIntercept);
        int priorBlockSize = 100;
        double initialLearningRate = 0.1;
        double decayBase = 0.99;
        AnnealingSchedule annealingSchedule = AnnealingSchedule.exponential(initialLearningRate, decayBase);
        double minImprovement = 1.0E-5;
        int minEpochs = 5;
        int maxEpochs = 500;
        Reporter reporter = null;
        DiscreteChooser chooser = DiscreteChooser.estimate(alternativess, choices, prior, priorBlockSize, annealingSchedule, minImprovement, minEpochs, maxEpochs, reporter);
        Vector coeffVector = chooser.coefficients();
        for (int d = 0; d < coeffVector.numDimensions(); ++d) {
            Assert.assertEquals((double)simCoeffVector.value(d), (double)coeffVector.value(d), (double)0.1);
        }
        DiscreteChooser deserChooser = (DiscreteChooser)AbstractExternalizable.serializeDeserialize(chooser);
        Vector deserCoeffVector = deserChooser.coefficients();
        for (int d = 0; d < coeffVector.numDimensions(); ++d) {
            Assert.assertEquals((double)coeffVector.value(d), (double)deserCoeffVector.value(d), (double)1.0E-5);
        }
    }

    @Test
    public void testChoice() throws IOException {
        this.assertChoice(new double[0], new double[]{0.2, 0.8}, new double[0][]);
        this.assertChoice(new double[0], new double[]{0.2, 0.8}, new double[][]{{-1.0, 1.0}});
        this.assertChoice(new double[0], new double[]{0.2, -1.2, 0.8}, {-1.0, 1.0, 1.0}, {2.0, 1.0, -1.0}, {-1.0, -1.0, -21.0}, {-1.0, 2.0, 1.0}, {1.0, -2.0, -1.0});
    }

    void assertChoice(double[] expectedBases, double[] coeffs, double[] ... inputs) throws IOException {
        DenseVector coeffVector = new DenseVector(coeffs);
        DiscreteChooser chooser = new DiscreteChooser(coeffVector);
        this.assertChoice(coeffVector, chooser, expectedBases, coeffs, inputs);
        DiscreteChooser serDeserChooser = (DiscreteChooser)AbstractExternalizable.serializeDeserialize(chooser);
        this.assertChoice(coeffVector, serDeserChooser, expectedBases, coeffs, inputs);
    }

    void assertChoice(Vector coeffVector, DiscreteChooser chooser, double[] expectedBases, double[] coeffs, double[][] inputs) {
        Vector[] inputVecs = new Vector[inputs.length];
        for (int i = 0; i < inputs.length; ++i) {
            inputVecs[i] = new DenseVector(inputs[i]);
        }
        if (inputVecs.length == 0) {
            try {
                chooser.choose(inputVecs);
                Assert.fail();
            }
            catch (IllegalArgumentException e) {
                Asserts.succeed();
            }
            try {
                chooser.choiceProbs(inputVecs);
                Assert.fail();
            }
            catch (IllegalArgumentException e) {
                Asserts.succeed();
            }
            try {
                chooser.choiceLogProbs(inputVecs);
                Assert.fail();
            }
            catch (IllegalArgumentException e) {
                Asserts.succeed();
            }
            return;
        }
        int choice = chooser.choose(inputVecs);
        double[] choiceProbs = chooser.choiceProbs(inputVecs);
        double[] choiceLogProbs = chooser.choiceLogProbs(inputVecs);
        double[] bases = new double[inputs.length];
        for (int i = 0; i < bases.length; ++i) {
            bases[i] = inputVecs[i].dotProduct(coeffVector);
        }
        double[] expBases = new double[inputs.length];
        for (int i = 0; i < expBases.length; ++i) {
            expBases[i] = java.lang.Math.exp(bases[i]);
        }
        double Z = 0.0;
        for (int i = 0; i < expBases.length; ++i) {
            Z += expBases[i];
        }
        double[] expProbs = new double[inputs.length];
        for (int i = 0; i < expProbs.length; ++i) {
            expProbs[i] = expBases[i] / Z;
        }
        double[] expLogProbs = new double[inputs.length];
        for (int i = 0; i < expLogProbs.length; ++i) {
            expLogProbs[i] = java.lang.Math.log(expProbs[i]);
        }
        int expChoice = 0;
        for (int i = 1; i < expBases.length; ++i) {
            if (!(expBases[i] > expBases[expChoice])) continue;
            expChoice = i;
        }
        Assert.assertEquals((int)expChoice, (int)choice);
        Asserts.assertEqualsArray(expProbs, choiceProbs, 0.001);
        Asserts.assertEqualsArray(expLogProbs, choiceLogProbs, 0.001);
        Assert.assertEquals((double)Math.sum(choiceProbs), (double)1.0, (double)0.001);
    }
}

