/*
 * Decompiled with CFR 0.152.
 */
package de.lmu.ifi.dbs.elki.application;

import de.lmu.ifi.dbs.elki.algorithm.AbstractAlgorithm;
import de.lmu.ifi.dbs.elki.algorithm.classification.Classifier;
import de.lmu.ifi.dbs.elki.application.AbstractApplication;
import de.lmu.ifi.dbs.elki.data.ClassLabel;
import de.lmu.ifi.dbs.elki.data.type.TypeInformation;
import de.lmu.ifi.dbs.elki.data.type.TypeUtil;
import de.lmu.ifi.dbs.elki.database.AbstractDatabase;
import de.lmu.ifi.dbs.elki.database.Database;
import de.lmu.ifi.dbs.elki.database.StaticArrayDatabase;
import de.lmu.ifi.dbs.elki.database.relation.Relation;
import de.lmu.ifi.dbs.elki.datasource.DatabaseConnection;
import de.lmu.ifi.dbs.elki.datasource.FileBasedDatabaseConnection;
import de.lmu.ifi.dbs.elki.datasource.MultipleObjectsBundleDatabaseConnection;
import de.lmu.ifi.dbs.elki.datasource.bundle.MultipleObjectsBundle;
import de.lmu.ifi.dbs.elki.evaluation.classification.ConfusionMatrix;
import de.lmu.ifi.dbs.elki.evaluation.classification.holdout.AbstractHoldout;
import de.lmu.ifi.dbs.elki.evaluation.classification.holdout.Holdout;
import de.lmu.ifi.dbs.elki.evaluation.classification.holdout.StratifiedCrossValidation;
import de.lmu.ifi.dbs.elki.evaluation.classification.holdout.TrainingAndTestSet;
import de.lmu.ifi.dbs.elki.index.IndexFactory;
import de.lmu.ifi.dbs.elki.logging.Logging;
import de.lmu.ifi.dbs.elki.logging.statistics.Duration;
import de.lmu.ifi.dbs.elki.logging.statistics.Statistic;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.OptionID;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameterization.Parameterization;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameters.ObjectListParameter;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameters.ObjectParameter;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameters.Parameter;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;

public class ClassifierHoldoutEvaluationTask<O>
extends AbstractApplication {
    private static final Logging LOG = Logging.getLogger(ClassifierHoldoutEvaluationTask.class);
    protected DatabaseConnection databaseConnection = null;
    protected Collection<IndexFactory<?>> indexFactories;
    protected Classifier<O> algorithm;
    protected Holdout holdout;

    public ClassifierHoldoutEvaluationTask(DatabaseConnection databaseConnection, Collection<IndexFactory<?>> indexFactories, Classifier<O> algorithm, Holdout holdout) {
        this.databaseConnection = databaseConnection;
        this.indexFactories = indexFactories;
        this.algorithm = algorithm;
        this.holdout = holdout;
    }

    public void run() {
        Duration ptime = LOG.newDuration("evaluation.time.load").begin();
        MultipleObjectsBundle allData = this.databaseConnection.loadData();
        this.holdout.initialize(allData);
        LOG.statistics((Statistic)ptime.end());
        Duration time = LOG.newDuration("evaluation.time.total").begin();
        ArrayList<ClassLabel> labels = this.holdout.getLabels();
        int[][] confusion = new int[labels.size()][labels.size()];
        for (int p = 0; p < this.holdout.numberOfPartitions(); ++p) {
            TrainingAndTestSet partition = this.holdout.nextPartitioning();
            Duration dur = LOG.newDuration(((Object)((Object)this)).getClass().getName() + ".fold-" + (p + 1) + ".init.time").begin();
            StaticArrayDatabase db = new StaticArrayDatabase((DatabaseConnection)new MultipleObjectsBundleDatabaseConnection(partition.getTraining()), this.indexFactories);
            db.initialize();
            LOG.statistics((Statistic)dur.end());
            dur = LOG.newDuration(((Object)((Object)this)).getClass().getName() + ".fold-" + (p + 1) + ".train.time").begin();
            Relation lrel = db.getRelation((TypeInformation)TypeUtil.CLASSLABEL, new Object[0]);
            this.algorithm.buildClassifier((Database)db, (Relation<ClassLabel>)lrel);
            LOG.statistics((Statistic)dur.end());
            dur = LOG.newDuration(((Object)((Object)this)).getClass().getName() + ".fold-" + (p + 1) + ".evaluation.time").begin();
            MultipleObjectsBundle test = partition.getTest();
            int lcol = AbstractHoldout.findClassLabelColumn(test);
            int tcol = lcol == 0 ? 1 : 0;
            int l = test.dataLength();
            for (int i = 0; i < l; ++i) {
                Object obj = test.data(i, tcol);
                ClassLabel truelbl = (ClassLabel)test.data(i, lcol);
                ClassLabel predlbl = this.algorithm.classify(obj);
                int pred = Collections.binarySearch(labels, predlbl);
                int real = Collections.binarySearch(labels, truelbl);
                int[] nArray = confusion[pred];
                int n = real;
                nArray[n] = nArray[n] + 1;
            }
            LOG.statistics((Statistic)dur.end());
        }
        LOG.statistics((Statistic)time.end());
        ConfusionMatrix m = new ConfusionMatrix(labels, confusion);
        LOG.statistics((CharSequence)m.toString());
    }

    public static void main(String[] args) {
        ClassifierHoldoutEvaluationTask.runCLIApplication(ClassifierHoldoutEvaluationTask.class, (String[])args);
    }

    public static class Parameterizer<O>
    extends AbstractApplication.Parameterizer {
        public static final OptionID HOLDOUT_ID = new OptionID("evaluation.holdout", "Holdout class used in evaluation.");
        protected DatabaseConnection databaseConnection = null;
        protected Collection<IndexFactory<?>> indexFactories;
        protected Classifier<O> algorithm;
        protected Holdout holdout;

        protected void makeOptions(Parameterization config) {
            ObjectParameter holdoutP;
            ObjectParameter algorithmP;
            ObjectListParameter indexFactoryP;
            super.makeOptions(config);
            ObjectParameter dbcP = new ObjectParameter(AbstractDatabase.Parameterizer.DATABASE_CONNECTION_ID, DatabaseConnection.class, FileBasedDatabaseConnection.class);
            if (config.grab((Parameter)dbcP)) {
                this.databaseConnection = (DatabaseConnection)dbcP.instantiateClass(config);
            }
            if (config.grab((Parameter)(indexFactoryP = new ObjectListParameter(AbstractDatabase.Parameterizer.INDEX_ID, IndexFactory.class, true)))) {
                this.indexFactories = indexFactoryP.instantiateClasses(config);
            }
            if (config.grab((Parameter)(algorithmP = new ObjectParameter(AbstractAlgorithm.ALGORITHM_ID, Classifier.class)))) {
                this.algorithm = (Classifier)algorithmP.instantiateClass(config);
            }
            if (config.grab((Parameter)(holdoutP = new ObjectParameter(HOLDOUT_ID, Holdout.class, StratifiedCrossValidation.class)))) {
                this.holdout = (Holdout)holdoutP.instantiateClass(config);
            }
        }

        protected ClassifierHoldoutEvaluationTask<O> makeInstance() {
            return new ClassifierHoldoutEvaluationTask<O>(this.databaseConnection, this.indexFactories, this.algorithm, this.holdout);
        }
    }
}

