/*
 * Decompiled with CFR 0.152.
 */
package cmu.arktweetnlp.impl;

import cmu.arktweetnlp.impl.ModelSentence;
import cmu.arktweetnlp.impl.Vocabulary;
import cmu.arktweetnlp.util.BasicFileIO;
import edu.berkeley.nlp.util.ArrayUtil;
import edu.berkeley.nlp.util.Triple;
import edu.stanford.nlp.math.ArrayMath;
import edu.stanford.nlp.util.Pair;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Arrays;

public class Model {
    public Vocabulary labelVocab = new Vocabulary();
    public Vocabulary featureVocab = new Vocabulary();
    public double[] biasCoefs;
    public double[][] edgeCoefs;
    public double[][] observationFeatureCoefs;
    public int numLabels;

    public int startMarker() {
        assert (this.labelVocab.isLocked());
        int n = this.labelVocab.size() - 1;
        return n + 1;
    }

    public void lockdownAfterFeatureExtraction() {
        this.labelVocab.lock();
        this.featureVocab.lock();
        this.allocateCoefs(this.labelVocab.size(), this.featureVocab.size());
    }

    public void allocateCoefs(int n, int n2) {
        this.observationFeatureCoefs = new double[n2][n];
        this.edgeCoefs = new double[n + 1][n];
        this.biasCoefs = new double[n];
    }

    public double[][] inferPosteriorGivenLabels(ModelSentence modelSentence) {
        double[][] dArray = new double[modelSentence.T][this.labelVocab.size()];
        double[] dArray2 = new double[this.numLabels];
        for (int i = 0; i < modelSentence.T; ++i) {
            this.computeLabelScores(i, modelSentence, dArray2);
            ArrayUtil.expInPlace((double[])dArray2);
            double d = ArrayUtil.sum((double[])dArray2);
            for (int j = 0; j < this.numLabels; ++j) {
                dArray[i][j] = dArray2[j] / d;
            }
        }
        return dArray;
    }

    public void greedyDecode(ModelSentence modelSentence, boolean bl) {
        int n = modelSentence.T;
        modelSentence.labels = new int[n];
        modelSentence.edgeFeatures[0] = this.startMarker();
        if (bl) {
            modelSentence.confidences = new double[n];
        }
        double[] dArray = new double[this.numLabels];
        for (int i = 0; i < n; ++i) {
            this.computeLabelScores(i, modelSentence, dArray);
            modelSentence.labels[i] = ArrayMath.argmax((double[])dArray);
            if (i < n - 1) {
                modelSentence.edgeFeatures[i + 1] = modelSentence.labels[i];
            }
            if (!bl) continue;
            ArrayMath.expInPlace((double[])dArray);
            double d = ArrayMath.sum((double[])dArray);
            ArrayMath.multiplyInPlace((double[])dArray, (double)(1.0 / d));
            modelSentence.confidences[i] = dArray[modelSentence.labels[i]];
        }
    }

    public double[][] inferPosteriorForUnknownLabels(ModelSentence modelSentence) {
        assert (false) : "Unimplemented";
        return null;
    }

    public void viterbiDecode(ModelSentence modelSentence) {
        int n;
        int n2 = modelSentence.T;
        modelSentence.labels = new int[n2];
        int[][] nArray = new int[n2][this.numLabels];
        double[][] dArray = new double[n2][this.numLabels];
        double[] dArray2 = new double[this.numLabels];
        this.computeVitLabelScores(0, this.startMarker(), modelSentence, dArray2);
        ArrayUtil.logNormalize((double[])dArray2);
        dArray[0] = dArray2;
        for (n = 0; n < this.numLabels; ++n) {
            nArray[0][n] = this.startMarker();
        }
        for (n = 1; n < n2; ++n) {
            int n3;
            double[][] dArray3 = new double[this.numLabels][this.numLabels];
            for (n3 = 0; n3 < this.numLabels; ++n3) {
                this.computeVitLabelScores(n, n3, modelSentence, dArray3[n3]);
                ArrayUtil.logNormalize((double[])dArray3[n3]);
                dArray3[n3] = ArrayUtil.add((double[])dArray3[n3], (double)dArray2[n3]);
            }
            for (n3 = 0; n3 < this.numLabels; ++n3) {
                double[] dArray4 = this.getColumn(dArray3, n3);
                nArray[n][n3] = ArrayUtil.argmax((double[])dArray4);
                dArray[n][n3] = dArray4[nArray[n][n3]];
            }
            dArray2 = dArray[n];
        }
        modelSentence.labels[n2 - 1] = ArrayUtil.argmax((double[])dArray[n2 - 1]);
        n = nArray[n2 - 1][modelSentence.labels[n2 - 1]];
        for (int i = n2 - 2; i >= 0 && n != this.startMarker(); --i) {
            modelSentence.labels[i] = n;
            n = nArray[i][n];
        }
        assert (n == this.startMarker());
    }

