/*
 * Decompiled with CFR 0.152.
 */
package com.aliasi.test.unit.crf;

import com.aliasi.corpus.Corpus;
import com.aliasi.corpus.ObjectHandler;
import com.aliasi.crf.ChainCrf;
import com.aliasi.crf.ChainCrfFeatureExtractor;
import com.aliasi.crf.ChainCrfFeatures;
import com.aliasi.io.Reporter;
import com.aliasi.matrix.DenseVector;
import com.aliasi.matrix.Vector;
import com.aliasi.stats.AnnealingSchedule;
import com.aliasi.stats.RegressionPrior;
import com.aliasi.symbol.SymbolTable;
import com.aliasi.symbol.SymbolTableCompiler;
import com.aliasi.tag.ScoredTagging;
import com.aliasi.tag.TagLattice;
import com.aliasi.tag.Tagging;
import com.aliasi.test.unit.Asserts;
import com.aliasi.util.AbstractExternalizable;
import com.aliasi.util.Math;
import com.aliasi.util.ObjectToDoubleMap;
import java.io.IOException;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.TreeSet;
import org.junit.Assert;
import org.junit.Test;

public class ChainCrfTest {
    static String CAT1 = "X";
    static String CAT2 = "Y";
    static String CAT3 = "Z";
    static String[] TAGS = new String[]{CAT1, CAT2, CAT3};
    static String X1 = "a";
    static String X2 = "b";
    static String X3 = "c";
    static String X4 = "d";
    static String[] TOKENS = new String[]{X1, X2, X3, X4};
    static String[] FEATURES = new String[]{CAT1, CAT2, CAT3, X1, X2, X3, X4};
    static double XX = 1.0;
    static double XY = 1.0;
    static double XZ = 2.0;
    static double YX = 2.0;
    static double YY = -1.0;
    static double YZ = 4.0;
    static double ZX = 3.0;
    static double ZY = 1.0;
    static double ZZ = 6.0;
    static double[][] TRANSITION_WEIGHTS = new double[][]{{XX, YX, ZX}, {XY, YY, ZY}, {XZ, YZ, ZZ}};
    static double Xa = 4.0;
    static double Xb = 5.0;
    static double Xc = 6.0;
    static double Xd = 7.0;
    static double Ya = -1.0;
    static double Yb = 10.0;
    static double Yc = -1.0;
    static double Yd = 1.0;
    static double Za = -2.0;
    static double Zb = -4.0;
    static double Zc = -6.0;
    static double Zd = 15.0;
    static double[][] TOKEN_WEIGHTS = new double[][]{{Xa, Xb, Xc, Xd}, {Ya, Yb, Yc, Yd}, {Za, Zb, Zc, Zd}};
    static int NUM_TAGS = TAGS.length;
    static Vector[] COEFFICIENTS = new DenseVector[]{new DenseVector(new double[]{XX, YX, ZX, Xa, Xb, Xc, Xd}), new DenseVector(new double[]{XY, YY, ZY, Ya, Yb, Yc, Yd}), new DenseVector(new double[]{XZ, YZ, ZZ, Za, Zb, Zc, Zd})};
    static final SymbolTable FEATURE_SYMBOL_TABLE = SymbolTableCompiler.asSymbolTable(FEATURES);
    static final ChainCrfFeatureExtractor<String> FEATURE_EXTRACTOR = new TestFeatureExtractor();
    static boolean ADD_INTERCEPT_FEATURE = false;
    static ChainCrf<String> CRF = new ChainCrf<String>(TAGS, COEFFICIENTS, FEATURE_SYMBOL_TABLE, FEATURE_EXTRACTOR, ADD_INTERCEPT_FEATURE);

