/*
 * Decompiled with CFR 0.152.
 */
package opennlp.tools.ml.perceptron;

import java.io.IOException;
import opennlp.tools.ml.AbstractEventTrainer;
import opennlp.tools.ml.ArrayMath;
import opennlp.tools.ml.model.AbstractModel;
import opennlp.tools.ml.model.Context;
import opennlp.tools.ml.model.DataIndexer;
import opennlp.tools.ml.model.EvalParameters;
import opennlp.tools.ml.model.MutableContext;
import opennlp.tools.ml.perceptron.PerceptronModel;
import opennlp.tools.monitoring.DefaultTrainingProgressMonitor;
import opennlp.tools.monitoring.IterDeltaAccuracyUnderTolerance;
import opennlp.tools.monitoring.StopCriteria;
import opennlp.tools.monitoring.TrainingMeasure;
import opennlp.tools.monitoring.TrainingProgressMonitor;
import opennlp.tools.util.TrainingConfiguration;
import opennlp.tools.util.TrainingParameters;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class PerceptronTrainer
extends AbstractEventTrainer {
    private static final Logger logger = LoggerFactory.getLogger(PerceptronTrainer.class);
    public static final String PERCEPTRON_VALUE = "PERCEPTRON";
    public static final double TOLERANCE_DEFAULT = 1.0E-5;
    private int numUniqueEvents;
    private int numEvents;
    private int numPreds;
    private int numOutcomes;
    private int[][] contexts;
    private float[][] values;
    private int[] outcomeList;
    private int[] numTimesEventsSeen;
    private String[] outcomeLabels;
    private String[] predLabels;
    private double tolerance = 1.0E-5;
    private Double stepSizeDecrease;
    private boolean useSkippedlAveraging;

    public PerceptronTrainer() {
    }

    public PerceptronTrainer(TrainingParameters parameters) {
        super(parameters);
    }

    @Override
    public void validate() {
        super.validate();
        String algorithmName = this.getAlgorithm();
        if (algorithmName != null && !PERCEPTRON_VALUE.equals(algorithmName)) {
            throw new IllegalArgumentException("algorithmName must be PERCEPTRON");
        }
    }

    @Override
    public boolean isSortAndMerge() {
        return false;
    }

    @Override
    public AbstractModel doTrain(DataIndexer indexer) throws IOException {
        int iterations = this.getIterations();
        int cutoff = this.getCutoff();
        boolean useAverage = this.trainingParameters.getBooleanParameter("UseAverage", true);
        boolean useSkippedAveraging = this.trainingParameters.getBooleanParameter("UseSkippedAveraging", false);
        if (useSkippedAveraging) {
            useAverage = true;
        }
        double stepSizeDecrease = this.trainingParameters.getDoubleParameter("StepSizeDecrease", 0.0);
        double tolerance = this.trainingParameters.getDoubleParameter("Tolerance", 1.0E-5);
        this.setSkippedAveraging(useSkippedAveraging);
        if (stepSizeDecrease > 0.0) {
            this.setStepSizeDecrease(stepSizeDecrease);
        }
        this.setTolerance(tolerance);
        return this.trainModel(iterations, indexer, cutoff, useAverage);
    }

    public void setTolerance(double tolerance) {
        if (tolerance < 0.0) {
            throw new IllegalArgumentException("tolerance must be a positive number but is " + tolerance + "!");
        }
        this.tolerance = tolerance;
    }

    public void setStepSizeDecrease(double decrease) {
        if (decrease < 0.0 || decrease > 100.0) {
            throw new IllegalArgumentException("decrease must be between 0 and 100 but is " + decrease + "!");
        }
        this.stepSizeDecrease = decrease;
    }

    public void setSkippedAveraging(boolean averaging) {
        this.useSkippedlAveraging = averaging;
    }

    public AbstractModel trainModel(int iterations, DataIndexer di, int cutoff) {
        return this.trainModel(iterations, di, cutoff, true);
    }

    public AbstractModel trainModel(int iterations, DataIndexer di, int cutoff, boolean useAverage) {
        logger.info("Incorporating indexed data for training... ");
        this.contexts = di.getContexts();
        this.values = di.getValues();
        this.numTimesEventsSeen = di.getNumTimesEventsSeen();
        this.numEvents = di.getNumEvents();
        this.numUniqueEvents = this.contexts.length;
        this.outcomeLabels = di.getOutcomeLabels();
        this.outcomeList = di.getOutcomeList();
        this.predLabels = di.getPredLabels();
        this.numPreds = this.predLabels.length;
        this.numOutcomes = this.outcomeLabels.length;
        logger.info("done.");
        logger.info("\tNumber of Event Tokens: {} \n\t Number of Outcomes: {} \n\t Number of Predicates: {}", new Object[]{this.numUniqueEvents, this.numOutcomes, this.numPreds});
        logger.info("Computing model parameters...");
        Context[] finalParameters = this.findParameters(iterations, useAverage);
        logger.info("...done.");
        return new PerceptronModel(finalParameters, this.predLabels, this.outcomeLabels);
    }

    private MutableContext[] findParameters(int iterations, boolean useAverage) {
        logger.info("Performing {} iterations.", (Object)iterations);
        int[] allOutcomesPattern = new int[this.numOutcomes];
        for (int oi = 0; oi < this.numOutcomes; ++oi) {
            allOutcomesPattern[oi] = oi;
        }
        Context[] params = new MutableContext[this.numPreds];
        for (int pi = 0; pi < this.numPreds; ++pi) {
            params[pi] = new MutableContext(allOutcomesPattern, new double[this.numOutcomes]);
            for (int aoi = 0; aoi < this.numOutcomes; ++aoi) {
                ((MutableContext)params[pi]).setParameter(aoi, 0.0);
            }
        }
        EvalParameters evalParams = new EvalParameters(params, this.numOutcomes);
        MutableContext[] summedParams = new MutableContext[this.numPreds];
        if (useAverage) {
            for (int pi = 0; pi < this.numPreds; ++pi) {
                summedParams[pi] = new MutableContext(allOutcomesPattern, new double[this.numOutcomes]);
                for (int aoi = 0; aoi < this.numOutcomes; ++aoi) {
                    summedParams[pi].setParameter(aoi, 0.0);
                }
            }
        }
        TrainingProgressMonitor progressMonitor = this.getTrainingProgressMonitor(this.trainingConfiguration);
        StopCriteria<Double> stopCriteria = this.getStopCriteria(this.trainingConfiguration);
        double prevAccuracy1 = 0.0;
        double prevAccuracy2 = 0.0;
        double prevAccuracy3 = 0.0;
        int numTimesSummed = 0;
        double stepsize = 1.0;
        for (int i = 1; i <= iterations; ++i) {
            boolean doAveraging;
            if (this.stepSizeDecrease != null) {
                stepsize *= 1.0 - this.stepSizeDecrease;
            }
            int numCorrect = 0;
            for (int ei = 0; ei < this.numUniqueEvents; ++ei) {
                int targetOutcome = this.outcomeList[ei];
                for (int ni = 0; ni < this.numTimesEventsSeen[ei]; ++ni) {
                    double[] modelDistribution = new double[this.numOutcomes];
                    if (this.values != null) {
                        PerceptronModel.eval(this.contexts[ei], this.values[ei], modelDistribution, evalParams, false);
                    } else {
                        PerceptronModel.eval(this.contexts[ei], null, modelDistribution, evalParams, false);
                    }
                    int maxOutcome = ArrayMath.argmax(modelDistribution);
                    if (maxOutcome != targetOutcome) {
                        for (int ci = 0; ci < this.contexts[ei].length; ++ci) {
                            int pi = this.contexts[ei][ci];
                            if (this.values == null) {
                                ((MutableContext)params[pi]).updateParameter(targetOutcome, stepsize);
                                ((MutableContext)params[pi]).updateParameter(maxOutcome, -stepsize);
                                continue;
                            }
                            ((MutableContext)params[pi]).updateParameter(targetOutcome, stepsize * (double)this.values[ei][ci]);
                            ((MutableContext)params[pi]).updateParameter(maxOutcome, -stepsize * (double)this.values[ei][ci]);
                        }
                    }
                    if (maxOutcome != targetOutcome) continue;
                    ++numCorrect;
                }
            }
            double trainingAccuracy = (double)numCorrect / (double)this.numEvents;
            if (i < 10 || i % 10 == 0) {
                progressMonitor.finishedIteration(i, numCorrect, this.numEvents, TrainingMeasure.ACCURACY, trainingAccuracy);
            }
            boolean bl = doAveraging = useAverage && this.useSkippedlAveraging && (i < 20 || PerceptronTrainer.isPerfectSquare(i)) || useAverage;
            if (doAveraging) {
                ++numTimesSummed;
                for (int pi = 0; pi < this.numPreds; ++pi) {
                    for (int aoi = 0; aoi < this.numOutcomes; ++aoi) {
                        summedParams[pi].updateParameter(aoi, params[pi].getParameters()[aoi]);
                    }
                }
            }
            if (stopCriteria.test(prevAccuracy1 - trainingAccuracy) && stopCriteria.test(prevAccuracy2 - trainingAccuracy) && stopCriteria.test(prevAccuracy3 - trainingAccuracy)) {
                progressMonitor.finishedTraining(iterations, stopCriteria);
                break;
            }
            prevAccuracy1 = prevAccuracy2;
            prevAccuracy2 = prevAccuracy3;
            prevAccuracy3 = trainingAccuracy;
        }
        if (!progressMonitor.isTrainingFinished()) {
            progressMonitor.finishedTraining(iterations, null);
        }
        progressMonitor.display(true);
        this.trainingStats(evalParams);
        if (useAverage) {
            for (int pi = 0; pi < this.numPreds; ++pi) {
                for (int aoi = 0; aoi < this.numOutcomes; ++aoi) {
                    summedParams[pi].setParameter(aoi, summedParams[pi].getParameters()[aoi] / (double)numTimesSummed);
                }
            }
            return summedParams;
        }
        return params;
    }

    private double trainingStats(EvalParameters evalParams) {
        int numCorrect = 0;
        for (int ei = 0; ei < this.numUniqueEvents; ++ei) {
            for (int ni = 0; ni < this.numTimesEventsSeen[ei]; ++ni) {
                double[] modelDistribution = new double[this.numOutcomes];
                if (this.values != null) {
                    PerceptronModel.eval(this.contexts[ei], this.values[ei], modelDistribution, evalParams, false);
                } else {
                    PerceptronModel.eval(this.contexts[ei], null, modelDistribution, evalParams, false);
                }
                int max = ArrayMath.argmax(modelDistribution);
                if (max != this.outcomeList[ei]) continue;
                ++numCorrect;
            }
        }
        double trainingAccuracy = (double)numCorrect / (double)this.numEvents;
        logger.info("Stats: ({}/{}) {}", new Object[]{numCorrect, this.numEvents, trainingAccuracy});
        return trainingAccuracy;
    }

    private static boolean isPerfectSquare(int n) {
        int root = (int)StrictMath.sqrt(n);
        return root * root == n;
    }

    private StopCriteria<Double> getStopCriteria(TrainingConfiguration trainingConfig) {
        return trainingConfig != null && trainingConfig.stopCriteria() != null ? trainingConfig.stopCriteria() : new IterDeltaAccuracyUnderTolerance(this.trainingParameters);
    }

    private TrainingProgressMonitor getTrainingProgressMonitor(TrainingConfiguration trainingConfig) {
        return trainingConfig != null && trainingConfig.progMon() != null ? trainingConfig.progMon() : new DefaultTrainingProgressMonitor();
    }
}

