(* -*- mode: sml -*-
 * $Id: prodcons.smlnj,v 1.1 2001/09/01 01:40:21 doug Exp $
 * http://www.bagley.org/~doug/shootout/
 * from Matthias Blume
 *)

(* producer-consumer threads in SML/NJ
 * (concurrency primitives re-implemented "by hand" on top of call/cc
 * using the code in John Reppy's book "Concurrent Programming in ML")
 *
 * (C) 2001 Lucent Technologies, Bell Labs
 * written by Matthias Blume
 *)

structure Queue :> sig
    exception Empty
    type tt = unit SMLofNJ.Cont.cont
    type q
    val new : unit -> q
    val enqueue : q * tt -> unit
    val dequeue : q -> tt
    val empty : q -> bool
end = struct

    exception Empty

    type tt = unit SMLofNJ.Cont.cont
    type q = tt list ref * tt list ref

    fun new () : q = (ref [], ref [])

    fun enqueue ((f as ref [], ref []) : q, x) = f := [x]
      | enqueue ((_, b as ref xs), x) = b := x :: xs

    fun dequeue ((f, b) : q) =
    case !f of
        [] => (case rev (!b) of
               x :: xs => (f := xs; b := []; x)
             | [] => raise Empty)
      | x :: xs => (f := xs; x)

    fun empty ((ref [], ref []) : q) = true
      | empty _ = false

end

structure Mutex :> sig

    val yield : unit -> unit
    val fork : (unit -> unit) -> unit
    val exit : unit -> 'a

    type mutex
    type condition

    val mutex : unit -> mutex
    val lock : mutex -> unit
    val unlock : mutex -> unit

    val condition : mutex -> condition
    val wait : condition -> unit
    val signal : condition -> unit

    val run : (unit -> unit) * Time.time -> unit

end = struct

    local
    structure C = SMLofNJ.Cont
    structure Q = Queue
    type tt = unit C.cont

    (* We take the easy way out: Simply drop signals that
     * arrive during an atomic section on the floor.  This is
     * enough for our purpose and simplifies the coding... *)
    val atomicState = ref false
    fun atomicBegin () = atomicState := true
    fun atomicEnd () = atomicState := false
    val readyQ : Q.q = Q.new ()

    fun dispatch () = C.throw (Q.dequeue readyQ) ()

    fun sigH (_: Signals.signal, _: int, k: tt) =
        if !atomicState then k
        else (Q.enqueue (readyQ, k); Q.dequeue readyQ)
    in
        
        fun yield () =
        (atomicBegin ();
         C.callcc (fn k => (Q.enqueue (readyQ, k); dispatch ()));
         atomicEnd ())

    fun exit () = (atomicBegin (); dispatch ())

    fun fork f = let
        val k = C.isolate (fn () => (atomicEnd ();
                     f () handle _ => ();
                     exit ()))
    in
        atomicBegin ();
        Q.enqueue (readyQ, k);
        atomicEnd ()
    end

    
        datatype mutex = Mutex of { locked : bool ref, blocked : Q.q }

    fun mutex () = Mutex { locked = ref false, blocked = Q.new () }

    fun lock (Mutex { locked, blocked }) =
        (atomicBegin ();
         if !locked then
         C.callcc (fn k => (Q.enqueue (blocked, k);
                    dispatch ()))
         else locked := true;
         atomicEnd ())

    fun unlock (Mutex { locked, blocked }) =
        (atomicBegin ();
         if Q.empty blocked then locked := false
         else C.callcc (fn k => (Q.enqueue (readyQ, k);
                     C.throw (Q.dequeue blocked) ()));
         atomicEnd ())


        
    datatype condition = Cond of { mutex : mutex, waiting : Q.q }

    fun condition m = Cond { mutex = m, waiting = Q.new () }

    fun wait (Cond { mutex = m as Mutex { locked, blocked }, waiting }) =
        (atomicBegin ();
         C.callcc (fn k =>
              (Q.enqueue (waiting, k);
               if Q.empty blocked then (locked := false;
                            dispatch ())
               else C.throw (Q.dequeue blocked) ()));
         if !locked then
         C.callcc (fn k => (Q.enqueue (blocked, k);
                    dispatch ()))
         else locked := true;
         atomicEnd ())

    fun signal (Cond { waiting, ... }) =
        (atomicBegin ();
         if Q.empty waiting then ()
         else Q.enqueue (readyQ, Q.dequeue waiting);
         atomicEnd ())

    fun run (f, t) = let
        val oh = Signals.setHandler (Signals.sigALRM,
                     Signals.HANDLER sigH)
        val _ = SMLofNJ.IntervalTimer.setIntTimer (SOME t)
        fun reset () =
        (ignore (Signals.setHandler (Signals.sigALRM, oh));
         SMLofNJ.IntervalTimer.setIntTimer NONE)
             
    in
        (f () handle e => (reset (); raise e))
        before reset ()
    end
    end
end

structure ProdCons : sig
    val main : string * string list -> OS.Process.status
end = struct

    fun doit n = let

    val c_running = Mutex.mutex ()
    val p_running = Mutex.mutex ()

    val consumer's_turn = ref false
    val data = ref 0
    val produced = ref 0
    val consumed = ref 0
    val m = Mutex.mutex ()
    val c = Mutex.condition m

    fun producer () = let
        fun wait () = if !consumer's_turn then wait (Mutex.wait c) else ()
        fun loop i =
        if i <= n then
            let val _ = Mutex.lock m
            val _ = wait ()
            in
            data := i;
            consumer's_turn := true;
            produced := !produced + 1;
            Mutex.signal c;
            Mutex.unlock m;
            loop (i + 1)
            end
        else ()
    in
        loop 1 before Mutex.unlock p_running
    end
         
    fun consumer () = let
        fun wait () = if !consumer's_turn then () else wait (Mutex.wait c)
        fun loop () = let
        val _ = Mutex.lock m
        val _ = wait ()
        val i = !data
        in
        consumer's_turn := false;
        consumed := !consumed + 1;
        Mutex.signal c;
        Mutex.unlock m;
        if i <> n then loop () else ()
        end
    in
        loop () before Mutex.unlock c_running
    end

    
    val _ = Mutex.lock p_running
    val _ = Mutex.lock c_running

    val p = Mutex.fork producer
    val c = Mutex.fork consumer
    in
    
    Mutex.lock p_running;
    Mutex.lock c_running;

    TextIO.output (TextIO.stdOut,
               concat [Int.toString (!produced), " ",
                   Int.toString (!consumed), "\n"])
    end

    fun main (_, args) = let
    val n = case args of [] => 1
               | (x :: _) => getOpt (Int.fromString x, 1)
    in
    Mutex.run (fn () => doit n, Time.fromMilliseconds 1);
    OS.Process.success
    end
end

val _ = SMLofNJ.exportFn("prodcons", ProdCons.main);