Repository: incubator-systemml
Updated Branches:
  refs/heads/master 1fc764b9b -> c3aeb48bf


[HOTFIX] for sparse GPU transpose


Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: 
http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/c3aeb48b
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/c3aeb48b
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/c3aeb48b

Branch: refs/heads/master
Commit: c3aeb48bf6b54febb861b7b4381c3d7af450a8e8
Parents: 1fc764b
Author: Nakul Jindal <[email protected]>
Authored: Wed May 17 18:46:21 2017 -0700
Committer: Nakul Jindal <[email protected]>
Committed: Wed May 17 18:46:21 2017 -0700

----------------------------------------------------------------------
 .../runtime/matrix/data/LibMatrixCUDA.java      | 118 +++++++++++--------
 1 file changed, 68 insertions(+), 50 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/c3aeb48b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java
index 074119b..b023159 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java
@@ -52,6 +52,7 @@ import static jcuda.jcudnn.cudnnDataType.CUDNN_DATA_DOUBLE;
 import static jcuda.jcudnn.cudnnNanPropagation.CUDNN_PROPAGATE_NAN;
 import static jcuda.jcudnn.cudnnPoolingMode.CUDNN_POOLING_MAX;
 import static jcuda.jcudnn.cudnnTensorFormat.CUDNN_TENSOR_NCHW;
+import static jcuda.jcusparse.JCusparse.cusparseDcsr2csc;
 import static jcuda.jcusparse.JCusparse.cusparseDcsrgemm;
 import static jcuda.jcusparse.JCusparse.cusparseDcsrmv;
 import static 
jcuda.jcusparse.cusparseOperation.CUSPARSE_OPERATION_NON_TRANSPOSE;
@@ -61,6 +62,8 @@ import static 
jcuda.runtime.cudaMemcpyKind.cudaMemcpyDeviceToDevice;
 import static jcuda.runtime.cudaMemcpyKind.cudaMemcpyDeviceToHost;
 import static jcuda.runtime.cudaMemcpyKind.cudaMemcpyHostToDevice;
 
+import jcuda.jcusparse.cusparseAction;
+import jcuda.jcusparse.cusparseIndexBase;
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
 import org.apache.sysml.api.DMLScript;
