/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.jaicore.ml.core.dataset.sampling.inmemory.stratified.sampling;

import ai.libs.jaicore.basic.algorithm.events.AlgorithmEvent;
import ai.libs.jaicore.basic.algorithm.exceptions.AlgorithmException;
import ai.libs.jaicore.ml.core.dataset.IDataset;
import ai.libs.jaicore.ml.core.dataset.IInstance;
import ai.libs.jaicore.ml.core.dataset.sampling.SampleElementAddedEvent;
import ai.libs.jaicore.ml.core.dataset.sampling.inmemory.ASamplingAlgorithm;
import ai.libs.jaicore.ml.core.dataset.sampling.inmemory.SimpleRandomSampling;
import ai.libs.jaicore.ml.core.dataset.sampling.inmemory.WaitForSamplingStepEvent;
import ai.libs.jaicore.ml.core.dataset.sampling.inmemory.stratified.sampling.IStratiAmountSelector;
import ai.libs.jaicore.ml.core.dataset.sampling.inmemory.stratified.sampling.IStratiAssigner;
import java.util.Collection;
import java.util.Random;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class StratifiedSampling<I extends IInstance>
extends ASamplingAlgorithm<I> {
    private Logger logger = LoggerFactory.getLogger(StratifiedSampling.class);
    private IStratiAmountSelector<I> stratiAmountSelector;
    private IStratiAssigner<I> stratiAssigner;
    private Random random;
    private IDataset<I>[] strati = null;
    private IDataset<I> datasetCopy;
    private ExecutorService executorService;
    private boolean allDatapointsAssigned = false;
    private boolean simpleRandomSamplingStarted;

    public StratifiedSampling(IStratiAmountSelector<I> stratiAmountSelector, IStratiAssigner<I> stratiAssigner, Random random, IDataset<I> input) {
        super(input);
        this.stratiAmountSelector = stratiAmountSelector;
        this.stratiAssigner = stratiAssigner;
        this.random = random;
    }

    public AlgorithmEvent nextWithException() throws InterruptedException, AlgorithmException {
        switch (this.getState()) {
            case created: {
                this.sample = ((IDataset)this.getInput()).createEmpty();
                if (!this.allDatapointsAssigned) {
                    this.datasetCopy = ((IDataset)this.getInput()).createEmpty();
                    this.datasetCopy.addAll((Collection)this.getInput());
                    this.stratiAmountSelector.setNumCPUs(this.getNumCPUs());
                    this.stratiAssigner.setNumCPUs(this.getNumCPUs());
                    this.strati = new IDataset[this.stratiAmountSelector.selectStratiAmount(this.datasetCopy)];
                    for (int i = 0; i < this.strati.length; ++i) {
                        this.strati[i] = ((IDataset)this.getInput()).createEmpty();
                    }
                    this.stratiAssigner.init(this.datasetCopy, this.strati.length);
                }
                this.simpleRandomSamplingStarted = false;
                this.executorService = Executors.newCachedThreadPool();
                return this.activate();
            }
            case active: {
                if (this.sample.size() < this.sampleSize) {
                    if (!this.allDatapointsAssigned) {
                        IInstance datapoint = (IInstance)this.datasetCopy.remove(0);
                        int assignedStrati = this.stratiAssigner.assignToStrati(datapoint);
                        if (assignedStrati < 0 || assignedStrati >= this.strati.length) {
                            throw new AlgorithmException("No existing strati for index " + assignedStrati);
                        }
                        this.strati[assignedStrati].add(datapoint);
                        if (this.datasetCopy.isEmpty()) {
                            this.allDatapointsAssigned = true;
                        }
                        return new SampleElementAddedEvent(this.getId());
                    }
                    if (!this.simpleRandomSamplingStarted) {
                        this.startSimpleRandomSamplingForStrati();
                        this.simpleRandomSamplingStarted = true;
                        return new WaitForSamplingStepEvent(this.getId());
                    }
                    if (this.executorService.isTerminated()) {
                        return this.terminate();
                    }
                    Thread.sleep(100L);
                    return new WaitForSamplingStepEvent(this.getId());
                }
                return this.terminate();
            }
            case inactive: {
                if (this.sample.size() < this.sampleSize) {
                    throw new AlgorithmException("Expected sample size was not reached before termination");
                }
                return this.terminate();
            }
        }
        throw new IllegalStateException("Unknown algorithm state " + this.getState());
    }

    private void startSimpleRandomSamplingForStrati() {
        int i;
        int[] sampleSizeForStrati = new int[this.strati.length];
        for (i = 0; i < this.strati.length; ++i) {
            sampleSizeForStrati[i] = Math.round((float)((double)this.sampleSize.intValue() * ((double)this.strati[i].size() / (double)((IDataset)this.getInput()).size())));
        }
        i = 0;
        while (i < this.strati.length) {
            int index = i++;
            this.executorService.execute(() -> {
                SimpleRandomSampling<I> simpleRandomSampling = new SimpleRandomSampling<I>(this.random, this.strati[index]);
                simpleRandomSampling.setSampleSize(sampleSizeForStrati[index]);
                try {
                    IDataset iDataset = this.sample;
                    synchronized (iDataset) {
                        this.sample.addAll(simpleRandomSampling.call());
                    }
                }
                catch (Exception e) {
                    this.logger.error("Unexpected exception during simple random sampling!", (Throwable)e);
                }
            });
        }
        this.executorService.shutdown();
    }

    public IDataset<I>[] getStrati() {
        return this.strati;
    }

    public void setStrati(IDataset<I>[] strati) {
        this.strati = strati;
        this.allDatapointsAssigned = true;
    }
}

