/*
 * Decompiled with CFR 0.152.
 */
package edu.stanford.nlp.parser.shiftreduce;

import edu.stanford.nlp.parser.common.ParserConstraint;
import edu.stanford.nlp.parser.metrics.EvaluateTreebank;
import edu.stanford.nlp.parser.shiftreduce.BaseModel;
import edu.stanford.nlp.parser.shiftreduce.CombinationFeatureFactory;
import edu.stanford.nlp.parser.shiftreduce.FeatureFactory;
import edu.stanford.nlp.parser.shiftreduce.ReorderingOracle;
import edu.stanford.nlp.parser.shiftreduce.ShiftReduceOptions;
import edu.stanford.nlp.parser.shiftreduce.ShiftReduceParser;
import edu.stanford.nlp.parser.shiftreduce.ShiftReduceTrainOptions;
import edu.stanford.nlp.parser.shiftreduce.ShiftReduceUtils;
import edu.stanford.nlp.parser.shiftreduce.State;
import edu.stanford.nlp.parser.shiftreduce.TrainingExample;
import edu.stanford.nlp.parser.shiftreduce.TrainingResult;
import edu.stanford.nlp.parser.shiftreduce.TrainingUpdate;
import edu.stanford.nlp.parser.shiftreduce.Transition;
import edu.stanford.nlp.parser.shiftreduce.Weight;
import edu.stanford.nlp.parser.shiftreduce.WeightMap;
import edu.stanford.nlp.stats.Counters;
import edu.stanford.nlp.stats.IntCounter;
import edu.stanford.nlp.stats.TwoDimensionalIntCounter;
import edu.stanford.nlp.tagger.common.Tagger;
import edu.stanford.nlp.trees.Tree;
import edu.stanford.nlp.trees.Treebank;
import edu.stanford.nlp.util.CollectionUtils;
import edu.stanford.nlp.util.Generics;
import edu.stanford.nlp.util.Index;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.ReflectionLoading;
import edu.stanford.nlp.util.Scored;
import edu.stanford.nlp.util.ScoredComparator;
import edu.stanford.nlp.util.ScoredObject;
import edu.stanford.nlp.util.StringUtils;
import edu.stanford.nlp.util.Timing;
import edu.stanford.nlp.util.concurrent.MulticoreWrapper;
import edu.stanford.nlp.util.concurrent.ThreadsafeProcessor;
import edu.stanford.nlp.util.logging.Redwood;
import java.text.DecimalFormat;
import java.text.NumberFormat;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.PriorityQueue;
import java.util.Random;
import java.util.Set;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

