/*
 * Decompiled with CFR 0.152.
 */
package com.johnsnowlabs.nlp.annotators.parser.typdep;

import com.johnsnowlabs.nlp.annotators.parser.typdep.DependencyArcList;
import com.johnsnowlabs.nlp.annotators.parser.typdep.DependencyInstance;
import com.johnsnowlabs.nlp.annotators.parser.typdep.DependencyPipe;
import com.johnsnowlabs.nlp.annotators.parser.typdep.Options;
import com.johnsnowlabs.nlp.annotators.parser.typdep.Parameters;
import com.johnsnowlabs.nlp.annotators.parser.typdep.TypedDependencyParser;
import com.johnsnowlabs.nlp.annotators.parser.typdep.feature.SyntacticFeatureFactory;
import com.johnsnowlabs.nlp.annotators.parser.typdep.util.FeatureVector;
import com.johnsnowlabs.nlp.annotators.parser.typdep.util.ScoreCollector;
import java.util.Arrays;

public class LocalFeatureData {
    private DependencyInstance dependencyInstance;
    private DependencyPipe pipe;
    private SyntacticFeatureFactory synFactory;
    private Options options;
    private Parameters parameters;
    private final int sentenceLength;
    private final int numberOfLabelTypes;
    private final float gammaL;
    FeatureVector[] wordFvs;
    float[][] wpU;
    float[][] wpV;
    float[][] wpU2;
    float[][] wpV2;
    float[][] wpW2;
    private float[][] scoresOrProbabilities;
    private float[][][] labelScores;

    LocalFeatureData(DependencyInstance dependencyInstance, TypedDependencyParser typedDependencyParser) {
        this.dependencyInstance = dependencyInstance;
        this.pipe = typedDependencyParser.getDependencyPipe();
        this.synFactory = this.pipe.getSynFactory();
        this.options = typedDependencyParser.getOptions();
        this.parameters = typedDependencyParser.getParameters();
        this.sentenceLength = dependencyInstance.getLength();
        this.numberOfLabelTypes = this.pipe.getTypes().length;
        int n = this.options.rankFirstOrderTensor;
        int n2 = this.options.rankSecondOrderTensor;
        this.gammaL = this.options.gammaLabel;
        this.wordFvs = new FeatureVector[this.sentenceLength];
        this.wpU = new float[this.sentenceLength][n];
        this.wpV = new float[this.sentenceLength][n];
        this.wpU2 = new float[this.sentenceLength][n2];
        this.wpV2 = new float[this.sentenceLength][n2];
        this.wpW2 = new float[this.sentenceLength][n2];
        this.scoresOrProbabilities = new float[this.sentenceLength][this.numberOfLabelTypes];
        this.labelScores = new float[this.sentenceLength][this.numberOfLabelTypes][this.numberOfLabelTypes];
        for (int i = 0; i < this.sentenceLength; ++i) {
            this.wordFvs[i] = this.synFactory.createWordFeatures(dependencyInstance, i);
            this.parameters.projectU(this.wordFvs[i], this.wpU[i]);
            this.parameters.projectV(this.wordFvs[i], this.wpV[i]);
            this.parameters.projectU2(this.wordFvs[i], this.wpU2 != null ? this.wpU2[i] : new float[]{});
            this.parameters.projectV2(this.wordFvs[i], this.wpV2 != null ? this.wpV2[i] : new float[]{});
            this.parameters.projectW2(this.wordFvs[i], this.wpW2 != null ? this.wpW2[i] : new float[]{});
        }
    }

    FeatureVector getLabeledFeatureDifference(DependencyInstance dependencyInstance, int[] nArray, int[] nArray2) {
        FeatureVector featureVector = new FeatureVector();
        int[] nArray3 = dependencyInstance.getHeads();
        int[] nArray4 = dependencyInstance.getDependencyLabelIds();
        for (int i = 1; i < this.dependencyInstance.getLength(); ++i) {
            int n = nArray3[i];
            if (nArray4[i] != nArray2[i]) {
                featureVector.addEntries(this.getLabelFeature(nArray3, nArray4, i, 1));
                featureVector.addEntries(this.getLabelFeature(nArray, nArray2, i, 1), -1.0f);
            }
            if (nArray4[i] == nArray2[i] && nArray4[n] == nArray2[n]) continue;
            featureVector.addEntries(this.getLabelFeature(nArray3, nArray4, i, 2));
            featureVector.addEntries(this.getLabelFeature(nArray, nArray2, i, 2), -1.0f);
        }
        return featureVector;
    }

    private FeatureVector getLabelFeature(int[] nArray, int[] nArray2, int n, int n2) {
        FeatureVector featureVector = new FeatureVector();
        this.synFactory.createLabelFeatures(featureVector, this.dependencyInstance, nArray, nArray2, n, n2);
        return featureVector;
    }

