/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.gradientcheck;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.Updater;
import org.deeplearning4j.nn.api.layers.IOutputLayer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.graph.GraphVertex;
import org.deeplearning4j.nn.conf.graph.LayerVertex;
import org.deeplearning4j.nn.conf.layers.BaseLayer;
import org.deeplearning4j.nn.conf.layers.BaseOutputLayer;
import org.deeplearning4j.nn.conf.layers.LossLayer;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.updater.UpdaterCreator;
import org.deeplearning4j.nn.updater.graph.ComputationGraphUpdater;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.activations.impl.ActivationSoftmax;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.util.DataTypeUtil;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.MultiDataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.function.Consumer;
import org.nd4j.linalg.learning.config.IUpdater;
import org.nd4j.linalg.learning.config.NoOp;
import org.nd4j.linalg.learning.config.Sgd;
import org.nd4j.linalg.lossfunctions.ILossFunction;
import org.nd4j.linalg.lossfunctions.impl.LossBinaryXENT;
import org.nd4j.linalg.lossfunctions.impl.LossMCXENT;
import org.nd4j.linalg.primitives.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class GradientCheckUtil {
    private static final Logger log = LoggerFactory.getLogger(GradientCheckUtil.class);
    private static final List<Class<? extends IActivation>> VALID_ACTIVATION_FUNCTIONS = Arrays.asList(Activation.CUBE.getActivationFunction().getClass(), Activation.ELU.getActivationFunction().getClass(), Activation.IDENTITY.getActivationFunction().getClass(), Activation.RATIONALTANH.getActivationFunction().getClass(), Activation.SIGMOID.getActivationFunction().getClass(), Activation.SOFTMAX.getActivationFunction().getClass(), Activation.SOFTPLUS.getActivationFunction().getClass(), Activation.SOFTSIGN.getActivationFunction().getClass(), Activation.TANH.getActivationFunction().getClass());

    private GradientCheckUtil() {
    }

    private static void configureLossFnClippingIfPresent(IOutputLayer outputLayer) {
        ILossFunction lfn = null;
        IActivation afn = null;
        if (outputLayer instanceof org.deeplearning4j.nn.layers.BaseOutputLayer) {
            org.deeplearning4j.nn.layers.BaseOutputLayer o = (org.deeplearning4j.nn.layers.BaseOutputLayer)outputLayer;
            lfn = ((BaseOutputLayer)o.layerConf()).getLossFn();
            afn = ((BaseLayer)o.layerConf()).getActivationFn();
        } else if (outputLayer instanceof org.deeplearning4j.nn.layers.LossLayer) {
            org.deeplearning4j.nn.layers.LossLayer o = (org.deeplearning4j.nn.layers.LossLayer)outputLayer;
            lfn = ((LossLayer)o.layerConf()).getLossFn();
            afn = ((LossLayer)o.layerConf()).getActivationFn();
        }
        if (lfn instanceof LossMCXENT && afn instanceof ActivationSoftmax && ((LossMCXENT)lfn).getSoftmaxClipEps() != 0.0) {
            log.info("Setting softmax clipping epsilon to 0.0 for " + lfn.getClass() + " loss function to avoid spurious gradient check failures");
            ((LossMCXENT)lfn).setSoftmaxClipEps(0.0);
        } else if (lfn instanceof LossBinaryXENT && ((LossBinaryXENT)lfn).getClipEps() != 0.0) {
            log.info("Setting clipping epsilon to 0.0 for " + lfn.getClass() + " loss function to avoid spurious gradient check failures");
            ((LossBinaryXENT)lfn).setClipEps(0.0);
        }
    }

    public static boolean checkGradients(MultiLayerNetwork mln, double epsilon, double maxRelError, double minAbsoluteError, boolean print, boolean exitOnFirstError, INDArray input, INDArray labels) {
        return GradientCheckUtil.checkGradients(mln, epsilon, maxRelError, minAbsoluteError, print, exitOnFirstError, input, labels, null, null);
    }

    public static boolean checkGradients(MultiLayerNetwork mln, double epsilon, double maxRelError, double minAbsoluteError, boolean print, boolean exitOnFirstError, INDArray input, INDArray labels, Set<String> excludeParams) {
        return GradientCheckUtil.checkGradients(mln, epsilon, maxRelError, minAbsoluteError, print, exitOnFirstError, input, labels, null, null, false, -1, excludeParams, (Integer)null);
    }

    public static boolean checkGradients(MultiLayerNetwork mln, double epsilon, double maxRelError, double minAbsoluteError, boolean print, boolean exitOnFirstError, INDArray input, INDArray labels, INDArray inputMask, INDArray labelMask) {
        return GradientCheckUtil.checkGradients(mln, epsilon, maxRelError, minAbsoluteError, print, exitOnFirstError, input, labels, inputMask, labelMask, false, -1);
    }

    public static boolean checkGradients(MultiLayerNetwork mln, double epsilon, double maxRelError, double minAbsoluteError, boolean print, boolean exitOnFirstError, INDArray input, INDArray labels, INDArray inputMask, INDArray labelMask, boolean subset, int maxPerParam) {
        return GradientCheckUtil.checkGradients(mln, epsilon, maxRelError, minAbsoluteError, print, exitOnFirstError, input, labels, inputMask, labelMask, subset, maxPerParam, null);
    }

    public static boolean checkGradients(MultiLayerNetwork mln, double epsilon, double maxRelError, double minAbsoluteError, boolean print, boolean exitOnFirstError, INDArray input, INDArray labels, INDArray inputMask, INDArray labelMask, boolean subset, int maxPerParam, Set<String> excludeParams) {
        return GradientCheckUtil.checkGradients(mln, epsilon, maxRelError, minAbsoluteError, print, exitOnFirstError, input, labels, inputMask, labelMask, subset, maxPerParam, excludeParams, (Consumer<MultiLayerNetwork>)((Consumer)null));
    }

    public static boolean checkGradients(MultiLayerNetwork mln, double epsilon, double maxRelError, double minAbsoluteError, boolean print, boolean exitOnFirstError, INDArray input, INDArray labels, INDArray inputMask, INDArray labelMask, boolean subset, int maxPerParam, Set<String> excludeParams, final Integer rngSeedResetEachIter) {
        Consumer<MultiLayerNetwork> c = null;
        if (rngSeedResetEachIter != null) {
            c = new Consumer<MultiLayerNetwork>(){

                public void accept(MultiLayerNetwork multiLayerNetwork) {
                    Nd4j.getRandom().setSeed(rngSeedResetEachIter.intValue());
                }
            };
        }
        return GradientCheckUtil.checkGradients(mln, epsilon, maxRelError, minAbsoluteError, print, exitOnFirstError, input, labels, inputMask, labelMask, subset, maxPerParam, excludeParams, c);
    }

    public static boolean checkGradients(MultiLayerNetwork mln, double epsilon, double maxRelError, double minAbsoluteError, boolean print, boolean exitOnFirstError, INDArray input, INDArray labels, INDArray inputMask, INDArray labelMask, boolean subset, int maxPerParam, Set<String> excludeParams, Consumer<MultiLayerNetwork> callEachIter) {
        int i;
        HashMap stepSizeForParam;
        if (epsilon <= 0.0 || epsilon > 0.1) {
            throw new IllegalArgumentException("Invalid epsilon: expect epsilon in range (0,0.1], usually 1e-4 or so");
        }
        if (maxRelError <= 0.0 || maxRelError > 0.25) {
            throw new IllegalArgumentException("Invalid maxRelativeError: " + maxRelError);
        }
        if (!(mln.getOutputLayer() instanceof IOutputLayer)) {
            throw new IllegalArgumentException("Cannot check backprop gradients without OutputLayer");
        }
        DataBuffer.Type dataType = DataTypeUtil.getDtypeFromContext();
        if (dataType != DataBuffer.Type.DOUBLE) {
            throw new IllegalStateException("Cannot perform gradient check: Datatype is not set to double precision (is: " + dataType + "). Double precision must be used for gradient checks. Set DataTypeUtil.setDTypeForContext(DataBuffer.Type.DOUBLE); before using GradientCheckUtil");
        }
        int layerCount = 0;
        for (NeuralNetConfiguration n : mln.getLayerWiseConfigurations().getConfs()) {
            if (n.getLayer() instanceof BaseLayer) {
                IActivation activation;
                BaseLayer bl = (BaseLayer)n.getLayer();
                IUpdater u = bl.getIUpdater();
                if (u instanceof Sgd) {
                    double lr = ((Sgd)u).getLearningRate();
                    if (lr != 1.0) {
                        throw new IllegalStateException("When using SGD updater, must also use lr=1.0 for layer " + layerCount + "; got " + u + " with lr=" + lr + " for layer \"" + n.getLayer().getLayerName() + "\"");
                    }
                } else if (!(u instanceof NoOp)) {
                    throw new IllegalStateException("Must have Updater.NONE (or SGD + lr=1.0) for layer " + layerCount + "; got " + u);
                }
                if ((activation = bl.getActivationFn()) != null && !VALID_ACTIVATION_FUNCTIONS.contains(activation.getClass())) {
                    log.warn("Layer " + layerCount + " is possibly using an unsuitable activation function: " + activation.getClass() + ". Activation functions for gradient checks must be smooth (like sigmoid, tanh, softmax) and not contain discontinuities like ReLU or LeakyReLU (these may cause spurious failures)");
                }
            }
            if (n.getLayer().getIDropout() == null || callEachIter != null) continue;
            throw new IllegalStateException("When gradient checking dropout, need to reset RNG seed each iter, or no dropout should be present during gradient checks - got dropout = " + n.getLayer().getIDropout() + " for layer " + layerCount);
        }
        for (Layer l : mln.getLayers()) {
            if (!(l instanceof IOutputLayer)) continue;
            GradientCheckUtil.configureLossFnClippingIfPresent((IOutputLayer)l);
        }
        mln.setInput(input);
        mln.setLabels(labels);
        mln.setLayerMaskArrays(inputMask, labelMask);
        if (callEachIter != null) {
            callEachIter.accept((Object)mln);
        }
        mln.computeGradientAndScore();
        Pair<Gradient, Double> gradAndScore = mln.gradientAndScore();
        Updater updater = UpdaterCreator.getUpdater(mln);
        updater.update(mln, (Gradient)gradAndScore.getFirst(), 0, 0, mln.batchSize(), LayerWorkspaceMgr.noWorkspaces());
        INDArray gradientToCheck = ((Gradient)gradAndScore.getFirst()).gradient().dup();
        INDArray originalParams = mln.params().dup();
        long nParams = originalParams.length();
        Map<String, INDArray> paramTable = mln.paramTable();
        ArrayList<String> paramNames = new ArrayList<String>(paramTable.keySet());
        long[] paramEnds = new long[paramNames.size()];
        paramEnds[0] = paramTable.get(paramNames.get(0)).length();
        if (subset) {
            stepSizeForParam = new HashMap();
            stepSizeForParam.put(paramNames.get(0), (int)Math.max(1L, paramTable.get(paramNames.get(0)).length() / (long)maxPerParam));
        } else {
            stepSizeForParam = null;
        }
        for (i = 1; i < paramEnds.length; ++i) {
            long n = paramTable.get(paramNames.get(i)).length();
            paramEnds[i] = paramEnds[i - 1] + n;
            if (!subset) continue;
            long ss = n / (long)maxPerParam;
            if (ss == 0L) {
                ss = n;
            }
            stepSizeForParam.put(paramNames.get(i), (int)ss);
        }
        if (print) {
            i = 0;
            for (Layer l : mln.getLayers()) {
                Set<String> s = l.paramTable().keySet();
                log.info("Layer " + i + ": " + l.getClass().getSimpleName() + " - params " + s);
                ++i;
            }
        }
        int totalNFailures = 0;
        double maxError = 0.0;
        DataSet ds = new DataSet(input, labels, inputMask, labelMask);
        int currParamNameIdx = 0;
        INDArray params = mln.params();
        long i2 = 0L;
        while (i2 < nParams) {
            long step;
            if (i2 >= paramEnds[currParamNameIdx]) {
                ++currParamNameIdx;
            }
            String paramName = (String)paramNames.get(currParamNameIdx);
            if (excludeParams != null && excludeParams.contains(paramName)) {
                log.info("Skipping parameters for parameter name: {}", (Object)paramName);
                i2 = paramEnds[currParamNameIdx++];
                continue;
            }
            double origValue = params.getDouble(i2);
            params.putScalar(i2, origValue + epsilon);
            if (callEachIter != null) {
                callEachIter.accept((Object)mln);
            }
            double scorePlus = mln.score(ds, true);
            params.putScalar(i2, origValue - epsilon);
            if (callEachIter != null) {
                callEachIter.accept((Object)mln);
            }
            double scoreMinus = mln.score(ds, true);
            params.putScalar(i2, origValue);
            double scoreDelta = scorePlus - scoreMinus;
            double numericalGradient = scoreDelta / (2.0 * epsilon);
            if (Double.isNaN(numericalGradient)) {
                throw new IllegalStateException("Numerical gradient was NaN for parameter " + i2 + " of " + nParams);
            }
            double backpropGradient = gradientToCheck.getDouble(i2);
            double relError = Math.abs(backpropGradient - numericalGradient) / (Math.abs(numericalGradient) + Math.abs(backpropGradient));
            if (backpropGradient == 0.0 && numericalGradient == 0.0) {
                relError = 0.0;
            }
            if (relError > maxError) {
                maxError = relError;
            }
            if (relError > maxRelError || Double.isNaN(relError)) {
                double absError = Math.abs(backpropGradient - numericalGradient);
                if (absError < minAbsoluteError) {
                    if (print) {
                        log.info("Param " + i2 + " (" + paramName + ") passed: grad= " + backpropGradient + ", numericalGrad= " + numericalGradient + ", relError= " + relError + "; absolute error = " + absError + " < minAbsoluteError = " + minAbsoluteError);
                    }
                } else {
                    if (print) {
                        log.info("Param " + i2 + " (" + paramName + ") FAILED: grad= " + backpropGradient + ", numericalGrad= " + numericalGradient + ", relError= " + relError + ", scorePlus=" + scorePlus + ", scoreMinus= " + scoreMinus + ", paramValue = " + origValue);
                    }
                    if (exitOnFirstError) {
                        return false;
                    }
                    ++totalNFailures;
                }
            } else if (print) {
                log.info("Param " + i2 + " (" + paramName + ") passed: grad= " + backpropGradient + ", numericalGrad= " + numericalGradient + ", relError= " + relError);
            }
            if (subset) {
                step = ((Integer)stepSizeForParam.get(paramName)).intValue();
                if (i2 + step > paramEnds[currParamNameIdx] + 1L) {
                    step = paramEnds[currParamNameIdx] + 1L - i2;
                }
            } else {
                step = 1L;
            }
            i2 += step;
        }
        if (print) {
            long nPass = nParams - (long)totalNFailures;
            log.info("GradientCheckUtil.checkGradients(): " + nParams + " params checked, " + nPass + " passed, " + totalNFailures + " failed. Largest relative error = " + maxError);
        }
        return totalNFailures == 0;
    }

    public static boolean checkGradients(ComputationGraph graph, double epsilon, double maxRelError, double minAbsoluteError, boolean print, boolean exitOnFirstError, INDArray[] inputs, INDArray[] labels) {
        return GradientCheckUtil.checkGradients(graph, epsilon, maxRelError, minAbsoluteError, print, exitOnFirstError, inputs, labels, null, null, null);
    }

    public static boolean checkGradients(ComputationGraph graph, double epsilon, double maxRelError, double minAbsoluteError, boolean print, boolean exitOnFirstError, INDArray[] inputs, INDArray[] labels, INDArray[] fMask, INDArray[] lMask) {
        return GradientCheckUtil.checkGradients(graph, epsilon, maxRelError, minAbsoluteError, print, exitOnFirstError, inputs, labels, fMask, lMask, null);
    }

    public static boolean checkGradients(ComputationGraph graph, double epsilon, double maxRelError, double minAbsoluteError, boolean print, boolean exitOnFirstError, INDArray[] inputs, INDArray[] labels, INDArray[] fMask, INDArray[] lMask, Set<String> excludeParams) {
        return GradientCheckUtil.checkGradients(graph, epsilon, maxRelError, minAbsoluteError, print, exitOnFirstError, inputs, labels, fMask, lMask, excludeParams, (Consumer<ComputationGraph>)((Consumer)null));
    }

    public static boolean checkGradients(ComputationGraph graph, double epsilon, double maxRelError, double minAbsoluteError, boolean print, boolean exitOnFirstError, INDArray[] inputs, INDArray[] labels, INDArray[] fMask, INDArray[] lMask, Set<String> excludeParams, final Integer rngSeedResetEachIter) {
        Consumer<ComputationGraph> c = null;
        if (rngSeedResetEachIter != null) {
            c = new Consumer<ComputationGraph>(){

                public void accept(ComputationGraph computationGraph) {
                    Nd4j.getRandom().setSeed(rngSeedResetEachIter.intValue());
                }
            };
        }
        return GradientCheckUtil.checkGradients(graph, epsilon, maxRelError, minAbsoluteError, print, exitOnFirstError, inputs, labels, fMask, lMask, excludeParams, c);
    }

    public static boolean checkGradients(ComputationGraph graph, double epsilon, double maxRelError, double minAbsoluteError, boolean print, boolean exitOnFirstError, INDArray[] inputs, INDArray[] labels, INDArray[] fMask, INDArray[] lMask, Set<String> excludeParams, Consumer<ComputationGraph> callEachIter) {
        int i;
        if (epsilon <= 0.0 || epsilon > 0.1) {
            throw new IllegalArgumentException("Invalid epsilon: expect epsilon in range (0,0.1], usually 1e-4 or so");
        }
        if (maxRelError <= 0.0 || maxRelError > 0.25) {
            throw new IllegalArgumentException("Invalid maxRelativeError: " + maxRelError);
        }
        if (graph.getNumInputArrays() != inputs.length) {
            throw new IllegalArgumentException("Invalid input arrays: expect " + graph.getNumInputArrays() + " inputs");
        }
        if (graph.getNumOutputArrays() != labels.length) {
            throw new IllegalArgumentException("Invalid labels arrays: expect " + graph.getNumOutputArrays() + " outputs");
        }
        DataBuffer.Type dataType = DataTypeUtil.getDtypeFromContext();
        if (dataType != DataBuffer.Type.DOUBLE) {
            throw new IllegalStateException("Cannot perform gradient check: Datatype is not set to double precision (is: " + dataType + "). Double precision must be used for gradient checks. Set DataTypeUtil.setDTypeForContext(DataBuffer.Type.DOUBLE); before using GradientCheckUtil");
        }
        int layerCount = 0;
        for (String vertexName : graph.getConfiguration().getVertices().keySet()) {
            GraphVertex gv = graph.getConfiguration().getVertices().get(vertexName);
            if (!(gv instanceof LayerVertex)) continue;
            LayerVertex lv = (LayerVertex)gv;
            if (lv.getLayerConf().getLayer() instanceof BaseLayer) {
                IActivation activation;
                BaseLayer bl = (BaseLayer)lv.getLayerConf().getLayer();
                IUpdater u = bl.getIUpdater();
                if (u instanceof Sgd) {
                    double lr = ((Sgd)u).getLearningRate();
                    if (lr != 1.0) {
                        throw new IllegalStateException("When using SGD updater, must also use lr=1.0 for layer " + layerCount + "; got " + u + " with lr=" + lr + " for layer \"" + lv.getLayerConf().getLayer().getLayerName() + "\"");
                    }
                } else if (!(u instanceof NoOp)) {
                    throw new IllegalStateException("Must have Updater.NONE (or SGD + lr=1.0) for layer " + layerCount + "; got " + u);
                }
                if ((activation = bl.getActivationFn()) != null && !VALID_ACTIVATION_FUNCTIONS.contains(activation.getClass())) {
                    log.warn("Layer \"" + vertexName + "\" is possibly using an unsuitable activation function: " + activation.getClass() + ". Activation functions for gradient checks must be smooth (like sigmoid, tanh, softmax) and not contain discontinuities like ReLU or LeakyReLU (these may cause spurious failures)");
                }
            }
            if (lv.getLayerConf().getLayer().getIDropout() == null || callEachIter != null) continue;
            throw new IllegalStateException("When gradient checking dropout, rng seed must be reset each iteration, or no dropout should be present during gradient checks - got dropout = " + lv.getLayerConf().getLayer().getIDropout() + " for layer " + layerCount);
        }
        for (Layer l : graph.getLayers()) {
            if (!(l instanceof IOutputLayer)) continue;
            GradientCheckUtil.configureLossFnClippingIfPresent((IOutputLayer)l);
        }
        for (i = 0; i < inputs.length; ++i) {
            graph.setInput(i, inputs[i]);
        }
        for (i = 0; i < labels.length; ++i) {
            graph.setLabel(i, labels[i]);
        }
        graph.setLayerMaskArrays(fMask, lMask);
        if (callEachIter != null) {
            callEachIter.accept((Object)graph);
        }
        graph.computeGradientAndScore();
        Pair<Gradient, Double> gradAndScore = graph.gradientAndScore();
        ComputationGraphUpdater updater = new ComputationGraphUpdater(graph);
        updater.update((Gradient)gradAndScore.getFirst(), 0, 0, graph.batchSize(), LayerWorkspaceMgr.noWorkspaces());
        INDArray gradientToCheck = ((Gradient)gradAndScore.getFirst()).gradient().dup();
        INDArray originalParams = graph.params().dup();
        long nParams = originalParams.length();
        Map<String, INDArray> paramTable = graph.paramTable();
        ArrayList<String> paramNames = new ArrayList<String>(paramTable.keySet());
        long[] paramEnds = new long[paramNames.size()];
        paramEnds[0] = paramTable.get(paramNames.get(0)).length();
        for (int i2 = 1; i2 < paramEnds.length; ++i2) {
            paramEnds[i2] = paramEnds[i2 - 1] + paramTable.get(paramNames.get(i2)).length();
        }
        int currParamNameIdx = 0;
        int totalNFailures = 0;
        double maxError = 0.0;
        MultiDataSet mds = new MultiDataSet(inputs, labels, fMask, lMask);
        INDArray params = graph.params();
        for (long i3 = 0L; i3 < nParams; ++i3) {
            if (i3 >= paramEnds[currParamNameIdx]) {
                ++currParamNameIdx;
            }
            String paramName = (String)paramNames.get(currParamNameIdx);
            if (excludeParams != null && excludeParams.contains(paramName)) {
                log.info("Skipping parameters for parameter name: {}", (Object)paramName);
                i3 = paramEnds[currParamNameIdx++];
                continue;
            }
            double origValue = params.getDouble(i3);
            params.putScalar(i3, origValue + epsilon);
            if (callEachIter != null) {
                callEachIter.accept((Object)graph);
            }
            double scorePlus = graph.score((org.nd4j.linalg.dataset.api.MultiDataSet)mds, true);
            params.putScalar(i3, origValue - epsilon);
            if (callEachIter != null) {
                callEachIter.accept((Object)graph);
            }
            double scoreMinus = graph.score((org.nd4j.linalg.dataset.api.MultiDataSet)mds, true);
            params.putScalar(i3, origValue);
            double scoreDelta = scorePlus - scoreMinus;
            double numericalGradient = scoreDelta / (2.0 * epsilon);
            if (Double.isNaN(numericalGradient)) {
                throw new IllegalStateException("Numerical gradient was NaN for parameter " + i3 + " of " + nParams);
            }
            double backpropGradient = gradientToCheck.getDouble(i3);
            double relError = Math.abs(backpropGradient - numericalGradient) / (Math.abs(numericalGradient) + Math.abs(backpropGradient));
            if (backpropGradient == 0.0 && numericalGradient == 0.0) {
                relError = 0.0;
            }
            if (relError > maxError) {
                maxError = relError;
            }
            if (relError > maxRelError || Double.isNaN(relError)) {
                double absError = Math.abs(backpropGradient - numericalGradient);
                if (absError < minAbsoluteError) {
                    log.info("Param " + i3 + " (" + paramName + ") passed: grad= " + backpropGradient + ", numericalGrad= " + numericalGradient + ", relError= " + relError + "; absolute error = " + absError + " < minAbsoluteError = " + minAbsoluteError);
                    continue;
                }
                if (print) {
                    log.info("Param " + i3 + " (" + paramName + ") FAILED: grad= " + backpropGradient + ", numericalGrad= " + numericalGradient + ", relError= " + relError + ", scorePlus=" + scorePlus + ", scoreMinus= " + scoreMinus + ", paramValue = " + origValue);
                }
                if (exitOnFirstError) {
                    return false;
                }
                ++totalNFailures;
                continue;
            }
            if (!print) continue;
            log.info("Param " + i3 + " (" + paramName + ") passed: grad= " + backpropGradient + ", numericalGrad= " + numericalGradient + ", relError= " + relError);
        }
        if (print) {
            long nPass = nParams - (long)totalNFailures;
            log.info("GradientCheckUtil.checkGradients(): " + nParams + " params checked, " + nPass + " passed, " + totalNFailures + " failed. Largest relative error = " + maxError);
        }
        return totalNFailures == 0;
    }

    public static boolean checkGradientsPretrainLayer(Layer layer, double epsilon, double maxRelError, double minAbsoluteError, boolean print, boolean exitOnFirstError, INDArray input, int rngSeed) {
        LayerWorkspaceMgr mgr = LayerWorkspaceMgr.noWorkspaces();
        if (epsilon <= 0.0 || epsilon > 0.1) {
            throw new IllegalArgumentException("Invalid epsilon: expect epsilon in range (0,0.1], usually 1e-4 or so");
        }
        if (maxRelError <= 0.0 || maxRelError > 0.25) {
            throw new IllegalArgumentException("Invalid maxRelativeError: " + maxRelError);
        }
        DataBuffer.Type dataType = DataTypeUtil.getDtypeFromContext();
        if (dataType != DataBuffer.Type.DOUBLE) {
            throw new IllegalStateException("Cannot perform gradient check: Datatype is not set to double precision (is: " + dataType + "). Double precision must be used for gradient checks. Set DataTypeUtil.setDTypeForContext(DataBuffer.Type.DOUBLE); before using GradientCheckUtil");
        }
        layer.setInput(input, LayerWorkspaceMgr.noWorkspaces());
        Nd4j.getRandom().setSeed(rngSeed);
        layer.computeGradientAndScore(mgr);
        Pair<Gradient, Double> gradAndScore = layer.gradientAndScore();
        Updater updater = UpdaterCreator.getUpdater(layer);
        updater.update(layer, (Gradient)gradAndScore.getFirst(), 0, 0, layer.batchSize(), LayerWorkspaceMgr.noWorkspaces());
        INDArray gradientToCheck = ((Gradient)gradAndScore.getFirst()).gradient().dup();
        INDArray originalParams = layer.params().dup();
        long nParams = originalParams.length();
        Map<String, INDArray> paramTable = layer.paramTable();
        ArrayList<String> paramNames = new ArrayList<String>(paramTable.keySet());
        long[] paramEnds = new long[paramNames.size()];
        paramEnds[0] = paramTable.get(paramNames.get(0)).length();
        for (int i = 1; i < paramEnds.length; ++i) {
            paramEnds[i] = paramEnds[i - 1] + paramTable.get(paramNames.get(i)).length();
        }
        int totalNFailures = 0;
        double maxError = 0.0;
        int currParamNameIdx = 0;
        INDArray params = layer.params();
        int i = 0;
        while ((long)i < nParams) {
            if ((long)i >= paramEnds[currParamNameIdx]) {
                ++currParamNameIdx;
            }
            String paramName = (String)paramNames.get(currParamNameIdx);
            double origValue = params.getDouble((long)i);
            params.putScalar((long)i, origValue + epsilon);
            Nd4j.getRandom().setSeed(rngSeed);
            layer.computeGradientAndScore(mgr);
            double scorePlus = layer.score();
            params.putScalar((long)i, origValue - epsilon);
            Nd4j.getRandom().setSeed(rngSeed);
            layer.computeGradientAndScore(mgr);
            double scoreMinus = layer.score();
            params.putScalar((long)i, origValue);
            double scoreDelta = scorePlus - scoreMinus;
            double numericalGradient = scoreDelta / (2.0 * epsilon);
            if (Double.isNaN(numericalGradient)) {
                throw new IllegalStateException("Numerical gradient was NaN for parameter " + i + " of " + nParams);
            }
            double backpropGradient = gradientToCheck.getDouble((long)i);
            double relError = Math.abs(backpropGradient - numericalGradient) / (Math.abs(numericalGradient) + Math.abs(backpropGradient));
            if (backpropGradient == 0.0 && numericalGradient == 0.0) {
                relError = 0.0;
            }
            if (relError > maxError) {
                maxError = relError;
            }
            if (relError > maxRelError || Double.isNaN(relError)) {
                double absError = Math.abs(backpropGradient - numericalGradient);
                if (absError < minAbsoluteError) {
                    log.info("Param " + i + " (" + paramName + ") passed: grad= " + backpropGradient + ", numericalGrad= " + numericalGradient + ", relError= " + relError + "; absolute error = " + absError + " < minAbsoluteError = " + minAbsoluteError);
                } else {
                    if (print) {
                        log.info("Param " + i + " (" + paramName + ") FAILED: grad= " + backpropGradient + ", numericalGrad= " + numericalGradient + ", relError= " + relError + ", scorePlus=" + scorePlus + ", scoreMinus= " + scoreMinus + ", paramValue = " + origValue);
                    }
                    if (exitOnFirstError) {
                        return false;
                    }
                    ++totalNFailures;
                }
            } else if (print) {
                log.info("Param " + i + " (" + paramName + ") passed: grad= " + backpropGradient + ", numericalGrad= " + numericalGradient + ", relError= " + relError);
            }
            ++i;
        }
        if (print) {
            long nPass = nParams - (long)totalNFailures;
            log.info("GradientCheckUtil.checkGradients(): " + nParams + " params checked, " + nPass + " passed, " + totalNFailures + " failed. Largest relative error = " + maxError);
        }
        return totalNFailures == 0;
    }
}

