/*
 * Decompiled with CFR 0.152.
 */
package com.aliasi.test.unit.lm;

import com.aliasi.lm.CompiledTokenizedLM;
import com.aliasi.lm.NGramBoundaryLM;
import com.aliasi.lm.TokenizedLM;
import com.aliasi.lm.TrieIntSeqCounter;
import com.aliasi.lm.UniformBoundaryLM;
import com.aliasi.symbol.SymbolTable;
import com.aliasi.test.unit.Asserts;
import com.aliasi.tokenizer.IndoEuropeanTokenizerFactory;
import com.aliasi.util.Math;
import com.aliasi.util.ScoredObject;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.Arrays;
import java.util.Iterator;
import java.util.Random;
import java.util.SortedSet;
import org.junit.Assert;
import org.junit.Test;

public class TokenizedLMTest {
    private static final int MAX_NGRAM = 3;
    private static final double LAMBDA_FACTOR = 4.0;

    void dumpProbs(String[] tokens, TokenizedLM lm) {
        System.out.println("TOKENS: " + Arrays.asList(tokens));
        System.out.println("lm.tokenProbability(): " + lm.tokenProbability(tokens, 0, tokens.length));
        System.out.println("lm.tokenProbCharSmooth(): " + lm.tokenProbCharSmooth(tokens, 0, tokens.length));
        System.out.println("lm.tokenProbCharSmoothNoBound(): " + lm.tokenProbCharSmoothNoBounds(tokens, 0, tokens.length));
        System.out.println();
    }

    @Test
    public void testPabloBug() {
        IndoEuropeanTokenizerFactory tokenizerFactory = IndoEuropeanTokenizerFactory.INSTANCE;
        int nGramOrder = 3;
        NGramBoundaryLM unknownTokenModel = new NGramBoundaryLM(3);
        NGramBoundaryLM whitespaceModel = new NGramBoundaryLM(3);
        double lambdaFactor = 1.0;
        TokenizedLM lm = new TokenizedLM(tokenizerFactory, nGramOrder, unknownTokenModel, whitespaceModel, lambdaFactor);
        String text = "ba ba be lo and behold and and lo some more";
        for (int i = 0; i < 100; ++i) {
            lm.handle(text);
        }
        this.dumpProbs(new String[]{"ba"}, lm);
        this.dumpProbs(new String[]{"bo"}, lm);
        this.dumpProbs(new String[]{"%%"}, lm);
    }

    @Test
    public void testTrainSequence() {
        IndoEuropeanTokenizerFactory tf = IndoEuropeanTokenizerFactory.INSTANCE;
        TokenizedLM lm = new TokenizedLM(tf, 3);
        SymbolTable st = lm.symbolTable();
        TrieIntSeqCounter counter = lm.sequenceCounter();
        junit.framework.Assert.assertEquals((int)1, (int)counter.count(new int[0], 0, 0));
        String ab = "a b";
        String ac = "a c";
        String abc = "a b c";
        lm.trainSequence(ab, 2);
        lm.trainSequence(ac, 3);
        lm.trainSequence(abc, 4);
        int a = st.symbolToID("a");
        int b = st.symbolToID("b");
        int c = st.symbolToID("c");
        junit.framework.Assert.assertEquals((int)2, (int)counter.count(new int[]{a, b}, 0, 2));
        junit.framework.Assert.assertEquals((int)3, (int)counter.count(new int[]{a, c}, 0, 2));
        junit.framework.Assert.assertEquals((int)4, (int)counter.count(new int[]{a, b, c}, 0, 3));
        junit.framework.Assert.assertEquals((long)5L, (long)counter.extensionCount(new int[]{a}, 0, 1));
        lm.trainSequence("a a a c c c", 111);
        junit.framework.Assert.assertEquals((int)111, (int)counter.count(new int[]{c, c, c}, 0, 3));
        junit.framework.Assert.assertEquals((long)111L, (long)counter.extensionCount(new int[]{c, c}, 0, 2));
        lm.trainSequence("", 999);
        junit.framework.Assert.assertEquals((int)1000, (int)counter.count(new int[0], 0, 0));
    }