    private double[] getColumn(double[][] dArray, int n) {
        double[] dArray2 = new double[dArray.length];
        for (int i = 0; i < dArray[0].length; ++i) {
            dArray2[i] = dArray[i][n];
        }
        return dArray2;
    }

    public void mbrDecode(ModelSentence modelSentence) {
        double[][] dArray = this.inferPosteriorForUnknownLabels(modelSentence);
        for (int i = 0; i < modelSentence.T; ++i) {
            modelSentence.labels[i] = ArrayMath.argmax((double[])dArray[i]);
        }
    }

    public void computeLabelScores(int n, ModelSentence modelSentence, double[] dArray) {
        Arrays.fill(dArray, 0.0);
        this.computeBiasScores(dArray);
        this.computeEdgeScores(n, modelSentence, dArray);
        this.computeObservedFeatureScores(n, modelSentence, dArray);
    }

    public void computeVitLabelScores(int n, int n2, ModelSentence modelSentence, double[] dArray) {
        Arrays.fill(dArray, 0.0);
        this.computeBiasScores(dArray);
        this.viterbiEdgeScores(n2, modelSentence, dArray);
        this.computeObservedFeatureScores(n, modelSentence, dArray);
    }

    public void computeBiasScores(double[] dArray) {
        for (int i = 0; i < this.numLabels; ++i) {
            int n = i;
            dArray[n] = dArray[n] + this.biasCoefs[i];
        }
    }

    public void computeEdgeScores(int n, ModelSentence modelSentence, double[] dArray) {
        int n2 = modelSentence.edgeFeatures[n];
        for (int i = 0; i < this.numLabels; ++i) {
            int n3 = i;
            dArray[n3] = dArray[n3] + this.edgeCoefs[n2][i];
        }
    }

    public void viterbiEdgeScores(int n, ModelSentence modelSentence, double[] dArray) {
        for (int i = 0; i < this.numLabels; ++i) {
            int n2 = i;
            dArray[n2] = dArray[n2] + this.edgeCoefs[n][i];
        }
    }

    public void computeObservedFeatureScores(int n, ModelSentence modelSentence, double[] dArray) {
        for (int i = 0; i < this.numLabels; ++i) {
            for (Pair<Integer, Double> pair : modelSentence.observationFeatures.get(n)) {
                int n2 = i;
                dArray[n2] = dArray[n2] + this.observationFeatureCoefs[(Integer)pair.first][i] * (Double)pair.second;
            }
        }
    }

    public double[] ThreewiseMultiply(double[] dArray, double[] dArray2, double[] dArray3) {
        if (dArray.length != dArray2.length || dArray2.length != dArray3.length) {
            throw new RuntimeException();
        }
        double[] dArray4 = new double[dArray.length];
        for (int i = 0; i < dArray4.length; ++i) {
            dArray4[i] = dArray[i] * dArray2[i] * dArray3[i];
        }
        return dArray4;
    }

