/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.arbiter.scoring.impl;

import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import org.deeplearning4j.arbiter.optimize.api.data.DataProvider;
import org.deeplearning4j.arbiter.optimize.api.data.DataSource;
import org.deeplearning4j.arbiter.optimize.api.score.ScoreFunction;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.DataSetIteratorFactory;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;

public abstract class BaseNetScoreFunction
implements ScoreFunction {
    public double score(Object model, DataProvider dataProvider, Map<String, Object> dataParameters) {
        Object testData = dataProvider.testData(dataParameters);
        return this.score(model, testData);
    }

    public double score(Object model, Class<? extends DataSource> dataSource, Properties dataSourceProperties) {
        DataSource ds;
        try {
            ds = dataSource.newInstance();
            if (dataSourceProperties != null) {
                ds.configure(dataSourceProperties);
            }
        }
        catch (Exception e) {
            throw new RuntimeException("Error creating DataSource instance - missing no-arg constructor?", e);
        }
        return this.score(model, ds.testData());
    }

    protected double score(Object model, Object testData) {
        if (model instanceof MultiLayerNetwork) {
            if (testData instanceof DataSetIterator) {
                return this.score((MultiLayerNetwork)model, (DataSetIterator)testData);
            }
            if (testData instanceof MultiDataSetIterator) {
                return this.score((MultiLayerNetwork)model, (MultiDataSetIterator)testData);
            }
            if (testData instanceof DataSetIteratorFactory) {
                return this.score((MultiLayerNetwork)model, ((DataSetIteratorFactory)testData).create());
            }
            throw new RuntimeException("Unknown type of data: " + testData.getClass());
        }
        if (testData instanceof DataSetIterator) {
            return this.score((ComputationGraph)model, (DataSetIterator)testData);
        }
        if (testData instanceof DataSetIteratorFactory) {
            return this.score((ComputationGraph)model, ((DataSetIteratorFactory)testData).create());
        }
        if (testData instanceof MultiDataSetIterator) {
            return this.score((ComputationGraph)model, (MultiDataSetIterator)testData);
        }
        throw new RuntimeException("Unknown type of data: " + testData.getClass());
    }

    public List<Class<?>> getSupportedModelTypes() {
        return Arrays.asList(MultiLayerNetwork.class, ComputationGraph.class);
    }

    public List<Class<?>> getSupportedDataTypes() {
        return Arrays.asList(DataSetIterator.class, MultiDataSetIterator.class);
    }

    public abstract double score(MultiLayerNetwork var1, DataSetIterator var2);

    public abstract double score(MultiLayerNetwork var1, MultiDataSetIterator var2);

    public abstract double score(ComputationGraph var1, DataSetIterator var2);

    public abstract double score(ComputationGraph var1, MultiDataSetIterator var2);

    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof BaseNetScoreFunction)) {
            return false;
        }
        BaseNetScoreFunction other = (BaseNetScoreFunction)o;
        return other.canEqual(this);
    }

    protected boolean canEqual(Object other) {
        return other instanceof BaseNetScoreFunction;
    }

    public int hashCode() {
        boolean result = true;
        return 1;
    }
}

