This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/master by this push:
new 5e49750 [SYSTEMDS-2860] Fix native BLAS tsmm integration (dsyrk outer
products)
5e49750 is described below
commit 5e497509ab9f2cd3218b74bf9576bee5241d95c3
Author: Matthias Boehm <[email protected]>
AuthorDate: Thu Feb 11 21:25:39 2021 +0100
[SYSTEMDS-2860] Fix native BLAS tsmm integration (dsyrk outer products)
This patch fixes an issue of the BLAS integration of dsyrk (tsmm in
SystemDS), which for a row vector input and left tsmm, apparently
returns the input vector. Since this operation is a memory-bandwidth
bound we avoid this edge case by using the respective Java kernels.
Furthermore, the tsmm runtimes where not yet included in the native BLAS
runtime statistics which is now also cleaned up.
---
src/main/cpp/libmatrixmult.cpp | 17 +++++++++--------
.../apache/sysds/runtime/matrix/data/LibMatrixMult.java | 4 ++++
.../sysds/runtime/matrix/data/LibMatrixNative.java | 15 ++++++++++++---
3 files changed, 25 insertions(+), 11 deletions(-)
diff --git a/src/main/cpp/libmatrixmult.cpp b/src/main/cpp/libmatrixmult.cpp
index 0d38511..a8ace72 100644
--- a/src/main/cpp/libmatrixmult.cpp
+++ b/src/main/cpp/libmatrixmult.cpp
@@ -26,11 +26,11 @@ void dmatmult(double *m1Ptr, double *m2Ptr, double *retPtr,
int m, int k, int n,
// slower than dgemm)
setNumThreadsForBLAS(numThreads);
if (m == 1 && n == 1) // VV
- retPtr[0] = cblas_ddot(k, m1Ptr, 1, m2Ptr, 1);
+ retPtr[0] = cblas_ddot(k, m1Ptr, 1, m2Ptr, 1);
else if (n == 1) // MV
- cblas_dgemv(CblasRowMajor, CblasNoTrans, m, k, 1, m1Ptr, k, m2Ptr, 1,
0, retPtr, 1);
+ cblas_dgemv(CblasRowMajor, CblasNoTrans, m, k, 1, m1Ptr, k, m2Ptr, 1, 0,
retPtr, 1);
else // MM
- cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, m, n, k, 1,
m1Ptr, k, m2Ptr, n, 0, retPtr, n);
+ cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, m, n, k, 1, m1Ptr,
k, m2Ptr, n, 0, retPtr, n);
}
void smatmult(float *m1Ptr, float *m2Ptr, float *retPtr, int m, int k, int n,
int numThreads) {
@@ -39,20 +39,21 @@ void smatmult(float *m1Ptr, float *m2Ptr, float *retPtr,
int m, int k, int n, in
// slower than sgemm)
setNumThreadsForBLAS(numThreads);
if (m == 1 && n == 1) // VV
- retPtr[0] = cblas_sdot(k, m1Ptr, 1, m2Ptr, 1);
+ retPtr[0] = cblas_sdot(k, m1Ptr, 1, m2Ptr, 1);
else if (n == 1) // MV
- cblas_sgemv(CblasRowMajor, CblasNoTrans, m, k, 1, m1Ptr, k, m2Ptr, 1,
0, retPtr, 1);
+ cblas_sgemv(CblasRowMajor, CblasNoTrans, m, k, 1, m1Ptr, k, m2Ptr, 1, 0,
retPtr, 1);
else // MM
- cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, m, n, k, 1,
m1Ptr, k, m2Ptr, n, 0, retPtr, n);
+ cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, m, n, k, 1, m1Ptr,
k, m2Ptr, n, 0, retPtr, n);
}
void tsmm(double *m1Ptr, double *retPtr, int m1rlen, int m1clen, bool
leftTrans, int numThreads) {
setNumThreadsForBLAS(numThreads);
if ((leftTrans && m1clen == 1) || (!leftTrans && m1rlen == 1))
- retPtr[0] = cblas_ddot(leftTrans ? m1rlen : m1clen, m1Ptr, 1, m1Ptr, 1);
+ retPtr[0] = cblas_ddot(leftTrans ? m1rlen : m1clen, m1Ptr, 1, m1Ptr, 1);
else { // general case
int n = leftTrans ? m1clen : m1rlen;
int k = leftTrans ? m1rlen : m1clen;
- cblas_dsyrk(CblasRowMajor, CblasUpper, leftTrans ? CblasTrans :
CblasNoTrans, n, k, 1, m1Ptr, n, 0, retPtr, n);
+ cblas_dsyrk(CblasRowMajor, CblasUpper,
+ leftTrans ? CblasTrans : CblasNoTrans, n, k, 1, m1Ptr, n, 0, retPtr, n);
}
}
diff --git
a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java
b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java
index a117a8d..02156f6 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java
@@ -3886,6 +3886,10 @@ public class LibMatrixMult
&& m1.getLength()+m2.getLength() <
(long)m1.rlen*m2.clen
&& outSp < MatrixBlock.SPARSITY_TURN_POINT);
}
+
+ public static boolean isOuterProductTSMM(int rlen, int clen, boolean
left) {
+ return left ? rlen == 1 & clen > 1 : rlen > 1 & clen == 1;
+ }
private static MatrixBlock prepMatrixMultRightInput( MatrixBlock m1,
MatrixBlock m2 ) {
MatrixBlock ret = m2;
diff --git
a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixNative.java
b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixNative.java
index 6e7ba49..ce7fbba 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixNative.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixNative.java
@@ -128,14 +128,23 @@ public class LibMatrixNative
public static void tsmm(MatrixBlock m1, MatrixBlock ret, boolean
leftTrans, int k) {
if( m1.isEmptyBlock(false) )
return;
+
if( NativeHelper.isNativeLibraryLoaded() && (ret.clen > 1 ||
ret.getLength()==1)
- && (!m1.sparse && m1.getDenseBlock().isContiguous() ) )
{
+ && !LibMatrixMult.isOuterProductTSMM(m1.rlen, m1.clen,
leftTrans)
+ && (!m1.sparse && m1.getDenseBlock().isContiguous() ) )
+ {
ret.sparse = false;
ret.allocateDenseBlock();
+ long start = DMLScript.STATISTICS ? System.nanoTime() :
0;
+
+ long nnz = NativeHelper.tsmm(m1.getDenseBlockValues(),
+ ret.getDenseBlockValues(), m1.rlen, m1.clen,
leftTrans, k);
- long nnz = NativeHelper.tsmm(m1.getDenseBlockValues(),
ret.getDenseBlockValues(),
-
m1.rlen, m1.clen, leftTrans, k);
if(nnz > -1) {
+ if(DMLScript.STATISTICS) {
+ Statistics.nativeLibMatrixMultTime +=
System.nanoTime() - start;
+
Statistics.numNativeLibMatrixMultCalls.increment();
+ }
ret.setNonZeros(nnz);
ret.examSparsity();
return;