/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.jaicore.ml.core.filter.sampling.inmemory.casecontrol;

import ai.libs.jaicore.basic.sets.Pair;
import ai.libs.jaicore.ml.core.filter.sampling.inmemory.casecontrol.PilotEstimateSampling;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import org.api4.java.ai.ml.classification.IClassifier;
import org.api4.java.ai.ml.core.dataset.IInstance;
import org.api4.java.ai.ml.core.dataset.supervised.ILabeledDataset;
import org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance;
import org.api4.java.algorithm.exceptions.AlgorithmExecutionCanceledException;
import org.api4.java.algorithm.exceptions.AlgorithmTimeoutedException;

public class LocalCaseControlSampling
extends PilotEstimateSampling<ILabeledDataset<?>> {
    public LocalCaseControlSampling(Random rand, int preSampleSize, ILabeledDataset<?> input, IClassifier pilot) {
        super(input, pilot);
        this.rand = rand;
        this.preSampleSize = preSampleSize;
    }

    @Override
    public List<Pair<ILabeledInstance, Double>> calculateAcceptanceThresholdsWithTrainedPilot(ILabeledDataset<?> instances, IClassifier pilotEstimator) throws AlgorithmTimeoutedException, InterruptedException, AlgorithmExecutionCanceledException {
        double loss;
        double boundaryOfCurrentInstance = 0.0;
        ArrayList<Pair> instanceProbabilityBoundaries = new ArrayList<Pair>();
        double sumOfDistributionLosses = 0.0;
        int i = 0;
        for (ILabeledInstance instance : instances) {
            if (i++ % 100 == 0) {
                this.checkAndConductTermination();
            }
            try {
                loss = 1.0 - pilotEstimator.predict((IInstance)instance).getProbabilityOfLabel(instance.getLabel());
            }
            catch (Exception e) {
                loss = 1.0;
            }
            sumOfDistributionLosses += loss;
        }
        for (ILabeledInstance instance : instances) {
            if (i++ % 100 == 0) {
                this.checkAndConductTermination();
            }
            try {
                loss = 1.0 - pilotEstimator.predict((IInstance)instance).getProbabilityOfLabel(instance.getLabel());
            }
            catch (Exception e) {
                loss = 1.0;
            }
            instanceProbabilityBoundaries.add(new Pair((Object)instance, (Object)(boundaryOfCurrentInstance += loss / sumOfDistributionLosses)));
        }
        ArrayList<Pair<ILabeledInstance, Double>> probabilityBoundaries = new ArrayList<Pair<ILabeledInstance, Double>>();
        int iterator = 0;
        for (ILabeledInstance instance : instances) {
            if (iterator % 100 == 0) {
                this.checkAndConductTermination();
            }
            probabilityBoundaries.add((Pair<ILabeledInstance, Double>)new Pair((Object)instance, (Object)((Double)((Pair)instanceProbabilityBoundaries.get(iterator)).getY())));
            ++iterator;
        }
        return probabilityBoundaries;
    }
}

