(* -*- mode: sml -*-
* $Id: prodcons.smlnj,v 1.1 2001/09/01 01:40:21 doug Exp $
* http://www.bagley.org/~doug/shootout/
* from Matthias Blume
*)
(* producer-consumer threads in SML/NJ
* (concurrency primitives re-implemented "by hand" on top of call/cc
* using the code in John Reppy's book "Concurrent Programming in ML")
*
* (C) 2001 Lucent Technologies, Bell Labs
* written by Matthias Blume
*)
structure Queue :> sig
exception Empty
type tt = unit SMLofNJ.Cont.cont
type q
val new : unit -> q
val enqueue : q * tt -> unit
val dequeue : q -> tt
val empty : q -> bool
end = struct
exception Empty
type tt = unit SMLofNJ.Cont.cont
type q = tt list ref * tt list ref
fun new () : q = (ref [], ref [])
fun enqueue ((f as ref [], ref []) : q, x) = f := [x]
| enqueue ((_, b as ref xs), x) = b := x :: xs
fun dequeue ((f, b) : q) =
case !f of
[] => (case rev (!b) of
x :: xs => (f := xs; b := []; x)
| [] => raise Empty)
| x :: xs => (f := xs; x)
fun empty ((ref [], ref []) : q) = true
| empty _ = false
end
structure Mutex :> sig
val yield : unit -> unit
val fork : (unit -> unit) -> unit
val exit : unit -> 'a
type mutex
type condition
val mutex : unit -> mutex
val lock : mutex -> unit
val unlock : mutex -> unit
val condition : mutex -> condition
val wait : condition -> unit
val signal : condition -> unit
val run : (unit -> unit) * Time.time -> unit
end = struct
local
structure C = SMLofNJ.Cont
structure Q = Queue
type tt = unit C.cont
(* We take the easy way out: Simply drop signals that
* arrive during an atomic section on the floor. This is
* enough for our purpose and simplifies the coding... *)
val atomicState = ref false
fun atomicBegin () = atomicState := true
fun atomicEnd () = atomicState := false
val readyQ : Q.q = Q.new ()
fun dispatch () = C.throw (Q.dequeue readyQ) ()
fun sigH (_: Signals.signal, _: int, k: tt) =
if !atomicState then k
else (Q.enqueue (readyQ, k); Q.dequeue readyQ)
in
fun yield () =
(atomicBegin ();
C.callcc (fn k => (Q.enqueue (readyQ, k); dispatch ()));
atomicEnd ())
fun exit () = (atomicBegin (); dispatch ())
fun fork f = let
val k = C.isolate (fn () => (atomicEnd ();
f () handle _ => ();
exit ()))
in
atomicBegin ();
Q.enqueue (readyQ, k);
atomicEnd ()
end
datatype mutex = Mutex of { locked : bool ref, blocked : Q.q }
fun mutex () = Mutex { locked = ref false, blocked = Q.new () }
fun lock (Mutex { locked, blocked }) =
(atomicBegin ();
if !locked then
C.callcc (fn k => (Q.enqueue (blocked, k);
dispatch ()))
else locked := true;
atomicEnd ())
fun unlock (Mutex { locked, blocked }) =
(atomicBegin ();
if Q.empty blocked then locked := false
else C.callcc (fn k => (Q.enqueue (readyQ, k);
C.throw (Q.dequeue blocked) ()));
atomicEnd ())
datatype condition = Cond of { mutex : mutex, waiting : Q.q }
fun condition m = Cond { mutex = m, waiting = Q.new () }
fun wait (Cond { mutex = m as Mutex { locked, blocked }, waiting }) =
(atomicBegin ();
C.callcc (fn k =>
(Q.enqueue (waiting, k);
if Q.empty blocked then (locked := false;
dispatch ())
else C.throw (Q.dequeue blocked) ()));
if !locked then
C.callcc (fn k => (Q.enqueue (blocked, k);
dispatch ()))
else locked := true;
atomicEnd ())
fun signal (Cond { waiting, ... }) =
(atomicBegin ();
if Q.empty waiting then ()
else Q.enqueue (readyQ, Q.dequeue waiting);
atomicEnd ())
fun run (f, t) = let
val oh = Signals.setHandler (Signals.sigALRM,
Signals.HANDLER sigH)
val _ = SMLofNJ.IntervalTimer.setIntTimer (SOME t)
fun reset () =
(ignore (Signals.setHandler (Signals.sigALRM, oh));
SMLofNJ.IntervalTimer.setIntTimer NONE)
in
(f () handle e => (reset (); raise e))
before reset ()
end
end
end
structure ProdCons : sig
val main : string * string list -> OS.Process.status
end = struct
fun doit n = let
val c_running = Mutex.mutex ()
val p_running = Mutex.mutex ()
val consumer's_turn = ref false
val data = ref 0
val produced = ref 0
val consumed = ref 0
val m = Mutex.mutex ()
val c = Mutex.condition m
fun producer () = let
fun wait () = if !consumer's_turn then wait (Mutex.wait c) else ()
fun loop i =
if i <= n then
let val _ = Mutex.lock m
val _ = wait ()
in
data := i;
consumer's_turn := true;
produced := !produced + 1;
Mutex.signal c;
Mutex.unlock m;
loop (i + 1)
end
else ()
in
loop 1 before Mutex.unlock p_running
end
fun consumer () = let
fun wait () = if !consumer's_turn then () else wait (Mutex.wait c)
fun loop () = let
val _ = Mutex.lock m
val _ = wait ()
val i = !data
in
consumer's_turn := false;
consumed := !consumed + 1;
Mutex.signal c;
Mutex.unlock m;
if i <> n then loop () else ()
end
in
loop () before Mutex.unlock c_running
end
val _ = Mutex.lock p_running
val _ = Mutex.lock c_running
val p = Mutex.fork producer
val c = Mutex.fork consumer
in
Mutex.lock p_running;
Mutex.lock c_running;
TextIO.output (TextIO.stdOut,
concat [Int.toString (!produced), " ",
Int.toString (!consumed), "\n"])
end
fun main (_, args) = let
val n = case args of [] => 1
| (x :: _) => getOpt (Int.fromString x, 1)
in
Mutex.run (fn () => doit n, Time.fromMilliseconds 1);
OS.Process.success
end
end
val _ = SMLofNJ.exportFn("prodcons", ProdCons.main);