// Copyright 2011 Foursquare Labs Inc. All Rights Reserved.

package com.foursquare.geo.quadtree

import com.vividsolutions.jts.geom.{Coordinate, Envelope, Geometry}
import java.io.{File, Serializable}
import java.net.URL
import org.geotools.data.{DataStoreFactorySpi, FileDataStore}
import org.geotools.data.shapefile.ShapefileDataStoreFactory
import org.geotools.data.simple.{SimpleFeatureIterator, SimpleFeatureSource}
import org.geotools.geometry.jts.JTSFactoryFinder
import scalaj.collection.Imports._

/** A loader for our custom shapefiles used in reverse geocoding
  *
  * The shapefiles are expected to have three attributes:
  * A geometry attribute (describes the feature's shape)
  * A key attribute (e.g. timezone, country code, etc)
  * An index attribute generated by our batch simplification process
  *  Which describes an shape's path in the Quadtree.
  * The name of the index attribute explains the layout and levels of the quadtree
  * for example, GI40_2_2_2 divides the world up into (40 x 40)
  * At the top level, then (2 x 2), then (2 x 2), then (2 x 2).
  * At the last level, the actual shapes are preserved. */

object ShapefileGeo {
  val geometryFactory = JTSFactoryFinder.getGeometryFactory(null);
  val indexAttributePrefix = "GI"

  case class GeoBounds(minLong: Double, minLat: Double, width: Double, height: Double)

  trait KeyShape {
    def keyValue: Option[String]
    def shape: Geometry
  }

  class ShapeLeafNode (value: String, geom: Geometry) extends KeyShape {
    override def keyValue = Some(value)
    override def shape = geom
  }

  /** A node of the Quadtree
    *
    * Note that this is a more flexible interpretretation of a Quadtree node
    * because each level need not be (2 x 2).  It can be (N x N)
    * Also, there is a limited number of levels.  When the deepest level is
    * reached, all remaining shapes are placed on the list and independently
    * queried.*/
  class ShapeTrieNode(level: Int, bounds: GeoBounds, alwaysCheckGeometry: Boolean) extends KeyShape {
    def shape =
      geometryFactory.toGeometry(new Envelope(
        bounds.minLong, bounds.minLong + bounds.width, bounds.minLat, bounds.minLat + bounds.height))

    def nodeBounds = bounds

    def nodeLevel = level

    /** Some if all elems at level have same key value.
        Avoids unecessary point-in-plane computation */
    def keyValue: Option[String] = subList.length match {
      // however, if the poly is identical to the box, we should do this
      case 1 if !alwaysCheckGeometry => Some(subList.head.keyValue.get)
      case _ => None
    }

    /** Some if there are deeper levels */
    var subGrid: Option[Array[Array[ShapeTrieNode]]] = None
    /** Some if the deepest level reached, but muliple keyValues
        point-in-plane called on each element */
    var subList: List[ShapeLeafNode] = Nil

    def subGridSize: (Int,Int) = subGrid.map(v => (v.length, v(0).length)).getOrElse(0,0)

    // NOTE: order here is lat, long to comply with other interfaces,
    // Whereas GeoTools is mostly long, lat
    def getNearest(geoLat: Double, geoLong: Double, fudger: Option[Fudger] = None): Option[String] =
      getNearestList(geoLat, geoLong, fudger).headOption

    def getNearestList(geoLat: Double, geoLong: Double, fudger: Option[Fudger] = None): List[String] = keyValue match {
      case Some(keyValue) => List(keyValue)
      case None => subGrid match {
        case Some(grid) => {
          val longIdx = math.max(0, math.min(subGridSize._1 - 1,
            math.floor((geoLong - bounds.minLong)/(bounds.width /subGridSize._1)).toInt))
          val latIdx  = math.max(0, math.min(subGridSize._2 - 1,
            math.floor((geoLat  - bounds.minLat )/(bounds.height/subGridSize._2)).toInt))

          // RECURSE
          grid(longIdx)(latIdx).getNearestList(geoLat, geoLong, fudger)// Note: Lat, Long
        }
        case None => {
          val point = geometryFactory.createPoint(new Coordinate(geoLong, geoLat))
          val ret = subList.filter(keyShape => keyShape.shape.covers(point)).map(_.keyValue.get)
          if (ret.isEmpty) {
            // fudgers might want to start returning lists
            fudger.flatMap(_.fudge(geoLat, geoLong, this)).toList
          } else {
            ret
          }
        }
      }
    }

