package org.mule.weave.v2.parser.phase

import org.mule.weave.v2.parser.MessageCollector
import org.mule.weave.v2.parser.VariableModuleOpenAccess
import org.mule.weave.v2.parser.VariableReferencedInOtherScope
import org.mule.weave.v2.parser.VariableReferencedMoreThanOnce
import org.mule.weave.v2.parser.annotation.MaterializeVariableAnnotation
import org.mule.weave.v2.parser.ast.AstNode
import org.mule.weave.v2.parser.ast.AstNodeHelper
import org.mule.weave.v2.parser.ast.conditional.IfNode
import org.mule.weave.v2.parser.ast.conditional.UnlessNode
import org.mule.weave.v2.parser.ast.functions.DoBlockNode
import org.mule.weave.v2.parser.ast.functions.FunctionNode
import org.mule.weave.v2.parser.ast.functions.UsingNode
import org.mule.weave.v2.parser.ast.header.directives.InputDirective
import org.mule.weave.v2.parser.ast.header.directives.VarDirective
import org.mule.weave.v2.parser.ast.module.ModuleNode
import org.mule.weave.v2.parser.ast.patterns.DeconstructArrayPatternNode
import org.mule.weave.v2.parser.ast.patterns.DeconstructObjectPatternNode
import org.mule.weave.v2.parser.ast.patterns.DefaultPatternNode
import org.mule.weave.v2.parser.ast.patterns.EmptyArrayPatternNode
import org.mule.weave.v2.parser.ast.patterns.EmptyObjectPatternNode
import org.mule.weave.v2.parser.ast.patterns.ExpressionPatternNode
import org.mule.weave.v2.parser.ast.patterns.LiteralPatternNode
import org.mule.weave.v2.parser.ast.patterns.PatternExpressionNode
import org.mule.weave.v2.parser.ast.patterns.PatternMatcherNode
import org.mule.weave.v2.parser.ast.patterns.RegexPatternNode
import org.mule.weave.v2.parser.ast.patterns.TypePatternNode
import org.mule.weave.v2.parser.ast.structure.DocumentNode
import org.mule.weave.v2.parser.ast.types.ObjectTypeNode
import org.mule.weave.v2.parser.ast.types.TypeReferenceNode
import org.mule.weave.v2.parser.ast.types.UnionTypeNode
import org.mule.weave.v2.parser.ast.variables.NameIdentifier
import org.mule.weave.v2.scope.AstNavigator
import org.mule.weave.v2.scope.Reference
import org.mule.weave.v2.scope.ScopesNavigator
import org.mule.weave.v2.ts.ArrayType
import org.mule.weave.v2.ts.ObjectType
import org.mule.weave.v2.ts.RecursionDetector
import org.mule.weave.v2.ts.ScopeGraphTypeReferenceResolver
import org.mule.weave.v2.ts.TypeHelper
import org.mule.weave.v2.ts.WeaveTypeTraverse
import org.mule.weave.v2.utils.IdentityHashMap

import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

/**
  * This phase will mark all the variables if of functions and inputs if they need they can be Streamed or not.
  *
  * A Variable needs to be materialized if it is referenced more than once or being referenced inside a lambda (different scope)
  *
  */
class MaterializeVariableMarkerPhase[R <: AstNode, T <: AstNodeResultAware[R] with ScopeNavigatorResultAware]() extends CompilationPhase[T, T] {

  override def doCall(source: T, context: ParsingContext): PhaseResult[T] = {
    val astNavigator = source.scope.rootScope.astNavigator()
    source.astNode match {
      case dn: DocumentNode => {
        val directives = dn.header.directives

        directives
          .collect({ case id: VarDirective => id })
          .foreach((id) => {
            val collector = new MessageCollector
            val willMaterialize = variableNeedsMaterialization(id.variable, dn, astNavigator, source.scope, collector)
            id.variable.annotate(MaterializeVariableAnnotation(willMaterialize, collector.errorMessages))
          })

        directives
          .collect({ case id: InputDirective => id })
          .foreach((id) => {
            val collector = new MessageCollector
            val willMaterialize = variableNeedsMaterialization(id.variable, dn, astNavigator, source.scope, collector)
            id.variable.annotate(MaterializeVariableAnnotation(willMaterialize, collector.errorMessages))
          })
      }
      case mn: ModuleNode => {
        mn.directives
          .collect({ case id: VarDirective => id })
          .foreach((id) => {
            //Module variables should allways be materialized
            id.variable.annotate(MaterializeVariableAnnotation(true, Seq((id.variable.location(), VariableModuleOpenAccess(id.variable)))))
          })
      }
      case _ =>
    }

    AstNodeHelper
      .collectChildren(source.astNode, {
        case _: FunctionNode | _: DoBlockNode | _: UsingNode | _: PatternExpressionNode => true
        case _ => false
      })
      .foreach({
        case fn: FunctionNode           => runOnFunctionNode(source, astNavigator, fn)
        case dbn: DoBlockNode           => runOnBlockNode(source, astNavigator, dbn)
        case usingNode: UsingNode       => runOnUsingNode(source, astNavigator, usingNode)
        case pen: PatternExpressionNode => runOnPatternExpressionNode(source, astNavigator, pen)
      })

    SuccessResult(source, context)
  }