    @Test
    public void testDecoder() throws IOException {
        ChainCrf crf2 = (ChainCrf)AbstractExternalizable.serializeDeserialize(CRF);
        junit.framework.Assert.assertEquals((boolean)CRF.addInterceptFeature(), (boolean)crf2.addInterceptFeature());
        junit.framework.Assert.assertEquals((int)CRF.featureSymbolTable().numSymbols(), (int)crf2.featureSymbolTable().numSymbols());
        for (int i = 0; i < CRF.featureSymbolTable().numSymbols(); ++i) {
            junit.framework.Assert.assertEquals((String)CRF.featureSymbolTable().idToSymbol(i), (String)crf2.featureSymbolTable().idToSymbol(i));
        }
        junit.framework.Assert.assertEquals(CRF.tags(), crf2.tags());
        Vector[] coeffsCRF = CRF.coefficients();
        Vector[] coeffsCrf2 = crf2.coefficients();
        junit.framework.Assert.assertEquals((int)coeffsCRF.length, (int)coeffsCrf2.length);
        for (int i = 0; i < coeffsCRF.length; ++i) {
            junit.framework.Assert.assertEquals((int)coeffsCRF[i].numDimensions(), (int)coeffsCrf2[i].numDimensions());
            Assert.assertArrayEquals((int[])coeffsCRF[i].nonZeroDimensions(), (int[])coeffsCrf2[i].nonZeroDimensions());
            for (Object d : (Object)coeffsCRF[i].nonZeroDimensions()) {
                junit.framework.Assert.assertEquals((double)coeffsCRF[i].value((int)d), (double)coeffsCrf2[i].value((int)d), (double)1.0E-4);
            }
        }
        for (int length = 0; length < 5; ++length) {
            for (int[] tokenIds : ChainCrfTest.allArrays(length, TOKENS.length)) {
                ArrayList<String> tokenList = new ArrayList<String>(length);
                for (int i = 0; i < tokenIds.length; ++i) {
                    tokenList.add(TOKENS[tokenIds[i]]);
                }
                ObjectToDoubleMap<int[]> otdMap = ChainCrfTest.bruteForce(tokenIds, TAGS.length, TRANSITION_WEIGHTS, TOKEN_WEIGHTS);
                ChainCrfTest.assertCorrectAnswer(CRF, tokenList, otdMap, TAGS);
                ChainCrfTest.assertCorrectAnswer(crf2, tokenList, otdMap, TAGS);
                Iterator<ScoredTagging<String>> nBest = CRF.tagNBest(tokenList, Integer.MAX_VALUE);
                this.assertCorrectNBest(otdMap, nBest, TAGS, false);
                Iterator<ScoredTagging<String>> nBestCond = CRF.tagNBestConditional(tokenList, Integer.MAX_VALUE);
                this.assertCorrectNBest(otdMap, nBestCond, TAGS, true);
                TagLattice<String> tagLattice = CRF.tagMarginal(tokenList);
                this.assertCorrectMarginal(otdMap, tagLattice, TAGS, tokenList);
            }
        }
    }

    void assertCorrectMarginal(ObjectToDoubleMap<int[]> otdMap, TagLattice<String> tagLattice, String[] tags, List<String> tokenList) {
        junit.framework.Assert.assertEquals(tokenList, tagLattice.tokenList());
        double logZ = ChainCrfTest.logZ(otdMap);
        junit.framework.Assert.assertEquals((double)logZ, (double)tagLattice.logZ(), (double)0.001);
        List<String> tagList = tagLattice.tagList();
        for (int pos = 0; pos < tokenList.size(); ++pos) {
            double sum = 0.0;
            for (int tagId = 0; tagId < tagList.size(); ++tagId) {
                sum += java.lang.Math.exp(tagLattice.logProbability(pos, tagId));
                junit.framework.Assert.assertEquals((double)ChainCrfTest.logMarginal(otdMap, pos, tagId, tags.length, logZ), (double)tagLattice.logProbability(pos, tagId), (double)1.0E-4);
            }
            junit.framework.Assert.assertEquals((String)("marginals norm " + pos + " " + tokenList), (double)1.0, (double)sum, (double)0.01);
        }
    }

    static double logMarginal(ObjectToDoubleMap<int[]> otdMap, int pos, int tagId, int numTags, double logZ) {
        int count = 0;
        for (int[] key : otdMap.keySet()) {
            if (key[pos] != tagId) continue;
            ++count;
        }
        double[] xs = new double[count];
        count = 0;
        for (Map.Entry entry : otdMap.entrySet()) {
            if (tagId != ((int[])entry.getKey())[pos]) continue;
            xs[count++] = (Double)entry.getValue();
        }
        return Math.logSumOfExponentials(xs) - logZ;
    }

    static double logZ(ObjectToDoubleMap<int[]> otdMap) {
        double[] xs = new double[otdMap.size()];
        int idx = 0;
        Iterator iterator = otdMap.values().iterator();
        while (iterator.hasNext()) {
            double x = (Double)iterator.next();
            xs[idx++] = x;
        }
        return Math.logSumOfExponentials(xs);
    }

