/*
 * Decompiled with CFR 0.152.
 */
package com.aliasi.chunk;

import com.aliasi.chunk.CompiledEstimator;
import com.aliasi.chunk.Node;
import com.aliasi.chunk.OutcomeCounter;
import com.aliasi.chunk.Tags;
import com.aliasi.symbol.SymbolTableCompiler;
import com.aliasi.tokenizer.TokenCategorizer;
import com.aliasi.util.AbstractExternalizable;
import com.aliasi.util.Compilable;
import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.TreeSet;

final class TrainableEstimator
implements Compilable {
    private Node mRootTagNode;
    private Node mRootTokenNode;
    private final SymbolTableCompiler mTokenSymbolTable = new SymbolTableCompiler();
    private final SymbolTableCompiler mTagSymbolTable = new SymbolTableCompiler();
    private double mLambdaFactor;
    private double mLogUniformVocabEstimate;
    private final TokenCategorizer mTokenCategorizer;

    public TrainableEstimator(double lambdaFactor, double logUniformVocabEstimate, TokenCategorizer categorizer) {
        this.mLambdaFactor = lambdaFactor;
        this.mLogUniformVocabEstimate = logUniformVocabEstimate;
        this.mTokenCategorizer = categorizer;
        this.mRootTagNode = new Node(null, this.mTagSymbolTable, null);
        this.mRootTokenNode = new Node(null, this.mTokenSymbolTable, null);
        this.mTagSymbolTable.addSymbol("O");
    }

    public TrainableEstimator(TokenCategorizer categorizer) {
        this(4.0, Math.log(1.0E-6), categorizer);
    }

    public void setLambdaFactor(double lambdaFactor) {
        if (lambdaFactor < 0.0 || Double.isNaN(lambdaFactor) || Double.isInfinite(lambdaFactor)) {
            throw new IllegalArgumentException("Lambda factor must be > 0. Was=" + lambdaFactor);
        }
        this.mLambdaFactor = lambdaFactor;
    }

    public void setLogUniformVocabularyEstimate(double estimate) {
        if (estimate >= 0.0 || Double.isNaN(estimate) || Double.isInfinite(estimate)) {
            throw new IllegalArgumentException("Log vocab estimate must be < 0. Was=" + estimate);
        }
        this.mLogUniformVocabEstimate = estimate;
    }

    public void handle(String[] tokens, String[] tags) {
        if (tokens.length < 1) {
            return;
        }
        this.trainOutcome(tokens[0], tags[0], "O", ".", ".");
        if (tokens.length < 2) {
            this.trainOutcome(".", "O", tags[0], tokens[0], ".");
            return;
        }
        this.trainOutcome(tokens[1], tags[1], tags[0], tokens[0], ".");
        for (int i = 2; i < tokens.length; ++i) {
            this.trainOutcome(tokens[i], tags[i], tags[i - 1], tokens[i - 1], tokens[i - 2]);
        }
        this.trainOutcome(".", "O", tags[tags.length - 1], tokens[tokens.length - 1], tokens[tokens.length - 2]);
    }

    @Override
    public void compileTo(ObjectOutput out) throws IOException {
        out.writeObject(new Externalizer(this));
    }

    public void trainOutcome(String token, String tag, String tagMinus1, String tokenMinus1, String tokenMinus2) {
        this.mTagSymbolTable.addSymbol(tag);
        this.mTokenSymbolTable.addSymbol(token);
        String tagMinus1Interior = tagMinus1 == null ? null : Tags.toInnerTag(tagMinus1);
        this.trainTokenModel(token, tag, tagMinus1Interior, tokenMinus1);
        this.trainTagModel(tag, tagMinus1Interior, tokenMinus1, tokenMinus2);
    }

    private void generateSymbols() {
        this.mRootTagNode.generateSymbols();
        this.mRootTokenNode.generateSymbols();
        String[] tokenCategories = this.mTokenCategorizer.categories();
        for (int i = 0; i < tokenCategories.length; ++i) {
            this.mTokenSymbolTable.addSymbol(tokenCategories[i]);
        }
    }

    public void trainTokenModel(String token, String tag, String tagMinus1, String tokenMinus1) {
        if (tag == null || token == null) {
            return;
        }
        Node nodeTag = this.mRootTokenNode.getOrCreateChild(tag, null, this.mTagSymbolTable);
        nodeTag.incrementOutcome(token, this.mTokenSymbolTable);
        if (tagMinus1 == null) {
            return;
        }
        Node nodeTagTag1 = nodeTag.getOrCreateChild(tagMinus1, nodeTag, this.mTagSymbolTable);
        nodeTagTag1.incrementOutcome(token, this.mTokenSymbolTable);
        if (tokenMinus1 == null) {
            return;
        }
        Node nodeTagTag1W1 = nodeTagTag1.getOrCreateChild(tokenMinus1, nodeTagTag1, this.mTokenSymbolTable);
        nodeTagTag1W1.incrementOutcome(token, this.mTokenSymbolTable);
    }

    public void trainTagModel(String tag, String tagMinus1, String tokenMinus1, String tokenMinus2) {
        if (tag == null || tagMinus1 == null) {
            return;
        }
        Node nodeTag1 = this.mRootTagNode.getOrCreateChild(tagMinus1, null, this.mTagSymbolTable);
        nodeTag1.incrementOutcome(tag, this.mTagSymbolTable);
        if (tokenMinus1 == null) {
            return;
        }
        Node nodeTag1W1 = nodeTag1.getOrCreateChild(tokenMinus1, nodeTag1, this.mTokenSymbolTable);
        nodeTag1W1.incrementOutcome(tag, this.mTagSymbolTable);
        if (tokenMinus2 == null) {
            return;
        }
        Node nodeTag1W1W2 = nodeTag1W1.getOrCreateChild(tokenMinus2, nodeTag1W1, this.mTokenSymbolTable);
        nodeTag1W1W2.incrementOutcome(tag, this.mTagSymbolTable);
    }

    public void trainTokenOutcome(String token, String tag) {
        this.trainTokenModel(token, tag, null, null);
    }

    public int numTagNodes() {
        return this.mRootTagNode.numNodes();
    }

    public int numTagOutcomes() {
        return this.mRootTagNode.numCounters();
    }

    public int numTokenNodes() {
        return this.mRootTokenNode.numNodes();
    }

    public int numTokenOutcomes() {
        return this.mRootTokenNode.numCounters();
    }

    public void prune(int thresholdTag, int thresholdToken) {
        this.mRootTagNode.prune(thresholdTag);
        this.mRootTokenNode.prune(thresholdToken);
    }

    public void smoothTags(int countToAdd) {
        String[] tags = this.mTagSymbolTable.symbols();
        for (int i = 0; i < tags.length; ++i) {
            String tag1 = tags[i];
            for (int j = 0; j < tags.length; ++j) {
                String tag2 = tags[j];
                if (Tags.illegalSequence(tag1, tag2)) continue;
                for (int k = 0; k < countToAdd; ++k) {
                    this.trainTagModel(tag2, tag1, null, null);
                }
            }
        }
    }

    private void writeEstimator(Node rootNode, ObjectOutput out) throws IOException {
        rootNode.compileEstimates(this.mLambdaFactor);
        TrainableEstimator.indexNodes(rootNode);
        out.writeInt(rootNode.numNodes());
        TrainableEstimator.writeNodes(rootNode, out);
        out.writeInt(rootNode.numCounters());
        TrainableEstimator.writeOutcomes(rootNode, out);
    }

    private static void indexNodes(Node rootNode) {
        LinkedList<Node> nodeQueue = new LinkedList<Node>();
        nodeQueue.addLast(rootNode);
        int index = 0;
        while (nodeQueue.size() > 0) {
            Node node = (Node)nodeQueue.removeFirst();
            node.setIndex(index++);
            for (String childString : node.children()) {
                nodeQueue.addLast(node.getChild(childString));
            }
        }
    }

    private static void writeNodes(Node rootNode, ObjectOutput out) throws IOException {
        LinkedList<Object[]> nodeQueue = new LinkedList<Object[]>();
        nodeQueue.addLast(new Object[]{rootNode, null});
        int outcomesIndex = 0;
        int index = 0;
        while (nodeQueue.size() > 0) {
            Object[] pair = (Object[])nodeQueue.removeFirst();
            Node node = (Node)pair[0];
            out.writeInt(node.getSymbolID());
            out.writeInt(outcomesIndex);
            outcomesIndex += node.outcomes().size();
            TreeSet<String> children = new TreeSet<String>(node.children());
            if (children.size() == 0) {
                out.writeInt(index);
            } else {
                Iterator<String> childIterator = children.iterator();
                Node firstChild = node.getChild(childIterator.next());
                out.writeInt(firstChild.index());
                index = firstChild.index() + node.children().size();
                for (String childName : children) {
                    Node childNode = node.getChild(childName);
                    nodeQueue.addLast(new Object[]{childNode, childName});
                }
            }
            out.writeFloat(node.oneMinusLambda());
            out.writeInt(node.backoffNode() == null ? -1 : node.backoffNode().index());
        }
    }

    private static void writeOutcomes(Node rootNode, ObjectOutput out) throws IOException {
        LinkedList<Node> nodeQueue = new LinkedList<Node>();
        nodeQueue.addLast(rootNode);
        while (nodeQueue.size() > 0) {
            Node node = (Node)nodeQueue.removeFirst();
            for (String outcome : node.outcomes()) {
                OutcomeCounter outcomeCounter = node.getOutcome(outcome);
                out.writeInt(outcomeCounter.getSymbolID());
                out.writeFloat(outcomeCounter.estimate());
            }
            for (String child : node.children()) {
                nodeQueue.addLast(node.getChild(child));
            }
        }
    }

    static class Externalizer
    extends AbstractExternalizable {
        private static final long serialVersionUID = 4179100933315980535L;
        final TrainableEstimator mEstimator;

        public Externalizer() {
            this(null);
        }

        public Externalizer(TrainableEstimator estimator) {
            this.mEstimator = estimator;
        }

        @Override
        public Object read(ObjectInput in) throws ClassNotFoundException, IOException {
            return new CompiledEstimator(in);
        }

        @Override
        public void writeExternal(ObjectOutput objOut) throws IOException {
            AbstractExternalizable.compileOrSerialize(this.mEstimator.mTokenCategorizer, objOut);
            this.mEstimator.generateSymbols();
            this.mEstimator.mTagSymbolTable.compileTo(objOut);
            this.mEstimator.mTokenSymbolTable.compileTo(objOut);
            this.mEstimator.writeEstimator(this.mEstimator.mRootTagNode, objOut);
            this.mEstimator.writeEstimator(this.mEstimator.mRootTokenNode, objOut);
            objOut.writeDouble(this.mEstimator.mLogUniformVocabEstimate);
        }
    }
}