    @Test
    public void testZeroGram() {
        IndoEuropeanTokenizerFactory tf = IndoEuropeanTokenizerFactory.INSTANCE;
        try {
            new TokenizedLM(tf, 0, new UniformBoundaryLM(16), new UniformBoundaryLM(16), 4.0);
            junit.framework.Assert.fail();
        }
        catch (IllegalArgumentException e) {
            Asserts.succeed();
        }
    }

    @Test
    public void testUnigram() {
        IndoEuropeanTokenizerFactory tf = IndoEuropeanTokenizerFactory.INSTANCE;
        TokenizedLM lm = new TokenizedLM(tf, 1, new UniformBoundaryLM(16), new UniformBoundaryLM(16), 4.0);
        lm.train("John Smith");
    }

    @Test
    public void testBiggerGram() {
        IndoEuropeanTokenizerFactory tf = IndoEuropeanTokenizerFactory.INSTANCE;
        TokenizedLM lm = new TokenizedLM(tf, 4, new UniformBoundaryLM(16), new UniformBoundaryLM(16), 4.0);
        lm.train("John Smith");
    }

    @Test
    public void testChiSquaredIndependence() {
        IndoEuropeanTokenizerFactory tf = IndoEuropeanTokenizerFactory.INSTANCE;
        TokenizedLM lm = new TokenizedLM(tf, 3, new UniformBoundaryLM(16), new UniformBoundaryLM(16), 4.0);
        lm.train("a b c a b d a b e a b");
        SymbolTable table = lm.symbolTable();
        junit.framework.Assert.assertEquals((int)5, (int)table.numSymbols());
        int aI = table.symbolToID("a");
        int bI = table.symbolToID("b");
        int cI = table.symbolToID("c");
        int dI = table.symbolToID("d");
        int eI = table.symbolToID("e");
        junit.framework.Assert.assertTrue((aI >= 0 ? 1 : 0) != 0);
        junit.framework.Assert.assertTrue((bI >= 0 ? 1 : 0) != 0);
        junit.framework.Assert.assertTrue((cI >= 0 ? 1 : 0) != 0);
        junit.framework.Assert.assertTrue((dI >= 0 ? 1 : 0) != 0);
        junit.framework.Assert.assertTrue((eI >= 0 ? 1 : 0) != 0);
        junit.framework.Assert.assertTrue((lm.chiSquaredIndependence(new int[]{aI, bI}) > lm.chiSquaredIndependence(new int[]{bI, cI}) ? 1 : 0) != 0);
        junit.framework.Assert.assertTrue((lm.chiSquaredIndependence(new int[]{cI, aI}) > lm.chiSquaredIndependence(new int[]{cI, eI}) ? 1 : 0) != 0);
    }

