[SYSTEMML-1039] Added row/col min, row/col max & mean Closes #343.
Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/627fdbe2 Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/627fdbe2 Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/627fdbe2 Branch: refs/heads/master Commit: 627fdbe2d233d4326b596dd512fc5164f7fb140e Parents: c528b76 Author: Nakul Jindal <njin...@us.ibm.com> Authored: Fri Jan 13 19:08:01 2017 -0800 Committer: Niketan Pansare <npan...@us.ibm.com> Committed: Fri Jan 13 19:08:01 2017 -0800 ---------------------------------------------------------------------- src/main/cpp/kernels/SystemML.cu | 157 +- src/main/cpp/kernels/SystemML.ptx | 5704 ++++++++++++------ .../java/org/apache/sysml/api/DMLScript.java | 2 +- .../java/org/apache/sysml/conf/DMLConfig.java | 2 +- .../java/org/apache/sysml/hops/AggUnaryOp.java | 6 +- .../instructions/GPUInstructionParser.java | 25 +- .../instructions/gpu/context/GPUContext.java | 5 +- .../instructions/gpu/context/JCudaContext.java | 214 +- .../instructions/gpu/context/JCudaObject.java | 12 +- .../runtime/matrix/data/LibMatrixCUDA.java | 150 +- 10 files changed, 4321 insertions(+), 1956 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/627fdbe2/src/main/cpp/kernels/SystemML.cu ---------------------------------------------------------------------- diff --git a/src/main/cpp/kernels/SystemML.cu b/src/main/cpp/kernels/SystemML.cu index 5964707..8235d5d 100644 --- a/src/main/cpp/kernels/SystemML.cu +++ b/src/main/cpp/kernels/SystemML.cu @@ -128,7 +128,7 @@ __global__ void reluBackward(double* X, double* dout, double* ret, int rlen, in } // Performs the operation corresponding to the DML script: -// ones = matrix(1, rows=1, cols=Hout*Wout) +// ones = matrix(1, rows=1, cols=Hout*Wout) // output = input + matrix(bias %*% ones, rows=1, cols=F*Hout*Wout) // This operation is often followed by conv2d and hence we have introduced bias_add(input, bias) built-in function extern "C" @@ -226,7 +226,7 @@ __global__ void fill(double* A, double scalar, int lenA) { template <typename ReductionOp> __device__ void reduce( double *g_idata, ///< input data stored in device memory (of size n) - double *g_odata, ///< output/temporary array stode in device memory (of size n) + double *g_odata, ///< output/temporary array stored in device memory (of size n) unsigned int n, ///< size of the input and temporary/output arrays ReductionOp reduction_op, ///< Reduction operation to perform (functor object) double initialValue) ///< initial value for the reduction variable @@ -259,6 +259,7 @@ __device__ void reduce( // do reduction in shared mem + if (blockDim.x >= 1024){ if (tid < 512) { sdata[tid] = v = reduction_op(v, sdata[tid + 512]); } __syncthreads(); } if (blockDim.x >= 512) { if (tid < 256) { sdata[tid] = v = reduction_op(v, sdata[tid + 256]); } __syncthreads(); } if (blockDim.x >= 256) { if (tid < 128) { sdata[tid] = v = reduction_op(v, sdata[tid + 128]); } __syncthreads(); } if (blockDim.x >= 128) { if (tid < 64) { sdata[tid] = v = reduction_op(v, sdata[tid + 64]); } __syncthreads(); } @@ -292,13 +293,17 @@ __device__ void reduce( * This works out fine for SystemML, since the maximum elements in a Java array can be 2^31 - c (some small constant) * If the matrix is "fat" and "short", i.e. there are small number of rows and a large number of columns, * there could be under-utilization of the hardware. - * @param g_idata input matrix stored in device memory - * @param g_odata output vector of size [rows * 1] in device memory - * @param rows number of rows in input matrix - * @param cols number of columns in input matrix + * The template-ized version of this function is similar to what is found in NVIDIA CUB + * @param ReductionOp Type of the functor object that implements the reduction operation */ -extern "C" -__global__ void reduce_row(double *g_idata, double *g_odata, unsigned int rows, unsigned int cols) +template <typename ReductionOp> +__device__ void reduce_row( + double *g_idata, ///< input data stored in device memory (of size rows*cols) + double *g_odata, ///< output/temporary array store in device memory (of size rows*cols) + unsigned int rows, ///< rows in input and temporary/output arrays + unsigned int cols, ///< columns in input and temporary/output arrays + ReductionOp reduction_op, ///< Reduction operation to perform (functor object) + double initialValue) ///< initial value for the reduction variable { extern __shared__ double sdata[]; @@ -312,9 +317,9 @@ __global__ void reduce_row(double *g_idata, double *g_odata, unsigned int rows, unsigned int i = tid; unsigned int block_offset = block * cols; - double v = 0; + double v = initialValue; while (i < cols){ - v += g_idata[block_offset + i]; + v = reduction_op(v, g_idata[block_offset + i]); i += blockDim.x; } @@ -322,11 +327,11 @@ __global__ void reduce_row(double *g_idata, double *g_odata, unsigned int rows, sdata[tid] = v; __syncthreads(); - - // do reduction in shared mem - if (blockDim.x >= 512) { if (tid < 256) { sdata[tid] = v = v + sdata[tid + 256]; } __syncthreads(); } - if (blockDim.x >= 256) { if (tid < 128) { sdata[tid] = v = v + sdata[tid + 128]; } __syncthreads(); } - if (blockDim.x >= 128) { if (tid < 64) { sdata[tid] = v = v + sdata[tid + 64]; } __syncthreads(); } + // do reduction in shared mem + if (blockDim.x >= 1024){ if (tid < 512) { sdata[tid] = v = reduction_op(v, sdata[tid + 512]); } __syncthreads(); } + if (blockDim.x >= 512) { if (tid < 256) { sdata[tid] = v = reduction_op(v, sdata[tid + 256]); } __syncthreads(); } + if (blockDim.x >= 256) { if (tid < 128) { sdata[tid] = v = reduction_op(v, sdata[tid + 128]); } __syncthreads(); } + if (blockDim.x >= 128) { if (tid < 64) { sdata[tid] = v = reduction_op(v, sdata[tid + 64]); } __syncthreads(); } if (tid < 32) { @@ -334,12 +339,12 @@ __global__ void reduce_row(double *g_idata, double *g_odata, unsigned int rows, // we need to declare our shared memory volatile so that the compiler // doesn't reorder stores to it and induce incorrect behavior. volatile double* smem = sdata; - if (blockDim.x >= 64) { smem[tid] = v = v + smem[tid + 32]; } - if (blockDim.x >= 32) { smem[tid] = v = v + smem[tid + 16]; } - if (blockDim.x >= 16) { smem[tid] = v = v + smem[tid + 8]; } - if (blockDim.x >= 8) { smem[tid] = v = v + smem[tid + 4]; } - if (blockDim.x >= 4) { smem[tid] = v = v + smem[tid + 2]; } - if (blockDim.x >= 2) { smem[tid] = v = v + smem[tid + 1]; } + if (blockDim.x >= 64) { smem[tid] = v = reduction_op(v, smem[tid + 32]); } + if (blockDim.x >= 32) { smem[tid] = v = reduction_op(v, smem[tid + 16]); } + if (blockDim.x >= 16) { smem[tid] = v = reduction_op(v, smem[tid + 8]); } + if (blockDim.x >= 8) { smem[tid] = v = reduction_op(v, smem[tid + 4]); } + if (blockDim.x >= 4) { smem[tid] = v = reduction_op(v, smem[tid + 2]); } + if (blockDim.x >= 2) { smem[tid] = v = reduction_op(v, smem[tid + 1]); } } // write result for this block to global mem @@ -351,16 +356,21 @@ __global__ void reduce_row(double *g_idata, double *g_odata, unsigned int rows, /** * Does a column wise reduction. * The intuition is that there are as many global threads as there are columns - * Each global thread is responsible for a single element in the output vector + * Each global thread is responsible for a single element in the output vector * This of course leads to a under-utilization of the GPU resources. * For cases, where the number of columns is small, there can be unused SMs - * @param g_idata input matrix stored in device memory - * @param g_odata output vector of size [1 * cols] in device memory - * @param rows number of rows in input matrix - * @param cols number of columns in input matrix + * + * The template-ized version of this function is similar to what is found in NVIDIA CUB + * @param ReductionOp Type of the functor object that implements the reduction operation */ -extern "C" -__global__ void reduce_col(double *g_idata, double *g_odata, unsigned int rows, unsigned int cols) +template <typename ReductionOp> +__device__ void reduce_col( + double *g_idata, ///< input data stored in device memory (of size rows*cols) + double *g_odata, ///< output/temporary array store in device memory (of size rows*cols) + unsigned int rows, ///< rows in input and temporary/output arrays + unsigned int cols, ///< columns in input and temporary/output arrays + ReductionOp reduction_op, ///< Reduction operation to perform (functor object) + double initialValue) ///< initial value for the reduction variable { unsigned int global_tid = blockIdx.x * blockDim.x + threadIdx.x; if (global_tid >= cols) { @@ -369,10 +379,10 @@ __global__ void reduce_col(double *g_idata, double *g_odata, unsigned int rows, unsigned int i = global_tid; unsigned int grid_size = cols; - double val = 0; + double val = initialValue; while (i < rows * cols) { - val += g_idata[i]; + val = reduction_op(val, g_idata[i]); i += grid_size; } g_odata[global_tid] = val; @@ -392,7 +402,7 @@ typedef struct { /** * Do a summation over all elements of an array/matrix * @param g_idata input data stored in device memory (of size n) - * @param g_odata output/temporary array stode in device memory (of size n) + * @param g_odata output/temporary array stored in device memory (of size n) * @param n size of the input and temporary/output arrays */ extern "C" @@ -402,6 +412,33 @@ __global__ void reduce_sum(double *g_idata, double *g_odata, unsigned int n){ } /** + * Do a summation over all rows of a matrix + * @param g_idata input matrix stored in device memory (of size rows * cols) + * @param g_odata output vector stored in device memory (of size rows) + * @param rows number of rows in input matrix + * @param cols number of columns in input matrix + */ +extern "C" +__global__ void reduce_row_sum(double *g_idata, double *g_odata, unsigned int rows, unsigned int cols){ + SumOp op; + reduce_row<SumOp>(g_idata, g_odata, rows, cols, op, 0.0); +} + +/** + * Do a summation over all columns of a matrix + * @param g_idata input matrix stored in device memory (of size rows * cols) + * @param g_odata output vector stored in device memory (of size cols) + * @param rows number of rows in input matrix + * @param cols number of columns in input matrix + */ +extern "C" +__global__ void reduce_col_sum(double *g_idata, double *g_odata, unsigned int rows, unsigned int cols){ + SumOp op; + reduce_col<SumOp>(g_idata, g_odata, rows, cols, op, 0.0); +} + + +/** * Functor op for max operation */ typedef struct { @@ -420,8 +457,34 @@ typedef struct { */ extern "C" __global__ void reduce_max(double *g_idata, double *g_odata, unsigned int n){ - MaxOp op; - reduce<MaxOp>(g_idata, g_odata, n, op, DBL_MIN); + MaxOp op; + reduce<MaxOp>(g_idata, g_odata, n, op, DBL_MIN); +} + +/** + * Do a max over all rows of a matrix + * @param g_idata input matrix stored in device memory (of size rows * cols) + * @param g_odata output vector stored in device memory (of size rows) + * @param rows number of rows in input matrix + * @param cols number of columns in input matrix + */ +extern "C" +__global__ void reduce_row_max(double *g_idata, double *g_odata, unsigned int rows, unsigned int cols){ + MaxOp op; + reduce_row<MaxOp>(g_idata, g_odata, rows, cols, op, DBL_MIN); +} + +/** + * Do a max over all columns of a matrix + * @param g_idata input matrix stored in device memory (of size rows * cols) + * @param g_odata output vector stored in device memory (of size cols) + * @param rows number of rows in input matrix + * @param cols number of columns in input matrix + */ +extern "C" +__global__ void reduce_col_max(double *g_idata, double *g_odata, unsigned int rows, unsigned int cols){ + MaxOp op; + reduce_col<MaxOp>(g_idata, g_odata, rows, cols, op, DBL_MIN); } /** @@ -443,5 +506,31 @@ typedef struct { extern "C" __global__ void reduce_min(double *g_idata, double *g_odata, unsigned int n){ MinOp op; - reduce<MinOp>(g_idata, g_odata, n, op, DBL_MAX); + reduce<MinOp>(g_idata, g_odata, n, op, DBL_MAX); +} + +/** + * Do a min over all rows of a matrix + * @param g_idata input matrix stored in device memory (of size rows * cols) + * @param g_odata output vector stored in device memory (of size rows) + * @param rows number of rows in input matrix + * @param cols number of columns in input matrix + */ +extern "C" +__global__ void reduce_row_min(double *g_idata, double *g_odata, unsigned int rows, unsigned int cols){ + MinOp op; + reduce_row<MinOp>(g_idata, g_odata, rows, cols, op, DBL_MAX); +} + +/** + * Do a min over all columns of a matrix + * @param g_idata input matrix stored in device memory (of size rows * cols) + * @param g_odata output vector stored in device memory (of size cols) + * @param rows number of rows in input matrix + * @param cols number of columns in input matrix + */ +extern "C" +__global__ void reduce_col_min(double *g_idata, double *g_odata, unsigned int rows, unsigned int cols){ + MinOp op; + reduce_col<MinOp>(g_idata, g_odata, rows, cols, op, DBL_MAX); }