  private def runOnPatternExpressionNode(source: T, astNavigator: AstNavigator, pen: PatternExpressionNode) = {
    val messageCollector = new MessageCollector
    pen match {
      case tpn: TypePatternNode => {
        val willMaterialize = variableNeedsMaterialization(tpn.name, tpn.onMatch, astNavigator, source.scope, messageCollector)
        tpn.name.annotate(MaterializeVariableAnnotation(willMaterialize, messageCollector.errorMessages))
      }
      case dpn: DefaultPatternNode => {
        val willMaterialize = variableNeedsMaterialization(dpn.name, dpn.onMatch, astNavigator, source.scope, messageCollector)
        dpn.name.annotate(MaterializeVariableAnnotation(willMaterialize, messageCollector.errorMessages))
      }
      case dapn: DeconstructArrayPatternNode => {
        val willMaterializeHead = variableNeedsMaterialization(dapn.head, dapn.onMatch, astNavigator, source.scope, messageCollector)
        dapn.head.annotate(MaterializeVariableAnnotation(willMaterializeHead, messageCollector.errorMessages))

        val willMaterializeTail = variableNeedsMaterialization(dapn.tail, dapn.onMatch, astNavigator, source.scope, messageCollector)
        dapn.tail.annotate(MaterializeVariableAnnotation(willMaterializeTail, messageCollector.errorMessages))
      }
      case dopn: DeconstructObjectPatternNode => {
        val willMaterializeHeadKey = variableNeedsMaterialization(dopn.headKey, dopn.onMatch, astNavigator, source.scope, messageCollector)
        dopn.headKey.annotate(MaterializeVariableAnnotation(willMaterializeHeadKey, messageCollector.errorMessages))

        val willMaterializeHeadValue = variableNeedsMaterialization(dopn.headValue, dopn.onMatch, astNavigator, source.scope, messageCollector)
        dopn.headValue.annotate(MaterializeVariableAnnotation(willMaterializeHeadValue, messageCollector.errorMessages))

        val willMaterializeTail = variableNeedsMaterialization(dopn.tail, dopn.onMatch, astNavigator, source.scope, messageCollector)
        dopn.tail.annotate(MaterializeVariableAnnotation(willMaterializeTail, messageCollector.errorMessages))
      }
      case lpn: LiteralPatternNode => {
        val willMaterializeTail = variableNeedsMaterialization(lpn.name, lpn.onMatch, astNavigator, source.scope, messageCollector)
        lpn.name.annotate(MaterializeVariableAnnotation(willMaterializeTail, messageCollector.errorMessages))
      }
      case epn: ExpressionPatternNode => {
        epn.name.annotate(MaterializeVariableAnnotation(true))
      }
      case rpn: RegexPatternNode => {
        val willMaterialize = variableNeedsMaterialization(rpn.name, rpn.onMatch, astNavigator, source.scope, messageCollector)
        rpn.name.annotate(MaterializeVariableAnnotation(willMaterialize, messageCollector.errorMessages))
      }
      case _: EmptyObjectPatternNode =>
      case _: EmptyArrayPatternNode  =>
    }
  }

  private def runOnUsingNode(source: T, astNavigator: AstNavigator, un: UsingNode) = {
    un.assignments.assignmentSeq.foreach((assignment) => {
      val collector = new MessageCollector
      val willMaterialize = variableNeedsMaterialization(assignment.name, un, astNavigator, source.scope, collector)
      assignment.name.annotate(MaterializeVariableAnnotation(willMaterialize, collector.errorMessages))
    })
  }

  private def runOnBlockNode(source: T, astNavigator: AstNavigator, doBlock: DoBlockNode) = {
    doBlock.header.directives
      .collect({ case id: VarDirective => id })
      .foreach((id) => {
        val collector = new MessageCollector
        val willMaterialize = variableNeedsMaterialization(id.variable, doBlock, astNavigator, source.scope, collector)
        id.variable.annotate(MaterializeVariableAnnotation(willMaterialize, collector.errorMessages))
      })
  }

  private def runOnFunctionNode(source: T, astNavigator: AstNavigator, fn: FunctionNode) = {
    val params = fn.params
    params.paramList.foreach((fp) => {
      val collector = new MessageCollector
      val willMaterialize = variableNeedsMaterialization(fp.variable, fn.body, astNavigator, source.scope, collector)
      fp.variable.annotate(MaterializeVariableAnnotation(willMaterialize, collector.errorMessages))
    })
  }

