/*
 * Copyright 2017-2022 John Snow Labs
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.johnsnowlabs.nlp.annotators.classifier.dl

import com.johnsnowlabs.ml.tensorflow._
import com.johnsnowlabs.nlp.AnnotatorType.{CATEGORY, SENTENCE_EMBEDDINGS}
import com.johnsnowlabs.nlp.annotators.ner.Verbose
import com.johnsnowlabs.nlp.util.io.ResourceHelper
import com.johnsnowlabs.nlp.{AnnotatorApproach, AnnotatorType, ParamsAndFeaturesWritable}
import com.johnsnowlabs.storage.HasStorageRef
import org.apache.spark.ml.PipelineModel
import org.apache.spark.ml.param._
import org.apache.spark.ml.util.{DefaultParamsReadable, Identifiable}
import org.apache.spark.sql.Dataset
import org.apache.spark.sql.types._

import scala.util.Random

/** Trains a ClassifierDL for generic Multi-class Text Classification.
  *
  * ClassifierDL uses the state-of-the-art Universal Sentence Encoder as an input for text
  * classifications. The ClassifierDL annotator uses a deep learning model (DNNs) we have built
  * inside TensorFlow and supports up to 100 classes.
  *
  * For instantiated/pretrained models, see [[ClassifierDLModel]].
  *
  * '''Notes''':
  *   - This annotator accepts a label column of a single item in either type of String, Int,
  *     Float, or Double.
  *   - [[com.johnsnowlabs.nlp.embeddings.UniversalSentenceEncoder UniversalSentenceEncoder]],
  *     [[com.johnsnowlabs.nlp.embeddings.BertSentenceEmbeddings BertSentenceEmbeddings]], or
  *     [[com.johnsnowlabs.nlp.embeddings.SentenceEmbeddings SentenceEmbeddings]] can be used for
  *     the `inputCol`.
  *
  * Setting a test dataset to monitor model metrics can be done with `.setTestDataset`. The method
  * expects a path to a parquet file containing a dataframe that has the same required columns as
  * the training dataframe. The pre-processing steps for the training dataframe should also be
  * applied to the test dataframe. The following example will show how to create the test dataset:
  *
  * {{{
  * val documentAssembler = new DocumentAssembler()
  *   .setInputCol("text")
  *   .setOutputCol("document")
  *
  * val embeddings = UniversalSentenceEncoder.pretrained()
  *   .setInputCols("document")
  *   .setOutputCol("sentence_embeddings")
  *
  * val preProcessingPipeline = new Pipeline().setStages(Array(documentAssembler, embeddings))
  *
  * val Array(train, test) = data.randomSplit(Array(0.8, 0.2))
  * preProcessingPipeline
  *   .fit(test)
  *   .transform(test)
  *   .write
  *   .mode("overwrite")
  *   .parquet("test_data")
  *
  * val classifier = new ClassifierDLApproach()
  *   .setInputCols("sentence_embeddings")
  *   .setOutputCol("category")
  *   .setLabelColumn("label")
  *   .setTestDataset("test_data")
  * }}}
  *
  * For extended examples of usage, see the Examples
  * [[https://github.com/JohnSnowLabs/spark-nlp/blob/master/examples/scala/training/Train%20Multi-Class%20Text%20Classification%20on%20News%20Articles.scala [1]]]
  * [[https://github.com/JohnSnowLabs/spark-nlp/blob/master/examples/python/training/english/classification/ClassifierDL_Train_multi_class_news_category_classifier.ipynb [2]]]
  * and the
  * [[https://github.com/JohnSnowLabs/spark-nlp/blob/master/src/test/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/ClassifierDLTestSpec.scala ClassifierDLTestSpec]].
  *
  * ==Example==
  * In this example, the training data `"sentiment.csv"` has the form of
  * {{{
  * text,label
  * This movie is the best movie I have wached ever! In my opinion this movie can win an award.,0
  * This was a terrible movie! The acting was bad really bad!,1
  * ...
  * }}}
  * Then traning can be done like so:
  * {{{
  * import com.johnsnowlabs.nlp.base.DocumentAssembler
  * import com.johnsnowlabs.nlp.embeddings.UniversalSentenceEncoder
  * import com.johnsnowlabs.nlp.annotators.classifier.dl.ClassifierDLApproach
  * import org.apache.spark.ml.Pipeline
  *
  * val smallCorpus = spark.read.option("header","true").csv("src/test/resources/classifier/sentiment.csv")
  *
  * val documentAssembler = new DocumentAssembler()
  *   .setInputCol("text")
  *   .setOutputCol("document")
  *
  * val useEmbeddings = UniversalSentenceEncoder.pretrained()
  *   .setInputCols("document")
  *   .setOutputCol("sentence_embeddings")
  *
  * val docClassifier = new ClassifierDLApproach()
  *   .setInputCols("sentence_embeddings")
  *   .setOutputCol("category")
  *   .setLabelColumn("label")
  *   .setBatchSize(64)
  *   .setMaxEpochs(20)
  *   .setLr(5e-3f)
  *   .setDropout(0.5f)
  *
  * val pipeline = new Pipeline()
  *   .setStages(
  *     Array(
  *       documentAssembler,
  *       useEmbeddings,
  *       docClassifier
  *     )
  *   )
  *
  * val pipelineModel = pipeline.fit(smallCorpus)
  * }}}
  *
  * @see
  *   [[MultiClassifierDLApproach]] for multi-class classification
  * @see
  *   [[SentimentDLApproach]] for sentiment analysis
  * @groupname anno Annotator types
  * @groupdesc anno
  *   Required input and expected output annotator types
  * @groupname Ungrouped Members
  * @groupname param Parameters
  * @groupname setParam Parameter setters
  * @groupname getParam Parameter getters
  * @groupname Ungrouped Members
  * @groupprio param  1
  * @groupprio anno  2
  * @groupprio Ungrouped 3
  * @groupprio setParam  4
  * @groupprio getParam  5
  * @groupdesc param
  *   A list of (hyper-)parameter keys this annotator can take. Users can set and get the
  *   parameter values through setters and getters, respectively.
  */
