/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.graphframes.examples

import org.apache.spark.graphx.Graph
import org.apache.spark.graphx.VertexRDD
import org.apache.spark.graphx.{Edge => GXEdge}
import org.apache.spark.sql.Column
import org.apache.spark.sql.Row
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.functions.lit
import org.apache.spark.sql.functions.sum
import org.apache.spark.sql.functions.udf
import org.apache.spark.sql.functions.when
import org.graphframes.GraphFrame
import org.graphframes.examples.Graphs.gridIsingModel
import org.graphframes.lib.AggregateMessages

/**
 * Example code for Belief Propagation (BP)
 *
 * This provides a template for building customized BP algorithms for different types of graphical
 * models.
 *
 * This example:
 *   - Ising model on a grid
 *   - Parallel Belief Propagation using colored fields
 *
 * Ising models are probabilistic graphical models over binary variables x,,i,,. Each binary
 * variable x,,i,, corresponds to one vertex, and it may take values -1 or +1. The probability
 * distribution P(X) (over all x,,i,,) is parameterized by vertex factors a,,i,, and edge factors
 * b,,ij,,:
 * {{{
 *  P(X) = (1/Z) * exp[ \sum_i a_i x_i + \sum_{ij} b_{ij} x_i x_j ]
 * }}}
 * where Z is the normalization constant (partition function). See
 * [[https://en.wikipedia.org/wiki/Ising_model Wikipedia]] for more information on Ising models.
 *
 * Belief Propagation (BP) provides marginal probabilities of the values of the variables x,,i,,,
 * i.e., P(x,,i,,) for each i. This allows a user to understand likely values of variables. See
 * [[https://en.wikipedia.org/wiki/Belief_propagation Wikipedia]] for more information on BP.
 *
 * We use a batch synchronous BP algorithm, where batches of vertices are updated synchronously.
 * We follow the mean field update algorithm in Slide 13 of the
 * [[http://www.eecs.berkeley.edu/~wainwrig/Talks/A_GraphModel_Tutorial talk slides]] from:
 * Wainwright. "Graphical models, message-passing algorithms, and convex optimization."
 *
 * The batches are chosen according to a coloring. For background on graph colorings for
 * inference, see for example: Gonzalez et al. "Parallel Gibbs Sampling: From Colored Fields to
 * Thin Junction Trees." AISTATS, 2011.
 *
 * The BP algorithm works by:
 *   - Coloring the graph by assigning a color to each vertex such that no neighboring vertices
 *     share the same color.
 *   - In each step of BP, update all vertices of a single color. Alternate colors.
 */
object BeliefPropagation {

  def main(args: Array[String]): Unit = {
    val spark = SparkSession
      .builder()
      .appName("BeliefPropagation example")
      .getOrCreate()

    // Create graphical model g of size 3 x 3.
    val g = gridIsingModel(spark, 3)

    println("Original Ising model:")
    g.vertices.show()
    g.edges.show()

    // Run BP for 5 iterations.
    val numIter = 5
    val results = runBPwithGraphX(g, numIter)

    // Display beliefs.
    val beliefs = results.vertices.select("id", "belief")
    println(s"Done with BP. Final beliefs after $numIter iterations:")
    beliefs.show()

    spark.stop()
  }

  /**
   * Given a GraphFrame, choose colors for each vertex. No neighboring vertices will share the
   * same color. The number of colors is minimized.
   *
   * This is written specifically for grid graphs. For non-grid graphs, it should be generalized,
   * such as by using a greedy coloring scheme.
   *
   * @param g
   *   Grid graph generated by [[org.graphframes.examples.Graphs.gridIsingModel()]]
   * @return
   *   Same graph, but with a new vertex column "color" of type Int (0 or 1)
   */
  private def colorGraph(g: GraphFrame): GraphFrame = {
    val colorUDF = udf { (i: Int, j: Int) => (i + j) % 2 }
    val v = g.vertices.withColumn("color", colorUDF(col("i"), col("j")))
    GraphFrame(v, g.edges)
  }

