[SYSTEMML-1039] Added uamean, uamax, uamin

Closes #334.


Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: 
http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/88526702
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/88526702
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/88526702

Branch: refs/heads/master
Commit: 885267024596cf3d675208650cee7b7930af3177
Parents: 13e1bd9
Author: Nakul Jindal <[email protected]>
Authored: Sun Jan 8 13:33:59 2017 -0800
Committer: Niketan Pansare <[email protected]>
Committed: Sun Jan 8 13:33:59 2017 -0800

----------------------------------------------------------------------
 src/main/cpp/kernels/SystemML.cu                |  138 +-
 src/main/cpp/kernels/SystemML.ptx               | 3428 +++++++++++-------
 .../java/org/apache/sysml/hops/AggUnaryOp.java  |    5 +-
 .../instructions/GPUInstructionParser.java      |    4 +-
 .../runtime/matrix/data/LibMatrixCUDA.java      |   69 +-
 5 files changed, 2318 insertions(+), 1326 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/88526702/src/main/cpp/kernels/SystemML.cu
----------------------------------------------------------------------
diff --git a/src/main/cpp/kernels/SystemML.cu b/src/main/cpp/kernels/SystemML.cu
index f57c04f..1812b6a 100644
--- a/src/main/cpp/kernels/SystemML.cu
+++ b/src/main/cpp/kernels/SystemML.cu
@@ -23,6 +23,7 @@ please compile the ptx file and commit it:
 nvcc -ptx SystemML.cu
 ***********************************/
 
+#include <cfloat>
 
 // dim => rlen (Assumption: rlen == clen)
 // N = length of dense array
@@ -182,9 +183,8 @@ __global__ void fill(double* A, double scalar, int lenA) {
        }
 }
 
-
 /**
- * Does a reduce (sum) over all elements of the array.
+ * Does a reduce operation over all elements of the array.
  * This method has been adapted from the Reduction sample in the NVIDIA CUDA 
Samples (v8.0)
  * and the Reduction example available through jcuda.org
  * When invoked initially, all blocks partly compute the reduction operation 
over the entire array
@@ -193,12 +193,17 @@ __global__ void fill(double* A, double scalar, int lenA) {
  * The number of threads, blocks and amount of shared memory is calculated in 
a specific way.
  * Please refer to the NVIDIA CUDA Sample or the SystemML code that invokes 
this method to see
  * how its done.
- * @param g_idata   input data stored in device memory (of size n)
- * @param g_odata   output/temporary array stode in device memory (of size n)
- * @param n         size of the input and temporary/output arrays
+ * The template-ized version of this function is similar to what is found in 
NVIDIA CUB
+ *
+ * @param ReductionOp       Type of the functor object that implements the 
reduction operation
  */
