[SYSTEMML-1039] Added row/col mean, sumsq, product Closes #346.
Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/02040346 Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/02040346 Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/02040346 Branch: refs/heads/master Commit: 020403466fce4220721a9ee6de50e95bee4a32e4 Parents: a4c7be7 Author: Nakul Jindal <[email protected]> Authored: Tue Jan 17 14:59:33 2017 -0800 Committer: Niketan Pansare <[email protected]> Committed: Tue Jan 17 14:59:33 2017 -0800 ---------------------------------------------------------------------- src/main/cpp/kernels/SystemML.cu | 119 +- src/main/cpp/kernels/SystemML.ptx | 2497 ++++++++++++------ .../java/org/apache/sysml/hops/AggUnaryOp.java | 10 +- .../instructions/GPUInstructionParser.java | 12 +- .../runtime/matrix/data/LibMatrixCUDA.java | 440 +-- 5 files changed, 2042 insertions(+), 1036 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/02040346/src/main/cpp/kernels/SystemML.cu ---------------------------------------------------------------------- diff --git a/src/main/cpp/kernels/SystemML.cu b/src/main/cpp/kernels/SystemML.cu index 8235d5d..4ce6fb2 100644 --- a/src/main/cpp/kernels/SystemML.cu +++ b/src/main/cpp/kernels/SystemML.cu @@ -203,7 +203,7 @@ __global__ void binCellScalarOp(double* A, double scalar, double* C, int rlenA, */ extern "C" __global__ void fill(double* A, double scalar, int lenA) { - int index = blockIdx.x * blockDim.x + threadIdx.x; + int index = blockIdx.x * blockDim.x + threadIdx.x; if (index < lenA){ A[index] = scalar; } @@ -295,16 +295,18 @@ __device__ void reduce( * there could be under-utilization of the hardware. * The template-ized version of this function is similar to what is found in NVIDIA CUB * @param ReductionOp Type of the functor object that implements the reduction operation + * @param AssignmentOp Type of the functor object that is used to modify the value before writing it to its final location in global memory for each row */ -template <typename ReductionOp> +template <typename ReductionOp, + typename AssignmentOp> __device__ void reduce_row( double *g_idata, ///< input data stored in device memory (of size rows*cols) double *g_odata, ///< output/temporary array store in device memory (of size rows*cols) unsigned int rows, ///< rows in input and temporary/output arrays unsigned int cols, ///< columns in input and temporary/output arrays - ReductionOp reduction_op, ///< Reduction operation to perform (functor object) - double initialValue) ///< initial value for the reduction variable -{ + ReductionOp reduction_op, ///< Reduction operation to perform (functor object) + AssignmentOp assignment_op, ///< Operation to perform before assigning this to its final location in global memory for each row + double initialValue){ ///< initial value for the reduction variable extern __shared__ double sdata[]; // one block per row @@ -327,8 +329,8 @@ __device__ void reduce_row( sdata[tid] = v; __syncthreads(); - // do reduction in shared mem - if (blockDim.x >= 1024){ if (tid < 512) { sdata[tid] = v = reduction_op(v, sdata[tid + 512]); } __syncthreads(); } + // do reduction in shared mem + if (blockDim.x >= 1024){ if (tid < 512) { sdata[tid] = v = reduction_op(v, sdata[tid + 512]); } __syncthreads(); } if (blockDim.x >= 512) { if (tid < 256) { sdata[tid] = v = reduction_op(v, sdata[tid + 256]); } __syncthreads(); } if (blockDim.x >= 256) { if (tid < 128) { sdata[tid] = v = reduction_op(v, sdata[tid + 128]); } __syncthreads(); } if (blockDim.x >= 128) { if (tid < 64) { sdata[tid] = v = reduction_op(v, sdata[tid + 64]); } __syncthreads(); } @@ -347,9 +349,9 @@ __device__ void reduce_row( if (blockDim.x >= 2) { smem[tid] = v = reduction_op(v, smem[tid + 1]); } } - // write result for this block to global mem + // write result for this block to global mem, modify it with assignment op if (tid == 0) - g_odata[block] = sdata[0]; + g_odata[block] = assignment_op(sdata[0]); } @@ -362,14 +364,17 @@ __device__ void reduce_row( * * The template-ized version of this function is similar to what is found in NVIDIA CUB * @param ReductionOp Type of the functor object that implements the reduction operation + * @param AssignmentOp Type of the functor object that is used to modify the value before writing it to its final location in global memory for each column */ -template <typename ReductionOp> +template <typename ReductionOp, + typename AssignmentOp> __device__ void reduce_col( double *g_idata, ///< input data stored in device memory (of size rows*cols) double *g_odata, ///< output/temporary array store in device memory (of size rows*cols) unsigned int rows, ///< rows in input and temporary/output arrays unsigned int cols, ///< columns in input and temporary/output arrays ReductionOp reduction_op, ///< Reduction operation to perform (functor object) + AssignmentOp assignment_op, ///< Operation to perform before assigning this to its final location in global memory for each column double initialValue) ///< initial value for the reduction variable { unsigned int global_tid = blockIdx.x * blockDim.x + threadIdx.x; @@ -385,10 +390,20 @@ __device__ void reduce_col( val = reduction_op(val, g_idata[i]); i += grid_size; } - g_odata[global_tid] = val; + g_odata[global_tid] = assignment_op(val); } /** + * Functor op for assignment op. This is a dummy/identity op. + */ +typedef struct { + __device__ __forceinline__ + double operator()(double a) const { + return a; + } +} IdentityOp; + +/** * Functor op for summation operation */ typedef struct { @@ -421,7 +436,8 @@ __global__ void reduce_sum(double *g_idata, double *g_odata, unsigned int n){ extern "C" __global__ void reduce_row_sum(double *g_idata, double *g_odata, unsigned int rows, unsigned int cols){ SumOp op; - reduce_row<SumOp>(g_idata, g_odata, rows, cols, op, 0.0); + IdentityOp aop; + reduce_row<SumOp, IdentityOp>(g_idata, g_odata, rows, cols, op, aop, 0.0); } /** @@ -434,7 +450,8 @@ __global__ void reduce_row_sum(double *g_idata, double *g_odata, unsigned int ro extern "C" __global__ void reduce_col_sum(double *g_idata, double *g_odata, unsigned int rows, unsigned int cols){ SumOp op; - reduce_col<SumOp>(g_idata, g_odata, rows, cols, op, 0.0); + IdentityOp aop; + reduce_col<SumOp, IdentityOp>(g_idata, g_odata, rows, cols, op, aop, 0.0); } @@ -471,7 +488,8 @@ __global__ void reduce_max(double *g_idata, double *g_odata, unsigned int n){ extern "C" __global__ void reduce_row_max(double *g_idata, double *g_odata, unsigned int rows, unsigned int cols){ MaxOp op; - reduce_row<MaxOp>(g_idata, g_odata, rows, cols, op, DBL_MIN); + IdentityOp aop; + reduce_row<MaxOp, IdentityOp>(g_idata, g_odata, rows, cols, op, aop, DBL_MIN); } /** @@ -484,7 +502,8 @@ __global__ void reduce_row_max(double *g_idata, double *g_odata, unsigned int ro extern "C" __global__ void reduce_col_max(double *g_idata, double *g_odata, unsigned int rows, unsigned int cols){ MaxOp op; - reduce_col<MaxOp>(g_idata, g_odata, rows, cols, op, DBL_MIN); + IdentityOp aop; + reduce_col<MaxOp, IdentityOp>(g_idata, g_odata, rows, cols, op, aop, DBL_MIN); } /** @@ -519,7 +538,8 @@ __global__ void reduce_min(double *g_idata, double *g_odata, unsigned int n){ extern "C" __global__ void reduce_row_min(double *g_idata, double *g_odata, unsigned int rows, unsigned int cols){ MinOp op; - reduce_row<MinOp>(g_idata, g_odata, rows, cols, op, DBL_MAX); + IdentityOp aop; + reduce_row<MinOp, IdentityOp>(g_idata, g_odata, rows, cols, op, aop, DBL_MAX); } /** @@ -532,5 +552,70 @@ __global__ void reduce_row_min(double *g_idata, double *g_odata, unsigned int ro extern "C" __global__ void reduce_col_min(double *g_idata, double *g_odata, unsigned int rows, unsigned int cols){ MinOp op; - reduce_col<MinOp>(g_idata, g_odata, rows, cols, op, DBL_MAX); + IdentityOp aop; + reduce_col<MinOp>(g_idata, g_odata, rows, cols, op, aop, DBL_MAX); +} + +/** + * Functor op for product operation + */ +typedef struct { + __device__ __forceinline__ + double operator()(double a, double b) const { + return a * b; + } +} ProductOp; + +/** + * Do a product over all elements of an array/matrix + * @param g_idata input data stored in device memory (of size n) + * @param g_odata output/temporary array stode in device memory (of size n) + * @param n size of the input and temporary/output arrays + */ +extern "C" +__global__ void reduce_prod(double *g_idata, double *g_odata, unsigned int n){ + ProductOp op; + reduce<ProductOp>(g_idata, g_odata, n, op, 1.0); +} + +/** + * Functor op for mean operation + */ +struct MeanOp { + const long _size; ///< Number of elements by which to divide to calculate mean + __device__ __forceinline__ + MeanOp(long size): _size(size) {} + __device__ __forceinline__ + double operator()(double total) const { + return total / _size; + } +}; + + +/** + * Do a mean over all rows of a matrix + * @param g_idata input matrix stored in device memory (of size rows * cols) + * @param g_odata output vector stored in device memory (of size rows) + * @param rows number of rows in input matrix + * @param cols number of columns in input matrix + */ +extern "C" +__global__ void reduce_row_mean(double *g_idata, double *g_odata, unsigned int rows, unsigned int cols){ + SumOp op; + MeanOp aop(rows*cols); + reduce_row<SumOp, MeanOp>(g_idata, g_odata, rows, cols, op, aop, 0.0); +} + +/** + * Do a mean over all columns of a matrix + * @param g_idata input matrix stored in device memory (of size rows * cols) + * @param g_odata output vector stored in device memory (of size cols) + * @param rows number of rows in input matrix + * @param cols number of columns in input matrix + */ +extern "C" +__global__ void reduce_col_mean(double *g_idata, double *g_odata, unsigned int rows, unsigned int cols){ + SumOp op; + MeanOp aop(rows*cols); + reduce_col<SumOp, MeanOp>(g_idata, g_odata, rows, cols, op, aop, 0.0); }