    private void predictLabelsDP(int[] nArray, int[] nArray2, boolean bl, DependencyArcList dependencyArcList) {
        int n = bl ? 0 : 1;
        for (int i = 1; i < this.sentenceLength; ++i) {
            int n2 = nArray[i];
            int n3 = n2 > i ? 1 : 2;
            int n4 = nArray[n2];
            int n5 = n4 > n2 ? 1 : 2;
            for (int j = n; j < this.numberOfLabelTypes; ++j) {
                int[] nArray3 = this.dependencyInstance.getXPosTagIds();
                boolean bl2 = this.pipe.getPruneLabel()[nArray3[n2]][nArray3[i]][j];
                if (bl2) {
                    nArray2[i] = j;
                    float f = 0.0f;
                    if (this.gammaL > 0.0f) {
                        f += this.gammaL * this.getLabelScoreTheta(nArray, nArray2, i, 1);
                    }
                    if (this.gammaL < 1.0f) {
                        f += (1.0f - this.gammaL) * this.parameters.dotProductL(this.wpU[n2], this.wpV[i], j, n3);
                    }
                    for (int k = n; k < this.numberOfLabelTypes; ++k) {
                        float f2 = 0.0f;
                        if (n4 != -1) {
                            if (this.pipe.getPruneLabel()[nArray3[n4]][nArray3[n2]][k]) {
                                nArray2[n2] = k;
                                if (this.gammaL > 0.0f) {
                                    f2 += this.gammaL * this.getLabelScoreTheta(nArray, nArray2, i, 2);
                                }
                                if (this.gammaL < 1.0f) {
                                    f2 += (1.0f - this.gammaL) * this.parameters.dotProduct2L(this.wpU2[n4], this.wpV2[n2], this.wpW2[i], k, j, n5, n3);
                                }
                            } else {
                                f2 = Float.NEGATIVE_INFINITY;
                            }
                        }
                        this.labelScores[i][j][k] = f + f2 + (bl && this.dependencyInstance.getDependencyLabelIds()[i] != j ? 1.0f : 0.0f);
                    }
                    continue;
                }
                Arrays.fill(this.labelScores[i][j], Float.NEGATIVE_INFINITY);
            }
        }
        this.treeDP(0, dependencyArcList, n);
        nArray2[0] = this.dependencyInstance.getDependencyLabelIds()[0];
        this.computeDependencyLabels(0, dependencyArcList, nArray2, n);
    }

    private float getLabelScoreTheta(int[] nArray, int[] nArray2, int n, int n2) {
        ScoreCollector scoreCollector = new ScoreCollector(this.parameters.getParamsL());
        this.synFactory.createLabelFeatures(scoreCollector, this.dependencyInstance, nArray, nArray2, n, n2);
        return scoreCollector.getScore();
    }

    private void treeDP(int n, DependencyArcList dependencyArcList, int n2) {
        Arrays.fill(this.scoresOrProbabilities[n], 0.0f);
        int n3 = dependencyArcList.startIndex(n);
        int n4 = dependencyArcList.endIndex(n);
        for (int i = n3; i < n4; ++i) {
            int n5 = dependencyArcList.get(i);
            this.treeDP(n5, dependencyArcList, n2);
            int n6 = n2;
            while (n6 < this.numberOfLabelTypes) {
                float f = this.scoresOrProbabilities[n5][n2];
                float f2 = this.labelScores[n5][n2][n6];
                float f3 = f + f2;
                for (int j = n2 + 1; j < this.numberOfLabelTypes; ++j) {
                    float f4 = this.scoresOrProbabilities[n5][j] + this.labelScores[n5][j][n6];
                    if (!(f4 > f3)) continue;
                    f3 = f4;
                }
                float[] fArray = this.scoresOrProbabilities[n];
                int n7 = n6++;
                fArray[n7] = fArray[n7] + f3;
            }
        }
    }

    private void computeDependencyLabels(int n, DependencyArcList dependencyArcList, int[] nArray, int n2) {
        int n3 = nArray[n];
        int n4 = dependencyArcList.startIndex(n);
        int n5 = dependencyArcList.endIndex(n);
        for (int i = n4; i < n5; ++i) {
            int n6 = dependencyArcList.get(i);
            int n7 = 0;
            float f = Float.NEGATIVE_INFINITY;
            for (int j = n2; j < this.numberOfLabelTypes; ++j) {
                float f2 = this.scoresOrProbabilities[n6][j];
                float f3 = this.labelScores[n6][j][n3];
                float f4 = f2 + f3;
                if (!(f4 > f)) continue;
                f = f4;
                n7 = j;
            }
            if (f == Float.NEGATIVE_INFINITY) {
                n7 = nArray[n6];
            }
            nArray[n6] = n7;
            this.computeDependencyLabels(n6, dependencyArcList, nArray, n2);
        }
    }

    void predictLabels(int[] nArray, int[] nArray2, boolean bl) {
        DependencyArcList dependencyArcList = new DependencyArcList(nArray);
        this.predictLabelsDP(nArray, nArray2, bl, dependencyArcList);
    }
}

