/*
 * Decompiled with CFR 0.152.
 */
package ai.idylnlp.models.opennlp.training;

import ai.idylnlp.model.nlp.subjects.SubjectOfTrainingOrEvaluation;
import ai.idylnlp.model.training.FMeasureModelValidationResult;
import ai.idylnlp.models.ModelOperationsUtils;
import ai.idylnlp.models.ObjectStreamUtils;
import ai.idylnlp.models.opennlp.training.model.ModelCrossValidationOperations;
import ai.idylnlp.models.opennlp.training.model.ModelSeparateDataValidationOperations;
import ai.idylnlp.models.opennlp.training.model.ModelTrainingOperations;
import ai.idylnlp.models.opennlp.training.model.TrainingAlgorithm;
import ai.idylnlp.opennlp.custom.encryption.OpenNLPEncryptionFactory;
import ai.idylnlp.training.definition.model.TrainingDefinitionReader;
import com.neovisionaries.i18n.LanguageCode;
import java.io.BufferedOutputStream;
import java.io.File;
import java.io.FileOutputStream;
import java.io.FilterOutputStream;
import java.io.IOException;
import java.io.OutputStream;
import java.nio.charset.Charset;
import java.util.HashMap;
import java.util.LinkedList;
import opennlp.tools.cmdline.namefind.NameEvaluationErrorListener;
import opennlp.tools.namefind.BioCodec;
import opennlp.tools.namefind.NameFinderME;
import opennlp.tools.namefind.TokenNameFinder;
import opennlp.tools.namefind.TokenNameFinderCrossValidator;
import opennlp.tools.namefind.TokenNameFinderEvaluationMonitor;
import opennlp.tools.namefind.TokenNameFinderEvaluator;
import opennlp.tools.namefind.TokenNameFinderFactory;
import opennlp.tools.namefind.TokenNameFinderModel;
import opennlp.tools.util.ObjectStream;
import opennlp.tools.util.SequenceCodec;
import opennlp.tools.util.TrainingParameters;
import opennlp.tools.util.eval.FMeasure;
import org.apache.commons.lang.StringUtils;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

