/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.layers.samediff;

import java.util.Arrays;
import java.util.LinkedHashMap;
import java.util.Map;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.samediff.AbstractSameDiffLayer;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.AbstractLayer;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.temp.ExternalErrorsFunction;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;

public class SameDiffLayer
extends AbstractLayer<AbstractSameDiffLayer> {
    public static final String INPUT_KEY = "input";
    protected SameDiff sameDiff;
    protected SDVariable outputVar;
    protected ExternalErrorsFunction fn;
    protected String outputKey;
    protected INDArray params;
    protected INDArray gradients;
    protected Map<String, INDArray> paramTable;
    protected Map<String, INDArray> gradTable;

    public SameDiffLayer(NeuralNetConfiguration conf) {
        super(conf);
    }

    public Layer clone() {
        throw new UnsupportedOperationException();
    }

    @Override
    public boolean isPretrainLayer() {
        return false;
    }

    @Override
    public void clearNoiseWeightParams() {
    }

    @Override
    public INDArray activate(boolean training, LayerWorkspaceMgr workspaceMgr) {
        this.assertInputSet(false);
        if (this.sameDiff == null) {
            this.doInit();
        }
        try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces();){
            this.sameDiff.clearExecutionCache();
            this.sameDiff.associateArrayWithVariable(this.input.dup(), this.sameDiff.getVariable(INPUT_KEY));
            for (String s : this.paramTable.keySet()) {
                this.sameDiff.associateArrayWithVariable(this.paramTable.get(s), s);
            }
            this.sameDiff.exec();
            INDArray result = this.sameDiff.getArrForVarName(this.outputKey);
            INDArray iNDArray = workspaceMgr.dup(ArrayType.ACTIVATIONS, result);
            return iNDArray;
        }
    }

    @Override
    public Pair<Gradient, INDArray> backpropGradient(INDArray epsilon, LayerWorkspaceMgr workspaceMgr) {
        INDArray dLdIn;
        this.assertInputSet(true);
        DefaultGradient g = new DefaultGradient();
        try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces();){
            this.sameDiff.clearExecutionCache();
            this.sameDiff.associateArrayWithVariable(this.input.dup(), this.sameDiff.getVariable(INPUT_KEY));
            this.fn.updateVariable(this.outputVar.getVarName(), epsilon.dup());
            for (String s : this.paramTable.keySet()) {
                this.sameDiff.associateArrayWithVariable(this.paramTable.get(s), s);
            }
            this.sameDiff.execBackwards();
            for (String s : this.paramTable.keySet()) {
                INDArray sdGrad = this.sameDiff.grad(s).getArr();
                INDArray dl4jGrad = this.gradTable.get(s);
                dl4jGrad.assign(sdGrad);
                g.gradientForVariable().put(s, dl4jGrad);
            }
            dLdIn = this.sameDiff.grad(INPUT_KEY).getArr();
        }
        return new Pair((Object)g, (Object)workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, dLdIn));
    }

    @Override
    public INDArray params() {
        return this.params;
    }

    @Override
    public INDArray getParam(String param) {
        return this.paramTable.get(param);
    }

    @Override
    public long numParams() {
        return this.params == null ? 0L : (long)((int)this.params.length());
    }

    @Override
    public void setParam(String key, INDArray val) {
        if (!this.paramTable.containsKey(key)) {
            throw new IllegalArgumentException("Cannot set parameter, invalid/unknown parameter key: " + key);
        }
        INDArray current = this.paramTable.get(key);
        if (!Arrays.equals(current.shape(), val.shape())) {
            throw new IllegalArgumentException("Cannot set parameter \"" + key + "\", invalid shape: parameter array has shape " + Arrays.toString(current.shape()) + ", trying to set parameter of shape " + Arrays.toString(val.shape()));
        }
    }

    @Override
    public void setParams(INDArray params) {
        if (params != null) {
            throw new UnsupportedOperationException("Not supported");
        }
    }

    @Override
    protected void setParams(INDArray params, char order) {
        this.setParams(params);
    }

    @Override
    public void setParamsViewArray(INDArray params) {
        this.params = params;
    }

    @Override
    public INDArray getGradientsViewArray() {
        return this.gradients;
    }

    @Override
    public void setBackpropGradientsViewArray(INDArray gradients) {
        this.gradients = gradients;
        this.gradTable = ((AbstractSameDiffLayer)this.layerConf()).initializer().getGradientsFromFlattened(this.conf(), gradients);
    }

    @Override
    public void setParamTable(Map<String, INDArray> paramTable) {
        if (this.paramTable == null) {
            this.paramTable = paramTable;
        } else {
            for (Map.Entry<String, INDArray> e : paramTable.entrySet()) {
                this.setParam(e.getKey(), e.getValue());
            }
        }
    }

    @Override
    public Map<String, INDArray> paramTable() {
        return this.paramTable(false);
    }

    @Override
    public Map<String, INDArray> paramTable(boolean backpropParamsOnly) {
        return this.paramTable;
    }

    protected void doInit() {
        try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces();){
            org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer bl = (org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer)this.layerConf();
            this.sameDiff = SameDiff.create();
            Map<String, INDArray> p = this.paramTable();
            long[] inputShape = (long[])this.input.shape().clone();
            SDVariable inputVar = this.sameDiff.var(INPUT_KEY, inputShape);
            Map<String, long[]> paramShapes = ((AbstractSameDiffLayer)this.layerConf()).getLayerParams().getParamShapes();
            LinkedHashMap<String, SDVariable> params = new LinkedHashMap<String, SDVariable>();
            for (String s : paramShapes.keySet()) {
                long[] ps = paramShapes.get(s);
                SDVariable v = this.sameDiff.var(s, ps);
                params.put(s, v);
            }
            SDVariable layerOutput = bl.defineLayer(this.sameDiff, inputVar, params);
            Preconditions.checkNotNull((Object)layerOutput, (String)"Invalid output: layer output is null");
            this.outputVar = layerOutput;
            for (Map.Entry<String, INDArray> e : p.entrySet()) {
                this.sameDiff.associateArrayWithVariable(e.getValue(), this.sameDiff.getVariable(e.getKey()));
            }
            this.fn = this.sameDiff.f().externalErrors(new SDVariable[]{layerOutput});
            this.fn.outputVariable();
            this.outputKey = this.outputVar.getVarName();
        }
    }
}

