/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.jaicore.ml.experiments;

import ai.libs.jaicore.logging.LoggerUtil;
import ai.libs.jaicore.ml.WekaUtil;
import ai.libs.jaicore.ml.core.evaluation.measure.singlelabel.EMulticlassMeasure;
import ai.libs.jaicore.ml.experiments.IMultiClassClassificationExperimentDatabase;
import ai.libs.jaicore.ml.experiments.MLExperiment;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.node.ArrayNode;
import com.google.common.collect.ContiguousSet;
import com.google.common.collect.DiscreteDomain;
import com.google.common.collect.Range;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.io.Reader;
import java.lang.reflect.Method;
import java.nio.file.FileVisitOption;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.commons.lang3.reflect.MethodUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import weka.classifiers.Classifier;
import weka.core.Instance;
import weka.core.Instances;

public abstract class MultiClassClassificationExperimentRunner {
    private static final Logger logger = LoggerFactory.getLogger(MultiClassClassificationExperimentRunner.class);
    private final File datasetFolder;
    private final List<File> availableDatasets;
    private final String[] classifiers;
    private final Map<String, String[]> setups;
    private final int numberOfSetups;
    private final int[] timeoutsInSeconds;
    private final int numberOfRunsPerExperiment;
    private final float trainingPortion;
    private final int numberOfCPUs;
    private final int memoryInMB;
    private final EMulticlassMeasure performanceMeasure;
    private final IMultiClassClassificationExperimentDatabase database;
    private final int totalExperimentSize;
    private Collection<MLExperiment> experimentsConductedEarlier;

    public MultiClassClassificationExperimentRunner(File datasetFolder, String[] classifiers, Map<String, String[]> setups, int[] timeoutsInSeconds, int numberOfRunsPerExperiment, float trainingPortion, int numberOfCPUs, int memoryInMB, EMulticlassMeasure performanceMeasure, IMultiClassClassificationExperimentDatabase logger) throws IOException {
        this.datasetFolder = datasetFolder;
        this.availableDatasets = this.getAvailableDatasets(datasetFolder);
        this.classifiers = classifiers;
        this.setups = setups;
        this.timeoutsInSeconds = timeoutsInSeconds;
        this.numberOfRunsPerExperiment = numberOfRunsPerExperiment;
        this.trainingPortion = trainingPortion;
        this.numberOfCPUs = numberOfCPUs;
        this.memoryInMB = memoryInMB;
        this.performanceMeasure = performanceMeasure;
        this.database = logger;
        int tmpNumberOfSetups = 0;
        for (String[] setupsOfClassifier : this.setups.values()) {
            tmpNumberOfSetups += setupsOfClassifier.length;
        }
        this.numberOfSetups = tmpNumberOfSetups;
        this.totalExperimentSize = classifiers.length * this.availableDatasets.size() * this.numberOfSetups * numberOfRunsPerExperiment * timeoutsInSeconds.length;
        System.out.println("Available datasets: ");
        AtomicInteger i = new AtomicInteger();
        this.availableDatasets.stream().forEach(ds -> System.out.println("\t" + i.getAndIncrement() + ": " + ds.getName()));
        System.out.println("Available algorithms: ");
        i.set(0);
        Arrays.asList(classifiers).stream().forEach(c -> System.out.println("\t" + i.getAndIncrement() + ": " + c));
    }

    protected abstract Classifier getConfiguredClassifier(int var1, String var2, String var3, int var4, int var5, int var6, EMulticlassMeasure var7);

    public void runAll() throws Exception {
        this.experimentsConductedEarlier = this.database.getExperimentsForWhichARunExists();
        for (int k = 0; k < this.totalExperimentSize; ++k) {
            try {
                this.runSpecific(k);
                continue;
            }
            catch (ExperimentAlreadyConductedException e) {
                System.out.println(e.getMessage());
                continue;
            }
            catch (Exception e) {
                e.printStackTrace();
            }
        }
    }

    public void runAny() throws Exception {
        this.experimentsConductedEarlier = this.database.getExperimentsForWhichARunExists();
        ArrayList indices = new ArrayList(ContiguousSet.create((Range)Range.closed((Comparable)Integer.valueOf(0), (Comparable)Integer.valueOf(this.totalExperimentSize - 1)), (DiscreteDomain)DiscreteDomain.integers()).asList());
        Collections.shuffle(indices);
        Iterator iterator = indices.iterator();
        while (iterator.hasNext()) {
            int index = (Integer)iterator.next();
            try {
                this.runSpecific(index);
                return;
            }
            catch (ExperimentAlreadyConductedException e) {
                System.out.println(e.getMessage());
            }
            catch (Exception e) {
                e.printStackTrace();
            }
        }
        while (true) {
            // Infinite loop
        }
    }

