/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.updater.graph;

import java.util.HashMap;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.graph.vertex.GraphVertex;
import org.deeplearning4j.nn.updater.BaseMultiLayerUpdater;
import org.nd4j.linalg.api.ndarray.INDArray;

public class ComputationGraphUpdater
extends BaseMultiLayerUpdater<ComputationGraph> {
    protected Layer[] orderedLayers;

    public ComputationGraphUpdater(ComputationGraph graph) {
        this(graph, (INDArray)null);
    }

    public ComputationGraphUpdater(ComputationGraph graph, INDArray updaterState) {
        super(graph, updaterState);
        Layer[] layers;
        this.layersByName = new HashMap();
        for (Layer l : layers = this.getOrderedLayers()) {
            this.layersByName.put(l.conf().getLayer().getLayerName(), l);
        }
    }

    @Override
    protected Layer[] getOrderedLayers() {
        if (this.orderedLayers != null) {
            return this.orderedLayers;
        }
        GraphVertex[] vertices = ((ComputationGraph)this.network).getVertices();
        int[] topologicalOrdering = ((ComputationGraph)this.network).topologicalSortOrder();
        Layer[] out = new Layer[((ComputationGraph)this.network).getNumLayers()];
        int j = 0;
        for (int i = 0; i < topologicalOrdering.length; ++i) {
            GraphVertex currentVertex = vertices[topologicalOrdering[i]];
            if (!currentVertex.hasLayer()) continue;
            out[j++] = currentVertex.getLayer();
        }
        this.orderedLayers = out;
        return this.orderedLayers;
    }

    @Override
    protected INDArray getFlattenedGradientsView() {
        if (((ComputationGraph)this.network).getFlattenedGradients() == null) {
            ((ComputationGraph)this.network).initGradientsView();
        }
        return ((ComputationGraph)this.network).getFlattenedGradients();
    }

    @Override
    protected INDArray getParams() {
        return ((ComputationGraph)this.network).params();
    }

    @Override
    protected boolean isMiniBatch() {
        return ((ComputationGraph)this.network).conf().isMiniBatch();
    }
}

