/*
 * Decompiled with CFR 0.152.
 */
package hex.deeplearning;

import hex.deeplearning.DeepLearningModelInfo;
import hex.deeplearning.DeepLearningTask;
import water.H2O;
import water.Key;
import water.MRTask;
import water.fvec.Frame;

public class DeepLearningTask2
extends MRTask<DeepLearningTask2> {
    private final Key _jobKey;
    private final Frame _fr;
    private DeepLearningModelInfo _sharedmodel;
    private final float _sync_fraction;
    private DeepLearningTask _res;
    private final int _iteration;

    public DeepLearningTask2(Key jobKey, Frame train, DeepLearningModelInfo model_info, float sync_fraction, int iteration) {
        assert (sync_fraction > 0.0f);
        this._jobKey = jobKey;
        this._fr = train;
        this._sharedmodel = model_info;
        this._sync_fraction = sync_fraction;
        this._iteration = iteration;
    }

    public DeepLearningModelInfo model_info() {
        return this._sharedmodel;
    }

    public void setupLocal() {
        super.setupLocal();
        this._res = new DeepLearningTask(this._jobKey, this._sharedmodel, this._sync_fraction, this._iteration, (H2O.H2OCountedCompleter)this);
        this.addToPendingCount(1);
        this._res.dfork(null, this._fr, true);
    }

    public void reduce(DeepLearningTask2 drt) {
        if (this._res == null) {
            this._res = drt._res;
        } else {
            this._res._chunk_node_count += drt._res._chunk_node_count;
            this._res.model_info().add(drt._res.model_info());
        }
        assert (this._res.model_info().get_params()._replicate_training_data);
    }

    protected void postGlobal() {
        assert (this._res.model_info().get_params()._replicate_training_data);
        super.postGlobal();
        this._res.model_info().div(this._res._chunk_node_count);
        this._res.model_info().add_processed_global(this._res.model_info().get_processed_local());
        this._res.model_info().set_processed_local(0L);
        DeepLearningModelInfo nodeAverageModel = this._res.model_info();
        this._sharedmodel = nodeAverageModel.get_params()._elastic_averaging ? DeepLearningModelInfo.timeAverage(nodeAverageModel) : nodeAverageModel;
    }
}

