(*===========================================================================*)
(*
 * CIL plugin for rootkit analysis.
 *
 * Vinod Ganapathy <vg@cs.wisc.edu>, October 15, 2007.
*)
(*===========================================================================*)

open Cil
open Str
open Typegraph

let locvartypes_file = ref ""
let globvartypes_file = ref ""
let typesizes_file = ref ""
let typedefs_file = ref ""
let printattrs = ref false

exception Notaninteger;;
exception Notanexp;;

(** List all the keys in a hashtable *)
let list_keys (table: ('a,'b) Hashtbl.t) : 'a list =
begin
  let retval : 'a list ref = ref [] in
  let iterfun (key : 'a) (data : 'b) : unit = 
    retval := key::!retval
  in
    begin
      (Hashtbl.iter iterfun table);
      !retval
    end
end

(** List one binding per key of the Hashtbl, namely the
 * one found by Hashtbl.find *)
let list_bindings_key (table: ('a,'b) Hashtbl.t) : 'b list =
  begin
    let allkeys = (list_keys table) in 
    let retval : 'b list ref = ref [] in
    for i = 0 to (List.length allkeys) - 1 do
      let ith = (List.nth allkeys i) in
      let datafound = (Hashtbl.find table ith) in
      retval := datafound::!retval;
    done;
    !retval
  end

