/*
 * Decompiled with CFR 0.152.
 */
package net.sf.tweety.machinelearning;

import java.util.ArrayList;
import java.util.Collection;
import net.sf.tweety.commons.util.Pair;
import net.sf.tweety.machinelearning.Category;
import net.sf.tweety.machinelearning.ClassificationTester;
import net.sf.tweety.machinelearning.Classifier;
import net.sf.tweety.machinelearning.Observation;
import net.sf.tweety.machinelearning.Trainer;
import net.sf.tweety.machinelearning.TrainingSet;

public class CrossValidator<S extends Observation, T extends Category>
extends ClassificationTester<S, T> {
    private int fold;

    public CrossValidator(int fold) {
        if (fold < 2) {
            throw new IllegalArgumentException("Number of partitions must be greater or equal to 2.");
        }
        this.fold = fold;
    }

    /*
     * WARNING - void declaration
     */
    @Override
    public double test(Trainer<S, T> trainer, TrainingSet<S, T> trainingSet) {
        ArrayList partitions = new ArrayList();
        for (int i = 0; i < this.fold; ++i) {
            partitions.add(new TrainingSet());
        }
        for (Category cat : trainingSet.getCategories()) {
            int i = 0;
            for (Pair pair : trainingSet.getObservations(cat)) {
                ((TrainingSet)partitions.get(i % this.fold)).add(pair);
                ++i;
            }
        }
        double perf = 0.0;
        for (int i = 0; i < this.fold; ++i) {
            void var8_16;
            TrainingSet actualTrainingSet = new TrainingSet();
            boolean bl = false;
            while (var8_16 < this.fold) {
                if (i != var8_16) {
                    actualTrainingSet.addAll((Collection)partitions.get((int)var8_16));
                }
                ++var8_16;
            }
            Classifier classifier = trainer.train(actualTrainingSet);
            perf += this.test(classifier, (TrainingSet)partitions.get(i));
        }
        return perf / (double)this.fold;
    }
}