    public void computeGradient(ModelSentence modelSentence, double[] dArray) {
        assert (dArray.length == this.flatIDsize());
        int n = modelSentence.T;
        double[][] dArray2 = this.inferPosteriorGivenLabels(modelSentence);
        for (int i = 0; i < n; ++i) {
            int n2 = modelSentence.edgeFeatures[i];
            int n3 = modelSentence.labels[i];
            for (int j = 0; j < this.numLabels; ++j) {
                double d = dArray2[i][j];
                boolean bl = n3 == j;
                int n4 = this.biasFeature_to_flatID(j);
                dArray[n4] = dArray[n4] + ((double)bl - d);
                int n5 = this.edgeFeature_to_flatID(n2, j);
                dArray[n5] = dArray[n5] + ((double)bl - d);
                for (Pair<Integer, Double> pair : modelSentence.observationFeatures.get(i)) {
                    int n6 = this.observationFeature_to_flatID((Integer)pair.first, j);
                    dArray[n6] = dArray[n6] + ((double)bl - d) * (Double)pair.second;
                }
            }
        }
    }

    public double computeLogLik(ModelSentence modelSentence) {
        double[][] dArray = this.inferPosteriorGivenLabels(modelSentence);
        double d = 0.0;
        for (int i = 0; i < modelSentence.T; ++i) {
            int n = modelSentence.labels[i];
            d += Math.log(dArray[i][n]);
        }
        return d;
    }

    public void setCoefsFromFlat(double[] dArray) {
        int n;
        int n2;
        for (n2 = 0; n2 < this.numLabels; ++n2) {
            this.biasCoefs[n2] = dArray[this.biasFeature_to_flatID(n2)];
        }
        for (n2 = 0; n2 < this.numLabels + 1; ++n2) {
            for (n = 0; n < this.numLabels; ++n) {
                this.edgeCoefs[n2][n] = dArray[this.edgeFeature_to_flatID(n2, n)];
            }
        }
        for (n2 = 0; n2 < this.featureVocab.size(); ++n2) {
            for (n = 0; n < this.numLabels; ++n) {
                this.observationFeatureCoefs[n2][n] = dArray[this.observationFeature_to_flatID(n2, n)];
            }
        }
    }

    public double[] convertCoefsToFlat() {
        int n;
        int n2;
        double[] dArray = new double[this.flatIDsize()];
        for (n2 = 0; n2 < this.numLabels; ++n2) {
            dArray[this.biasFeature_to_flatID((int)n2)] = this.biasCoefs[n2];
        }
        for (n2 = 0; n2 < this.numLabels + 1; ++n2) {
            for (n = 0; n < this.numLabels; ++n) {
                dArray[this.edgeFeature_to_flatID((int)n2, (int)n)] = this.edgeCoefs[n2][n];
            }
        }
        for (n2 = 0; n2 < this.featureVocab.size(); ++n2) {
            for (n = 0; n < this.numLabels; ++n) {
                dArray[this.observationFeature_to_flatID((int)n2, (int)n)] = this.observationFeatureCoefs[n2][n];
            }
        }
        return dArray;
    }

    public int flatIDsize() {
        int n = this.labelVocab.size();
        int n2 = this.featureVocab.size();
        return n + (n + 1) * n + n2 * n;
    }

    private int biasFeature_to_flatID(int n) {
        return n;
    }

    private int edgeFeature_to_flatID(int n, int n2) {
        int n3 = this.labelVocab.size();
        return n3 + n * n3 + n2;
    }

    private int observationFeature_to_flatID(int n, int n2) {
        int n3 = this.labelVocab.size();
        return n3 + (n3 + 1) * n3 + n * n3 + n2;
    }