(** add_if adds a binding to 'key' in a hashtable if it's not bound already *)
let add_if (table: ('a,'b) Hashtbl.t) (key: 'a) (data: 'b): unit = 
begin
  if (not (Hashtbl.mem table key)) then (Hashtbl.add table key data)
end

(*---------------------------------------------------------------------------*)
(** CIL visitor to spit out the following information:
 * 1. Variable type information (varname, type, funcname/GLOB) 
 * 2. Type size information     (type, size in bytes)
 * 3. Type definitions          
*)
class emit_typeinfo = object (self) inherit nopCilVisitor 
  (* Out channels *)
  val mutable locvartypes_htab : (string, bool) Hashtbl.t = (Hashtbl.create 117);
  val mutable globvartypes_htab : (string, bool) Hashtbl.t = (Hashtbl.create 117);
  val mutable typesizes_htab : (string, bool) Hashtbl.t = (Hashtbl.create 117);
  val mutable typedefs_htab : (string, typ) Hashtbl.t = (Hashtbl.create 117);

  (* Try to flatten an exp to an integer as much as possible *)
  method eval_exp_as_int (e: exp) : int = 
  begin
    (match e with
    | Const(c) -> 
      (match c with
      | CInt64(i64, _, _) -> (Int64.to_int i64);
      | _ -> (raise Notaninteger);
      );
    | SizeOf(t) -> (self#sizeof_typ t); 
    | SizeOfE(e') -> (self#eval_exp_as_int e');
    | UnOp(unop, e', t) -> 
      if (isIntegralType t) 
      then 
      begin 
        let len = (self#eval_exp_as_int e') in
        (match unop with
        | Neg -> (0 - len);
        | BNot -> (lnot len); 
        | LNot -> if (len <> 0) then 0 else 1; 
        );
      end
      else (raise Notaninteger);
    | BinOp(binop, e1, e2, t) -> 
      if (isIntegralType t) then 
      begin
        let len1 = (self#eval_exp_as_int e1) in
        let len2 = (self#eval_exp_as_int e2) in 
        (match binop with
        | PlusA -> len1 + len2; 
        | MinusA -> len1 - len2;
        | Shiftlt ->  
          let len1_64 = (Int64.of_int len1) in 
          let shftlt_64 = (Int64.shift_left len1_64 len2) in
          (Int64.to_int shftlt_64);
        | Shiftrt ->
          let len1_64 = (Int64.of_int len1) in 
          let shftrt_64 = (Int64.shift_right len1_64 len2) in
          (Int64.to_int shftrt_64);
        | Mult -> len1 * len2;
        | Div -> len1/len2;
        | BAnd -> len1 land len2;
        | BOr -> len1 lor len2;
        | _ -> (raise Notaninteger);
        );
      end else (raise Notaninteger);
    | CastE(t,e') -> 
      if (isIntegralType t) 
      then (self#eval_exp_as_int e')
      else (raise Notaninteger);
    | AlignOfE(e') -> (self#eval_exp_as_int e');
    | _ -> (raise Notaninteger); 
    );
  end

  method eval_exp_as_int_and_convert_to_string (e: exp) : string = 
  begin
    (try
      let i = (self#eval_exp_as_int e) in 
      (Int64.to_string (Int64.of_int i));
    with Notaninteger -> (
      (Pretty.sprint 500 (d_exp() e));
    ));
  end

  (* attrparam_to_exp: Convert an attribute parameter to an expression.
   * Raise not_an_exp if you can't. We only support selected attrparams
   * for now.
  *)
  method attrparam_to_exp (ap: attrparam) : exp = 
  begin
    (match ap with
    | AInt(i) -> (integer i);
    | AUnOp(u, ap1) -> (UnOp(u,(self#attrparam_to_exp ap1),intType));
    | ABinOp(b, ap1, ap2) ->
      (BinOp(b,(self#attrparam_to_exp ap1),(self#attrparam_to_exp ap2),intType));
    | AAlignOfE(ap1) -> (self#attrparam_to_exp ap1);
    | _ -> (raise Notanexp);
    );
  end

  (* String representation of attribute. Use the direct print out for
   * most cases, except those Arati wants flattened. *)
  method attribute_tostring (a: attribute) : string = 
  begin
    let Attr(str,apl) = a in
    if (List.length apl) <> 1
    then (Pretty.sprint 500 (d_attr() a))
    else 
      let first = (List.nth apl 0) in 
      (match first with
      | AInt(_) ->
        (try 
          let fs = 
            (self#eval_exp_as_int_and_convert_to_string (self#attrparam_to_exp first)) in
          str ^ "(" ^ fs ^ ")";
        with Notanexp -> ((Pretty.sprint 500 (d_attr() a))));
      | AUnOp(_) ->
        (try 
          let fs = 
            (self#eval_exp_as_int_and_convert_to_string (self#attrparam_to_exp first)) in
          str ^ "(" ^ fs ^ ")";
        with Notanexp -> ((Pretty.sprint 500 (d_attr() a))));
      | ABinOp(_) ->
        (try 
          let fs = 
            (self#eval_exp_as_int_and_convert_to_string (self#attrparam_to_exp first)) in
          str ^ "(" ^ fs ^ ")";
        with Notanexp -> ((Pretty.sprint 500 (d_attr() a))));
      | AAlignOfE(_) -> 
        (try 
          let fs = 
            (self#eval_exp_as_int_and_convert_to_string (self#attrparam_to_exp first)) in
          str ^ "(" ^ fs ^ ")";
        with Notanexp -> ((Pretty.sprint 500 (d_attr() a))));
      | _ -> (Pretty.sprint 500 (d_attr() a));
      );
  end

  (* String representation of a type. Arrays get special treatment to eval
     their exps *)
  method typ_tostring (t: typ) : string = 
  begin
    (match t with
    | TArray(t', lenopt, _) ->
      (match lenopt with
      | Some(lenexp) ->
        let lenstr = (self#eval_exp_as_int_and_convert_to_string lenexp) in
        let s' = (self#typ_tostring t') in
        s' ^ "[" ^ lenstr ^ "]";
      | None -> (Pretty.sprint 500 (d_type() t));
      );
    | _ -> (Pretty.sprint 500 (d_type() t));
    );
  end

  (** Strip attributes from an input type *)
  method strip_typ_attribs (t: typ) : typ = 
  begin
    (match t with
       | TVoid(_) -> TVoid([]);
       | TInt(k,_) -> TInt(k,[]);
       | TFloat(k,_) -> TFloat(k,[]);
       | TPtr(t',_) -> TPtr((self#strip_typ_attribs t'), []);
       | TArray(t',e,_) -> TArray((self#strip_typ_attribs t'), e, []);
       | TFun(t',l,b,_) -> TFun((self#strip_typ_attribs t'),l,b,[]);
       | TNamed(tinfo,_) -> TNamed(tinfo,[]);
       | TComp(cinfo,_) -> TComp(cinfo,[]);
       | TEnum(einfo,_) -> TEnum(einfo,[]);
       | TBuiltin_va_list(_) -> TBuiltin_va_list([]);
    );
  end

  (* String representation of a type, without attribute information *)
  method typ_tostring_noattr (t: typ) : string = 
  begin
    let t_stripped = (self#strip_typ_attribs t)  in
    (self#typ_tostring t_stripped);
  end

  (* Method to print a type *)
  method print_type (t: typ) : string = 
  begin
    if (!printattrs) then (self#typ_tostring t) else (self#typ_tostring_noattr t);
  end

  (* Prettyprint a type definition. Not printing out attribute information *)
  method prettyprint_type_definition (t: typ) : string = 
  begin
    let outstr = ref "" in
    (match t with
    | TNamed(t',al) -> 
      outstr := (Printf.sprintf "@typedef| @%s| @%s|" 
        t'.tname (self#print_type t'.ttype));
      (* VG: These lines print out attribs *)
      if (!printattrs) then begin
        for i = 0 to (List.length al) - 1 do
          let ith = (List.nth al i) in 
          outstr := !outstr ^ " __attribute__((" ^ (self#attribute_tostring ith) ^ "))";
        done;
      end;
      outstr := !outstr ^ (Printf.sprintf ";\n"); 
    | TComp(c,al) -> 
      if c.cstruct then
        outstr := (Printf.sprintf "struct %s {\n" c.cname)
      else
        outstr := (Printf.sprintf "union %s {\n" c.cname);
      for i = 0 to (List.length c.cfields) - 1 do
        let ith = (List.nth c.cfields i) in 
        (match ith.fbitfield with
        | Some(width) ->
          outstr := !outstr ^ 
            (Printf.sprintf "\t@field| @%s| @%s|:%d;\n" 
              (self#print_type ith.ftype) ith.fname width)
        | None ->
          outstr := !outstr ^
            (Printf.sprintf "\t@field| @%s| @%s|;\n" 
              (self#print_type ith.ftype) ith.fname);
        );
      done;
      outstr := !outstr ^ (Printf.sprintf "}");
      (* VG: These lines print out attribs *)
      if (!printattrs) then begin
        for i = 0 to (List.length c.cattr) - 1 do
          let ith = (List.nth c.cattr i) in
          outstr := !outstr ^ " __attribute__((" ^ (self#attribute_tostring ith) ^ "))";
        done;
        for i = 0 to (List.length al) - 1 do
          let ith = (List.nth al i) in 
          outstr := !outstr ^ " __attribute__((" ^ (self#attribute_tostring ith) ^ "))";
        done;
      end;
      outstr := !outstr ^ (Printf.sprintf ";\n");
    | TEnum(e,al) -> 
      outstr := (Printf.sprintf "enum %s {\n" e.ename);
      for i = 0 to (List.length e.eitems) - 1 do
        let (ithstr,ithexp,_) = (List.nth e.eitems i) in
        (match ithexp with
        | Const(c) ->
          let const_tostring = (Pretty.sprint 500 (d_exp() ithexp)) in
          outstr := !outstr ^
            (Printf.sprintf "\t@field| @%s| @%s|,\n" ithstr const_tostring);
        | _ -> 
          outstr := !outstr ^ (Printf.sprintf "\titem %s val_unknown\n" ithstr);
        );
      done;
      outstr := !outstr ^ (Printf.sprintf "}");
      (* VG: These lines print out attribs *)
      if (!printattrs) then begin
        for i = 0 to (List.length e.eattr) - 1 do
          let ith = (List.nth e.eattr i) in
          outstr := !outstr ^ " __attribute__((" ^ (self#attribute_tostring ith) ^ "))";
        done;
        for i = 0 to (List.length al) - 1 do
          let ith = (List.nth al i) in 
          outstr := !outstr ^ " __attribute__((" ^ (self#attribute_tostring ith) ^ "))";
        done;
      end;
      outstr := !outstr ^ (Printf.sprintf ";\n");
    | _ -> ();
    );
    !outstr;
  end

  (* Get size information in bytes *)
  method sizeof_typ (t: typ) : int = 
  begin
    let e = (sizeOf t) in
    let retsize = ref 0 in 
    (match e with
    | Const(c) ->
      (match c with 
      | CInt64(i64,_,_) -> retsize := (Int64.to_int i64);
      | _ -> retsize := -1;
      );
    | SizeOf(_) -> retsize := -2;
    | _ -> retsize := -3;
    );
    if (!retsize < 0) then
    (Printf.printf "Can't determine size of %s\n" (self#print_type t));
    !retsize;
  end

  (* Emit type names of local variables. *)
  method emit_local_vartypes (v: varinfo) (fdec: fundec) (str: string) : unit = 
  begin
    let sizeof = (self#sizeof_typ v.vtype) in
    let outstr = 
      (Printf.sprintf "@LOCAL_%s| @%s| @%s| @%d| @%s|\n" 
        fdec.svar.vname v.vname (self#print_type v.vtype) sizeof str) in
    (add_if locvartypes_htab outstr true);
  end
 
  (* Emit type names of global variables. *)
  method emit_global_vartypes (v: varinfo) : unit = 
  begin
    let sizeof = (self#sizeof_typ v.vtype) in
    let outstr = 
      (Printf.sprintf "@GLOB| @%s| @%s| @%d|\n" v.vname 
        (self#print_type v.vtype) sizeof) in
    (add_if globvartypes_htab outstr true);
  end

  (* Emit sizes information for declared types *)
  method emit_size_info (t: typ) : unit = 
  begin
    let sizeof = (self#sizeof_typ t) in
     let outstr = 
       (Printf.sprintf "@%s| @%d|\n" (self#print_type t) sizeof) in
    (add_if typesizes_htab outstr true);
  end

  (* Emit definitions of declared types *)
  method emit_typedefs (t: typ) : unit = 
  begin
    let outstr = (self#prettyprint_type_definition t) in 
    (add_if typedefs_htab outstr t);
  end

  (* Remove duplicate typedefs: Look at the list of input types for "typedefs".
   * Do not emit duplicate typedefs *)
  method remove_duplicate_typedefs (tl: typ list) : typ list =
  begin
    let typedef_names : (string, bool) Hashtbl.t = (Hashtbl.create 117) in
    let retval = ref [] in
    for i = 0 to (List.length tl) - 1 do
      let ith = (List.nth tl i) in
      (match ith with
      | TNamed(tinfo, _) -> 
        (try
          (ignore (Hashtbl.find typedef_names tinfo.tname));
        with Not_found -> (
          (Hashtbl.add typedef_names tinfo.tname true);
          retval := (List.append !retval [ith]);
        ));
      | TEnum(einfo, _) ->
        (try
          (ignore (Hashtbl.find typedef_names einfo.ename));
        with Not_found -> (
          (Hashtbl.add typedef_names einfo.ename true);
          retval := (List.append !retval [ith]);
        ));
      | _ -> retval := (List.append !retval [ith]); 
      );
    done;
    !retval;
  end

  (* Visitor for function declarations. This will spit out the static 
   * local variables declared in a function *)
  method vfunc (fdec: fundec) : fundec visitAction = 
  begin
    for i = 0 to (List.length fdec.slocals) - 1 do
      let ith = (List.nth fdec.slocals i) in
      (match ith.vstorage with
      | Static -> (self#emit_local_vartypes ith fdec "static");
      | Extern -> (self#emit_local_vartypes ith fdec "extern");
      | Register -> (self#emit_local_vartypes ith fdec "register");
      | _ -> (self#emit_local_vartypes ith fdec "");
      );
    done;
    DoChildren;
  end;

  (* Visitor for variable names: Emit the names and types of global
     variables. Don't emit function declarations *)
  method vvdec (v: varinfo) : varinfo visitAction = 
  begin
    if (v.vglob = true) then
    begin
      (match v.vtype with
      | TFun(_) -> ();
      | _ -> (self#emit_global_vartypes v);
      );
    end;
    DoChildren;
  end
  
  (* Visitor for type information. Do the following:
   * 1. Spit out the type name and size information.
   * 2. Spit out the definition of the type. 
  *)
  method vtype (t: typ) : typ visitAction = 
  begin
    (*(Printf.fprintf stderr "vtype:  %s\n%!" (self#typ_tostring t));*)
    (match t with
    | TFun(_) -> ();
    | _ -> 
      (self#emit_size_info t);  
      (self#emit_typedefs t);
    );
    DoChildren
  end

  (* Top level function. *)
  method top_level (f: file) : unit = 
  begin
    (visitCilFile (self :> cilVisitor) f);

    (* Write stuff now *)
    let locvartypes_outstream : out_channel = 
      (open_out_gen [Open_append] 1 !locvartypes_file) in 
    let locvartypes_list = (list_keys locvartypes_htab) in 
    for i = 0 to (List.length locvartypes_list) - 1 do
      let ith = (List.nth locvartypes_list i) in
      (Printf.fprintf locvartypes_outstream "%s" ith);
    done;
    (close_out locvartypes_outstream);
    
    let globvartypes_outstream : out_channel = 
      (open_out_gen [Open_append] 1 !globvartypes_file) in 
    let globvartypes_list = (list_keys globvartypes_htab) in 
    for i = 0 to (List.length globvartypes_list) - 1 do
      let ith = (List.nth globvartypes_list i) in
      (Printf.fprintf globvartypes_outstream "%s" ith);
    done;
    (close_out globvartypes_outstream);

    let typesizes_outstream : out_channel = 
      (open_out_gen [Open_append] 1 !typesizes_file) in 
    let typesizes_list = (list_keys typesizes_htab) in 
    for i = 0 to (List.length typesizes_list) - 1 do
      let ith = (List.nth typesizes_list i) in
      (Printf.fprintf typesizes_outstream "%s" ith);
    done;
    (close_out typesizes_outstream);
    
    let typedefs_outstream : out_channel = 
      (open_out_gen [Open_append] 1 !typedefs_file) in 
    let typedefs_list = (list_bindings_key typedefs_htab) in
    let typedefs_list = (Typegraph.get_ordered_typedefs typedefs_list) in
    let typedefs_list = (self#remove_duplicate_typedefs typedefs_list) in
    for i = 0 to (List.length typedefs_list) - 1 do
      let ith = (List.nth typedefs_list i) in 
      let ith_pp = (self#prettyprint_type_definition ith) in
      (Printf.fprintf typedefs_outstream "%s" ith_pp);
    done;
    (close_out typedefs_outstream);
    
    let size1 = (List.length (list_keys typedefs_htab)) in 
    (Printf.fprintf stderr "typedefs_htab: %d\n" size1);
  end

end

(*---------------------------------------------------------------------------*)
(** Toplevel function that performs analysis for rootkit detection. *)
let dorootkitanalysis (f:file) : unit = 
begin
(*
  (* Init the files *)
  let outstr = (open_out !locvartypes_file) in 
  (Printf.fprintf outstr "");
  (close_out outstr);
  let outstr = (open_out !globvartypes_file) in
  (Printf.fprintf outstr "");
  (close_out outstr);
  let outstr = (open_out !typesizes_file) in 
  (Printf.fprintf outstr "");
  (close_out outstr);
  let outstr = (open_out !typedefs_file) in 
  (Printf.fprintf outstr "");
  (close_out outstr);
*)

  (* Simplify the program to make it look like a CFG -- the makeCFG
   * feature of CIL. This simplifies the analysis that we have to do *)
  (Cil.initCIL());
  (ignore (Partial.calls_end_basic_blocks f));
  (ignore (Partial.globally_unique_vids f));
  Cil.iterGlobals f (fun glob -> 
    match glob with
    | Cil.GFun(fd,_) -> 
      Cil.prepareCFG fd; 
      (ignore (Cil.computeCFGInfo fd true))
    | _ -> ()
  );

  (* Emit variable name/type information and Emit all type information.
   * The visitors do the job for us.*)
  let obj : emit_typeinfo = (new emit_typeinfo) in
  (obj#top_level f);
 
  (Printf.printf "#### Total execution time: %f\n" (Sys.time()));
end

(*---------------------------------------------------------------------------*)
(** CIL feature descriptions *)
let feature : featureDescr = 
  { fd_name = "rootkit";
    fd_enabled = Cilutil.rootkit;
    fd_description = "rootkit analysis";
    fd_extraopt = 
    [("--printattr",
      Arg.Bool (fun f -> printattrs := f),
      "true if you want to print type attrs, false otherwise");
     ("--locvartypes",
      Arg.String (fun f -> locvartypes_file := f),
      "File name to dump local variable type information");
     ("--globvartypes",
      Arg.String (fun f -> globvartypes_file := f),
      "File name to dump glob variable type information");
     ("--typesizes",
       Arg.String (fun f -> typesizes_file := f),
      "File name to store size information for types");
     ("--typedefs",
       Arg.String (fun f -> typedefs_file := f),
      "File name to store type definitions");
    ]; 
    fd_doit = dorootkitanalysis;
    fd_post_check = true
  }

(*===========================================================================*)
