/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.jaicore.ml.scikitwrapper;

import ai.libs.jaicore.basic.FileUtil;
import ai.libs.jaicore.basic.ResourceUtil;
import ai.libs.jaicore.ml.classification.singlelabel.SingleLabelClassification;
import ai.libs.jaicore.ml.classification.singlelabel.SingleLabelClassificationPredictionBatch;
import ai.libs.jaicore.ml.core.EScikitLearnProblemType;
import ai.libs.jaicore.ml.core.dataset.serialization.ArffDatasetAdapter;
import ai.libs.jaicore.ml.core.learner.ASupervisedLearner;
import ai.libs.jaicore.ml.regression.singlelabel.SingleTargetRegressionPrediction;
import ai.libs.jaicore.ml.regression.singlelabel.SingleTargetRegressionPredictionBatch;
import ai.libs.jaicore.ml.scikitwrapper.AProcessListener;
import ai.libs.jaicore.ml.scikitwrapper.DefaultProcessListener;
import ai.libs.jaicore.ml.scikitwrapper.IScikitLearnWrapperConfig;
import ai.libs.jaicore.processes.EOperatingSystem;
import ai.libs.jaicore.processes.ProcessIDNotRetrievableException;
import ai.libs.jaicore.processes.ProcessUtil;
import ai.libs.python.IPythonConfig;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.OutputStream;
import java.nio.file.Files;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.StringJoiner;
import java.util.stream.Collectors;
import org.aeonbits.owner.ConfigCache;
import org.aeonbits.owner.ConfigFactory;
import org.apache.commons.lang3.StringUtils;
import org.api4.java.ai.ml.core.dataset.schema.attribute.ICategoricalAttribute;
import org.api4.java.ai.ml.core.dataset.schema.attribute.INumericAttribute;
import org.api4.java.ai.ml.core.dataset.supervised.ILabeledDataset;
import org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance;
import org.api4.java.ai.ml.core.evaluation.IPrediction;
import org.api4.java.ai.ml.core.evaluation.IPredictionBatch;
import org.api4.java.ai.ml.core.exception.DatasetCreationException;
import org.api4.java.ai.ml.core.exception.PredictionException;
import org.api4.java.ai.ml.core.exception.TrainingException;
import org.api4.java.ai.ml.core.learner.ISupervisedLearner;
import org.api4.java.algorithm.Timeout;
import org.jtwig.JtwigModel;
import org.jtwig.JtwigTemplate;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ScikitLearnWrapper<P extends IPrediction, B extends IPredictionBatch>
extends ASupervisedLearner<ILabeledInstance, ILabeledDataset<? extends ILabeledInstance>, P, B>
implements ISupervisedLearner<ILabeledInstance, ILabeledDataset<? extends ILabeledInstance>> {
    private static final Logger L = LoggerFactory.getLogger(ScikitLearnWrapper.class);
    private static final IScikitLearnWrapperConfig CONF = (IScikitLearnWrapperConfig)ConfigCache.getOrCreate(IScikitLearnWrapperConfig.class, (Map[])new Map[0]);
    private IPythonConfig pythonConfig = (IPythonConfig)ConfigFactory.create(IPythonConfig.class, (Map[])new Map[0]);
    private boolean listenToPidFromProcess;
    private File scikitTemplate;
    private ILabeledDataset<ILabeledInstance> dataset;
    private EScikitLearnProblemType problemType;
    private int[] targetColumns = new int[0];
    private final String configurationUID;
    private File modelFile;
    private File trainArff;
    private final boolean withModelDump;
    private String constructInstruction;
    private List<List<Double>> rawLastClassificationResults = null;
    private long seed;
    private Timeout timeout;

    public ScikitLearnWrapper(String constructInstruction, String imports, boolean withModelDump, EScikitLearnProblemType problemType) throws IOException {
        File scriptFile;
        this.listenToPidFromProcess = ProcessUtil.getOS() == EOperatingSystem.MAC || ProcessUtil.getOS() == EOperatingSystem.LINUX;
        this.withModelDump = withModelDump;
        this.constructInstruction = constructInstruction;
        this.setProblemType(problemType);
        Map<String, Object> templateValues = this.getTemplateValueMap(constructInstruction, imports);
        String hashCode = StringUtils.join((Object[])new String[]{constructInstruction, imports}).hashCode() + "";
        String string = this.configurationUID = hashCode.startsWith("-") ? hashCode.replace("-", "1") : "0" + hashCode;
        if (!CONF.getTempFolder().exists()) {
            CONF.getTempFolder().mkdirs();
        }
        if (!(scriptFile = this.getSKLearnScriptFile()).createNewFile() && L.isDebugEnabled()) {
            L.debug("Script file for configuration UID {} already exists in {}", (Object)this.configurationUID, (Object)scriptFile.getAbsolutePath());
        }
        if (CONF.getDeleteFileOnExit()) {
            scriptFile.deleteOnExit();
        }
        JtwigTemplate template = JtwigTemplate.fileTemplate((File)this.scikitTemplate);
        JtwigModel model = JtwigModel.newModel(templateValues);
        template.render(model, (OutputStream)new FileOutputStream(scriptFile));
    }

    public ScikitLearnWrapper(String constructInstruction, String imports, EScikitLearnProblemType problemType) throws IOException {
        this(constructInstruction, imports, true, problemType);
    }

    public ScikitLearnWrapper(String constructInstruction, String imports, File trainedModelPath, EScikitLearnProblemType problemType) throws IOException {
        this(constructInstruction, imports, true, problemType);
        this.modelFile = trainedModelPath;
    }

    public EScikitLearnProblemType getProblemType() {
        return this.problemType;
    }

    public IPythonConfig getPythonConfig() {
        return this.pythonConfig;
    }

    public void setPythonConfig(IPythonConfig pythonConfig) {
        this.pythonConfig = pythonConfig;
    }

    private File getSKLearnScriptFile() {
        Objects.requireNonNull(this.configurationUID);
        return new File(CONF.getTempFolder(), this.configurationUID + CONF.getPythonFileExtension());
    }

    private File getResultFile(String arffName) {
        return new File(CONF.getModelDumpsDirectory(), arffName + "_" + this.configurationUID + CONF.getResultFileExtension());
    }

    public void fit(ILabeledDataset<? extends ILabeledInstance> data) throws TrainingException, InterruptedException {
        try {
            CONF.getModelDumpsDirectory().mkdirs();
            String arffName = this.getArffName(data);
            this.trainArff = this.getArffFile(data, arffName);
            this.dataset = data.createEmptyCopy();
            if (data.getLabelAttribute() instanceof ICategoricalAttribute) {
                this.problemType = EScikitLearnProblemType.CLASSIFICATION;
            } else if (data.getLabelAttribute() instanceof INumericAttribute && this.problemType != EScikitLearnProblemType.RUL) {
                this.problemType = EScikitLearnProblemType.REGRESSION;
            }
            if (this.withModelDump) {
                this.modelFile = new File(CONF.getModelDumpsDirectory(), this.configurationUID + "_" + arffName + CONF.getPickleFileExtension());
                ScikitLearnWrapperCommandBuilder skLearnWrapperCommandBuilder = new ScikitLearnWrapperCommandBuilder().withTrainMode().withArffFile(this.trainArff).withOutputFile(this.modelFile);
                skLearnWrapperCommandBuilder.withSeed(this.seed);
                skLearnWrapperCommandBuilder.withTimeout(this.timeout);
                Object[] trainCommand = skLearnWrapperCommandBuilder.toCommandArray();
                if (L.isDebugEnabled()) {
                    L.debug("{} run train mode {}", (Object)Thread.currentThread().getName(), (Object)Arrays.toString(trainCommand));
                }
                DefaultProcessListener listener = new DefaultProcessListener(this.listenToPidFromProcess);
                this.runProcess((String[])trainCommand, listener);
                if (!listener.getErrorOutput().isEmpty()) {
                    L.error("Raise error message: {}", (Object)listener.getErrorOutput());
                    throw new TrainingException(listener.getErrorOutput().split("\\n")[0]);
                }
            }
        }
        catch (TrainingException e) {
            throw e;
        }
        catch (Exception e) {
            throw new TrainingException("An exception occurred while training.", (Throwable)e);
        }
    }

    private synchronized File getArffFile(ILabeledDataset<? extends ILabeledInstance> data, String arffName) throws IOException {
        File arffOutputFile = new File(CONF.getTempFolder(), arffName + ".arff");
        if (CONF.getDeleteFileOnExit()) {
            arffOutputFile.deleteOnExit();
        }
        if (arffOutputFile.exists()) {
            L.debug("Reusing {}.arff", (Object)arffName);
            return arffOutputFile;
        }
        ArffDatasetAdapter.serializeDataset(arffOutputFile, data);
        return arffOutputFile;
    }

    @Override
    public P predict(ILabeledInstance instance) throws PredictionException, InterruptedException {
        return (P)this.predict(new ILabeledInstance[]{instance}).get(0);
    }

    @Override
    public B predict(ILabeledInstance[] dTest) throws PredictionException, InterruptedException {
        File outputFile;
        block20: {
            Object[] testCommand;
            ScikitLearnWrapperCommandBuilder skLearnWrapperCommandBuilder;
            File testArff;
            ILabeledDataset data;
            try {
                data = this.dataset.createEmptyCopy();
            }
            catch (DatasetCreationException e1) {
                throw new PredictionException("Could not replicate labeled dataset instance", (Throwable)e1);
            }
            Arrays.stream(dTest).forEach(arg_0 -> data.add(arg_0));
            CONF.getModelDumpsDirectory().mkdirs();
            String arffName = this.getArffName((ILabeledDataset<ILabeledInstance>)data);
            try {
                testArff = this.getArffFile((ILabeledDataset<ILabeledInstance>)data, arffName);
            }
            catch (IOException e1) {
                throw new PredictionException("Could not dump arff file for prediction", (Throwable)e1);
            }
            outputFile = this.getResultFile(arffName);
            outputFile.getParentFile().mkdirs();
            if (this.withModelDump) {
                skLearnWrapperCommandBuilder = new ScikitLearnWrapperCommandBuilder().withTestMode().withArffFile(testArff).withModelFile(this.modelFile).withOutputFile(outputFile);
                skLearnWrapperCommandBuilder.withSeed(this.seed);
                skLearnWrapperCommandBuilder.withTimeout(this.timeout);
                testCommand = skLearnWrapperCommandBuilder.toCommandArray();
                if (L.isDebugEnabled()) {
                    L.debug("Run test mode with {}", (Object)Arrays.toString(testCommand));
                }
                try {
                    this.runProcess((String[])testCommand, new DefaultProcessListener(this.listenToPidFromProcess));
                }
                catch (IOException e) {
                    throw new PredictionException("Could not run scikit-learn classifier.", (Throwable)e);
                }
            }
            skLearnWrapperCommandBuilder = new ScikitLearnWrapperCommandBuilder().withTrainTestMode().withArffFile(this.trainArff).withTestArffFile(testArff).withOutputFile(outputFile);
            skLearnWrapperCommandBuilder.withSeed(this.seed);
            skLearnWrapperCommandBuilder.withTimeout(this.timeout);
            testCommand = skLearnWrapperCommandBuilder.toCommandArray();
            if (L.isDebugEnabled()) {
                L.debug("Run train test mode with {}", (Object)Arrays.toString(testCommand));
            }
            DefaultProcessListener listener = new DefaultProcessListener(this.listenToPidFromProcess);
            try {
                this.runProcess((String[])testCommand, listener);
                if (listener.getErrorOutput().isEmpty()) break block20;
                if (listener.getErrorOutput().toLowerCase().contains("convergence")) {
                    L.warn("Learner {} could not converge. Consider increase number of iterations.", (Object)this.constructInstruction);
                    break block20;
                }
                throw new PredictionException(listener.getErrorOutput());
            }
            catch (InterruptedException | PredictionException e) {
                throw e;
            }
            catch (Exception e) {
                throw new PredictionException("Could not run scikit-learn classifier.", (Throwable)e);
            }
        }
        String fileContent = "";
        try {
            fileContent = FileUtil.readFileAsString((File)outputFile);
            if (CONF.getDeleteFileOnExit()) {
                Files.delete(outputFile.toPath());
            }
            ObjectMapper objMapper = new ObjectMapper();
            this.rawLastClassificationResults = (List)objMapper.readValue(fileContent, List.class);
        }
        catch (IOException e) {
            throw new PredictionException("Could not read result file or parse the json content to a list.", (Throwable)e);
        }
        if (this.problemType == EScikitLearnProblemType.CLASSIFICATION) {
            if (this.rawLastClassificationResults.get(0).size() == 1) {
                int numClasses = ((ICategoricalAttribute)this.dataset.getLabelAttribute()).getLabels().size();
                return (B)new SingleLabelClassificationPredictionBatch(this.rawLastClassificationResults.stream().flatMap(Collection::stream).map(x -> new SingleLabelClassification(numClasses, x.intValue())).collect(Collectors.toList()));
            }
            return (B)new SingleLabelClassificationPredictionBatch(this.rawLastClassificationResults.stream().map(x -> x.stream().mapToDouble(y -> y).toArray()).map(SingleLabelClassification::new).collect(Collectors.toList()));
        }
        if (this.problemType == EScikitLearnProblemType.RUL || this.problemType == EScikitLearnProblemType.REGRESSION) {
            if (L.isInfoEnabled()) {
                L.info("{}", this.rawLastClassificationResults.stream().flatMap(Collection::stream).collect(Collectors.toList()));
            }
            L.debug("#Created construction string: {}", (Object)this.constructInstruction);
            return (B)new SingleTargetRegressionPredictionBatch(this.rawLastClassificationResults.stream().flatMap(Collection::stream).map(x -> new SingleTargetRegressionPrediction((double)x)).collect(Collectors.toList()));
        }
        throw new PredictionException("Unknown Problem Type.");
    }

    public static String createImportStatementFromImportFolder(File importsFolder, boolean keepNamespace) throws IOException {
        File initFile;
        if (importsFolder == null || !importsFolder.exists() || importsFolder.list().length == 0) {
            return "";
        }
        if (!Arrays.asList(importsFolder.list()).contains("__init__.py") && !(initFile = new File(importsFolder, "__init__.py")).createNewFile() && L.isDebugEnabled()) {
            L.debug("Init file {} exists already", (Object)initFile.getAbsolutePath());
        }
        StringBuilder result = new StringBuilder();
        String absoluteFolderPath = importsFolder.getAbsolutePath();
        result.append("\n");
        result.append("sys.path.append(r'" + absoluteFolderPath + "')\n");
        for (File module : importsFolder.listFiles()) {
            if (module.getName().startsWith("__")) continue;
            if (keepNamespace) {
                result.append("import " + module.getName().substring(0, module.getName().length() - 3) + "\n");
                continue;
            }
            result.append("from " + module.getName().substring(0, module.getName().length() - 3) + " import *\n");
        }
        return result.toString();
    }

    private Map<String, Object> getTemplateValueMap(String constructInstruction, String imports) {
        if (constructInstruction == null || constructInstruction.isEmpty()) {
            throw new AssertionError((Object)"Construction command for classifier must be stated.");
        }
        HashMap<String, Object> templateValues = new HashMap<String, Object>();
        templateValues.put("imports", imports != null ? imports : "");
        templateValues.put("classifier_construct", constructInstruction);
        return templateValues;
    }

    public static String getImportString(Collection<String> imports) {
        return imports == null || imports.isEmpty() ? "" : "import " + StringUtils.join(imports, (String)"\nimport ");
    }

    public List<List<Double>> getRawLastClassificationResults() {
        return this.rawLastClassificationResults;
    }

    public void setProblemType(EScikitLearnProblemType problemType) {
        if (this.problemType != problemType) {
            this.problemType = problemType;
            this.scikitTemplate = new File(ResourceUtil.getResourceAsTempFile((String)this.problemType.getRessourceScikitTemplate()));
        }
    }

    public void setSeed(long seed) {
        this.seed = seed;
    }

    public void setTimeout(Timeout timeout) {
        this.timeout = timeout;
    }

    public void setTargets(int ... targetColumns) {
        this.targetColumns = targetColumns;
    }

    public void setModelPath(File modelFile) {
        this.modelFile = modelFile;
    }

    public File getModelPath() {
        return this.modelFile;
    }

    private String getArffName(ILabeledDataset<? extends ILabeledInstance> data) {
        String hash = "" + data.hashCode();
        hash = hash.startsWith("-") ? hash.replace("-", "1") : "0" + hash;
        return hash;
    }

    private void runProcess(String[] parameters, AProcessListener listener) throws InterruptedException, IOException {
        if (L.isDebugEnabled()) {
            String call = Arrays.toString(parameters).replace(",", "");
            L.debug("Starting process {}", (Object)call.substring(1, call.length() - 1));
        }
        ProcessBuilder processBuilder = new ProcessBuilder(parameters).directory(CONF.getTempFolder());
        Process process = processBuilder.start();
        try {
            L.debug("Started process with PID: {}", (Object)ProcessUtil.getPID((Process)process));
        }
        catch (ProcessIDNotRetrievableException e) {
            L.warn("Could not retrieve process ID.");
        }
        listener.listenTo(process);
    }

    public double[] distributionForInstance(ILabeledInstance instance) {
        throw new UnsupportedOperationException("This method is not yet implemented");
    }

    public String toString() {
        return this.constructInstruction;
    }

    private class ScikitLearnWrapperCommandBuilder {
        private static final String ARFF_FLAG = "--arff";
        private static final String TEST_ARFF_FLAG = "--testarff";
        private static final String MODE_FLAG = "--mode";
        private static final String MODEL_FLAG = "--model";
        private static final String OUTPUT_FLAG = "--output";
        private static final String SEED_FLAG = "--seed";
        private String arffFile;
        private String testArffFile;
        private EWrapperExecutionMode mode;
        private String modelFile;
        private String outputFile;
        private long seed;
        private Timeout timeout;

        private ScikitLearnWrapperCommandBuilder() {
        }

        public ScikitLearnWrapperCommandBuilder withTestArffFile(File testArffFile) {
            this.testArffFile = testArffFile.getAbsolutePath();
            return this;
        }

        public ScikitLearnWrapperCommandBuilder withTrainMode() {
            return this.withMode(EWrapperExecutionMode.TRAIN);
        }

        public ScikitLearnWrapperCommandBuilder withTestMode() {
            return this.withMode(EWrapperExecutionMode.TEST);
        }

        public ScikitLearnWrapperCommandBuilder withTrainTestMode() {
            return this.withMode(EWrapperExecutionMode.TRAIN_TEST);
        }

        private ScikitLearnWrapperCommandBuilder withMode(EWrapperExecutionMode execMode) {
            this.mode = execMode;
            return this;
        }

        private ScikitLearnWrapperCommandBuilder withModelFile(File modelFile) {
            if (!modelFile.exists()) {
                throw new IllegalArgumentException("Model dump does not exist");
            }
            this.modelFile = modelFile.getAbsolutePath();
            return this;
        }

        private ScikitLearnWrapperCommandBuilder withOutputFile(File outputFile) {
            this.outputFile = outputFile.getAbsolutePath();
            return this;
        }

        private ScikitLearnWrapperCommandBuilder withArffFile(File arffFile) {
            if (!arffFile.exists()) {
                throw new IllegalArgumentException("Arff File does not exist.");
            }
            this.arffFile = arffFile.getAbsolutePath();
            return this;
        }

        private ScikitLearnWrapperCommandBuilder withSeed(long seed) {
            this.seed = seed;
            return this;
        }

        private ScikitLearnWrapperCommandBuilder withTimeout(Timeout timeout) {
            this.timeout = timeout;
            return this;
        }

        private String[] toCommandArray() {
            Objects.requireNonNull(this.mode);
            Objects.requireNonNull(this.outputFile);
            Objects.requireNonNull(this.arffFile);
            File scriptFile = ScikitLearnWrapper.this.getSKLearnScriptFile();
            if (!scriptFile.exists()) {
                throw new IllegalArgumentException("The wrapped sklearn script " + scriptFile.getAbsolutePath() + " file does not exist");
            }
            ArrayList<String> processParameters = new ArrayList<String>();
            EOperatingSystem os = ProcessUtil.getOS();
            if (ScikitLearnWrapper.this.pythonConfig != null && ScikitLearnWrapper.this.pythonConfig.getAnacondaEnvironment() != null) {
                if (os == EOperatingSystem.MAC) {
                    processParameters.add("source");
                    processParameters.add("~/anaconda3/etc/profile.d/conda.sh");
                    processParameters.add("&&");
                }
                processParameters.add("conda");
                processParameters.add("activate");
                processParameters.add(ScikitLearnWrapper.this.pythonConfig.getAnacondaEnvironment());
                processParameters.add("&&");
            }
            if (this.timeout != null && os == EOperatingSystem.LINUX) {
                L.info("Executing with timeout {}s", (Object)this.timeout.seconds());
                processParameters.add("timeout");
                processParameters.add(this.timeout.seconds() - 5L + "");
            }
            if (ScikitLearnWrapper.this.pythonConfig != null && ScikitLearnWrapper.this.pythonConfig.getPath() != null) {
                processParameters.add(ScikitLearnWrapper.this.pythonConfig.getPath() + File.separator + ScikitLearnWrapper.this.pythonConfig.getPythonCommand());
            } else {
                processParameters.add(ScikitLearnWrapper.this.pythonConfig.getPythonCommand());
            }
            processParameters.add("-u");
            processParameters.add(scriptFile.getAbsolutePath());
            processParameters.addAll(Arrays.asList(MODE_FLAG, this.mode.toString()));
            processParameters.addAll(Arrays.asList(ARFF_FLAG, this.arffFile));
            if (this.testArffFile != null) {
                processParameters.addAll(Arrays.asList(TEST_ARFF_FLAG, this.testArffFile));
            }
            processParameters.addAll(Arrays.asList(OUTPUT_FLAG, this.outputFile));
            if (!ScikitLearnWrapper.this.problemType.getScikitLearnCommandLineFlag().isEmpty()) {
                processParameters.add(ScikitLearnWrapper.this.problemType.getScikitLearnCommandLineFlag());
            }
            processParameters.addAll(Arrays.asList(SEED_FLAG, String.valueOf(this.seed)));
            if (this.mode == EWrapperExecutionMode.TEST) {
                Objects.requireNonNull(this.modelFile);
                processParameters.addAll(Arrays.asList(MODEL_FLAG, this.modelFile));
            }
            if (ScikitLearnWrapper.this.targetColumns != null && ScikitLearnWrapper.this.targetColumns.length > 0) {
                processParameters.add("--targets");
                for (int i : ScikitLearnWrapper.this.targetColumns) {
                    processParameters.add("" + i);
                }
            }
            if (os == EOperatingSystem.MAC) {
                StringJoiner stringJoiner = new StringJoiner(" ");
                for (String parameter : processParameters) {
                    stringJoiner.add(parameter);
                }
                return new String[]{"sh", "-c", stringJoiner.toString()};
            }
            return processParameters.toArray(new String[0]);
        }
    }

    private static enum EWrapperExecutionMode {
        TRAIN("train"),
        TEST("test"),
        TRAIN_TEST("traintest");

        private String name;

        private EWrapperExecutionMode(String name) {
            this.name = name;
        }

        public String toString() {
            return this.name;
        }
    }
}

