package ai.minxiao.ds4s.core.dl4j.nnbase

import org.deeplearning4j.nn.api.OptimizationAlgorithm
import org.deeplearning4j.nn.conf.inputs.InputType
import org.deeplearning4j.nn.conf.inputs.InputType.InputTypeFeedForward
import org.deeplearning4j.nn.conf.layers.OutputLayer
import org.deeplearning4j.nn.conf.{BackpropType, GradientNormalization, MultiLayerConfiguration, NeuralNetConfiguration, Updater}
import org.deeplearning4j.nn.conf.NeuralNetConfiguration.{Builder => NNConfBuilder, ListBuilder}
import org.deeplearning4j.nn.conf.weightnoise.DropConnect
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork
import org.deeplearning4j.nn.weights.WeightInit
import org.nd4j.linalg.activations.Activation
import org.nd4j.linalg.factory.Nd4j
import org.nd4j.linalg.learning.config.{AdaDelta, AdaGrad, AdaMax, Adam, Nadam, Nesterovs, NoOp, RmsProp, Sgd}
import org.nd4j.linalg.lossfunctions.impl._
import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction

import ai.minxiao.ds4s.core.dl4j.ui.UIStarter

/**
  * Neural Network Base
  *
  * BASE -------------------------------------------------------------------------------------------------------------
  *
  * @param seed random generator seed, default=2018
  *
  * ------------------------------------------------------------------------------------------------------------
  *
  * REGULARIZATION
  * @param l2 l2 regularization, default=0.0
  * @param l1 l1 regularization, default=0.0
  * @param l2Bias l2 bias term, default=0.0
  * @param l1Bias l1 bias term, default=0.0
  * @param weightNoise whether to use weight noise (drop connect), default=false
  * @param weightRetainProbability weight retain probability for the weight noise (drop-connect), default=1 (no drop-connect)
  * @param applyToBiases whether apply to biases for the weight noise (drop-connect), default=false
  *
  * ------------------------------------------------------------------------------------------------------------------
  *
  * OPTIMIZATION
  *
  * @param optimizationAlgo
  *   <a href="https://github.com/deeplearning4j/deeplearning4j/blob/master/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/api/OptimizationAlgorithm.java">
  *   optimization algorithm</a> (default=STOCHASTIC_GRADIENT_DESCENT)
  * {{{
  * STOCHASTIC_GRADIENT_DESCENT://<a href="https://github.com/deeplearning4j/deeplearning4j/blob/master/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/optimize/solvers/StochasticGradientDescent.java">StochasticGradientDescent.java</a>
  * LINE_GRADIENT_DESCENT://<a href="https://github.com/deeplearning4j/deeplearning4j/blob/master/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/optimize/solvers/LineGradientDescent.java">LineGradientDescent.java</a>
  * CONJUGATE_GRADIENT://<a href="https://github.com/deeplearning4j/deeplearning4j/blob/master/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/optimize/solvers/ConjugateGradient.java">ConjugateGradient.java</a>
  * LBFGS://<a href="https://github.com/deeplearning4j/deeplearning4j/blob/master/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/optimize/solvers/LBFGS.java">LBFGS.java</a>
  * }}}
  * @param miniBatch whether to use mini-batch, default=true
  * @param learningRate learning rate, default=0.1
  * @param beta1 gradient moving avg decay rate, default=0.9
  * @param beta2 gradient sqrt decay rate, default=0.999
  * @param epsilon default=1E-8
  * @param momentum NESTEROVS momentum, default=0.9
  * @param rmsDecay RMSPROP decay rate, default=0.95
  * @param rho ADADELTA decay rate, default=0.95
  * @param updater <a href="https://github.com/deeplearning4j/deeplearning4j/blob/master/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/Updater.java">
  *   weights updater</a>, (default = NESTEROVS).
  * Options: {{{
  * SGD: //<a href="https://github.com/deeplearning4j/deeplearning4j/blob/master/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/config/Sgd.java">Sgd.java</a>
  *   learningRate: learning rate (default = 1E-3)
  * ADAM: //<a href="https://github.com/deeplearning4j/deeplearning4j/blob/master/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/config/Adam.java">Adam.java</a>
  *   learningRate: learning rate, DEFAULT_ADAM_LEARNING_RATE = 1e-3;
  *   beta1: gradient moving avg decay rate, DEFAULT_ADAM_BETA1_MEAN_DECAY = 0.9;
  *   beta2: gradient sqrt decay rate, DEFAULT_ADAM_BETA2_VAR_DECAY = 0.999;
  *   epsilon: epsilon, DEFAULT_ADAM_EPSILON = 1e-8;
  *   //<a href="http://arxiv.org/abs/1412.6980">Adam: A Method for Stochastic Optimization</a>
  * ADAMAX: //<a href="https://github.com/deeplearning4j/deeplearning4j/blob/master/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/config/AdaMax.java">AdaMax.java</a>
  *   learningRate: learning rate, DEFAULT_ADAMAX_LEARNING_RATE = 1e-3;
  *   beta1: gradient moving avg decay rate, DEFAULT_ADAMAX_BETA1_MEAN_DECAY = 0.9;
  *   beta2: gradient sqrt decay rate, DEFAULT_ADAMAX_BETA2_VAR_DECAY = 0.999;
  *   epsilon: epsilon, DEFAULT_ADAMAX_EPSILON = 1e-8;
  *   //<a href="http://arxiv.org/abs/1412.6980">Adam: A Method for Stochastic Optimization</a>
  * NADAM://<a href="https://github.com/deeplearning4j/deeplearning4j/blob/master/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/config/Nadam.java">Nadam.java</a>
  *   learningRate: learning rate, DEFAULT_NADAM_LEARNING_RATE = 1e-3;
  *   epsilon: DEFAULT_NADAM_EPSILON = 1e-8;
  *   beta1: gradient moving avg decay rate, DEFAULT_NADAM_BETA1_MEAN_DECAY = 0.9;
  *   beta2: gradient sqrt decay rate, DEFAULT_NADAM_BETA2_VAR_DECAY = 0.999;
  *   //<a href="https://arxiv.org/pdf/1609.04747.pdf">An overview of gradient descent optimization algorithms</a>
  * AMSGRAD: //<a href="https://github.com/deeplearning4j/deeplearning4j/blob/master/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/config/AMSGrad.java">AMSGrad.java</a>
  *   learningRate: learning rate, DEFAULT_AMSGRAD_LEARNING_RATE = 1e-3;
  *   epsilon: DEFAULT_AMSGRAD_EPSILON = 1e-8;
  *   beta1: DEFAULT_AMSGRAD_BETA1_MEAN_DECAY = 0.9;
  *   beta2: DEFAULT_AMSGRAD_BETA2_VAR_DECAY = 0.999;
  * ADAGRAD: Vectorized Learning Rate used per Connection Weight//<a href="https://github.com/deeplearning4j/deeplearning4j/blob/master/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/config/AdaGrad.java">AdaGrad.java</a>
  *   learningRate: learning rate, DEFAULT_ADAGRAD_LEARNING_RATE = 1e-1;
  *   epsilon: DEFAULT_ADAGRAD_EPSILON = 1e-6;
  *   //<a href="http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf">Adaptive Subgradient Methods for Online Learning and Stochastic Optimization</a>
  *   //<a href="http://xcorr.net/2014/01/23/adagrad-eliminating-learning-rates-in-stochastic-gradient-descent/">Adagrad – eliminating learning rates in stochastic gradient descent</a>
  * NESTEROVS: tracks previous layer's gradient and uses it as a way of updating the gradient //<a href="https://github.com/deeplearning4j/deeplearning4j/blob/master/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/config/Nesterovs.java">Nesterovs.java</a>
  *   learningRate: learning rate, DEFAULT_NESTEROV_LEARNING_RATE = 0.1;
  *   momentum: DEFAULT_NESTEROV_MOMENTUM = 0.9;
  * RMSPROP: //<a href="https://github.com/deeplearning4j/deeplearning4j/blob/master/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/config/RmsProp.java">RmsProp.java</a>
  *   learningRate: learning rate, DEFAULT_RMSPROP_LEARNING_RATE = 1e-1;
  *   epsilon: DEFAULT_RMSPROP_EPSILON = 1e-8;
  *   rmsDecay: decay rate, DEFAULT_RMSPROP_RMSDECAY = 0.95;
  *   //<a href="http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf">Neural Networks for Machine Learning</a>
  * ADADELTA: //<a href="https://github.com/deeplearning4j/deeplearning4j/blob/master/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/config/AdaDelta.java">AdaDelta.java</a>
  *   rho: decay rate, controlling the decay of the previous parameter updates, DEFAULT_ADADELTA_RHO = 0.95;
  *   epsilon: DEFAULT_ADADELTA_EPSILON = 1e-6;
  *   (no need to manually set the learning rate)
  *   //<a href="http://www.matthewzeiler.com/pubs/googleTR2012/googleTR2012.pdf">ADADELTA: AN ADAPTIVE LEARNING RATE METHOD</a>
  * NONE: no updates //<a href="https://github.com/deeplearning4j/deeplearning4j/blob/master/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/config/NoOp.java">NoOp.java</a>
  * }}}
  * @param gradientNormalization gradient normalization, default=None
  * Options: GradientNormalization.X
  * {{{
  * ClipElementWiseAbsoluteValue:
  *  g <- sign(g)*max(maxAllowedValue,|g|).
  * ClipL2PerLayer:
  *   GOut = G                             if l2Norm(G) < threshold (i.e., no change)
  *   GOut = threshold * G / l2Norm(G)     otherwise
  * ClipL2PerParamType: conditional renormalization. Very similar to ClipL2PerLayer, however instead of clipping per layer, do clipping on each parameter type separately.
  * None: no gradient normalization
  * RenormalizeL2PerLayer: rescale gradients by dividing by the L2 norm of all gradients for the layer
  * RenormalizeL2PerParamType:
  *  GOut_weight = G_weight / l2(G_weight)
  *  GOut_bias = G_bias / l2(G_bias)
  * }}}
  * @param gradientNormalizationThreshold gradient threshold, default=0.5
  *
  * -------------------------------------------------------------------------------------------------------------------------------------------
  *
  * @author mx
  */
