/*
 * Decompiled with CFR 0.152.
 */
package cmu.arktweetnlp;

import cmu.arktweetnlp.impl.Model;
import cmu.arktweetnlp.impl.ModelSentence;
import cmu.arktweetnlp.impl.OWLQN;
import cmu.arktweetnlp.impl.Sentence;
import cmu.arktweetnlp.impl.features.FeatureExtractor;
import cmu.arktweetnlp.io.CoNLLReader;
import edu.stanford.nlp.math.ArrayMath;
import edu.stanford.nlp.optimization.DiffFunction;
import java.io.IOException;
import java.util.ArrayList;

public class Train {
    public double l2penalty = 2.0;
    public double l1penalty = 0.25;
    public double tol = 1.0E-7;
    public int maxIter = 500;
    public String modelLoadFilename = null;
    public String examplesFilename = null;
    public String modelSaveFilename = null;
    public boolean dumpFeatures = false;
    private ArrayList<Sentence> lSentences = new ArrayList();
    private ArrayList<ModelSentence> mSentences = new ArrayList();
    private int numTokens = 0;
    private Model model = new Model();

    Train() {
    }

    public void doFeatureDumping() throws IOException {
        this.readTrainingSentences(this.examplesFilename);
        this.constructLabelVocab();
        this.extractFeatures();
        this.dumpFeatures();
    }

    public void doTraining() throws IOException {
        this.readTrainingSentences(this.examplesFilename);
        this.constructLabelVocab();
        this.extractFeatures();
        this.model.lockdownAfterFeatureExtraction();
        if (this.modelLoadFilename != null) {
            this.readWarmStartModel();
        }
        this.optimizationLoop();
        this.model.saveModelAsText(this.modelSaveFilename);
    }

    public void readTrainingSentences(String string) throws IOException {
        this.lSentences = CoNLLReader.readFile(string);
        for (Sentence sentence : this.lSentences) {
            this.numTokens += sentence.T();
        }
    }

    public void constructLabelVocab() {
        for (Sentence sentence : this.lSentences) {
            for (String string : sentence.labels) {
                this.model.labelVocab.num(string);
            }
        }
        this.model.labelVocab.lock();
        this.model.numLabels = this.model.labelVocab.size();
    }

    public void dumpFeatures() throws IOException {
        FeatureExtractor featureExtractor = new FeatureExtractor(this.model, true);
        featureExtractor.dumpMode = true;
        for (Sentence sentence : this.lSentences) {
            ModelSentence modelSentence = new ModelSentence(sentence.T());
            featureExtractor.computeFeatures(sentence, modelSentence);
        }
    }

    public void extractFeatures() throws IOException {
        System.out.println("Extracting features");
        FeatureExtractor featureExtractor = new FeatureExtractor(this.model, true);
        for (Sentence sentence : this.lSentences) {
            ModelSentence modelSentence = new ModelSentence(sentence.T());
            featureExtractor.computeFeatures(sentence, modelSentence);
            this.mSentences.add(modelSentence);
        }
    }

    public void readWarmStartModel() throws IOException {
        assert (this.model.featureVocab.isLocked());
        Model model = Model.loadModelFromText(this.modelLoadFilename);
        Model.copyCoefsForIntersectingFeatures(model, this.model);
    }

    public void optimizationLoop() {
        OWLQN oWLQN = new OWLQN();
        oWLQN.setMaxIters(this.maxIter);
        oWLQN.setQuiet(false);
        oWLQN.setWeightsPrinting(new MyWeightsPrinter());
        double[] dArray = this.model.convertCoefsToFlat();
        double[] dArray2 = oWLQN.minimize(new GradientCalculator(), dArray, this.l1penalty, this.tol, 5);
        this.model.setCoefsFromFlat(dArray2);
    }

    private void addL2regularizerGradient(double[] dArray, double[] dArray2) {
        assert (dArray.length == dArray2.length);
        for (int i = 0; i < dArray2.length; ++i) {
            int n = i;
            dArray[n] = dArray[n] + this.l2penalty * dArray2[i];
        }
    }

    private double regularizerValue(double[] dArray) {
        double d = 0.0;
        for (int i = 0; i < dArray.length; ++i) {
            d += Math.pow(dArray[i], 2.0);
        }
        return 0.5 * this.l2penalty * d;
    }

    public static void main(String[] stringArray) throws IOException {
        Train train = new Train();
        if (stringArray.length < 2 || stringArray[0].equals("-h") || stringArray[1].equals("--help")) {
            Train.usage();
        }
        int n = 0;
        while (n < stringArray.length && stringArray[n].startsWith("-")) {
            if (stringArray[n].equals("--warm-start")) {
                train.modelLoadFilename = stringArray[n + 1];
                n += 2;
                continue;
            }
            if (stringArray[n].equals("--max-iter")) {
                train.maxIter = Integer.parseInt(stringArray[n + 1]);
                n += 2;
                continue;
            }
            if (stringArray[n].equals("--dump-feat")) {
                train.dumpFeatures = true;
                ++n;
                continue;
            }
            if (stringArray[n].equals("--l2")) {
                train.l2penalty = Double.parseDouble(stringArray[n + 1]);
                n += 2;
                continue;
            }
            if (stringArray[n].equals("--l1")) {
                train.l1penalty = Double.parseDouble(stringArray[n + 1]);
                n += 2;
                continue;
            }
            Train.usage();
        }
        if (train.dumpFeatures) {
            train.examplesFilename = stringArray[n];
            train.doFeatureDumping();
            System.exit(0);
        }
        if (stringArray.length - n < 2) {
            Train.usage();
        }
        train.examplesFilename = stringArray[n];
        train.modelSaveFilename = stringArray[n + 1];
        train.doTraining();
    }

    public static void usage() {
        System.out.println("Train [options] <ExamplesFilename> <ModelOutputFilename>\nOptions:\n  --max-iter <n>\n  --warm-start <modelfile>    Initializes at weights of this model.  discards base features that aren't in training set.\n  --dump-feat                 Show extracted features, instead of training. Useful for debugging/analyzing feature extractors.\n");
        System.exit(1);
    }

    public class MyWeightsPrinter
    implements OWLQN.WeightsPrinter {
        @Override
        public void printWeights() {
            double d = 0.0;
            for (ModelSentence modelSentence : Train.this.mSentences) {
                d += Train.this.model.computeLogLik(modelSentence);
            }
            System.out.printf("\tTokLL %.6f\t", d / (double)Train.this.numTokens);
        }
    }

    private class GradientCalculator
    implements DiffFunction {
        private GradientCalculator() {
        }

        public int domainDimension() {
            return Train.this.model.flatIDsize();
        }

        public double valueAt(double[] dArray) {
            Train.this.model.setCoefsFromFlat(dArray);
            double d = 0.0;
            for (ModelSentence modelSentence : Train.this.mSentences) {
                d += Train.this.model.computeLogLik(modelSentence);
            }
            return -d + Train.this.regularizerValue(dArray);
        }

        public double[] derivativeAt(double[] dArray) {
            double[] dArray2 = new double[Train.this.model.flatIDsize()];
            Train.this.model.setCoefsFromFlat(dArray);
            for (ModelSentence modelSentence : Train.this.mSentences) {
                Train.this.model.computeGradient(modelSentence, dArray2);
            }
            ArrayMath.multiplyInPlace((double[])dArray2, (double)-1.0);
            Train.this.addL2regularizerGradient(dArray2, dArray);
            return dArray2;
        }
    }
}