-extern "C"
-__global__ void reduce(double *g_idata, double *g_odata, unsigned int n)
+template <typename ReductionOp>
+__device__ void reduce(
+    double *g_idata,            ///< input data stored in device memory (of 
size n)
+    double *g_odata,            ///< output/temporary array stode in device 
memory (of size n)
+    unsigned int n,             ///< size of the input and temporary/output 
arrays
+    ReductionOp reduction_op,  ///< Reduction operation to perform (functor 
object)
+       double initialValue)            ///< initial value for the reduction 
variable
 {
     extern __shared__ double sdata[];
 
@@ -208,29 +213,29 @@ __global__ void reduce(double *g_idata, double *g_odata, 
unsigned int n)
     unsigned int i = blockIdx.x*blockDim.x*2 + threadIdx.x;
     unsigned int gridSize = blockDim.x*2*gridDim.x;
 
-    double mySum = 0;
+    double v = initialValue;
 
     // we reduce multiple elements per thread.  The number is determined by the
     // number of active thread blocks (via gridDim).  More blocks will result
     // in a larger gridSize and therefore fewer elements per thread
     while (i < n)
     {
-        mySum += g_idata[i];
+        v = reduction_op(v, g_idata[i]);
         // ensure we don't read out of bounds
         if (i + blockDim.x < n)
-            mySum += g_idata[i+blockDim.x];
+            v = reduction_op(v, g_idata[i+blockDim.x]);
         i += gridSize;
     }
 
     // each thread puts its local sum into shared memory
-    sdata[tid] = mySum;
+    sdata[tid] = v;
     __syncthreads();
 
 
     // do reduction in shared mem
-    if (blockDim.x >= 512) { if (tid < 256) { sdata[tid] = mySum = mySum + 
sdata[tid + 256]; } __syncthreads(); }
-    if (blockDim.x >= 256) { if (tid < 128) { sdata[tid] = mySum = mySum + 
sdata[tid + 128]; } __syncthreads(); }
-    if (blockDim.x >= 128) { if (tid <  64) { sdata[tid] = mySum = mySum + 
sdata[tid +  64]; } __syncthreads(); }
+    if (blockDim.x >= 512) { if (tid < 256) { sdata[tid] = v = reduction_op(v, 
sdata[tid + 256]); } __syncthreads(); }
+    if (blockDim.x >= 256) { if (tid < 128) { sdata[tid] = v = reduction_op(v, 
sdata[tid + 128]); } __syncthreads(); }
+    if (blockDim.x >= 128) { if (tid <  64) { sdata[tid] = v = reduction_op(v, 
sdata[tid +  64]); } __syncthreads(); }
 
     if (tid < 32)
     {
@@ -238,12 +243,12 @@ __global__ void reduce(double *g_idata, double *g_odata, 
unsigned int n)
         // we need to declare our shared memory volatile so that the compiler
         // doesn't reorder stores to it and induce incorrect behavior.
         volatile double* smem = sdata;
-        if (blockDim.x >=  64) { smem[tid] = mySum = mySum + smem[tid + 32]; }
-        if (blockDim.x >=  32) { smem[tid] = mySum = mySum + smem[tid + 16]; }
-        if (blockDim.x >=  16) { smem[tid] = mySum = mySum + smem[tid +  8]; }
-        if (blockDim.x >=   8) { smem[tid] = mySum = mySum + smem[tid +  4]; }
-        if (blockDim.x >=   4) { smem[tid] = mySum = mySum + smem[tid +  2]; }
-        if (blockDim.x >=   2) { smem[tid] = mySum = mySum + smem[tid +  1]; }
+        if (blockDim.x >=  64) { smem[tid] = v = reduction_op(v, smem[tid + 
32]); }
+        if (blockDim.x >=  32) { smem[tid] = v = reduction_op(v, smem[tid + 
16]); }
+        if (blockDim.x >=  16) { smem[tid] = v = reduction_op(v, smem[tid +  
8]); }
+        if (blockDim.x >=   8) { smem[tid] = v = reduction_op(v, smem[tid +  
4]); }
+        if (blockDim.x >=   4) { smem[tid] = v = reduction_op(v, smem[tid +  
2]); }
+        if (blockDim.x >=   2) { smem[tid] = v = reduction_op(v, smem[tid +  
1]); }
     }
 
     // write result for this block to global mem
@@ -252,6 +257,7 @@ __global__ void reduce(double *g_idata, double *g_odata, 
unsigned int n)
 }
 
 