    @Test
    public void testConstantSubModels() throws ClassNotFoundException, IOException {
        IndoEuropeanTokenizerFactory tf = IndoEuropeanTokenizerFactory.INSTANCE;
        TokenizedLM lm = new TokenizedLM(tf, 3, new UniformBoundaryLM(127), new UniformBoundaryLM(15), 4.0);
        double lambda_ = 0.2;
        double pml_EOS = 1.0;
        double p_EOS = lambda_ * pml_EOS;
        double pws_ = 0.0625;
        this.assertEstimate(Math.log2(p_EOS * pws_), lm, "");
        double p_UNK = 1.0 - lambda_;
        double ptok_a = 6.103515625E-5;
        this.assertEstimate(Math.log2(p_UNK * p_EOS * ptok_a * pws_ * pws_), lm, "a");
        double ptok_b = ptok_a;
        double pws_s = pws_ * 1.0 / 16.0;
        this.assertEstimate(Math.log2(p_UNK * p_UNK * p_EOS * pws_ * pws_ * pws_s * ptok_a * ptok_b), lm, "a b");
        double ptok_c = ptok_b;
        this.assertEstimate(Math.log2(p_UNK * p_UNK * p_UNK * p_EOS * pws_ * pws_ * pws_s * pws_s * ptok_a * ptok_b * ptok_c), lm, "a b c");
        double ptok_d = ptok_b;
        this.assertEstimate(Math.log2(p_UNK * p_UNK * p_UNK * p_UNK * p_EOS * pws_ * pws_ * pws_s * pws_s * pws_s * ptok_a * ptok_b * ptok_c * ptok_d), lm, "a b c d");
        lm.train("a");
        lambda_ = 0.2727272727272727;
        pml_EOS = 0.6666666666666666;
        p_EOS = lambda_ * pml_EOS;
        double lambda_EOS = 0.2;
        double p_EOS_giv_EOS = (1.0 - lambda_EOS) * p_EOS;
        this.assertEstimate(Math.log2(p_EOS_giv_EOS * pws_), lm, "");
        lambda_EOS = 0.2;
        double pml_A_giv_EOS = 1.0;
        lambda_ = 0.2727272727272727;
        double pml_A = 0.3333333333333333;
        double p_A = lambda_ * pml_A;
        double p_A_giv_EOS = lambda_EOS * pml_A_giv_EOS + (1.0 - lambda_EOS) * p_A;
        double lambda_EOS_A = 0.2;
        double pml_EOS_giv_EOS_A = 1.0;
        double lambda_A = 0.2;
        double pml_EOS_giv_A = 1.0;
        pml_EOS = 0.6666666666666666;
        p_EOS = lambda_ * pml_EOS;
        double p_EOS_giv_A = lambda_A * pml_EOS_giv_A + (1.0 - lambda_A) * p_EOS;
        double p_EOS_giv_EOS_A = lambda_EOS_A * pml_EOS_giv_EOS_A + (1.0 - lambda_EOS_A) * p_EOS_giv_A;
        this.assertEstimate(Math.log2(p_A_giv_EOS * p_EOS_giv_EOS_A * pws_ * pws_), lm, "a");
    }

    @Test
    public void testTwo() throws ClassNotFoundException, IOException {
        TokenizedLM lm = new TokenizedLM(IndoEuropeanTokenizerFactory.INSTANCE, 3, new UniformBoundaryLM(127), new UniformBoundaryLM(15), 4.0);
        this.assertEqEstimate(lm, "a");
        this.assertEqEstimate(lm, "a b");
        this.assertEqEstimate(lm, "a a b");
        this.assertEqEstimate(lm, "a b a");
        lm.train("a");
        this.assertEqEstimate(lm, "a");
        this.assertEqEstimate(lm, "a b");
        this.assertEqEstimate(lm, "a a b");
        this.assertEqEstimate(lm, "a b a");
        lm.train("a b c");
        this.assertEqEstimate(lm, "a");
        this.assertEqEstimate(lm, "a b");
        this.assertEqEstimate(lm, "a b e");
        lm.train("x y");
        this.assertEqEstimate(lm, "x y a b e x y");
        this.assertEqEstimate(lm, "");
        this.assertEqEstimate(lm, "x");
    }

    @Test
    public void testCollocs() {
        TokenizedLM lm = new TokenizedLM(IndoEuropeanTokenizerFactory.INSTANCE, 4);
        lm.train("a b c d");
        lm.train("a b e f");
        lm.train("c f e");
        SortedSet<ScoredObject<String[]>> collocSet = lm.collocationSet(2, 1, 2);
        junit.framework.Assert.assertEquals((int)2, (int)collocSet.size());
        Iterator it = collocSet.iterator();
        Assert.assertArrayEquals((Object[])new String[]{"a", "b"}, (Object[])((Object[])((ScoredObject)it.next()).getObject()));
        Assert.assertArrayEquals((Object[])new String[]{"c", "d"}, (Object[])((Object[])((ScoredObject)it.next()).getObject()));
        lm = new TokenizedLM(IndoEuropeanTokenizerFactory.INSTANCE, 4);
        lm.train("a b c d");
        lm.train("a b c e");
        lm.train("d e f");
        lm.train("f d e");
        lm.train("e f d");
        collocSet = lm.collocationSet(3, 1, 2);
        junit.framework.Assert.assertEquals((int)2, (int)collocSet.size());
        Assert.assertArrayEquals((Object[])new String[]{"a", "b", "c"}, (Object[])((Object[])((ScoredObject)collocSet.iterator().next()).getObject()));
        try {
            lm.collocationSet(1, 1, 3);
            junit.framework.Assert.fail();
        }
        catch (IllegalArgumentException e) {
            Asserts.succeed();
        }
    }

