Hello, I'm writing some matrix multiplication and inversion functions for small matrices (3x3 and 4x4 mostly, for 3d graphics, modeling, simulation, etc.) I noticed that the matrix multiplication was a bottleneck so I set out to optimize and found that using unsafeRead instead of (!) (or readArray in stateful code) helped a lot. So then I went to optimize my gaussian elimination function and found just the opposite. unsafeRead is slower than readArray. This struck me as very odd considering that readArray calls unsafeRead.
If there is a "good" reason why the compiler optimized readArray better than unsafeRead, I'd like to know what it is so that I can make all my array code safe as well as fast. (By "good" reason I mean something deterministic and repeatable, not just luck.) On the otherhand, if this is a fluke, I'm inclined to think that it's not the safe code which is freakishly fast, but the unsafe code which is needlessly slow. That is, something about my program is hindering optimization of the unsafe code. What is it? Attached is the profiling results and a test program with a handful of matrix multiplication and gaussian elimination functions to illustrate what I've seen. This happens both on amd64 and intel core architectures. Thanks for any insight, Scott
{- total time = 1196.70 secs (23934 ticks @ 50 ms) total alloc = 419,350,893,280 bytes (excludes profiling overheads) COST CENTRE MODULE %time %alloc ticks bytes matrixMultSafe Main 26.9 31.4 6449 16450164500 gaussElimUnsafe' Main 21.4 20.4 5126 10687510443 gaussElim2Unsafe Main 17.4 17.9 4173 9366179268 gaussElimSafe' Main 16.2 16.0 3881 8377362484 gaussElim2Safe Main 13.4 13.1 3207 6869088980 matrixMultUnsafe Main 3.5 0.8 829 400004000 -} {-# OPTIONS_GHC -O2 -optc-O2 -fglasgow-exts -fbang-patterns #-} import Control.Monad import Data.List import Data.Array.IO import Data.Array.Unboxed import System import System.IO.Unsafe import System.Random import Data.Array.Base matrixMultUnsafe n a b = unsafePerformIO $ do c <- newArray_ ((1,1),(n,n)) :: IO (IOUArray (Int,Int) Double) let f !i !j !k !s | (k==n) = s f !i !j !k !s = f i j (k+1) $ s + (a`unsafeAt`(n*i+k))*(b`unsafeAt`(n*k+j)) jloop !i !j | (j==n) = return() jloop !i !j = do unsafeWrite c (i*n+j) (f i j 0 0) ; jloop i (j+1) iloop !i | (i==n) = return() iloop !i = do jloop i 0; iloop (i+1) iloop 0 unsafeFreeze c matrixMultSafe n a b = unsafePerformIO $ do c <- newArray_ ((1,1),(n,n)) :: IO (IOUArray (Int,Int) Double) let f !i !j !k !s | (k>n) = s f !i !j !k !s = f i j (k+1) $ s + (a!(i,k))*(b!(k,j)) jloop !i !j | (j>n) = return() jloop !i !j = do writeArray c (i,j) (f i j 1 0) ; jloop i (j+1) iloop !i | (i>n) = return() iloop !i = do jloop i 1; iloop (i+1) iloop 1 unsafeFreeze c gaussElimUnsafe matrix = do ((i1,j1),(m,n)) <- getBounds matrix gaussElimUnsafe' matrix (m-i1+1) (n-j1+1) gaussElimUnsafe' matrix m n = doColumn 0 0 where doColumn !i !j | (i==m||j==n) = return() doColumn !i !j = do (pivotRow,pivotVal) <- findPivot i j if nearZero pivotVal then doColumn i (j+1) else do swapRowsAndDivideByPivot i pivotRow pivotVal subtractRows i j doColumn (i+1) (j+1) findPivot !i !j = f i (i,0) where f !i (!maxi,!maxe) | (i==m) = return (maxi,maxe) f !i (!maxi,!maxe) = do e <- unsafeRead matrix (i*n+j) f (i+1) $ if abs e > abs maxe then (i,e) else (maxi,maxe) swapRowsAndDivideByPivot !i !pr !pv = f 0 where f !j | (j==n) = return () f !j = do ei <- unsafeRead matrix (i *n+j) ep <- unsafeRead matrix (pr*n+j) unsafeWrite matrix (i *n+j) ep unsafeWrite matrix (pr*n+j) (ei/pv) f (j+1) subtractRows !i !j = f 0 where f !u | (u==m) = return () f !u | (u==i) = f (u+1) f !u = do s <- unsafeRead matrix (u*n+j) g s u 0 f (u+1) g _ _ !j | (j==n) = return () g !s !u !j = do ei <- unsafeRead matrix (i*n+j) eu <- unsafeRead matrix (u*n+j) unsafeWrite matrix (u*n+j) (eu - s*ei) g s u (j+1) -------------------------------------------------- gaussElimSafe matrix = do bnds <- getBounds matrix gaussElimSafe' matrix bnds gaussElimSafe' matrix ((i1,j1),(m,n)) = doColumn i1 j1 where doColumn !i !j | (i>m||j>n) = return() doColumn !i !j = do (pivotRow,pivotVal) <- findPivot i j if nearZero pivotVal then doColumn i (j+1) else do swapRowsAndDivideByPivot i pivotRow pivotVal subtractRows i j doColumn (i+1) (j+1) findPivot !i !j = f i (i,0) where f !i (!maxi,!maxe) | i>m = return (maxi,maxe) f !i (!maxi,!maxe) = do e <- readArray matrix (i,j) f (i+1) $ if abs e > abs maxe then (i,e) else (maxi,maxe) swapRowsAndDivideByPivot !i !pr !pv = f j1 where f !j | j>n = return () f !j = do ei <- readArray matrix (i ,j) ep <- readArray matrix (pr,j) writeArray matrix (i ,j) (ep/pv) writeArray matrix (pr,j) ei f (j+1) subtractRows !i !j = f i1 where f !u | u>m = return () f !u | u==i = f (u+1) f !u = do s <- readArray matrix (u,j) g s u j1 f (u+1) g _ _ !j | j>n = return () g !s !u !j = do ei <- readArray matrix (i,j) eu <- readArray matrix (u,j) writeArray matrix (u,j) (eu - s*ei) g s u (j+1) ------------------------------------------------------ gaussElim2Unsafe m n matrix = do _ <- fold1M doColumn [0..n-1] return () --matrix where doColumn i j | i==m = return i doColumn i j = do (pivotRow,pivotVal) <- findPivot i j if nearZero pivotVal then return i else do swapRows i pivotRow divideRow i pivotVal mapM_ (\i' -> do e <- unsafeRead matrix (i'*n+j); subtractRow i (e,i')) [0..m-1] return (i+1) findPivot i j = do pivotRow <- fold1M (\ ia ib -> do ea <- unsafeRead matrix (ia*n+j) eb <- unsafeRead matrix (ib*n+j) if abs ea > abs eb then return ia else return ib ) [i..m-1] pivotVal <- unsafeRead matrix (pivotRow*n+j) return (pivotRow,pivotVal) swapRows ia ib = unless (ia == ib) $ mapM_ f [0..n-1] where f j = do ea <- unsafeRead matrix (ia*n+j) eb <- unsafeRead matrix (ib*n+j) unsafeWrite matrix (ia*n+j) eb unsafeWrite matrix (ib*n+j) ea -- subtract s*row(ia) from row(ib) subtractRow ia (s,ib) = unless (ia == ib) $ mapM_ f [0..n-1] where f j = do ea <- unsafeRead matrix (ia*n+j) eb <- unsafeRead matrix (ib*n+j) unsafeWrite matrix (ib*n+j) (eb - s*ea) --divide row(i) by s divideRow i s = mapM_ f [0..n-1] where f j = do e <- unsafeRead matrix (i*n+j) unsafeWrite matrix (i*n+j) (e/s) ---------------------------------------------------------------------- gaussElim2Safe matrix ((i1,j1),(m,n)) = do _ <- fold1M doColumn [j1..n] return () --matrix where doColumn i j | i > m = return i doColumn i j = do (pivotRow,pivotVal) <- findPivot i j if nearZero pivotVal then return i else do swapRows i pivotRow divideRow i pivotVal mapM_ (\i' -> do e <- readArray matrix (i',j); subtractRow i (e,i')) [i1..m] return (i+1) findPivot i j = do pivotRow <- fold1M (\ ia ib -> do ea <- readArray matrix (ia,j) eb <- readArray matrix (ib,j) if abs ea > abs eb then return ia else return ib ) [i..m] pivotVal <- readArray matrix (pivotRow,j) return (pivotRow,pivotVal) swapRows ia ib = unless (ia == ib) $ mapM_ f [j1..n] where f j = do ea <- readArray matrix (ia,j) eb <- readArray matrix (ib,j) writeArray matrix (ia,j) eb writeArray matrix (ib,j) ea -- subtract s*row(ia) from row(ib) subtractRow ia (s,ib) = unless (ia == ib) $ mapM_ f [j1..n] where f j = do ea <- readArray matrix (ia,j) eb <- readArray matrix (ib,j) writeArray matrix (ib,j) (eb - s*ea) --divide row(i) by s divideRow i s = mapM_ f [j1..n] where f j = do e <- readArray matrix (i,j) writeArray matrix (i,j) (e/s) --------------------------------------------------------------------- fold1M f xs = foldM f (head xs) xs fold1M_ f xs = fold1M f xs >> return () nearZero x = abs x < 1e-5 numItrs = 100 main = do rngs <- sequence (replicate numItrs newStdGen) putStrLn "mulMatrixUnsafe==================" forM_ rngs $ \rng -> do let xs = randoms rng a = listArray ((1,1),(4,4)) (take 16 xs) :: UArray (Int,Int) Double b = foldl' (matrixMultUnsafe 4) a (replicate 100000 a) print a print $ (matrixMultUnsafe 4 a a :: UArray (Int,Int) Double) print b putStrLn "mulMatrixSafe==================" forM_ rngs $ \rng -> do let xs = randoms rng a = listArray ((1,1),(4,4)) (take 16 xs) :: UArray (Int,Int) Double b = foldl' (matrixMultSafe 4) a (replicate 100000 a) print a print $ (matrixMultSafe 4 a a :: UArray (Int,Int) Double) print b putStrLn "gaussElimSafe==================" forM_ rngs $ \rng -> do a <- makeMatrix rng forM_ [1..10000] $ \_ -> gaussElimSafe a printMatrix a putStrLn "gaussElimUnsafe==================" forM_ rngs $ \rng -> do a <- makeMatrix rng forM_ [1..10000] $ \_ -> gaussElimUnsafe a printMatrix a putStrLn "gaussElim2Safe==================" forM_ rngs $ \rng -> do a <- makeMatrix rng forM_ [1..10000] $ \_ -> gaussElim2Safe a ((1,1),(4,8)) printMatrix a putStrLn "gaussElim2Unsafe==================" forM_ rngs $ \rng -> do a <- makeMatrix rng forM_ [1..10000] $ \_ -> gaussElim2Unsafe 4 8 a printMatrix a makeMatrix rng = do let xs = randoms rng m <- newListArray ((1,1),(4,8)) (take 32 xs) :: IO (IOUArray (Int,Int) Double) forM_ [1..4] $ \i -> do forM_ [5..8] $ \j -> writeArray m (i,j) (if j==i+4 then 1 else 0) return m printMatrix :: IOUArray (Int,Int) Double -> IO () printMatrix m = do ((r1,c1),(rm,cn)) <- getBounds m forM_ [ (i,j) | i<-[r1..rm],j<-[c1..cn] ] $ \(i,j) -> do a <- readArray m (i,j) putStr $ (show a) ++ " " ++ (if j == cn then "\n" else "") putStr "\n"
_______________________________________________ Glasgow-haskell-users mailing list Glasgow-haskell-users@haskell.org http://www.haskell.org/mailman/listinfo/glasgow-haskell-users