/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.arbiter.evaluator.multilayer;

import java.util.Arrays;
import java.util.List;
import java.util.Map;
import org.deeplearning4j.arbiter.optimize.api.data.DataProvider;
import org.deeplearning4j.arbiter.optimize.api.evaluation.ModelEvaluator;
import org.deeplearning4j.arbiter.scoring.util.ScoreUtil;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;

public class ClassificationEvaluator
implements ModelEvaluator {
    private Map<String, Object> params = null;

    public Evaluation evaluateModel(Object model, DataProvider dataProvider) {
        if (model instanceof MultiLayerNetwork) {
            DataSetIterator iterator = ScoreUtil.getIterator(dataProvider.testData(this.params));
            return ScoreUtil.getEvaluation((MultiLayerNetwork)model, iterator);
        }
        DataSetIterator iterator = ScoreUtil.getIterator(dataProvider.testData(this.params));
        return ScoreUtil.getEvaluation((ComputationGraph)model, iterator);
    }

    public List<Class<?>> getSupportedModelTypes() {
        return Arrays.asList(MultiLayerNetwork.class, ComputationGraph.class);
    }

    public List<Class<?>> getSupportedDataTypes() {
        return Arrays.asList(DataSetIterator.class, MultiDataSetIterator.class);
    }

    public ClassificationEvaluator() {
    }

    public ClassificationEvaluator(Map<String, Object> params) {
        this.params = params;
    }
}

