/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.jaicore.ml.core.dataset.splitter;

import ai.libs.jaicore.basic.reconstruction.ReconstructionInstruction;
import ai.libs.jaicore.basic.reconstruction.ReconstructionUtil;
import ai.libs.jaicore.ml.core.dataset.splitter.DatasetSplitSet;
import ai.libs.jaicore.ml.core.filter.sampling.inmemory.SimpleRandomSampling;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Random;
import java.util.stream.IntStream;
import org.api4.java.ai.ml.core.dataset.IDataset;
import org.api4.java.ai.ml.core.dataset.splitter.IFoldSizeConfigurableRandomDatasetSplitter;
import org.api4.java.ai.ml.core.dataset.splitter.IRandomDatasetSplitter;
import org.api4.java.ai.ml.core.dataset.splitter.SplitFailedException;
import org.api4.java.ai.ml.core.evaluation.execution.IDatasetSplitSet;
import org.api4.java.ai.ml.core.evaluation.execution.IDatasetSplitSetGenerator;
import org.api4.java.ai.ml.core.exception.DatasetCreationException;
import org.api4.java.algorithm.exceptions.AlgorithmException;
import org.api4.java.algorithm.exceptions.AlgorithmExecutionCanceledException;
import org.api4.java.algorithm.exceptions.AlgorithmTimeoutedException;
import org.api4.java.common.control.ILoggingCustomizable;
import org.api4.java.common.reconstruction.IReconstructible;
import org.api4.java.common.reconstruction.IReconstructionInstruction;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class RandomHoldoutSplitter<D extends IDataset<?>>
implements IRandomDatasetSplitter<D>,
IDatasetSplitSetGenerator<D>,
ILoggingCustomizable,
IFoldSizeConfigurableRandomDatasetSplitter<D> {
    private final Random rand;
    private final double[] portions;
    private Logger logger = LoggerFactory.getLogger(RandomHoldoutSplitter.class);

    public RandomHoldoutSplitter(double ... portions) {
        this(new Random(), portions);
    }

    public RandomHoldoutSplitter(Random rand, double ... portions) {
        double portionSum = Arrays.stream(portions).sum();
        if (!(portionSum > 0.0) || !(portionSum <= 1.0)) {
            throw new IllegalArgumentException("The sum of the given portions must not be less or equal 0 or larger than 1. Given portions: " + Arrays.toString(portions));
        }
        this.rand = rand;
        if (portionSum == 1.0) {
            this.portions = portions;
        } else {
            this.portions = Arrays.copyOf(portions, portions.length + 1);
            this.portions[portions.length] = 1.0 - portionSum;
        }
    }

    public static <D extends IDataset<?>> List<D> createSplit(D data, long seed, double ... portions) throws SplitFailedException, InterruptedException {
        return RandomHoldoutSplitter.createSplit(data, seed, LoggerFactory.getLogger(RandomHoldoutSplitter.class), portions);
    }

    public static <D extends IDataset<?>> List<D> createSplit(D data, long seed, Logger logger, double ... pPortions) throws SplitFailedException, InterruptedException {
        double[] portions;
        double portionsSum = Arrays.stream(pPortions).sum();
        if (portionsSum > 1.0) {
            throw new IllegalArgumentException("Sum of portions must not be greater than 1.");
        }
        if (portionsSum < 0.99999999) {
            portions = new double[pPortions.length + 1];
            IntStream.range(0, pPortions.length).forEach(x -> {
                portions[x] = pPortions[x];
            });
            portions[portions.length - 1] = 1.0 - portionsSum;
        } else {
            portions = pPortions;
        }
        logger.info("Creating new split with {} folds.", (Object)portions.length);
        ArrayList<Object> folds = new ArrayList<Object>(portions.length);
        int totalItems = data.size();
        try {
            Object copy = data.createCopy();
            Collections.shuffle(copy, new Random(seed));
            double remainingMass = 1.0;
            for (int numFold = 0; numFold < portions.length; ++numFold) {
                double portion;
                if ((remainingMass -= (portion = numFold < portions.length ? portions[numFold] : remainingMass)) > 0.0) {
                    SimpleRandomSampling<IDataset> subSampler = new SimpleRandomSampling<IDataset>(new Random(seed), (IDataset)copy);
                    int sampleSize = (int)Math.round(portion * (double)totalItems);
                    subSampler.setSampleSize(sampleSize);
                    logger.debug("Computing fold of size {}/{}, i.e. a portion of {}", new Object[]{sampleSize, totalItems, portion});
                    Object fold = subSampler.call();
                    RandomHoldoutSplitter.addReconstructionInfo(data, fold, seed, numFold, portions);
                    folds.add(fold);
                    copy = subSampler.getComplementOfLastSample();
                    logger.debug("Reduced the data by the fold. Remaining items: {}", (Object)copy.size());
                    continue;
                }
                logger.debug("This is the last fold, which exhausts the complete original data, so no more sampling will be conducted.");
                folds.add(copy);
                RandomHoldoutSplitter.addReconstructionInfo(data, copy, seed, numFold, portions);
            }
        }
        catch (DatasetCreationException | AlgorithmException | AlgorithmExecutionCanceledException | AlgorithmTimeoutedException e) {
            throw new SplitFailedException(e);
        }
        if (folds.size() != portions.length) {
            throw new IllegalStateException("Needed to generate " + portions.length + " folds, but only produced " + folds.size());
        }
        return folds;
    }

    private static void addReconstructionInfo(IDataset<?> data, IDataset<?> fold, long seed, int numFold, double[] portions) {
        if (data instanceof IReconstructible && ReconstructionUtil.areInstructionsNonEmptyIfReconstructibilityClaimed(data)) {
            ((IReconstructible)data).getConstructionPlan().getInstructions().forEach(arg_0 -> ((IReconstructible)((IReconstructible)fold)).addInstruction(arg_0));
            ((IReconstructible)fold).addInstruction((IReconstructionInstruction)new ReconstructionInstruction(RandomHoldoutSplitter.class.getName(), "getFoldOfSplit", new Class[]{IDataset.class, Long.TYPE, Integer.TYPE, double[].class}, new Object[]{"this", seed, numFold, portions}));
        }
    }

    public static <D extends IDataset<?>> D getFoldOfSplit(D data, long seed, int fold, double ... portions) throws SplitFailedException, InterruptedException {
        return RandomHoldoutSplitter.createSplit(data, seed, portions).get(fold);
    }

    public List<D> split(D data, Random random) throws SplitFailedException, InterruptedException {
        return RandomHoldoutSplitter.createSplit(data, this.rand.nextLong(), this.logger, this.portions);
    }

    public int getNumberOfFoldsPerSplit() {
        return this.portions.length;
    }

    public int getNumSplitsPerSet() {
        return 1;
    }

    public int getNumFoldsPerSplit() {
        return this.portions.length;
    }

    public IDatasetSplitSet<D> nextSplitSet(D data) throws InterruptedException, SplitFailedException {
        return new DatasetSplitSet(Arrays.asList(this.split((IDataset)data)));
    }

    public String getLoggerName() {
        return this.logger.getName();
    }

    public void setLoggerName(String name) {
        this.logger = LoggerFactory.getLogger((String)name);
    }

    public String toString() {
        return "RandomHoldoutSplitter [rand=" + this.rand + ", portions=" + Arrays.toString(this.portions) + "]";
    }

    public List<D> split(D data, Random random, double ... relativeFoldSizes) throws SplitFailedException, InterruptedException {
        return RandomHoldoutSplitter.createSplit(data, random.nextLong(), this.logger, relativeFoldSizes);
    }
}