    static ScoredObject[] newTerms(TokenizedLM lm, int ngram, int minCount, int maxReturn, TokenizedLM lm2) {
        SortedSet<ScoredObject<String[]>> termSet = lm.newTermSet(ngram, minCount, maxReturn, lm2);
        ScoredObject[] result = new ScoredObject[termSet.size()];
        Iterator it = termSet.iterator();
        for (int i = 0; i < result.length; ++i) {
            result[i] = (ScoredObject)it.next();
        }
        return result;
    }

    static ScoredObject[] oldTerms(TokenizedLM lm, int ngram, int minCount, int maxReturn, TokenizedLM lm2) {
        SortedSet<ScoredObject<String[]>> termSet = lm.oldTermSet(ngram, minCount, maxReturn, lm2);
        ScoredObject[] result = new ScoredObject[termSet.size()];
        Iterator it = termSet.iterator();
        for (int i = 0; i < result.length; ++i) {
            result[i] = (ScoredObject)it.next();
        }
        return result;
    }

    static ScoredObject[] frequentTerms(TokenizedLM lm, int ngram, int maxReturn) {
        SortedSet<ScoredObject<String[]>> termSet = lm.frequentTermSet(ngram, maxReturn);
        ScoredObject[] result = new ScoredObject[termSet.size()];
        Iterator it = termSet.iterator();
        for (int i = 0; i < result.length; ++i) {
            result[i] = (ScoredObject)it.next();
        }
        return result;
    }

    static ScoredObject[] infrequentTerms(TokenizedLM lm, int ngram, int maxReturn) {
        SortedSet<ScoredObject<String[]>> termSet = lm.infrequentTermSet(ngram, maxReturn);
        ScoredObject[] result = new ScoredObject[termSet.size()];
        Iterator it = termSet.iterator();
        for (int i = 0; i < result.length; ++i) {
            result[i] = (ScoredObject)it.next();
        }
        return result;
    }

    @Test
    public void testNewAndOldTerms() {
        TokenizedLM lm1 = new TokenizedLM(IndoEuropeanTokenizerFactory.INSTANCE, 3);
        TokenizedLM lm2 = new TokenizedLM(IndoEuropeanTokenizerFactory.INSTANCE, 3);
        lm1.train("b c d");
        lm1.train("b c d");
        lm1.train("b c d");
        lm1.train("b c f");
        lm2.train("b c x");
        lm2.train("b c x");
        lm2.train("b c x");
        lm2.train("b c y");
        ScoredObject[] newTerms1 = TokenizedLMTest.newTerms(lm1, 2, 1, 3, lm2);
        Assert.assertArrayEquals((Object[])new String[]{"c", "d"}, (Object[])((String[])newTerms1[0].getObject()));
        ScoredObject[] newTerms2 = TokenizedLMTest.newTerms(lm2, 2, 1, 2, lm1);
        Assert.assertArrayEquals((Object[])new String[]{"c", "x"}, (Object[])((String[])newTerms2[0].getObject()));
        ScoredObject[] oldTerms1 = TokenizedLMTest.oldTerms(lm1, 2, 1, 3, lm2);
        Assert.assertArrayEquals((Object[])new String[]{"c", "f"}, (Object[])((String[])oldTerms1[0].getObject()));
        ScoredObject[] oldTerms2 = TokenizedLMTest.oldTerms(lm2, 2, 1, 3, lm1);
        Assert.assertArrayEquals((Object[])new String[]{"c", "y"}, (Object[])((String[])oldTerms2[0].getObject()));
        ScoredObject[] fTerms1 = TokenizedLMTest.frequentTerms(lm1, 2, 10);
        Assert.assertArrayEquals((Object[])new String[]{"b", "c"}, (Object[])((String[])fTerms1[0].getObject()));
        Assert.assertArrayEquals((Object[])new String[]{"c", "d"}, (Object[])((String[])fTerms1[1].getObject()));
        Assert.assertArrayEquals((Object[])new String[]{"c", "f"}, (Object[])((String[])fTerms1[2].getObject()));
        ScoredObject[] fTerms2 = TokenizedLMTest.infrequentTerms(lm1, 2, 10);
        Assert.assertArrayEquals((Object[])new String[]{"b", "c"}, (Object[])((String[])fTerms2[2].getObject()));
        Assert.assertArrayEquals((Object[])new String[]{"c", "d"}, (Object[])((String[])fTerms2[1].getObject()));
        Assert.assertArrayEquals((Object[])new String[]{"c", "f"}, (Object[])((String[])fTerms2[0].getObject()));
    }

