/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.jaicore.ml.core.dataset.sampling.inmemory.casecontrol;

import ai.libs.jaicore.basic.algorithm.AlgorithmExecutionCanceledException;
import ai.libs.jaicore.basic.algorithm.events.AlgorithmEvent;
import ai.libs.jaicore.basic.algorithm.exceptions.AlgorithmException;
import ai.libs.jaicore.ml.core.dataset.IDataset;
import ai.libs.jaicore.ml.core.dataset.IInstance;
import ai.libs.jaicore.ml.core.dataset.sampling.SampleElementAddedEvent;
import ai.libs.jaicore.ml.core.dataset.sampling.inmemory.casecontrol.CaseControlLikeSampling;
import ai.libs.jaicore.ml.core.dataset.weka.WekaInstances;
import java.util.Random;
import java.util.stream.IntStream;
import org.apache.commons.math3.distribution.EnumeratedIntegerDistribution;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import weka.classifiers.Classifier;
import weka.classifiers.bayes.NaiveBayes;
import weka.core.Instance;
import weka.core.Instances;

public class ClassifierWeightedSampling<I extends IInstance>
extends CaseControlLikeSampling<I> {
    private Logger logger = LoggerFactory.getLogger(ClassifierWeightedSampling.class);
    private Classifier pilotEstimator;
    private EnumeratedIntegerDistribution finalDistribution;
    private double addForRightClassification;
    private double baseValue;

    public ClassifierWeightedSampling(Random rand, Instances instances, IDataset<I> input) {
        super(input);
        this.rand = rand;
        this.pilotEstimator = new NaiveBayes();
        try {
            this.pilotEstimator.buildClassifier(instances);
        }
        catch (Exception e) {
            this.logger.error("Cannot build pilot estimator", (Throwable)e);
        }
        double mid = this.getMean(instances);
        this.baseValue = 10.0 * mid + 1.0;
        this.addForRightClassification = this.baseValue + 2.0 * mid;
    }

    public AlgorithmEvent nextWithException() throws InterruptedException, AlgorithmExecutionCanceledException, AlgorithmException {
        switch (this.getState()) {
            case created: {
                this.sample = ((IDataset)this.getInput()).createEmpty();
                IDataset sampleCopy = ((IDataset)this.getInput()).createEmpty();
                for (IInstance instance : (IDataset)this.getInput()) {
                    sampleCopy.add(instance);
                }
                this.finalDistribution = this.calculateFinalInstanceBoundariesWithDiscaring((Instances)((WekaInstances)sampleCopy).getList(), this.pilotEstimator);
                this.finalDistribution.reseedRandomGenerator(this.rand.nextLong());
                return this.activate();
            }
            case active: {
                if (this.sample.size() < this.sampleSize) {
                    IInstance choosenInstance;
                    while (this.sample.contains(choosenInstance = (IInstance)((IDataset)this.getInput()).get(this.finalDistribution.sample()))) {
                    }
                    this.sample.add(choosenInstance);
                    return new SampleElementAddedEvent(this.getId());
                }
                return this.terminate();
            }
            case inactive: {
                this.doInactiveStep();
                break;
            }
            default: {
                throw new IllegalStateException("Unknown algorithm state " + this.getState());
            }
        }
        return null;
    }

    private EnumeratedIntegerDistribution calculateFinalInstanceBoundariesWithDiscaring(Instances instances, Classifier pilotEstimator) {
        double[] weights = new double[instances.size()];
        for (int i = 0; i < instances.size(); ++i) {
            try {
                double clazz = this.pilotEstimator.classifyInstance(instances.get(i));
                if (clazz == instances.get(i).classValue()) {
                    weights[i] = this.addForRightClassification - pilotEstimator.distributionForInstance(instances.get(i))[(int)instances.get(i).classValue()];
                    continue;
                }
                weights[i] = this.baseValue + pilotEstimator.distributionForInstance(instances.get(i))[(int)clazz];
                continue;
            }
            catch (Exception e) {
                weights[i] = 0.0;
            }
        }
        int[] indices = IntStream.range(0, ((IDataset)this.getInput()).size()).toArray();
        return new EnumeratedIntegerDistribution(indices, weights);
    }

    private double getMean(Instances instances) {
        double sum = 0.0;
        for (Instance instance : instances) {
            try {
                sum += this.pilotEstimator.distributionForInstance(instance)[(int)instance.classValue()];
            }
            catch (Exception e) {
                this.logger.error("Unexpected error in pilot estimator", (Throwable)e);
            }
        }
        return sum / (double)instances.size();
    }
}

