package ai.minxiao.ds4s.core.dl4j.serializer

import java.io.{File, InputStream, OutputStream}

import org.nd4j.linalg.dataset.api.preprocessor.{DataNormalization, MultiDataNormalization, Normalizer}
import org.deeplearning4j.nn.api.Model
import org.deeplearning4j.nn.graph.ComputationGraph
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork
import org.deeplearning4j.util.ModelSerializer
import org.nd4j.linalg.primitives.Pair

/**
  * Serializer for DL4J
  *
  * @author mx
  */
object Serializer {

  def save(model: Model, path: String, saveUpdater: Boolean): Unit =
    ModelSerializer.writeModel(model, path, saveUpdater)

  def save(model: Model, os: OutputStream, saveUpdater: Boolean, dataNormalization: Option[DataNormalization] = None): Unit =
    if (dataNormalization != None)
      ModelSerializer.writeModel(model, os, saveUpdater, dataNormalization.get)
    else ModelSerializer.writeModel(model, os, saveUpdater)

  def save(model: Model, oF: File, saveUpdater: Boolean, multiDataNormalization: MultiDataNormalization): Unit = {
      ModelSerializer.writeModel(model, oF, saveUpdater)
      ModelSerializer.addNormalizerToModel(oF, multiDataNormalization)
    }

  def loadMLNN(is: InputStream, saveUpdater: Boolean, saveType: String = "mlnn"): Tuple2[MultiLayerNetwork, Option[DataNormalization]] = {
    saveType match {
      case "mlnn-dtnm" =>
        val p = ModelSerializer.restoreMultiLayerNetworkAndNormalizer(is, saveUpdater)
        (p.getFirst(), Some(p.getSecond().asInstanceOf[DataNormalization]))
      case _/*"mlnn"*/ =>
          (ModelSerializer.restoreMultiLayerNetwork(is, saveUpdater),
            None)
    }
  }

  def loadCGNN(is: InputStream, saveUpdater: Boolean, saveType: String = "cgnn"): Tuple2[ComputationGraph, Option[DataNormalization]] =
    saveType match {
      case "cgnn-dtnm" =>
        val p = ModelSerializer.restoreComputationGraphAndNormalizer(is, saveUpdater)
        (p.getFirst(), Some(p.getSecond().asInstanceOf[DataNormalization]))
      case _/*cgnn*/   =>
        (ModelSerializer.restoreComputationGraph(is, saveUpdater),
          None)
    }

    def loadCGNNMD(is: InputStream, saveUpdater: Boolean, saveType: String = "cgnn"): Tuple2[ComputationGraph, Option[MultiDataNormalization]] =
      saveType match {
        case "cgnn-dtnm" =>
          val p = ModelSerializer.restoreComputationGraphAndNormalizer(is, saveUpdater)
          (p.getFirst(), Some(p.getSecond().asInstanceOf[MultiDataNormalization]))
        case _/*cgnn*/   =>
          (ModelSerializer.restoreComputationGraph(is, saveUpdater),
            None)
      }

}
