(*===========================================================================*)
(*
 * CIL plugin for rootkit analysis.
 *
 * Vinod Ganapathy <vg@cs.wisc.edu>, November 26, 2007.
*)
(*===========================================================================*)

open Cil
open Str

(** Strip attributes from an input type *)
let rec strip_typ_attribs (t: typ) : typ = 
  begin
    (match t with
      | TVoid(_) -> TVoid([]);
      | TInt(k,_) -> TInt(k,[]);
      | TFloat(k,_) -> TFloat(k,[]);
      | TPtr(t',_) -> TPtr((strip_typ_attribs t'), []);
      | TArray(t',e,_) -> TArray((strip_typ_attribs t'), e, []);
      | TFun(t',l,b,_) -> TFun((strip_typ_attribs t'),l,b,[]);
      | TNamed(tinfo,_) -> TNamed(tinfo,[]);
      | TComp(cinfo,_) -> TComp(cinfo,[]);
      | TEnum(einfo,_) -> TEnum(einfo,[]);
      | TBuiltin_va_list(_) -> TBuiltin_va_list([]);
    );
  end

let typ_tostring_noattrib (t: typ) = 
  begin
    let t' = (strip_typ_attribs t) in 
    (Pretty.sprint 500 (d_type() t'))
  end

let repeat = ref 0 

(* Type graph: node *)
type tgnode_t = 
  | Tgnode of typ

(* Type graph: node ordering *)
module TgnodeOrder = 
struct 
  type t = tgnode_t
  let compare n1 n2 = 
    begin
      let Tgnode(t1) = n1 in
      let Tgnode(t2) = n2 in
      let s1 = (typ_tostring_noattrib t1) in
      let s2 = (typ_tostring_noattrib t2) in
      (String.compare s1 s2);
    end
end

(* Set of nodes *)
module Tgnodeset = Set.Make(TgnodeOrder)

(** Emit a node *)
let tg_dumpnode (n: tgnode_t) : string = 
  begin
    let Tgnode(t) = n in
    (Pretty.sprint 500 (d_type() t));
  end

(* Type graph edges: An edge exists from a source type to a target type 
   if the source must be defined before the target *)
type tgedge_t =
  | Tgedge of tgnode_t * tgnode_t

(** Edge ordering *)
module TgedgeOrder =
struct
  type t = tgedge_t
  let compare e1 e2 =
    begin
      let Tgedge(s1,t1) = e1 in
      let Tgedge(s2,t2) = e2 in
      let str1 = (tg_dumpnode s1) ^ (tg_dumpnode t1) in
      let str2 = (tg_dumpnode s2) ^ (tg_dumpnode t2) in
      (String.compare str1 str2);
    end
end

(** Set of edges *)
module Tgedgeset = Set.Make(TgedgeOrder)

(** Type Graph *)
type tgraph_t = {nodes: Tgnodeset.t; edges: Tgedgeset.t}

(** Hashtable storing the children of each node (i.e., neigbours from outgoing
 * edges *)
let children: (tgnode_t, Tgnodeset.t) Hashtbl.t = (Hashtbl.create 117)

(*---------------------------------------------------------------------------*)
(** Functions related to the type graph *)
(** Return an empty graph *)
let tg_empty_graph () : tgraph_t =
  begin
    let retgraph : tgraph_t = {
      nodes = (Tgnodeset.empty);
      edges = (Tgedgeset.empty);
    } in
    retgraph;
  end

(** Adds a node to a graph and returns the new graph *)
let tg_add_node (n: tgnode_t) (g: tgraph_t) : tgraph_t =
  begin
    let retgraph : tgraph_t = {
      nodes = (Tgnodeset.add n g.nodes);
      edges = g.edges;
    } in
    retgraph;
  end

(** Adds an edge to a graph and returns the new graph *)
let tg_add_edge (e: tgedge_t) (g: tgraph_t) : tgraph_t =
  begin
    let retgraph : tgraph_t = {
      nodes = g.nodes;
      edges = (Tgedgeset.add e g.edges);
    } in
    (* Update the list of children too *)
    let Tgedge(src, tgt) = e in
    (try
      let currnbrs = (Hashtbl.find children src) in 
      let newnbrs = (Tgnodeset.add tgt currnbrs) in
      (Hashtbl.replace children src newnbrs);
    with Not_found -> (
      let currnbrs = (Tgnodeset.empty) in 
      let newnbrs = (Tgnodeset.add tgt currnbrs) in
      (Hashtbl.replace children src newnbrs);
    ));
    retgraph;
  end

(** Statistics *)
let tg_stats (g: tgraph_t) : unit = 
  begin
    let numnodes = (Tgnodeset.cardinal g.nodes) in
    let numedges = (Tgedgeset.cardinal g.edges) in
    (Printf.fprintf stderr "Nodes: %d, Edges: %d\n" numnodes numedges);
  end

(** Emit a node *)
let tg_dumpnode (n: tgnode_t) : string = 
  begin
    let Tgnode(t) = n in (typ_tostring_noattrib t);
  end

(** Emit an edge *)
let tg_dumpedge (e: tgedge_t) : string = 
  begin
    let Tgedge(s, t) = e in
    (tg_dumpnode s) ^ " -> " ^ (tg_dumpnode t);
  end

(** Get the children of a node (i.e., the targets of outgoing edges of 
 * this node *)
let tg_getchildren (n: tgnode_t) (g: tgraph_t) : tgnode_t list = 
  begin
    (Printf.fprintf stderr "tg_getchildren: I was called\n");
    (try 
      let nbrs = (Hashtbl.find children n) in      
      (Tgnodeset.elements nbrs);
    with Not_found -> ([]));
  end
(*
    let retval = ref [] in
    let edgelist = (Tgedgeset.elements g.edges) in
    for i = 0 to (List.length edgelist) - 1 do
      let ithedge = (List.nth edgelist i) in
      let Tgedge(src, tgt) = ithedge in
      if (TgnodeOrder.compare src n) = 0 then
        retval := (List.append !retval [tgt]);
    done;
    !retval;
*)

(** Dump the type graph *)
let tg_dump (g: tgraph_t) : unit = 
  begin
    let nodelist = (Tgnodeset.elements g.nodes) in
    let edgelist = (Tgedgeset.elements g.edges) in
    for i = 0 to (List.length nodelist) - 1 do
      let ith = (List.nth nodelist i) in
      (Printf.fprintf stderr "Node: %s\n" (tg_dumpnode ith));
    done;
    for i = 0 to (List.length edgelist) - 1 do
      let ith = (List.nth edgelist i) in
      (Printf.fprintf stderr "Edge: %s\n" (tg_dumpedge ith));
    done;
  end

(*---------------------------------------------------------------------------*)
(* Implementation of the ordering algorithm *)
(** This hastable stores the types to be ordered. We store the string
 * that represents the type with all attributes stripped, and we hash
 * all the types that match this signature *)
let types_to_be_ordered: (string, typ) Hashtbl.t = (Hashtbl.create 11)
(** This hashtable stores the actual types that were ordered using the 
 * graph *)
let types_ordered: (string, bool) Hashtbl.t = (Hashtbl.create 11)

(* deserves_tg: Does the input type deserve a node in the type graph? Only
 * non-base types deserve a node in the type graph. Further, only types in
 * the original input type list can ever deserve a node in the tg*)
let rec deserves_tg_node (t: typ) : bool = 
  begin
  (*(Printf.fprintf stderr "\tcheck_deserve: %s\n" 
      (typ_tostring_noattrib t)); *)
    (match t with
    | TPtr(deref, _) -> 
      (* We don't want cycles in the graph. Therefore, we will say that
         TPtr deserves a node if and only if the pointer is of a TNamed
         type. If not, we say that it does not deserve a node *)
      (match deref with
      | TNamed(_) -> true;
      | _ -> false;
      );
    | TComp(_) -> true;
    | TNamed(_) -> true;
    | TArray(_) -> true;
    | TFun(_) -> true;
    | _ -> false;
    );
  end

(* get_predtypes: For the input type, get the list of types that must be
   defined before this type. The return list from get_predtypes may also
   include base types. *)
let rec get_predtypes (t: typ) : typ list = 
  begin
    (match t with
    | TPtr(dereftyp, _) -> [dereftyp];
    | TArray(basetyp, _, _) -> [basetyp];
    | TFun(rettyp, arglistopt, _, _) ->
      let retval = ref [] in 
      retval := (List.append !retval [rettyp]);
      (match arglistopt with
      | Some(arglist) ->
        for i = 0 to (List.length arglist) - 1 do
          let ith = (List.nth arglist i) in
          let (ithstr, ithtyp, ithattr) = ith in
          retval := (List.append !retval [ithtyp]);
        done;
      | None -> ();
      );
      !retval;
    | TNamed(tinfo, _) -> [tinfo.ttype];
    | TComp(cinfo, _) ->
      let retval = ref [] in
      for i = 0 to (List.length cinfo.cfields) - 1 do 
        let ithfld = (List.nth cinfo.cfields i) in
        let ithtyp = ithfld.ftype in
        retval := (List.append !retval [ithtyp]);
      done;
      !retval;
    | _ -> [];
    );
  end

(* build_tg: Given an input list of types, build the type graph. We build
 * the type graph using the input types. If nodes corresponding to new types 
 * are introduced in the process of constructing the graph, we account for
 * them by reconstructing the graph afresh. 
 * NOTE: Some optimization may be possible here: we could simply compute 
 * the predessors of newly added nodes instead of recomputing the graph 
 * each time.
*)
let types_before_tg: (string, typ) Hashtbl.t = (Hashtbl.create 11)
let rec build_tg (tl: typ list) : tgraph_t = 
  begin
    (Hashtbl.clear types_before_tg);
    for i = 0 to (List.length tl) - 1 do
      let ith = (List.nth tl i) in
      (Hashtbl.add types_before_tg (typ_tostring_noattrib ith) ith);
    done;
    repeat := !repeat + 1;
    (Printf.fprintf stderr "build_tg %d\n%!" !repeat);
    (* 1. Initialize the graph to be the empty graph *)
    let curr_tg = ref (tg_empty_graph()) in
    (* 2. Construct the type graph by traversing through the list of types *)
    for i = 0 to (List.length tl) - 1 do
      let ith = (List.nth tl i) in
      (* 2.1 We only care about complex types and will be adding nodes only for
             them. Base types don't deserve a node in the type graph *)
      if (deserves_tg_node ith) = true then 
        begin
          let newnode = Tgnode(ith) in
          curr_tg := (tg_add_node newnode !curr_tg);
          (* 2.2 Get the types that must precede this type *)
          let predsl = (get_predtypes ith) in
(*        (Printf.fprintf stderr "Numpreds: (%d) %s\n"
            (List.length predsl) (tg_dumpnode newnode)); *)
          for j = 0 to (List.length predsl) - 1 do
            let jth = (List.nth predsl j) in
            (* 2.3 For each of them, add a node only if they too deserve a type *)
            if (deserves_tg_node jth) = true then 
              begin
                let srcnode = Tgnode(jth) in
(*              (Printf.fprintf stderr "\t %d deserves_tg_node: %s\n" j
                  (tg_dumpnode srcnode)); *)
                curr_tg := (tg_add_node srcnode !curr_tg);
                let tgtnode = Tgnode(ith) in
                let new_edge = Tgedge(srcnode, tgtnode) in
                curr_tg := (tg_add_edge new_edge !curr_tg); 
              end;
          done;
        end;
    done;
    (* 3. Get the list of types in the graph that we just built. If we 
     * introduced any new types in addition to the input types, we must
     * rebuild the graph (so as to compute their predecessors too) *)
    let curr_tg_nodes = (Tgnodeset.elements (!curr_tg.nodes)) in
    let curr_tg_tl = ref [] in
    let new_node_added = ref false in
    for i = 0 to (List.length curr_tg_nodes) - 1 do
      let ith = (List.nth curr_tg_nodes i) in
      let Tgnode(ithtyp) = ith in
      curr_tg_tl := (List.append !curr_tg_tl [ithtyp]);
      (try
        (ignore (Hashtbl.find types_before_tg (typ_tostring_noattrib ithtyp)));
      with Not_found -> (
        new_node_added := true;
        (Printf.fprintf stderr "new node added: %s\n%!" 
          (typ_tostring_noattrib ithtyp));
      ));
    done;
    if (!new_node_added) then begin
      curr_tg := (build_tg !curr_tg_tl);
    end;
    !curr_tg;
  end

(*
    (* Incorrect TG construction *)
    let tl_len = (List.length tl) in 
    let curr_tg_tl_len = (List.length !curr_tg_tl) in
    if (tl_len <> curr_tg_tl_len) then begin
      curr_tg := (build_tg !curr_tg_tl);
    end;
*)
(*
    let tl_len = (List.length tl) in 
    let curr_tg_tl_len = (List.length !curr_tg_tl) in
    for i = 0 to (tl_len - 1) do
      let ith = (List.nth tl i) in
      (Printf.fprintf stderr "=> %s\n" (typ_tostring_noattrib ith));
    done;
    for i = 0 to (curr_tg_tl_len - 1) do
      let ith = (List.nth !curr_tg_tl i) in
      (Printf.fprintf stderr "-> %s\n" (typ_tostring_noattrib ith));
    done;
*)


(*---------------------------------------------------------------------------*)
(* Depth first search *)
let starttime: (string, int) Hashtbl.t  = (Hashtbl.create 117)
let endtime: (string, int) Hashtbl.t = (Hashtbl.create 117)
let dfsstatus: (string, int) Hashtbl.t = (Hashtbl.create 117)
let dfsTime = ref 0

let dumptable (table: (tgnode_t, int) Hashtbl.t) (str: string) : unit = 
  begin
    let iterfun (key: tgnode_t) (data: int) : unit = 
      (Printf.fprintf stderr "%s: %s %d\n" str
        (tg_dumpnode key) data);
    in
    (Hashtbl.iter iterfun table);
  end

let dumpcycle () : unit = 
  begin
    let iterfun (key: string) (data: int) : unit = 
      if (data = 1) then (Printf.fprintf stderr " [%s]" key)
    in
    (Printf.fprintf stderr "CYCLE: ");
    (Hashtbl.iter iterfun dfsstatus);
    (Printf.fprintf stderr "\n%!");
  end

(* DFS related routines *)
let rec dfs_visit (node: tgnode_t) (g: tgraph_t) : unit =
  begin
    let visitchild (child: tgnode_t) : unit = 
      (try
        let visitstat = (Hashtbl.find dfsstatus (tg_dumpnode child)) in
        if visitstat = 0 then (dfs_visit child g) 
        else if visitstat = 1 then begin 
          (Printf.fprintf stderr "WARNING: CYCLE\n%!");
          (dumpcycle());
        end;
      with Not_found -> ((assert false)));
    in
    (Hashtbl.replace dfsstatus (tg_dumpnode node) 1);
    dfsTime := !dfsTime + 1;
    (Hashtbl.replace starttime (tg_dumpnode node) !dfsTime); 
    let children = (tg_getchildren node g) in
    for i = 0 to (List.length children) - 1 do
      let ith = (List.nth children i) in
      (visitchild ith);
    done;
    (Hashtbl.replace dfsstatus (tg_dumpnode node) 2);
    dfsTime := !dfsTime + 1;
    (Hashtbl.replace endtime (tg_dumpnode node) !dfsTime);
  end

(* DFS related routines: the following routine obtains start 
 * and finishing times on the nodes *)
let dfs (g: tgraph_t) : unit =
  begin
    let do_dfs_visit (node: tgnode_t) : unit = 
      (try
        let visitstat = (Hashtbl.find dfsstatus (tg_dumpnode node)) in
        if visitstat = 0 then (dfs_visit node g)  (* Visits children *) 
      with Not_found -> ((assert false)));
    in
    (Hashtbl.clear dfsstatus);
    (Hashtbl.clear starttime);
    (Hashtbl.clear endtime);
    (* Initialization steps *)
    let nodelist = (Tgnodeset.elements g.nodes) in
    for i = 0 to (List.length nodelist) - 1 do
      let ith = (List.nth nodelist i) in
      (Hashtbl.replace dfsstatus (tg_dumpnode ith) 0);
      (Hashtbl.replace starttime (tg_dumpnode ith) 0);
      (Hashtbl.replace endtime (tg_dumpnode ith) 0);
      (Printf.fprintf stderr "setting : [%s]\n%!"
          (tg_dumpnode ith)); 
    done;
    dfsTime := 0;  (* Reset the DFS time *)
    (* DFS Visit step *)
    for i = 0 to (List.length nodelist) - 1 do
      let ith = (List.nth nodelist i) in
      (Printf.fprintf stderr "DFS: Doing node: %d\n%!" i);
      (do_dfs_visit ith); 
      (Printf.fprintf stderr "DFS: Done node: %d\n%!" i)
    done;
  end


(* nodes_dfs_order Return a list of nodes sorted in decreasing order 
   by DFS finishing times. ASSUME: That dfs has already been done 
*)
let nodes_dfs_order (g: tgraph_t): tgnode_t list =  
  begin
    let sortfun (a: tgnode_t) (b: tgnode_t) : int =
      let endtime_a = (Hashtbl.find endtime (tg_dumpnode a)) in
      let endtime_b = (Hashtbl.find endtime (tg_dumpnode b)) in
      if endtime_a = endtime_b
      then 0
      else if endtime_a < endtime_b
      then -1 else 1
    in
    let nodelist = (Tgnodeset.elements g.nodes) in 
    (List.sort sortfun nodelist)
  end

(* Topsort the graph *)
let topsort (g: tgraph_t) : tgnode_t list = 
  begin
    (dfs g);
    (nodes_dfs_order g);
  end


(*---------------------------------------------------------------------------*)
(* Query front end *)

(* get_ordered_typedefs:
 * Input is a list of types. Output is a list of types in the order that they
 * must appear in a file *)
let get_ordered_typedefs (tl: typ list) : typ list =
begin
  (* Insert the types to be ordered into a hashtable. The hashtable
   * stores the stripped version of the type, and all the types that
   * share the same stripped version *)
  for i = 0 to (List.length tl) - 1 do
    let ith = (List.nth tl i) in
    (Hashtbl.add types_to_be_ordered (typ_tostring_noattrib ith) ith);
  done;

  let retval = ref [] in
  (* For types that do not get a node in the tg, insert them at the 
   * beginning of the return list *)
  for i = 0 to (List.length tl) - 1 do
    let ith = (List.nth tl i) in
    if (deserves_tg_node ith) = false then
    begin
      (Hashtbl.add types_ordered (typ_tostring_noattrib ith) true);
      retval := (List.append !retval [ith]);
    end;
  done;
  (Printf.fprintf stderr "stg1: retval_length: %d\n%!"
    (List.length !retval));

  (* For types that *do* get a node in the tg, before inserting them
   * into the return list, check whether they were part of the input
   * list *)
  repeat := 0;
  let typgraph = (build_tg tl) in 
  (Printf.fprintf stderr "stg1: finished building graph:\n%!");
  (tg_stats typgraph);
  let tsort = (List.rev (topsort typgraph)) in 
  (Printf.fprintf stderr "stg2: tsort_length: %d\n%!" (List.length tsort));
  (* Take the topsorted list, and for each type, look at all the types
   * that are the same modulo type attributes and emit all of them. The
   * idea is that those types must be emitted at this slot in the topsort
   * order *)
  for i = 0 to (List.length tsort) - 1 do
    let ith = (List.nth tsort i) in
    let Tgnode(ithtyp) = ith in
    (try
      (ignore (Hashtbl.find types_to_be_ordered (typ_tostring_noattrib ithtyp)));
      let allmatches = 
        (Hashtbl.find_all types_to_be_ordered (typ_tostring_noattrib ithtyp)) in
      (Hashtbl.add types_ordered (typ_tostring_noattrib ithtyp) true);
      retval := (List.append !retval allmatches);
    with Not_found -> ());
  done;

  (* Sanity checks: the input list must be of the same length as the length of
   * the list that we're returning *)
  (tg_dump typgraph);
  if ((List.length !retval) <> (List.length tl)) then
  begin
    (Printf.fprintf stderr "retval_length: %d, tl_length: %d\n"
      (List.length !retval) (List.length tl));
    for i = 0 to (List.length !retval) - 1 do
      let ith = (List.nth !retval i) in
      (try
        (ignore (Hashtbl.find types_to_be_ordered (typ_tostring_noattrib ith)));
      with Not_found -> (
        (Printf.fprintf stderr "%s in retval, not in tl\n%!" 
          (typ_tostring_noattrib ith));
        (assert false);
      ));
    done;
    for i = 0 to (List.length tl) - 1 do
      let ith = (List.nth tl i) in 
      (try
        (ignore (Hashtbl.find types_ordered (typ_tostring_noattrib ith)));
      with Not_found -> (
        (Printf.fprintf stderr "%s in tl, not in retval\n%!"
          (typ_tostring_noattrib ith));
        (assert false);
      ));
    done;
  end;

  !retval;
end
