module type MONAD = sig
  type 'a m
  val return : 'a -> 'a m
  val bind : 'a m -> ('a -> 'b m) -> 'b m
end

module Id : MONAD with type 'a m = 'a = struct
  type 'a m = 'a
  let return x = x
  let bind m f = f m
end

(** Transformateurs de monades pour Reader et List *)

module type READER = sig
  include MONAD
  type t = char
  val read : t option m

  type 'a m0
  val lift : 'a m0 -> 'a m
end

module ReaderT (M:MONAD) : READER with
  type 'a m0 = 'a M.m and
  type 'a m = char list -> (char list * 'a) M.m =
struct

  type t = char
  type 'a m0 = 'a M.m
  type 'a m = t list -> (t list * 'a) M.m

  let return x = fun l -> M.return (l,x)
  let bind (m:'a m) (f:'a -> 'b m) =
    fun l -> M.bind (m l) (fun (l',v) -> f v l')

  let read = function
    | hd::tl -> M.return (tl,Some hd)
    | [] -> M.return ([],None)

  let lift m = fun l -> M.bind m (fun v -> M.return (l,v))
end

module ListT (M:MONAD) : MONAD with type 'a m = 'a list M.m = struct
  type 'a m = 'a list M.m
  let return x = M.return [x]
  let rec mmap f l acc =
    match l with
      | [] -> M.return acc
      | hd::tl ->
          M.bind (f hd) (fun hd' -> mmap f tl (List.rev_append hd' acc))
  let bind m f =
    M.bind m (fun lv -> mmap f lv [])
end

(** Application à la construction de la monade de parsing *)

module Parsing = struct
  module L = ListT(Id)
  include ReaderT(L)

  let plus : 'a m -> 'a m -> 'a m =
    fun m n l -> (m l) @ (n l)

  let zero : 'a m = fun l -> []

  let read : char m =
    function
      | hd::tl -> [tl,hd]
      | [] -> []

  let eos : bool m = function
    | [] -> [[],true]
    | l -> [l,false]

  let char c : char list m =
    bind read (fun c' -> if c' = c then return [c] else zero)

  let one : char list m = return []

  let concat p q =
    bind p (fun w -> bind q (fun w' -> return (w@w')))

  let lconcat p f =
    bind p (fun w -> bind (f ()) (fun w' -> return (w@w')))

  let string s =
    let l = String.length s in
    let rec str i =
      if i = l then one else
        concat (char s.[i]) (str (i+1))
    in
      str 0

  let rec star p =
    plus one (lconcat p (fun () -> star p))

end

(** Utilisation de la monade de parsing *)

open Parsing

let print_results l =
  Printf.printf "Possible results:\n" ;
  List.iter
    (fun (_,r) ->
       print_string "  \"" ;
       List.iter print_char r ;
       print_string "\"\n")
    l

let (!) = star
let (^) = concat
let (++) = plus

let total m =
  bind m (fun res -> bind eos (fun b -> if b then return res else zero))

let () =
  print_results (total (char 'a') ['a';'a']) ;
  print_results (total !(char 'a') ['a';'a']) ;
  print_results (total !(char 'a' ++ char 'b') ['a';'b';'a']) ;
  print_results (string "fo" ['f';'o';'o'])

let () =
  assert (List.length (string "fo" ['b';'a';'r']) = 0) ;
  assert (List.length (string "fo" ['f';'o';'o']) = 1) ;
  assert (List.length (string "fo" ['f';'a']) = 0)

let () =
  assert (List.length (!(char 'a' ++ char 'b') ['a';'b';'a']) = 4)

(** Parseur arithmetique de type char list m. *)

let digit =
  List.fold_left (++) zero
    (List.map char ['0';'1';'2';'3';'4';'5';'6';'7';'8';'9'])

(** Grammaire expr ::= digit | digit + expr *)
let rec expr () = digit ++ (concat digit (lconcat (char '+') expr))
let expr = expr ()

let () =
  print_results (expr ['0';'+';'1';'+';'2']) ;
  print_results (total expr ['0';'+';'1';'+';'2'])

(** Parseur-interpréteur de type int m. *)

let (>>=) = bind
let (>>) m n = bind m (fun _ -> n)

let digit =
  digit >>= fun c -> return (int_of_char (List.hd c) - int_of_char '0')

(** Même grammaire expr ::= digit | digit + expr *)
let rec expr () =
  digit
  ++
  (digit >>= fun i ->
     char '+' >>
       expr () >>= fun j ->
         return (i+j))

let expr = expr ()

let () =
  Printf.printf "Possible results (1,3,6 expected):\n" ;
  List.iter
    (fun (_,i) -> Printf.printf "  %d\n" i)
    (expr ['1';'+';'2';'+';'3'])

(** On ajoute la multiplication
  *   expr ::= factors | factors+expr
  *   factors ::= digit | (expr) | factors*factors *)
let rec expr () =
  factors () ++
  (factors () >>= fun f -> char '+' >> expr () >>= fun e -> return (f+e))
and factors () =
  (digit ++
   (char '(' >>= fun _ -> expr () >>= fun e -> char ')' >> return e))
    >>= fun i -> return i ++
                 (char '*' >> factors () >>= fun f -> return (i*f))

let expr = expr ()
let factors = factors ()

let () =
  let test l =
    match total expr l with
      | [_,i] -> Printf.printf "Résultat: %d\n" i
      | _ -> assert false
  in
    test ['5';'+';'2'] ;
    test ['1';'+';'2';'*';'3'] ;
    test ['2';'*';'3';'+';'1'] ;
    test ['2';'*';'(';'1';'+';'2';')';'+';'1'] ;
    test ['(';'1';'+';'2';')';'*';'2';'+';'1'] ;
    test ['(';'7';'*';'1';')';'*';'1']