    public void saveModelAsText(String string) throws IOException {
        int n;
        int n2;
        BufferedWriter bufferedWriter = BasicFileIO.openFileToWriteUTF8(string);
        PrintWriter printWriter = new PrintWriter(bufferedWriter);
        for (n2 = 0; n2 < this.numLabels; ++n2) {
            printWriter.printf("***BIAS***\t%s\t%g\n", this.labelVocab.name(n2), this.biasCoefs[n2]);
        }
        for (n2 = 0; n2 < this.numLabels + 1; ++n2) {
            for (n = 0; n < this.numLabels; ++n) {
                printWriter.printf("***EDGE***\t%s %s\t%s\n", n2, n, this.edgeCoefs[n2][n]);
            }
        }
        assert (this.featureVocab.size() == this.observationFeatureCoefs.length);
        for (n2 = 0; n2 < this.featureVocab.size(); ++n2) {
            for (n = 0; n < this.numLabels; ++n) {
                if (this.observationFeatureCoefs[n2][n] == 0.0) continue;
                printWriter.printf("%s\t%s\t%g\n", this.featureVocab.name(n2), this.labelVocab.name(n), this.observationFeatureCoefs[n2][n]);
            }
        }
        printWriter.close();
        bufferedWriter.close();
    }

    public static Model loadModelFromText(String string) throws IOException {
        int n;
        String[] stringArray;
        String string2;
        Model model = new Model();
        BufferedReader bufferedReader = BasicFileIO.openFileOrResource(string);
        ArrayList<Double> arrayList = new ArrayList<Double>();
        ArrayList<Triple> arrayList2 = new ArrayList<Triple>();
        ArrayList<Triple> arrayList3 = new ArrayList<Triple>();
        while ((string2 = bufferedReader.readLine()) != null && (stringArray = string2.split("\t"))[0].equals("***BIAS***")) {
            model.labelVocab.num(stringArray[1]);
            arrayList.add(Double.parseDouble(stringArray[2]));
        }
        model.labelVocab.lock();
        model.numLabels = model.labelVocab.size();
        while ((stringArray = string2.split("\t"))[0].equals("***EDGE***")) {
            String[] stringArray2 = stringArray[1].split(" ");
            n = Integer.parseInt(stringArray2[0]);
            int n2 = Integer.parseInt(stringArray2[1]);
            arrayList2.add(new Triple((Object)n, (Object)n2, (Object)Double.parseDouble(stringArray[2])));
            string2 = bufferedReader.readLine();
            if (string2 != null) continue;
        }
        do {
            stringArray = string2.split("\t");
            int n3 = model.featureVocab.num(stringArray[0]);
            n = model.labelVocab.num(stringArray[1]);
            arrayList3.add(new Triple((Object)n3, (Object)n, (Object)Double.parseDouble(stringArray[2])));
        } while ((string2 = bufferedReader.readLine()) != null);
        model.featureVocab.lock();
        model.allocateCoefs(model.labelVocab.size(), model.featureVocab.size());
        for (int i = 0; i < model.numLabels; ++i) {
            model.biasCoefs[i] = (Double)arrayList.get(i);
        }
        for (Triple triple : arrayList2) {
            model.edgeCoefs[((Integer)triple.getFirst()).intValue()][((Integer)triple.getSecond()).intValue()] = (Double)triple.getThird();
        }
        for (Triple triple : arrayList3) {
            model.observationFeatureCoefs[((Integer)triple.getFirst()).intValue()][((Integer)triple.getSecond()).intValue()] = (Double)triple.getThird();
        }
        bufferedReader.close();
        return model;
    }

    public static void copyCoefsForIntersectingFeatures(Model model, Model model2) {
        int n;
        int n2 = model.numLabels;
        if (n2 != model2.numLabels) {
            throw new RuntimeException("label vocabs must be same size for warm-start");
        }
        for (n = 0; n < n2; ++n) {
            if (model2.labelVocab.name(n).equals(model.labelVocab.name(n))) continue;
            throw new RuntimeException("label vocabs must agree for warm-start");
        }
        model2.biasCoefs = ArrayUtil.copy((double[])model.biasCoefs);
        model2.edgeCoefs = ArrayUtil.copy((double[][])model.edgeCoefs);
        for (n = 0; n < model.featureVocab.size(); ++n) {
            String string = model.featureVocab.name(n);
            if (!model2.featureVocab.contains(string)) continue;
            int n3 = model2.featureVocab.num(string);
            model2.observationFeatureCoefs[n3] = ArrayUtil.copy((double[])model.observationFeatureCoefs[n]);
        }
    }
}

