/*
 * Decompiled with CFR 0.152.
 */
package cmu.arktweetnlp.impl;

import cmu.arktweetnlp.impl.OptimizerState;
import edu.stanford.nlp.math.ArrayMath;
import edu.stanford.nlp.optimization.DiffFunction;
import gnu.trove.set.hash.THashSet;
import java.util.LinkedList;
import java.util.Queue;
import java.util.Set;

public class OWLQN {
    private int maxIters = Integer.MAX_VALUE;
    private static boolean constrained = false;
    boolean quiet;
    boolean responsibleForTermCrit;
    public static Set<Integer> biasParameters = new THashSet();
    TerminationCriterion termCrit;
    WeightsPrinter printer;
    private static int numUnconstrainedWeights = -1;

    public OWLQN(boolean bl) {
        this.quiet = bl;
        this.termCrit = new RelativeMeanImprovementCriterion();
        this.responsibleForTermCrit = true;
    }

    public OWLQN() {
        this(false);
    }

    public OWLQN(TerminationCriterion terminationCriterion, boolean bl) {
        this.quiet = bl;
        this.termCrit = terminationCriterion;
        this.responsibleForTermCrit = false;
    }

    public void setQuiet(boolean bl) {
        this.quiet = bl;
    }

    public double[] minimize(DiffFunction diffFunction, double[] dArray, double d, double d2, int n) {
        OptimizerState optimizerState = new OptimizerState(diffFunction, dArray, n, d, this.quiet);
        if (!this.quiet) {
            System.err.printf("Optimizing function of %d variables with OWL-QN parameters:\n", optimizerState.dim);
            System.err.printf("   l1 regularization weight: %f.\n", d);
            System.err.printf("   L-BFGS memory parameter (m): %d\n", n);
            System.err.printf("   Convergence tolerance: %f\n\n", d2);
            System.err.printf("Iter    n:\tnew_value\tdf\t(conv_crit)\tline_search\n", new Object[0]);
            System.err.printf("Iter    0:\t%.4e\t\t(***********)\t", optimizerState.value);
        }
        StringBuilder stringBuilder = new StringBuilder();
        this.termCrit.getValue(optimizerState, stringBuilder);
        for (int i = 0; i < this.maxIters; ++i) {
            stringBuilder.setLength(0);
            optimizerState.updateDir();
            optimizerState.backTrackingLineSearch();
            double d3 = this.termCrit.getValue(optimizerState, stringBuilder);
            if (!this.quiet) {
                int n2 = ArrayMath.countNonZero((double[])optimizerState.newX);
                System.err.printf("Iter %4d:\t%.4e\t%d", optimizerState.iter, optimizerState.value, n2);
                System.err.print("\t" + stringBuilder.toString());
                if (this.printer != null) {
                    this.printer.printWeights();
                }
            }
            if (this.arrayEquals(optimizerState.x, optimizerState.newX)) {
                System.err.println("Warning: Stopping OWL-QN since there was no change in the parameters in the last iteration.  This probably means convergence has been reached.");
                break;
            }
            if (d3 < d2) break;
            optimizerState.shift();
        }
        if (!this.quiet) {
            System.err.println();
            System.err.printf("Finished with optimization.  %d/%d non-zero weights.\n", ArrayMath.countNonZero((double[])optimizerState.newX), optimizerState.newX.length);
        }
        return optimizerState.newX;
    }

    private boolean arrayEquals(double[] dArray, double[] dArray2) {
        if (dArray.length != dArray2.length) {
            return false;
        }
        for (int i = 0; i < dArray.length; ++i) {
            if (dArray[i] == dArray2[i]) continue;
            return false;
        }
        return true;
    }

    public void setMaxIters(int n) {
        this.maxIters = n;
    }

    public int getMaxIters() {
        return this.maxIters;
    }

    public void setWeightsPrinting(WeightsPrinter weightsPrinter) {
        this.printer = weightsPrinter;
    }

    public static void setConstrained(boolean bl) {
        constrained = bl;
        numUnconstrainedWeights = bl ? 0 : -1;
    }

    public static void setConstrained(int n) {
        numUnconstrainedWeights = n;
        constrained = n >= 0;
    }

    public static boolean isConstrained() {
        return constrained;
    }

    protected static double[] projectWeights(double[] dArray) {
        if (numUnconstrainedWeights == 0) {
            return OWLQN.project(dArray);
        }
        double[] dArray2 = new double[dArray.length - numUnconstrainedWeights];
        for (int i = numUnconstrainedWeights; i < dArray.length; ++i) {
            dArray2[i - OWLQN.numUnconstrainedWeights] = dArray[i];
        }
        double[] dArray3 = OWLQN.project(dArray2);
        double[] dArray4 = new double[dArray.length];
        for (int i = 0; i < dArray.length; ++i) {
            dArray4[i] = i < numUnconstrainedWeights ? dArray[i] : dArray3[i - numUnconstrainedWeights];
        }
        return dArray4;
    }

    public static double[] project(double[] dArray) {
        int n;
        THashSet tHashSet = new THashSet();
        double[] dArray2 = dArray;
        do {
            double d = 0.0;
            for (double d2 : dArray2) {
                d += d2;
            }
            double d3 = (d - 1.0) / (double)(dArray.length - tHashSet.size());
            n = 1;
            for (int i = 0; i < dArray2.length; ++i) {
                double d4 = dArray2[i] = tHashSet.contains(i) ? 0.0 : dArray2[i] - d3;
                if (!(dArray2[i] < 0.0)) continue;
                n = 0;
                tHashSet.add(i);
                dArray2[i] = 0.0;
            }
        } while (n == 0);
        return dArray2;
    }

    public static interface WeightsPrinter {
        public void printWeights();
    }

    static class RelativeMeanImprovementCriterion
    implements TerminationCriterion {
        int numItersToAvg;
        Queue<Double> prevVals;

        RelativeMeanImprovementCriterion() {
            this(10);
        }

        RelativeMeanImprovementCriterion(int n) {
            this.numItersToAvg = n;
            this.prevVals = new LinkedList<Double>();
        }

        @Override
        public double getValue(OptimizerState optimizerState, StringBuilder stringBuilder) {
            double d = Double.POSITIVE_INFINITY;
            if (this.prevVals.size() >= this.numItersToAvg) {
                double d2 = this.prevVals.peek();
                if (this.prevVals.size() == this.numItersToAvg) {
                    this.prevVals.poll();
                }
                double d3 = (d2 - optimizerState.getValue()) / (double)this.prevVals.size();
                double d4 = d3 / Math.abs(optimizerState.getValue());
                String string = String.format("%.4e", d4);
                stringBuilder.append("  (").append(string).append(") ");
                d = d4;
            } else {
                stringBuilder.append("  (wait for " + this.numItersToAvg + " iters) ");
            }
            this.prevVals.offer(optimizerState.getValue());
            return d;
        }
    }

    static interface TerminationCriterion {
        public double getValue(OptimizerState var1, StringBuilder var2);
    }
}