public class EntityModelOperations
implements ModelTrainingOperations,
ModelSeparateDataValidationOperations<FMeasureModelValidationResult>,
ModelCrossValidationOperations<FMeasureModelValidationResult> {
    private static final Logger LOGGER = LogManager.getLogger(EntityModelOperations.class);
    private String type;
    private String featureGeneratorXml;

    public static String train(TrainingDefinitionReader reader) throws IOException {
        String type = reader.getTrainingDefinition().getModel().getType();
        String featureGeneratorXml = reader.getFeatures();
        EntityModelOperations ops = new EntityModelOperations(type, featureGeneratorXml);
        SubjectOfTrainingOrEvaluation subjectOfTraining = ModelOperationsUtils.getSubjectOfTrainingOrEvaluation(reader);
        String modelFile = reader.getTrainingDefinition().getModel().getFile();
        String language = reader.getTrainingDefinition().getModel().getLanguage();
        String encryptionKey = reader.getTrainingDefinition().getModel().getEncryptionkey();
        int cutOff = reader.getTrainingDefinition().getAlgorithm().getCutoff().intValue();
        int iterations = reader.getTrainingDefinition().getAlgorithm().getIterations().intValue();
        int threads = reader.getTrainingDefinition().getAlgorithm().getThreads().intValue();
        String algorithm = reader.getTrainingDefinition().getAlgorithm().getName();
        LanguageCode languageCode = LanguageCode.getByCodeIgnoreCase((String)language);
        if (algorithm.equalsIgnoreCase(TrainingAlgorithm.PERCEPTRON.getName())) {
            return ops.trainPerceptron(subjectOfTraining, modelFile, languageCode, encryptionKey, cutOff, iterations);
        }
        if (algorithm.equalsIgnoreCase(TrainingAlgorithm.MAXENT_QN.getName())) {
            double l1 = reader.getTrainingDefinition().getAlgorithm().getL1().doubleValue();
            double l2 = reader.getTrainingDefinition().getAlgorithm().getL2().doubleValue();
            int m = reader.getTrainingDefinition().getAlgorithm().getM().intValue();
            int max = reader.getTrainingDefinition().getAlgorithm().getMax().intValue();
            return ops.trainMaxEntQN(subjectOfTraining, modelFile, languageCode, encryptionKey, cutOff, iterations, threads, l1, l2, m, max);
        }
        throw new IOException("Invalid algorithm specified in the training definition file: " + algorithm);
    }

    public static FMeasureModelValidationResult crossValidate(TrainingDefinitionReader reader, int folds) throws IOException {
        String language = reader.getTrainingDefinition().getModel().getLanguage();
        int iterations = reader.getTrainingDefinition().getAlgorithm().getIterations().intValue();
        int cutoff = reader.getTrainingDefinition().getAlgorithm().getCutoff().intValue();
        String featureGeneratorXml = reader.getFeatures();
        String type = reader.getTrainingDefinition().getModel().getType();
        String algorithm = reader.getTrainingDefinition().getAlgorithm().getName();
        double l1 = reader.getTrainingDefinition().getAlgorithm().getL1().doubleValue();
        double l2 = reader.getTrainingDefinition().getAlgorithm().getL2().doubleValue();
        int m = reader.getTrainingDefinition().getAlgorithm().getM().intValue();
        int max = reader.getTrainingDefinition().getAlgorithm().getMax().intValue();
        LanguageCode languageCode = LanguageCode.getByCodeIgnoreCase((String)language);
        SubjectOfTrainingOrEvaluation subjectOfTraining = ModelOperationsUtils.getSubjectOfTrainingOrEvaluation(reader);
        EntityModelOperations entityModelOperations = new EntityModelOperations(type, featureGeneratorXml);
        FMeasureModelValidationResult result = null;
        if (StringUtils.equalsIgnoreCase((String)algorithm, (String)TrainingAlgorithm.PERCEPTRON.getName())) {
            result = entityModelOperations.crossValidationEvaluatePerceptron(subjectOfTraining, languageCode, iterations, cutoff, folds);
        } else if (StringUtils.equalsIgnoreCase((String)algorithm, (String)TrainingAlgorithm.MAXENT_QN.getName())) {
            result = entityModelOperations.crossValidationEvaluateMaxEntQN(subjectOfTraining, languageCode, iterations, cutoff, folds, l1, l2, m, max);
        } else {
            throw new IOException("Invalid algorithm specified in the training definition file: " + algorithm);
        }
        return result;
    }

    public EntityModelOperations(String type, String featureGeneratorXml) {
        this.type = type;
        this.featureGeneratorXml = featureGeneratorXml;
    }

    @Override
    public FMeasureModelValidationResult crossValidationEvaluateMaxEntQN(SubjectOfTrainingOrEvaluation subjectOfTraining, LanguageCode language, int iterations, int cutOff, int folds, double l1, double l2, int m, int max) throws IOException {
        LOGGER.info("Doing model evaluation using cross-validation with {} folds.", (Object)folds);
        ObjectStream sampleStream = ObjectStreamUtils.getObjectStream((SubjectOfTrainingOrEvaluation)subjectOfTraining);
        TrainingParameters trainParams = new TrainingParameters();
        trainParams.put("Cutoff", Integer.toString(cutOff));
        trainParams.put("Iterations", Integer.toString(iterations));
        trainParams.put("Algorithm", TrainingAlgorithm.MAXENT_QN.getAlgorithm());
        trainParams.put("L1Cost", String.valueOf(l1));
        trainParams.put("L2Cost", String.valueOf(l2));
        trainParams.put("NumOfUpdates", String.valueOf(m));
        trainParams.put("MaxFctEval", String.valueOf(max));
        byte[] featureGeneratorBytes = this.featureGeneratorXml.getBytes(Charset.forName("UTF-8"));
        HashMap resources = new HashMap();
        NameEvaluationErrorListener monitor = new NameEvaluationErrorListener();
        TokenNameFinderCrossValidator evaluator = new TokenNameFinderCrossValidator(language.getAlpha3().toString(), this.type, trainParams, featureGeneratorBytes, resources, new TokenNameFinderEvaluationMonitor[]{monitor});
        evaluator.evaluate(sampleStream, folds);
        LinkedList<ai.idylnlp.model.training.FMeasure> fmeasures = new LinkedList<ai.idylnlp.model.training.FMeasure>();
        for (FMeasure f : evaluator.getFMeasures()) {
            fmeasures.add(new ai.idylnlp.model.training.FMeasure(f.getPrecisionScore(), f.getRecallScore(), f.getFMeasure()));
        }
        ai.idylnlp.model.training.FMeasure fmeasure = new ai.idylnlp.model.training.FMeasure(evaluator.getFMeasure().getPrecisionScore(), evaluator.getFMeasure().getRecallScore(), evaluator.getFMeasure().getFMeasure());
        return new FMeasureModelValidationResult(fmeasure, fmeasures);
    }

    @Override
    public FMeasureModelValidationResult crossValidationEvaluatePerceptron(SubjectOfTrainingOrEvaluation subjectOfTraining, LanguageCode language, int iterations, int cutOff, int folds) throws IOException {
        LOGGER.info("Doing model evaluation using cross-validation with {} folds.", (Object)folds);
        ObjectStream sampleStream = ObjectStreamUtils.getObjectStream((SubjectOfTrainingOrEvaluation)subjectOfTraining);
        TrainingParameters trainParams = new TrainingParameters();
        trainParams.put("Cutoff", Integer.toString(cutOff));
        trainParams.put("Iterations", Integer.toString(iterations));
        trainParams.put("Algorithm", TrainingAlgorithm.PERCEPTRON.getAlgorithm());
        byte[] featureGeneratorBytes = this.featureGeneratorXml.getBytes(Charset.forName("UTF-8"));
        HashMap resources = new HashMap();
        NameEvaluationErrorListener monitor = new NameEvaluationErrorListener();
        TokenNameFinderCrossValidator evaluator = new TokenNameFinderCrossValidator(language.getAlpha3().toString(), this.type, trainParams, featureGeneratorBytes, resources, new TokenNameFinderEvaluationMonitor[]{monitor});
        evaluator.evaluate(sampleStream, folds);
        LinkedList<ai.idylnlp.model.training.FMeasure> fmeasures = new LinkedList<ai.idylnlp.model.training.FMeasure>();
        for (FMeasure f : evaluator.getFMeasures()) {
            fmeasures.add(new ai.idylnlp.model.training.FMeasure(f.getPrecisionScore(), f.getRecallScore(), f.getFMeasure()));
        }
        ai.idylnlp.model.training.FMeasure fmeasure = new ai.idylnlp.model.training.FMeasure(evaluator.getFMeasure().getPrecisionScore(), evaluator.getFMeasure().getRecallScore(), evaluator.getFMeasure().getFMeasure());
        return new FMeasureModelValidationResult(fmeasure, fmeasures);
    }

    @Override
    public FMeasureModelValidationResult separateDataEvaluate(SubjectOfTrainingOrEvaluation subjectOfTraining, String modelFileName, String encryptionKey) throws IOException {
        LOGGER.info("Doing model evaluation using separate training data.");
        ObjectStream sampleStream = ObjectStreamUtils.getObjectStream((SubjectOfTrainingOrEvaluation)subjectOfTraining);
        OpenNLPEncryptionFactory.getDefault().setKey(encryptionKey);
        TokenNameFinderModel model = new TokenNameFinderModel(new File(modelFileName));
        NameFinderME nameFinderME = new NameFinderME(model);
        TokenNameFinderEvaluator evaluator = new TokenNameFinderEvaluator((TokenNameFinder)nameFinderME, new TokenNameFinderEvaluationMonitor[0]);
        evaluator.evaluate(sampleStream);
        OpenNLPEncryptionFactory.getDefault().clearKey();
        ai.idylnlp.model.training.FMeasure fmeasure = new ai.idylnlp.model.training.FMeasure(evaluator.getFMeasure().getPrecisionScore(), evaluator.getFMeasure().getRecallScore(), evaluator.getFMeasure().getFMeasure());
        return new FMeasureModelValidationResult(fmeasure);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    public String trainPerceptron(SubjectOfTrainingOrEvaluation subjectOfTraining, String modelFile, LanguageCode language, String encryptionKey, int cutOff, int iterations) throws IOException {
        LOGGER.info("Beginning entity model training. Output model will be: {}", (Object)modelFile);
        ObjectStream sampleStream = ObjectStreamUtils.getObjectStream((SubjectOfTrainingOrEvaluation)subjectOfTraining);
        TrainingParameters trainParams = new TrainingParameters();
        trainParams.put("Cutoff", Integer.toString(cutOff));
        trainParams.put("Iterations", Integer.toString(iterations));
        trainParams.put("Algorithm", TrainingAlgorithm.PERCEPTRON.getAlgorithm());
        BioCodec sequenceCodec = new BioCodec();
        byte[] featureGeneratorBytes = this.featureGeneratorXml.getBytes(Charset.forName("UTF-8"));
        HashMap resources = new HashMap();
        TokenNameFinderFactory tokenNameFinderFactory = TokenNameFinderFactory.create((String)TokenNameFinderFactory.class.getName(), (byte[])featureGeneratorBytes, resources, (SequenceCodec)sequenceCodec);
        OpenNLPEncryptionFactory.getDefault().setKey(encryptionKey);
        TokenNameFinderModel model = NameFinderME.train((String)language.getAlpha3().toString(), (String)this.type, (ObjectStream)sampleStream, (TrainingParameters)trainParams, (TokenNameFinderFactory)tokenNameFinderFactory);
        FilterOutputStream modelOut = null;
        String modelId = "";
        try {
            modelOut = new BufferedOutputStream(new FileOutputStream(modelFile));
            modelId = model.serialize((OutputStream)modelOut);
        }
        catch (Exception ex) {
            LOGGER.error("Unable to create the model.", (Throwable)ex);
        }
        finally {
            if (modelOut != null) {
                modelOut.close();
            }
            OpenNLPEncryptionFactory.getDefault().clearKey();
        }
        return modelId;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    public String trainMaxEntQN(SubjectOfTrainingOrEvaluation subjectOfTraining, String modelFile, LanguageCode language, String encryptionKey, int cutOff, int iterations, int threads, double l1, double l2, int m, int max) throws IOException {
        LOGGER.info("Beginning entity model training with {} threads. Output model will be: {}", (Object)threads, (Object)modelFile);
        ObjectStream sampleStream = ObjectStreamUtils.getObjectStream((SubjectOfTrainingOrEvaluation)subjectOfTraining);
        TrainingParameters trainParams = new TrainingParameters();
        trainParams.put("Cutoff", Integer.toString(cutOff));
        trainParams.put("Iterations", Integer.toString(iterations));
        trainParams.put("Algorithm", TrainingAlgorithm.MAXENT_QN.getAlgorithm());
        trainParams.put("Threads", Integer.toString(threads));
        trainParams.put("L1Cost", String.valueOf(l1));
        trainParams.put("L2Cost", String.valueOf(l2));
        trainParams.put("NumOfUpdates", String.valueOf(m));
        trainParams.put("MaxFctEval", String.valueOf(max));
        SequenceCodec sequenceCodec = TokenNameFinderFactory.instantiateSequenceCodec(null);
        byte[] featureGeneratorBytes = this.featureGeneratorXml.getBytes(Charset.forName("UTF-8"));
        HashMap resources = new HashMap();
        TokenNameFinderFactory tokenNameFinderFactory = TokenNameFinderFactory.create((String)TokenNameFinderFactory.class.getName(), (byte[])featureGeneratorBytes, resources, (SequenceCodec)sequenceCodec);
        OpenNLPEncryptionFactory.getDefault().setKey(encryptionKey);
        TokenNameFinderModel model = NameFinderME.train((String)language.getAlpha3().toString(), (String)this.type, (ObjectStream)sampleStream, (TrainingParameters)trainParams, (TokenNameFinderFactory)tokenNameFinderFactory);
        FilterOutputStream modelOut = null;
        String modelId = "";
        try {
            modelOut = new BufferedOutputStream(new FileOutputStream(modelFile));
            modelId = model.serialize((OutputStream)modelOut);
        }
        catch (Exception ex) {
            LOGGER.error("Unable to create the model.", (Throwable)ex);
        }
        finally {
            if (modelOut != null) {
                modelOut.close();
            }
            OpenNLPEncryptionFactory.getDefault().clearKey();
        }
        return modelId;
    }
}

