/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.jaicore.ml.core.filter.sampling.inmemory.stratified.sampling;

import ai.libs.jaicore.ml.core.dataset.DatasetDeriver;
import ai.libs.jaicore.ml.core.filter.sampling.SampleElementAddedEvent;
import ai.libs.jaicore.ml.core.filter.sampling.inmemory.ASamplingAlgorithm;
import ai.libs.jaicore.ml.core.filter.sampling.inmemory.SimpleRandomSampling;
import ai.libs.jaicore.ml.core.filter.sampling.inmemory.WaitForSamplingStepEvent;
import ai.libs.jaicore.ml.core.filter.sampling.inmemory.stratified.sampling.IStratiAmountSelector;
import ai.libs.jaicore.ml.core.filter.sampling.inmemory.stratified.sampling.IStratiAssigner;
import java.lang.reflect.Array;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Random;
import org.api4.java.ai.ml.core.dataset.IDataset;
import org.api4.java.ai.ml.core.dataset.IInstance;
import org.api4.java.ai.ml.core.exception.DatasetCreationException;
import org.api4.java.algorithm.IAlgorithm;
import org.api4.java.algorithm.events.IAlgorithmEvent;
import org.api4.java.algorithm.exceptions.AlgorithmException;
import org.api4.java.algorithm.exceptions.AlgorithmExecutionCanceledException;
import org.api4.java.algorithm.exceptions.AlgorithmTimeoutedException;
import org.api4.java.common.control.ILoggingCustomizable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class StratifiedSampling<D extends IDataset<?>>
extends ASamplingAlgorithm<D> {
    private Logger logger = LoggerFactory.getLogger(StratifiedSampling.class);
    private IStratiAmountSelector stratiAmountSelector;
    private IStratiAssigner stratiAssigner;
    private Random random;
    private DatasetDeriver<D>[] stratiBuilder = null;
    private boolean allDatapointsAssigned = false;
    private boolean simpleRandomSamplingStarted;

    public StratifiedSampling(IStratiAmountSelector stratiAmountSelector, IStratiAssigner stratiAssigner, Random random, D input) {
        super(input);
        this.stratiAmountSelector = stratiAmountSelector;
        this.stratiAssigner = stratiAssigner;
        this.random = random;
    }

    public IAlgorithmEvent nextWithException() throws InterruptedException, AlgorithmException, AlgorithmTimeoutedException, AlgorithmExecutionCanceledException {
        switch (this.getState()) {
            case CREATED: {
                if (!this.allDatapointsAssigned) {
                    int dsHash = ((IDataset)this.getInput()).hashCode();
                    this.stratiAmountSelector.setNumCPUs(this.getNumCPUs());
                    this.stratiAssigner.setNumCPUs(this.getNumCPUs());
                    this.stratiBuilder = (DatasetDeriver[])Array.newInstance(DatasetDeriver.class, this.stratiAmountSelector.selectStratiAmount((IDataset)this.getInput()));
                    for (int i = 0; i < this.stratiBuilder.length; ++i) {
                        this.stratiBuilder[i] = new DatasetDeriver<IDataset>((IDataset)this.getInput());
                    }
                    if (this.stratiBuilder.length == 0) {
                        throw new IllegalStateException("No strati have been defined.");
                    }
                    this.stratiAssigner.init((IDataset)this.getInput(), this.stratiBuilder.length);
                    if (((IDataset)this.getInput()).hashCode() != dsHash) {
                        throw new IllegalStateException("Original dataset has been modified!");
                    }
                }
                this.simpleRandomSamplingStarted = false;
                this.logger.info("Stratified sampler initialized.");
                return this.activate();
            }
            case ACTIVE: {
                if (!this.allDatapointsAssigned) {
                    this.logger.info("Starting to sort all datapoints into their strati.");
                    IDataset dataset = (IDataset)this.getInput();
                    int n = dataset.size();
                    for (int i = 0; i < n; ++i) {
                        IInstance datapoint = (IInstance)dataset.get(i);
                        if (i % 100 == 0) {
                            this.checkAndConductTermination();
                        }
                        this.logger.debug("Computing statrum for next data point {}", (Object)datapoint);
                        int assignedStratum = this.stratiAssigner.assignToStrati(datapoint);
                        if (assignedStratum < 0 || assignedStratum >= this.stratiBuilder.length) {
                            throw new AlgorithmException("No existing strati for index " + assignedStratum);
                        }
                        this.stratiBuilder[assignedStratum].add(i);
                        this.logger.debug("Added data point {} to stratum {}. {} datapoints remaining.", new Object[]{datapoint, assignedStratum, n - i - 1});
                    }
                    this.allDatapointsAssigned = true;
                    int totalItemsAssigned = 0;
                    for (DatasetDeriver<D> d : this.stratiBuilder) {
                        this.logger.debug("Elements in stratum: {}", (Object)d.currentSizeOfTarget());
                        totalItemsAssigned += d.currentSizeOfTarget();
                    }
                    this.logger.info("Finished stratum assignments. Assigned {} data points in total.", (Object)totalItemsAssigned);
                    if (totalItemsAssigned != ((IDataset)this.getInput()).size()) {
                        throw new IllegalStateException("Not all data have been collected.");
                    }
                    return new SampleElementAddedEvent((IAlgorithm<?, ?>)this);
                }
                if (!this.simpleRandomSamplingStarted) {
                    try {
                        this.startSimpleRandomSamplingForStrati();
                    }
                    catch (DatasetCreationException e) {
                        throw new AlgorithmException("Could not create sample from strati.", (Throwable)e);
                    }
                    this.simpleRandomSamplingStarted = true;
                    return new WaitForSamplingStepEvent((IAlgorithm<?, ?>)this);
                }
                this.logger.info("Stratified sampling completed.");
                return this.terminate();
            }
            case INACTIVE: {
                if (this.sample.size() < this.sampleSize) {
                    throw new AlgorithmException("Expected sample size was not reached before termination");
                }
                return this.terminate();
            }
        }
        throw new IllegalStateException("Unknown algorithm state " + this.getState());
    }

    private void startSimpleRandomSamplingForStrati() throws InterruptedException, DatasetCreationException, AlgorithmTimeoutedException, AlgorithmExecutionCanceledException {
        if (this.sampleSize == -1) {
            throw new IllegalStateException("No valid sample size specified");
        }
        this.logger.info("Now drawing simple random elements in each stratum.");
        int[] sampleSizeForStrati = new int[this.stratiBuilder.length];
        int numSamplesTotal = 0;
        ArrayList<Integer> fillupStrati = new ArrayList<Integer>();
        double totalInputSize = ((IDataset)this.getInput()).size();
        for (int i = 0; i < this.stratiBuilder.length; ++i) {
            if (this.stratiBuilder[i].currentSizeOfTarget() < 0) {
                throw new IllegalStateException("Builder for stratum " + i + " has a negative current target size: " + this.stratiBuilder[i].currentSizeOfTarget());
            }
            int totalNumberOfElementsInStratum = this.stratiBuilder[i].currentSizeOfTarget();
            sampleSizeForStrati[i] = (int)Math.floor((double)totalNumberOfElementsInStratum * ((double)this.sampleSize / totalInputSize));
            if (sampleSizeForStrati[i] < 0) {
                throw new IllegalStateException("Determined negative stratum size " + sampleSizeForStrati[i] + " for " + i + "-th stratum.");
            }
            numSamplesTotal += sampleSizeForStrati[i];
            fillupStrati.add(i);
        }
        while (numSamplesTotal < this.sampleSize) {
            int indexForNextFillUp;
            Collections.shuffle(fillupStrati, this.random);
            int n = indexForNextFillUp = ((Integer)fillupStrati.remove(0)).intValue();
            sampleSizeForStrati[n] = sampleSizeForStrati[n] + 1;
            ++numSamplesTotal;
        }
        if (numSamplesTotal != this.sampleSize) {
            throw new IllegalStateException("Number of samples is " + numSamplesTotal + " where it should be " + this.sampleSize);
        }
        int stratiSumCheck = 0;
        for (int i = 0; i < this.stratiBuilder.length; ++i) {
            stratiSumCheck += sampleSizeForStrati[i];
        }
        if (stratiSumCheck != this.sampleSize) {
            throw new IllegalStateException("The total number of samples assigned within the strati is " + stratiSumCheck + ", but it should be " + this.sampleSize + ".");
        }
        DatasetDeriver<IDataset> sampleDeriver = new DatasetDeriver<IDataset>((IDataset)this.getInput());
        for (int i = 0; i < this.stratiBuilder.length; ++i) {
            DatasetDeriver<D> stratumBuilder = this.stratiBuilder[i];
            D stratum = stratumBuilder.build();
            if (stratum.isEmpty()) {
                this.logger.warn("{}-th stratum is empty!", (Object)i);
                continue;
            }
            if (sampleSizeForStrati[i] == 0) {
                this.logger.warn("No samples for stratum {}", (Object)i);
                continue;
            }
            if (sampleSizeForStrati[i] == stratum.size()) {
                sampleDeriver.addIndices(stratumBuilder.getIndicesOfNewInstancesInOriginalDataset());
                continue;
            }
            this.checkAndConductTermination();
            SimpleRandomSampling<D> simpleRandomSampling = new SimpleRandomSampling<D>(this.random, stratum);
            simpleRandomSampling.setSampleSize(sampleSizeForStrati[i]);
            this.logger.info("Setting sample size for {}-th stratus to {}", (Object)i, (Object)sampleSizeForStrati[i]);
            try {
                this.logger.debug("Calling SimpleRandomSampling");
                simpleRandomSampling.call();
                this.logger.debug("SimpleRandomSampling finished");
            }
            catch (InterruptedException e) {
                throw e;
            }
            catch (Exception e) {
                this.logger.error("Unexpected exception during simple random sampling!", (Throwable)e);
            }
            if (simpleRandomSampling.getChosenIndices().size() != sampleSizeForStrati[i]) {
                throw new IllegalStateException("Number of samples drawn for stratum " + i + " is " + simpleRandomSampling.getChosenIndices().size() + ", but it should be " + sampleSizeForStrati[i]);
            }
            sampleDeriver.addIndices(stratumBuilder.getIndicesOfNewInstancesInOriginalDataset(simpleRandomSampling.getChosenIndices()));
        }
        if (sampleDeriver.currentSizeOfTarget() != this.sampleSize) {
            throw new IllegalStateException("The deriver says that the target has " + sampleDeriver.currentSizeOfTarget() + " elements, but it should have been configured for " + this.sampleSize);
        }
        this.checkAndConductTermination();
        this.logger.info("Strati sub-samples completed, building the final sample and shuffling it.");
        this.sample = sampleDeriver.build();
        if (this.sample.size() != numSamplesTotal) {
            throw new IllegalStateException("The sample deriver has produced a sample with " + this.sample.size() + " elements while it should have " + numSamplesTotal);
        }
        Collections.shuffle(this.sample, this.random);
        this.logger.info("Overall stratified shuffled sample completed.");
    }

    public void setLoggerName(String loggername) {
        this.logger = LoggerFactory.getLogger((String)loggername);
        if (this.stratiAssigner instanceof ILoggingCustomizable) {
            ((ILoggingCustomizable)this.stratiAssigner).setLoggerName(loggername + ".assigner");
        }
        if (this.stratiAmountSelector instanceof ILoggingCustomizable) {
            if (this.stratiAmountSelector != this.stratiAssigner) {
                ((ILoggingCustomizable)this.stratiAmountSelector).setLoggerName(loggername + ".stratiamountselector");
            } else {
                this.logger.info("Strati assigner and amount selector are the same object. Using .assigner for logging.");
            }
        }
    }

    public String getLoggerName() {
        return this.logger.getName();
    }
}

