/*
 * Decompiled with CFR 0.152.
 */
package com.aliasi.test.unit.dca;

import com.aliasi.dca.DiscreteChooser;
import com.aliasi.dca.DiscreteObjectChooser;
import com.aliasi.io.Reporter;
import com.aliasi.matrix.DenseVector;
import com.aliasi.matrix.Vector;
import com.aliasi.stats.AnnealingSchedule;
import com.aliasi.stats.RegressionPrior;
import com.aliasi.symbol.SymbolTable;
import com.aliasi.util.AbstractExternalizable;
import com.aliasi.util.FeatureExtractor;
import java.io.IOException;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
import java.util.Random;
import junit.framework.Assert;
import org.junit.Test;

public class DiscreteObjectChooserTest {
    @Test
    public void testSim() throws IOException {
        int numSamples = 1000;
        double[] simCoeffs = new double[]{0.0, 3.0, -2.0, 1.0};
        int numDims = simCoeffs.length;
        DenseVector simCoeffVector = new DenseVector(simCoeffs);
        DiscreteChooser simChooser = new DiscreteChooser(simCoeffVector);
        Random random = new Random(42L);
        Vector[][] alternativess = new Vector[numSamples][];
        int[] choices = new int[numSamples];
        block0: for (int i = 0; i < numSamples; ++i) {
            int numChoices = 1 + random.nextInt(8);
            alternativess[i] = new Vector[numChoices];
            for (int k = 0; k < numChoices; ++k) {
                double[] xs = new double[numDims];
                xs[0] = 1.0;
                for (int d = 1; d < numDims; ++d) {
                    xs[d] = 2.0 * random.nextGaussian();
                }
                alternativess[i][k] = new DenseVector(xs);
            }
            double[] choiceProbs = simChooser.choiceProbs(alternativess[i]);
            double choiceProb = random.nextDouble();
            double cumProb = 0.0;
            for (int k = 0; k < numChoices; ++k) {
                if (!(choiceProb < (cumProb += choiceProbs[k])) && k != numChoices - 1) continue;
                choices[i] = k;
                continue block0;
            }
        }
        double priorVariance = 5.0;
        boolean nonInformativeIntercept = true;
        RegressionPrior prior = RegressionPrior.gaussian(priorVariance, nonInformativeIntercept);
        int priorBlockSize = 100;
        double initialLearningRate = 0.1;
        double decayBase = 0.99;
        AnnealingSchedule annealingSchedule = AnnealingSchedule.exponential(initialLearningRate, decayBase);
        double minImprovement = 1.0E-5;
        int minEpochs = 5;
        int maxEpochs = 500;
        Reporter reporter = null;
        HashMap<Integer, Vector> vectorMap = new HashMap<Integer, Vector>();
        ArrayList alternativeObjectss = new ArrayList(alternativess.length);
        int count = 0;
        for (int i = 0; i < alternativess.length; ++i) {
            ArrayList<Integer> alternativeObjects = new ArrayList<Integer>(alternativess[i].length);
            alternativeObjectss.add(alternativeObjects);
            for (int j = 0; j < alternativess[i].length; ++j) {
                Integer obj = count++;
                vectorMap.put(obj, alternativess[i][j]);
                alternativeObjects.add(obj);
            }
        }
        MapFeatureExtractor featureExtractor = new MapFeatureExtractor(vectorMap);
        int minFeatureCount = 5;
        DiscreteObjectChooser<Integer> objectChooser = DiscreteObjectChooser.estimate(featureExtractor, alternativeObjectss, choices, minFeatureCount, prior, priorBlockSize, annealingSchedule, minImprovement, minEpochs, maxEpochs, reporter);
        DiscreteChooser chooser = objectChooser.chooser();
        SymbolTable featureSymbolTable = objectChooser.featureSymbolTable();
        Vector coeffVector = chooser.coefficients();
        for (int d = 0; d < coeffVector.numDimensions(); ++d) {
            Assert.assertEquals((double)simCoeffVector.value(d), (double)coeffVector.value(featureSymbolTable.symbolToID(Integer.toString(d))), (double)0.1);
        }
        DiscreteObjectChooser deserChooser = (DiscreteObjectChooser)AbstractExternalizable.serializeDeserialize(objectChooser);
        Vector deserCoeffVector = deserChooser.chooser().coefficients();
        SymbolTable deserSymTab = deserChooser.featureSymbolTable();
        for (int d = 0; d < deserCoeffVector.numDimensions(); ++d) {
            Assert.assertEquals((double)simCoeffVector.value(d), (double)deserCoeffVector.value(deserSymTab.symbolToID(Integer.toString(d))), (double)0.1);
        }
    }

    static class MapFeatureExtractor
    implements FeatureExtractor<Integer>,
    Serializable {
        final Map<Integer, Vector> mMap;

        MapFeatureExtractor(Map<Integer, Vector> map) {
            this.mMap = map;
        }

        @Override
        public Map<String, Double> features(Integer i) {
            Vector v = this.mMap.get(i);
            HashMap<String, Double> result = new HashMap<String, Double>(5);
            for (int d = 0; d < v.numDimensions(); ++d) {
                result.put(Integer.toString(d), v.value(d));
            }
            return result;
        }
    }
}

