/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.jaicore.ml.core.filter.sampling.inmemory.casecontrol;

import ai.libs.jaicore.basic.sets.Pair;
import ai.libs.jaicore.ml.core.filter.sampling.inmemory.casecontrol.PilotEstimateSampling;
import ai.libs.jaicore.ml.core.filter.sampling.inmemory.factories.interfaces.ISamplingAlgorithmFactory;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import org.api4.java.ai.ml.classification.IClassifier;
import org.api4.java.ai.ml.core.dataset.IInstance;
import org.api4.java.ai.ml.core.dataset.supervised.ILabeledDataset;
import org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance;
import org.api4.java.algorithm.exceptions.AlgorithmExecutionCanceledException;
import org.api4.java.algorithm.exceptions.AlgorithmTimeoutedException;

public class OSMAC<D extends ILabeledDataset<? extends ILabeledInstance>>
extends PilotEstimateSampling<D> {
    public OSMAC(Random rand, D input, IClassifier pilot) {
        super(input, pilot);
        this.rand = rand;
    }

    public OSMAC(Random rand, D input, ISamplingAlgorithmFactory<D, ?> subSamplingFactory, int preSampleSize, IClassifier pilot) {
        super(input, subSamplingFactory, preSampleSize, pilot);
        this.rand = rand;
    }

    @Override
    public List<Pair<ILabeledInstance, Double>> calculateAcceptanceThresholdsWithTrainedPilot(D instances, IClassifier pilotEstimator) throws AlgorithmTimeoutedException, InterruptedException, AlgorithmExecutionCanceledException {
        int i;
        double boundaryOfCurrentInstance = 0.0;
        ArrayList<Pair<ILabeledInstance, Double>> probabilityBoundaries = new ArrayList<Pair<ILabeledInstance, Double>>();
        double sumOfDistributionLosses = 0.0;
        int n = instances.size();
        double[] normalizedLosses = new double[n];
        for (i = 0; i < n; ++i) {
            double loss;
            if (i % 100 == 0) {
                this.checkAndConductTermination();
            }
            ILabeledInstance instance = (ILabeledInstance)instances.get(i);
            int vectorLength = 0;
            for (Object attributeVal : instance.getAttributes()) {
                if (attributeVal.equals("?")) continue;
                if (!(attributeVal instanceof Number)) {
                    throw new IllegalArgumentException("Illegal non-double attribute value " + attributeVal);
                }
                vectorLength = (int)((double)vectorLength + Double.valueOf(attributeVal.toString()));
            }
            try {
                loss = 1.0 - pilotEstimator.predict((IInstance)instance).getProbabilityOfLabel(instance.getLabel());
            }
            catch (Exception e) {
                loss = 1.0;
            }
            normalizedLosses[i] = loss * (double)vectorLength;
            sumOfDistributionLosses += normalizedLosses[i];
        }
        for (i = 0; i < n; ++i) {
            if (i % 100 == 0) {
                this.checkAndConductTermination();
            }
            probabilityBoundaries.add((Pair<ILabeledInstance, Double>)new Pair((Object)((ILabeledInstance)instances.get(i)), (Object)(boundaryOfCurrentInstance += normalizedLosses[i] / sumOfDistributionLosses)));
        }
        return probabilityBoundaries;
    }
}

