(* -*- mode: sml -*-
 * $Id: wordfreq.smlnj,v 1.3 2001/07/09 00:25:29 doug Exp $
 * http://www.bagley.org/~doug/shootout/
 * from Stephen Weeks
 *)

fun for (start, stop, f) =
   let
      fun loop i =
     if i > stop
        then ()
     else (f i; loop (i + 1))
   in
      loop start
   end
fun incr r = r := 1 + !r
val sub = Array.sub
val update = Array.update
   
signature HASH_SET =
   sig
      type 'a t

      val foreach: 'a t * ('a -> unit) -> unit
      (* lookupOrInsert (s, h, p, f)  looks in the set s for an entry with hash h
       * satisfying predicate p.  If the entry is there, it is returned.
       * Otherwise, the function f is called to create a new entry, which is
       * inserted and returned.
       *)
      val lookupOrInsert: 'a t * word * ('a -> bool) * (unit -> 'a) -> 'a
      val new: {hash: 'a -> word} -> 'a t
      val size: 'a t -> int
   end

structure HashSet: HASH_SET =
struct

datatype 'a t =
   T of {buckets: 'a list array ref,
     hash: 'a -> word,
     mask: word ref,
     numItems: int ref}

val initialSize: int = 65536
val initialMask: word = Word.fromInt initialSize - 0w1

fun 'a new {hash}: 'a t =
   T {buckets = ref (Array.array (initialSize, [])),
      hash = hash,
      numItems = ref 0,
      mask = ref initialMask}

fun size (T {numItems, ...}) = !numItems
fun numBuckets (T {buckets, ...}) = Array.length (!buckets)

fun index (w: word, mask: word): int =
   Word.toInt (Word.andb (w, mask))
   
fun resize (T {buckets, hash, mask, ...}, size: int, newMask: word): unit =
   let
      val newBuckets = Array.array (size, [])
   in Array.app (fn r =>
         List.app (fn a =>
                   let val j = index (hash a, newMask)
                   in Array.update
                  (newBuckets, j,
                   a :: Array.sub (newBuckets, j))
                   end) r) (!buckets)
      ; buckets := newBuckets
      ; mask := newMask
   end
              
fun maybeGrow (s as T {buckets, mask, numItems, ...}): unit =
   let
      val n = Array.length (!buckets)
   in if !numItems * 4 > n
     then resize (s,
              n * 2,
              
              Word.orb (0w1, Word.<< (!mask, 0w1)))
      else ()
   end

fun peekGen (T {buckets = ref buckets, mask, ...}, w, p, no, yes) =
   let
      val j = index (w, !mask)
      val b = Array.sub (buckets, j)
   in case List.find p b of
      NONE => no (j, b)
    | SOME a => yes a
   end

fun lookupOrInsert (table as T {buckets, numItems, ...}, w, p, f) =
   let
      fun no (j, b) =
     let val a = f ()
        val _ = incr numItems
        val _ = Array.update (!buckets, j, a :: b)
        val _ = maybeGrow table
     in a
     end
   in peekGen (table, w, p, no, fn x => x)
   end

fun foreach (T {buckets, ...}, f) =
   Array.app (fn r => List.app f r) (!buckets)

end

structure Buffer:
   sig
      type t

      val add: t * Word8.word -> unit
      val clear: t -> unit
      val contents: t -> string
      val new: int -> t
   end =
   struct
      datatype t = T of {elts: Word8Array.array ref,
             size: int ref}

      fun contents (T {elts, size, ...}) =
     Byte.bytesToString (Word8Array.extract (!elts, 0, SOME (!size)))

      fun clear (T {size, ...}) = size := 0

      fun new (bufSize) =
     T {elts = ref (Word8Array.array (bufSize, 0w0)),
        size = ref 0}

      fun add (T {elts, size}, x) =
     let
        val s = !size
        val _ = size := s + 1
        val a = !elts
        val n = Word8Array.length a
     in
        if s = n
           then
          let
             val a' =
            Word8Array.tabulate
            (2 * n, fn i =>
             if i < n then Word8Array.sub (a, i) else 0w0)
             val _ = elts := a'
             val _ = Word8Array.update (a', s, x)
          in ()
          end
        else Word8Array.update (a, s, x)
     end
   end

structure Quicksort:
   sig
      val quicksort: 'a array * ('a * 'a -> bool) -> unit
   end =
   struct
      fun assert (s, f: unit -> bool) =
     if true orelse f ()
        then ()
     else raise Fail (concat ["assert: ", s])

      fun forall (low, high, f) =
     let
        fun loop i = i > high orelse (f i andalso loop (i + 1))
     in
        loop low
     end

      fun fold (l, u, state, f) =
     let
        fun loop (i, state) =
           if i > u
          then state
           else loop (i + 1, f (i, state))
     in
        loop (l, state)
     end

      
      fun 'a isSorted (a: 'a array,
               lo: int,
               hi: int,
               op <= : 'a * 'a -> bool) =
     let
        fun loop (i, x) =
           i > hi
           orelse let
             val y = sub (a, i)
              in
             x <= y andalso loop (i + 1, y)
              end
     in
        lo >= hi orelse loop (lo + 1, sub (a, lo))
     end

      
      local
     open Word
     val seed = ref 0w13
      in
     fun rand () =
        let
           val res = 0w1664525 * !seed + 0w1013904223
           val _ = seed := res
        in
           toIntX res
        end
      end

      fun randInt (lo, hi) = lo + Int.mod (rand(), hi - lo + 1)

      
      fun insertionSort (a: 'a array, op <= : 'a * 'a -> bool): unit =
     let
        fun x i = sub (a, i)
     in
        for (1, Array.length a - 1, fn i =>
         let
            val _ =
               assert ("insertionSort1", fn () =>
                   isSorted (a, 0, i - 1, op <=))
            val t = x i
            fun sift (j: int) =
               (assert ("insertionSort2", fn () =>
                isSorted (a, 0, j - 1, op <=)
                andalso isSorted (a, j + 1, i, op <=)
                andalso forall (j + 1, i, fn k => t <= x k))
            ; if j > 0
                 then
                let
                   val j' = j - 1
                   val z = x j'
                in if t <= z
                      then (update (a, j, z);
                        sift j')
                   else j
                end
              else j)
            val _ = update (a, sift i, t)
         in ()
         end)
     end

      
      fun 'a quicksort (a: 'a array, op <= : 'a * 'a -> bool): unit =
     let
        fun x i = Array.sub (a, i)
        fun swap (i, j) =
           let
          val t = x i
          val _ = update (a, i, x j)
          val _ = update (a, j, t)
           in ()
           end
        val cutoff = 20
        fun qsort (l: int, u: int): unit =
           if u - l > cutoff
          then
             let
            val _ = swap (l, randInt (l, u))
            val t = x l
            val m =
               fold
               (l + 1, u, l, fn (i, m) =>
                (assert
                 ("qsort", fn () =>
                  forall (l + 1, m, fn k => x k <= t)
                  andalso forall (m + 1, i - 1, fn k => not (x k <= t)))
                 ; if x i <= t
                  then (swap (m + 1, i)
                    ; m + 1)
                   else m))
            val _ = swap (l, m)
            val _ = qsort (l, m - 1)
            val _ = qsort (m + 1, u)
             in ()
             end
           else ()
        val max = Array.length a - 1
        val _ = qsort (0, max)
        val _ = insertionSort (a, op <=)  
     in
        ()
     end
   end

structure Test : sig
    val main : (string * string list) -> OS.Process.status
end = struct

(* This hash function is taken from pages 56-57 of
 * The Practice of Programming by Kernighan and Pike.
 *)
fun hash (s: string): word =
   let
      val n = String.size s
      fun loop (i, w) =
     if i = n
        then w
     else Word.fromInt (Char.ord (String.sub (s, i))) + Word.* (w, 0w31)
   in
      loop (0, 0w0)
   end

fun hash (s: string): word =
   let
      val n = String.size s
      fun loop (i, w) =
     if i = n
        then w
     else loop (i + 1,
            Word.fromInt (Char.ord (String.sub (s, i)))
               + Word.* (w, 0w31))
   in
      loop (0, 0w0)
   end



val max = 4096
val buf = Word8Array.array (max, 0w0)
val count: {hash: word,
        word: string,
        count: int ref} HashSet.t = HashSet.new {hash = #hash}
val wbuf = Buffer.new 64

val c2b = Byte.charToByte
fun scan_words (i, n, inword) =
  if i < n
     then
    let
       val c = Word8Array.sub (buf, i)
    in
       if c2b #"a" <= c andalso c <= c2b #"z"
          then (Buffer.add (wbuf, c);
            scan_words (i + 1, n, true))
       else
          if c2b #"A" <= c andalso c <= c2b #"Z"
         then
            (Buffer.add (wbuf, c + 0w32);
             scan_words (i + 1, n, true))
          else
         if inword
            then 
               let
              val w = Buffer.contents wbuf
              val h = hash w
               in
              incr (#count
                (HashSet.lookupOrInsert
                 (count, h,
                  fn {hash, word, ...} =>
                  hash = h andalso word = w,
                  fn () => {hash = h, word = w, count = ref 0})));
              Buffer.clear wbuf;
              scan_words (i + 1, n, false)
               end
         else scan_words (i + 1, n, false)
    end
  else
     let
    val nread =
       Posix.IO.readArr (Posix.FileSys.stdin,
                 {buf = buf, i =  0, sz = NONE})
     in
    if nread = 0
       then ()
    else scan_words (0, nread, inword)
     end

fun printl [] = print "\n" | printl(h::t) = ( print h ; printl t )

fun rightJustify (s: string, width: int) =
   let
      val n = String.size s
   in concat [CharVector.tabulate (width - n, fn _ => #" "), s]
   end

fun main (name, args) =
   let
    val _ = scan_words (0, 0, false)
    val a = Array.array (HashSet.size count, (0, ""))
    val i = ref 0
    val _ = HashSet.foreach (count, fn {word, count, ...} =>
             (Array.update (a, !i, (!count, word)); incr i))
    val _ = Quicksort.quicksort (a, fn ((c, w), (c', w')) =>
                 c > c' orelse c = c' andalso w >= w')
    val _ = Array.app (fn (c, w) =>
           printl [rightJustify (Int.toString c, 7), "\t", w]) a
   in
      OS.Process.success
   end
end

val _ = SMLofNJ.exportFn("wordfreq", Test.main);