/*
 * Decompiled with CFR 0.152.
 */
package datafu.opennlp.maxent.quasinewton;

import datafu.opennlp.maxent.quasinewton.DifferentiableFunction;
import datafu.opennlp.model.DataIndexer;
import datafu.opennlp.model.OnePassRealValueDataIndexer;
import java.util.ArrayList;
import java.util.Arrays;

public class LogLikelihoodFunction
implements DifferentiableFunction {
    private int domainDimension;
    private double value;
    private double[] gradient;
    private double[] lastX;
    private double[] empiricalCount;
    private int numOutcomes;
    private int numFeatures;
    private int numContexts;
    private double[][] probModel;
    private String[] outcomeLabels;
    private String[] predLabels;
    private int[][] outcomePatterns;
    private final float[][] values;
    private final int[][] contexts;
    private final int[] outcomeList;
    private final int[] numTimesEventsSeen;

    public LogLikelihoodFunction(DataIndexer indexer) {
        this.values = indexer instanceof OnePassRealValueDataIndexer ? indexer.getValues() : (float[][])null;
        this.contexts = indexer.getContexts();
        this.outcomeList = indexer.getOutcomeList();
        this.numTimesEventsSeen = indexer.getNumTimesEventsSeen();
        this.outcomeLabels = indexer.getOutcomeLabels();
        this.predLabels = indexer.getPredLabels();
        this.numOutcomes = indexer.getOutcomeLabels().length;
        this.numFeatures = indexer.getPredLabels().length;
        this.numContexts = this.contexts.length;
        this.domainDimension = this.numOutcomes * this.numFeatures;
        this.probModel = new double[this.numContexts][this.numOutcomes];
        this.gradient = null;
    }

    public double valueAt(double[] x) {
        if (!this.checkLastX(x)) {
            this.calculate(x);
        }
        return this.value;
    }

    public double[] gradientAt(double[] x) {
        if (!this.checkLastX(x)) {
            this.calculate(x);
        }
        return this.gradient;
    }

    public int getDomainDimension() {
        return this.domainDimension;
    }

    public double[] getInitialPoint() {
        return new double[this.domainDimension];
    }

    public String[] getPredLabels() {
        return this.predLabels;
    }

    public String[] getOutcomeLabels() {
        return this.outcomeLabels;
    }

    public int[][] getOutcomePatterns() {
        return this.outcomePatterns;
    }

    private void calculate(double[] x) {
        double predValue;
        int vectorIndex;
        int af;
        if (x.length != this.domainDimension) {
            throw new IllegalArgumentException("x is invalid, its dimension is not equal to the function.");
        }
        this.initProbModel();
        if (this.empiricalCount == null) {
            this.initEmpCount();
        }
        double logLikelihood = 0.0;
        for (int ci = 0; ci < this.numContexts; ++ci) {
            int i;
            double voteSum = 0.0;
            for (af = 0; af < this.contexts[ci].length; ++af) {
                vectorIndex = this.indexOf(this.outcomeList[ci], this.contexts[ci][af]);
                predValue = 1.0;
                if (this.values != null) {
                    predValue = this.values[ci][af];
                }
                if (predValue == 0.0) continue;
                voteSum += predValue * x[vectorIndex];
            }
            this.probModel[ci][this.outcomeList[ci]] = Math.exp(voteSum);
            double totalVote = 0.0;
            for (i = 0; i < this.numOutcomes; ++i) {
                totalVote += this.probModel[ci][i];
            }
            i = 0;
            while (i < this.numOutcomes) {
                double[] dArray = this.probModel[ci];
                int n = i++;
                dArray[n] = dArray[n] / totalVote;
            }
            for (i = 0; i < this.numTimesEventsSeen[ci]; ++i) {
                logLikelihood += Math.log(this.probModel[ci][this.outcomeList[ci]]);
            }
        }
        this.value = logLikelihood;
        double[] expectedCount = new double[this.numOutcomes * this.numFeatures];
        for (int ci = 0; ci < this.numContexts; ++ci) {
            for (int oi = 0; oi < this.numOutcomes; ++oi) {
                for (af = 0; af < this.contexts[ci].length; ++af) {
                    vectorIndex = this.indexOf(oi, this.contexts[ci][af]);
                    predValue = 1.0;
                    if (this.values != null) {
                        predValue = this.values[ci][af];
                    }
                    if (predValue == 0.0) continue;
                    int n = vectorIndex;
                    expectedCount[n] = expectedCount[n] + predValue * this.probModel[ci][oi] * (double)this.numTimesEventsSeen[ci];
                }
            }
        }
        double[] gradient = new double[this.domainDimension];
        for (int i = 0; i < this.numOutcomes * this.numFeatures; ++i) {
            gradient[i] = expectedCount[i] - this.empiricalCount[i];
        }
        this.gradient = gradient;
        this.lastX = (double[])x.clone();
    }

    private boolean checkLastX(double[] x) {
        if (this.lastX == null) {
            return false;
        }
        for (int i = 0; i < x.length; ++i) {
            if (this.lastX[i] == x[i]) continue;
            return false;
        }
        return true;
    }

    private int indexOf(int outcomeId, int featureId) {
        return outcomeId * this.numFeatures + featureId;
    }

    private void initProbModel() {
        for (int i = 0; i < this.probModel.length; ++i) {
            Arrays.fill(this.probModel[i], 1.0);
        }
    }

    private void initEmpCount() {
        this.empiricalCount = new double[this.numOutcomes * this.numFeatures];
        this.outcomePatterns = new int[this.predLabels.length][];
        for (int ci = 0; ci < this.numContexts; ++ci) {
            for (int af = 0; af < this.contexts[ci].length; ++af) {
                int vectorIndex = this.indexOf(this.outcomeList[ci], this.contexts[ci][af]);
                if (this.values != null) {
                    int n = vectorIndex;
                    this.empiricalCount[n] = this.empiricalCount[n] + (double)(this.values[ci][af] * (float)this.numTimesEventsSeen[ci]);
                    continue;
                }
                int n = vectorIndex;
                this.empiricalCount[n] = this.empiricalCount[n] + 1.0 * (double)this.numTimesEventsSeen[ci];
            }
        }
        for (int fi = 0; fi < this.outcomePatterns.length; ++fi) {
            ArrayList<Integer> pattern = new ArrayList<Integer>();
            for (int oi = 0; oi < this.outcomeLabels.length; ++oi) {
                int countIndex = fi + this.predLabels.length * oi;
                if (!(this.empiricalCount[countIndex] > 0.0)) continue;
                pattern.add(oi);
            }
            this.outcomePatterns[fi] = new int[pattern.size()];
            for (int i = 0; i < pattern.size(); ++i) {
                this.outcomePatterns[fi][i] = (Integer)pattern.get(i);
            }
        }
    }
}

