/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.earlystopping.trainer;

import org.deeplearning4j.datasets.iterator.impl.SingletonDataSetIterator;
import org.deeplearning4j.datasets.iterator.impl.SingletonMultiDataSetIterator;
import org.deeplearning4j.earlystopping.EarlyStoppingConfiguration;
import org.deeplearning4j.earlystopping.listener.EarlyStoppingListener;
import org.deeplearning4j.earlystopping.trainer.BaseEarlyStoppingTrainer;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;

public class EarlyStoppingGraphTrainer
extends BaseEarlyStoppingTrainer<ComputationGraph> {
    private ComputationGraph net;

    public EarlyStoppingGraphTrainer(EarlyStoppingConfiguration<ComputationGraph> esConfig, ComputationGraph net, DataSetIterator train) {
        this(esConfig, net, train, null);
    }

    public EarlyStoppingGraphTrainer(EarlyStoppingConfiguration<ComputationGraph> esConfig, ComputationGraph net, DataSetIterator train, EarlyStoppingListener<ComputationGraph> listener) {
        super(esConfig, net, train, null, listener);
        if (net.getNumInputArrays() != 1 || net.getNumOutputArrays() != 1) {
            throw new IllegalStateException("Cannot do early stopping training on ComputationGraph with DataSetIterator: graph does not have 1 input and 1 output array");
        }
        this.net = net;
    }

    public EarlyStoppingGraphTrainer(EarlyStoppingConfiguration<ComputationGraph> esConfig, ComputationGraph net, MultiDataSetIterator train, EarlyStoppingListener<ComputationGraph> listener) {
        super(esConfig, net, null, train, listener);
        this.net = net;
    }

    /*
     * Enabled force condition propagation
     * Lifted jumps to return sites
     */
    @Override
    protected void fit(org.nd4j.linalg.dataset.DataSet ds) {
        if (!this.net.getConfiguration().isBackprop()) {
            if (!this.net.getConfiguration().isPretrain()) throw new IllegalStateException("Cannot train - network configuration has both isBackprop == false and isPretrain == false");
            this.net.pretrain((DataSetIterator)new SingletonDataSetIterator(ds));
            return;
        } else {
            this.net.fit((DataSet)ds);
        }
    }

    /*
     * Enabled force condition propagation
     * Lifted jumps to return sites
     */
    @Override
    protected void fit(MultiDataSet mds) {
        if (!this.net.getConfiguration().isBackprop()) {
            if (!this.net.getConfiguration().isPretrain()) throw new IllegalStateException("Cannot train - network configuration has both isBackprop == false and isPretrain == false");
            this.net.pretrain((MultiDataSetIterator)new SingletonMultiDataSetIterator(mds));
            return;
        } else {
            this.net.fit(mds);
        }
    }
}