@@ -2732,7 +2735,7 @@ public class LibMatrixCUDA {
         * Performs sparse and dense dgeam given two input matrices
         * C = alpha* op( A ) + beta* op ( B )
         * where op = transpose or not (specified by isLeftTransposed and 
isRightTransposed).
-        *
+        * To indicate a transpose operation, make sure in1 == in2 and 
isLeftTransposed == isRightTransposed == true
         * @param ec execution context
         * @param gCtx a valid {@link GPUContext}
         * @param instName the invoking instruction's name for record {@link 
Statistics}.
@@ -2756,35 +2759,6 @@ public class LibMatrixCUDA {
                int transa = isLeftTransposed ? CUBLAS_OP_T : CUBLAS_OP_N;
                int transb = isRightTransposed ? CUBLAS_OP_T : CUBLAS_OP_N;
 
-               int lda = (int) in1.getNumColumns();
-               int ldb = (int) in2.getNumColumns();
-               int m = (int) in1.getNumColumns();
-               int n = (int) in2.getNumRows();
-               if (isLeftTransposed && isRightTransposed) {
-                       m = (int) in1.getNumRows();
-                       n = (int) in2.getNumColumns();
-               }
-               else if (isLeftTransposed) {
-                       m = (int) in1.getNumRows();
-               } else if (isRightTransposed) {
-                       n = (int) in2.getNumColumns();
-               }
-               int ldc = m;
-
-
-
-               /**
-               int m = (int) in1.getNumRows();
-               int n = (int) in1.getNumColumns();
-               if(!isLeftTransposed && isRightTransposed) {
-                       m = (int) in1.getNumColumns();
-                       n = (int) in1.getNumRows();
-               }
-               int lda = isLeftTransposed ? n : m;
-               int ldb = isRightTransposed ? n : m;
-               int ldc = m;
-               **/
-
                MatrixObject out = ec.getMatrixObject(outputName);
                boolean isSparse1 = isInSparseFormat(gCtx, in1);
                boolean isSparse2 = isInSparseFormat(gCtx, in2);
@@ -2792,39 +2766,83 @@ public class LibMatrixCUDA {
                long t0=0,t1=0;
                // TODO: Implement sparse-dense matrix cublasDgeam kernel
                if(isSparse1 || isSparse2) {
+                       int m = (int)in1.getNumRows();
+                       int n = (int)in1.getNumColumns();
                        // Invoke cuSparse when either are in sparse format
                        // Perform sparse-sparse dgeam
-                       if(!isInSparseFormat(gCtx, in1)) {
-                               if (GPUStatistics.DISPLAY_STATISTICS) t0 = 
System.nanoTime();
+                       if (!isInSparseFormat(gCtx, in1)) {
+                               if (GPUStatistics.DISPLAY_STATISTICS)
+                                       t0 = System.nanoTime();
                                in1.getGPUObject(gCtx).denseToSparse();
-                               if (GPUStatistics.DISPLAY_STATISTICS) 
GPUStatistics.maintainCPMiscTimes(instName, 
GPUInstruction.MISC_TIMER_DENSE_TO_SPARSE, System.nanoTime() - t0);
+                               if (GPUStatistics.DISPLAY_STATISTICS)
+                                       
GPUStatistics.maintainCPMiscTimes(instName, 
GPUInstruction.MISC_TIMER_DENSE_TO_SPARSE,
+                                                       System.nanoTime() - t0);
                        }
                        CSRPointer A = 
in1.getGPUObject(gCtx).getJcudaSparseMatrixPtr();
-                       if(!isInSparseFormat(gCtx, in2)) {
-                               if (GPUStatistics.DISPLAY_STATISTICS) t0 = 
System.nanoTime();
+                       if (!isInSparseFormat(gCtx, in2)) {
+                               if (GPUStatistics.DISPLAY_STATISTICS)
+                                       t0 = System.nanoTime();
                                in2.getGPUObject(gCtx).denseToSparse();
-                               if (GPUStatistics.DISPLAY_STATISTICS) 
GPUStatistics.maintainCPMiscTimes(instName, 
GPUInstruction.MISC_TIMER_DENSE_TO_SPARSE, System.nanoTime() - t0);
+                               if (GPUStatistics.DISPLAY_STATISTICS)
+                                       
GPUStatistics.maintainCPMiscTimes(instName, 
GPUInstruction.MISC_TIMER_DENSE_TO_SPARSE,
+                                                       System.nanoTime() - t0);
                        }
                        CSRPointer B = 
in2.getGPUObject(gCtx).getJcudaSparseMatrixPtr();
 
                        ec.allocateGPUMatrixObject(outputName);
+                       out.getGPUObject(gCtx).addReadLock();
 
-                       if (GPUStatistics.DISPLAY_STATISTICS) t1 = 
System.nanoTime();
-                       CSRPointer C = CSRPointer.allocateForDgeam(gCtx, 
getCusparseHandle(gCtx), A, B, m, n);
-                       if (GPUStatistics.DISPLAY_STATISTICS) 
GPUStatistics.maintainCPMiscTimes(instName, 
GPUInstruction.MISC_TIMER_SPARSE_ALLOCATE_LIB, System.nanoTime() - t1);
+                       if (in1 == in2 && isLeftTransposed == true && 
isLeftTransposed == isRightTransposed) {
+                               // Special case for transpose
 
-                       out.getGPUObject(gCtx).setSparseMatrixCudaPointer(C);
-                       //long sizeOfC = CSRPointer.estimateSize(C.nnz, 
out.getNumRows());
-                       out.getGPUObject(gCtx).addReadLock();
-                       if (GPUStatistics.DISPLAY_STATISTICS) t0 = 
System.nanoTime();
-                       JCusparse.cusparseDcsrgeam(getCusparseHandle(gCtx), m, 
n, alphaPtr, A.descr, (int)A.nnz, A.val, A.rowPtr, A.colInd, betaPtr,
-                                                       B.descr, (int)B.nnz, 
B.val, B.rowPtr, B.colInd,
-                                                       C.descr, C.val, 
C.rowPtr, C.colInd);
-                       //cudaDeviceSynchronize;
-                       if (GPUStatistics.DISPLAY_STATISTICS) 
GPUStatistics.maintainCPMiscTimes(instName, 
GPUInstruction.MISC_TIMER_SPARSE_DGEAM_LIB, System.nanoTime() - t0);
-               }
-               else {
+                               int nnz = (int)A.nnz;
+                               CSRPointer C = CSRPointer.allocateEmpty(gCtx, 
nnz, n);
+                               
out.getGPUObject(gCtx).setSparseMatrixCudaPointer(C);
+                               cusparseDcsr2csc(getCusparseHandle(gCtx), m, n, 
nnz, A.val, A.rowPtr, A.colInd, C.val, C.colInd, C.rowPtr, 
cusparseAction.CUSPARSE_ACTION_NUMERIC, 
cusparseIndexBase.CUSPARSE_INDEX_BASE_ZERO);
+                       } else {
+                               // General case (cusparse does not support 
accept the transpose operator for dgeam)
+                               // TODO: to implement the transposed + dgeam 
for sparse matrices, they need to be converted to csc, which is effectively a 
tranpose
+                               if (isLeftTransposed || isRightTransposed) {
+                                       throw new DMLRuntimeException(
+                                                       "Transpose in 
cusparseDcsrgeam not supported for sparse matrices on GPU");
+                               }
+
+                               if (GPUStatistics.DISPLAY_STATISTICS)
+                                       t1 = System.nanoTime();
+                               CSRPointer C = 
CSRPointer.allocateForDgeam(gCtx, getCusparseHandle(gCtx), A, B, m, n);
+                               if (GPUStatistics.DISPLAY_STATISTICS)
+                                       
GPUStatistics.maintainCPMiscTimes(instName, 
GPUInstruction.MISC_TIMER_SPARSE_ALLOCATE_LIB,
+                                                       System.nanoTime() - t1);
+
+                               
out.getGPUObject(gCtx).setSparseMatrixCudaPointer(C);
+                               //long sizeOfC = CSRPointer.estimateSize(C.nnz, 
out.getNumRows());
+                               if (GPUStatistics.DISPLAY_STATISTICS)
+                                       t0 = System.nanoTime();
+                               
JCusparse.cusparseDcsrgeam(getCusparseHandle(gCtx), m, n, alphaPtr, A.descr, 
(int) A.nnz, A.val, A.rowPtr, A.colInd, betaPtr,
+                                               B.descr, (int) B.nnz, B.val, 
B.rowPtr, B.colInd, C.descr, C.val, C.rowPtr, C.colInd);
+                               //cudaDeviceSynchronize;
+                               if (GPUStatistics.DISPLAY_STATISTICS)
+                                       
GPUStatistics.maintainCPMiscTimes(instName, 
GPUInstruction.MISC_TIMER_SPARSE_DGEAM_LIB,
+                                                       System.nanoTime() - t0);
+                       }
+               } else {
                        // Dense-Dense dgeam
+
+                       int lda = (int) in1.getNumColumns();
+                       int ldb = (int) in2.getNumColumns();
+                       int m = (int) in1.getNumColumns();
+                       int n = (int) in2.getNumRows();
+                       if (isLeftTransposed && isRightTransposed) {
+                               m = (int) in1.getNumRows();
+                               n = (int) in2.getNumColumns();
+                       }
+                       else if (isLeftTransposed) {
+                               m = (int) in1.getNumRows();
+                       } else if (isRightTransposed) {
+                               n = (int) in2.getNumColumns();
+                       }
+                       int ldc = m;
+
                        Pointer A = getDensePointer(gCtx, in1, instName);
                        Pointer B = getDensePointer(gCtx, in2, instName);
                        getDenseMatrixOutputForGPUInstruction(ec, instName, 
outputName);        // Allocated the dense output matrix

Reply via email to