/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.jaicore.ml.core.filter;

import ai.libs.jaicore.basic.reconstruction.ReconstructionInstruction;
import ai.libs.jaicore.basic.reconstruction.ReconstructionUtil;
import ai.libs.jaicore.ml.core.dataset.splitter.ReproducibleSplit;
import ai.libs.jaicore.ml.core.filter.sampling.inmemory.factories.interfaces.ISamplingAlgorithmFactory;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.Random;
import org.api4.java.ai.ml.core.dataset.IDataset;
import org.api4.java.ai.ml.core.dataset.splitter.IDatasetSplitter;
import org.api4.java.ai.ml.core.dataset.splitter.IFoldSizeConfigurableRandomDatasetSplitter;
import org.api4.java.ai.ml.core.dataset.splitter.SplitFailedException;
import org.api4.java.ai.ml.core.exception.DatasetCreationException;
import org.api4.java.common.control.ILoggingCustomizable;
import org.api4.java.common.reconstruction.IReconstructible;
import org.api4.java.common.reconstruction.IReconstructionInstruction;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class FilterBasedDatasetSplitter<D extends IDataset<?>>
implements IDatasetSplitter<D>,
IFoldSizeConfigurableRandomDatasetSplitter<D>,
ILoggingCustomizable {
    private final ISamplingAlgorithmFactory<D, ?> samplerFactory;
    private final double relSampleSize;
    private final Random random;
    private Logger logger = LoggerFactory.getLogger(FilterBasedDatasetSplitter.class);

    public FilterBasedDatasetSplitter(ISamplingAlgorithmFactory<D, ?> samplerFactory) {
        this(samplerFactory, Double.NaN, null);
    }

    public FilterBasedDatasetSplitter(ISamplingAlgorithmFactory<D, ?> samplerFactory, double relSampleSize, Random random) {
        this.samplerFactory = samplerFactory;
        this.relSampleSize = relSampleSize;
        this.random = random;
    }

    public List<D> split(D data) throws SplitFailedException, InterruptedException {
        if (this.random == null || Double.isNaN(this.relSampleSize)) {
            throw new IllegalStateException("The splitter has not been initialized with a random source and relative sample size configured. Provide these explicitly in the split method or in the initialization.");
        }
        return this.split(data, this.random, this.relSampleSize);
    }

    public int getNumberOfFoldsPerSplit() {
        return 2;
    }

    public List<D> split(D data, Random random, double ... relativeFoldSizes) throws SplitFailedException, InterruptedException {
        return FilterBasedDatasetSplitter.getSplit(data, this.samplerFactory, random.nextLong(), this.logger, relativeFoldSizes);
    }

    public static <D extends IDataset<?>> List<D> getSplit(D data, ISamplingAlgorithmFactory<D, ?> samplerFactory, long seed, List<Double> relativeFoldSizes) throws InterruptedException, SplitFailedException {
        if (relativeFoldSizes.size() > 1) {
            return FilterBasedDatasetSplitter.getSplit(data, samplerFactory, seed, relativeFoldSizes.get(0), relativeFoldSizes.get(1));
        }
        return FilterBasedDatasetSplitter.getSplit(data, samplerFactory, seed, relativeFoldSizes.get(0));
    }

    public static <D extends IDataset<?>> List<D> getSplit(D data, ISamplingAlgorithmFactory<D, ?> samplerFactory, long seed, double ... relativeFoldSizes) throws InterruptedException, SplitFailedException {
        return FilterBasedDatasetSplitter.getSplit(data, samplerFactory, seed, LoggerFactory.getLogger(FilterBasedDatasetSplitter.class), relativeFoldSizes);
    }

    public static <D extends IDataset<?>> List<D> getSplit(D data, ISamplingAlgorithmFactory<D, ?> samplerFactory, long seed, Logger logger, double ... relativeFoldSizes) throws InterruptedException, SplitFailedException {
        Objects.requireNonNull(data);
        if (data.isEmpty()) {
            throw new IllegalArgumentException("Cannot split empty dataset.");
        }
        if (relativeFoldSizes.length > 2 || relativeFoldSizes.length == 2 && relativeFoldSizes[0] + relativeFoldSizes[1] != 1.0) {
            throw new IllegalArgumentException("Invalid fold size specification " + Arrays.toString(relativeFoldSizes));
        }
        if (data instanceof IReconstructible && !(samplerFactory instanceof IReconstructible)) {
            throw new IllegalStateException("Given data is reproducible and so should the splitters, but the sampler factory used to create the sampling algorithm is not reproducible.");
        }
        int size = (int)Math.round((double)data.size() * relativeFoldSizes[0]);
        logger.info("Drawing 2-fold split with size {} for the first fold.", (Object)size);
        Object sampler = samplerFactory.getAlgorithm(size, data, new Random(seed));
        if (sampler instanceof ILoggingCustomizable) {
            ((ILoggingCustomizable)sampler).setLoggerName(logger.getName() + ".sampler");
        }
        try {
            IDataset firstFold = sampler.nextSample();
            logger.debug("Sample for first fold completed, now computing the complement to fill the second fold.");
            IDataset secondFold = sampler.getComplementOfLastSample();
            logger.info("Fold creation completed. Adding reconstruction information.");
            if (data instanceof IReconstructible) {
                if (!ReconstructionUtil.areInstructionsNonEmptyIfReconstructibilityClaimed(data)) {
                    logger.info("Not making the split reproducible since the original data is not reproducible.");
                    return Arrays.asList(firstFold, secondFold);
                }
                ArrayList<Double> portionsAsList = new ArrayList<Double>();
                for (double d : relativeFoldSizes) {
                    portionsAsList.add(d);
                }
                ArrayList instructions = new ArrayList(((IReconstructible)data).getConstructionPlan().getInstructions());
                instructions.forEach(arg_0 -> ((IReconstructible)((IReconstructible)firstFold)).addInstruction(arg_0));
                ReconstructionInstruction rInstForFirstFold = new ReconstructionInstruction(FilterBasedDatasetSplitter.class.getName(), "getFoldOfSplit", new Class[]{IDataset.class, ISamplingAlgorithmFactory.class, Long.TYPE, Integer.TYPE, List.class}, new Object[]{"this", samplerFactory, seed, 0, portionsAsList});
                ((IReconstructible)firstFold).addInstruction((IReconstructionInstruction)rInstForFirstFold);
                instructions.forEach(arg_0 -> ((IReconstructible)((IReconstructible)secondFold)).addInstruction(arg_0));
                ReconstructionInstruction rInstForSecondFold = new ReconstructionInstruction(FilterBasedDatasetSplitter.class.getName(), "getFoldOfSplit", new Class[]{IDataset.class, ISamplingAlgorithmFactory.class, Long.TYPE, Integer.TYPE, List.class}, new Object[]{"this", samplerFactory, seed, 1, portionsAsList});
                ((IReconstructible)secondFold).addInstruction((IReconstructionInstruction)rInstForSecondFold);
                ReconstructionUtil.requireNonEmptyInstructionsIfReconstructibilityClaimed((Object)firstFold);
                ReconstructionUtil.requireNonEmptyInstructionsIfReconstructibilityClaimed((Object)secondFold);
                ReconstructionInstruction inst = new ReconstructionInstruction(FilterBasedDatasetSplitter.class.getName(), "getSplit", new Class[]{IDataset.class, ISamplingAlgorithmFactory.class, Long.TYPE, List.class}, new Object[]{"this", samplerFactory, seed, Arrays.asList(new double[][]{relativeFoldSizes})});
                logger.info("Sampling-based split completed, returning two folds of sizes {} and {}.", (Object)firstFold.size(), (Object)secondFold.size());
                return new ReproducibleSplit(inst, data, new IDataset[]{firstFold, secondFold});
            }
            logger.info("Sampling-based split completed, returning two folds of sizes {} and {}.", (Object)firstFold.size(), (Object)secondFold.size());
            return Arrays.asList(firstFold, secondFold);
        }
        catch (DatasetCreationException e) {
            throw new SplitFailedException((Throwable)e);
        }
    }

    public static <D extends IDataset<?>> D getFoldOfSplit(D data, ISamplingAlgorithmFactory<D, ?> samplerFactory, long seed, int fold, List<Double> relativeFoldSizes) throws InterruptedException, SplitFailedException {
        return FilterBasedDatasetSplitter.getSplit(data, samplerFactory, seed, relativeFoldSizes).get(fold);
    }

    public String getLoggerName() {
        return this.logger.getName();
    }

    public void setLoggerName(String name) {
        this.logger = LoggerFactory.getLogger((String)name);
    }
}

