/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package ai.h2o.sparkling.ml.algos

import ai.h2o.sparkling.H2OFrame
import ai.h2o.sparkling.ml.params.H2OCommonParams
import ai.h2o.sparkling.ml.utils.{EstimatorCommonUtils, SchemaUtils}
import org.apache.spark.h2o.H2OContext
import org.apache.spark.sql.Dataset
import org.apache.spark.sql.functions.col

trait H2OAlgoCommonUtils extends EstimatorCommonUtils {
  protected def getExcludedCols(): Seq[String]

  private[sparkling] def getFeaturesCols(): Array[String]

  private[sparkling] def getFeaturesColsInternal(): Array[String] = getFeaturesCols()

  private[sparkling] def getColumnsToCategorical(): Array[String]

  private[sparkling] def getColumnsToCategoricalInternal(): Array[String] = getColumnsToCategorical()

  private[sparkling] def getSplitRatio(): Double

  private[sparkling] def getSplitRatioInternal(): Double = getSplitRatio()

  private[sparkling] def setFeaturesCols(value: Array[String]): this.type

  private[sparkling] def setFeaturesColsInternal(value: Array[String]): this.type = setFeaturesCols(value)

  protected def prepareDatasetForFitting(dataset: Dataset[_]): (H2OFrame, Option[H2OFrame]) = {
    val excludedCols = getExcludedCols()

    if (getFeaturesColsInternal().isEmpty) {
      val features = dataset.columns.filter(c => excludedCols.forall(e => c.compareToIgnoreCase(e) != 0))
      setFeaturesColsInternal(features)
    } else {
      val missingColumns = getFeaturesColsInternal()
        .filterNot(col => dataset.columns.contains(col))

      if (missingColumns.nonEmpty) {
        throw new IllegalArgumentException(
          "The following feature columns are not available on" +
            s" the training dataset: '${missingColumns.mkString(", ")}'")
      }
    }

    val featureColumns = getFeaturesColsInternal().map(sanitize).map(col)
    val excludedColumns = excludedCols.map(sanitize).map(col)
    val columns = featureColumns ++ excludedColumns
    val h2oContext = H2OContext.ensure(
      "H2OContext needs to be created in order to train the model. Please create one as H2OContext.getOrCreate().")
    val trainFrame = H2OFrame(h2oContext.asH2OFrameKeyString(dataset.select(columns: _*).toDF()))

    // Our MOJO wrapper needs the full column name before the array/vector expansion in order to do predictions
    trainFrame.convertColumnsToCategorical(getColumnsToCategoricalInternal())

    if (getSplitRatioInternal() < 1.0) {
      val frames = trainFrame.split(getSplitRatioInternal())
      if (frames.length > 1) {
        (frames(0), Some(frames(1)))
      } else {
        (frames(0), None)
      }
    } else {
      (trainFrame, None)
    }
  }

  def sanitize(colName: String) = '`' + colName + '`'
}
