/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.jaicore.ml.core.dataset.sampling.inmemory.stratified.sampling;

import ai.libs.jaicore.basic.algorithm.events.AlgorithmEvent;
import ai.libs.jaicore.basic.algorithm.exceptions.AlgorithmException;
import ai.libs.jaicore.ml.core.dataset.DatasetCreationException;
import ai.libs.jaicore.ml.core.dataset.IDataset;
import ai.libs.jaicore.ml.core.dataset.IOrderedDataset;
import ai.libs.jaicore.ml.core.dataset.sampling.SampleElementAddedEvent;
import ai.libs.jaicore.ml.core.dataset.sampling.inmemory.ASamplingAlgorithm;
import ai.libs.jaicore.ml.core.dataset.sampling.inmemory.SimpleRandomSampling;
import ai.libs.jaicore.ml.core.dataset.sampling.inmemory.WaitForSamplingStepEvent;
import ai.libs.jaicore.ml.core.dataset.sampling.inmemory.stratified.sampling.IStratiAmountSelector;
import ai.libs.jaicore.ml.core.dataset.sampling.inmemory.stratified.sampling.IStratiAssigner;
import java.util.Collection;
import java.util.Random;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class StratifiedSampling<I, D extends IOrderedDataset<I>>
extends ASamplingAlgorithm<I, D> {
    private Logger logger = LoggerFactory.getLogger(StratifiedSampling.class);
    private IStratiAmountSelector<D> stratiAmountSelector;
    private IStratiAssigner<I, D> stratiAssigner;
    private Random random;
    private IDataset[] strati = null;
    private D datasetCopy;
    private boolean allDatapointsAssigned = false;
    private boolean simpleRandomSamplingStarted;

    public StratifiedSampling(IStratiAmountSelector<D> stratiAmountSelector, IStratiAssigner<I, D> stratiAssigner, Random random, D input) {
        super(input);
        this.stratiAmountSelector = stratiAmountSelector;
        this.stratiAssigner = stratiAssigner;
        this.random = random;
    }

    public AlgorithmEvent nextWithException() throws InterruptedException, AlgorithmException {
        switch (this.getState()) {
            case CREATED: {
                try {
                    this.sample = (IOrderedDataset)((IOrderedDataset)this.getInput()).createEmpty();
                    if (!this.allDatapointsAssigned) {
                        this.datasetCopy = (IOrderedDataset)((IOrderedDataset)this.getInput()).createEmpty();
                        this.datasetCopy.addAll((Collection)this.getInput());
                        this.stratiAmountSelector.setNumCPUs(this.getNumCPUs());
                        this.stratiAssigner.setNumCPUs(this.getNumCPUs());
                        this.strati = new IDataset[this.stratiAmountSelector.selectStratiAmount(this.datasetCopy)];
                        for (int i = 0; i < this.strati.length; ++i) {
                            this.strati[i] = ((IOrderedDataset)this.getInput()).createEmpty();
                        }
                        this.stratiAssigner.init(this.datasetCopy, this.strati.length);
                    }
                    this.simpleRandomSamplingStarted = false;
                }
                catch (DatasetCreationException e) {
                    throw new AlgorithmException((Throwable)e, "Could not create a copy of the dataset.");
                }
                return this.activate();
            }
            case ACTIVE: {
                if (((IOrderedDataset)this.sample).size() < this.sampleSize) {
                    if (!this.allDatapointsAssigned) {
                        Object datapoint = this.datasetCopy.remove(0);
                        int assignedStrati = this.stratiAssigner.assignToStrati(datapoint);
                        if (assignedStrati < 0 || assignedStrati >= this.strati.length) {
                            throw new AlgorithmException("No existing strati for index " + assignedStrati);
                        }
                        this.strati[assignedStrati].add(datapoint);
                        if (this.datasetCopy.isEmpty()) {
                            this.allDatapointsAssigned = true;
                        }
                        return new SampleElementAddedEvent(this.getId());
                    }
                    if (!this.simpleRandomSamplingStarted) {
                        this.startSimpleRandomSamplingForStrati();
                        this.simpleRandomSamplingStarted = true;
                        return new WaitForSamplingStepEvent(this.getId());
                    }
                    return this.terminate();
                }
                return this.terminate();
            }
            case INACTIVE: {
                if (((IOrderedDataset)this.sample).size() < this.sampleSize) {
                    throw new AlgorithmException("Expected sample size was not reached before termination");
                }
                return this.terminate();
            }
        }
        throw new IllegalStateException("Unknown algorithm state " + this.getState());
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private void startSimpleRandomSamplingForStrati() {
        int i;
        int[] sampleSizeForStrati = new int[this.strati.length];
        for (i = 0; i < this.strati.length; ++i) {
            sampleSizeForStrati[i] = Math.round((float)((double)this.sampleSize.intValue() * ((double)this.strati[i].size() / (double)((IOrderedDataset)this.getInput()).size())));
        }
        for (i = 0; i < this.strati.length; ++i) {
            SimpleRandomSampling simpleRandomSampling = new SimpleRandomSampling(this.random, (IOrderedDataset)this.strati[i]);
            simpleRandomSampling.setSampleSize(sampleSizeForStrati[i]);
            try {
                IOrderedDataset iOrderedDataset = (IOrderedDataset)this.sample;
                synchronized (iOrderedDataset) {
                    ((IOrderedDataset)this.sample).addAll(simpleRandomSampling.call());
                    continue;
                }
            }
            catch (Exception e) {
                this.logger.error("Unexpected exception during simple random sampling!", (Throwable)e);
            }
        }
    }

    public IDataset[] getStrati() {
        return this.strati;
    }

    public void setStrati(IDataset[] strati) {
        this.strati = strati;
    }
}

