Hello,

I'm writing some matrix multiplication and inversion functions for
small matrices (3x3 and 4x4 mostly, for 3d graphics, modeling,
simulation, etc.)  I noticed that the matrix multiplication was a
bottleneck so I set out to optimize and found that using unsafeRead
instead of (!) (or readArray in stateful code) helped a lot. So then I
went to optimize my gaussian elimination function and found just the
opposite. unsafeRead is slower than readArray. This struck me as very
odd considering that readArray calls unsafeRead.

If there is a "good" reason why the compiler optimized readArray
better than unsafeRead, I'd like to know what it is so that I can make
all my array code safe as well as fast. (By "good" reason I mean
something deterministic and repeatable, not just luck.)

On the otherhand, if this is a fluke, I'm inclined to think that it's
not the safe code which is freakishly fast, but the unsafe code which
is needlessly slow. That is, something about my program is hindering
optimization of the unsafe code. What is it?

Attached is the profiling results and a test program with a handful of
matrix multiplication and gaussian elimination functions to illustrate
what I've seen. This happens both on amd64 and intel core
architectures.


Thanks for any insight,
Scott
{-
	total time  =     1196.70 secs   (23934 ticks @ 50 ms)
	total alloc = 419,350,893,280 bytes  (excludes profiling overheads)

COST CENTRE                    MODULE               %time %alloc  ticks     bytes

matrixMultSafe                 Main                  26.9   31.4   6449 16450164500
gaussElimUnsafe'               Main                  21.4   20.4   5126 10687510443
gaussElim2Unsafe               Main                  17.4   17.9   4173 9366179268
gaussElimSafe'                 Main                  16.2   16.0   3881 8377362484
gaussElim2Safe                 Main                  13.4   13.1   3207 6869088980
matrixMultUnsafe               Main                   3.5    0.8    829 400004000
-}


