﻿//-----------------------------------------------------------------------------
//
// Copyright (C) Microsoft Corporation.  All Rights Reserved.
//
//-----------------------------------------------------------------------------

namespace Microsoft.Research.Vx86.ContractGeneration

  module ContractGen =
    open Microsoft.Research.Vcc
    open Microsoft.FSharp.Math
    open CAST
    open Util

    [<System.ComponentModel.Composition.Export("Microsoft.Research.Vcc.Plugin")>]    
    type ContractGeneratorPlugin() =
      inherit Microsoft.Research.Vcc.Plugin()
      
      let dbg() = System.Diagnostics.Debugger.Break()
      
      let pluginOptions = ref []
      let verifiedCOptions = ref null

      let die() = failwith "confused, will now die"
      
      let hasBoolAttr n = List.exists (function VccAttr (n', "true") -> n = n' | _ -> false)

      override this.IsModular() = false
      override this.Help() = "Don't panic!"
      override this.Name() = "vx86contractgen"
      override this.UseCommandLineOptions options = pluginOptions := [ for o in options -> o ]
      override this.UseVccOptions options = verifiedCOptions := options

      override this.Verify(filename, env, decls) = 
        let outfileName = System.IO.Path.ChangeExtension(filename, "contracts.asm")
        use out = new System.IO.StreamWriter(outfileName, false)
        
        let wr (str:string) = out.Write("; "); out.WriteLine(str)
        let wrc (str : string) = out.Write(";^ "); out.WriteLine(str)
        let nl() = wr ""
        
        let wrHeader() = 
          nl()
          wr "vx86 assembler contract file"
          nl()
          wr "Automatically generated by vcc vx86contractgen plugin - DO NOT EDIT"
          nl()
          
        let wrFooter() = 
          nl()
          wr "That's all folks."
          out.WriteLine()
          
        let opp = Map.ofList [ ".", 0; 
                               "->", 1; 
                               "[]", 2; 
                               "f()", 3; 
                               "u~", 4; "u!", 4; "u-", 4; "u+", 4; "&", 4; "*()", 4; "(T)", 4; 
                               "b==>", 5;
                               "b*", 6; "b/", 6; "b%", 6; 
                               "b+", 7; "b-", 7; 
                               "b<<", 8; "b>>",8;
                               "b<", 9; "b>", 9; "b<=", 9; "b>=", 9;
                               "b==", 10; "b!=", 10;
                               "b&", 11;
                               "b^", 12;
                               "b|", 13;
                               "b&&", 14;
                               "b||", 15;
                               "?:", 16; ]
                                   
        let parenthesize (e:Expr) = // adds Macro(_, "()",_) nodes where operator precedence would require parenthesis
        
          let rec addParens prec _ = 
          
            let addParen p (e :Expr) = if (p > prec) then Some(Macro(e.Common, "()", [e])) else Some(e)
            let addParensList p = List.map (fun (e:Expr) -> e.SelfMap(addParens p))

            function
            | Prim(ec, Op(op, cs), [e]) -> 
              let p = opp.["u"+op]
              addParen p (Prim(ec, Op(op, cs), [e.SelfMap(addParens p)]))
            | Prim(ec, Op(op, cs), [e1; e2]) -> 
              let p = opp.["b"+op]
              addParen p (Prim(ec, Op(op, cs), [e1.SelfMap(addParens p); e2.SelfMap(addParens p)]))
            | Cast(ec, cs, e) -> 
              let p = opp.["(T)"]
              addParen p (Cast(ec, cs, (e.SelfMap(addParens p))))
            | Deref(ec, e) ->
              let p = opp.["*()"]
              addParen p (Deref(ec, (e.SelfMap(addParens p))))
            | Call(ec, fn, targs, args) ->
              let p = opp.["f()"]
              addParen p (Call(ec, fn, targs, addParensList p args))
            | Macro(ec, "&", [e]) ->
              let p = opp.["&"]
              addParen p  (Macro(ec, "&", [e.SelfMap(addParens p)]))
            | _ -> None
          
          e.SelfMap(addParens 16)  
        
        let trExpr (expr:Expr) = 
          let sBuf = new System.Text.StringBuilder()
          let wr (str : string) = sBuf.Append(str) |> ignore          
          let commas separationChar f l  =
            let rec loop isFirst = function
              | [] -> ()
              | e::es ->
                if not isFirst then wr separationChar
                f e
                loop false es
            loop true l
            
          let intString = function
            | UInt8 -> "uint8"
            | Int8 -> "int8"
            | UInt16 -> "uint16"
            | Int16 -> "int16"
            | UInt32 -> "uint32"
            | Int32 -> "int32"
            | UInt64 -> "uint64"
            | Int64 -> "int64"
          
          let rec tr = function
            | Macro(_, "()", [e]) -> wr "("; tr e; wr ")"
            | Macro(_, m, []) -> wr m
            | Macro(_, "ite", [cond; e1; e2]) -> tr cond; wr " ? "; tr e1; wr " : "; tr e2
            | Macro(_, m, args) -> wr m; wr "("; trs ", " args; wr ")"
            | Prim(_, Op(op,_), [e]) -> wr op; tr e
            | Prim(_, Op(op,_), [e1; e2]) -> tr e1; wr " "; wr op; wr " "; tr e2
            | Ref(_, v) -> trVar v
            | IntLiteral(_,i) -> wr (i.ToString())
            | BoolLiteral(_, true) -> wr "true"
            | BoolLiteral(_, false) -> wr "false"
            | Deref(_, e) -> wr "*"; tr e
            | Cast({Type = Ptr(Integer(kind))}, _, e) -> 
              wr "("; wr (intString kind); wr "*)"
              tr e
            | Cast({Type = Integer(kind)}, _, e) ->
              wr "("; wr (intString kind); wr ")"
              tr e
            | Old(_, Macro(_, "prestate", []), e) ->  wr "old("; tr e; wr(")")
            | Old _ -> failwith "unexpected pre-state in old"
            | Quant(_, qdata) -> trQuant qdata
            | (e : Expr) -> dbg()
          and trQuant (qd : QuantData) =
            let qToString = function
              | QuantKind.Forall -> "forall"
              | QuantKind.Exists -> "exists"
              | QuantKind.Lambda -> "lambda"
            let trQVar (v:Variable) = trType v.Type; wr " "; wr v.Name
            let trTrigger (trigger : Expr list) = 
              if (trigger.IsEmpty) then ()
              else wr "{ "; trs ", " trigger; wr " } " 
            let trCond = function
              | None -> ()
              | Some(e) -> tr e; wr "; "
            wr (qToString qd.Kind)
            wr "("; 
            commas "," trQVar qd.Variables
            wr "; "
            List.iter trTrigger (qd.Triggers)
            trCond qd.Condition
            tr qd.Body
            wr ")"
          and trs separationChar = commas separationChar (fun e -> tr e)
                  
          and trVar = function
            | { Kind = (VarKind.ConstGlobal | VarKind.Global); Name = n } -> wr n
            | { Kind = (VarKind.SpecParameter | VarKind.OutParameter); Name = n } -> wr n
            | { Kind = VarKind.QuantBound; Name = n } -> wr n
            | _ -> failwith "wrong variable kind survived"
            
          and trType = function
            | Integer i -> wr (intString i)
            | _ -> failwith "Cannot handle type"
            
          tr expr
          sBuf.ToString()
          
        let wrContract ctr expr = wrc (ctr + "(" + trExpr expr + ")")
          
        let parSubst = function
          | 0 -> "core.R[RCX]"
          | 1 -> "core.R[RDX]"
          | 2 -> "core.R[R8]"
          | 3 -> "core.R[R9]"
          | _ -> failwith "only up to 4 function parameters are supported"
          
        let wrFunction (f : Function) =
        
          let wrFixedEnsures() =
            let wrReg r = wrc ("ensures(core.R[" + r + "] == old(core.R[" + r + "]))")
            List.iter wrReg ["RSI"; "RDI"; "RBP"; "RSP"; "R12"; "R13"; "R14"; "R15"]
        
          let macrofyCore (e:Expr) =
            let markAsCore (e:Expr) = Some(Macro(e.Common, "|core|", [e]))
            let coreVal (ec : ExprCommon) = markAsCore (Macro(ec, ec.Token.Value, []))
            let (|Core|_|) = function
              | Macro(_, "|core|", [Macro(_, str, [])]) -> Some str
              | _ -> None
            let macrofy self = function
              | Deref(_, Macro(_, "&", [e])) -> Some(self e)
              | Macro(_, "&", [Deref(_,e)]) -> Some(self e)
              | Ref(ec, {Name = "core"}) -> markAsCore (Macro(ec, "core", []))
              | Call(ec, ({Name = (("_vcc_wrapped" | "_vcc_mutable") as name)} as fn), targs, [e]) ->
                match self e with
                  | (Core e) as e' -> Some(Macro(ec, name.Substring(5), [e']))
                  | e -> Some(Call(ec, fn, targs, [e]))
              | Macro(ec, "&", [e])->
                match self e with 
                  | Core e -> coreVal ec
                  | e -> Some(Macro(ec, "&", [e]))
              | Deref(ec, e) ->
                match self e with 
                  | Core e -> coreVal ec
                  | e -> Some(Deref(ec, e))
              | Cast(ec, cs, e) ->
                match self e with 
                  | Core e as expr -> coreVal ec
                  | e -> Some(Cast(ec, cs, e))
              | Dot(ec, e, fld) ->
                match self e with
                  | Core e -> coreVal ec
                  | e -> Some(Dot(ec,e,fld))
              | Index(ec, e, (IntLiteral(_, i) as idx)) ->
                match self e with
                  | Core e -> coreVal ec
                  | e -> Some(Index(ec, e, idx))
              | _ -> None
            
            let removeMarker self = function
              | Macro(_, "|core|", [e]) -> Some(self e)
              | _ -> None
            
            e.SelfMap(macrofy).SelfMap(removeMarker)

          let mkInt (i : int) = IntLiteral({bogusEC with Type = Integer(IntKind.Int64)}, new bigint(i))

          let preNormalize (e:Expr) =
            let elementTypeForArithmetic = function
              | Ptr Void -> Type.Integer IntKind.UInt8
              | Ptr t -> t
              | _ -> failwith "non-ptr type used in pointer arithmetic"

            let extractArraySize (expr:Expr) (elementType:Type) (byteCount:Expr) =

              let typeSz = new bigint(elementType.SizeOf)
              let byteCount =
                match byteCount with
                  | Cast (_, _, e) -> e
                  | e -> e
              let (neg, byteCount) =
                match byteCount with
                  | Prim (c, (Op("-", _) as op), [e]) -> (fun e -> Prim (c, op, [e])), e
                  | e -> (fun x -> x), e
              let elts =
                match byteCount with
                  | IntLiteral (c, allocSz) when (allocSz % typeSz) = bigint.Zero ->
                    IntLiteral (c, allocSz / typeSz)
                  | Prim (_, Op("*", _), [Expr.IntLiteral (_, allocSz); e]) when allocSz = typeSz -> e
                  | Prim (_, Op("*", _), [e; Expr.IntLiteral (_, allocSz)]) when allocSz = typeSz -> e
                  | _ when typeSz = bigint.One -> byteCount
                  | _ -> failwith "don't know how to determine number of elements in array"
          
              match neg elts with
                | IntLiteral (_, OneBigInt) -> None
                | sz -> Some sz

            let preNormalize' self = function
              | Expr.Macro (c, "ptr_addition", [(Dot(_,_,fld) as e1); e2]) as expr when fld.Type._IsArray ->
                let ptr, off =
                  if e1.Type._IsPtr then (e1, e2)
                  else (e2, e1)
                let elType = elementTypeForArithmetic ptr.Type
                let off =
                  match extractArraySize expr elType off with
                    | Some e -> dbg(); e
                    | None -> Expr.IntLiteral (off.Common, bigint.One)
                Some (self (Expr.Index (c, ptr, off)))
              | _ ->  None
            
            e.SelfMap(preNormalize')
          
          let normalize (e : Expr) = 

            let i64 = Type.Integer(IntKind.Int64)
            let u64Ptr = Type.PhysPtr(Type.Integer(IntKind.UInt64))
            let sizeOf (e:Expr) =
              match e.Type with
                | Ptr(t) -> t.SizeOf
                | _ -> failwith "pointer type expected"
            let isLargerType = function
              | Integer i1, Integer i2 ->
                match i1, i2 with 
                  | IntKind.UInt64, (IntKind.UInt32|IntKind.UInt16|IntKind.UInt8) -> true
                  | IntKind.UInt32, (IntKind.UInt16|IntKind.UInt8) -> true
                  | IntKind.UInt16, (IntKind.UInt8) -> true
                  | IntKind.Int64, (IntKind.Int32|IntKind.Int16|IntKind.Int8) -> true
                  | IntKind.Int32, (IntKind.Int16|IntKind.Int8) -> true
                  | IntKind.Int16, (IntKind.Int8) -> true
                  | _ -> false
              | _ -> false
                  
            let toPtr (e:Expr) = Cast({e.Common with Type = u64Ptr}, CheckedStatus.Unchecked, e)
            let offset (f : Field) = 
              match f.Offset with
                | Normal n -> n
                | BitField _ -> failwith "bit fields are not supported"
          
            let normalize' self = function 
              | Result(ec) -> Some(Macro(ec, "core.R[RAX]", []))
              | Macro(_, "&", [e]) ->
                  match self e with
                    | Deref(_, e) -> Some(e)
                    | e -> Some(Macro(e.Common, "&", [e]))
              | Cast({Type = Ptr(Type.Void)}, _, e) 
              | Cast({Type = Bool},_, e) -> Some(self e)
              | Cast(ec, cs, Cast(_, _, e)) -> Some (self (Cast(ec, cs, e)))
              | Cast({Type = Integer _}, _, (IntLiteral _ as i)) -> Some(i) 
              | Cast(ec, _, e) when isLargerType (ec.Type,e.Type) -> Some(self e)
              | Cast(ec, cs, e) -> 
                match self e with
                  | e when e.Type = ec.Type -> Some(e)
                  | e -> Some(Cast(ec, cs, e))
              | Call(ec, {Name = ("_vcc_wrapped"|"_vcc_mutable")}, _, args) -> Some(self(Macro(ec, "_vcc_mutable", args)))
              | Macro(ec, "_vcc_mutable", [Cast(_,_, p)]) -> 
                let p' = self(toPtr p)
                Some(Macro(ec, "assembler_is_mutable_array", [p'; mkInt (((sizeOf p) + 7) / 8)]))
              | Macro(ec, "ptr_addition", [e1; e2]) ->
                  match self e2 with
                    | Prim(_, Op("*",_), [e2'; IntLiteral(_, n)]) -> 
                        if n = new bigint(sizeOf e1) then 
                          Some(Prim(ec, Op("+", Unchecked), [self e1; e2']))
                        else 
                          failwith "unexpected expression structure"
                    | e2' -> dbg(); Some(Prim(ec, Op("+", Unchecked), [self e1; Prim(e2'.Common, Op("/", Unchecked), [e2'; mkInt(sizeOf e1)])]))
              | Macro(_, "ite", [e; BoolLiteral(_, true); BoolLiteral(_, false)]) -> Some(self e)
              | Macro(_, "ite", [e; BoolLiteral(_, false); BoolLiteral(_, true)]) -> Some(self (Prim(e.Common, Op("!", CheckedStatus.Processed), [e])))
              | Macro(_, "ite", [e; e1; BoolLiteral(_,false)]) -> Some(self (Prim(e.Common, Op("&&", CheckedStatus.Processed), [e; e1])))
              | Macro(_, "ite", [e; BoolLiteral(_, true); e1]) -> Some(self (Prim(e.Common, Op("||", CheckedStatus.Processed), [e; e1])))
              | Macro(_, "ite", [e; e1; BoolLiteral(_, true)]) -> Some(self (Prim(e.Common, Op("==>", CheckedStatus.Processed), [e; e1])))
              | Call(ec, {Name = fnName}, _, args) when fnName.StartsWith("_vcc_") -> Some(self (Macro(ec, fnName.Substring(5), args)))
              | Call(ec, fn, _, args) -> Some(self (Macro(ec, fn.Name, args)))
              | Dot(_, e, f) -> 
                let o = offset f
                if o % 8 <> 0 then failwith ("misaligned field '" + f.Name + "'")
                if o = 0 then 
                  Some(self e)
                else
                  Some(Prim(e.Common, Op("+", CheckedStatus.Unchecked), [self e; mkInt (o/8)]))
              | Ref(ec, {Kind = VarKind.Parameter; Name = name}) -> 
                let reg = Macro(ec, parSubst (List.findIndex (fun (v : Variable) -> v.Name = name) f.InParameters), [])
                let reg = if ec.Type._IsPtr then toPtr(reg) else reg
                Some(reg)
              | _ -> None
            e.SelfMap(normalize')

          let postprocessWrite = function
            | Cast({Type = Ptr(_)}, _, ptr) as e ->
              match ptr.Type with
                | Ptr(Type.Ref(td)) -> Macro(e.Common, "array_range", [e; mkInt ((td.SizeOf + 7) / 8)])
                | _ -> e
            | e -> e

        
          nl()
          let fForPrinting = {f with Reads = []; Writes = [];  Requires = []; Ensures = [] }
          wr (fForPrinting.ToString().TrimEnd([| '\r'; '\n' |] ))
          nl()
          
          let prepare = List.map preNormalize >> List.map macrofyCore >> List.map normalize >> List.map parenthesize 
          
          f.Writes |> prepare |> List.map postprocessWrite |> List.iter (wrContract "writes")
          f.Reads |> prepare |> List.iter (wrContract "reads") 
          f.Requires |> prepare |> List.iter (wrContract "requires") 
          f.Ensures |> prepare |> List.iter (wrContract "ensures") 
          wrFixedEnsures()
          
        wrHeader()
        
        for d in decls do
          match d with 
            | Top.FunctionDecl(f) when hasBoolAttr "asm_routine" f.CustomAttr -> wrFunction f
            | _ -> ()
            
        wrFooter()
      