/*
 * Decompiled with CFR 0.152.
 */
package com.johnsnowlabs.nlp.annotators.parser.typdep;

import com.johnsnowlabs.nlp.annotators.parser.typdep.ConllData;
import com.johnsnowlabs.nlp.annotators.parser.typdep.DependencyInstance;
import com.johnsnowlabs.nlp.annotators.parser.typdep.DependencyPipe;
import com.johnsnowlabs.nlp.annotators.parser.typdep.LocalFeatureData;
import com.johnsnowlabs.nlp.annotators.parser.typdep.LowRankTensor;
import com.johnsnowlabs.nlp.annotators.parser.typdep.Options;
import com.johnsnowlabs.nlp.annotators.parser.typdep.Parameters;
import com.johnsnowlabs.nlp.annotators.parser.typdep.io.ConllWriter;
import com.johnsnowlabs.nlp.annotators.parser.typdep.util.DependencyLabel;
import java.io.Serializable;
import java.util.ArrayList;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class TypedDependencyParser
implements Serializable {
    private static final long serialVersionUID = 1L;
    private transient Logger logger = LoggerFactory.getLogger((String)"TypedDependencyParser");
    private Options options;
    private DependencyPipe dependencyPipe;
    private Parameters parameters;

    DependencyPipe getDependencyPipe() {
        return this.dependencyPipe;
    }

    Parameters getParameters() {
        return this.parameters;
    }

    public Options getOptions() {
        return this.options;
    }

    void setDependencyPipe(DependencyPipe dependencyPipe) {
        this.dependencyPipe = dependencyPipe;
    }

    void setParameters(Parameters parameters) {
        this.parameters = parameters;
    }

    public void setOptions(Options options) {
        this.options = options;
    }

    void train(DependencyInstance[] dependencyInstanceArray) {
        long l;
        long l2;
        if ((this.options.rankFirstOrderTensor > 0 || this.options.rankSecondOrderTensor > 0) && this.options.gammaLabel < 1.0f && this.options.initTensorWithPretrain) {
            Options options = Options.newInstance(this.options);
            this.options.rankFirstOrderTensor = 0;
            this.options.rankSecondOrderTensor = 0;
            this.options.gammaLabel = 1.0f;
            options.setNumberOfTrainingIterations(this.options.numberOfPreTrainingIterations);
            this.parameters.setRankFirstOrderTensor(this.options.rankFirstOrderTensor);
            this.parameters.setRankSecondOrderTensor(this.options.rankSecondOrderTensor);
            this.parameters.setGammaLabel(this.options.gammaLabel);
            this.logger.debug("Pre-training:%n");
            l2 = System.currentTimeMillis();
            this.logger.debug("Running MIRA ... ");
            this.trainIterations(dependencyInstanceArray);
            this.options = options;
            this.parameters.setRankFirstOrderTensor(this.options.rankFirstOrderTensor);
            this.parameters.setRankSecondOrderTensor(this.options.rankSecondOrderTensor);
            this.parameters.setGammaLabel(this.options.gammaLabel);
            this.logger.debug("Init tensor ... ");
            int n = this.parameters.getNumberWordFeatures();
            int n2 = this.parameters.getDL();
            LowRankTensor lowRankTensor = new LowRankTensor(new int[]{n, n, n2}, this.options.rankFirstOrderTensor);
            LowRankTensor lowRankTensor2 = new LowRankTensor(new int[]{n, n, n, n2, n2}, this.options.rankSecondOrderTensor);
            this.dependencyPipe.getSynFactory().fillParameters(lowRankTensor, lowRankTensor2, this.parameters);
            ArrayList<float[][]> arrayList = new ArrayList<float[][]>();
            arrayList.add(this.parameters.getU());
            arrayList.add(this.parameters.getV());
            arrayList.add(this.parameters.getWL());
            lowRankTensor.decompose(arrayList);
            ArrayList<float[][]> arrayList2 = new ArrayList<float[][]>();
            arrayList2.add(this.parameters.getU2());
            arrayList2.add(this.parameters.getV2());
            arrayList2.add(this.parameters.getW2());
            arrayList2.add(this.parameters.getX2L());
            arrayList2.add(this.parameters.getY2L());
            lowRankTensor2.decompose(arrayList2);
            this.parameters.assignTotal();
            this.parameters.printStat();
            l = System.currentTimeMillis();
            if (this.logger.isDebugEnabled()) {
                this.logger.debug(String.format("Pre-training took %d ms.%n", l - l2));
            }
        } else {
            this.parameters.randomlyInit();
        }
        this.logger.debug(" Training:%n");
        l2 = System.currentTimeMillis();
        this.logger.debug("Running MIRA ... ");
        this.trainIterations(dependencyInstanceArray);
        l = System.currentTimeMillis();
        if (this.logger.isDebugEnabled()) {
            this.logger.debug(String.format("Training took %d ms.%n", l - l2));
        }
    }

    private void trainIterations(DependencyInstance[] dependencyInstanceArray) {
        int n;
        int n2 = n = 10000 < dependencyInstanceArray.length ? dependencyInstanceArray.length / 10 : 1000;
        if (this.logger.isDebugEnabled()) {
            this.logger.debug(String.format("Number of Training Iterations: %d", this.options.getNumberOfTrainingIterations()));
        }
        for (int i = 0; i < this.options.getNumberOfTrainingIterations(); ++i) {
            double d = 0.0;
            int n3 = 0;
            int n4 = 0;
            long l = System.currentTimeMillis();
            for (int j = 0; j < dependencyInstanceArray.length; ++j) {
                if ((j + 1) % n == 0 && this.logger.isDebugEnabled()) {
                    this.logger.debug(String.format("  %d (time=%ds)", j + 1, (System.currentTimeMillis() - l) / 1000L));
                }
                DependencyInstance dependencyInstance = dependencyInstanceArray[j];
                LocalFeatureData localFeatureData = new LocalFeatureData(dependencyInstance, this);
                int n5 = dependencyInstance.getLength();
                int[] nArray = dependencyInstance.getHeads();
                int[] nArray2 = new int[n5];
                localFeatureData.predictLabels(nArray, nArray2, true);
                int n6 = this.getNumberCorrectMatches(dependencyInstance.getHeads(), dependencyInstance.getDependencyLabelIds(), nArray, nArray2);
                if (n6 != n5 - 1) {
                    d += (double)this.parameters.updateLabel(dependencyInstance, nArray, nArray2, localFeatureData, i * dependencyInstanceArray.length + j + 1);
                }
                n3 += n6;
                n4 += n5 - 1;
            }
            int n7 = n4 = n4 == 0 ? 1 : n4;
            if (this.logger.isDebugEnabled()) {
                this.logger.debug(String.format("%n Iter %d loss=%.4f totalNUmberCorrectMatches=%.4f [%ds]%n", i + 1, d, (double)n3 / ((double)n4 + 0.0), (System.currentTimeMillis() - l) / 1000L));
            }
            this.parameters.printStat();
        }
    }

    private int getNumberCorrectMatches(int[] nArray, int[] nArray2, int[] nArray3, int[] nArray4) {
        int n = 0;
        int n2 = nArray.length;
        for (int i = 1; i < n2; ++i) {
            if (nArray[i] != nArray3[i] || nArray2[i] != nArray4[i]) continue;
            ++n;
        }
        return n;
    }

    DependencyLabel[] predictDependency(ConllData[][] conllDataArray, String string) {
        ConllData[] conllDataArray2;
        DependencyInstance dependencyInstance;
        ConllWriter conllWriter = new ConllWriter(this.options, this.dependencyPipe);
        DependencyLabel[] dependencyLabelArray = new DependencyLabel[conllDataArray[0].length];
        ConllData[][] conllDataArray3 = conllDataArray;
        int n = conllDataArray3.length;
        for (int i = 0; i < n && (dependencyInstance = this.dependencyPipe.nextSentence(conllDataArray2 = conllDataArray3[i], string)) != null; ++i) {
            LocalFeatureData localFeatureData = new LocalFeatureData(dependencyInstance, this);
            int n2 = dependencyInstance.getLength();
            int[] nArray = dependencyInstance.getHeads();
            int[] nArray2 = new int[n2];
            localFeatureData.predictLabels(nArray, nArray2, true);
            dependencyLabelArray = conllWriter.getDependencyLabels(dependencyInstance, nArray, nArray2);
        }
        return dependencyLabelArray;
    }
}