    void assertCorrectNBest(ObjectToDoubleMap<int[]> otdMap, Iterator<ScoredTagging<String>> nBest, String[] tags, boolean conditional) {
        double logZ = conditional ? ChainCrfTest.logZ(otdMap) : 0.0;
        ObjectToDoubleMap otdMap2 = new ObjectToDoubleMap();
        int count = 0;
        TreeSet<String> expectedTaggingSet = new TreeSet<String>();
        for (Map.Entry entry : otdMap.entrySet()) {
            Double val = (Double)entry.getValue();
            int[] tagIds = (int[])entry.getKey();
            StringBuilder sb = new StringBuilder();
            for (int i = 0; i < tagIds.length; ++i) {
                sb.append(tags[tagIds[i]]);
            }
            String tagRep = sb.toString();
            otdMap2.put(tagRep, val);
            expectedTaggingSet.add(tagRep);
            ++count;
        }
        TreeSet<String> foundTaggingSet = new TreeSet<String>();
        count = 0;
        while (nBest.hasNext()) {
            ScoredTagging<String> scoredTagging = nBest.next();
            double val = scoredTagging.score();
            List<String> tagList = scoredTagging.tags();
            StringBuilder sb = new StringBuilder();
            for (String tag : tagList) {
                sb.append(tag);
            }
            String tagRep = sb.toString();
            foundTaggingSet.add(tagRep);
            double expectedVal = (Double)otdMap2.get(tagRep) - logZ;
            junit.framework.Assert.assertEquals((double)expectedVal, (double)val, (double)1.0E-4);
            ++count;
        }
        junit.framework.Assert.assertEquals(expectedTaggingSet, foundTaggingSet);
    }

    @Test
    public void testAllOutputsSizes() {
        junit.framework.Assert.assertEquals((int)1, (int)ChainCrfTest.allArrays(0, 5).size());
        junit.framework.Assert.assertEquals((int)5, (int)ChainCrfTest.allArrays(1, 5).size());
        junit.framework.Assert.assertEquals((int)25, (int)ChainCrfTest.allArrays(2, 5).size());
        junit.framework.Assert.assertEquals((int)125, (int)ChainCrfTest.allArrays(3, 5).size());
    }

    static void assertCorrectAnswer(ChainCrf<String> crf, List<String> tokenList, ObjectToDoubleMap<int[]> otdMap, String[] tags) {
        Tagging<String> tagging = crf.tag(tokenList);
        List<String> foundTags = tagging.tags();
        List<int[]> keysList = otdMap.keysOrderedByValueList();
        double score = otdMap.getValue(keysList.get(0));
        for (int[] keys : keysList) {
            double score2 = otdMap.getValue(keys);
            if (score2 < score) {
                junit.framework.Assert.fail();
            }
            if (!ChainCrfTest.areEqualTags(foundTags, keys, tags)) continue;
            Asserts.succeed();
            return;
        }
    }

    static boolean areEqualTags(List<String> foundTags, int[] expectedTags, String[] tags) {
        for (int i = 0; i < expectedTags.length; ++i) {
            if (foundTags.get(i).equals(tags[expectedTags[i]])) continue;
            return false;
        }
        return true;
    }

    static ObjectToDoubleMap<int[]> bruteForce(int[] tokens, int numTags, double[][] transitionWeights, double[][] tokenWeights) {
        ObjectToDoubleMap<int[]> outputMap = new ObjectToDoubleMap<int[]>();
        List<int[]> allArrays = ChainCrfTest.allArrays(tokens.length, numTags);
        for (int[] output : allArrays) {
            double score = ChainCrfTest.score(tokens, output, transitionWeights, tokenWeights);
            outputMap.put(output, score);
        }
        return outputMap;
    }

    static double score(int[] tokens, int[] output, double[][] transitionWeights, double[][] tokenWeights) {
        int i;
        double score = 0.0;
        for (i = 0; i < tokens.length; ++i) {
            score += tokenWeights[output[i]][tokens[i]];
        }
        for (i = 1; i < tokens.length; ++i) {
            score += transitionWeights[output[i]][output[i - 1]];
        }
        return score;
    }

    static List<int[]> allArrays(int size, int maxVal) {
        ArrayList<int[]> result = new ArrayList<int[]>();
        ChainCrfTest.allArrays(size, maxVal, new int[size], result);
        return result;
    }

    static void allArrays(int size, int maxVal, int[] buf, List<int[]> result) {
        if (size == 0) {
            result.add((int[])buf.clone());
            return;
        }
        int i = 0;
        while (i < maxVal) {
            buf[size - 1] = i++;
            ChainCrfTest.allArrays(size - 1, maxVal, buf, result);
        }
    }

