Repository: systemml
Updated Branches:
  refs/heads/master 2896f3316 -> d916ba5bd


[SYSTEMML-540] Added a rewrite to support a common tensor operation (sum over 
channels)

- Added a rewrite to convert out = rowSums(matrix(colSums(A), rows=C, cols=HW)) 
to out = channel_sums(A) when nrow(A) > 1 and exectype is CP or GPU.
- This avoids unnecessary intermediates and GPU-CP-GPU transfer (for
  reshape). This saves about ~150 seconds on sentence CNN for 200 epochs.
- When we move to a higher CuDNN version, we can replace the custom 
channel_sums kernel with possibly more optimized CuDNN reduce tensor kernel.
- Added the corresponding CPU and GPU tests.
- Updated T_MAX(val) to MAX(). Interestingly enough, nvcc was smart enough
  to remove the parameter automatically, hence the ptx remained the same
  after the change.

Closes #693.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/d916ba5b
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/d916ba5b
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/d916ba5b

Branch: refs/heads/master
Commit: d916ba5bd8ceec591a04f4d16c6d24f3985e3e4f
Parents: 2896f33
Author: Niketan Pansare <[email protected]>
Authored: Mon Oct 30 10:32:53 2017 -0700
Committer: Niketan Pansare <[email protected]>
Committed: Mon Oct 30 10:32:53 2017 -0700

----------------------------------------------------------------------
 src/main/cpp/kernels/Makefile                   |   2 +-
 src/main/cpp/kernels/SystemML.cu                | 308 ++++++++++---------
 .../java/org/apache/sysml/hops/AggUnaryOp.java  |  95 ++++--
 .../apache/sysml/lops/ConvolutionTransform.java |  42 ++-
 .../instructions/CPInstructionParser.java       |   1 +
 .../instructions/GPUInstructionParser.java      |   1 +
 .../cp/ConvolutionCPInstruction.java            |  86 ++++++
 .../gpu/ConvolutionGPUInstruction.java          |  47 +++
 .../spark/QuantilePickSPInstruction.java        |   2 +-
 .../runtime/matrix/data/LibMatrixCUDA.java      |  31 ++
 .../runtime/matrix/data/LibMatrixCuDNN.java     |   4 +-
 .../sysml/test/gpu/AggregateUnaryOpTests.java   |  31 ++
 .../apache/sysml/test/gpu/UnaryOpTestsBase.java |   8 +-
 .../functions/tensor/ChannelSumTest.java        | 146 +++++++++
 .../scripts/functions/tensor/ChannelSumTest.R   |  39 +++
 .../scripts/functions/tensor/ChannelSumTest.dml |  35 +++
 16 files changed, 690 insertions(+), 188 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/d916ba5b/src/main/cpp/kernels/Makefile
----------------------------------------------------------------------
diff --git a/src/main/cpp/kernels/Makefile b/src/main/cpp/kernels/Makefile
index 5feae69..ec10317 100644
--- a/src/main/cpp/kernels/Makefile
+++ b/src/main/cpp/kernels/Makefile
@@ -16,7 +16,7 @@
 # under the License.
 
 NVCC=nvcc
-CUDAFLAGS= -ptx -c -arch=sm_30 
+CUDAFLAGS= -ptx -c -arch=sm_30 --std c++11
 
 # Use these flags for precise math
 #CUDAFLAGS= -ptx -c -arch=sm_30 -ftz=false -prec-div=true -prec-sqrt=true

http://git-wip-us.apache.org/repos/asf/systemml/blob/d916ba5b/src/main/cpp/kernels/SystemML.cu
----------------------------------------------------------------------
diff --git a/src/main/cpp/kernels/SystemML.cu b/src/main/cpp/kernels/SystemML.cu
index d176f8f..ade2dd1 100644
--- a/src/main/cpp/kernels/SystemML.cu
+++ b/src/main/cpp/kernels/SystemML.cu
@@ -20,7 +20,7 @@
 /**********************************
 When updating a kernel or adding a new one,
 please compile the ptx file and commit it:
-nvcc -ptx -arch=sm_30 SystemML.cu
+nvcc -ptx -arch=sm_30 --std c++11 SystemML.cu
 ***********************************/
 
 #include <cfloat>
@@ -29,7 +29,8 @@ nvcc -ptx -arch=sm_30 SystemML.cu
 extern "C" __global__ void double2float_f(double *A, float *ret, int N) {
   int tid = blockIdx.x * blockDim.x + threadIdx.x;
   if (tid < N) {
-       // TODO: Use __double2float_rd or __double2float_rn  or 
__double2float_ru or __double2float_rz after 
+    // TODO: Use __double2float_rd or __double2float_rn  or __double2float_ru 
or
+    // __double2float_rz after
     ret[tid] = (float)A[tid];
   }
 }