    def makeSubGrid(levelSizes: Array[Int]): Unit = {
      if (subGrid.isEmpty) {
        val nLongs = levelSizes(level)
        val nLats = levelSizes(level)
        val longChunk = bounds.width/nLongs
        val latChunk = bounds.height/nLats

        val grid = Array.tabulate(nLongs, nLats)((iLong, iLat) =>
          new ShapeTrieNode(level+1, GeoBounds( bounds.minLong + longChunk * iLong,
                                               bounds.minLat + latChunk * iLat,
                                               longChunk,
                                               latChunk), alwaysCheckGeometry))
        subGrid = Some(grid)
      }
    }


    def addFeature(path: List[(Int, Int)],
                    keyVal: String,
                    geometry: Geometry,
                    levelSizes: Array[Int]): Unit = path match {
      case Nil => subList ::= new ShapeLeafNode(keyVal, geometry)
      case head :: tail => {
        if (subGrid.isEmpty){
          makeSubGrid(levelSizes)
        }
        subGrid.get(head._1)(head._2).addFeature(tail, keyVal, geometry, levelSizes)
      }
    }
  } // end of ShapeTrieNode


  // Loads in a shape file, simplifying if necessary
  def load(
    url: URL,
    keyAttribute: String,
    validValues: Option[Set[String]],
    defaultValue: String,
    alwaysCheckGeometry: Boolean): ShapeTrieNode = {

   // Converts "12,13;20,22;" into (12,13) :: (20,22)
   def parseIndex(path: String): List[(Int,Int)] = {
      path.split(";").foldRight(
        Nil: List[(Int,Int)])((elem, list) => {
          val elemArr = elem.split(",",2).map(e => e.toInt)
          (elemArr(0), elemArr(1)) :: list
        }
      )
    }

    val dataStoreParams: java.util.Map[String, Serializable] = new java.util.HashMap[String, Serializable]()
    dataStoreParams.put(ShapefileDataStoreFactory.URLP.key, url)
    dataStoreParams.put(ShapefileDataStoreFactory.MEMORY_MAPPED.key, java.lang.Boolean.TRUE)
    dataStoreParams.put(ShapefileDataStoreFactory.CACHE_MEMORY_MAPS.key, java.lang.Boolean.FALSE)
    val dataStoreFactory: DataStoreFactorySpi = new ShapefileDataStoreFactory()
    val dataStore = dataStoreFactory.createDataStore(dataStoreParams).asInstanceOf[FileDataStore]
    val featureSource: SimpleFeatureSource = dataStore.getFeatureSource()

    // determine the key, index, attribute names, and the number and size of the index levels
    if (featureSource.getSchema.getDescriptor(keyAttribute) == null)
      throw new IllegalArgumentException("Schema has no attribute named \""+keyAttribute+"\"")

    val indexAttribute: String = featureSource.getSchema.getDescriptors.asScala.find(d =>
      d.getName.toString.startsWith(indexAttributePrefix)) match {
      case None =>
        throw new IllegalArgumentException("Schema has no attribute starting with \""+indexAttributePrefix+"\"")
      case Some(descriptor) => descriptor.getName.toString
    }
    val sourceLevelSizes = indexAttribute.substring(indexAttributePrefix.length).split("_").map(_.toInt)

    // build the world
    val bounds = featureSource.getInfo.getBounds
    val world = new ShapeTrieNode( 0,
                              GeoBounds(bounds.getMinX,
                                        bounds.getMinY,
                                        bounds.getWidth,
                                        bounds.getHeight), alwaysCheckGeometry)


    // would love to do toScala here, but though it looks like an iterator
    // and quacks like an iterator, it is not a java iterator.
    val iterator: SimpleFeatureIterator = featureSource.getFeatures.features

    try{
      while (iterator.hasNext) {
        val feature = iterator.next()
        val sourceGeometry = feature.getDefaultGeometry().asInstanceOf[Geometry]
        val keyValueCopy = feature.getAttribute(keyAttribute).toString
        val keyValue = if (validValues.map{vv => vv(keyValueCopy)}.getOrElse(true)) keyValueCopy else defaultValue
        val index = parseIndex(feature.getAttribute(indexAttribute).toString)
        world.addFeature(index, keyValue, sourceGeometry, sourceLevelSizes)
      }
    } finally {
      iterator.close()
    }

    world
  }