    @Test
    public void testEstimate() throws Exception {
        TestCorpus corpus = new TestCorpus();
        int minCount = 1;
        boolean addIntercept = true;
        boolean cacheFeatureVectors = true;
        boolean allowUnseenTransitions = true;
        RegressionPrior prior = RegressionPrior.gaussian(10.0, true);
        int priorBlockSize = 3;
        AnnealingSchedule annealingSchedule = AnnealingSchedule.exponential(0.02, 0.995);
        double minImprovement = 1.0E-5;
        int minEpochs = 2;
        int maxEpochs = 2000;
        Reporter reporter = null;
        ChainCrf<String> crf = ChainCrf.estimate(corpus, FEATURE_EXTRACTOR, addIntercept, minCount, cacheFeatureVectors, allowUnseenTransitions, prior, priorBlockSize, annealingSchedule, minImprovement, minEpochs, maxEpochs, reporter);
        ChainCrfTest.assertTagging(Arrays.asList("John", "ran", "."), Arrays.asList("PN", "IV", "EOS"), crf);
        ChainCrfTest.assertTagging(Arrays.asList("Mary", "ran", "."), Arrays.asList("PN", "IV", "EOS"), crf);
        ChainCrfTest.assertTagging(Arrays.asList("The", "dog", "ran", "."), Arrays.asList("DET", "N", "IV", "EOS"), crf);
        ChainCrfTest.assertTagging(Arrays.asList("The", "dog", "ran", "!"), Arrays.asList("DET", "N", "IV", "EOS"), crf);
        ChainCrfTest.assertTagging(Arrays.asList("The", "dog", "sat", "!"), Arrays.asList("DET", "N", "IV", "EOS"), crf);
        ChainCrfTest.assertTagging(Arrays.asList("The", "dog", "sat", "."), Arrays.asList("DET", "N", "IV", "EOS"), crf);
        ChainCrfTest.assertTagging(Arrays.asList("John", "likes", "Mary", "."), Arrays.asList("PN", "TV", "PN", "EOS"), crf);
        ChainCrfTest.assertTagging(Arrays.asList("Mary", "likes", "John", "."), Arrays.asList("PN", "TV", "PN", "EOS"), crf);
        junit.framework.Assert.assertNotNull(crf.tag(Arrays.asList("Fred", "likes", "John", ".")));
        junit.framework.Assert.assertNotNull(crf.tag(Arrays.asList(";", ".", "likes", "likes")));
    }

    static <E> void assertTagging(List<E> tokens, List<String> tagsExpected, ChainCrf<E> crf) {
        Tagging<E> tagging = crf.tag(tokens);
        List<String> tagsFound = tagging.tags();
        junit.framework.Assert.assertEquals(tagsExpected, tagsFound);
    }

    static class TestCorpus
    extends Corpus<ObjectHandler<Tagging<String>>> {
        static final String[][][] WORDS_TAGSS = new String[][][]{{new String[0], new String[0]}, {{"."}, {"EOS"}}, {{"John", "ran", "."}, {"PN", "IV", "EOS"}}, {{"Mary", "ran", "."}, {"PN", "IV", "EOS"}}, {{"John", "jumped", "!"}, {"PN", "IV", "EOS"}}, {{"The", "dog", "jumped", "!"}, {"DET", "N", "IV", "EOS"}}, {{"The", "dog", "sat", "."}, {"DET", "N", "IV", "EOS"}}, {{"Mary", "sat", "!"}, {"PN", "IV", "EOS"}}, {{"Mary", "likes", "John", "."}, {"PN", "TV", "PN", "EOS"}}, {{"The", "dog", "likes", "Mary", "."}, {"DET", "N", "TV", "PN", "EOS"}}, {{"John", "likes", "the", "dog", "."}, {"PN", "TV", "DET", "N", "EOS"}}, {{"The", "dog", "ran", "."}, {"DET", "N", "IV", "EOS"}}, {{"The", "dog", "ran", "."}, {"DET", "N", "IV", "EOS"}}};

        TestCorpus() {
        }

        @Override
        public void visitTrain(ObjectHandler<Tagging<String>> handler) {
            for (String[][] wordsTags : WORDS_TAGSS) {
                String[] words = wordsTags[0];
                String[] tags = wordsTags[1];
                Tagging<String> tagging = new Tagging<String>(Arrays.asList(words), Arrays.asList(tags));
                handler.handle(tagging);
            }
        }

        @Override
        public void visitTest(ObjectHandler<Tagging<String>> handler) {
        }
    }

    static class TestCrfFeatures
    extends ChainCrfFeatures<String> {
        public TestCrfFeatures(List<String> tokens, List<String> tags) {
            super(tokens, tags);
        }

        @Override
        public Map<String, Integer> nodeFeatures(int n) {
            return Collections.singletonMap(this.token(n), 1);
        }

        @Override
        public Map<String, Integer> edgeFeatures(int n, int prevTagIndex) {
            return Collections.singletonMap(this.tag(prevTagIndex), 1);
        }
    }

    static class TestFeatureExtractor
    implements ChainCrfFeatureExtractor<String>,
    Serializable {
        TestFeatureExtractor() {
        }

        @Override
        public ChainCrfFeatures<String> extract(List<String> tokens, List<String> tags) {
            return new TestCrfFeatures(tokens, tags);
        }
    }
}

