(* -*- mode: sml -*-
 * $Id: matrix.smlnj,v 1.3 2001/07/11 01:40:04 doug Exp $
 * http://www.bagley.org/~doug/shootout/
 * from Stephen Weeks
 *)


structure Test : sig
    val main : (string * string list) -> OS.Process.status
end = struct

fun incr r = r := !r + 1
fun for (start, stop, f) =
   let
      fun loop i =
     if i > stop
        then ()
     else (f i; loop (i + 1))
   in
      loop start
   end

structure Array2 =
   struct
      datatype 'a t = T of 'a array array

      fun sub (T a, r, c) = Array.sub (Array.sub (a, r), c)
      fun subr (T a, r) =
     let val a = Array.sub (a, r)
     in fn c => Array.sub (a, c)
     end
      fun update (T a, r, c, x) = Array.update (Array.sub (a, r), c, x)
      fun array (r, c, x) =
     T (Array.tabulate (r, fn _ => Array.array (c, x)))
   end
val sub = Array2.sub
val update = Array2.update
   
val size = 30

fun mkmatrix (rows, cols) =
   let
      val count = ref 1
      val last_col = cols - 1
      val m = Array2.array (rows, cols, 0)
   in
      for (0, rows - 1, fn i =>
       for (0, last_col, fn j =>
        (update (m, i, j, !count)
         ; incr count)));
      m
   end

fun mmult (rows, cols, m1, m2, m3) =
   let
      val last_col = cols - 1
      val last_row = rows - 1
   in
      for (0, last_row, fn i =>
       for (0, last_col, fn j =>
        update (m3, i, j,
            let
               val m1i = Array2.subr (m1, i)
               fun loop (k, sum) =
                  if k < 0
                 then sum
                  else loop (k - 1,
                     sum + m1i k * sub (m2, k, j))
            in loop (last_row, 0)
            end)))
   end

fun atoi s = case Int.fromString s of SOME num => num | NONE => 0;
fun printl [] = print "\n" | printl(h::t) = ( print h ; printl t );

fun main (name, args) =
  let
     val n = atoi (hd (args @ ["1"]))
     val m1 = mkmatrix (size, size)
     val m2 = mkmatrix (size, size)
     val m3 = Array2.array (size, size, 0)
     val _ = for (1, n - 1, fn _ => mmult (size, size, m1, m2, m3))
     val _ = mmult (size, size, m1, m2, m3)
  in
     printl [Int.toString (sub (m3, 0, 0)),
         " ",
         Int.toString (sub (m3, 2, 3)),
         " ",
         Int.toString (sub (m3, 3, 2)),
         " ",
         Int.toString (sub (m3, 4, 4))];
     OS.Process.success
  end
end

val _ = SMLofNJ.exportFn("matrix", Test.main)