public class PerceptronModel
extends BaseModel {
    private static final Redwood.RedwoodChannels log = Redwood.channels(PerceptronModel.class);
    private float learningRate = 1.0f;
    WeightMap featureWeights;
    final FeatureFactory featureFactory;
    private static final NumberFormat NF = new DecimalFormat("0.00");
    private static final NumberFormat FILENAME = new DecimalFormat("0000");
    private static final long serialVersionUID = 1L;

    public PerceptronModel(ShiftReduceOptions op, Index<Transition> transitionIndex, Set<String> knownStates, Set<String> rootStates, Set<String> rootOnlyStates) {
        super(op, transitionIndex, knownStates, rootStates, rootOnlyStates);
        this.featureWeights = new WeightMap();
        String[] classes = op.featureFactoryClass.split(";");
        if (classes.length == 1) {
            this.featureFactory = (FeatureFactory)ReflectionLoading.loadByReflection(classes[0], new Object[0]);
        } else {
            FeatureFactory[] factories = new FeatureFactory[classes.length];
            for (int i = 0; i < classes.length; ++i) {
                int paren = classes[i].indexOf(40);
                if (paren >= 0) {
                    String arg = classes[i].substring(paren + 1, classes[i].length() - 1);
                    factories[i] = (FeatureFactory)ReflectionLoading.loadByReflection(classes[i].substring(0, paren), arg);
                    continue;
                }
                factories[i] = (FeatureFactory)ReflectionLoading.loadByReflection(classes[i], new Object[0]);
            }
            this.featureFactory = new CombinationFeatureFactory(factories);
        }
    }

    public PerceptronModel(PerceptronModel other) {
        super(other);
        this.featureFactory = other.featureFactory;
        this.featureWeights = new WeightMap();
        for (String feature : other.featureWeights.keySet()) {
            this.featureWeights.put(feature, new Weight(other.featureWeights.get(feature)));
        }
    }

    public void averageScoredModels(Collection<ScoredObject<PerceptronModel>> scoredModels) {
        if (scoredModels.isEmpty()) {
            throw new IllegalArgumentException("Cannot average empty models");
        }
        log.info("Averaging " + scoredModels.size() + " models with scores");
        for (ScoredObject<PerceptronModel> model : scoredModels) {
            log.info(" " + NF.format(model.score()));
        }
        log.info(new Object[0]);
        List<PerceptronModel> models = CollectionUtils.transformAsList(scoredModels, ScoredObject::object);
        this.averageModels(models);
    }

    public void averageModels(Collection<PerceptronModel> models) {
        if (models.isEmpty()) {
            throw new IllegalArgumentException("Cannot average empty models");
        }
        Set<String> features = Generics.newHashSet();
        for (PerceptronModel model : models) {
            for (String feature : model.featureWeights.keySet()) {
                features.add(feature);
            }
        }
        this.featureWeights = new WeightMap();
        for (String feature : features) {
            this.featureWeights.put(feature, new Weight());
        }
        int numModels = models.size();
        for (String feature : features) {
            for (PerceptronModel model : models) {
                if (!model.featureWeights.containsKey(feature)) continue;
                this.featureWeights.get(feature).addScaled(model.featureWeights.get(feature), 1.0f / (float)numModels);
            }
        }
    }

    private void condenseFeatures() {
        Iterator<String> featureIt = this.featureWeights.keySet().iterator();
        while (featureIt.hasNext()) {
            String feature = featureIt.next();
            Weight weights = this.featureWeights.get(feature);
            weights.condense();
            if (weights.size() != 0) continue;
            featureIt.remove();
        }
    }

    private void filterFeatures(Set<String> keep) {
        Iterator<String> featureIt = this.featureWeights.keySet().iterator();
        while (featureIt.hasNext()) {
            if (keep.contains(featureIt.next())) continue;
            featureIt.remove();
        }
    }

    public int numWeights() {
        int numWeights = 0;
        for (Map.Entry<String, Weight> stringWeightEntry : this.featureWeights.entrySet()) {
            numWeights += stringWeightEntry.getValue().size();
        }
        return numWeights;
    }

    public float maxAbs() {
        float maxAbs = 0.0f;
        for (Map.Entry<String, Weight> weight : this.featureWeights.entrySet()) {
            maxAbs = Math.max(maxAbs, weight.getValue().maxAbs());
        }
        return maxAbs;
    }

    public void outputStats(TrainingResult result) {
        log.info("While training, got " + result.numCorrect + " transitions correct and " + result.numWrong + " transitions wrong");
        log.info("Number of known features: " + this.featureWeights.size());
        log.info("Number of non-zero weights: " + this.numWeights());
        log.info("Weight values maxAbs: " + this.maxAbs());
        int wordLength = 0;
        for (String feature : this.featureWeights.keySet()) {
            wordLength += feature.length();
        }
        log.info("Total word length: " + wordLength);
        log.info("Number of transitions: " + this.transitionIndex.size());
        IntCounter<Pair<Integer, Integer>> firstErrors = new IntCounter<Pair<Integer, Integer>>();
        for (Pair<Integer, Integer> firstError : result.firstErrors) {
            firstErrors.incrementCount(firstError);
        }
        this.outputFirstErrors(firstErrors);
        this.outputReordererStats(result.reorderSuccess, result.reorderFail);
        this.outputTransitionStats(result);
    }

    @Override
    Set<String> tagSet() {
        Set<String> tags = Generics.newHashSet();
        Pattern p1 = Pattern.compile("Q0TQ1T-([^-]+)-.*");
        Pattern p2 = Pattern.compile("S0T-(.*)");
        for (String feat : this.featureWeights.keySet()) {
            Matcher m2;
            Matcher m1 = p1.matcher(feat);
            if (m1.matches()) {
                tags.add(m1.group(1));
            }
            if (!(m2 = p2.matcher(feat)).matches()) continue;
            tags.add(m2.group(1));
        }
        tags.add(".$$.");
        return tags;
    }

    private ScoredObject<Integer> findHighestScoringTransition(State state, List<String> features, boolean requireLegal) {
        Collection<ScoredObject<Integer>> transitions = this.findHighestScoringTransitions(state, features, requireLegal, 1, null);
        if (transitions.isEmpty()) {
            return null;
        }
        return transitions.iterator().next();
    }

    @Override
    public Collection<ScoredObject<Integer>> findHighestScoringTransitions(State state, boolean requireLegal, int numTransitions, List<ParserConstraint> constraints) {
        List<String> features = this.featureFactory.featurize(state);
        return this.findHighestScoringTransitions(state, features, requireLegal, numTransitions, constraints);
    }

    private Collection<ScoredObject<Integer>> findHighestScoringTransitions(State state, List<String> features, boolean requireLegal, int numTransitions, List<ParserConstraint> constraints) {
        float[] scores = new float[this.transitionIndex.size()];
        for (String feature : features) {
            Weight weight = this.featureWeights.get(feature);
            if (weight == null) continue;
            weight.score(scores);
        }
        PriorityQueue<Scored> queue = new PriorityQueue<Scored>(numTransitions + 1, ScoredComparator.ASCENDING_COMPARATOR);
        for (int i = 0; i < scores.length; ++i) {
            if (requireLegal && !((Transition)this.transitionIndex.get(i)).isLegal(state, constraints)) continue;
            queue.add(new ScoredObject<Integer>(i, scores[i]));
            if (queue.size() <= numTransitions) continue;
            queue.poll();
        }
        return queue;
    }

    /*
     * Could not resolve type clashes
     */
    private TrainingResult trainTree(TrainingExample example) {
        int numCorrect = 0;
        int numWrong = 0;
        Tree tree = example.binarizedTree;
        ArrayList<TrainingUpdate> updates = Generics.newArrayList();
        Pair<Integer, Integer> firstError = null;
        IntCounter<Class<? extends Transition>> correctTransitions = new IntCounter<Class<? extends Transition>>();
        TwoDimensionalIntCounter<Class<? extends Transition>, Class<? extends Transition>> wrongTransitions = new TwoDimensionalIntCounter<Class<? extends Transition>, Class<? extends Transition>>();
        ReorderingOracle reorderer = null;
        if (this.op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.REORDER_ORACLE || this.op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.REORDER_BEAM) {
            reorderer = new ReorderingOracle(this.op, this.rootOnlyStates);
        }
        int reorderSuccess = 0;
        int reorderFail = 0;
        if (this.op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.BEAM || this.op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.REORDER_BEAM) {
            if (this.op.trainOptions().beamSize <= 0) {
                throw new IllegalArgumentException("Illegal beam size " + this.op.trainOptions().beamSize);
            }
            PriorityQueue<Scored> agenda = new PriorityQueue<Scored>(this.op.trainOptions().beamSize + 1, ScoredComparator.ASCENDING_COMPARATOR);
            State goldState = example.initialStateFromGoldTagTree();
            List<Transition> transitions = example.trainTransitions();
            agenda.add(goldState);
            while (transitions.size() > 0) {
                Transition goldTransition = transitions.get(0);
                Object highestScoringTransitionFromGoldState = null;
                double highestScoreFromGoldState = 0.0;
                PriorityQueue<Scored> newAgenda = new PriorityQueue<Scored>(this.op.trainOptions().beamSize + 1, ScoredComparator.ASCENDING_COMPARATOR);
                State highestScoringState = null;
                State highestCurrentState = null;
                for (State currentState : agenda) {
                    boolean isGoldState = goldState.areTransitionsEqual(currentState);
                    List<String> features = this.featureFactory.featurize(currentState);
                    Collection<ScoredObject<Integer>> stateTransitions = this.findHighestScoringTransitions(currentState, features, true, this.op.trainOptions().beamSize, null);
                    for (ScoredObject<Integer> transition : stateTransitions) {
                        State newState = ((Transition)this.transitionIndex.get(transition.object())).apply(currentState, transition.score());
                        newAgenda.add(newState);
                        if (newAgenda.size() > this.op.trainOptions().beamSize) {
                            newAgenda.poll();
                        }
                        if (highestScoringState == null || highestScoringState.score() < newState.score()) {
                            highestScoringState = newState;
                            highestCurrentState = currentState;
                        }
                        if (!isGoldState || highestScoringTransitionFromGoldState != null && !(transition.score() > highestScoreFromGoldState)) continue;
                        highestScoringTransitionFromGoldState = (Transition)this.transitionIndex.get(transition.object());
                        highestScoreFromGoldState = transition.score();
                    }
                }
                if (this.op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.REORDER_BEAM && highestScoringTransitionFromGoldState == null) break;
                if (highestScoringState == null) {
                    System.err.println("Unable to find a best transition!");
                    System.err.println("Previous agenda:");
                    for (State state : agenda) {
                        System.err.println(state);
                    }
                    System.err.println("Gold transitions:");
                    System.err.println(example.transitions);
                    break;
                }
                State newGoldState = goldTransition.apply(goldState, 0.0);
                if (firstError == null && !highestScoringTransitionFromGoldState.equals(goldTransition)) {
                    int predictedIndex = this.transitionIndex.indexOf(highestScoringTransitionFromGoldState);
                    int goldIndex = this.transitionIndex.indexOf(goldTransition);
                    if (predictedIndex < 0) {
                        throw new AssertionError((Object)("Predicted transition not in the index: " + highestScoringTransitionFromGoldState));
                    }
                    if (goldIndex < 0) {
                        throw new AssertionError((Object)("Gold transition not in the index: " + goldTransition));
                    }
                    firstError = new Pair<Integer, Integer>(predictedIndex, goldIndex);
                }
                if (!newGoldState.areTransitionsEqual(highestScoringState)) {
                    ++numWrong;
                    wrongTransitions.incrementCount(goldTransition.getClass(), highestScoringTransitionFromGoldState.getClass());
                    List<String> goldFeatures = this.featureFactory.featurize(goldState);
                    int lastTransition = this.transitionIndex.indexOf(highestScoringState.transitions.peek());
                    updates.add(new TrainingUpdate(this.featureFactory.featurize(highestCurrentState), -1, lastTransition, this.learningRate));
                    updates.add(new TrainingUpdate(goldFeatures, this.transitionIndex.indexOf(goldTransition), -1, this.learningRate));
                    if (this.op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.BEAM) {
                        if (!ShiftReduceUtils.findStateOnAgenda(newAgenda, newGoldState)) break;
                        transitions.remove(0);
                    } else if (this.op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.REORDER_BEAM) {
                        if (!ShiftReduceUtils.findStateOnAgenda(newAgenda, newGoldState)) {
                            if (!reorderer.reorder(goldState, (Transition)highestScoringTransitionFromGoldState, transitions)) {
                                if (reorderSuccess != 0) break;
                                reorderFail = 1;
                                break;
                            }
                            newGoldState = highestScoringTransitionFromGoldState.apply(goldState);
                            if (!ShiftReduceUtils.findStateOnAgenda(newAgenda, newGoldState)) {
                                if (reorderSuccess != 0) break;
                                reorderFail = 1;
                                break;
                            }
                            reorderSuccess = 1;
                        } else {
                            transitions.remove(0);
                        }
                    }
                } else {
                    ++numCorrect;
                    correctTransitions.incrementCount(goldTransition.getClass());
                    transitions.remove(0);
                }
                goldState = newGoldState;
                agenda = newAgenda;
            }
        } else if (this.op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.REORDER_ORACLE || this.op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.EARLY_TERMINATION || this.op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.GOLD) {
            State state = example.initialStateFromGoldTagTree();
            List<Transition> transitions = example.trainTransitions();
            boolean keepGoing = true;
            block9: while (transitions.size() > 0 && keepGoing) {
                Transition gold = transitions.get(0);
                int goldNum = this.transitionIndex.indexOf(gold);
                List<String> features = this.featureFactory.featurize(state);
                int predictedNum = this.findHighestScoringTransition(state, features, false).object();
                Transition predicted = (Transition)this.transitionIndex.get(predictedNum);
                if (goldNum == predictedNum) {
                    transitions.remove(0);
                    state = gold.apply(state);
                    ++numCorrect;
                    correctTransitions.incrementCount(gold.getClass());
                    continue;
                }
                ++numWrong;
                wrongTransitions.incrementCount(gold.getClass(), predicted.getClass());
                if (firstError == null) {
                    firstError = new Pair<Integer, Integer>(predictedNum, goldNum);
                }
                updates.add(new TrainingUpdate(features, goldNum, predictedNum, this.learningRate));
                switch (this.op.trainOptions().trainingMethod) {
                    case EARLY_TERMINATION: {
                        keepGoing = false;
                        continue block9;
                    }
                    case GOLD: {
                        transitions.remove(0);
                        state = gold.apply(state);
                        continue block9;
                    }
                    case REORDER_ORACLE: {
                        keepGoing = reorderer.reorder(state, predicted, transitions);
                        if (keepGoing) {
                            state = predicted.apply(state);
                            reorderSuccess = 1;
                            continue block9;
                        }
                        if (reorderSuccess != 0) continue block9;
                        reorderFail = 1;
                        continue block9;
                    }
                }
                throw new IllegalArgumentException("Unexpected method " + (Object)((Object)this.op.trainOptions().trainingMethod));
            }
        }
        return new TrainingResult(updates, numCorrect, numWrong, firstError, correctTransitions, wrongTransitions, reorderSuccess, reorderFail);
    }

    private TrainingResult trainBatch(List<TrainingExample> trainingData, MulticoreWrapper<TrainingExample, TrainingResult> wrapper) {
        ArrayList<TrainingResult> results = new ArrayList<TrainingResult>();
        if (this.op.trainOptions.trainingThreads == 1) {
            for (TrainingExample example : trainingData) {
                TrainingResult result = this.trainTree(example);
                results.add(result);
            }
        } else {
            for (TrainingExample example : trainingData) {
                wrapper.put(example);
            }
            wrapper.join(false);
            while (wrapper.peek()) {
                TrainingResult result = wrapper.poll();
                results.add(result);
            }
        }
        return new TrainingResult(results);
    }

    private double evaluate(Tagger tagger, Treebank devTreebank, String message) {
        ShiftReduceParser temp = new ShiftReduceParser(this.op, this);
        EvaluateTreebank evaluator = new EvaluateTreebank(temp.getOp(), null, temp, tagger, temp.getExtraEvals(), temp.getParserQueryEvals());
        evaluator.testOnTreebank(devTreebank);
        double labelF1 = evaluator.getLBScore();
        log.info(message + ": " + labelF1);
        return labelF1;
    }

    static void augmentSubsentences(List<TrainingExample> augmentedData, List<TrainingExample> trainingData, Random random, float augmentFraction) {
        for (TrainingExample example : trainingData) {
            if (example.transitions.size() <= 10 || !(random.nextDouble() < (double)augmentFraction)) continue;
            int pivot = random.nextInt(example.transitions.size() - 10) + 7;
            augmentedData.add(new TrainingExample(example.binarizedTree, example.transitions, pivot));
        }
    }

    private void outputFirstErrors(IntCounter<Pair<Integer, Integer>> firstErrors) {
        if (firstErrors == null || firstErrors.size() == 0) {
            return;
        }
        IntCounter<Pair<Integer, Integer>> firstErrorCopy = new IntCounter<Pair<Integer, Integer>>(firstErrors);
        log.info("Most common transition errors: gold -> predicted");
        for (int i = 0; i < 9 && firstErrorCopy.size() > 0; ++i) {
            Pair<Integer, Integer> mostCommon = firstErrorCopy.argmax();
            int count = firstErrorCopy.max();
            firstErrorCopy.decrementCount(mostCommon, count);
            Transition predicted = (Transition)this.transitionIndex.get(mostCommon.first());
            Transition gold = (Transition)this.transitionIndex.get(mostCommon.second());
            log.info("  # " + (i + 1) + ": " + gold + " -> " + predicted + " happened " + firstErrorCopy.max() + " times");
        }
    }

    private void outputReordererStats(int numReorderSuccess, int numReorderFail) {
        if (numReorderSuccess == 0 && numReorderFail == 0) {
            return;
        }
        log.info("Reorderer successfully operated at least once on " + numReorderSuccess + " training trees and failed to do anything useful on " + numReorderFail + " trees");
    }

    private void outputTransitionStats(TrainingResult result) {
        List<Class<? extends Transition>> sorted = Counters.toSortedList(result.correctTransitions);
        ArrayList<String> correct = new ArrayList<String>();
        correct.add("Got the following transition types correct:");
        for (Class<? extends Transition> t : sorted) {
            correct.add(ShiftReduceUtils.transitionShortName(t) + ": " + result.correctTransitions.getCount(t));
        }
        log.info(StringUtils.join(correct, "\n  "));
        IntCounter<Class<? extends Transition>> totalGuesses = result.wrongTransitions.totalCounts();
        sorted = Counters.toSortedList(totalGuesses);
        ArrayList<String> wrong = new ArrayList<String>();
        wrong.add("Got the following transition types incorrect:");
        for (Class<? extends Transition> t : sorted) {
            IntCounter<Class<? extends Transition>> inner = result.wrongTransitions.getCounter(t);
            List<Class<? extends Transition>> sortedInner = Counters.toSortedList(inner);
            for (Class<? extends Transition> u : sortedInner) {
                wrong.add(ShiftReduceUtils.transitionShortName(t) + " -> " + ShiftReduceUtils.transitionShortName(u) + ": " + inner.getCount(u));
            }
        }
        log.info(StringUtils.join(wrong, "\n  "));
    }

    private void trainModel(String serializedPath, Tagger tagger, Random random, List<TrainingExample> trainingData, Treebank devTreebank, int nThreads, Set<String> allowedFeatures) {
        double bestScore = 0.0;
        int bestIteration = 0;
        PriorityQueue<Scored> bestModels = null;
        if (this.op.trainOptions().averagedModels > 0) {
            bestModels = new PriorityQueue<Scored>(this.op.trainOptions().averagedModels + 1, ScoredComparator.ASCENDING_COMPARATOR);
        }
        MulticoreWrapper<TrainingExample, TrainingResult> wrapper = null;
        if (nThreads != 1) {
            wrapper = new MulticoreWrapper<TrainingExample, TrainingResult>(this.op.trainOptions.trainingThreads, new TrainTreeProcessor());
        }
        IntCounter<String> featureFrequencies = null;
        if (this.op.trainOptions().featureFrequencyCutoff > 1 && allowedFeatures == null) {
            featureFrequencies = new IntCounter<String>();
        }
        for (int iteration = 1; iteration <= this.op.trainOptions.trainingIterations; ++iteration) {
            float l1Reg;
            Timing trainingTimer = new Timing();
            ArrayList<TrainingResult> results = new ArrayList<TrainingResult>();
            ArrayList<TrainingExample> augmentedData = new ArrayList<TrainingExample>(trainingData);
            PerceptronModel.augmentSubsentences(augmentedData, trainingData, random, this.op.trainOptions().augmentSubsentences);
            Collections.shuffle(augmentedData, random);
            log.info("Original list " + trainingData.size() + "; augmented " + augmentedData.size());
            for (int start = 0; start < augmentedData.size(); start += this.op.trainOptions.batchSize) {
                int end = Math.min(start + this.op.trainOptions.batchSize, augmentedData.size());
                TrainingResult trainingResult = this.trainBatch(augmentedData.subList(start, end), wrapper);
                results.add(trainingResult);
                for (TrainingUpdate update : trainingResult.updates) {
                    for (String feature : update.features) {
                        if (allowedFeatures != null && !allowedFeatures.contains(feature)) continue;
                        Weight weight = this.featureWeights.get(feature);
                        if (weight == null) {
                            weight = new Weight();
                            this.featureWeights.put(feature, weight);
                        }
                        weight.updateWeight(update.goldTransition, update.delta);
                        weight.updateWeight(update.predictedTransition, -update.delta);
                        if (featureFrequencies == null) continue;
                        featureFrequencies.incrementCount(feature, update.goldTransition >= 0 && update.predictedTransition >= 0 ? 2 : 1);
                    }
                }
            }
            float l2Reg = this.op.trainOptions().l2Reg;
            if (l2Reg > 0.0f) {
                for (Map.Entry<String, Weight> entry : this.featureWeights.entrySet()) {
                    entry.getValue().l2Reg(l2Reg);
                }
            }
            if ((l1Reg = this.op.trainOptions().l1Reg) > 0.0f) {
                for (Map.Entry<String, Weight> weight2 : this.featureWeights.entrySet()) {
                    weight2.getValue().l1Reg(l1Reg);
                }
            }
            trainingTimer.done("Iteration " + iteration);
            this.outputStats(new TrainingResult(results));
            double d = 0.0;
            if (devTreebank != null) {
                d = this.evaluate(tagger, devTreebank, "Label F1 for iteration " + iteration);
                if (d > bestScore) {
                    log.info("New best dev score (previous best " + bestScore + ")");
                    bestScore = d;
                    bestIteration = iteration;
                } else {
                    log.info("Failed to improve for " + (iteration - bestIteration) + " iteration(s) on previous best score of " + bestScore);
                    if (this.op.trainOptions.stalledIterationLimit > 0 && iteration - bestIteration >= this.op.trainOptions.stalledIterationLimit) {
                        log.info("Failed to improve for too long, stopping training");
                        break;
                    }
                }
                log.info("\n\n");
                if (bestModels != null) {
                    PerceptronModel copy = new PerceptronModel(this);
                    copy.condenseFeatures();
                    bestModels.add(new ScoredObject<PerceptronModel>(copy, d));
                    if (bestModels.size() > this.op.trainOptions().averagedModels) {
                        bestModels.poll();
                    }
                }
            }
            if (this.op.trainOptions().saveIntermediateModels && serializedPath != null && this.op.trainOptions.debugOutputFrequency > 0) {
                String tempName = serializedPath.substring(0, serializedPath.length() - 7) + "-" + FILENAME.format(iteration) + "-" + NF.format(d) + ".ser.gz";
                ShiftReduceParser temp = new ShiftReduceParser(this.op, this);
                temp.saveModel(tempName);
            }
            if (iteration % 10 != 0 || !(this.op.trainOptions().decayLearningRate > 0.0)) continue;
            this.learningRate = (float)((double)this.learningRate * this.op.trainOptions().decayLearningRate);
        }
        if (wrapper != null) {
            wrapper.join();
        }
        if (bestModels != null) {
            if (this.op.trainOptions().cvAveragedModels && devTreebank != null) {
                ArrayList<Scored> models = Generics.newArrayList();
                while (bestModels.size() > 0) {
                    models.add(bestModels.poll());
                }
                Collections.reverse(models);
                double bestF1 = 0.0;
                int bestSize = 0;
                for (int i = 1; i <= models.size(); ++i) {
                    log.info("Testing with " + i + " models averaged together");
                    this.averageScoredModels(models.subList(0, i));
                    double labelF1 = this.evaluate(tagger, devTreebank, "Label F1 for " + i + " models");
                    if (!(labelF1 > bestF1)) continue;
                    bestF1 = labelF1;
                    bestSize = i;
                }
                this.averageScoredModels(models.subList(0, bestSize));
                log.info("Label F1 for " + bestSize + " models: " + bestF1);
            } else {
                this.averageScoredModels(bestModels);
            }
        }
        if (featureFrequencies != null) {
            this.filterFeatures(featureFrequencies.keysAbove(this.op.trainOptions().featureFrequencyCutoff));
        }
        this.condenseFeatures();
    }

    static Set<String> pruneFeatures(Set<String> features, Random random, double drop) {
        HashSet<String> prunedFeatures;
        block1: {
            String feature;
            prunedFeatures = new HashSet<String>();
            Iterator<String> iterator = features.iterator();
            while (iterator.hasNext()) {
                feature = iterator.next();
                if (!(random.nextDouble() > drop)) continue;
                prunedFeatures.add(feature);
            }
            if (prunedFeatures.size() != 0 || !(iterator = features.iterator()).hasNext()) break block1;
            feature = iterator.next();
            prunedFeatures.add(feature);
        }
        return prunedFeatures;
    }

    public static PerceptronModel trainModel(ShiftReduceOptions op, Index<Transition> transitionIndex, Set<String> knownStates, Set<String> rootStates, Set<String> rootOnlyStates, PerceptronModel initialModel, String serializedPath, Tagger tagger, Random random, List<TrainingExample> trainingData, Treebank devTreebank, int nThreads) {
        if (initialModel == null) {
            initialModel = new PerceptronModel(op, transitionIndex, knownStates, rootStates, rootOnlyStates);
        }
        if (op.trainOptions().retrainAfterCutoff && op.trainOptions().featureFrequencyCutoff > 0 || op.trainOptions().retrainShards > 1) {
            String tempName = serializedPath.substring(0, serializedPath.length() - 7) + "-temp.ser.gz";
            PerceptronModel currentModel = new PerceptronModel(initialModel);
            currentModel.trainModel(tempName, tagger, random, trainingData, devTreebank, nThreads, null);
            if (op.trainOptions().saveIntermediateModels) {
                ShiftReduceParser temp = new ShiftReduceParser(op, currentModel);
                temp.saveModel(tempName);
            }
            log.info("Beginning retraining");
            Set<String> allowedFeatures = currentModel.featureWeights.keySet();
            currentModel = new PerceptronModel(initialModel);
            currentModel.filterFeatures(allowedFeatures);
            currentModel.trainModel(serializedPath, tagger, random, trainingData, devTreebank, nThreads, allowedFeatures);
            if (op.trainOptions().retrainShards > 1) {
                ArrayList<PerceptronModel> shards = Generics.newArrayList();
                shards.add(currentModel);
                for (int i = 1; i < op.trainOptions().retrainShards; ++i) {
                    log.info("Beginning retraining of shard " + (i + 1));
                    Set<String> prunedFeatures = PerceptronModel.pruneFeatures(allowedFeatures, random, op.trainOptions().retrainShardFeatureDrop);
                    currentModel = new PerceptronModel(initialModel);
                    currentModel.filterFeatures(prunedFeatures);
                    currentModel.trainModel(serializedPath, tagger, random, trainingData, devTreebank, nThreads, prunedFeatures);
                    shards.add(currentModel);
                }
                log.info("Averaging " + op.trainOptions().retrainShards + " shards");
                currentModel = new PerceptronModel(initialModel);
                currentModel.averageModels(shards);
                currentModel.condenseFeatures();
                currentModel.evaluate(tagger, devTreebank, "Label F1 for " + op.trainOptions().retrainShards + " averaged shards");
            }
            return currentModel;
        }
        PerceptronModel currentModel = new PerceptronModel(initialModel);
        currentModel.trainModel(serializedPath, tagger, random, trainingData, devTreebank, nThreads, null);
        return currentModel;
    }

    private class TrainTreeProcessor
    implements ThreadsafeProcessor<TrainingExample, TrainingResult> {
        @Override
        public TrainingResult process(TrainingExample example) {
            return PerceptronModel.this.trainTree(example);
        }

        public TrainTreeProcessor newInstance() {
            return this;
        }
    }
}

