package ai.minxiao.ds4s.core.dl4j.prediction

import org.deeplearning4j.nn.multilayer.MultiLayerNetwork
import org.nd4j.linalg.api.ndarray.INDArray
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator
import org.nd4j.linalg.dataset.DataSet

/**
  * DL4J Predictor
  *
  * @author mx
  */
object Predictor {

  /**
    * @param net trained multilayer network
    * @param data test data
    */
  def run(net: MultiLayerNetwork, data: INDArray): INDArray =
    net.output(data, false)

  def run(net: MultiLayerNetwork, data: DataSet): INDArray =
    net.output(data.getFeatures, false)
}