  /**
   * Run Belief Propagation.
   *
   * This implementation of BP shows how to use GraphX's aggregateMessages method. It is simple to
   * convert to and from GraphX format. This method does the following:
   *   - Color GraphFrame vertices for BP scheduling.
   *   - Convert GraphFrame to GraphX format.
   *   - Run BP using GraphX's aggregateMessages API.
   *   - Augment the original GraphFrame with the BP results (vertex beliefs).
   *
   * @param g
   *   Graphical model created by `org.graphframes.examples.Graphs.gridIsingModel()`
   * @param numIter
   *   Number of iterations of BP to run. One iteration includes updating each vertex's belief
   *   once.
   * @return
   *   Same graphical model, but with [[GraphFrame.vertices]] augmented with a new column "belief"
   *   containing P(x,,i,, = +1), the marginal probability of vertex i taking value +1 instead of
   *   -1.
   */
  def runBPwithGraphX(g: GraphFrame, numIter: Int): GraphFrame = {
    // Choose colors for vertices for BP scheduling.
    val colorG = colorGraph(g)
    val numColors: Int = colorG.vertices.select("color").distinct().count().toInt

    // Convert GraphFrame to GraphX, and initialize beliefs.
    val gx0 = colorG.toGraphX
    // Schema maps for extracting attributes
    val vColsMap = colorG.vertexColumnMap
    val eColsMap = colorG.edgeColumnMap
    // Convert vertex attributes to nice case classes.
    val gx1: Graph[VertexAttr, Row] = gx0.mapVertices { case (_, attr) =>
      // Initialize belief at 0.0
      VertexAttr(attr.getDouble(vColsMap("a")), 0.0, attr.getInt(vColsMap("color")))
    }
    // Convert edge attributes to nice case classes.
    val extractEdgeAttr: (GXEdge[Row] => EdgeAttr) = { e =>
      EdgeAttr(e.attr.getDouble(eColsMap("b")))
    }
    var gx: Graph[VertexAttr, EdgeAttr] = gx1.mapEdges(extractEdgeAttr)

    // Run BP for numIter iterations.
    for (iter <- Range(0, numIter)) {
      // For each color, have that color receive messages from neighbors.
      for (color <- Range(0, numColors)) {
        // Send messages to vertices of the current color.
        val msgs: VertexRDD[Double] = gx.aggregateMessages(
          ctx =>
            // Can send to source or destination since edges are treated as undirected.
            if (ctx.dstAttr.color == color) {
              val msg = ctx.attr.b * ctx.srcAttr.belief
              // Only send message if non-zero.
              if (msg != 0) ctx.sendToDst(msg)
            } else if (ctx.srcAttr.color == color) {
              val msg = ctx.attr.b * ctx.dstAttr.belief
              // Only send message if non-zero.
              if (msg != 0) ctx.sendToSrc(msg)
            },
          _ + _)
        // Receive messages, and update beliefs for vertices of the current color.
        gx = gx.outerJoinVertices(msgs) { case (vID, vAttr, optMsg) =>
          if (vAttr.color == color) {
            val x = vAttr.a + optMsg.getOrElse(0.0)
            val newBelief = math.exp(-log1pExp(-x))
            VertexAttr(vAttr.a, newBelief, color)
          } else {
            vAttr
          }
        }
      }
    }

    // Convert back to GraphFrame with a new column "belief" for vertices DataFrame.
    val gxFinal: Graph[Double, Unit] = gx.mapVertices((_, attr) => attr.belief).mapEdges(_ => ())
    GraphFrame.fromGraphX(colorG, gxFinal, vertexNames = Seq("belief"))
  }

  case class VertexAttr(a: Double, belief: Double, color: Int)

  case class EdgeAttr(b: Double)

  /**
   * Run Belief Propagation.
   *
   * This implementation of BP shows how to use GraphFrame's aggregateMessages method.
   *   - Color GraphFrame vertices for BP scheduling.
   *   - Run BP using GraphFrame's aggregateMessages API.
   *   - Augment the original GraphFrame with the BP results (vertex beliefs).
   *
   * @param g
   *   Graphical model created by `org.graphframes.examples.Graphs.gridIsingModel()`
   * @param numIter
   *   Number of iterations of BP to run. One iteration includes updating each vertex's belief
   *   once.
   * @return
   *   Same graphical model, but with [[GraphFrame.vertices]] augmented with a new column "belief"
   *   containing P(x,,i,, = +1), the marginal probability of vertex i taking value +1 instead of
   *   -1.
   */
  def runBPwithGraphFrames(g: GraphFrame, numIter: Int): GraphFrame = {
    // Choose colors for vertices for BP scheduling.
    val colorG = colorGraph(g)
    val numColors: Int = colorG.vertices.select("color").distinct().count().toInt

    // TODO: Handle vertices without any edges.

    // Initialize vertex beliefs at 0.0.
    var gx = GraphFrame(colorG.vertices.withColumn("belief", lit(0.0)), colorG.edges)

    // Run BP for numIter iterations.
    for (iter <- Range(0, numIter)) {
      // For each color, have that color receive messages from neighbors.
      for (color <- Range(0, numColors)) {
        // Define "AM" for shorthand for referring to the src, dst, edge, and msg fields.
        // (See usage below.)
        val AM = AggregateMessages
        // Send messages to vertices of the current color.
        // We may send to source or destination since edges are treated as undirected.
        val msgForSrc: Column = when(AM.src("color") === color, AM.edge("b") * AM.dst("belief"))
        val msgForDst: Column = when(AM.dst("color") === color, AM.edge("b") * AM.src("belief"))
        val logistic = udf { (x: Double) => math.exp(-log1pExp(-x)) }
        val aggregates = gx.aggregateMessages
          .sendToSrc(msgForSrc)
          .sendToDst(msgForDst)
          .agg(sum(AM.msg).as("aggMess"))
        val v = gx.vertices
        // Receive messages, and update beliefs for vertices of the current color.
        val newBeliefCol = when(
          v("color") === color && aggregates("aggMess").isNotNull,
          logistic(aggregates("aggMess") + v("a")))
          .otherwise(v("belief")) // keep old beliefs for other colors
        val newVertices = v
          .join(aggregates, v("id") === aggregates("id"), "left_outer") // join messages, vertices
          .drop(aggregates("id")) // drop duplicate ID column (from outer join)
          .withColumn("newBelief", newBeliefCol) // compute new beliefs
          .drop("aggMess") // drop messages
          .drop("belief") // drop old beliefs
          .withColumnRenamed("newBelief", "belief")
        // Cache new vertices using workaround for SPARK-13346
        val cachedNewVertices = AM.getCachedDataFrame(newVertices)
        gx = GraphFrame(cachedNewVertices, gx.edges)
      }
    }

    // Drop the "color" column from vertices
    GraphFrame(gx.vertices.drop("color"), gx.edges)
  }

  /** More numerically stable `log(1 + exp(x))` */
  private def log1pExp(x: Double): Double = {
    if (x > 0) {
      x + math.log1p(math.exp(-x))
    } else {
      math.log1p(math.exp(x))
    }
  }
}
