/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.jaicore.ml.core.filter;

import ai.libs.jaicore.basic.reconstruction.ReconstructionInstruction;
import ai.libs.jaicore.ml.core.dataset.splitter.ReproducibleSplit;
import ai.libs.jaicore.ml.core.filter.FilterBasedDatasetSplitter;
import ai.libs.jaicore.ml.core.filter.sampling.inmemory.factories.LabelBasedStratifiedSamplingFactory;
import ai.libs.jaicore.ml.core.filter.sampling.inmemory.factories.SimpleRandomSamplingFactory;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Random;
import org.api4.java.ai.ml.core.dataset.IDataset;
import org.api4.java.ai.ml.core.dataset.splitter.SplitFailedException;
import org.api4.java.ai.ml.core.dataset.supervised.ILabeledDataset;
import org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance;
import org.api4.java.ai.ml.core.exception.DatasetCreationException;
import org.api4.java.common.reconstruction.IReconstructible;
import org.api4.java.common.reconstruction.IReconstructionInstruction;

public class SplitterUtil {
    private SplitterUtil() {
    }

    public static <D extends ILabeledDataset<?>> List<D> getLabelStratifiedTrainTestSplit(D dataset, long seed, double relativeTrainSize) throws SplitFailedException, InterruptedException {
        boolean isReproducible = dataset instanceof IReconstructible;
        List folds = new FilterBasedDatasetSplitter<D>(new LabelBasedStratifiedSamplingFactory(), relativeTrainSize, new Random(seed)).split(dataset);
        if (!isReproducible) {
            return folds;
        }
        try {
            ReconstructionInstruction instruction = new ReconstructionInstruction(SplitterUtil.class.getMethod("getLabelStratifiedTrainTestSplit", ILabeledDataset.class, Long.TYPE, Double.TYPE), new Object[]{"this", seed, relativeTrainSize});
            return new ReproducibleSplit(instruction, dataset, (IDataset[])new ILabeledDataset[]{(ILabeledDataset)folds.get(0), (ILabeledDataset)folds.get(1)});
        }
        catch (NoSuchMethodException | SecurityException e) {
            throw new SplitFailedException((Throwable)e);
        }
    }

    public static List<ILabeledDataset<?>> getLabelStratifiedTrainTestSplit(ILabeledDataset<?> dataset, Random random, double relativeTrainSize) throws SplitFailedException, InterruptedException {
        return SplitterUtil.getLabelStratifiedTrainTestSplit(dataset, random.nextLong(), relativeTrainSize);
    }

    public static ILabeledDataset<?> getTrainFoldOfLabelStratifiedTrainTestSplit(ILabeledDataset<?> dataset, Random random, double relativeTrainSize) throws SplitFailedException, InterruptedException {
        return SplitterUtil.getLabelStratifiedTrainTestSplit(dataset, random, relativeTrainSize).get(0);
    }

    public static ILabeledDataset<?> getTrainFoldOfLabelStratifiedTrainTestSplit(ILabeledDataset<?> dataset, long seed, double relativeTrainSize) throws SplitFailedException, InterruptedException {
        return SplitterUtil.getLabelStratifiedTrainTestSplit(dataset, seed, relativeTrainSize).get(0);
    }

    public static ILabeledDataset<?> getTestFoldOfLabelStratifiedTrainTestSplit(ILabeledDataset<?> dataset, Random random, double relativeTrainSize) throws SplitFailedException, InterruptedException {
        return SplitterUtil.getLabelStratifiedTrainTestSplit(dataset, random, relativeTrainSize).get(1);
    }

    public static ILabeledDataset<?> getTestFoldOfLabelStratifiedTrainTestSplit(ILabeledDataset<?> dataset, long seed, double relativeTrainSize) throws SplitFailedException, InterruptedException {
        return SplitterUtil.getLabelStratifiedTrainTestSplit(dataset, seed, relativeTrainSize).get(1);
    }