+
 /**
  * Does a reduce (sum) over each row of the array.
  * This kernel must be launched with as many blocks as there are rows.
@@ -280,21 +286,21 @@ __global__ void reduce_row(double *g_idata, double 
*g_odata, unsigned int rows,
     unsigned int i = tid;
     unsigned int block_offset = block * cols;
 
-    double mySum = 0;
+    double v = 0;
     while (i < cols){
-        mySum += g_idata[block_offset + i];
+        v += g_idata[block_offset + i];
         i += blockDim.x;
     }
 
     // each thread puts its local sum into shared memory
-    sdata[tid] = mySum;
+    sdata[tid] = v;
     __syncthreads();
 
 
     // do reduction in shared mem
-    if (blockDim.x >= 512) { if (tid < 256) { sdata[tid] = mySum = mySum + 
sdata[tid + 256]; } __syncthreads(); }
-    if (blockDim.x >= 256) { if (tid < 128) { sdata[tid] = mySum = mySum + 
sdata[tid + 128]; } __syncthreads(); }
-    if (blockDim.x >= 128) { if (tid <  64) { sdata[tid] = mySum = mySum + 
sdata[tid +  64]; } __syncthreads(); }
+    if (blockDim.x >= 512) { if (tid < 256) { sdata[tid] = v = v + sdata[tid + 
256]; } __syncthreads(); }
+    if (blockDim.x >= 256) { if (tid < 128) { sdata[tid] = v = v + sdata[tid + 
128]; } __syncthreads(); }
+    if (blockDim.x >= 128) { if (tid <  64) { sdata[tid] = v = v + sdata[tid + 
 64]; } __syncthreads(); }
 
     if (tid < 32)
     {
@@ -302,12 +308,12 @@ __global__ void reduce_row(double *g_idata, double 
*g_odata, unsigned int rows,
         // we need to declare our shared memory volatile so that the compiler
         // doesn't reorder stores to it and induce incorrect behavior.
         volatile double* smem = sdata;
-        if (blockDim.x >=  64) { smem[tid] = mySum = mySum + smem[tid + 32]; }
-        if (blockDim.x >=  32) { smem[tid] = mySum = mySum + smem[tid + 16]; }
-        if (blockDim.x >=  16) { smem[tid] = mySum = mySum + smem[tid +  8]; }
-        if (blockDim.x >=   8) { smem[tid] = mySum = mySum + smem[tid +  4]; }
-        if (blockDim.x >=   4) { smem[tid] = mySum = mySum + smem[tid +  2]; }
-        if (blockDim.x >=   2) { smem[tid] = mySum = mySum + smem[tid +  1]; }
+        if (blockDim.x >=  64) { smem[tid] = v = v + smem[tid + 32]; }
+        if (blockDim.x >=  32) { smem[tid] = v = v + smem[tid + 16]; }
+        if (blockDim.x >=  16) { smem[tid] = v = v + smem[tid +  8]; }
+        if (blockDim.x >=   8) { smem[tid] = v = v + smem[tid +  4]; }
+        if (blockDim.x >=   4) { smem[tid] = v = v + smem[tid +  2]; }
+        if (blockDim.x >=   2) { smem[tid] = v = v + smem[tid +  1]; }
     }
 
     // write result for this block to global mem
@@ -345,3 +351,71 @@ __global__ void reduce_col(double *g_idata, double 
*g_odata, unsigned int rows,
     }
     g_odata[global_tid] = val;
 }
+
+/**
+ * Functor op for summation operation
+ */
+typedef struct {
+    __device__ __forceinline__
+    double operator()(double a, double b) const {
+        return a + b;
+    }
+} SumOp;
+
+
+/**
+ * Do a summation over all elements of an array/matrix
+ * @param g_idata   input data stored in device memory (of size n)
+ * @param g_odata   output/temporary array stode in device memory (of size n)
+ * @param n         size of the input and temporary/output arrays
+ */
+extern "C"
+__global__ void reduce_sum(double *g_idata, double *g_odata, unsigned int n){
+       SumOp op;
+  reduce<SumOp>(g_idata, g_odata, n, op, 0.0);
+}
+
+/**
+ * Functor op for max operation
+ */
+typedef struct {
+    __device__ __forceinline__
+    double operator()(double a, double b) const {
+        return fmax(a, b);
+    }
+} MaxOp;
+
+
+/**
+ * Do a max over all elements of an array/matrix
+ * @param g_idata   input data stored in device memory (of size n)
+ * @param g_odata   output/temporary array stode in device memory (of size n)
+ * @param n         size of the input and temporary/output arrays
+ */
+extern "C"
+__global__ void reduce_max(double *g_idata, double *g_odata, unsigned int n){
+       MaxOp op;
+  reduce<MaxOp>(g_idata, g_odata, n, op, DBL_MIN);
+}
+
+/**
+ * Functor op for min operation
+ */
+typedef struct {
+    __device__ __forceinline__
+    double operator()(double a, double b) const {
+        return fmin(a, b);
+    }
+} MinOp;
+
+/**
+ * Do a min over all elements of an array/matrix
+ * @param g_idata   input data stored in device memory (of size n)
+ * @param g_odata   output/temporary array stode in device memory (of size n)
+ * @param n         size of the input and temporary/output arrays
+ */
+extern "C"
+__global__ void reduce_min(double *g_idata, double *g_odata, unsigned int n){
+       MinOp op;
+  reduce<MinOp>(g_idata, g_odata, n, op, DBL_MAX);
+}

Reply via email to