(** Signature d'une monade *)

module type MONAD = sig
  type 'a t
  val return : 'a -> 'a t
  val bind : 'a t -> ('a -> 'b t) -> 'b t
end

(** Monade d'exception *)

module type EXN = sig
  include MONAD
  val throw : exn -> 'a t
  val try_with : 'a t -> (exn -> 'a t) -> 'a t
  val run : 'a t -> 'a
end

module Exn : EXN = struct
  type 'a t = Val of 'a | Exn of exn
  let return x = Val x
  let bind m n =
    match m with
      | Val x -> n x
      | Exn e -> Exn e
  let throw e = Exn e
  let try_with m f =
    match m with
      | Val v -> Val v
      | Exn e -> f e
  let run m =
    match m with
      | Val x -> x
      | Exn e -> raise e
end

(** Test de la monade d'exception *)

let () =
  let module M = Exn in
  let m =
    M.try_with
      (M.throw (Failure "blah"))
      (fun e ->
         M.try_with
           (M.return 42)
           (fun e -> M.throw (Failure "tu me vois pas")))
  in
    Printf.printf "Test exn: %d\n" (M.run m)

(** Monade de continuation *)

module type CONT = sig
  include MONAD
  type 'a cont
  val throw : 'a cont -> 'a -> 'b t
  val callcc : ('a cont -> 'a t) -> 'a t
  val run : 'a t -> 'a
end

module Cont : CONT = struct

  type 'a t = ('a -> unit) -> unit

  let return x f = f x
  let bind m n f = m (fun x -> n x f)

  exception Return

  let run m =
    let res = ref None in
      try m (fun x -> res := Some x ; raise Return) ; assert false with
        | Return -> match !res with Some x -> x | None -> assert false

  type 'a cont = 'a -> unit
  let throw k v f = k v
  let callcc m f = m f f

end

(** Test de la monade de continuation *)

module M = Cont

let rec iter f l =
  match l with
    | [] -> M.return ()
    | hd::tl -> M.bind (f hd) (fun () -> iter f tl)

let () =
  let find pred lst =
    M.callcc (fun k ->
      M.bind
        (iter
           (fun x ->
              if pred x then
                M.callcc (fun k' -> M.throw k (Some (x,k'))) else M.return ())
           lst)
        (fun () -> M.throw k None))
  in
  let print_all pred lst =
    M.bind
      (find pred lst)
      (function
         | Some (y,k) -> Printf.printf "Test cont: %d\n" y ; M.throw k ()
         | None -> M.return ())
  in
    M.run (print_all (fun x -> x mod 2 = 0) [1;2;3;4])

(** Monade de non déterminisme
  * Pour changer des listes, on fait du success/failure continuation *)

module type ND_t = sig
  include MONAD
  val orelse : 'a t -> 'a t -> 'a t
  val run : 'a t -> ('a -> unit) -> unit
end

module ND = struct

  type cont = unit -> unit
  type failure = cont
  type 'a success = 'a -> cont -> unit
  type 'a t = 'a success -> failure -> unit

  let return x = fun s f -> s x f
  let bind m n = fun s f -> m (fun x k -> n x s k) f
  let orelse m n = fun s f -> m s (fun () -> n s f)

  let run m f = m (fun x k -> f x ; k ()) (fun () -> ())

end

let () =
  let module M = ND in
  let m =
    M.orelse (M.return 3) (M.bind (M.return 1) (fun x -> M.return (x+1)))
  in
    M.run m (fun x -> Printf.printf "Test 1: %d\n" x)

(** Extension de la monade précédente avec des probas *)

module type T = sig
  include MONAD
  val fail : 'a t
  val flip : bool t
  val choice : float -> bool t
end

module P = struct

  type 'a t = ('a -> float) -> float

  let return x = fun f -> f x
  let bind m n = fun f -> m (fun x -> (n x) f)

  let fail = fun f -> 0.
  let choice p = fun f -> p *. f true +. (1.-.p) *. f false
  let flip = choice 0.5

end

module Algos (M:T) = struct

  let pair =
    M.bind M.flip (fun x -> M.bind M.flip (fun y -> M.return (x,y)))

  let pick l =
    let rec aux l len =
      match l with
        | [] -> M.fail
        | hd::tl ->
            M.bind
              (M.choice (1. /. float len))
              (fun b ->
                 if b then
                   M.return hd
                 else
                   aux tl (len-1))
    in
      aux l (List.length l)

end

module TestP = struct
  module A = Algos(P)
  let dirac x = fun y -> if x = y then 1. else 0.
  let () =
    Printf.printf "P[pair=(true,true)] = %.2f\n" (A.pair (dirac (true,true))) ;
    Printf.printf "P[pick([1;2;3])=2] = %.3f\n" (A.pick [1;2;3] (dirac 2)) ;
    Printf.printf "P[pick([1;2;2])=2] = %.3f\n" (A.pick [1;2;2] (dirac 2))
end