[SYSTEMML-2268] Performance native BLAS (dispatch DOT, GEMV, GEMM) This patch improves the performance of native BLAS matrix multiply operations for special (but common) cases of vector-vector dot products and matrix-vectors by dispatching BLAS calls to DOT, GEMV, and GEMM according to input dimensions instead of always calling GEMM. In detail, this applies to mm and tsmm operations.
Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/aff00094 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/aff00094 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/aff00094 Branch: refs/heads/master Commit: aff000942ec8bdcbec4e219a2e86cc5c85e7b2ea Parents: 2f278bc Author: Matthias Boehm <[email protected]> Authored: Fri Apr 20 23:58:51 2018 -0700 Committer: Matthias Boehm <[email protected]> Committed: Fri Apr 20 23:58:51 2018 -0700 ---------------------------------------------------------------------- .../cpp/lib/libsystemml_mkl-Linux-x86_64.so | Bin 32208 -> 32376 bytes .../lib/libsystemml_openblas-Linux-x86_64.so | Bin 31288 -> 31464 bytes src/main/cpp/libmatrixmult.cpp | 31 +++++++++++++++---- .../runtime/matrix/data/LibMatrixNative.java | 8 ++--- 4 files changed, 28 insertions(+), 11 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/aff00094/src/main/cpp/lib/libsystemml_mkl-Linux-x86_64.so ---------------------------------------------------------------------- diff --git a/src/main/cpp/lib/libsystemml_mkl-Linux-x86_64.so b/src/main/cpp/lib/libsystemml_mkl-Linux-x86_64.so index adc3bbe..fb1d33e 100755 Binary files a/src/main/cpp/lib/libsystemml_mkl-Linux-x86_64.so and b/src/main/cpp/lib/libsystemml_mkl-Linux-x86_64.so differ http://git-wip-us.apache.org/repos/asf/systemml/blob/aff00094/src/main/cpp/lib/libsystemml_openblas-Linux-x86_64.so ---------------------------------------------------------------------- diff --git a/src/main/cpp/lib/libsystemml_openblas-Linux-x86_64.so b/src/main/cpp/lib/libsystemml_openblas-Linux-x86_64.so index 0b39eaa..8905252 100755 Binary files a/src/main/cpp/lib/libsystemml_openblas-Linux-x86_64.so and b/src/main/cpp/lib/libsystemml_openblas-Linux-x86_64.so differ http://git-wip-us.apache.org/repos/asf/systemml/blob/aff00094/src/main/cpp/libmatrixmult.cpp ---------------------------------------------------------------------- diff --git a/src/main/cpp/libmatrixmult.cpp b/src/main/cpp/libmatrixmult.cpp index 773a85a..868fd24 100644 --- a/src/main/cpp/libmatrixmult.cpp +++ b/src/main/cpp/libmatrixmult.cpp @@ -42,18 +42,37 @@ void setNumThreadsForBLAS(int numThreads) { } void dmatmult(double* m1Ptr, double* m2Ptr, double* retPtr, int m, int k, int n, int numThreads) { + //BLAS routine dispatch according to input dimension sizes (we don't use cblas_dgemv + //with CblasColMajor for matrix-vector because it was generally slower than dgemm) setNumThreadsForBLAS(numThreads); - cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, m, n, k, 1, m1Ptr, k, m2Ptr, n, 0, retPtr, n); + if( m == 1 && n == 1 ) //VV + retPtr[0] = cblas_ddot(k, m1Ptr, 1, m2Ptr, 1); + else if( n == 1 ) //MV + cblas_dgemv(CblasRowMajor, CblasNoTrans, m, k, 1, m1Ptr, k, m2Ptr, 1, 0, retPtr, 1); + else //MM + cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, m, n, k, 1, m1Ptr, k, m2Ptr, n, 0, retPtr, n); } void smatmult(float* m1Ptr, float* m2Ptr, float* retPtr, int m, int k, int n, int numThreads) { + //BLAS routine dispatch according to input dimension sizes (we don't use cblas_sgemv + //with CblasColMajor for matrix-vector because it was generally slower than sgemm) setNumThreadsForBLAS(numThreads); - cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, m, n, k, 1, m1Ptr, k, m2Ptr, n, 0, retPtr, n); + if( m == 1 && n == 1 ) //VV + retPtr[0] = cblas_sdot(k, m1Ptr, 1, m2Ptr, 1); + else if( n == 1 ) //MV + cblas_sgemv(CblasRowMajor, CblasNoTrans, m, k, 1, m1Ptr, k, m2Ptr, 1, 0, retPtr, 1); + else //MM + cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, m, n, k, 1, m1Ptr, k, m2Ptr, n, 0, retPtr, n); } -void tsmm(double* m1Ptr, double* retPtr, int m1rlen, int m1clen, bool isLeftTrans, int numThreads) { - int n = isLeftTrans ? m1clen : m1rlen; - int k = isLeftTrans ? m1rlen : m1clen; +void tsmm(double* m1Ptr, double* retPtr, int m1rlen, int m1clen, bool leftTrans, int numThreads) { setNumThreadsForBLAS(numThreads); - cblas_dsyrk(CblasRowMajor, CblasUpper, isLeftTrans ? CblasTrans : CblasNoTrans, n, k, 1, m1Ptr, n, 0, retPtr, n); + if( (leftTrans && m1clen == 1) || (!leftTrans && m1rlen == 1) ) { + retPtr[0] = cblas_ddot(leftTrans ? m1rlen : m1clen, m1Ptr, 1, m1Ptr, 1); + } + else { //general case + int n = leftTrans ? m1clen : m1rlen; + int k = leftTrans ? m1rlen : m1clen; + cblas_dsyrk(CblasRowMajor, CblasUpper, leftTrans ? CblasTrans : CblasNoTrans, n, k, 1, m1Ptr, n, 0, retPtr, n); + } } http://git-wip-us.apache.org/repos/asf/systemml/blob/aff00094/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixNative.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixNative.java b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixNative.java index e122e7f..9fec026 100644 --- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixNative.java +++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixNative.java @@ -123,7 +123,7 @@ public class LibMatrixNative public static void tsmm(MatrixBlock m1, MatrixBlock ret, boolean leftTrans, int k) { if( m1.isEmptyBlock(false) ) return; - if( NativeHelper.isNativeLibraryLoaded() && ret.clen > 1 + if( NativeHelper.isNativeLibraryLoaded() && (ret.clen > 1 || ret.getLength()==1) && (!m1.sparse && m1.getDenseBlock().isContiguous() ) ) { ret.sparse = false; ret.allocateDenseBlock(); @@ -136,10 +136,8 @@ public class LibMatrixNative ret.examSparsity(); return; } - else { - Statistics.incrementNativeFailuresCounter(); - //fallback to default java implementation - } + //fallback to default java implementation + Statistics.incrementNativeFailuresCounter(); } if( k > 1 ) LibMatrixMult.matrixMultTransposeSelf(m1, ret, leftTrans, k);