    public static List<ILabeledDataset<?>> getSimpleTrainTestSplit(ILabeledDataset<?> dataset, long seed, double relativeTrainSize) throws SplitFailedException, InterruptedException {
        boolean isReproducible = dataset instanceof IReconstructible;
        List<ILabeledDataset<?>> folds = SplitterUtil.getSimpleTrainTestSplit(dataset, new Random(seed), relativeTrainSize);
        if (!isReproducible) {
            return folds;
        }
        try {
            IReconstructible rDataset = (IReconstructible)dataset;
            IReconstructible trainFold = (IReconstructible)folds.get(0);
            IReconstructible testFold = (IReconstructible)folds.get(1);
            rDataset.getConstructionPlan().getInstructions().forEach(i -> {
                trainFold.addInstruction(i);
                testFold.addInstruction(i);
            });
            trainFold.addInstruction((IReconstructionInstruction)new ReconstructionInstruction(SplitterUtil.class.getMethod("getTrainFoldOfLabelStratifiedTrainTestSplit", ILabeledDataset.class, Long.TYPE, Double.TYPE), new Object[]{"this", seed, relativeTrainSize}));
            testFold.addInstruction((IReconstructionInstruction)new ReconstructionInstruction(SplitterUtil.class.getMethod("getTestFoldOfLabelStratifiedTrainTestSplit", ILabeledDataset.class, Long.TYPE, Double.TYPE), new Object[]{"this", seed, relativeTrainSize}));
            ReconstructionInstruction instruction = new ReconstructionInstruction(SplitterUtil.class.getMethod("getLabelStratifiedTrainTestSplit", ILabeledDataset.class, Long.TYPE, Double.TYPE), new Object[]{"this", seed, relativeTrainSize});
            return new ReproducibleSplit(instruction, dataset, (IDataset[])new ILabeledDataset[]{folds.get(0), folds.get(1)});
        }
        catch (NoSuchMethodException | SecurityException e) {
            throw new SplitFailedException((Throwable)e);
        }
    }

    public static List<ILabeledDataset<?>> getSimpleTrainTestSplit(ILabeledDataset<?> dataset, Random random, double relativeTrainSize) throws SplitFailedException, InterruptedException {
        return new FilterBasedDatasetSplitter(new SimpleRandomSamplingFactory(), relativeTrainSize, random).split(dataset);
    }

    public static ILabeledDataset<?> getTrainFoldOfSimpleTrainTestSplit(ILabeledDataset<?> dataset, Random random, double relativeTrainSize) throws SplitFailedException, InterruptedException {
        return SplitterUtil.getSimpleTrainTestSplit(dataset, random, relativeTrainSize).get(0);
    }

    public static ILabeledDataset<?> getTrainFoldOfSimpleTrainTestSplit(ILabeledDataset<?> dataset, long seed, double relativeTrainSize) throws SplitFailedException, InterruptedException {
        return SplitterUtil.getSimpleTrainTestSplit(dataset, seed, relativeTrainSize).get(0);
    }

    public static ILabeledDataset<?> getTestFoldOfSimpleTrainTestSplit(ILabeledDataset<?> dataset, Random random, double relativeTrainSize) throws SplitFailedException, InterruptedException {
        return SplitterUtil.getSimpleTrainTestSplit(dataset, random, relativeTrainSize).get(1);
    }

    public static ILabeledDataset<?> getTestFoldOfSimpleTrainTestSplit(ILabeledDataset<?> dataset, long seed, double relativeTrainSize) throws SplitFailedException, InterruptedException {
        return SplitterUtil.getSimpleTrainTestSplit(dataset, seed, relativeTrainSize).get(1);
    }

    public static List<ILabeledDataset<ILabeledInstance>> getRealizationOfSplitSpecification(ILabeledDataset<? extends ILabeledInstance> dataset, Collection<? extends Collection<Integer>> splitSpec) throws DatasetCreationException, InterruptedException {
        ArrayList<ILabeledDataset<ILabeledInstance>> split = new ArrayList<ILabeledDataset<ILabeledInstance>>(splitSpec.size());
        for (Collection<Integer> collection : splitSpec) {
            ILabeledDataset foldDataset = dataset.createEmptyCopy();
            for (int index : collection) {
                foldDataset.add((Object)((ILabeledInstance)dataset.get(index)));
            }
            split.add((ILabeledDataset<ILabeledInstance>)foldDataset);
        }
        return split;
    }
}

