/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.jaicore.ml.core.filter.sampling.inmemory.casecontrol;

import ai.libs.jaicore.basic.sets.Pair;
import ai.libs.jaicore.ml.core.filter.sampling.inmemory.casecontrol.CaseControlLikeSampling;
import ai.libs.jaicore.ml.core.filter.sampling.inmemory.factories.interfaces.ISamplingAlgorithmFactory;
import java.util.List;
import java.util.Objects;
import java.util.Random;
import org.api4.java.ai.ml.classification.IClassifier;
import org.api4.java.ai.ml.core.dataset.IDataSource;
import org.api4.java.ai.ml.core.dataset.supervised.ILabeledDataset;
import org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance;
import org.api4.java.ai.ml.core.filter.unsupervised.sampling.ISamplingAlgorithm;
import org.api4.java.algorithm.exceptions.AlgorithmException;
import org.api4.java.algorithm.exceptions.AlgorithmExecutionCanceledException;
import org.api4.java.algorithm.exceptions.AlgorithmTimeoutedException;

public abstract class PilotEstimateSampling<D extends ILabeledDataset<? extends ILabeledInstance>>
extends CaseControlLikeSampling<D> {
    private final ISamplingAlgorithm<D> subSampler;
    protected int preSampleSize;
    private final IClassifier pilotEstimator;

    protected PilotEstimateSampling(D input, IClassifier pilotClassifier) {
        this(input, null, 1, pilotClassifier);
    }

    protected PilotEstimateSampling(D input, ISamplingAlgorithmFactory<D, ?> subSamplingFactory, int preSampleSize, IClassifier pilotClassifier) {
        super(input);
        Objects.requireNonNull(pilotClassifier);
        this.pilotEstimator = pilotClassifier;
        this.preSampleSize = preSampleSize;
        this.subSampler = subSamplingFactory != null ? subSamplingFactory.getAlgorithm(preSampleSize, input, new Random(0L)) : null;
    }

    @Override
    public List<Pair<ILabeledInstance, Double>> computeAcceptanceThresholds() throws InterruptedException, AlgorithmTimeoutedException, AlgorithmExecutionCanceledException, AlgorithmException {
        if (this.subSampler != null) {
            ILabeledDataset subSample = (ILabeledDataset)this.subSampler.call();
            this.logger.info("Fitting pilot with reduced dataset of {}/{} instances.", (Object)subSample.size(), (Object)((ILabeledDataset)this.getInput()).size());
            this.pilotEstimator.fit((IDataSource)subSample);
        } else {
            this.logger.info("Fitting pilot with full dataset.");
            this.pilotEstimator.fit((IDataSource)((ILabeledDataset)this.getInput()));
        }
        return this.calculateAcceptanceThresholdsWithTrainedPilot((ILabeledDataset)this.getInput(), this.pilotEstimator);
    }

    public abstract List<Pair<ILabeledInstance, Double>> calculateAcceptanceThresholdsWithTrainedPilot(D var1, IClassifier var2) throws InterruptedException, AlgorithmTimeoutedException, AlgorithmExecutionCanceledException, AlgorithmException;

    public IClassifier getPilotEstimator() {
        return this.pilotEstimator;
    }
}