  def collectReferenceByBranch(node: AstNode, references: IdentityHashMap[NameIdentifier, Reference], collector: BranchReferenceNode = new BranchReferenceNode()): BranchReferenceNode = {
    node match {
      case in: IfNode => {
        collectReferenceByBranch(in.condition, references, collector)
        collector.runInExclusiveBranch((eb) => {
          eb.runInNewBranch(collectReferenceByBranch(in.ifExpr, references, _))
          eb.runInNewBranch(collectReferenceByBranch(in.elseExpr, references, _))
        })
      }
      case in: UnlessNode => {
        collectReferenceByBranch(in.condition, references, collector)
        collector.runInExclusiveBranch((eb) => {
          eb.runInNewBranch(collectReferenceByBranch(in.ifExpr, references, _))
          eb.runInNewBranch(collectReferenceByBranch(in.elseExpr, references, _))
        })
      }
      case pmn: PatternMatcherNode => {
        collectReferenceByBranch(pmn.lhs, references, collector)
        collector.runInExclusiveBranch((eb) => {
          pmn.patterns.patterns.foreach((pmn) => {
            pmn
              .children()
              .foreach((n) => {
                eb.runInNewBranch(collectReferenceByBranch(n, references, _))
              })
          })
        })

      }
      case ni: NameIdentifier => {
        if (references.contains(ni)) {
          collector.addReference(ni)
        }
      }
      case _ => {
        node.children().foreach((n) => collectReferenceByBranch(n, references, collector))
      }
    }
    collector
  }

  def variableNeedsMaterialization(variable: NameIdentifier, bodyNode: AstNode, astNavigator: AstNavigator, scopeNaviagtor: ScopesNavigator, messageCollector: MessageCollector): Boolean = {
    val references: Seq[Reference] = scopeNaviagtor
      .scopeOf(variable)
      .map(_.resolveLocalReferenceTo(variable))
      .getOrElse(Seq.empty)
    val materialize = if (references.isEmpty) {
      false
    } else {
      val referencesByNameIdentifier = IdentityHashMap[NameIdentifier, Reference]()
      references.foreach((r) => referencesByNameIdentifier.put(r.referencedNode, r))
      val rootBranch: BranchReferenceNode = collectReferenceByBranch(bodyNode, referencesByNameIdentifier)
      val allBranches: Seq[BranchReferenceNode] = rootBranch.allBranches()
      val referencedMoreThanOnce: Boolean = rootBranch.maxReferences() > 1
      if (referencedMoreThanOnce) {
        val nameIdentifiers = rootBranch.allBranches().filter(_.maxReferences() > 1).flatMap(_.localRefs())
        messageCollector.error(VariableReferencedMoreThanOnce(variable, nameIdentifiers.map(_.location())), variable.location())
      }
      referencedMoreThanOnce || allBranches.exists((an) => {
        val maybeFunctionNodes = an.localRefs().toStream.flatMap(astNavigator.parentWithTypeUntil(_, classOf[FunctionNode], bodyNode))
        if (maybeFunctionNodes.nonEmpty) {
          messageCollector.error(VariableReferencedInOtherScope(variable, maybeFunctionNodes.map(_.location())), variable.location())
        }
        maybeFunctionNodes.nonEmpty
      })
    }
    materialize
  }

  class BranchReferenceNode(parent: Option[BranchReferenceNode] = None, exclusive: Boolean = false) {
    private val localBranchReferences: mutable.ArrayBuffer[NameIdentifier] = ArrayBuffer()
    private val children: mutable.ArrayBuffer[BranchReferenceNode] = ArrayBuffer()

    def runInNewBranch(callback: (BranchReferenceNode) => Unit): Unit = {
      callback(newChildBranch(false))
    }

    def runInExclusiveBranch(callback: (BranchReferenceNode) => Unit): Unit = {
      callback(newChildBranch(true))
    }

    def addReference(reference: NameIdentifier): Unit = {
      localBranchReferences += (reference)
    }

    def maxReferences(): Int = {
      val localReferences = localBranchReferences.size
      val childRefs =
        if (children.isEmpty) {
          0
        } else if (exclusive) {
          children.map(_.maxReferences()).max
        } else {
          children.map(_.maxReferences()).sum
        }
      localReferences + childRefs
    }

    def allReferences(): Seq[NameIdentifier] = {
      localBranchReferences ++ parent.map(_.allReferences()).getOrElse(Seq())
    }

    def allBranches(): Seq[BranchReferenceNode] = {
      Seq(this) ++ children.flatMap(_.allBranches())
    }

    def localRefs(): Seq[NameIdentifier] = {
      localBranchReferences
    }

    private def newChildBranch(exclusive: Boolean): BranchReferenceNode = {
      val child = new BranchReferenceNode(Some(this), exclusive)
      children += (child)
      child
    }
  }

}
