[SYSTEMML-445] Improved performance of GPU right indexing

- Improved performance of dense right-indexing by increasing parallelization to 
elements rather than row.
- Added slice_sparse_dense_nnz to handle indexing of sparse wide matrices.
- Added EAGER_CUDA_FREE flag to allow users to enable/disable cudaFree 
optimization.
- Introduced IGNORE_CLEAR_MEMORY_BUG flag in GPUTests to run tests.
- Performance improvement of 4x over 1-layer CNN. Also, observed
  similar improvements for large dense rix. Microbenchmarks shows no
  performance degradation over other cases (please see
  https://github.com/apache/systemml/pull/667 for exhaustive results).

Closes #667.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/34bb3ca8
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/34bb3ca8
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/34bb3ca8

Branch: refs/heads/master
Commit: 34bb3ca82c9495b651d046ea21e8281fa2f8aa53
Parents: c14682b
Author: Niketan Pansare <[email protected]>
Authored: Wed Sep 20 14:19:27 2017 -0700
Committer: Niketan Pansare <[email protected]>
Committed: Wed Sep 20 14:22:57 2017 -0700

----------------------------------------------------------------------
 conf/SystemML-config.xml.template               |    3 +
 src/main/cpp/kernels/SystemML.cu                |   97 +-
 src/main/cpp/kernels/SystemML.ptx               | 3074 +++++++++---------
 .../java/org/apache/sysml/api/DMLScript.java    |    5 +-
 .../apache/sysml/api/ScriptExecutorUtils.java   |    1 +
 .../java/org/apache/sysml/conf/DMLConfig.java   |    8 +-
 .../instructions/gpu/GPUInstruction.java        |    3 +-
 .../instructions/gpu/context/CSRPointer.java    |    2 +-
 .../instructions/gpu/context/GPUContext.java    |    4 +-
 .../instructions/gpu/context/GPUObject.java     |    2 +-
 .../runtime/matrix/data/LibMatrixCUDA.java      |   61 +-
 .../runtime/matrix/data/LibMatrixCuDNN.java     |    6 +-
 .../org/apache/sysml/api/dl/Caffe2DML.scala     |   70 +
 .../org/apache/sysml/test/gpu/GPUTests.java     |   12 +-
 14 files changed, 1792 insertions(+), 1556 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/34bb3ca8/conf/SystemML-config.xml.template
----------------------------------------------------------------------
diff --git a/conf/SystemML-config.xml.template 
b/conf/SystemML-config.xml.template
index 1731a3b..ef24e30 100644
--- a/conf/SystemML-config.xml.template
+++ b/conf/SystemML-config.xml.template
@@ -87,6 +87,9 @@
     <!-- whether to synchronize GPUs after every GPU instruction -->
     <systemml.gpu.sync.postProcess>true</systemml.gpu.sync.postProcess>
     
+    <!-- whether to perform eager CUDA free on rmvar instruction -->
+    <systemml.gpu.eager.cudaFree>false</systemml.gpu.eager.cudaFree>
+    
     <!-- maximum wrap length for instruction and miscellaneous timer column of 
statistics -->
    <systemml.stats.maxWrapLength>30</systemml.stats.maxWrapLength>
 </root>

http://git-wip-us.apache.org/repos/asf/systemml/blob/34bb3ca8/src/main/cpp/kernels/SystemML.cu
----------------------------------------------------------------------
diff --git a/src/main/cpp/kernels/SystemML.cu b/src/main/cpp/kernels/SystemML.cu
index 231a32a..3e1a13a 100644
--- a/src/main/cpp/kernels/SystemML.cu
+++ b/src/main/cpp/kernels/SystemML.cu
@@ -26,10 +26,10 @@ nvcc -ptx -arch=sm_30 SystemML.cu
 #include <cfloat>
 #include <cmath>
 
-
 /**
  * Performs a slice operation where the input matrix is sparse and the output 
matrix is dense.
  * This function avoids unnecessary sparse to dense conversion of the input 
matrix.
+ * Parallelization: rows of output matrix.
  * 
  * @params inVal input val pointer
  * @params inRowPtr input row pointer
@@ -42,10 +42,33 @@ nvcc -ptx -arch=sm_30 SystemML.cu
  * @param retClen number of columns of output matrix
  */
 extern "C"
-__global__ void slice_sparse_dense(double* inVal, int* inRowPtr, int* colInd, 
double* ret, int rl, int ru, int cl, int cu, int retClen) {
-       int index = blockIdx.x * blockDim.x + threadIdx.x;
+__global__ void slice_sparse_dense_row(double* inVal, int* inRowPtr, int* 
colInd, double* ret, 
+    int rl, int ru, int cl, int cu, int retClen) {
+       int index = blockIdx.x * blockDim.x + threadIdx.x;
        int rowIndex = index + rl;
-    if (rowIndex <= ru){
+       if (rowIndex <= ru){
+               /*
+                * TODO: Alternative approach: use dynamic parallelism. We are 
skipping this for now to avoid
+                * the complexity of two-step separate compilation and linking 
process.
+                *  
+                * extern "C"
+                * __global__ void slice_sparse_dense_row_helper(double* inVal, 
int* inRowPtr, int* colInd, double* ret, 
+                *     int rl, int ru, int cl, int cu, int retClen, int start, 
int end, int index) {
+                *  int i = blockIdx.x * blockDim.x + threadIdx.x + start;   
+                *      // Only slice if the index falls into the given range
+                *      if(i < end && cl <= colInd[i] && colInd[i] <= cu) {
+                *              ret[ index*retClen + (colInd[i] - cl) ] = 
inVal[i];
+                *      }
+                * }
+                *
+                * int size = inRowPtr[rowIndex+1] - inRowPtr[rowIndex];
+                * double numThreads = (double)min(size, 
MAX_NUM_THREADS_CHILD_KERNEL);
+                * slice_sparse_dense_row_helper<<< ceil(numThreads/ 
MAX_NUM_THREADS_CHILD_KERNEL), MAX_NUM_THREADS_CHILD_KERNEL>>>(inVal, inRowPtr, 
colInd, ret, 
+        *                      rl, ru, cl, cu, retClen, inRowPtr[rowIndex], 
inRowPtr[rowIndex+1], index);
+        *
+        * Two-step compilation and linking process in JCudaKernels's 
constructor:
+        * cuLinkAddFile(linkState, CUjitInputType.CU_JIT_INPUT_LIBRARY, 
"/usr/local/cuda/lib64/libcudadevrt.a", jitOptions);
+                */
        // Iterate over elements of the row 'rowIndex'.
        for(int i = inRowPtr[rowIndex]; i < inRowPtr[rowIndex+1]; i++) {
                // Only slice if the index falls into the given range
@@ -57,6 +80,38 @@ __global__ void slice_sparse_dense(double* inVal, int* 
inRowPtr, int* colInd, do
 }
 
 /**
+ * Performs a slice operation where the input matrix is sparse and the output 
matrix is dense.
+ * This function avoids unnecessary sparse to dense conversion of the input 
matrix.
+ * Parallelization: subset of number of non-zeroes of input matrix.
+ * 
+ * @params inVal input val pointer
+ * @params inRowPtr input row pointer
+ * @params colInd input col index pointer
+ * @params ret dense output pointer
+ * @param rl row lower
+ * @param ru row upper
+ * @param cl column lower
+ * @param cu column upper
+ * @param retClen number of columns of output matrix
+ */
+extern "C"
+__global__ void slice_sparse_dense_nnz(double* inVal, int* inRowPtr, int* 
colInd, double* ret, 
+    int rl, int ru, int cl, int cu, int retClen) {
+    int tid = blockIdx.x * blockDim.x + threadIdx.x;
+    int i = tid + inRowPtr[rl];
+    
+    // Only slice if the index falls into the given range
+    if(i < inRowPtr[ru+1] && cl <= colInd[i] && colInd[i] <= cu) {
+       // Find the row index for corresponding non-zero value 'i'.
+       int rowIndex = rl;
+       while(inRowPtr[rowIndex+1] <= i) {
+               rowIndex++;
+       }
+           ret[ (rowIndex-rl)*retClen + (colInd[i] - cl) ] = inVal[i];
+    }
+}
+
+/**
  * Performs a slice operation where the input matrix is dense and the output 
matrix is dense.
  * 
  * @params in dense input pointer
@@ -66,19 +121,18 @@ __global__ void slice_sparse_dense(double* inVal, int* 
inRowPtr, int* colInd, do
  * @param cl column lower
  * @param cu column upper
  * @param inClen number of columns of input matrix
+ * @param retRlen number of rows of output matrix
  * @param retClen number of columns of output matrix
  */
 extern "C"
-__global__ void slice_dense_dense(double* in, double* ret, int rl, int ru, int 
cl, int cu, int inClen, int retClen) {
-       int index = blockIdx.x * blockDim.x + threadIdx.x;
-       int rowIndex = index + rl;
-    if (rowIndex <= ru){
-       int inIndex = rowIndex*inClen + cl;
-       int retIndex = index*retClen;
-       for(int i = retIndex; i < retIndex+retClen; i++, inIndex++) {
-                       ret[i] = in[inIndex];
-               }
-    }
+__global__ void slice_dense_dense(double* in, double* ret, int rl, int ru, int 
cl, int cu, int inClen, int retRlen, int retClen) {
+    int tid = blockIdx.x * blockDim.x + threadIdx.x;
+       int ix = tid / retClen;
+       int iy = tid % retClen;
+       if(ix < retRlen && iy < retClen) {
+           int inIndex = (ix + rl)*inClen + cl + iy;
+               ret[tid] = in[inIndex];
+       }
 }
 
 
@@ -164,8 +218,7 @@ __global__ void relu(double* A,  double* ret, int rlen, int 
clen) {
        int ix = tid / clen;
        int iy = tid % clen;
        if(ix < rlen && iy < clen) {
-               int index = ix * clen + iy;
-               ret[index] = max(0.0, A[index]);
+               ret[tid] = max(0.0, A[tid]);
        }
 }
 
@@ -176,8 +229,7 @@ __global__ void relu_backward(double* X,  double* dout, 
double* ret, int rlen, i
        int ix = tid / clen;
        int iy = tid % clen;
        if(ix < rlen && iy < clen) {
-               int index = ix * clen + iy;
-               ret[index] = X[index] > 0 ?  dout[index] : 0;
+               ret[tid] = X[tid] > 0 ?  dout[tid] : 0;
        }
 }
 
@@ -195,8 +247,7 @@ __global__ void inplace_add(double* input,  double* ret, 
int rlen, int clen) {
        int ix = tid / clen;
        int iy = tid % clen;
        if(ix < rlen && iy < clen) {
-               int index = ix * clen + iy;
-               ret[index] += input[index];
+               ret[tid] += input[tid];
        }
 }
 
@@ -210,9 +261,8 @@ __global__ void bias_add(double* input,  double* bias, 
double* ret, int rlen, in
        int ix = tid / clen;
        int iy = tid % clen;
        if(ix < rlen && iy < clen) {
-               int index = ix * clen + iy;
                int biasIndex = iy / PQ;
-               ret[index] = input[index] + bias[biasIndex];
+               ret[tid] = input[tid] + bias[biasIndex];
        }
 }
 
@@ -240,9 +290,8 @@ __global__ void bias_multiply(double* input,  double* bias, 
double* ret, int rle
        int ix = tid / clen;
        int iy = tid % clen;
        if(ix < rlen && iy < clen) {
-               int index = ix * clen + iy;
                int biasIndex = iy / PQ;
-               ret[index] = input[index] * bias[biasIndex];
+               ret[tid] = input[tid] * bias[biasIndex];
        }
 }
 

Reply via email to