@SerialVersionUID(787866L)
abstract class NNBase (
  // Base
  seed: Long = 2018L,
  // Regularization
  l2: Double = 0.0,
  l1: Double = 0.0,
  l2Bias: Double = 0.0,
  l1Bias: Double = 0.0,
  weightNoise: Boolean = false,
  weightRetainProbability: Double = 1.0,
  applyToBiases: Boolean = false,
  // Optimization
  optimizationAlgo: OptimizationAlgorithm = OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT,
  miniBatch: Boolean = true,
  learningRate: Double = 0.1,
  beta1: Double = 0.9,
  beta2: Double = 0.999,
  epsilon: Double = 1E-8,
  momentum: Double = 0.9,
  rmsDecay: Double = 0.95,
  rho: Double = 0.95,
  updater: Updater = Updater.NESTEROVS,
  gradientNormalization: GradientNormalization = GradientNormalization.None,
  gradientNormalizationThreshold: Double = 1.0
) extends Serializable {

  /** Base Configuration Builder */
  protected lazy val baseConfBuilder: NNConfBuilder =
    new NeuralNetConfiguration.Builder().
    seed(seed)

  /** Regularization Configuration Builder */
  protected def regConfBuilder(confBuilder: NNConfBuilder): NNConfBuilder = {
    confBuilder.
    l2(l2).
    l1(l1).
    l2Bias(l2Bias).
    l1Bias(l1Bias)

    if (weightNoise) confBuilder.weightNoise(new DropConnect(weightRetainProbability, applyToBiases))
    confBuilder
  }

  /** Optimization Configuration Builder */
  protected def optConfBuilder(confBuilder: NNConfBuilder): NNConfBuilder =
    confBuilder.
    optimizationAlgo(optimizationAlgo).
    miniBatch(miniBatch).
    updater(// set updater
      updater match {
        case Updater.ADADELTA  => new AdaDelta(rho, epsilon)
        case Updater.ADAGRAD   => new AdaGrad(learningRate, epsilon)
        case Updater.ADAMAX    => new AdaMax(learningRate, beta1, beta2, epsilon)
        case Updater.ADAM      => new Adam(learningRate, beta1, beta2, epsilon)
        case Updater.NADAM     => new Nadam(learningRate, beta1, beta2, epsilon)
        case Updater.NESTEROVS => new Nesterovs(learningRate, momentum)
        case Updater.RMSPROP   => new RmsProp(learningRate, rmsDecay, epsilon)
        case Updater.SGD       => new Sgd(learningRate)
        case /* NONE */ _      => new NoOp
      }
    ).
    gradientNormalization(gradientNormalization).
    gradientNormalizationThreshold(gradientNormalizationThreshold)
}
