/*
 * Decompiled with CFR 0.152.
 */
package elki.evaluation.classification.holdout;

import elki.datasource.bundle.MultipleObjectsBundle;
import elki.evaluation.classification.holdout.RandomizedHoldout;
import elki.evaluation.classification.holdout.TrainingAndTestSet;
import elki.utilities.optionhandling.OptionID;
import elki.utilities.optionhandling.constraints.CommonConstraints;
import elki.utilities.optionhandling.constraints.ParameterConstraint;
import elki.utilities.optionhandling.parameterization.Parameterization;
import elki.utilities.optionhandling.parameters.IntParameter;
import elki.utilities.random.RandomFactory;
import java.util.ArrayList;
import java.util.Random;

public class DisjointCrossValidation
extends RandomizedHoldout {
    protected int nfold;
    protected int fold;
    protected int[] assignment;
    protected int[] sizes;

    public DisjointCrossValidation(RandomFactory random, int nfold) {
        super(random);
        this.nfold = nfold;
    }

    @Override
    public void initialize(MultipleObjectsBundle bundle) {
        super.initialize(bundle);
        this.fold = 0;
        Random rnd = this.random.getSingleThreadedRandom();
        this.sizes = new int[this.nfold];
        this.assignment = new int[bundle.dataLength()];
        for (int i = 0; i < this.assignment.length; ++i) {
            int p;
            this.assignment[i] = p = rnd.nextInt(this.nfold);
            int n = p;
            this.sizes[n] = this.sizes[n] + 1;
        }
    }

    @Override
    public int numberOfPartitions() {
        return this.nfold;
    }

    @Override
    public TrainingAndTestSet nextPartitioning() {
        if (this.fold >= this.nfold) {
            return null;
        }
        int tesize = this.sizes[this.fold];
        int trsize = this.bundle.dataLength() - tesize;
        MultipleObjectsBundle training = new MultipleObjectsBundle();
        MultipleObjectsBundle test = new MultipleObjectsBundle();
        int cs = this.bundle.metaLength();
        for (int c = 0; c < cs; ++c) {
            ArrayList<Object> tr = new ArrayList<Object>(trsize);
            ArrayList te = new ArrayList(tesize);
            for (int i = 0; i < this.bundle.dataLength(); ++i) {
                (this.assignment[i] != this.fold ? tr : te).add(this.bundle.data(i, c));
            }
            training.appendColumn(this.bundle.meta(c), tr);
            test.appendColumn(this.bundle.meta(c), te);
        }
        ++this.fold;
        return new TrainingAndTestSet(training, test, this.labels);
    }

    public static class Par
    extends RandomizedHoldout.Par {
        public static final int N_DEFAULT = 10;
        public static final OptionID NFOLD_ID = new OptionID("nfold", "Number of folds for cross-validation.");
        protected int nfold = 10;

        @Override
        public void configure(Parameterization config) {
            super.configure(config);
            ((IntParameter)new IntParameter(NFOLD_ID, 10).addConstraint((ParameterConstraint)CommonConstraints.GREATER_EQUAL_ONE_INT)).grab(config, x -> {
                this.nfold = x;
            });
        }

        public DisjointCrossValidation make() {
            return new DisjointCrossValidation(this.random, this.nfold);
        }
    }
}