    public void runSpecific(int k) throws Exception {
        int numberOfDatasets = this.availableDatasets.size();
        int numberOfSeeds = this.numberOfRunsPerExperiment;
        int numberOfTimeouts = this.timeoutsInSeconds.length;
        System.out.println("Number of runs (seeds) per dataset/algo-combination: " + numberOfSeeds);
        int frameSizeForTimeout = this.totalExperimentSize / numberOfTimeouts;
        int frameSizeForSeed = frameSizeForTimeout / numberOfSeeds;
        int frameSizeForSetup = frameSizeForSeed / this.numberOfSetups;
        int frameSizeForDataset = frameSizeForSetup / numberOfDatasets;
        if (k >= this.totalExperimentSize) {
            throw new IllegalArgumentException("Only " + this.totalExperimentSize + " experiments defined.");
        }
        int timeoutId = (int)Math.floor((float)(k / frameSizeForTimeout) * 1.0f);
        int indexWithinTimeout = k % frameSizeForTimeout;
        int seedId = (int)Math.floor((float)(indexWithinTimeout / frameSizeForSeed) * 1.0f);
        int indexWithinSeed = indexWithinTimeout % frameSizeForSeed;
        int datasetId = (int)Math.floor((float)(indexWithinSeed / frameSizeForDataset) * 1.0f);
        int indexWithinDataset = indexWithinSeed % frameSizeForDataset;
        int algoAndSetupId = (int)Math.floor((float)(indexWithinDataset / frameSizeForSetup) * 1.0f);
        System.out.println("Running experiment " + k + "/" + this.totalExperimentSize + ". The setup is: " + timeoutId + "/" + seedId + "/" + datasetId + "//" + algoAndSetupId + "(timeout/seed/dataset/algo-setup-id)");
        this.runExperiment(datasetId, timeoutId, seedId, algoAndSetupId);
    }

    private int getAlgoIdForAlgoSetupId(int algoSetupId) {
        int counter = 0;
        for (int i = 0; i < this.classifiers.length; ++i) {
            if (algoSetupId >= (counter += this.setups.get(this.classifiers[i]).length)) continue;
            return i;
        }
        return -1;
    }

    private int getSetupIdForAlgoSetupId(int algoSetupId) {
        int counter = 0;
        for (int i = 0; i < this.classifiers.length; ++i) {
            String[] setupsOfThisClassifier = this.setups.get(this.classifiers[i]);
            for (int j = 0; j < setupsOfThisClassifier.length; ++j) {
                if (counter == algoSetupId) {
                    return j;
                }
                ++counter;
            }
        }
        return -1;
    }

