package gapt.proofs.resolution

import gapt.expr._
import gapt.expr.formula.Eq
import gapt.expr.formula.Formula
import gapt.expr.formula.fol.Hol2FolDefinitions
import gapt.expr.formula.fol.undoHol2Fol.Signature
import gapt.expr.formula.fol.{ replaceAbstractions, undoHol2Fol }
import gapt.expr.subst.Substitution
import gapt.expr.ty.To
import gapt.expr.ty.Ty
import gapt.expr.util.toVNF
import gapt.proofs.{ Ant, Suc }

object ResolutionToRal extends ResolutionToRal {
  /* One of our heuristics maps higher-order types into first-order ones. When the proof is converted to Ral,
   * convert_formula and convert_substitution map the types back, if possible. The reason it is part of the
   * Ral transformation is that before the layer cleanup we also needed to convert all formulas to the same layer type
   * (i.e. not mix FOLFormulas with Formulas).
   */
  override def convert_formula( e: Formula ): Formula = e
  override def convert_substitution( s: Substitution ): Substitution = s
  override def convert_context( con: Abs ) = con

}

abstract class ResolutionToRal {
  /* convert formula will be called on any formula before translation */
  def convert_formula( e: Formula ): Formula

  /* convert substitution will be called on any substitution before translation */
  def convert_substitution( s: Substitution ): Substitution

  def convert_context( con: Abs ): Abs

  def apply( p: ResolutionProof ): ResolutionProof = p match {
    case Input( cls )         => Input( cls map convert_formula )
    case Taut( f )            => Taut( convert_formula( f ) )
    case Refl( t )            => convert_formula( t === t ) match { case Eq( t_, _ ) => Refl( t_ ) }
    case Factor( p1, i1, i2 ) => Factor( apply( p1 ), i1, i2 )
    case Subst( p1, subst ) =>
      val substNew = convert_substitution( subst )
      Subst( apply( p1 ), substNew )
    case p @ Resolution( p1, i1, p2, i2 ) =>
      val p1New = apply( p1 )
      val p2New = apply( p2 )
      Resolution( p1New, i1, p2New, i2 )
    case Paramod( p1, eq @ Suc( _ ), dir, p2, lit, con: Abs ) =>
      val p1New = apply( p1 )
      val p2New = apply( p2 )
      Paramod( p1New, eq, dir, p2New, lit, convert_context( con ) )
    case Flip( p1, i1 ) => Flip( apply( p1 ), i1 )
  }
}

/**
 * A converter from resolution proofs to Ral proofs, which reintroduces the lambda abstractions
 * which we removed for the fol export.
 *
 * @param sig_vars The signature of the variables in the original proof
 * @param sig_consts The signature of constants in the original proof
 * @param cmap The mapping of abstracted symbols to lambda terms. The abstracted symbols must be unique (i.e. cmap
 *             must be a bijection)
 */
class Resolution2RalWithAbstractions(
    sig_vars:   Map[String, List[Var]],
    sig_consts: Map[String, List[Const]],
    cmap:       Hol2FolDefinitions ) extends ResolutionToRal {

  private def bt( e: Expr, t_expected: Option[Ty] ) = BetaReduction.betaNormalize(
    undoHol2Fol.backtranslate( e, sig_vars, sig_consts, cmap, t_expected ) )

  override def convert_formula( e: Formula ): Formula = bt( e, Some( To ) ).asInstanceOf[Formula]

  override def convert_context( con: Abs ) = {
    val Abs( v, rest ) = con
    val restNew = bt( rest, None )
    toVNF( Abs( v, restNew ) ).asInstanceOf[Abs]
  }

  override def convert_substitution( s: Substitution ): Substitution = {
    val mapping = s.map.toList.map {
      case ( from, to ) =>
        ( bt( from, None ).asInstanceOf[Var], bt( to, None ) )
    }

    Substitution( mapping )
  }
}

/**
 * A converter from Robinson resolution proofs to Ral proofs, which reintroduces the lambda abstractions
 * which we removed for the fol export.
 */
object Resolution2RalWithAbstractions {

  /**
   * @param signature The signature of the original proof
   * @param cmap The mapping of abstracted symbols to lambda terms. The abstracted symbols must be unique (i.e. cmap
   *             must be a bijection)
   */
  def apply( signature: Signature, cmap: Hol2FolDefinitions ) = {
    val ( sigc, sigv ) = signature
    new Resolution2RalWithAbstractions(
      sigv.map( x => ( x._1, x._2.toList ) ),
      sigc.map( x => ( x._1, x._2.toList ) ), cmap )
  }

  /**
   * @param sig_vars The signature of the variables in the original proof
   * @param sig_consts The signature of constants in the original proof
   * @param cmap The mapping of abstracted symbols to lambda terms. The abstracted symbols must be unique (i.e. cmap
   *             must be a bijection)
   */
  def apply(
    sig_vars:   Map[String, List[Var]],
    sig_consts: Map[String, List[Const]],
    cmap:       Hol2FolDefinitions ) = new Resolution2RalWithAbstractions( sig_vars, sig_consts, cmap )

}
