/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.jaicore.ml.core.filter.sampling.inmemory.casecontrol;

import ai.libs.jaicore.basic.sets.Pair;
import ai.libs.jaicore.ml.core.filter.sampling.inmemory.casecontrol.PilotEstimateSampling;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Random;
import java.util.stream.IntStream;
import org.apache.commons.math3.distribution.EnumeratedIntegerDistribution;
import org.api4.java.ai.ml.classification.IClassifier;
import org.api4.java.ai.ml.core.dataset.IInstance;
import org.api4.java.ai.ml.core.dataset.supervised.ILabeledDataset;
import org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance;
import org.api4.java.ai.ml.core.evaluation.IPrediction;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ClassifierWeightedSampling<D extends ILabeledDataset<? extends ILabeledInstance>>
extends PilotEstimateSampling<D> {
    private Logger logger = LoggerFactory.getLogger(ClassifierWeightedSampling.class);

    public ClassifierWeightedSampling(IClassifier pilotEstimator, Random rand, D dataset) {
        super(dataset, pilotEstimator);
        this.rand = rand;
    }

    private double getMean(ILabeledDataset<?> instances) {
        double sum = 0.0;
        for (ILabeledInstance instance : instances) {
            try {
                sum += this.getPilotEstimator().predict((IInstance)instance).getProbabilityOfLabel(instance.getLabel());
            }
            catch (Exception e) {
                this.logger.error("Unexpected error in pilot estimator", (Throwable)e);
            }
        }
        return sum / (double)instances.size();
    }

    @Override
    public List<Pair<ILabeledInstance, Double>> calculateAcceptanceThresholdsWithTrainedPilot(D dataset, IClassifier pilot) {
        double mid = this.getMean((ILabeledDataset<?>)dataset);
        double baseValue = 10.0 * mid + 1.0;
        double addForRightClassification = baseValue + 2.0 * mid;
        double[] weights = new double[dataset.size()];
        for (int i = 0; i < weights.length; ++i) {
            try {
                IPrediction prediction = pilot.predict((IInstance)((ILabeledInstance)dataset.get(i)));
                if (prediction.getLabelWithHighestProbability() == ((ILabeledInstance)dataset.get(i)).getLabel()) {
                    weights[i] = addForRightClassification - prediction.getProbabilityOfLabel(((ILabeledInstance)dataset.get(i)).getLabel());
                    continue;
                }
                weights[i] = baseValue + prediction.getProbabilityOfLabel(prediction.getLabelWithHighestProbability());
                continue;
            }
            catch (Exception e) {
                weights[i] = 0.0;
            }
        }
        int[] indices = IntStream.range(0, ((ILabeledDataset)this.getInput()).size()).toArray();
        EnumeratedIntegerDistribution finalDistribution = new EnumeratedIntegerDistribution(indices, weights);
        finalDistribution.reseedRandomGenerator(this.rand.nextLong());
        int n = this.getSampleSize();
        HashSet<Integer> consideredIndices = new HashSet<Integer>();
        for (int i = 0; i < n; ++i) {
            int index;
            while (consideredIndices.contains(index = finalDistribution.sample())) {
            }
            consideredIndices.add(index);
        }
        ArrayList<Pair<ILabeledInstance, Double>> thresholds = new ArrayList<Pair<ILabeledInstance, Double>>();
        int m = dataset.size();
        for (int i = 0; i < m; ++i) {
            ILabeledInstance inst = (ILabeledInstance)dataset.get(i);
            double threshold = consideredIndices.contains(i) ? 1.0 : 0.0;
            thresholds.add((Pair<ILabeledInstance, Double>)new Pair((Object)inst, (Object)threshold));
        }
        return thresholds;
    }
}

