/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.jaicore.ml.core.filter.sampling.infiles.stratified.sampling;

import ai.libs.jaicore.basic.TempFileHandler;
import ai.libs.jaicore.ml.core.filter.sampling.SampleElementAddedEvent;
import ai.libs.jaicore.ml.core.filter.sampling.infiles.AFileSamplingAlgorithm;
import ai.libs.jaicore.ml.core.filter.sampling.infiles.ArffUtilities;
import ai.libs.jaicore.ml.core.filter.sampling.infiles.ReservoirSampling;
import ai.libs.jaicore.ml.core.filter.sampling.infiles.stratified.sampling.IStratiFileAssigner;
import ai.libs.jaicore.ml.core.filter.sampling.inmemory.WaitForSamplingStepEvent;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import org.api4.java.algorithm.IAlgorithm;
import org.api4.java.algorithm.events.IAlgorithmEvent;
import org.api4.java.algorithm.exceptions.AlgorithmException;
import org.api4.java.algorithm.exceptions.AlgorithmExecutionCanceledException;
import org.api4.java.algorithm.exceptions.AlgorithmTimeoutedException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class StratifiedFileSampling
extends AFileSamplingAlgorithm {
    private Logger logger = LoggerFactory.getLogger(StratifiedFileSampling.class);
    private Random random;
    private TempFileHandler tempFileHandler;
    private BufferedReader reader;
    private IStratiFileAssigner assigner;
    private int datapointAmount;
    private int streamedDatapoints;
    private boolean stratiSamplingStarted;
    private boolean stratiSamplingFinished;
    private ExecutorService executorService;
    private List<String> sample;

    public StratifiedFileSampling(Random random, IStratiFileAssigner stratiFileAssigner, File input) {
        super(input);
        this.random = random;
        this.assigner = stratiFileAssigner;
        this.tempFileHandler = new TempFileHandler();
    }

    public IAlgorithmEvent nextWithException() throws InterruptedException, AlgorithmExecutionCanceledException, AlgorithmException, AlgorithmTimeoutedException {
        switch (this.getState()) {
            case CREATED: {
                try {
                    this.assigner.setArffHeader(ArffUtilities.extractArffHeader((File)this.getInput()));
                    this.assigner.setTempFileHandler(this.tempFileHandler);
                    this.datapointAmount = ArffUtilities.countDatasetEntries((File)this.getInput(), true);
                    this.streamedDatapoints = 0;
                    this.stratiSamplingStarted = false;
                    this.stratiSamplingFinished = false;
                    this.sample = new LinkedList<String>();
                    this.reader = new BufferedReader(new FileReader((File)this.getInput()));
                    this.executorService = Executors.newCachedThreadPool();
                    ArffUtilities.skipWithReaderToDatapoints(this.reader);
                    return this.activate();
                }
                catch (IOException e) {
                    throw new AlgorithmException("Was not able to count the datapoints.", (Throwable)e);
                }
            }
            case ACTIVE: {
                if (this.streamedDatapoints % 100 == 0) {
                    this.checkAndConductTermination();
                }
                if (this.streamedDatapoints < this.datapointAmount) {
                    try {
                        String datapoint = this.reader.readLine();
                        if (datapoint != null && datapoint.trim().length() > 0 && datapoint.trim().charAt(0) != '%') {
                            this.assigner.assignDatapoint(datapoint);
                        }
                        ++this.streamedDatapoints;
                        return new SampleElementAddedEvent((IAlgorithm<?, ?>)this);
                    }
                    catch (IOException e) {
                        throw new AlgorithmException("Was not able to read datapoint line form input file", (Throwable)e);
                    }
                }
                this.logger.debug("All datapoints are assigned, now sampling from strati.");
                try {
                    this.reader.close();
                }
                catch (IOException e) {
                    throw new AlgorithmException("Was not able to close input file reader.", (Throwable)e);
                }
                if (!this.stratiSamplingStarted) {
                    this.stratiSamplingStarted = true;
                    this.startReservoirSamplingForStrati(this.assigner.getAllCreatedStrati());
                    return new WaitForSamplingStepEvent((IAlgorithm<?, ?>)this);
                }
                if (!this.stratiSamplingFinished) {
                    if (this.executorService.isTerminated()) {
                        this.stratiSamplingFinished = true;
                    } else {
                        Thread.sleep(100L);
                    }
                    return new WaitForSamplingStepEvent((IAlgorithm<?, ?>)this);
                }
                try {
                    if (this.sample.size() != this.sampleSize.intValue()) {
                        throw new IllegalStateException("Will write " + this.sample.size() + " instead of " + this.sampleSize + " instances.");
                    }
                    for (int i = 0; i < this.sample.size(); ++i) {
                        if (i % 100 == 0) {
                            this.checkAndConductTermination();
                        }
                        this.outputFileWriter.write(this.sample.get(i) + "\n");
                    }
                    return this.terminate();
                }
                catch (IOException e) {
                    throw new AlgorithmException("Was not able to write datapoint into output file.", (Throwable)e);
                }
            }
            case INACTIVE: {
                if (this.streamedDatapoints < this.datapointAmount || !this.stratiSamplingStarted || !this.stratiSamplingFinished) {
                    throw new AlgorithmException("Expected sample size was not reached before termination");
                }
                return this.terminate();
            }
        }
        this.cleanUp();
        throw new IllegalStateException("Unknown algorithm state " + this.getState());
    }

    @Override
    protected void cleanUp() {
        this.executorService.shutdownNow();
        this.tempFileHandler.cleanUp();
    }

    private void startReservoirSamplingForStrati(Map<String, Integer> strati) {
        this.logger.info("Start reservoir sampling for strati.");
        int[] sampleSizeForStrati = new int[strati.keySet().size()];
        int i = 0;
        int numOfSamplesThatWillBeCreated = 0;
        ArrayList<Integer> fillupStrati = new ArrayList<Integer>();
        for (Map.Entry<String, Integer> entry : strati.entrySet()) {
            sampleSizeForStrati[i] = (int)Math.floor((float)((double)this.sampleSize.intValue() * ((double)strati.get(entry.getKey()).intValue() / (double)this.datapointAmount)));
            numOfSamplesThatWillBeCreated += sampleSizeForStrati[i];
            fillupStrati.add(i);
            ++i;
        }
        while (numOfSamplesThatWillBeCreated < this.sampleSize) {
            int indexForNextFillUp;
            Collections.shuffle(fillupStrati, this.random);
            int n = indexForNextFillUp = ((Integer)fillupStrati.remove(0)).intValue();
            sampleSizeForStrati[n] = sampleSizeForStrati[n] + 1;
            ++numOfSamplesThatWillBeCreated;
        }
        if (numOfSamplesThatWillBeCreated != this.sampleSize) {
            throw new IllegalStateException("The strati sum up to a size of " + numOfSamplesThatWillBeCreated + " instead of " + this.sampleSize + ".");
        }
        i = 0;
        for (Map.Entry<String, Integer> entry : strati.entrySet()) {
            int index = i++;
            this.executorService.execute(() -> {
                String outputFile = this.tempFileHandler.createTempFile();
                ReservoirSampling reservoirSampling = new ReservoirSampling(this.random, this.tempFileHandler.getTempFile((String)entry.getKey()));
                reservoirSampling.setSampleSize(sampleSizeForStrati[index]);
                try {
                    String line;
                    reservoirSampling.setOutputFileName(this.tempFileHandler.getTempFile(outputFile).getAbsolutePath());
                    reservoirSampling.call();
                    BufferedReader bufferedReader = this.tempFileHandler.getFileReaderForTempFile(outputFile);
                    ArffUtilities.skipWithReaderToDatapoints(bufferedReader);
                    while ((line = bufferedReader.readLine()) != null) {
                        if (line.trim().equals("") || line.trim().charAt(0) == '%') continue;
                        List<String> list = this.sample;
                        synchronized (list) {
                            this.sample.add(line);
                        }
                    }
                    return;
                }
                catch (Exception e) {
                    this.logger.error("Unexpected exception during reservoir sampling!", (Throwable)e);
                }
            });
        }
        this.executorService.shutdown();
    }
}

