/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.jaicore.ml.core.dataset.util;

import ai.libs.jaicore.basic.algorithm.AlgorithmExecutionCanceledException;
import ai.libs.jaicore.basic.algorithm.exceptions.AlgorithmException;
import ai.libs.jaicore.ml.core.dataset.DatasetCreationException;
import ai.libs.jaicore.ml.core.dataset.INumericLabeledAttributeArrayInstance;
import ai.libs.jaicore.ml.core.dataset.IOrderedLabeledAttributeArrayDataset;
import ai.libs.jaicore.ml.core.dataset.sampling.inmemory.stratified.sampling.AttributeBasedStratiAmountSelectorAndAssigner;
import ai.libs.jaicore.ml.core.dataset.sampling.inmemory.stratified.sampling.DiscretizationHelper;
import ai.libs.jaicore.ml.core.dataset.sampling.inmemory.stratified.sampling.StratifiedSampling;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Random;

public class StratifiedSplit<I extends INumericLabeledAttributeArrayInstance<L>, L, D extends IOrderedLabeledAttributeArrayDataset<I, L>> {
    private final D dataset;
    private D trainingData;
    private D testData;
    private final long seed;

    public StratifiedSplit(D dataset, long seed) {
        this.dataset = dataset;
        this.seed = seed;
    }

    public void doSplit(double trainPortion) throws AlgorithmException {
        Random r = new Random(this.seed);
        List<Integer> attributeIndices = Collections.singletonList(this.dataset.getNumberOfAttributes());
        AttributeBasedStratiAmountSelectorAndAssigner selectorAndAssigner = new AttributeBasedStratiAmountSelectorAndAssigner(attributeIndices, DiscretizationHelper.DiscretizationStrategy.EQUAL_SIZE, 10);
        StratifiedSampling stratifiedSampling = new StratifiedSampling(selectorAndAssigner, selectorAndAssigner, r, this.dataset);
        int sampleSize = (int)(trainPortion * (double)this.dataset.size());
        stratifiedSampling.setSampleSize(sampleSize);
        try {
            this.trainingData = (IOrderedLabeledAttributeArrayDataset)stratifiedSampling.call();
            this.testData = (IOrderedLabeledAttributeArrayDataset)this.dataset.createEmpty();
            this.testData.addAll(this.dataset);
            this.testData.removeAll((Collection<?>)this.trainingData);
        }
        catch (AlgorithmExecutionCanceledException e) {
            throw new AlgorithmException("Stratified split has been cancelled");
        }
        catch (InterruptedException e) {
            Thread.currentThread().interrupt();
        }
        catch (DatasetCreationException e) {
            throw new AlgorithmException("Could not create an empty copy of the given dataset.");
        }
    }

    public D getTrainingData() {
        return this.trainingData;
    }

    public D getTestData() {
        return this.testData;
    }
}

