/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.layers;

import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.BaseLayer;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.lossfunctions.ILossFunction;
import org.nd4j.linalg.primitives.Pair;

public abstract class BasePretrainNetwork<LayerConfT extends org.deeplearning4j.nn.conf.layers.BasePretrainNetwork>
extends BaseLayer<LayerConfT> {
    public BasePretrainNetwork(NeuralNetConfiguration conf) {
        super(conf);
    }

    public BasePretrainNetwork(NeuralNetConfiguration conf, INDArray input) {
        super(conf, input);
    }

    public INDArray getCorruptedInput(INDArray x, double corruptionLevel) {
        INDArray corrupted = Nd4j.getDistributions().createBinomial(1, 1.0 - corruptionLevel).sample(x.shape());
        corrupted.muli(x);
        return corrupted;
    }

    protected Gradient createGradient(INDArray wGradient, INDArray vBiasGradient, INDArray hBiasGradient) {
        DefaultGradient ret = new DefaultGradient(this.gradientsFlattened);
        INDArray wg = (INDArray)this.gradientViews.get("W");
        wg.assign(wGradient);
        INDArray hbg = (INDArray)this.gradientViews.get("b");
        hbg.assign(hBiasGradient);
        INDArray vbg = (INDArray)this.gradientViews.get("vb");
        vbg.assign(vBiasGradient);
        ret.gradientForVariable().put("W", wg);
        ret.gradientForVariable().put("b", hbg);
        ret.gradientForVariable().put("vb", vbg);
        return ret;
    }

    @Override
    public long numParams(boolean backwards) {
        return super.numParams(backwards);
    }

    public abstract Pair<INDArray, INDArray> sampleHiddenGivenVisible(INDArray var1);

    public abstract Pair<INDArray, INDArray> sampleVisibleGivenHidden(INDArray var1);

    @Override
    protected void setScoreWithZ(INDArray z) {
        if (this.input == null || z == null) {
            throw new IllegalStateException("Cannot calculate score without input and labels " + this.layerId());
        }
        ILossFunction lossFunction = ((org.deeplearning4j.nn.conf.layers.BasePretrainNetwork)this.layerConf()).getLossFunction().getILossFunction();
        double score = lossFunction.computeScore(this.input, z, ((org.deeplearning4j.nn.conf.layers.BasePretrainNetwork)this.layerConf()).getActivationFn(), this.maskArray, false);
        score += this.calcL1(false) + this.calcL2(false);
        this.score = score /= (double)this.getInputMiniBatchSize();
    }

    @Override
    public Map<String, INDArray> paramTable(boolean backpropParamsOnly) {
        if (!backpropParamsOnly) {
            return this.params;
        }
        LinkedHashMap<String, INDArray> map = new LinkedHashMap<String, INDArray>();
        map.put("W", (INDArray)this.params.get("W"));
        map.put("b", (INDArray)this.params.get("b"));
        return map;
    }

    @Override
    public INDArray params() {
        ArrayList list = new ArrayList(2);
        for (Map.Entry entry : this.params.entrySet()) {
            list.add(entry.getValue());
        }
        return Nd4j.toFlattened((char)'f', list);
    }

    @Override
    public long numParams() {
        int ret = 0;
        for (Map.Entry entry : this.params.entrySet()) {
            ret = (int)((long)ret + ((INDArray)entry.getValue()).length());
        }
        return ret;
    }

    @Override
    public void setParams(INDArray params) {
        if (params == this.paramsFlattened) {
            return;
        }
        List<String> parameterList = this.conf.variables();
        long paramLength = 0L;
        for (String s : parameterList) {
            long len = this.getParam(s).length();
            paramLength += len;
        }
        if (params.length() != paramLength) {
            throw new IllegalArgumentException("Unable to set parameters: must be of length " + paramLength + ", got params of length " + params.length() + " " + this.layerId());
        }
        this.paramsFlattened.assign(params);
    }

    @Override
    public Pair<Gradient, INDArray> backpropGradient(INDArray epsilon, LayerWorkspaceMgr workspaceMgr) {
        Pair<Gradient, INDArray> result = super.backpropGradient(epsilon, workspaceMgr);
        ((DefaultGradient)result.getFirst()).setFlattenedGradient(this.gradientsFlattened);
        INDArray vBiasGradient = (INDArray)this.gradientViews.get("vb");
        ((Gradient)result.getFirst()).gradientForVariable().put("vb", vBiasGradient);
        vBiasGradient.assign((Number)0);
        this.weightNoiseParams.clear();
        return result;
    }

    @Override
    public double calcL2(boolean backpropParamsOnly) {
        double l2Sum = super.calcL2(true);
        if (backpropParamsOnly) {
            return l2Sum;
        }
        if (((org.deeplearning4j.nn.conf.layers.BasePretrainNetwork)this.layerConf()).getL2ByParam("vb") > 0.0) {
            double l2Norm = this.getParam("vb").norm2Number().doubleValue();
            l2Sum += 0.5 * ((org.deeplearning4j.nn.conf.layers.BasePretrainNetwork)this.layerConf()).getL2ByParam("vb") * l2Norm * l2Norm;
        }
        return l2Sum;
    }

    @Override
    public double calcL1(boolean backpropParamsOnly) {
        double l1Sum = super.calcL1(true);
        if (((org.deeplearning4j.nn.conf.layers.BasePretrainNetwork)this.layerConf()).getL1ByParam("vb") > 0.0) {
            l1Sum += ((org.deeplearning4j.nn.conf.layers.BasePretrainNetwork)this.layerConf()).getL1ByParam("vb") * this.getParam("vb").norm1Number().doubleValue();
        }
        return l1Sum;
    }
}