    private void assertEstimate(double estimate, TokenizedLM lm, CharSequence cSeq) throws ClassNotFoundException, IOException {
        junit.framework.Assert.assertEquals((double)estimate, (double)lm.log2Estimate(cSeq), (double)0.005);
        this.assertEqEstimate(lm, cSeq.toString());
    }

    public void assertEqEstimate(TokenizedLM lm, CharSequence cSeq) throws ClassNotFoundException, IOException {
        junit.framework.Assert.assertEquals((double)lm.log2Estimate(cSeq), (double)TokenizedLMTest.writeRead(lm).log2Estimate(cSeq), (double)0.005);
    }

    private static CompiledTokenizedLM writeRead(TokenizedLM lm) {
        try {
            ByteArrayOutputStream bytesOut = new ByteArrayOutputStream();
            ObjectOutputStream objOut = new ObjectOutputStream(bytesOut);
            lm.compileTo(objOut);
            ByteArrayInputStream bytesIn = new ByteArrayInputStream(bytesOut.toByteArray());
            ObjectInputStream objIn = new ObjectInputStream(bytesIn);
            return (CompiledTokenizedLM)objIn.readObject();
        }
        catch (IOException e) {
            junit.framework.Assert.fail((String)e.toString());
        }
        catch (ClassNotFoundException e) {
            junit.framework.Assert.fail((String)e.toString());
        }
        return null;
    }

    @Test
    public void testMultipleIncrements() {
        Random random = new Random();
        IndoEuropeanTokenizerFactory tf = IndoEuropeanTokenizerFactory.INSTANCE;
        TokenizedLM lm1 = new TokenizedLM(tf, 3);
        TokenizedLM lm2 = new TokenizedLM(tf, 3);
        for (int i = 0; i < 100; ++i) {
            StringBuilder sb = new StringBuilder();
            for (int k = 0; k < 5; ++k) {
                sb.append((char)random.nextInt(16));
                sb.append(' ');
            }
            int trainingCount = random.nextInt(10);
            this.incrementAssertSynched(lm1, lm2, sb, trainingCount);
        }
    }

    void incrementAssertSynched(TokenizedLM lm1, TokenizedLM lm2, CharSequence cs, int count) {
        for (int i = 0; i < count; ++i) {
            lm1.train(cs);
        }
        lm2.train(cs, count);
        this.assertSynched(lm1, lm2);
    }

    void assertSynched(TokenizedLM lm1, TokenizedLM lm2) {
        for (int i = 0; i < 100; ++i) {
            for (int k = 0; k < 5; ++k) {
                this.assertSynched(lm1, lm2, k);
            }
        }
    }

    void assertSynched(TokenizedLM lm1, TokenizedLM lm2, int k) {
        Random random = new Random();
        StringBuilder sb = new StringBuilder();
        for (int i = 0; i < k; ++i) {
            sb.append((char)random.nextInt(16));
            sb.append(' ');
        }
        junit.framework.Assert.assertEquals((double)lm1.log2Estimate(sb), (double)lm2.log2Estimate(sb), (double)1.0E-4);
    }
}