  /** A fudger is used when no map shapes contain the point.
    *
    * Due to geocoding innaccuracies, a coordinate may not be contained
    * by any shape.  This almost always occurs in the following scenario:
    * 0. We are at the deepest level (smallest cell)
    * 1. There is a body of water (a "no-timezone region") in the cell
    * 2. There are at least 2 landmasses in the cell which have different timezones
    *
    * - Or -
    *
    * We are just straight up in the middle of some serious body of water.
    * No land in sight*/
  trait Fudger{
    def fudge(lat: Double, long: Double, node: ShapeTrieNode): Option[String]
  }

  /** Determines centroid of each shape, calculates euclid dist to coord*/
  class CentroidDistanceFudger extends Fudger{
    def fudge(lat: Double, long: Double, node: ShapeTrieNode): Option[String] = {
      if (node.subList.isEmpty){
        None
      } else {
        val point = geometryFactory.createPoint(new Coordinate(long, lat))
        Some(node.subList.
             map(keyShape => (keyShape, keyShape.shape.getCentroid.distance(point))). //(keyShape, centroid)
             reduceLeft((ksd1, ksd2) => if (ksd1._2 < ksd2._2) ksd1 else ksd2)._1.keyValue.get)
      }
    }
  }

  /** Tries to fudge by drawing a bounding rectangle around each shape.
    *
     * Whoever contains the point wins, ties are broken by smallest area wins*/
  class BoxBoundaryFudger extends Fudger{
    def fudge(lat: Double, long: Double, node: ShapeTrieNode): Option[String] = {
      if (node.subList.isEmpty) {
        None
      } else {
        val point = geometryFactory.createPoint(new Coordinate(long, lat))
        (node.subList
         .map(keyShape => (keyShape, keyShape.shape.getEnvelope))
         .filter(_._2.covers(point))
         .reduceLeftOption((kse1, kse2) => if (kse1._2.getArea < kse2._2.getArea) kse1 else kse2)// Ordering On?
         .map(_._1.keyValue.get))
      }
    }
  }

  class OpenOceanMeridianFudgerTZ extends Fudger{
    def fudge(lat: Double, long: Double, node: ShapeTrieNode): Option[String] = {
      // Not in the ocean if there are shapes in the node
      if (!node.subList.isEmpty){
        None
      } else {
        // Okay, so Etc/GMT+1 means one hour _WEST_ of GMT, also known to normal people as GMT-01:00 (GMT-1)
        // Unfortunately, joda and java only agree on these silly "Etc/" definitions, and not the
        // ones that actually make any sense.
        // See http://twiki.org/cgi-bin/xtra/tzdate?tz=Etc/GMT+5 and
        // http://bugs.sun.com/bugdatabase/view_bug.do?bug_id=4813746
        val closestMeridian = -math.round(long/15.0).toInt
        if (closestMeridian > 0){
          Some("Etc/GMT+"+closestMeridian) // Etc/GMT+1
        }
        else {
          Some("Etc/GMT"+closestMeridian) // Etc/GMT0, Etc/GMT-10
        }
      }
    }
  }

  class OpenOceanDefaultFudger(default: String) extends Fudger{
    def fudge(lat: Double, long: Double, node: ShapeTrieNode): Option[String] = {
      // Not in the ocean if there are shapes in the node
      if (!node.subList.isEmpty){
        None
      }
      else{
        Some(default)
      }
    }
  }

  class MultiFudger (fudgers: List[Fudger]) extends Fudger{
    def fudge(lat: Double, long: Double, node: ShapeTrieNode): Option[String] = {

        // preferable to mapping because centroid calculation can be expensive.  Don't want to instantiate
        // it if not used.   Maybe lazy?
        def fudgeList(lat: Double, long: Double, node: ShapeTrieNode, fudgers: List[Fudger]): Option[String] = {
          fudgers match {
            case Nil => None
            case head :: tail => head.fudge(lat, long, node) orElse fudgeList(lat, long, node, tail)
          }
        }
        fudgeList(lat, long, node, fudgers)
    }
  }

  class MultiFudgerTZ extends MultiFudger(new OpenOceanMeridianFudgerTZ ::
                                          new BoxBoundaryFudger ::
                                          new CentroidDistanceFudger ::
                                          Nil)

  class MultiFudgerCC(default: String)
    extends MultiFudger(new OpenOceanDefaultFudger(default) ::
                        new BoxBoundaryFudger ::
                        new CentroidDistanceFudger ::
                        Nil)
}