    public void runExperiment(int datasetId, int timeoutId, int seedId, int algoAndSetupId) throws Exception {
        String datasetName = this.availableDatasets.get(datasetId).getName();
        datasetName = datasetName.substring(0, datasetName.lastIndexOf("."));
        int algoId = this.getAlgoIdForAlgoSetupId(algoAndSetupId);
        int setupId = this.getSetupIdForAlgoSetupId(algoAndSetupId);
        String algo = this.classifiers[algoId];
        String algoMode = this.setups.get(algo)[setupId];
        int timeoutInSeconds = this.timeoutsInSeconds[timeoutId];
        if (this.performanceMeasure != EMulticlassMeasure.ERROR_RATE) {
            throw new IllegalArgumentException("Currently the only supported performance measure is errorRate");
        }
        MLExperiment exp = new MLExperiment(new File(this.datasetFolder + File.separator + this.availableDatasets.get(datasetId)).getAbsolutePath(), algo, algoMode, seedId, timeoutInSeconds, this.numberOfCPUs, this.memoryInMB, this.performanceMeasure.toString());
        if (this.experimentsConductedEarlier != null && this.experimentsConductedEarlier.contains(exp)) {
            throw new ExperimentAlreadyConductedException("Experiment " + exp + " has already been conducted");
        }
        try {
            System.out.println("Now configuring classifier ...");
            Classifier c = this.getConfiguredClassifier(seedId, algo, algoMode, timeoutInSeconds, this.numberOfCPUs, this.memoryInMB, this.performanceMeasure);
            Collection<MLExperiment> experiments = this.database.getExperimentsForWhichARunExists();
            if (experiments.contains(exp)) {
                throw new ExperimentAlreadyConductedException("Experiment has already been conducted, but rather recently: " + exp);
            }
            int runId = this.database.createRunIfDoesNotExist(exp);
            if (runId < 0) {
                throw new ExperimentAlreadyConductedException("Experiment has already been conducted, but quite recently: " + exp);
            }
            System.out.println("The assigned runId for this experiment is " + runId);
            Random r = new Random(seedId);
            Instances data = this.getKthInstances(this.datasetFolder, datasetId);
            data.setClassIndex(data.numAttributes() - 1);
            Collection<Integer>[] overallSplitIndices = WekaUtil.getStratifiedSplitIndices(data, r, this.trainingPortion);
            List<Instances> overallSplit = WekaUtil.realizeSplit(data, overallSplitIndices);
            Instances internalData = overallSplit.get(0);
            Instances testData = overallSplit.get(1);
            ObjectMapper om = new ObjectMapper();
            ArrayNode an = om.createArrayNode();
            overallSplitIndices[0].stream().sorted().forEach(v -> an.add(v));
            System.out.println("Data were split into " + internalData.size() + "/" + testData.size());
            HashMap<String, String> runUpdate = new HashMap<String, String>();
            runUpdate.put("rows_for_training", an.toString());
            this.database.updateExperiment(exp, runUpdate);
            this.database.associatedRunWithClassifier(runId, c);
            System.out.println("Classifier configured. Determining result files.");
            System.out.println("Invoking " + this.getExperimentDescription(datasetId, c, seedId) + " with setup " + algoMode + " and timeout " + this.timeoutsInSeconds[timeoutId] + "s");
            long start = System.currentTimeMillis();
            try {
                c.buildClassifier(internalData);
                long end = System.currentTimeMillis();
                System.out.println("Search has finished. Runtime: " + (float)(end - start) / 1000.0f + " s");
                int mistakes = 0;
                Method m = MethodUtils.getMatchingAccessibleMethod(c.getClass(), (String)"classifyInstances", (Class[])new Class[]{Instances.class});
                if (m != null) {
                    Object predictions = (double[])m.invoke((Object)c, testData);
                    for (int i = 0; i < ((Object)predictions).length; ++i) {
                        if (predictions[i] == testData.get(i).classValue()) continue;
                        ++mistakes;
                    }
                } else {
                    for (Instance i : testData) {
                        if (i.classValue() == c.classifyInstance(i)) continue;
                        ++mistakes;
                    }
                }
                double error = (float)mistakes * 10000.0f / (float)testData.size();
                System.out.println("Sending error Rate " + error + " to logger.");
                this.database.addResultEntry(runId, error);
            }
            catch (Throwable e) {
                logger.error("Experiment failed. Details:\n{}", (Object)LoggerUtil.getExceptionInfo((Throwable)e));
                System.out.println("Sending error Rate -10000 to logger.");
                try {
                    this.database.addResultEntry(runId, -10000.0);
                }
                catch (Exception e1) {
                    logger.error("Could not write result to database. Details:\n{}", (Object)LoggerUtil.getExceptionInfo((Throwable)e1));
                }
            }
        }
        catch (Exception e) {
            logger.error("Experiment failed. Details:\n{}", (Object)LoggerUtil.getExceptionInfo((Throwable)e));
        }
    }

    public String getExperimentDescription(int datasetId, Classifier algorithm, int seed) {
        return algorithm + "-" + this.availableDatasets.get(datasetId).getName() + "-" + seed;
    }

    public List<File> getAvailableDatasets(File folder) throws IOException {
        ArrayList files = new ArrayList();
        try (Stream<Path> paths = Files.walk(folder.toPath(), new FileVisitOption[0]);){
            paths.filter(f -> f.getParent().toFile().equals(folder) && f.toFile().getAbsolutePath().endsWith(".arff")).forEach(f -> files.add(f.toFile()));
        }
        return files.stream().sorted().collect(Collectors.toList());
    }

    public Instances getKthInstances(File folder, int k) throws IOException {
        File f = this.getAvailableDatasets(folder).get(k);
        System.out.println("Selecting " + f);
        Instances inst = new Instances((Reader)new BufferedReader(new FileReader(f)));
        inst.setRelationName(f.getAbsolutePath().replace(File.separator, "/"));
        return inst;
    }

    public IMultiClassClassificationExperimentDatabase getLogger() {
        return this.database;
    }

    class ExperimentAlreadyConductedException
    extends Exception {
        public ExperimentAlreadyConductedException(String message) {
            super(message);
        }
    }
}

