/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.jaicore.ml.scikitwrapper;

import ai.libs.jaicore.basic.FileUtil;
import ai.libs.jaicore.basic.ResourceUtil;
import ai.libs.jaicore.ml.classification.singlelabel.SingleLabelClassification;
import ai.libs.jaicore.ml.classification.singlelabel.SingleLabelClassificationPredictionBatch;
import ai.libs.jaicore.ml.core.dataset.serialization.ArffDatasetAdapter;
import ai.libs.jaicore.ml.core.learner.ASupervisedLearner;
import ai.libs.jaicore.ml.scikitwrapper.AProcessListener;
import ai.libs.jaicore.ml.scikitwrapper.DefaultProcessListener;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.OutputStream;
import java.nio.file.Files;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
import org.apache.commons.lang3.StringUtils;
import org.api4.java.ai.ml.classification.singlelabel.evaluation.ISingleLabelClassification;
import org.api4.java.ai.ml.classification.singlelabel.evaluation.ISingleLabelClassificationPredictionBatch;
import org.api4.java.ai.ml.core.dataset.supervised.ILabeledDataset;
import org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance;
import org.api4.java.ai.ml.core.exception.DatasetCreationException;
import org.api4.java.ai.ml.core.exception.PredictionException;
import org.api4.java.ai.ml.core.exception.TrainingException;
import org.api4.java.ai.ml.core.learner.ISupervisedLearner;
import org.jtwig.JtwigModel;
import org.jtwig.JtwigTemplate;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ScikitLearnWrapper
extends ASupervisedLearner<ILabeledInstance, ILabeledDataset<? extends ILabeledInstance>, ISingleLabelClassification, ISingleLabelClassificationPredictionBatch>
implements ISupervisedLearner<ILabeledInstance, ILabeledDataset<? extends ILabeledInstance>> {
    private static final String PYTHON_FILE_EXT = ".py";
    private static final String MODEL_DUMP_FILE_EXT = ".pcl";
    private static final String RESULT_FILE_EXT = ".json";
    private static final Logger L = LoggerFactory.getLogger(ScikitLearnWrapper.class);
    private static final File TMP_FOLDER = new File("tmp");
    private static final String RES_SCIKIT_TEMPLATE_PATH = "sklearn/scikit_template.twig.py";
    private static final File SCIKIT_TEMPLATE = new File(ResourceUtil.getResourceAsTempFile((String)"sklearn/scikit_template.twig.py"));
    private static final File MODEL_DUMPS_DIRECTORY = new File(TMP_FOLDER, "model_dumps");
    private static final boolean VERBOSE = false;
    private static final boolean DELETE_TEMPORARY_FILES_ON_EXIT = true;
    private ILabeledDataset<ILabeledInstance> dataset;
    private ProblemType problemType = ProblemType.CLASSIFICATION;
    private int[] targetColumns = new int[0];
    private final String configurationUID;
    private File modelFile;
    private File trainArff;
    private final boolean withoutModelDump;
    private String constructInstruction;
    private List<List<Double>> rawLastClassificationResults = null;

    public ScikitLearnWrapper(String constructInstruction, String imports, boolean withoutModelDump) throws IOException {
        File scriptFile;
        this.withoutModelDump = withoutModelDump;
        this.constructInstruction = constructInstruction;
        Map<String, Object> templateValues = this.getTemplateValueMap(constructInstruction, imports);
        String hashCode = StringUtils.join((Object[])new String[]{constructInstruction, imports}).hashCode() + "";
        String string = this.configurationUID = hashCode.startsWith("-") ? hashCode.replace("-", "1") : "0" + hashCode;
        if (!TMP_FOLDER.exists()) {
            TMP_FOLDER.mkdirs();
        }
        if (!(scriptFile = this.getSKLearnScriptFile()).createNewFile() && L.isDebugEnabled()) {
            L.debug("Script file for configuration UID {} already exists in {}", (Object)this.configurationUID, (Object)scriptFile.getAbsolutePath());
        }
        scriptFile.deleteOnExit();
        JtwigTemplate template = JtwigTemplate.fileTemplate((File)SCIKIT_TEMPLATE);
        JtwigModel model = JtwigModel.newModel(templateValues);
        template.render(model, (OutputStream)new FileOutputStream(scriptFile));
    }

    public ScikitLearnWrapper(String constructInstruction, String imports) throws IOException {
        this(constructInstruction, imports, false);
    }

    public ScikitLearnWrapper(String constructInstruction, String imports, File trainedModelPath) throws IOException {
        this(constructInstruction, imports, false);
        this.modelFile = trainedModelPath;
    }

    private File getSKLearnScriptFile() {
        Objects.requireNonNull(this.configurationUID);
        return new File(TMP_FOLDER, this.configurationUID + PYTHON_FILE_EXT);
    }

    private File getResultFile(String arffName) {
        return new File(MODEL_DUMPS_DIRECTORY, arffName + "_" + this.configurationUID + RESULT_FILE_EXT);
    }

    public void fit(ILabeledDataset<? extends ILabeledInstance> data) throws TrainingException, InterruptedException {
        try {
            MODEL_DUMPS_DIRECTORY.mkdirs();
            String arffName = this.getArffName(data);
            this.trainArff = this.getArffFile(data, arffName);
            this.dataset = data.createEmptyCopy();
            if (!this.withoutModelDump) {
                this.modelFile = new File(MODEL_DUMPS_DIRECTORY, this.configurationUID + "_" + arffName + MODEL_DUMP_FILE_EXT);
                Object[] trainCommand = new SKLearnWrapperCommandBuilder().withTrainMode().withArffFile(this.trainArff).withOutputFile(this.modelFile).toCommandArray();
                if (L.isDebugEnabled()) {
                    L.debug("{} run train mode {}", (Object)Thread.currentThread().getName(), (Object)Arrays.toString(trainCommand));
                }
                DefaultProcessListener listener = new DefaultProcessListener(false);
                this.runProcess((String[])trainCommand, listener);
                if (!listener.getErrorOutput().isEmpty()) {
                    L.error("Raise error message");
                    throw new TrainingException(listener.getErrorOutput().split("\\n")[0]);
                }
            }
        }
        catch (TrainingException e) {
            throw e;
        }
        catch (Exception e) {
            throw new TrainingException("An exception occurred while training.", (Throwable)e);
        }
    }

    private File getArffFile(ILabeledDataset<? extends ILabeledInstance> data, String arffName) throws IOException {
        File arffOutputFile = new File(TMP_FOLDER, arffName + ".arff");
        arffOutputFile.deleteOnExit();
        if (arffOutputFile.exists()) {
            L.debug("Reusing {}.arff", (Object)arffName);
            return arffOutputFile;
        }
        ArffDatasetAdapter.serializeDataset(arffOutputFile, data);
        return arffOutputFile;
    }

    @Override
    public ISingleLabelClassification predict(ILabeledInstance instance) throws PredictionException, InterruptedException {
        return this.predict(new ILabeledInstance[]{instance}).get(0);
    }

    @Override
    public ISingleLabelClassificationPredictionBatch predict(ILabeledInstance[] dTest) throws PredictionException, InterruptedException {
        Object[] testCommand;
        File testArff;
        ILabeledDataset data;
        try {
            data = this.dataset.createEmptyCopy();
        }
        catch (DatasetCreationException e1) {
            throw new PredictionException("Could not replicate labeled dataset instance", (Throwable)e1);
        }
        Arrays.stream(dTest).forEach(arg_0 -> data.add(arg_0));
        MODEL_DUMPS_DIRECTORY.mkdirs();
        String arffName = this.getArffName((ILabeledDataset<? extends ILabeledInstance>)data);
        try {
            testArff = this.getArffFile((ILabeledDataset<? extends ILabeledInstance>)data, arffName);
        }
        catch (IOException e1) {
            throw new PredictionException("Could not dump arff file for prediction", (Throwable)e1);
        }
        File outputFile = this.getResultFile(arffName);
        outputFile.getParentFile().mkdirs();
        if (!this.withoutModelDump) {
            testCommand = new SKLearnWrapperCommandBuilder().withTestMode().withArffFile(testArff).withModelFile(this.modelFile).withOutputFile(outputFile).toCommandArray();
            if (L.isDebugEnabled()) {
                L.debug("Run test mode with {}", (Object)Arrays.toString(testCommand));
            }
            try {
                this.runProcess((String[])testCommand, new DefaultProcessListener(false));
            }
            catch (IOException e) {
                throw new PredictionException("Could not run scikit-learn classifier.", (Throwable)e);
            }
        }
        testCommand = new SKLearnWrapperCommandBuilder().withTrainTestMode().withArffFile(this.trainArff).withTestArffFile(testArff).withOutputFile(outputFile).toCommandArray();
        if (L.isDebugEnabled()) {
            L.debug("Run train test mode with {}", (Object)Arrays.toString(testCommand));
        }
        DefaultProcessListener listener = new DefaultProcessListener(false);
        try {
            this.runProcess((String[])testCommand, listener);
            if (!listener.getErrorOutput().isEmpty()) {
                String[] message = listener.getErrorOutput().split("\\n");
                throw new PredictionException(message[message.length - 1].trim());
            }
        }
        catch (InterruptedException | PredictionException e) {
            throw e;
        }
        catch (Exception e) {
            throw new PredictionException("Could not run scikit-learn classifier.", (Throwable)e);
        }
        String fileContent = "";
        try {
            fileContent = FileUtil.readFileAsString((File)outputFile);
            Files.delete(outputFile.toPath());
            ObjectMapper objMapper = new ObjectMapper();
            this.rawLastClassificationResults = (List)objMapper.readValue(fileContent, List.class);
        }
        catch (IOException e) {
            throw new PredictionException("Could not read result file or parse the json content to a list", (Throwable)e);
        }
        return new SingleLabelClassificationPredictionBatch(this.rawLastClassificationResults.stream().flatMap(Collection::stream).map(x -> new SingleLabelClassification((int)x.doubleValue())).collect(Collectors.toList()));
    }

    public static String createImportStatementFromImportFolder(File importsFolder, boolean keepNamespace) throws IOException {
        File initFile;
        if (importsFolder == null || !importsFolder.exists() || importsFolder.list().length == 0) {
            return "";
        }
        if (!Arrays.asList(importsFolder.list()).contains("__init__.py") && !(initFile = new File(importsFolder, "__init__.py")).createNewFile() && L.isDebugEnabled()) {
            L.debug("Init file {} exists already", (Object)initFile.getAbsolutePath());
        }
        StringBuilder result = new StringBuilder();
        String absoluteFolderPath = importsFolder.getAbsolutePath();
        result.append("\n");
        result.append("sys.path.append(r'" + absoluteFolderPath + "')\n");
        for (File module : importsFolder.listFiles()) {
            if (module.getName().startsWith("__")) continue;
            if (keepNamespace) {
                result.append("import " + module.getName().substring(0, module.getName().length() - 3) + "\n");
                continue;
            }
            result.append("from " + module.getName().substring(0, module.getName().length() - 3) + " import *\n");
        }
        return result.toString();
    }

    private Map<String, Object> getTemplateValueMap(String constructInstruction, String imports) {
        if (constructInstruction == null || constructInstruction.isEmpty()) {
            throw new AssertionError((Object)"Construction command for classifier must be stated.");
        }
        HashMap<String, Object> templateValues = new HashMap<String, Object>();
        templateValues.put("imports", imports != null ? imports : "");
        templateValues.put("classifier_construct", constructInstruction);
        return templateValues;
    }

    public static String getImportString(Collection<String> imports) {
        return imports == null || imports.isEmpty() ? "" : "import " + StringUtils.join(imports, (String)"\nimport ");
    }

    public List<List<Double>> getRawLastClassificationResults() {
        return this.rawLastClassificationResults;
    }

    public void setProblemType(ProblemType problemType) {
        this.problemType = problemType;
    }

    public void setTargets(int ... targetColumns) {
        this.targetColumns = targetColumns;
    }

    public void setModelPath(File modelFile) {
        this.modelFile = modelFile;
    }

    public File getModelPath() {
        return this.modelFile;
    }

    private String getArffName(ILabeledDataset<? extends ILabeledInstance> data) {
        String hash = "" + data.hashCode();
        hash = hash.startsWith("-") ? hash.replace("-", "1") : "0" + hash;
        return hash;
    }

    private void runProcess(String[] parameters, AProcessListener listener) throws InterruptedException, IOException {
        if (L.isDebugEnabled()) {
            String call = Arrays.toString(parameters).replace(",", "");
            L.debug("Starting process {}", (Object)call.substring(1, call.length() - 1));
        }
        ProcessBuilder processBuilder = new ProcessBuilder(parameters).directory(TMP_FOLDER);
        listener.listenTo(processBuilder.start());
    }

    public double[] distributionForInstance(ILabeledInstance instance) {
        throw new UnsupportedOperationException("This method is not yet implemented");
    }

    public String toString() {
        return this.constructInstruction;
    }

    private class SKLearnWrapperCommandBuilder {
        private static final String ARFF_FLAG = "--arff";
        private static final String TEST_ARFF_FLAG = "--testarff";
        private static final String MODE_FLAG = "--mode";
        private static final String MODEL_FLAG = "--model";
        private static final String OUTPUT_FLAG = "--output";
        private static final String REGRESSION_FLAG = "--regression";
        private String arffFile;
        private String testArffFile;
        private WrapperExecutionMode mode;
        private String modelFile;
        private String outputFile;

        private SKLearnWrapperCommandBuilder() {
        }

        public SKLearnWrapperCommandBuilder withTestArffFile(File testArffFile) {
            this.testArffFile = testArffFile.getAbsolutePath();
            return this;
        }

        public SKLearnWrapperCommandBuilder withTrainMode() {
            return this.withMode(WrapperExecutionMode.TRAIN);
        }

        public SKLearnWrapperCommandBuilder withTestMode() {
            return this.withMode(WrapperExecutionMode.TEST);
        }

        public SKLearnWrapperCommandBuilder withTrainTestMode() {
            return this.withMode(WrapperExecutionMode.TRAIN_TEST);
        }

        private SKLearnWrapperCommandBuilder withMode(WrapperExecutionMode execMode) {
            this.mode = execMode;
            return this;
        }

        private SKLearnWrapperCommandBuilder withModelFile(File modelFile) {
            if (!modelFile.exists()) {
                throw new IllegalArgumentException("Model dump does not exist");
            }
            this.modelFile = modelFile.getAbsolutePath();
            return this;
        }

        private SKLearnWrapperCommandBuilder withOutputFile(File outputFile) {
            this.outputFile = outputFile.getAbsolutePath();
            return this;
        }

        private SKLearnWrapperCommandBuilder withArffFile(File arffFile) {
            if (!arffFile.exists()) {
                throw new IllegalArgumentException("Arff File does not exist.");
            }
            this.arffFile = arffFile.getAbsolutePath();
            return this;
        }

        private String[] toCommandArray() {
            Objects.requireNonNull(this.mode);
            Objects.requireNonNull(this.outputFile);
            Objects.requireNonNull(this.arffFile);
            File scriptFile = ScikitLearnWrapper.this.getSKLearnScriptFile();
            if (!scriptFile.exists()) {
                throw new IllegalArgumentException("The wrapped sklearn script " + scriptFile.getAbsolutePath() + " file does not exist");
            }
            ArrayList<String> processParameters = new ArrayList<String>();
            processParameters.add("python");
            processParameters.add("-u");
            processParameters.add(scriptFile.getAbsolutePath());
            processParameters.addAll(Arrays.asList(MODE_FLAG, this.mode.toString()));
            processParameters.addAll(Arrays.asList(ARFF_FLAG, this.arffFile));
            if (this.testArffFile != null) {
                processParameters.addAll(Arrays.asList(TEST_ARFF_FLAG, this.testArffFile));
            }
            processParameters.addAll(Arrays.asList(OUTPUT_FLAG, this.outputFile));
            if (ScikitLearnWrapper.this.problemType == ProblemType.REGRESSION) {
                processParameters.add(REGRESSION_FLAG);
            }
            if (this.mode == WrapperExecutionMode.TEST) {
                Objects.requireNonNull(this.modelFile);
                processParameters.addAll(Arrays.asList(MODEL_FLAG, this.modelFile));
            }
            if (ScikitLearnWrapper.this.targetColumns != null && ScikitLearnWrapper.this.targetColumns.length > 0) {
                processParameters.add("--targets");
                for (int i : ScikitLearnWrapper.this.targetColumns) {
                    processParameters.add("" + i);
                }
            }
            return processParameters.toArray(new String[0]);
        }
    }

    private static enum WrapperExecutionMode {
        TRAIN("train"),
        TEST("test"),
        TRAIN_TEST("traintest");

        private String name;

        private WrapperExecutionMode(String name) {
            this.name = name;
        }

        public String toString() {
            return this.name;
        }
    }

    public static enum ProblemType {
        REGRESSION,
        CLASSIFICATION;

    }
}