{-# OPTIONS_GHC -O2 -optc-O2 -fglasgow-exts -fbang-patterns #-}

import Control.Monad
import Data.List
import Data.Array.IO
import Data.Array.Unboxed
import System
import System.IO.Unsafe
import System.Random
import Data.Array.Base 

matrixMultUnsafe n a b = unsafePerformIO $ 
  do
  c <- newArray_ ((1,1),(n,n)) :: IO (IOUArray (Int,Int) Double)
  let
    f !i !j !k !s | (k==n) = s
    f !i !j !k !s = f i j (k+1) $ s + (a`unsafeAt`(n*i+k))*(b`unsafeAt`(n*k+j))
    jloop !i !j | (j==n) = return()
    jloop !i !j = do unsafeWrite c (i*n+j) (f i j 0 0) ; jloop i (j+1)
    iloop !i | (i==n) = return()
    iloop !i = do jloop i 0; iloop (i+1)
  iloop 0
  unsafeFreeze c

matrixMultSafe n a b = unsafePerformIO $ 
  do
  c <- newArray_ ((1,1),(n,n)) :: IO (IOUArray (Int,Int) Double)
  let
    f !i !j !k !s | (k>n)     = s
    f !i !j !k !s = f i j (k+1) $ s + (a!(i,k))*(b!(k,j))
    jloop !i !j | (j>n) = return()
    jloop !i !j = do writeArray c (i,j) (f i j 1 0) ; jloop i (j+1)
    iloop !i | (i>n) = return()
    iloop !i = do jloop i 1; iloop (i+1)
  iloop 1
  unsafeFreeze c


gaussElimUnsafe matrix =
  do
  ((i1,j1),(m,n)) <- getBounds matrix
  gaussElimUnsafe' matrix (m-i1+1) (n-j1+1)

gaussElimUnsafe' matrix m n = doColumn 0 0
  where
    doColumn !i !j | (i==m||j==n)  = return()
    doColumn !i !j = 
      do 
      (pivotRow,pivotVal) <- findPivot i j
      if nearZero pivotVal
        then doColumn i (j+1)
        else 
          do 
          swapRowsAndDivideByPivot i pivotRow pivotVal
          subtractRows i j
          doColumn (i+1) (j+1)

    findPivot !i !j = f i (i,0)
      where 
        f !i (!maxi,!maxe) | (i==m) = return (maxi,maxe)
        f !i (!maxi,!maxe) = 
          do 
          e <- unsafeRead matrix (i*n+j)
          f (i+1) $ if abs e > abs maxe then (i,e) else (maxi,maxe)

    swapRowsAndDivideByPivot !i !pr !pv = f 0
      where
        f !j | (j==n) = return ()
        f !j =
          do 
          ei <- unsafeRead matrix (i *n+j)
          ep <- unsafeRead matrix (pr*n+j)
          unsafeWrite matrix (i *n+j) ep
          unsafeWrite matrix (pr*n+j) (ei/pv)
          f (j+1)

    subtractRows !i !j = f 0
      where
        f !u | (u==m) = return ()
        f !u | (u==i) = f (u+1)
        f !u = 
          do 
          s <- unsafeRead matrix (u*n+j)
          g s u 0
          f (u+1)

        g  _  _ !j | (j==n) = return ()
        g !s !u !j = 
          do 
          ei <- unsafeRead matrix (i*n+j)
          eu <- unsafeRead matrix (u*n+j)
          unsafeWrite matrix (u*n+j) (eu - s*ei)
          g s u (j+1)

--------------------------------------------------

gaussElimSafe matrix =
  do
  bnds <- getBounds matrix
  gaussElimSafe' matrix bnds

gaussElimSafe' matrix ((i1,j1),(m,n)) = doColumn i1 j1
  where
    doColumn !i !j | (i>m||j>n) = return()
    doColumn !i !j = 
      do 
      (pivotRow,pivotVal) <- findPivot i j
      if nearZero pivotVal
        then doColumn i (j+1)
        else 
          do 
          swapRowsAndDivideByPivot i pivotRow pivotVal
          subtractRows i j
          doColumn (i+1) (j+1)

    findPivot !i !j = f i (i,0)
      where 
        f !i (!maxi,!maxe) | i>m = return (maxi,maxe)
        f !i (!maxi,!maxe) = 
          do 
          e <- readArray matrix (i,j)
          f (i+1) $ if abs e > abs maxe then (i,e) else (maxi,maxe)

    swapRowsAndDivideByPivot !i !pr !pv = f j1
      where
        f !j | j>n = return ()
        f !j =
          do 
          ei <- readArray matrix (i ,j)
          ep <- readArray matrix (pr,j)
          writeArray matrix (i ,j) (ep/pv)
          writeArray matrix (pr,j) ei
          f (j+1)

    subtractRows !i !j = f i1
      where
        f !u | u>m  = return ()
        f !u | u==i = f (u+1)
        f !u = 
          do 
          s <- readArray matrix (u,j)
          g s u j1
          f (u+1)

        g  _  _ !j | j>n = return ()
        g !s !u !j = 
          do 
          ei <- readArray matrix (i,j)
          eu <- readArray matrix (u,j)
          writeArray matrix (u,j) (eu - s*ei)
          g s u (j+1)

------------------------------------------------------

gaussElim2Unsafe m n matrix =
  do
  _ <- fold1M doColumn [0..n-1]
  return () --matrix
  where
    doColumn i j | i==m = return i
    doColumn i j = 
      do (pivotRow,pivotVal) <- findPivot i j
         if nearZero pivotVal
            then return i
            else do swapRows i pivotRow
                    divideRow i pivotVal
                    mapM_ (\i' -> do e <- unsafeRead matrix (i'*n+j); subtractRow i (e,i')) [0..m-1] 
                    return (i+1)

    findPivot i j =
      do pivotRow <- fold1M
           (\ ia ib ->
             do ea <- unsafeRead matrix (ia*n+j)
                eb <- unsafeRead matrix (ib*n+j)
                if abs ea > abs eb
                    then return ia
                    else return ib
           ) [i..m-1]
         pivotVal <- unsafeRead matrix (pivotRow*n+j)
         return (pivotRow,pivotVal)

    swapRows ia ib = unless (ia == ib) $ mapM_ f [0..n-1]
      where f j = do ea <- unsafeRead matrix (ia*n+j)
                     eb <- unsafeRead matrix (ib*n+j)
                     unsafeWrite matrix (ia*n+j) eb
                     unsafeWrite matrix (ib*n+j) ea

    -- subtract s*row(ia) from row(ib)
    subtractRow ia (s,ib) = unless (ia == ib) $ mapM_ f [0..n-1]
      where f j = do ea <- unsafeRead matrix (ia*n+j)
                     eb <- unsafeRead matrix (ib*n+j)
                     unsafeWrite matrix (ib*n+j) (eb - s*ea)

    --divide row(i) by s
    divideRow i s = mapM_ f [0..n-1]
      where f j = do e <- unsafeRead matrix (i*n+j)
                     unsafeWrite matrix (i*n+j) (e/s)
----------------------------------------------------------------------


gaussElim2Safe matrix ((i1,j1),(m,n)) =
  do
  _ <- fold1M doColumn [j1..n]
  return () --matrix
  where
    doColumn i j | i > m = return i
    doColumn i j = 
      do (pivotRow,pivotVal) <- findPivot i j
         if nearZero pivotVal
            then return i
            else do swapRows i pivotRow
                    divideRow i pivotVal
                    mapM_ (\i' -> do e <- readArray matrix (i',j); subtractRow i (e,i')) [i1..m]
                    return (i+1)

    findPivot i j =
      do pivotRow <- fold1M
           (\ ia ib ->
             do ea <- readArray matrix (ia,j)
                eb <- readArray matrix (ib,j)
                if abs ea > abs eb
                    then return ia
                    else return ib
           ) [i..m]
         pivotVal <- readArray matrix (pivotRow,j)
         return (pivotRow,pivotVal)

    swapRows ia ib = unless (ia == ib) $ mapM_ f [j1..n]
      where f j = do ea <- readArray matrix (ia,j)
                     eb <- readArray matrix (ib,j)
                     writeArray matrix (ia,j) eb
                     writeArray matrix (ib,j) ea

    -- subtract s*row(ia) from row(ib)
    subtractRow ia (s,ib) = unless (ia == ib) $ mapM_ f [j1..n]
      where f j = do ea <- readArray matrix (ia,j)
                     eb <- readArray matrix (ib,j)
                     writeArray matrix (ib,j) (eb - s*ea)

    --divide row(i) by s
    divideRow i s = mapM_ f [j1..n]
      where f j = do e <- readArray matrix (i,j)
                     writeArray matrix (i,j) (e/s)
---------------------------------------------------------------------

fold1M f xs = foldM f (head xs) xs
fold1M_ f xs = fold1M f xs >> return ()





nearZero x = abs x < 1e-5

numItrs = 100

main =
  do

  rngs <- sequence (replicate numItrs newStdGen)

  putStrLn "mulMatrixUnsafe=================="

  forM_ rngs $ \rng ->
    do
    let
      xs = randoms rng
      a = listArray ((1,1),(4,4)) (take 16 xs) :: UArray (Int,Int) Double 
      b = foldl' (matrixMultUnsafe 4) a (replicate 100000 a)
    print a
    print $ (matrixMultUnsafe 4 a a :: UArray (Int,Int) Double)
    print b

  putStrLn "mulMatrixSafe=================="

  forM_ rngs $ \rng ->
    do
    let
      xs = randoms rng
      a = listArray ((1,1),(4,4)) (take 16 xs) :: UArray (Int,Int) Double 
      b = foldl' (matrixMultSafe 4) a (replicate 100000 a)
    print a
    print $ (matrixMultSafe 4 a a :: UArray (Int,Int) Double)
    print b


  putStrLn "gaussElimSafe=================="

  forM_ rngs $ \rng ->
    do
    a <- makeMatrix rng
    forM_ [1..10000] $ \_ -> gaussElimSafe a
    printMatrix a

  putStrLn "gaussElimUnsafe=================="

  forM_ rngs $ \rng ->
    do
    a <- makeMatrix rng
    forM_ [1..10000] $ \_ -> gaussElimUnsafe a
    printMatrix a

  putStrLn "gaussElim2Safe=================="

  forM_ rngs $ \rng ->
    do
    a <- makeMatrix rng
    forM_ [1..10000] $ \_ -> gaussElim2Safe a ((1,1),(4,8))
    printMatrix a

  putStrLn "gaussElim2Unsafe=================="

  forM_ rngs $ \rng ->
    do
    a <- makeMatrix rng
    forM_ [1..10000] $ \_ -> gaussElim2Unsafe 4 8 a
    printMatrix a


makeMatrix rng =
  do
  let xs = randoms rng
  m <- newListArray ((1,1),(4,8)) (take 32 xs) :: IO (IOUArray (Int,Int) Double)
  forM_ [1..4] $ \i ->
    do
    forM_ [5..8] $ \j -> writeArray m (i,j) (if j==i+4 then 1 else 0)
  return m

printMatrix :: IOUArray (Int,Int) Double -> IO ()
printMatrix m =
  do
  ((r1,c1),(rm,cn)) <- getBounds m
  forM_ [ (i,j) | i<-[r1..rm],j<-[c1..cn] ] $ \(i,j) ->
    do
    a <- readArray m (i,j)
    putStr $ (show a) ++ " " ++ (if j == cn then "\n" else "")
  putStr "\n"



_______________________________________________
Glasgow-haskell-users mailing list
Glasgow-haskell-users@haskell.org
http://www.haskell.org/mailman/listinfo/glasgow-haskell-users

Reply via email to