/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.jaicore.ml.core.dataset.sampling.inmemory.casecontrol;

import ai.libs.jaicore.basic.algorithm.events.AlgorithmEvent;
import ai.libs.jaicore.basic.algorithm.exceptions.AlgorithmException;
import ai.libs.jaicore.basic.sets.Pair;
import ai.libs.jaicore.ml.core.dataset.DatasetCreationException;
import ai.libs.jaicore.ml.core.dataset.IDataset;
import ai.libs.jaicore.ml.core.dataset.ILabeledAttributeArrayInstance;
import ai.libs.jaicore.ml.core.dataset.sampling.SampleElementAddedEvent;
import ai.libs.jaicore.ml.core.dataset.sampling.inmemory.casecontrol.CaseControlLikeSampling;
import ai.libs.jaicore.ml.core.dataset.weka.WekaInstances;
import java.util.ArrayList;
import java.util.HashMap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import weka.classifiers.Classifier;
import weka.classifiers.functions.Logistic;
import weka.core.Instance;
import weka.core.Instances;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.NumericToNominal;

public abstract class PilotEstimateSampling<I extends ILabeledAttributeArrayInstance<?>, D extends IDataset<I>>
extends CaseControlLikeSampling<I, D> {
    private Logger logger = LoggerFactory.getLogger(PilotEstimateSampling.class);
    protected int preSampleSize;
    private I chosenInstance = null;

    protected PilotEstimateSampling(D input) {
        super(input);
        if (!(input instanceof WekaInstances)) {
            throw new IllegalArgumentException("Pilot Estimate Sampling currently only works with WekaInstances. The signature is kept general to avoid refactoring later on.");
        }
    }

    public I getChosenInstance() {
        return this.chosenInstance;
    }

    public void setChosenInstance(I chosenInstance) {
        this.chosenInstance = chosenInstance;
    }

    public AlgorithmEvent nextWithException() throws AlgorithmException, InterruptedException {
        this.logger.info("Executing next step.");
        switch (this.getState()) {
            case CREATED: {
                this.doInitStep();
                break;
            }
            case ACTIVE: {
                if (this.sample.size() < this.sampleSize) {
                    do {
                        double r = this.rand.nextDouble();
                        this.chosenInstance = null;
                        for (int i = 0; i < this.probabilityBoundaries.size(); ++i) {
                            if (!((Double)((Pair)this.probabilityBoundaries.get(i)).getY() > r)) continue;
                            this.chosenInstance = (ILabeledAttributeArrayInstance)((Pair)this.probabilityBoundaries.get(i)).getX();
                            break;
                        }
                        if (this.chosenInstance != null) continue;
                        this.chosenInstance = (ILabeledAttributeArrayInstance)((Pair)this.probabilityBoundaries.get(this.probabilityBoundaries.size() - 1)).getX();
                    } while (this.sample.contains(this.chosenInstance));
                    this.sample.add(this.chosenInstance);
                    return new SampleElementAddedEvent(this.getId());
                }
                return this.terminate();
            }
            case INACTIVE: {
                this.doInactiveStep();
                break;
            }
            default: {
                throw new IllegalStateException("Unknown algorithm state " + this.getState());
            }
        }
        return null;
    }

    private AlgorithmEvent doInitStep() throws AlgorithmException, InterruptedException {
        try {
            this.sample = ((IDataset)this.getInput()).createEmpty();
            if (this.probabilityBoundaries == null || this.chosenInstance == null) {
                Logistic pilotEstimator = new Logistic();
                if (this.preSampleSize < 1) {
                    this.preSampleSize = ((IDataset)this.getInput()).size() / 2;
                }
                IDataset pilotEstimateSample = ((IDataset)this.getInput()).createEmpty();
                IDataset sampleCopy = ((IDataset)this.getInput()).createEmpty();
                for (ILabeledAttributeArrayInstance instance : (IDataset)this.getInput()) {
                    sampleCopy.add(instance);
                }
                HashMap<Object, Integer> classOccurrences = this.countClassOccurrences(sampleCopy);
                int numberOfClasses = classOccurrences.keySet().size();
                this.probabilityBoundaries = this.calculateInstanceBoundaries(classOccurrences, numberOfClasses);
                for (int i = 0; i < this.preSampleSize; ++i) {
                    ILabeledAttributeArrayInstance choosenInstance;
                    do {
                        double r = this.rand.nextDouble();
                        choosenInstance = null;
                        for (int j = 0; j < this.probabilityBoundaries.size(); ++j) {
                            if (!((Double)((Pair)this.probabilityBoundaries.get(j)).getY() > r)) continue;
                            choosenInstance = (ILabeledAttributeArrayInstance)((Pair)this.probabilityBoundaries.get(j)).getX();
                            break;
                        }
                        if (choosenInstance != null) continue;
                        choosenInstance = (ILabeledAttributeArrayInstance)((Pair)this.probabilityBoundaries.get(this.probabilityBoundaries.size() - 1)).getX();
                    } while (pilotEstimateSample.contains(choosenInstance));
                    pilotEstimateSample.add(choosenInstance);
                }
                Instances pilotEstimateInstances = (Instances)((WekaInstances)pilotEstimateSample).getList();
                NumericToNominal numericToNominal = new NumericToNominal();
                String[] options = new String[]{"-R", "last"};
                numericToNominal.setOptions(options);
                numericToNominal.setInputFormat(pilotEstimateInstances);
                pilotEstimateInstances = Filter.useFilter((Instances)pilotEstimateInstances, (Filter)numericToNominal);
                ArrayList<Pair> classMapping = new ArrayList<Pair>();
                for (Instance in : pilotEstimateInstances) {
                    boolean classNotInMapping = true;
                    for (Pair classPair : classMapping) {
                        if (in.classValue() != ((Double)classPair.getX()).doubleValue()) continue;
                        classNotInMapping = false;
                    }
                    if (!classNotInMapping) continue;
                    classMapping.add(new Pair((Object)in.classValue(), (Object)classMapping.size()));
                }
                pilotEstimator.buildClassifier(pilotEstimateInstances);
                this.probabilityBoundaries = this.calculateFinalInstanceBoundaries(sampleCopy, (Classifier)pilotEstimator);
            }
        }
        catch (DatasetCreationException e1) {
            throw new AlgorithmException((Throwable)e1, "Could not create a copy of the dataset.");
        }
        catch (InterruptedException e) {
            throw e;
        }
        catch (Exception e) {
            throw new AlgorithmException((Throwable)e, "Unexpected error");
        }
        return this.activate();
    }

    abstract ArrayList<Pair<I, Double>> calculateFinalInstanceBoundaries(D var1, Classifier var2);
}

