[SYSTEMML-2256] Exploit native matrix mult in dense wsigmoid 

This patch improves the performance of the fused dense wsigmoid on
modern processors with wide SIMD registers and fma by exploiting a
native matrix mult as part of the larger wsigmod computation.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/f6e3a91d
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/f6e3a91d
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/f6e3a91d

Branch: refs/heads/master
Commit: f6e3a91dfe8852839be6b471ba0eabc723f55bf1
Parents: 02e5ba5
Author: Matthias Boehm <[email protected]>
Authored: Wed Apr 18 21:26:01 2018 -0700
Committer: Matthias Boehm <[email protected]>
Committed: Wed Apr 18 21:40:17 2018 -0700

----------------------------------------------------------------------
 .../runtime/matrix/data/LibMatrixMult.java      | 43 +++++++++++++++++++-
 .../runtime/matrix/data/LibMatrixNative.java    |  2 +-
 2 files changed, 43 insertions(+), 2 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/f6e3a91d/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java
index d9e741a..3a2c58e 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java
@@ -42,6 +42,7 @@ import org.apache.sysml.runtime.functionobjects.ValueFunction;
 import org.apache.sysml.runtime.matrix.operators.ReorgOperator;
 import org.apache.sysml.runtime.util.CommonThreadPool;
 import org.apache.sysml.runtime.util.UtilFunctions;
+import org.apache.sysml.utils.NativeHelper;
 
 /**
  * MB: Library for matrix multiplications including MM, MV, VV for all
@@ -557,7 +558,13 @@ public class LibMatrixMult
                ret.allocateBlock();
                
                //core weighted square sum mm computation
-               if( !mW.sparse && !mU.sparse && !mV.sparse && 
!mU.isEmptyBlock() && !mV.isEmptyBlock() )
+               boolean allDense = !mW.sparse && !mU.sparse && !mV.sparse
+                       && !mU.isEmptyBlock() && !mV.isEmptyBlock();
+               if( NativeHelper.isNativeLibraryLoaded() && allDense && 
(mW.rlen == 1 || mW.clen == 1) 
+                       && !LibMatrixNative.isMatMultMemoryBound(mU.rlen, 
mU.clen, mV.rlen)
+                       && mW.getDenseBlock().isContiguous() && 
mU.getDenseBlock().isContiguous() && mV.getDenseBlock().isContiguous() )
+                       matrixMultWSigmoidDenseNative(mW, mU, mV, ret, wt);
+               else if( allDense )
                        matrixMultWSigmoidDense(mW, mU, mV, ret, wt, 0, 
mW.rlen);
                else if( mW.sparse && !mU.sparse && !mV.sparse && 
!mU.isEmptyBlock() && !mV.isEmptyBlock())
                        matrixMultWSigmoidSparseDense(mW, mU, mV, ret, wt, 0, 
mW.rlen);
@@ -2384,6 +2391,30 @@ public class LibMatrixMult
                        dotProduct(tmp1.getDenseBlockValues(), 
tmp2.getDenseBlockValues(), mU.clen*mU.clen)));
        }
 
+       private static void matrixMultWSigmoidDenseNative(MatrixBlock mW, 
MatrixBlock mU, MatrixBlock mV, MatrixBlock ret, WSigmoidType wt) {
+               double[] w = mW.getDenseBlockValues();
+               double[] c = ret.getDenseBlockValues();
+               final int m = mW.rlen, n = mW.clen;
+               final int cd = mU.clen;
+               boolean flagminus = (wt==WSigmoidType.MINUS || 
wt==WSigmoidType.LOG_MINUS); 
+               boolean flaglog = (wt==WSigmoidType.LOG || 
wt==WSigmoidType.LOG_MINUS);
+               
+               //call native matrix multiplication (only called for 
single-threaded and matrix-vector
+               //because this ensures that we can deal with the transpose mV 
without additional transpose)
+               if(!NativeHelper.dmmdd(((m==1)?mV:mU).getDenseBlockValues(),
+                       ((m==1)?mU:mV).getDenseBlockValues(), c, (m==1)?n:m, 
cd, 1, 1) )
+                       throw new DMLRuntimeException("Error executing native 
matrix mult.");
+               
+               //compute remaining wsigmoid for all relevant outputs
+               for(int i=0, ix=0; i<m; i++, ix+=n) {
+                       for(int j=0; j<n; j++) {
+                               double wij = w[ix +j];
+                               //if( wij != 0 )
+                                       c[ix+j] = wsigmoid(wij, c[ix+j], 
flagminus, flaglog);
+                       }
+               }
+       }
+       
        private static void matrixMultWSigmoidDense(MatrixBlock mW, MatrixBlock 
mU, MatrixBlock mV, MatrixBlock ret, WSigmoidType wt, int rl, int ru) {
                DenseBlock w = mW.getDenseBlock();
                DenseBlock c = ret.getDenseBlock();
@@ -3462,6 +3493,16 @@ public class LibMatrixMult
                //compute weighted output
                return wij * ((flaglog) ? Math.log(cval) : cval);
        }
+       
+       private static double wsigmoid(final double wij, final double uvij, 
final boolean flagminus, final boolean flaglog) {
+               //compute core sigmoid function
+               double cval = flagminus ?
+                               1 / (1 + FastMath.exp(uvij)) :
+                               1 / (1 + FastMath.exp(-uvij));
+               
+               //compute weighted output
+               return wij * ((flaglog) ? Math.log(cval) : cval);
+       }
 
        private static void wdivmm( final double wij, double[] u, double[] v, 
double[] c, final int uix, final int vix, final boolean left, final boolean 
mult, final boolean minus, final int len )
        {

http://git-wip-us.apache.org/repos/asf/systemml/blob/f6e3a91d/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixNative.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixNative.java 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixNative.java
index 1d46927..eade43f 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixNative.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixNative.java
@@ -48,7 +48,7 @@ public class LibMatrixNative
        
        // We could encapsulate heuristics in this function
        // For now, we only consider matrix-vector operation to be memory bound
-       private static boolean isMatMultMemoryBound(int m1Rlen, int m1Clen, int 
m2Clen) {
+       public static boolean isMatMultMemoryBound(int m1Rlen, int m1Clen, int 
m2Clen) {
                return (m1Rlen == 1 || m1Clen == 1 || m2Clen == 1)
                        && (8L*m1Rlen*m1Clen > 16 * LibMatrixMult.L3_CACHESIZE 
                                || 8L*m1Clen*m2Clen > 16 * 
LibMatrixMult.L3_CACHESIZE);

Reply via email to