class ClassifierDLApproach(override val uid: String)
    extends AnnotatorApproach[ClassifierDLModel]
    with ParamsAndFeaturesWritable
    with ClassifierEncoder {

  def this() = this(Identifiable.randomUID("ClassifierDL"))

  /** Trains TensorFlow model for multi-class text classification */
  override val description = "Trains TensorFlow model for multi-class text classification"

  /** Input annotator type : SENTENCE_EMBEDDINGS
    *
    * @group anno
    */
  override val inputAnnotatorTypes: Array[AnnotatorType] = Array(SENTENCE_EMBEDDINGS)

  /** Output annotator type : CATEGORY
    *
    * @group anno
    */
  override val outputAnnotatorType: String = CATEGORY

  /** Dropout coefficient (Default: `0.5f`)
    *
    * @group param
    */
  val dropout = new FloatParam(this, "dropout", "Dropout coefficient")

  /** Dropout coefficient (Default: `0.5f`)
    *
    * @group setParam
    */
  def setDropout(dropout: Float): ClassifierDLApproach.this.type = set(this.dropout, dropout)

  /** Dropout coefficient (Default: `0.5f`)
    *
    * @group getParam
    */
  def getDropout: Float = $(this.dropout)

  setDefault(maxEpochs -> 10, lr -> 5e-3f, dropout -> 0.5f, batchSize -> 64)

  override def train(
      dataset: Dataset[_],
      recursivePipeline: Option[PipelineModel]): ClassifierDLModel = {

    val labelColType = dataset.schema($(labelColumn)).dataType
    require(
      labelColType == StringType | labelColType == IntegerType | labelColType == DoubleType | labelColType == FloatType | labelColType == LongType,
      s"The label column $labelColumn type is $labelColType and it's not compatible. " +
        s"Compatible types are StringType, IntegerType, DoubleType, LongType, or FloatType. ")

    val (trainDataset, trainLabels) = buildDatasetWithLabels(dataset, getInputCols(0))
    val settings = ClassifierDatasetEncoderParams(tags = trainLabels)
    val encoder = new ClassifierDatasetEncoder(settings)
    val trainInputs = extractInputs(encoder, trainDataset)

    var testEncoder: Option[ClassifierDatasetEncoder] = None
    val testInputs =
      if (!isDefined(testDataset)) None
      else {
        val testDataFrame = ResourceHelper.readSparkDataFrame($(testDataset))
        val (test, testLabels) = buildDatasetWithLabels(testDataFrame, getInputCols(0))
        val settings = ClassifierDatasetEncoderParams(tags = testLabels)
        testEncoder = Some(new ClassifierDatasetEncoder(settings))
        Option(extractInputs(testEncoder.get, test))
      }

    val tfWrapper: TensorflowWrapper = loadSavedModel()

    val classifier =
      try {
        val model = new TensorflowClassifier(
          tensorflow = tfWrapper,
          encoder,
          testEncoder,
          Verbose($(verbose)))
        if (isDefined(randomSeed)) {
          Random.setSeed($(randomSeed))
        }

        model.train(
          trainInputs,
          testInputs,
          trainLabels.length,
          lr = $(lr),
          batchSize = $(batchSize),
          dropout = $(dropout),
          endEpoch = $(maxEpochs),
          configProtoBytes = getConfigProtoBytes,
          validationSplit = $(validationSplit),
          evaluationLogExtended = $(evaluationLogExtended),
          enableOutputLogs = $(enableOutputLogs),
          outputLogsPath = $(outputLogsPath),
          uuid = this.uid)
        model
      } catch {
        case e: Exception =>
          throw e
      }

    val newWrapper = new TensorflowWrapper(
      TensorflowWrapper.extractVariablesSavedModel(
        tfWrapper.getTFSession(configProtoBytes = getConfigProtoBytes)),
      tfWrapper.graph)

    val embeddingsRef = HasStorageRef.getStorageRefFromInput(
      dataset,
      $(inputCols),
      AnnotatorType.SENTENCE_EMBEDDINGS)

    val model = new ClassifierDLModel()
      .setDatasetParams(classifier.encoder.params)
      .setModelIfNotSet(dataset.sparkSession, newWrapper)
      .setStorageRef(embeddingsRef)

    if (get(configProtoBytes).isDefined)
      model.setConfigProtoBytes($(configProtoBytes))

    model
  }

  def loadSavedModel(): TensorflowWrapper = {

    val wrapper =
      TensorflowWrapper
        .readZippedSavedModel("/classifier-dl", tags = Array("serve"), initAllTables = true)

    wrapper.variables = Variables(Array.empty[Array[Byte]], Array.empty[Byte])
    wrapper
  }
}

/** This is the companion object of [[ClassifierDLApproach]]. Please refer to that class for the
  * documentation.
  */
object ClassifierDLApproach extends DefaultParamsReadable[ClassifierDLApproach]
