/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.jaicore.ml.core.dataset.util;

import ai.libs.jaicore.basic.algorithm.AlgorithmExecutionCanceledException;
import ai.libs.jaicore.basic.algorithm.exceptions.AlgorithmException;
import ai.libs.jaicore.ml.core.dataset.IDataset;
import ai.libs.jaicore.ml.core.dataset.IInstance;
import ai.libs.jaicore.ml.core.dataset.sampling.inmemory.stratified.sampling.AttributeBasedStratiAmountSelectorAndAssigner;
import ai.libs.jaicore.ml.core.dataset.sampling.inmemory.stratified.sampling.DiscretizationHelper;
import ai.libs.jaicore.ml.core.dataset.sampling.inmemory.stratified.sampling.StratifiedSampling;
import java.util.Collections;
import java.util.List;
import java.util.Random;

public class StratifiedSplit<I extends IInstance> {
    private final IDataset<I> dataset;
    private IDataset<I> trainingData;
    private IDataset<I> testData;
    private final long seed;

    public StratifiedSplit(IDataset<I> dataset, long seed) {
        this.dataset = dataset;
        this.seed = seed;
    }

    public void doSplit(double trainPortion) throws AlgorithmException {
        Random r = new Random(this.seed);
        List<Integer> attributeIndices = Collections.singletonList(this.dataset.getNumberOfAttributes());
        AttributeBasedStratiAmountSelectorAndAssigner selectorAndAssigner = new AttributeBasedStratiAmountSelectorAndAssigner(attributeIndices, DiscretizationHelper.DiscretizationStrategy.EQUAL_SIZE, 10);
        StratifiedSampling stratifiedSampling = new StratifiedSampling(selectorAndAssigner, selectorAndAssigner, r, this.dataset);
        int sampleSize = (int)(trainPortion * (double)this.dataset.size());
        stratifiedSampling.setSampleSize(sampleSize);
        try {
            this.trainingData = stratifiedSampling.call();
            this.testData = this.dataset.createEmpty();
            this.testData.addAll(this.dataset);
            this.testData.removeAll(this.trainingData);
        }
        catch (AlgorithmExecutionCanceledException e) {
            throw new AlgorithmException("Stratified split has been cancelled");
        }
        catch (InterruptedException e) {
            Thread.currentThread().interrupt();
        }
    }

    public IDataset<I> getTrainingData() {
        return this.trainingData;
    }

    public IDataset<I> getTestData() {
        return this.testData;
    }
}