@@ -84,15 +85,14 @@ __device__ void slice_sparse_dense_row(T *inVal, int 
*inRowPtr, int *colInd,
      *
      * int size = inRowPtr[rowIndex+1] - inRowPtr[rowIndex];
      * double numThreads = (double)min(size, MAX_NUM_THREADS_CHILD_KERNEL);
-     * slice_sparse_dense_row_helper<<< ceil(numThreads/
-*MAX_NUM_THREADS_CHILD_KERNEL), MAX_NUM_THREADS_CHILD_KERNEL>>>(inVal, 
inRowPtr,
-*colInd, ret,
-*                      rl, ru, cl, cu, retClen, inRowPtr[rowIndex],
-*inRowPtr[rowIndex+1], index);
-*
-* Two-step compilation and linking process in JCudaKernels's constructor:
-* cuLinkAddFile(linkState, CUjitInputType.CU_JIT_INPUT_LIBRARY,
-*"/usr/local/cuda/lib64/libcudadevrt.a", jitOptions);
+     * slice_sparse_dense_row_helper
+     * <<< ceil(numThreads/MAX_NUM_THREADS_CHILD_KERNEL), 
MAX_NUM_THREADS_CHILD_KERNEL>>>
+     * (inVal, inRowPtr, colInd, ret, rl, ru, cl, cu, retClen, 
inRowPtr[rowIndex],
+     * inRowPtr[rowIndex+1], index);
+     *
+     * Two-step compilation and linking process in JCudaKernels's constructor:
+     * cuLinkAddFile(linkState, CUjitInputType.CU_JIT_INPUT_LIBRARY,
+     * "/usr/local/cuda/lib64/libcudadevrt.a", jitOptions);
      */
     // Iterate over elements of the row 'rowIndex'.
     for (int i = inRowPtr[rowIndex]; i < inRowPtr[rowIndex + 1]; i++) {
@@ -104,17 +104,18 @@ __device__ void slice_sparse_dense_row(T *inVal, int 
*inRowPtr, int *colInd,
   }
 }
 
-extern "C" __global__ void slice_sparse_dense_row_d(double *inVal, int 
*inRowPtr,
-                                                   int *colInd, double *ret,
-                                                   int rl, int ru, int cl,
-                                                   int cu, int retClen) {
+extern "C" __global__ void slice_sparse_dense_row_d(double *inVal,
+                                                    int *inRowPtr, int *colInd,
+                                                    double *ret, int rl, int 
ru,
+                                                    int cl, int cu,
+                                                    int retClen) {
   slice_sparse_dense_row(inVal, inRowPtr, colInd, ret, rl, ru, cl, cu, 
retClen);
 }
 
 extern "C" __global__ void slice_sparse_dense_row_f(float *inVal, int 
*inRowPtr,
-                                                   int *colInd, float *ret,
-                                                   int rl, int ru, int cl,
-                                                   int cu, int retClen) {
+                                                    int *colInd, float *ret,
+                                                    int rl, int ru, int cl,
+                                                    int cu, int retClen) {
   slice_sparse_dense_row(inVal, inRowPtr, colInd, ret, rl, ru, cl, cu, 
retClen);
 }
 
@@ -153,17 +154,18 @@ __device__ void slice_sparse_dense_nnz(T *inVal, int 
*inRowPtr, int *colInd,
   }
 }
 
-extern "C" __global__ void slice_sparse_dense_nnz_d(double *inVal, int 
*inRowPtr,
-                                                   int *colInd, double *ret,
-                                                   int rl, int ru, int cl,
-                                                   int cu, int retClen) {
+extern "C" __global__ void slice_sparse_dense_nnz_d(double *inVal,
+                                                    int *inRowPtr, int *colInd,
+                                                    double *ret, int rl, int 
ru,
+                                                    int cl, int cu,
+                                                    int retClen) {
   slice_sparse_dense_nnz(inVal, inRowPtr, colInd, ret, rl, ru, cl, cu, 
retClen);
 }
 
 extern "C" __global__ void slice_sparse_dense_nnz_f(float *inVal, int 
*inRowPtr,
-                                                   int *colInd, float *ret,
-                                                   int rl, int ru, int cl,
-                                                   int cu, int retClen) {
+                                                    int *colInd, float *ret,
+                                                    int rl, int ru, int cl,
+                                                    int cu, int retClen) {
   slice_sparse_dense_nnz(inVal, inRowPtr, colInd, ret, rl, ru, cl, cu, 
retClen);
 }
 
@@ -194,16 +196,16 @@ __device__ void slice_dense_dense(T *in, T *ret, int rl, 
int ru, int cl, int cu,
 }
 
 extern "C" __global__ void slice_dense_dense_d(double *in, double *ret, int rl,
-                                              int ru, int cl, int cu,
-                                              int inClen, int retRlen,
-                                              int retClen) {
+                                               int ru, int cl, int cu,
+                                               int inClen, int retRlen,
+                                               int retClen) {
   slice_dense_dense(in, ret, rl, ru, cl, cu, inClen, retRlen, retClen);
 }
 
 extern "C" __global__ void slice_dense_dense_f(float *in, float *ret, int rl,
-                                              int ru, int cl, int cu,
-                                              int inClen, int retRlen,
-                                              int retClen) {
+                                               int ru, int cl, int cu,
+                                               int inClen, int retRlen,
+                                               int retClen) {
   slice_dense_dense(in, ret, rl, ru, cl, cu, inClen, retRlen, retClen);
 }
 
@@ -236,15 +238,15 @@ extern "C" __global__ void copy_u2l_dense_f(float *ret, 
int dim, int N) {
 
 // Use this method in templates to fetch the maximum value for a given datatype
 template <typename T>
-__forceinline__ __device__ T T_MAX(T x) {
-  return (T)DBL_MAX;
+__forceinline__ __device__ T MAX() {
+  return T();
 }
 template <>
-__forceinline__ __device__ float T_MAX(float x) {
+__forceinline__ __device__ float MAX <float>() {
   return FLT_MAX;
 }
 template <>
-__forceinline__ __device__ double T_MAX(double x) {
+__forceinline__ __device__ double MAX <double>() {
   return DBL_MAX;
 }
 
@@ -311,7 +313,7 @@ __forceinline__ __device__ T binaryOp(T x, T y, int op) {
       }
     }
     default:
-      return T_MAX(x);
+      return MAX<T>();
   }
 }
 
@@ -342,7 +344,8 @@ extern "C" __global__ void relu_f(float *A, float *ret, int 
rlen, int clen) {
 }
 
 /**
- * This method computes the backpropagation errors for previous layer of relu 
operation
+ * This method computes the backpropagation errors for previous layer of relu
+ * operation
  *
  * @param X input activation array allocated on the GPU
  * @param dout errors from previous layer
@@ -361,12 +364,12 @@ __device__ void relu_backward(T *X, T *dout, T *ret, int 
rlen, int clen) {
 }
 
 extern "C" __global__ void relu_backward_d(double *X, double *dout, double 
*ret,
-                                          int rlen, int clen) {
+                                           int rlen, int clen) {
   relu_backward(X, dout, ret, rlen, clen);
 }
 
 extern "C" __global__ void relu_backward_f(float *X, float *dout, float *ret,
-                                          int rlen, int clen) {
+                                           int rlen, int clen) {
   relu_backward(X, dout, ret, rlen, clen);
 }
 
@@ -389,12 +392,12 @@ __device__ void inplace_add(T *input, T *ret, int rlen, 
int clen) {
 }
 
 extern "C" __global__ void inplace_add_d(double *input, double *ret, int rlen,
-                                        int clen) {
+                                         int clen) {
   inplace_add(input, ret, rlen, clen);
 }
 
 extern "C" __global__ void inplace_add_f(float *input, float *ret, int rlen,
-                                        int clen) {
+                                         int clen) {
   inplace_add(input, ret, rlen, clen);
 }
 
@@ -416,12 +419,12 @@ __device__ void bias_add(T *input, T *bias, T *ret, int 
rlen, int clen,
 }
 
 extern "C" __global__ void bias_add_d(double *input, double *bias, double *ret,
-                                     int rlen, int clen, int PQ) {
+                                      int rlen, int clen, int PQ) {
   bias_add(input, bias, ret, rlen, clen, PQ);
 }
 
 extern "C" __global__ void bias_add_f(float *input, float *bias, float *ret,
-                                     int rlen, int clen, int PQ) {
+                                      int rlen, int clen, int PQ) {
   bias_add(input, bias, ret, rlen, clen, PQ);
 }
 
@@ -443,16 +446,16 @@ __device__ void daxpy_matrix_vector(T *A, T *B, double 
alpha, T *ret, int rlenA,
 }
 
 extern "C" __global__ void daxpy_matrix_vector_d(double *A, double *B,
-                                                double alpha, double *ret,
-                                                int rlenA, int clenA, int 
rlenB,
-                                                int clenB) {
+                                                 double alpha, double *ret,
+                                                 int rlenA, int clenA,
+                                                 int rlenB, int clenB) {
   daxpy_matrix_vector(A, B, alpha, ret, rlenA, clenA, rlenB, clenB);
 }
 
 extern "C" __global__ void daxpy_matrix_vector_f(float *A, float *B,
-                                                double alpha, float *ret,
-                                                int rlenA, int clenA, int 
rlenB,
-                                                int clenB) {
+                                                 double alpha, float *ret,
+                                                 int rlenA, int clenA,
+                                                 int rlenB, int clenB) {
   daxpy_matrix_vector(A, B, alpha, ret, rlenA, clenA, rlenB, clenB);
 }
 
@@ -471,13 +474,14 @@ __device__ void bias_multiply(T *input, T *bias, T *ret, 
int rlen, int clen,
 }
 
 extern "C" __global__ void bias_multiply_d(double *input, double *bias,
-                                          double *ret, int rlen, int clen,
-                                          int PQ) {
+                                           double *ret, int rlen, int clen,
+                                           int PQ) {
   bias_multiply(input, bias, ret, rlen, clen, PQ);
 }
 
-extern "C" __global__ void bias_multiply_f(float *input, float *bias, float 
*ret,
-                                          int rlen, int clen, int PQ) {
+extern "C" __global__ void bias_multiply_f(float *input, float *bias,
+                                           float *ret, int rlen, int clen,
+                                           int PQ) {
   bias_multiply(input, bias, ret, rlen, clen, PQ);
 }
 
@@ -563,14 +567,14 @@ __device__ void matrix_scalar_op(T *A, T scalar, T *C, 
int size, int op,
 }
 
 extern "C" __global__ void matrix_scalar_op_d(double *A, double scalar,
-                                             double *C, int size, int op,
-                                             int isLeftScalar) {
+                                              double *C, int size, int op,
+                                              int isLeftScalar) {
   matrix_scalar_op(A, scalar, C, size, op, isLeftScalar);
 }
 
 extern "C" __global__ void matrix_scalar_op_f(float *A, double scalar, float 
*C,
-                                             int size, int op,
-                                             int isLeftScalar) {
+                                              int size, int op,
+                                              int isLeftScalar) {
   matrix_scalar_op(A, (float)scalar, C, size, op, isLeftScalar);
 }
 
@@ -635,12 +639,12 @@ __device__ void cbind(T *A, T *B, T *C, int rowsA, int 
colsA, int rowsB,
 }
 
 extern "C" __global__ void cbind_d(double *A, double *B, double *C, int rowsA,
-                                  int colsA, int rowsB, int colsB) {
+                                   int colsA, int rowsB, int colsB) {
   cbind(A, B, C, rowsA, colsA, rowsB, colsB);
 }
 
 extern "C" __global__ void cbind_f(float *A, float *B, float *C, int rowsA,
-                                  int colsA, int rowsB, int colsB) {
+                                   int colsA, int rowsB, int colsB) {
   cbind(A, B, C, rowsA, colsA, rowsB, colsB);
 }
 
@@ -684,12 +688,12 @@ __device__ void rbind(T *A, T *B, T *C, int rowsA, int 
colsA, int rowsB,
 }
 
 extern "C" __global__ void rbind_d(double *A, double *B, double *C, int rowsA,
-                                  int colsA, int rowsB, int colsB) {
+                                   int colsA, int rowsB, int colsB) {
   rbind(A, B, C, rowsA, colsA, rowsB, colsB);
 }
 
 extern "C" __global__ void rbind_f(float *A, float *B, float *C, int rowsA,
-                                  int colsA, int rowsB, int colsB) {
+                                   int colsA, int rowsB, int colsB) {
   rbind(A, B, C, rowsA, colsA, rowsB, colsB);
 }
 
@@ -828,15 +832,15 @@ template <typename ReductionOp, typename AssignmentOp, 
typename T>
 __device__ void reduce_row(
     T *g_idata,  ///< input data stored in device memory (of size rows*cols)
     T *g_odata,  ///< output/temporary array store in device memory (of size
-                 ///rows*cols)
+    /// rows*cols)
     unsigned int rows,  ///< rows in input and temporary/output arrays
     unsigned int cols,  ///< columns in input and temporary/output arrays
     ReductionOp
         reduction_op,  ///< Reduction operation to perform (functor object)
     AssignmentOp assignment_op,  ///< Operation to perform before assigning 
this
-                                 ///to its final location in global memory for
-                                 ///each row
-    T initialValue) {            ///< initial value for the reduction variable
+    /// to its final location in global memory for
+    /// each row
+    T initialValue) {  ///< initial value for the reduction variable
   // extern __shared__ T sdata[];
   extern __shared__ __align__(sizeof(T)) unsigned char my_sdata[];
   T *sdata = reinterpret_cast<T *>(my_sdata);
@@ -935,15 +939,15 @@ template <typename ReductionOp, typename AssignmentOp, 
typename T>
 __device__ void reduce_col(
     T *g_idata,  ///< input data stored in device memory (of size rows*cols)
     T *g_odata,  ///< output/temporary array store in device memory (of size
-                 ///rows*cols)
+    /// rows*cols)
     unsigned int rows,  ///< rows in input and temporary/output arrays
     unsigned int cols,  ///< columns in input and temporary/output arrays
     ReductionOp
         reduction_op,  ///< Reduction operation to perform (functor object)
     AssignmentOp assignment_op,  ///< Operation to perform before assigning 
this
-                                 ///to its final location in global memory for
-                                 ///each column
-    T initialValue)              ///< initial value for the reduction variable
+    /// to its final location in global memory for
+    /// each column
+    T initialValue)  ///< initial value for the reduction variable
 {
   unsigned int global_tid = blockIdx.x * blockDim.x + threadIdx.x;
   if (global_tid >= cols) {
@@ -990,12 +994,12 @@ __device__ void reduce_sum(T *g_idata, T *g_odata, 
unsigned int n) {
 }
 
 extern "C" __global__ void reduce_sum_d(double *g_idata, double *g_odata,
-                                       unsigned int n) {
+                                        unsigned int n) {
   reduce_sum(g_idata, g_odata, n);
 }
 
 extern "C" __global__ void reduce_sum_f(float *g_idata, float *g_odata,
-                                       unsigned int n) {
+                                        unsigned int n) {
   reduce_sum(g_idata, g_odata, n);
 }
 
@@ -1016,14 +1020,14 @@ __device__ void reduce_row_sum(T *g_idata, T *g_odata, 
unsigned int rows,
 }
 
 extern "C" __global__ void reduce_row_sum_d(double *g_idata, double *g_odata,
-                                           unsigned int rows,
-                                           unsigned int cols) {
+                                            unsigned int rows,
+                                            unsigned int cols) {
   reduce_row_sum(g_idata, g_odata, rows, cols);
 }
 
 extern "C" __global__ void reduce_row_sum_f(float *g_idata, float *g_odata,
-                                           unsigned int rows,
-                                           unsigned int cols) {
+                                            unsigned int rows,
+                                            unsigned int cols) {
   reduce_row_sum(g_idata, g_odata, rows, cols);
 }
 
@@ -1044,14 +1048,14 @@ __device__ void reduce_col_sum(T *g_idata, T *g_odata, 
unsigned int rows,
 }
 
 extern "C" __global__ void reduce_col_sum_d(double *g_idata, double *g_odata,
-                                           unsigned int rows,
-                                           unsigned int cols) {
+                                            unsigned int rows,
+                                            unsigned int cols) {
   reduce_col_sum(g_idata, g_odata, rows, cols);
 }
 
 extern "C" __global__ void reduce_col_sum_f(float *g_idata, float *g_odata,
-                                           unsigned int rows,
-                                           unsigned int cols) {
+                                            unsigned int rows,
+                                            unsigned int cols) {
   reduce_col_sum(g_idata, g_odata, rows, cols);
 }
 
@@ -1063,12 +1067,13 @@ struct MaxOp {
   __device__ __forceinline__ T operator()(T a, T b) const { return fmax(a, b); 
}
 };
 
-template<>
+template <>
 struct MaxOp<float> {
-  __device__ __forceinline__ float operator()(float a, float b) const { return 
fmaxf(a, b); }
+  __device__ __forceinline__ float operator()(float a, float b) const {
+    return fmaxf(a, b);
+  }
 };
 
-
 /**
  * Do a max over all elements of an array/matrix
  * @param g_idata   input data stored in device memory (of size n)
@@ -1078,16 +1083,16 @@ struct MaxOp<float> {
 template <typename T>
 __device__ void reduce_max(T *g_idata, T *g_odata, unsigned int n) {
   MaxOp<T> op;
-  reduce<MaxOp<T>, T>(g_idata, g_odata, n, op, -T_MAX(g_idata[0]));
+  reduce<MaxOp<T>, T>(g_idata, g_odata, n, op, -MAX<T>());
 }
 
 extern "C" __global__ void reduce_max_d(double *g_idata, double *g_odata,
-                                       unsigned int n) {
+                                        unsigned int n) {
   reduce_max(g_idata, g_odata, n);
 }
 
 extern "C" __global__ void reduce_max_f(float *g_idata, float *g_odata,
-                                       unsigned int n) {
+                                        unsigned int n) {
   reduce_max(g_idata, g_odata, n);
 }
 
@@ -1104,18 +1109,18 @@ __device__ void reduce_row_max(T *g_idata, T *g_odata, 
unsigned int rows,
   MaxOp<T> op;
   IdentityOp<T> aop;
   reduce_row<MaxOp<T>, IdentityOp<T>, T>(g_idata, g_odata, rows, cols, op, aop,
-                                         -T_MAX(g_idata[0]));
+                                         -MAX<T>());
 }
 
 extern "C" __global__ void reduce_row_max_d(double *g_idata, double *g_odata,
-                                           unsigned int rows,
-                                           unsigned int cols) {
+                                            unsigned int rows,
+                                            unsigned int cols) {
   reduce_row_max(g_idata, g_odata, rows, cols);
 }
 
 extern "C" __global__ void reduce_row_max_f(float *g_idata, float *g_odata,
-                                           unsigned int rows,
-                                           unsigned int cols) {
+                                            unsigned int rows,
+                                            unsigned int cols) {
   reduce_row_max(g_idata, g_odata, rows, cols);
 }
 
@@ -1132,18 +1137,18 @@ __device__ void reduce_col_max(T *g_idata, T *g_odata, 
unsigned int rows,
   MaxOp<T> op;
   IdentityOp<T> aop;
   reduce_col<MaxOp<T>, IdentityOp<T>, T>(g_idata, g_odata, rows, cols, op, aop,
-                                         (T)-T_MAX(g_idata[0]));
+                                         -MAX<T>());
 }
 
 extern "C" __global__ void reduce_col_max_d(double *g_idata, double *g_odata,
-                                           unsigned int rows,
-                                           unsigned int cols) {
+                                            unsigned int rows,
+                                            unsigned int cols) {
   reduce_col_max(g_idata, g_odata, rows, cols);
 }
 
 extern "C" __global__ void reduce_col_max_f(float *g_idata, float *g_odata,
-                                           unsigned int rows,
-                                           unsigned int cols) {
+                                            unsigned int rows,
+                                            unsigned int cols) {
   reduce_col_max(g_idata, g_odata, rows, cols);
 }
 
@@ -1164,16 +1169,16 @@ struct MinOp {
 template <typename T>
 __device__ void reduce_min(T *g_idata, T *g_odata, unsigned int n) {
   MinOp<T> op;
-  reduce<MinOp<T>, T>(g_idata, g_odata, n, op, T_MAX(g_idata[0]));
+  reduce<MinOp<T>, T>(g_idata, g_odata, n, op, MAX<T>());
 }
 
 extern "C" __global__ void reduce_min_d(double *g_idata, double *g_odata,
-                                       unsigned int n) {
+                                        unsigned int n) {
   reduce_min(g_idata, g_odata, n);
 }
 
 extern "C" __global__ void reduce_min_f(float *g_idata, float *g_odata,
-                                       unsigned int n) {
+                                        unsigned int n) {
   reduce_min(g_idata, g_odata, n);
 }
 
@@ -1190,18 +1195,18 @@ __device__ void reduce_row_min(T *g_idata, T *g_odata, 
unsigned int rows,
   MinOp<T> op;
   IdentityOp<T> aop;
   reduce_row<MinOp<T>, IdentityOp<T>, T>(g_idata, g_odata, rows, cols, op, aop,
-                                         T_MAX(g_idata[0]));
+                                         MAX<T>());
 }
 
 extern "C" __global__ void reduce_row_min_d(double *g_idata, double *g_odata,
-                                           unsigned int rows,
-                                           unsigned int cols) {
+                                            unsigned int rows,
+                                            unsigned int cols) {
   reduce_row_min(g_idata, g_odata, rows, cols);
 }
 
 extern "C" __global__ void reduce_row_min_f(float *g_idata, float *g_odata,
-                                           unsigned int rows,
-                                           unsigned int cols) {
+                                            unsigned int rows,
+                                            unsigned int cols) {
   reduce_row_min(g_idata, g_odata, rows, cols);
 }
 
@@ -1218,18 +1223,18 @@ __device__ void reduce_col_min(T *g_idata, T *g_odata, 
unsigned int rows,
   MinOp<T> op;
   IdentityOp<T> aop;
   reduce_col<MinOp<T>, IdentityOp<T>, T>(g_idata, g_odata, rows, cols, op, aop,
-                                         T_MAX(g_idata[0]));
+                                         MAX<T>());
 }
 
 extern "C" __global__ void reduce_col_min_d(double *g_idata, double *g_odata,
-                                           unsigned int rows,
-                                           unsigned int cols) {
+                                            unsigned int rows,
+                                            unsigned int cols) {
   reduce_col_min(g_idata, g_odata, rows, cols);
 }
 
 extern "C" __global__ void reduce_col_min_f(float *g_idata, float *g_odata,
-                                           unsigned int rows,
-                                           unsigned int cols) {
+                                            unsigned int rows,
+                                            unsigned int cols) {
   reduce_col_min(g_idata, g_odata, rows, cols);
 }
 
@@ -1254,12 +1259,12 @@ __device__ void reduce_prod(T *g_idata, T *g_odata, 
unsigned int n) {
 }
 
 extern "C" __global__ void reduce_prod_d(double *g_idata, double *g_odata,
-                                        unsigned int n) {
+                                         unsigned int n) {
   reduce_prod(g_idata, g_odata, n);
 }
 
 extern "C" __global__ void reduce_prod_f(float *g_idata, float *g_odata,
-                                        unsigned int n) {
+                                         unsigned int n) {
   reduce_prod(g_idata, g_odata, n);
 }
 
@@ -1293,14 +1298,14 @@ __device__ void reduce_row_mean(T *g_idata, T *g_odata, 
unsigned int rows,
 }
 
 extern "C" __global__ void reduce_row_mean_d(double *g_idata, double *g_odata,
-                                            unsigned int rows,
-                                            unsigned int cols) {
+                                             unsigned int rows,
+                                             unsigned int cols) {
   reduce_row_mean(g_idata, g_odata, rows, cols);
 }
 
 extern "C" __global__ void reduce_row_mean_f(float *g_idata, float *g_odata,
-                                            unsigned int rows,
-                                            unsigned int cols) {
+                                             unsigned int rows,
+                                             unsigned int cols) {
   reduce_row_mean(g_idata, g_odata, rows, cols);
 }
 
@@ -1321,14 +1326,14 @@ __device__ void reduce_col_mean(T *g_idata, T *g_odata, 
unsigned int rows,
 }
 
 extern "C" __global__ void reduce_col_mean_d(double *g_idata, double *g_odata,
-                                            unsigned int rows,
-                                            unsigned int cols) {
+                                             unsigned int rows,
+                                             unsigned int cols) {
   reduce_col_mean(g_idata, g_odata, rows, cols);
 }
 
 extern "C" __global__ void reduce_col_mean_f(float *g_idata, float *g_odata,
-                                            unsigned int rows,
-                                            unsigned int cols) {
+                                             unsigned int rows,
+                                             unsigned int cols) {
   reduce_col_mean(g_idata, g_odata, rows, cols);
 }
 
@@ -1347,7 +1352,7 @@ __device__ void matrix_exp(T *A, T *C, unsigned int size) 
{
 }
 
 extern "C" __global__ void matrix_exp_d(double *A, double *C,
-                                       unsigned int size) {
+                                        unsigned int size) {
   matrix_exp(A, C, size);
 }
 
@@ -1370,11 +1375,12 @@ __device__ void matrix_sqrt(T *A, T *C, unsigned int 
size) {
 }
 
 extern "C" __global__ void matrix_sqrt_d(double *A, double *C,
-                                        unsigned int size) {
+                                         unsigned int size) {
   matrix_sqrt(A, C, size);
 }
 
-extern "C" __global__ void matrix_sqrt_f(float *A, float *C, unsigned int 
size) {
+extern "C" __global__ void matrix_sqrt_f(float *A, float *C,
+                                         unsigned int size) {
   matrix_sqrt(A, C, size);
 }
 
@@ -1393,12 +1399,12 @@ __device__ void matrix_round(T *A, T *C, unsigned int 
size) {
 }
 
 extern "C" __global__ void matrix_round_d(double *A, double *C,
-                                         unsigned int size) {
+                                          unsigned int size) {
   matrix_round(A, C, size);
 }
 
 extern "C" __global__ void matrix_round_f(float *A, float *C,
-                                         unsigned int size) {
+                                          unsigned int size) {
   matrix_round(A, C, size);
 }
 
@@ -1417,7 +1423,7 @@ __device__ void matrix_abs(T *A, T *C, unsigned int size) 
{
 }
 
 extern "C" __global__ void matrix_abs_d(double *A, double *C,
-                                       unsigned int size) {
+                                        unsigned int size) {
   matrix_abs(A, C, size);
 }
 
@@ -1440,7 +1446,7 @@ __device__ void matrix_log(T *A, T *C, unsigned int size) 
{
 }
 
 extern "C" __global__ void matrix_log_d(double *A, double *C,
-                                       unsigned int size) {
+                                        unsigned int size) {
   matrix_log(A, C, size);
 }
 
@@ -1463,12 +1469,12 @@ __device__ void matrix_floor(T *A, T *C, unsigned int 
size) {
 }
 
 extern "C" __global__ void matrix_floor_d(double *A, double *C,
-                                         unsigned int size) {
+                                          unsigned int size) {
   matrix_floor(A, C, size);
 }
 
 extern "C" __global__ void matrix_floor_f(float *A, float *C,
-                                         unsigned int size) {
+                                          unsigned int size) {
   matrix_floor(A, C, size);
 }
 
@@ -1487,11 +1493,12 @@ __device__ void matrix_ceil(T *A, T *C, unsigned int 
size) {
 }
 
 extern "C" __global__ void matrix_ceil_d(double *A, double *C,
-                                        unsigned int size) {
+                                         unsigned int size) {
   matrix_ceil(A, C, size);
 }
 
-extern "C" __global__ void matrix_ceil_f(float *A, float *C, unsigned int 
size) {
+extern "C" __global__ void matrix_ceil_f(float *A, float *C,
+                                         unsigned int size) {
   matrix_ceil(A, C, size);
 }
 
@@ -1510,7 +1517,7 @@ __device__ void matrix_sin(T *A, T *C, unsigned int size) 
{
 }
 
 extern "C" __global__ void matrix_sin_d(double *A, double *C,
-                                       unsigned int size) {
+                                        unsigned int size) {
   matrix_sin(A, C, size);
 }
 
@@ -1533,11 +1540,12 @@ __device__ void matrix_sinh(T *A, T *C, unsigned int 
size) {
 }
 
 extern "C" __global__ void matrix_sinh_d(double *A, double *C,
-                                        unsigned int size) {
+                                         unsigned int size) {
   matrix_sinh(A, C, size);
 }
 
-extern "C" __global__ void matrix_sinh_f(float *A, float *C, unsigned int 
size) {
+extern "C" __global__ void matrix_sinh_f(float *A, float *C,
+                                         unsigned int size) {
   matrix_sinh(A, C, size);
 }
 
@@ -1556,7 +1564,7 @@ __device__ void matrix_cos(T *A, T *C, unsigned int size) 
{
 }
 
 extern "C" __global__ void matrix_cos_d(double *A, double *C,
-                                       unsigned int size) {
+                                        unsigned int size) {
   matrix_cos(A, C, size);
 }
 
@@ -1579,11 +1587,12 @@ __device__ void matrix_cosh(T *A, T *C, unsigned int 
size) {
 }
 
 extern "C" __global__ void matrix_cosh_d(double *A, double *C,
-                                        unsigned int size) {
+                                         unsigned int size) {
   matrix_cosh(A, C, size);
 }
 
-extern "C" __global__ void matrix_cosh_f(float *A, float *C, unsigned int 
size) {
+extern "C" __global__ void matrix_cosh_f(float *A, float *C,
+                                         unsigned int size) {
   matrix_cosh(A, C, size);
 }
 
@@ -1602,7 +1611,7 @@ __device__ void matrix_tan(T *A, T *C, unsigned int size) 
{
 }
 
 extern "C" __global__ void matrix_tan_d(double *A, double *C,
-                                       unsigned int size) {
+                                        unsigned int size) {
   matrix_tan(A, C, size);
 }
 
@@ -1625,11 +1634,12 @@ __device__ void matrix_tanh(T *A, T *C, unsigned int 
size) {
 }
 
 extern "C" __global__ void matrix_tanh_d(double *A, double *C,
-                                        unsigned int size) {
+                                         unsigned int size) {
   matrix_tanh(A, C, size);
 }
 
-extern "C" __global__ void matrix_tanh_f(float *A, float *C, unsigned int 
size) {
+extern "C" __global__ void matrix_tanh_f(float *A, float *C,
+                                         unsigned int size) {
   matrix_tanh(A, C, size);
 }
 
@@ -1648,11 +1658,12 @@ __device__ void matrix_asin(T *A, T *C, unsigned int 
size) {
 }
 
 extern "C" __global__ void matrix_asin_d(double *A, double *C,
-                                        unsigned int size) {
+                                         unsigned int size) {
   matrix_asin(A, C, size);
 }
 
-extern "C" __global__ void matrix_asin_f(float *A, float *C, unsigned int 
size) {
+extern "C" __global__ void matrix_asin_f(float *A, float *C,
+                                         unsigned int size) {
   matrix_asin(A, C, size);
 }
 
@@ -1671,11 +1682,12 @@ __device__ void matrix_acos(T *A, T *C, unsigned int 
size) {
 }
 
 extern "C" __global__ void matrix_acos_d(double *A, double *C,
-                                        unsigned int size) {
+                                         unsigned int size) {
   matrix_acos(A, C, size);
 }
 
-extern "C" __global__ void matrix_acos_f(float *A, float *C, unsigned int 
size) {
+extern "C" __global__ void matrix_acos_f(float *A, float *C,
+                                         unsigned int size) {
   matrix_acos(A, C, size);
 }
 
@@ -1694,11 +1706,12 @@ __device__ void matrix_atan(T *A, T *C, unsigned int 
size) {
 }
 
 extern "C" __global__ void matrix_atan_d(double *A, double *C,
-                                        unsigned int size) {
+                                         unsigned int size) {
   matrix_atan(A, C, size);
 }
 
-extern "C" __global__ void matrix_atan_f(float *A, float *C, unsigned int 
size) {
+extern "C" __global__ void matrix_atan_f(float *A, float *C,
+                                         unsigned int size) {
   matrix_atan(A, C, size);
 }
 
@@ -1722,10 +1735,11 @@ __device__ void matrix_sign(T *A, T *C, unsigned int 
size) {
 }
 
 extern "C" __global__ void matrix_sign_d(double *A, double *C,
-                                        unsigned int size) {
+                                         unsigned int size) {
   matrix_sign(A, C, size);
 }
 
-extern "C" __global__ void matrix_sign_f(float *A, float *C, unsigned int 
size) {
+extern "C" __global__ void matrix_sign_f(float *A, float *C,
+                                         unsigned int size) {
   matrix_sign(A, C, size);
-}
\ No newline at end of file
+}

http://git-wip-us.apache.org/repos/asf/systemml/blob/d916ba5b/src/main/java/org/apache/sysml/hops/AggUnaryOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/AggUnaryOp.java 
b/src/main/java/org/apache/sysml/hops/AggUnaryOp.java
index 04a32bd..9b9406a 100644
--- a/src/main/java/org/apache/sysml/hops/AggUnaryOp.java
+++ b/src/main/java/org/apache/sysml/hops/AggUnaryOp.java
@@ -26,6 +26,7 @@ import org.apache.sysml.hops.rewrite.HopRewriteUtils;
 import org.apache.sysml.lops.Aggregate;
 import org.apache.sysml.lops.Aggregate.OperationTypes;
 import org.apache.sysml.lops.Binary;
+import org.apache.sysml.lops.ConvolutionTransform;
 import org.apache.sysml.lops.Group;
 import org.apache.sysml.lops.Lop;
 import org.apache.sysml.lops.LopsException;
@@ -131,6 +132,20 @@ public class AggUnaryOp extends Hop implements 
MultiThreadedHop
                return false;
        }
        
+       /**
+        * Checks if channels sum rewrite is applicable
+        * 
+        * @return returns true for pattern rowSums(matrix(colSums(X), rows=.., 
cols=..)) else false
+        */
+       private boolean isChannelSumRewriteApplicable() {
+               if( OptimizerUtils.ALLOW_OPERATOR_FUSION && _op == AggOp.SUM && 
_direction == Direction.Row
+                       && getInput().get(0) instanceof ReorgOp && 
((ReorgOp)getInput().get(0)).getOp() == ReOrgOp.RESHAPE) {
+                       Hop input1 = getInput().get(0).getInput().get(0);
+                       return input1 instanceof AggUnaryOp && 
((AggUnaryOp)input1)._op == AggOp.SUM && ((AggUnaryOp)input1)._direction == 
Direction.Col;
+               }
+               return false;
+       }
+       
        @Override
        public Lop constructLops()
                throws HopsException, LopsException 
@@ -147,41 +162,57 @@ public class AggUnaryOp extends Hop implements 
MultiThreadedHop
                        if ( et == ExecType.CP || et == ExecType.GPU ) 
                        {
                                Lop agg1 = null;
-                               if( isTernaryAggregateRewriteApplicable() ) {
-                                       agg1 = 
constructLopsTernaryAggregateRewrite(et);
+                               long numChannels = 
isChannelSumRewriteApplicable() ? 
Hop.computeSizeInformation(getInput().get(0).getInput().get(1)) : -1;
+                               if(numChannels > 0 && numChannels < 1000000) {
+                                       // Apply channel sums only if rewrite 
is applicable and if the dimension of C is known at compile time
+                                       // and if numChannels is less than 8 MB.
+                                       ReorgOp in = 
((ReorgOp)getInput().get(0));
+                                       agg1 = new ConvolutionTransform(
+                                                       
in.getInput().get(0).getInput().get(0).constructLops(), 
+                                                       
in.getInput().get(1).constructLops(),
+                                                       
in.getInput().get(2).constructLops(),
+                                                       
ConvolutionTransform.OperationTypes.CHANNEL_SUMS, getDataType(), 
getValueType(), et, -1);
+                                       
agg1.getOutputParameters().setDimensions(numChannels, 1, getRowsInBlock(), 
getColsInBlock(), -1);
+                                       setLineNumbers(agg1);
+                                       setLops(agg1);
                                }
-                               else if( 
isUnaryAggregateOuterCPRewriteApplicable() )
-                               {
-                                       OperationTypes op = 
HopsAgg2Lops.get(_op);
-                                       DirectionTypes dir = 
HopsDirection2Lops.get(_direction);
-
-                                       BinaryOp binput = 
(BinaryOp)getInput().get(0);
-                                       agg1 = new UAggOuterChain( 
binput.getInput().get(0).constructLops(), 
-                                                       
binput.getInput().get(1).constructLops(), op, dir, 
-                                                       
HopsOpOp2LopsB.get(binput.getOp()), DataType.MATRIX, getValueType(), 
ExecType.CP);
-                                       
PartialAggregate.setDimensionsBasedOnDirection(agg1, getDim1(), getDim2(), 
input.getRowsInBlock(), input.getColsInBlock(), dir);
-                               
+                               else { 
+                                       if( 
isTernaryAggregateRewriteApplicable() ) {
+                                               agg1 = 
constructLopsTernaryAggregateRewrite(et);
+                                       }
+                                       else if( 
isUnaryAggregateOuterCPRewriteApplicable() )
+                                       {
+                                               OperationTypes op = 
HopsAgg2Lops.get(_op);
+                                               DirectionTypes dir = 
HopsDirection2Lops.get(_direction);
+       
+                                               BinaryOp binput = 
(BinaryOp)getInput().get(0);
+                                               agg1 = new UAggOuterChain( 
binput.getInput().get(0).constructLops(), 
+                                                               
binput.getInput().get(1).constructLops(), op, dir, 
+                                                               
HopsOpOp2LopsB.get(binput.getOp()), DataType.MATRIX, getValueType(), 
ExecType.CP);
+                                               
PartialAggregate.setDimensionsBasedOnDirection(agg1, getDim1(), getDim2(), 
input.getRowsInBlock(), input.getColsInBlock(), dir);
+                                       
+                                               if (getDataType() == 
DataType.SCALAR) {
+                                                       UnaryCP unary1 = new 
UnaryCP(agg1, HopsOpOp1LopsUS.get(OpOp1.CAST_AS_SCALAR),
+                                                                               
            getDataType(), getValueType());
+                                                       
unary1.getOutputParameters().setDimensions(0, 0, 0, 0, -1);
+                                                       setLineNumbers(unary1);
+                                                       setLops(unary1);
+                                               }
+                                       
+                                       }                               
+                                       else { //general case           
+                                               int k = 
OptimizerUtils.getConstrainedNumThreads(_maxNumThreads);
+                                               agg1 = new 
PartialAggregate(input.constructLops(), 
+                                                               
HopsAgg2Lops.get(_op), HopsDirection2Lops.get(_direction), 
getDataType(),getValueType(), et, k);
+                                       }
+                                       
+                                       setOutputDimensions(agg1);
+                                       setLineNumbers(agg1);
+                                       setLops(agg1);
+                                       
                                        if (getDataType() == DataType.SCALAR) {
-                                               UnaryCP unary1 = new 
UnaryCP(agg1, HopsOpOp1LopsUS.get(OpOp1.CAST_AS_SCALAR),
-                                                                               
    getDataType(), getValueType());
-                                               
unary1.getOutputParameters().setDimensions(0, 0, 0, 0, -1);
-                                               setLineNumbers(unary1);
-                                               setLops(unary1);
+                                               
agg1.getOutputParameters().setDimensions(1, 1, getRowsInBlock(), 
getColsInBlock(), getNnz());
                                        }
-                               
-                               }                               
-                               else { //general case           
-                                       int k = 
OptimizerUtils.getConstrainedNumThreads(_maxNumThreads);
-                                       agg1 = new 
PartialAggregate(input.constructLops(), 
-                                                       HopsAgg2Lops.get(_op), 
HopsDirection2Lops.get(_direction), getDataType(),getValueType(), et, k);
-                               }
-                               
-                               setOutputDimensions(agg1);
-                               setLineNumbers(agg1);
-                               setLops(agg1);
-                               
-                               if (getDataType() == DataType.SCALAR) {
-                                       
agg1.getOutputParameters().setDimensions(1, 1, getRowsInBlock(), 
getColsInBlock(), getNnz());
                                }
                        }
                        else if( et == ExecType.MR )

http://git-wip-us.apache.org/repos/asf/systemml/blob/d916ba5b/src/main/java/org/apache/sysml/lops/ConvolutionTransform.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/lops/ConvolutionTransform.java 
b/src/main/java/org/apache/sysml/lops/ConvolutionTransform.java
index dfc187c..94a67f0 100644
--- a/src/main/java/org/apache/sysml/lops/ConvolutionTransform.java
+++ b/src/main/java/org/apache/sysml/lops/ConvolutionTransform.java
@@ -32,7 +32,7 @@ public class ConvolutionTransform extends Lop
        public enum OperationTypes {
                MAX_POOLING, MAX_POOLING_BACKWARD, RELU_MAX_POOLING, 
RELU_BACKWARD, RELU_MAX_POOLING_BACKWARD,
                DIRECT_CONV2D, DIRECT_CONV2D_BACKWARD_FILTER, 
DIRECT_CONV2D_BACKWARD_DATA,
-               BIAS_ADD, DIRECT_CONV2D_BIAS_ADD, BIAS_MULTIPLY
+               BIAS_ADD, DIRECT_CONV2D_BIAS_ADD, BIAS_MULTIPLY, CHANNEL_SUMS
        }
        
        private OperationTypes operation = null;
@@ -67,6 +67,18 @@ public class ConvolutionTransform extends Lop
                input2.addOutput(this);
                setLevel();
        }
+       
+       public ConvolutionTransform(Lop input1, Lop input2, Lop input3, 
ConvolutionTransform.OperationTypes op, DataType dt, ValueType vt, ExecType et, 
int k) 
+       {
+               super(Lop.Type.Transform, dt, vt);              
+               init(input1, op, dt, vt, et);
+               numThreads = k;
+               this.addInput(input2);
+               input2.addOutput(this);
+               this.addInput(input3);
+               input3.addOutput(this);
+               setLevel();
+       }
 
        private void init (Lop input, ConvolutionTransform.OperationTypes op, 
DataType dt, ValueType vt, ExecType et) 
        {
@@ -142,6 +154,9 @@ public class ConvolutionTransform extends Lop
                case DIRECT_CONV2D_BACKWARD_DATA:
                        return "conv2d_backward_data";
                        
+               case CHANNEL_SUMS:
+                       return "channel_sums";
+                       
                default:
                        throw new 
UnsupportedOperationException(this.printErrorLocation() + "Instruction is not 
defined for Transform operation " + operation);
                                
@@ -180,6 +195,31 @@ public class ConvolutionTransform extends Lop
        }
        
        @Override
+       public String getInstructions(String input, String C, String HW, String 
output) throws LopsException {
+               if(operation == OperationTypes.CHANNEL_SUMS) {
+                       StringBuilder sb = new StringBuilder();
+                       sb.append( getExecType() );
+                       
+                       sb.append( OPERAND_DELIMITOR );
+                       sb.append( getOpcode() );
+                       sb.append( OPERAND_DELIMITOR );
+                       sb.append( getInputs().get(0).prepInputOperand(input));
+                       sb.append( OPERAND_DELIMITOR );
+                       sb.append( getInputs().get(1).prepInputOperand(C));
+                       sb.append( OPERAND_DELIMITOR );
+                       sb.append( getInputs().get(2).prepInputOperand(HW));
+                       //output
+                       sb.append( OPERAND_DELIMITOR );
+                       sb.append( this.prepOutputOperand(output));
+                       
+                       return sb.toString();
+               }
+               else {
+                       throw new LopsException("The operation is not supported 
with three operands:" + operation.name());
+               }
+       }
+       
+       @Override
        public String getInstructions(String[] inputs, String output) throws 
LopsException {
                StringBuilder sb = new StringBuilder();
                appendOpcode(sb);

http://git-wip-us.apache.org/repos/asf/systemml/blob/d916ba5b/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java 
b/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java
index 4e66042..d0bc429 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java
@@ -233,6 +233,7 @@ public class CPInstructionParser extends InstructionParser
                String2CPInstructionType.put( "conv2d_backward_data"      , 
CPINSTRUCTION_TYPE.Convolution);
                String2CPInstructionType.put( "bias_add"      , 
CPINSTRUCTION_TYPE.Convolution);
                String2CPInstructionType.put( "bias_multiply"      , 
CPINSTRUCTION_TYPE.Convolution);
+               String2CPInstructionType.put( "channel_sums"      , 
CPINSTRUCTION_TYPE.Convolution);
                
                // Quaternary instruction opcodes
                String2CPInstructionType.put( "wsloss"  , 
CPINSTRUCTION_TYPE.Quaternary);

http://git-wip-us.apache.org/repos/asf/systemml/blob/d916ba5b/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java 
b/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java
index 503576f..ae19969 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java
@@ -53,6 +53,7 @@ public class GPUInstructionParser  extends InstructionParser
                String2GPUInstructionType.put( "maxpooling_backward",    
GPUINSTRUCTION_TYPE.Convolution);
                String2GPUInstructionType.put( "bias_add",               
GPUINSTRUCTION_TYPE.Convolution);
                String2GPUInstructionType.put( "bias_multiply",          
GPUINSTRUCTION_TYPE.Convolution);
+               String2GPUInstructionType.put( "channel_sums",          
GPUINSTRUCTION_TYPE.Convolution);
 
                // Matrix Multiply Operators
                String2GPUInstructionType.put( "ba+*",  
GPUINSTRUCTION_TYPE.AggregateBinary);

http://git-wip-us.apache.org/repos/asf/systemml/blob/d916ba5b/src/main/java/org/apache/sysml/runtime/instructions/cp/ConvolutionCPInstruction.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/cp/ConvolutionCPInstruction.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/cp/ConvolutionCPInstruction.java
index c6b4698..36422d9 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/cp/ConvolutionCPInstruction.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/cp/ConvolutionCPInstruction.java
@@ -27,12 +27,14 @@ import org.apache.sysml.api.DMLScript;
 import org.apache.sysml.hops.OptimizerUtils;
 import org.apache.sysml.runtime.DMLRuntimeException;
 import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysml.runtime.functionobjects.KahanPlus;
 import org.apache.sysml.runtime.functionobjects.SwapIndex;
 import org.apache.sysml.runtime.instructions.InstructionUtils;
 import org.apache.sysml.runtime.matrix.data.ConvolutionParameters;
 import org.apache.sysml.runtime.matrix.data.LibMatrixDNN;
 import org.apache.sysml.runtime.matrix.data.LibMatrixNative;
 import org.apache.sysml.runtime.matrix.data.MatrixBlock;
+import org.apache.sysml.runtime.matrix.data.SparseBlock;
 import org.apache.sysml.runtime.matrix.operators.ReorgOperator;
 import org.apache.sysml.runtime.util.ConvolutionUtils;
 import org.apache.sysml.utils.NativeHelper;
@@ -59,6 +61,19 @@ public class ConvolutionCPInstruction extends 
UnaryCPInstruction {
                _numThreads = numThreads;
                _intermediateMemoryBudget = intermediateMemoryBudget;
        }
+       
+       public ConvolutionCPInstruction(CPOperand in, CPOperand in2, CPOperand 
in3, CPOperand out, String opcode, String istr, int numThreads, double 
intermediateMemoryBudget) throws DMLRuntimeException {
+               super(new ReorgOperator(SwapIndex.getSwapIndexFnObject()), in, 
out,
+                               opcode, istr);
+               if( !opcode.equals("channel_sums") ) {
+                       throw new DMLRuntimeException("Incorrect usage. 
Expected the opcode to be channel_sums, but found " + opcode);
+               }
+               _in2 = in2;
+               _in3 = in3;
+               _cptype = CPINSTRUCTION_TYPE.Convolution;
+               _numThreads = numThreads;
+               _intermediateMemoryBudget = intermediateMemoryBudget;
+       }
 
        private ConvolutionCPInstruction(CPOperand in, CPOperand out, String 
opcode, String istr,
                        ArrayList<CPOperand> stride, ArrayList<CPOperand> 
padding, ArrayList<CPOperand> input_shape,
@@ -212,6 +227,14 @@ public class ConvolutionCPInstruction extends 
UnaryCPInstruction {
                        int k = Integer.parseInt(parts[4]);
                        return new ConvolutionCPInstruction(in, in2, out, 
opcode, str, k, Double.parseDouble(parts[5]));
                }
+               else if (opcode.equalsIgnoreCase("channel_sums")) {
+                       InstructionUtils.checkNumFields(parts, 4);
+                       CPOperand in = new CPOperand(parts[1]);
+                       CPOperand in2 = new CPOperand(parts[2]);
+                       CPOperand in3 = new CPOperand(parts[3]);
+                       CPOperand out = new CPOperand(parts[4]);
+                       return new ConvolutionCPInstruction(in, in2, in3, out, 
opcode, str, -1, 0);
+               }
                else {
                        throw new DMLRuntimeException("Unknown opcode while 
parsing a ConvolutionCPInstruction: " + str);
                }
@@ -297,6 +320,65 @@ public class ConvolutionCPInstruction extends 
UnaryCPInstruction {
                ec.setMatrixOutput(getOutputVariableName(), outputBlock, 
getExtendedOpcode());
        }
        
+       public void processChannelSumsInstruction(ExecutionContext ec) throws 
DMLRuntimeException {
+               MatrixBlock input = ec.getMatrixInput(input1.getName(), 
getExtendedOpcode());
+               int C = (int) ec.getScalarInput(_in2.getName(), 
_in2.getValueType(), _in2.isLiteral()).getLongValue();
+               int HW = (int) ec.getScalarInput(_in3.getName(), 
_in3.getValueType(), _in3.isLiteral()).getLongValue();
+               if(C*HW != input.getNumColumns()) {
+                       throw new DMLRuntimeException("Expected rows*cols" + C 
+ "*" + HW + " to be equal to number of columns of input " + 
input.getNumColumns());
+               }
+               MatrixBlock outputBlock = null;
+               if(input.isEmpty()) {
+                       outputBlock = new MatrixBlock(C, 1, true);
+               }
+               else {
+                       outputBlock = new MatrixBlock(C, 1, 
false).allocateBlock();
+                       double [] output = outputBlock.getDenseBlock();
+                       if(input.isInSparseFormat()) {
+                               SparseBlock sblock = input.getSparseBlock();
+                               for(int n = 0; n < input.getNumRows(); n++) {
+                                       if( sblock.isEmpty(n) )
+                                               continue;
+                                       int apos = sblock.pos(n);
+                                       int alen = sblock.size(n);
+                                       int[] aix = sblock.indexes(n);
+                                       double[] avals = sblock.values(n);
+                                       
+                                       // Iterate over the sparse block
+                                       for(int j=apos; j<apos+alen; j++) {
+                                               // Note: the input is of shape 
[N, CHW]
+                                               int chw = aix[j];
+                                               
+                                               // Get individual zero-based 
c,h,w indexes from zero-based 'chw'
+                                               int c = chw / HW;
+                                               output[c] += avals[j];
+                                       }
+                               }
+                       }
+                       else {
+                               double [] inArr = input.getDenseBlock();
+                               if(inArr != null) {
+                                       KahanPlus kplus = 
KahanPlus.getKahanPlusFnObject();
+                                       for(int c = 0; c < C; c++) {
+                                               KahanObject sum = new 
KahanObject(0.0, 0.0);
+                                               for(int n = 0; n < 
input.getNumRows(); n++) {
+                                                       int index =  n*C*HW + 
c*HW;
+                                                       for(int hw = 0; hw < 
HW; hw++, index++) {
+                                                               
kplus.execute2(sum, inArr[index]);
+                                                       }
+                                               }
+                                               output[c] = sum._sum;
+                                       }
+                               }
+                       }
+                       outputBlock.recomputeNonZeros(getExtendedOpcode());
+               }
+               
+               // release inputs/outputs
+               ec.releaseMatrixInput(input1.getName(), getExtendedOpcode());
+               ec.setMatrixOutput(getOutputVariableName(), outputBlock, 
getExtendedOpcode());
+       }
+       
        // Assumption: enableNative && NativeHelper.isNativeLibraryLoaded() is 
true
        // This increases the number of native calls. For example:the cases 
where filter is sparse but input is dense
        private static boolean isFilterSparse(MatrixBlock filter) throws 
DMLRuntimeException {
@@ -324,6 +406,10 @@ public class ConvolutionCPInstruction extends 
UnaryCPInstruction {
                        processReluBackwardInstruction(ec);
                        return;
                }
+               else if (instOpcode.equalsIgnoreCase("channel_sums")) {
+                       processChannelSumsInstruction(ec);
+                       return;
+               }
                
                // acquire inputs
                MatrixBlock outputBlock = null;

http://git-wip-us.apache.org/repos/asf/systemml/blob/d916ba5b/src/main/java/org/apache/sysml/runtime/instructions/gpu/ConvolutionGPUInstruction.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/gpu/ConvolutionGPUInstruction.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/gpu/ConvolutionGPUInstruction.java
index 8565b5a..fdb208e 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/gpu/ConvolutionGPUInstruction.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/gpu/ConvolutionGPUInstruction.java
@@ -20,12 +20,17 @@ package org.apache.sysml.runtime.instructions.gpu;
 
 import java.util.ArrayList;
 
+import jcuda.Pointer;
+
 import org.apache.sysml.runtime.DMLRuntimeException;
 import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
 import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
 import org.apache.sysml.runtime.functionobjects.SwapIndex;
 import org.apache.sysml.runtime.instructions.InstructionUtils;
 import org.apache.sysml.runtime.instructions.cp.CPOperand;
+import org.apache.sysml.runtime.instructions.cp.ConvolutionCPInstruction;
+import org.apache.sysml.runtime.instructions.gpu.context.ExecutionConfig;
+import org.apache.sysml.runtime.instructions.gpu.context.GPUContext;
 import org.apache.sysml.runtime.matrix.data.LibMatrixCUDA;
 import org.apache.sysml.runtime.matrix.data.LibMatrixCuDNN;
 import org.apache.sysml.runtime.matrix.operators.ReorgOperator;
@@ -57,6 +62,19 @@ public class ConvolutionGPUInstruction extends 
GPUInstruction {
                _intermediateMemoryBudget = intermediateMemoryBudget;
        }
        
+       public ConvolutionGPUInstruction(CPOperand in1, CPOperand in2, 
CPOperand in3, CPOperand out, String opcode, String istr, double 
intermediateMemoryBudget) throws DMLRuntimeException {
+               super(new ReorgOperator(SwapIndex.getSwapIndexFnObject()), 
opcode, istr);
+               if( !opcode.equals("channel_sums") ) {
+                       throw new DMLRuntimeException("Incorrect usage. 
Expected the opcode to be channel_sums, but found " + opcode);
+               }
+               _input1 = in1;
+               _input2 = in2;
+               _input3 = in3;
+               _gputype = GPUINSTRUCTION_TYPE.Convolution;
+               _output = out;
+               _intermediateMemoryBudget = intermediateMemoryBudget;
+       }
+       
        public ConvolutionGPUInstruction(CPOperand in1, CPOperand in2, 
CPOperand in3, CPOperand out, String opcode,
                        String istr, ArrayList<CPOperand> stride,
                        ArrayList<CPOperand> padding, ArrayList<CPOperand> 
input_shape,
@@ -210,6 +228,14 @@ public class ConvolutionGPUInstruction extends 
GPUInstruction {
                        CPOperand out = new CPOperand(parts[3]);
                        return new ConvolutionGPUInstruction(in1, in2, out, 
opcode, str, Double.parseDouble(parts[4]));
                }
+               else if (opcode.equalsIgnoreCase("channel_sums")) {
+                       InstructionUtils.checkNumFields(parts, 4);
+                       CPOperand in = new CPOperand(parts[1]);
+                       CPOperand in2 = new CPOperand(parts[2]);
+                       CPOperand in3 = new CPOperand(parts[3]);
+                       CPOperand out = new CPOperand(parts[4]);
+                       return new ConvolutionGPUInstruction(in, in2, in3, out, 
opcode, str, 0);
+               }
                else {
                        throw new DMLRuntimeException("Unknown opcode while 
parsing a ConvolutionGPUInstruction: " + str);      
                }
@@ -246,6 +272,23 @@ public class ConvolutionGPUInstruction extends 
GPUInstruction {
                ec.releaseMatrixOutputForGPUInstruction(_output.getName());
        }
        
+       public void processChannelSumsInstruction(ExecutionContext ec) throws 
DMLRuntimeException {
+               GPUStatistics.incrementNoOfExecutedGPUInst();
+               MatrixObject input = getMatrixInputForGPUInstruction(ec, 
_input1.getName());
+               int C = (int) ec.getScalarInput(_input2.getName(), 
_input2.getValueType(), _input2.isLiteral()).getLongValue();
+               int HW = (int) ec.getScalarInput(_input3.getName(), 
_input3.getValueType(), _input3.isLiteral()).getLongValue();
+               if(C*HW != input.getNumColumns()) {
+                       throw new DMLRuntimeException("Expected rows*cols" + C 
+ "*" + HW + " to be equal to number of columns of input " + 
input.getNumColumns());
+               }
+               MatrixObject outputBlock = 
getDenseMatrixOutputForGPUInstruction(ec, _output.getName(), C, 1);
+               
+               LibMatrixCUDA.channelSums(ec.getGPUContext(0), 
getExtendedOpcode(), input, outputBlock, C, HW);
+               
+               // release inputs/outputs
+               ec.releaseMatrixInputForGPUInstruction(_input1.getName());
+               ec.releaseMatrixOutputForGPUInstruction(_output.getName());
+       }
+       
        @Override
        public void processInstruction(ExecutionContext ec) 
                        throws DMLRuntimeException 
@@ -258,6 +301,10 @@ public class ConvolutionGPUInstruction extends 
GPUInstruction {
                        processReLUBackwardInstruction(ec);
                        return;
                }
+               else if (instOpcode.equalsIgnoreCase("channel_sums")) {
+                       processChannelSumsInstruction(ec);
+                       return;
+               }
                
                GPUStatistics.incrementNoOfExecutedGPUInst();
                                        

http://git-wip-us.apache.org/repos/asf/systemml/blob/d916ba5b/src/main/java/org/apache/sysml/runtime/instructions/spark/QuantilePickSPInstruction.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/spark/QuantilePickSPInstruction.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/QuantilePickSPInstruction.java
index e7f515a..caaa9e8 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/spark/QuantilePickSPInstruction.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/QuantilePickSPInstruction.java
@@ -316,7 +316,7 @@ public class QuantilePickSPInstruction extends 
BinarySPInstruction {
                                sum += v2.next()._2().sumWeightForQuantile();
                        
                        //return tuple for partition aggregate
-                       return Arrays.asList(new Tuple2<>(v1,sum)).iterator();
+                       return Arrays.asList(new Tuple2<Integer, 
Double>(v1,sum)).iterator();
                }
        }
        

http://git-wip-us.apache.org/repos/asf/systemml/blob/d916ba5b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java
index 2cccde0..c0091c8 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java
@@ -322,6 +322,37 @@ public class LibMatrixCUDA {
                if (GPUStatistics.DISPLAY_STATISTICS) 
GPUStatistics.maintainCPMiscTimes(instName, 
GPUInstruction.MISC_TIMER_RELU_BACKWARD_KERNEL, System.nanoTime() - t1);
 
        }
+       
+       /**
+        * Perform channel_sums operations: out = rowSums(matrix(colSums(A), 
rows=C, cols=HW))
+        * 
+        * @param gCtx a valid {@link GPUContext}
+        * @param instName the invoking instruction's name for record {@link 
Statistics}.
+        * @param input input image
+        * @param outputBlock output
+        * @param C number of channels
+        * @param HW height*width
+        * @throws DMLRuntimeException if DMLRuntimeException occurs
+        */
+       public static void channelSums(GPUContext gCtx, String instName, 
MatrixObject input, MatrixObject outputBlock, long C, long HW) throws 
DMLRuntimeException {
+               if(LOG.isTraceEnabled()) {
+                       LOG.trace("GPU : channelSums" + ", GPUContext=" + gCtx);
+               }
+               int N = toInt(input.getNumRows());
+               int cols = toInt(input.getNumColumns());
+               if(cols != C*HW) {
+                       throw new DMLRuntimeException("Incorrect parameters, 
number of columns " + cols + " != " + C + "*" + HW);
+               }
+               Pointer imagePointer = getDensePointer(gCtx, input, instName);
+               Pointer outputPointer = getDensePointer(gCtx, outputBlock, 
instName);
+               
+               // We can replace this with CuDNN tensor reduce
+               Pointer tmp = gCtx.allocate(instName, cols*sizeOfDataType);
+               reduceCol(gCtx, instName, "reduce_col_sum", imagePointer, tmp, 
N, cols);
+               reduceRow(gCtx, instName, "reduce_row_sum", tmp, outputPointer, 
toInt(C), toInt(HW));
+               gCtx.cudaFreeHelper(tmp);
+
+       }
 
        /**
         * Performs the operation corresponding to the DML script:

http://git-wip-us.apache.org/repos/asf/systemml/blob/d916ba5b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNN.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNN.java 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNN.java
index e0a6a57..5935285 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNN.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNN.java
@@ -64,7 +64,7 @@ public class LibMatrixCuDNN extends LibMatrixCUDA {
        protected static cudnnHandle getCudnnHandle(GPUContext gCtx) throws 
DMLRuntimeException {
                return gCtx.getCudnnHandle();
        }
-
+       
        /**
         * Does a 2D convolution followed by a bias_add
         *
@@ -722,4 +722,4 @@ public class LibMatrixCuDNN extends LibMatrixCUDA {
                if(status != cudnnStatus.CUDNN_STATUS_SUCCESS)
                        throw new DMLRuntimeException("Error status returned by 
CuDNN:" + jcuda.jcudnn.cudnnStatus.stringFor(status));
        }
-}
\ No newline at end of file
+}

http://git-wip-us.apache.org/repos/asf/systemml/blob/d916ba5b/src/test/java/org/apache/sysml/test/gpu/AggregateUnaryOpTests.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/gpu/AggregateUnaryOpTests.java 
b/src/test/java/org/apache/sysml/test/gpu/AggregateUnaryOpTests.java
index 0b229f0..59e9cb1 100644
--- a/src/test/java/org/apache/sysml/test/gpu/AggregateUnaryOpTests.java
+++ b/src/test/java/org/apache/sysml/test/gpu/AggregateUnaryOpTests.java
@@ -45,6 +45,37 @@ public class AggregateUnaryOpTests extends UnaryOpTestsBase {
        public void colSums() {
                testSimpleUnaryOpMatrixOutput("colSums", "gpu_uack+");
        }
+       
+       @Test
+       public void channelSums() {
+               int[] rows = rowSizes;
+               int[] C = new int[] { 2, 5, 10, 50 };
+               int[] HW = new int[] { 10, 12, 21, 51 };
+               double[] sparsities = this.sparsities;
+               int seed = this.seed;   
+
+               for (int k = 0; k < sparsities.length; k++) {
+                       double sparsity = sparsities[k];
+                       if(sparsity == 0)
+                               continue; // sparsity == 0 has been 
independently tested but it fails with non-informative mlcontext error
+                       for (int i = 0; i < rows.length; i++) {
+                               int row = rows[i];
+                               if(row == 1)
+                                       continue; // Currently channel_sums 
rewrite is enabled only for row > 1
+                               for (int c : C) {
+                                       if(c == 1)
+                                               continue; // C == 1 will result 
in scalar value, but this case has been independently tested
+                                       for (int hw : HW) {
+                                               // Skip the case of a scalar 
unary op
+                                               // System.out.println("Started 
channelSum test for " + row + " " + c + " " + hw + " " +  sparsity);
+                                               String scriptStr = "out = 
rowSums(matrix(colSums(in1), rows=" + c + ", cols=" + hw + "));";
+                                               
testUnaryOpMatrixOutput(scriptStr, "gpu_channel_sums", "in1", "out", seed, row, 
c*hw, sparsity);
+                                               // System.out.println("Ended 
channelSum test for " + row + " " + c + " " + hw + " " +  sparsity);
+                                       }
+                               }
+                       }
+               }
+       }
 
        @Test
        public void rowSums() {

http://git-wip-us.apache.org/repos/asf/systemml/blob/d916ba5b/src/test/java/org/apache/sysml/test/gpu/UnaryOpTestsBase.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/gpu/UnaryOpTestsBase.java 
b/src/test/java/org/apache/sysml/test/gpu/UnaryOpTestsBase.java
index 0051dd4..0f6b59c 100644
--- a/src/test/java/org/apache/sysml/test/gpu/UnaryOpTestsBase.java
+++ b/src/test/java/org/apache/sysml/test/gpu/UnaryOpTestsBase.java
@@ -31,10 +31,10 @@ import org.apache.sysml.api.mlcontext.Matrix;
 public abstract class UnaryOpTestsBase extends GPUTests {
 
        // Set of rows and column sizes & sparsities to test unary ops
-       private final int[] rowSizes = new int[] { 2049, 1024, 140, 64, 1 };
-       private final int[] columnSizes = new int[] { 2049, 1024, 140, 64, 1 };
-       private final double[] sparsities = new double[] { 0.9, 0.3, 0.03, 0.0 
};
-       private final int seed = 42;
+       protected final int[] rowSizes = new int[] { 2049, 1024, 140, 64, 1 };
+       protected final int[] columnSizes = new int[] { 2049, 1024, 150, 64, 1 
};
+       protected final double[] sparsities = new double[] { 0.9, 0.3, 0.03, 
0.0 };
+       protected final int seed = 42;
 
        /**
         * Tests unary ops with a variety of matrix shapes and sparsities.

http://git-wip-us.apache.org/repos/asf/systemml/blob/d916ba5b/src/test/java/org/apache/sysml/test/integration/functions/tensor/ChannelSumTest.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/sysml/test/integration/functions/tensor/ChannelSumTest.java
 
b/src/test/java/org/apache/sysml/test/integration/functions/tensor/ChannelSumTest.java
new file mode 100644
index 0000000..61ca370
--- /dev/null
+++ 
b/src/test/java/org/apache/sysml/test/integration/functions/tensor/ChannelSumTest.java
@@ -0,0 +1,146 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sysml.test.integration.functions.tensor;
+
+import java.util.HashMap;
+
+import org.apache.sysml.api.DMLScript;
+import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM;
+import org.apache.sysml.lops.LopProperties.ExecType;
+import org.apache.sysml.runtime.matrix.data.MatrixValue.CellIndex;
+import org.apache.sysml.test.integration.AutomatedTestBase;
+import org.apache.sysml.test.integration.TestConfiguration;
+import org.apache.sysml.test.utils.TestUtils;
+import org.junit.Test;
+
+public class ChannelSumTest extends AutomatedTestBase
+{
+       
+       private final static String TEST_NAME = "ChannelSumTest";
+       private final static String TEST_DIR = "functions/tensor/";
+       private final static String TEST_CLASS_DIR = TEST_DIR + 
PoolTest.class.getSimpleName() + "/";
+       private final static double epsilon=0.0000000001;
+       
+       @Override
+       public void setUp() {
+               addTestConfiguration(TEST_NAME, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, 
+                               new String[] {"B"}));
+       }
+       
+       @Test
+       public void testChannelSumDense1() 
+       {
+               int numImg = 10; int imgSize = 9; int numChannels = 5; 
+               runChannelSumTest(ExecType.CP, imgSize, numImg, numChannels, 
false);
+       }
+       
+       @Test
+       public void testChannelSumDense2() 
+       {
+               int numImg = 2; int imgSize = 5; int numChannels = 3; 
+               runChannelSumTest(ExecType.CP, imgSize, numImg, numChannels, 
false);
+       }
+       
+       @Test
+       public void testChannelSumDense3() 
+       {
+               int numImg = 9; int imgSize = 4; int numChannels = 11; 
+               runChannelSumTest(ExecType.CP, imgSize, numImg, numChannels, 
false);
+       }
+       
+       @Test
+       public void testChannelSumDense4() 
+       {
+               int numImg = 7; int imgSize = 8; int numChannels = 12; 
+               runChannelSumTest(ExecType.CP, imgSize, numImg, numChannels, 
false);
+       }
+       
+       @Test
+       public void testChannelSumSparse1() 
+       {
+               int numImg = 4; int imgSize = 10; int numChannels = 5; 
+               runChannelSumTest(ExecType.CP, imgSize, numImg, numChannels, 
true);
+       }
+       
+       @Test
+       public void testChannelSumSparse2() 
+       {
+               int numImg = 2; int imgSize = 10; int numChannels = 8; 
+               runChannelSumTest(ExecType.CP, imgSize, numImg, numChannels, 
true);
+       }
+       
+       @Test
+       public void testChannelSumSparse3() 
+       {
+               int numImg = 4; int imgSize = 10; int numChannels = 11; 
+               runChannelSumTest(ExecType.CP, imgSize, numImg, numChannels, 
true);
+       }
+       
+       @Test
+       public void testChannelSumSparse4() 
+       {
+               int numImg = 9; int imgSize = 6; int numChannels = 8; 
+               runChannelSumTest(ExecType.CP, imgSize, numImg, numChannels, 
true);
+       }
+       
+       public void runChannelSumTest( ExecType et, int imgSize, int numImg, 
int numChannels, boolean sparse) 
+       {
+               RUNTIME_PLATFORM platformOld = rtplatform;
+               switch( et ){
+                       case MR: rtplatform = RUNTIME_PLATFORM.HADOOP; break;
+                       case SPARK: rtplatform = RUNTIME_PLATFORM.SPARK; break;
+                       default: rtplatform = RUNTIME_PLATFORM.HYBRID; break;
+               }
+               boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+               if( rtplatform == RUNTIME_PLATFORM.SPARK )
+                       DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+               
+               try
+               {
+                       String sparseVal = String.valueOf(sparse).toUpperCase();
+                       
+                       TestConfiguration config = 
getTestConfiguration(TEST_NAME);
+                       loadTestConfiguration(config);
+       
+                       String RI_HOME = SCRIPT_DIR + TEST_DIR;
+                       fullDMLScriptName = RI_HOME + TEST_NAME + ".dml";
+                       programArgs = new String[]{"-explain", "hops", "-args", 
String.valueOf(imgSize), 
+                               String.valueOf(numImg), 
String.valueOf(numChannels),
+                               output("B"), sparseVal};
+                       
+                       fullRScriptName = RI_HOME + TEST_NAME + ".R";
+                       rCmd = "Rscript" + " " + fullRScriptName + " " + 
imgSize + " " + numImg + 
+                               " " + numChannels + " " + expectedDir() + " " + 
sparseVal; 
+                       
+                       // run scripts
+                       runTest(true, false, null, -1);
+                       runRScript(true);
+                       
+                       //compare results
+                       HashMap<CellIndex, Double> bHM = readRMatrixFromFS("B");
+                       HashMap<CellIndex, Double> dmlfile = 
readDMLMatrixFromHDFS("B");
+                       TestUtils.compareMatrices(dmlfile, bHM, epsilon, 
"B-DML", "NumPy");
+               }
+               finally {
+                       rtplatform = platformOld;
+                       DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+               }
+       }
+       
+}

http://git-wip-us.apache.org/repos/asf/systemml/blob/d916ba5b/src/test/scripts/functions/tensor/ChannelSumTest.R
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/tensor/ChannelSumTest.R 
b/src/test/scripts/functions/tensor/ChannelSumTest.R
new file mode 100644
index 0000000..c605074
--- /dev/null
+++ b/src/test/scripts/functions/tensor/ChannelSumTest.R
@@ -0,0 +1,39 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+args <- commandArgs(TRUE)
+library("Matrix")
+library("matrixStats") 
+imgSize=as.integer(args[1])
+numImg=as.integer(args[2])
+numChannels=as.integer(args[3])
+
+# Assumption: NCHW image format
+x=matrix(seq(1, numImg*numChannels*imgSize*imgSize), numImg, 
numChannels*imgSize*imgSize, byrow=TRUE)
+if(as.logical(args[5])) {
+       zero_mask = (x - 1.5*mean(x)) > 0 
+       x = x * zero_mask
+} else {
+       x = x - mean(x)
+}
+
+output = rowSums(matrix(colSums(x), numChannels, imgSize*imgSize, byrow=TRUE));
+
+writeMM(as(output,"CsparseMatrix"), paste(args[4], "B", sep=""))
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/systemml/blob/d916ba5b/src/test/scripts/functions/tensor/ChannelSumTest.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/tensor/ChannelSumTest.dml 
b/src/test/scripts/functions/tensor/ChannelSumTest.dml
new file mode 100644
index 0000000..7810a12
--- /dev/null
+++ b/src/test/scripts/functions/tensor/ChannelSumTest.dml
@@ -0,0 +1,35 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# 
+#-------------------------------------------------------------
+imgSize=$1
+numImg=$2
+numChannels=$3
+
+# Assumption: NCHW image format
+x=matrix(seq(1, numImg*numChannels*imgSize*imgSize), rows=numImg, 
cols=numChannels*imgSize*imgSize)
+if($5) {
+       zero_mask = (x - 1.5*mean(x)) > 0 
+       x = x * zero_mask
+}
+else {
+       x = x - mean(x)
+}
+output = rowSums(matrix(colSums(x), rows=numChannels, cols=imgSize*imgSize))  
# shape (C, 1)
+write(output, $4, format="text")
\ No newline at end of file

Reply via email to