/*
 * Decompiled with CFR 0.152.
 */
package elki.application;

import elki.Algorithm;
import elki.application.AbstractApplication;
import elki.classification.Classifier;
import elki.data.ClassLabel;
import elki.data.type.TypeInformation;
import elki.data.type.TypeUtil;
import elki.database.AbstractDatabase;
import elki.database.Database;
import elki.database.StaticArrayDatabase;
import elki.database.ids.DBIDIter;
import elki.database.ids.DBIDRef;
import elki.database.relation.Relation;
import elki.datasource.DatabaseConnection;
import elki.datasource.FileBasedDatabaseConnection;
import elki.datasource.MultipleObjectsBundleDatabaseConnection;
import elki.datasource.bundle.MultipleObjectsBundle;
import elki.evaluation.classification.ConfusionMatrix;
import elki.evaluation.classification.holdout.Holdout;
import elki.evaluation.classification.holdout.StratifiedCrossValidation;
import elki.evaluation.classification.holdout.TrainingAndTestSet;
import elki.index.IndexFactory;
import elki.logging.Logging;
import elki.logging.statistics.Duration;
import elki.logging.statistics.Statistic;
import elki.utilities.optionhandling.OptionID;
import elki.utilities.optionhandling.parameterization.Parameterization;
import elki.utilities.optionhandling.parameters.ObjectListParameter;
import elki.utilities.optionhandling.parameters.ObjectParameter;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;

public class ClassifierHoldoutEvaluationTask<O>
extends AbstractApplication {
    private static final Logging LOG = Logging.getLogger(ClassifierHoldoutEvaluationTask.class);
    protected DatabaseConnection databaseConnection = null;
    protected Collection<? extends IndexFactory<?>> indexFactories;
    protected Classifier<O> algorithm;
    protected Holdout holdout;

    public ClassifierHoldoutEvaluationTask(DatabaseConnection databaseConnection, Collection<? extends IndexFactory<?>> indexFactories, Classifier<O> algorithm, Holdout holdout) {
        this.databaseConnection = databaseConnection;
        this.indexFactories = indexFactories;
        this.algorithm = algorithm;
        this.holdout = holdout;
    }

    public void run() {
        Duration ptime = LOG.newDuration("evaluation.time.load").begin();
        MultipleObjectsBundle allData = this.databaseConnection.loadData();
        this.holdout.initialize(allData);
        LOG.statistics((Statistic)ptime.end());
        Duration time = LOG.newDuration("evaluation.time.total").begin();
        ArrayList<ClassLabel> labels = this.holdout.getLabels();
        int[][] confusion = new int[labels.size()][labels.size()];
        for (int p = 0; p < this.holdout.numberOfPartitions(); ++p) {
            TrainingAndTestSet partition = this.holdout.nextPartitioning();
            String fold = ((Object)((Object)this)).getClass().getName() + ".fold-" + (p + 1);
            Duration dur = LOG.newDuration(fold + ".train.init").begin();
            StaticArrayDatabase db = new StaticArrayDatabase((DatabaseConnection)new MultipleObjectsBundleDatabaseConnection(partition.getTraining()), this.indexFactories);
            db.initialize();
            LOG.statistics((Statistic)dur.end());
            dur = LOG.newDuration(fold + ".train.time").begin();
            Relation lrel = db.getRelation((TypeInformation)TypeUtil.CLASSLABEL, new Object[0]);
            this.algorithm.buildClassifier((Database)db, (Relation<ClassLabel>)lrel);
            LOG.statistics((Statistic)dur.end());
            dur = LOG.newDuration(fold + ".test.init").begin();
            StaticArrayDatabase testdb = new StaticArrayDatabase((DatabaseConnection)new MultipleObjectsBundleDatabaseConnection(partition.getTest()));
            testdb.initialize();
            Relation testdata = testdb.getRelation(this.algorithm.getInputTypeRestriction()[0], new Object[0]);
            Relation testlabels = testdb.getRelation((TypeInformation)TypeUtil.CLASSLABEL, new Object[0]);
            LOG.statistics((Statistic)dur.end());
            dur = LOG.newDuration(fold + ".evaluation.time").begin();
            DBIDIter iter = testdata.iterDBIDs();
            while (iter.valid()) {
                ClassLabel predlbl = this.algorithm.classify(testdata.get((DBIDRef)iter));
                ClassLabel truelbl = (ClassLabel)testlabels.get((DBIDRef)iter);
                int pred = Collections.binarySearch(labels, predlbl);
                int real = Collections.binarySearch(labels, truelbl);
                int[] nArray = confusion[pred];
                int n = real;
                nArray[n] = nArray[n] + 1;
                iter.advance();
            }
            LOG.statistics((Statistic)dur.end());
        }
        LOG.statistics((Statistic)time.end());
        ConfusionMatrix m = new ConfusionMatrix(labels, confusion);
        LOG.statistics((CharSequence)m.toString());
    }

    public static void main(String[] args) {
        ClassifierHoldoutEvaluationTask.runCLIApplication(ClassifierHoldoutEvaluationTask.class, (String[])args);
    }

    public static class Par<O>
    extends AbstractApplication.Par {
        public static final OptionID HOLDOUT_ID = new OptionID("evaluation.holdout", "Holdout class used in evaluation.");
        protected DatabaseConnection databaseConnection;
        protected Collection<? extends IndexFactory<?>> indexFactories;
        protected Classifier<O> algorithm;
        protected Holdout holdout;

        public void configure(Parameterization config) {
            super.configure(config);
            new ObjectParameter(AbstractDatabase.Par.DATABASE_CONNECTION_ID, DatabaseConnection.class, FileBasedDatabaseConnection.class).grab(config, x -> {
                this.databaseConnection = x;
            });
            new ObjectListParameter(AbstractDatabase.Par.INDEX_ID, IndexFactory.class).setOptional(true).grab(config, x -> {
                this.indexFactories = x;
            });
            new ObjectParameter(Algorithm.Utils.ALGORITHM_ID, Classifier.class).grab(config, x -> {
                this.algorithm = x;
            });
            new ObjectParameter(HOLDOUT_ID, Holdout.class, StratifiedCrossValidation.class).grab(config, x -> {
                this.holdout = x;
            });
        }

        public ClassifierHoldoutEvaluationTask<O> make() {
            return new ClassifierHoldoutEvaluationTask<O>(this.databaseConnection, this.indexFactories, this.algorithm, this.holdout);
        }
    }
}

