Repository: incubator-systemml
Updated Branches:
  refs/heads/master 6f8cea9bc -> 7a30925e6


[SYSTEMML-1039] Implemented uack+/uac+

Closes #331.


Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: 
http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/7a30925e
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/7a30925e
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/7a30925e

Branch: refs/heads/master
Commit: 7a30925e6b366caaf5d8bbe0d2457c2e71858485
Parents: 6f8cea9
Author: Nakul Jindal <[email protected]>
Authored: Fri Jan 6 10:32:29 2017 -0800
Committer: Niketan Pansare <[email protected]>
Committed: Fri Jan 6 10:32:29 2017 -0800

----------------------------------------------------------------------
 src/main/cpp/kernels/SystemML.cu                | 35 ++++++++-
 src/main/cpp/kernels/SystemML.ptx               | 78 +++++++++++++++++---
 .../java/org/apache/sysml/hops/AggUnaryOp.java  |  2 +-
 .../instructions/GPUInstructionParser.java      |  2 +
 .../runtime/matrix/data/LibMatrixCUDA.java      | 28 ++++++-
 5 files changed, 130 insertions(+), 15 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/7a30925e/src/main/cpp/kernels/SystemML.cu
----------------------------------------------------------------------
diff --git a/src/main/cpp/kernels/SystemML.cu b/src/main/cpp/kernels/SystemML.cu
index 11a337c..f57c04f 100644
--- a/src/main/cpp/kernels/SystemML.cu
+++ b/src/main/cpp/kernels/SystemML.cu
@@ -254,8 +254,12 @@ __global__ void reduce(double *g_idata, double *g_odata, 
unsigned int n)
 
 /**
  * Does a reduce (sum) over each row of the array.
+ * This kernel must be launched with as many blocks as there are rows.
  * The intuition for this kernel is that each block does a reduction over a 
single row.
- * The maximum numver
+ * The maximum number of blocks that can launched (as of compute capability 
3.0) is 2^31 - 1
+ * This works out fine for SystemML, since the maximum elements in a Java 
array can be 2^31 - c (some small constant)
+ * If the matrix is "fat" and "short", i.e. there are small number of rows and 
a large number of columns,
+ * there could be under-utilization of the hardware.
  * @param g_idata   input matrix stored in device memory
  * @param g_odata   output vector of size [rows * 1] in device memory
  * @param rows      number of rows in input matrix
@@ -312,3 +316,32 @@ __global__ void reduce_row(double *g_idata, double 
*g_odata, unsigned int rows,
 }
 
 
+/**
+ * Does a column wise reduction.
+ * The intuition is that there are as many global threads as there are columns
+ *  Each global thread is responsible for a single element in the output vector
+ * This of course leads to a under-utilization of the GPU resources.
+ * For cases, where the number of columns is small, there can be unused SMs
+ * @param g_idata   input matrix stored in device memory
+ * @param g_odata   output vector of size [1 * cols] in device memory
+ * @param rows      number of rows in input matrix
+ * @param cols      number of columns in input matrix
+ */
+extern "C"
+__global__ void reduce_col(double *g_idata, double *g_odata, unsigned int 
rows, unsigned int cols)
+{
+    unsigned int global_tid = blockIdx.x * blockDim.x + threadIdx.x;
+    if (global_tid >= cols) {
+        return;
+    }
+
+    unsigned int i = global_tid;
+    unsigned int grid_size = cols;
+    double val = 0;
+
+    while (i < rows * cols) {
+      val += g_idata[i];
+      i += grid_size;
+    }
+    g_odata[global_tid] = val;
+}

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/7a30925e/src/main/cpp/kernels/SystemML.ptx
----------------------------------------------------------------------
diff --git a/src/main/cpp/kernels/SystemML.ptx 
b/src/main/cpp/kernels/SystemML.ptx
index 0683492..c708134 100644
--- a/src/main/cpp/kernels/SystemML.ptx
+++ b/src/main/cpp/kernels/SystemML.ptx
@@ -1901,6 +1901,62 @@ BB11_31:
        ret;
 }
 
+       // .globl       reduce_col
+.visible .entry reduce_col(
+       .param .u64 reduce_col_param_0,
+       .param .u64 reduce_col_param_1,
+       .param .u32 reduce_col_param_2,
+       .param .u32 reduce_col_param_3
+)
+{
+       .reg .pred      %p<4>;
+       .reg .b32       %r<11>;
+       .reg .f64       %fd<10>;
+       .reg .b64       %rd<9>;
+
+
+       ld.param.u64    %rd2, [reduce_col_param_0];
+       ld.param.u64    %rd3, [reduce_col_param_1];
+       ld.param.u32    %r5, [reduce_col_param_2];
+       ld.param.u32    %r6, [reduce_col_param_3];
+       mov.u32         %r7, %ntid.x;
+       mov.u32         %r8, %ctaid.x;
+       mov.u32         %r9, %tid.x;
+       mad.lo.s32      %r1, %r7, %r8, %r9;
+       setp.ge.u32     %p1, %r1, %r6;
+       @%p1 bra        BB12_5;
+
+       cvta.to.global.u64      %rd1, %rd2;
+       mul.lo.s32      %r2, %r6, %r5;
+       mov.f64         %fd8, 0d0000000000000000;
+       mov.f64         %fd9, %fd8;
+       setp.ge.u32     %p2, %r1, %r2;
+       @%p2 bra        BB12_4;
+
+       mov.u32         %r10, %r1;
+
+BB12_3:
+       mov.u32         %r3, %r10;
+       mul.wide.u32    %rd4, %r3, 8;
+       add.s64         %rd5, %rd1, %rd4;
+       ld.global.f64   %fd6, [%rd5];
+       add.f64         %fd9, %fd9, %fd6;
+       add.s32         %r4, %r3, %r6;
+       setp.lt.u32     %p3, %r4, %r2;
+       mov.u32         %r10, %r4;
+       mov.f64         %fd8, %fd9;
+       @%p3 bra        BB12_3;
+
+BB12_4:
+       cvta.to.global.u64      %rd6, %rd3;
+       mul.wide.u32    %rd7, %r1, 8;
+       add.s64         %rd8, %rd6, %rd7;
+       st.global.f64   [%rd8], %fd8;
+
+BB12_5:
+       ret;
+}
+
 .func  (.param .b64 func_retval0) __internal_accurate_pow(
        .param .b64 __internal_accurate_pow_param_0,
        .param .b64 __internal_accurate_pow_param_1
@@ -1924,7 +1980,7 @@ BB11_31:
        }
        shr.u32         %r50, %r49, 20;
        setp.ne.s32     %p1, %r50, 0;
-       @%p1 bra        BB12_2;
+       @%p1 bra        BB13_2;
 
        mul.f64         %fd14, %fd12, 0d4350000000000000;
        {
@@ -1938,13 +1994,13 @@ BB11_31:
        shr.u32         %r16, %r49, 20;
        add.s32         %r50, %r16, -54;
 
-BB12_2:
+BB13_2:
        add.s32         %r51, %r50, -1023;
        and.b32         %r17, %r49, -2146435073;
        or.b32          %r18, %r17, 1072693248;
        mov.b64         %fd133, {%r48, %r18};
        setp.lt.u32     %p2, %r18, 1073127583;
-       @%p2 bra        BB12_4;
+       @%p2 bra        BB13_4;
 
        {
        .reg .b32 %temp; 
@@ -1958,7 +2014,7 @@ BB12_2:
        mov.b64         %fd133, {%r19, %r21};
        add.s32         %r51, %r50, -1022;
 
-BB12_4:
+BB13_4:
        add.f64         %fd16, %fd133, 0d3FF0000000000000;
        // inline asm
        rcp.approx.ftz.f64 %fd15,%fd16;
@@ -2124,13 +2180,13 @@ BB12_4:
        mov.b32          %f2, %r35;
        abs.f32         %f1, %f2;
        setp.lt.f32     %p4, %f1, 0f4086232B;
-       @%p4 bra        BB12_7;
+       @%p4 bra        BB13_7;
 
        setp.lt.f64     %p5, %fd4, 0d0000000000000000;
        add.f64         %fd130, %fd4, 0d7FF0000000000000;
        selp.f64        %fd134, 0d0000000000000000, %fd130, %p5;
        setp.geu.f32    %p6, %f1, 0f40874800;
-       @%p6 bra        BB12_7;
+       @%p6 bra        BB13_7;
 
        shr.u32         %r36, %r13, 31;
        add.s32         %r37, %r13, %r36;
@@ -2145,26 +2201,26 @@ BB12_4:
        mov.b64         %fd132, {%r44, %r43};
        mul.f64         %fd134, %fd131, %fd132;
 
-BB12_7:
+BB13_7:
        {
        .reg .b32 %temp; 
        mov.b64         {%temp, %r45}, %fd134;
        }
        and.b32         %r46, %r45, 2147483647;
        setp.ne.s32     %p7, %r46, 2146435072;
-       @%p7 bra        BB12_9;
+       @%p7 bra        BB13_9;
 
        {
        .reg .b32 %temp; 
        mov.b64         {%r47, %temp}, %fd134;
        }
        setp.eq.s32     %p8, %r47, 0;
-       @%p8 bra        BB12_10;
+       @%p8 bra        BB13_10;
 
-BB12_9:
+BB13_9:
        fma.rn.f64      %fd134, %fd134, %fd5, %fd134;
 
-BB12_10:
+BB13_10:
        st.param.f64    [func_retval0+0], %fd134;
        ret;
 }

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/7a30925e/src/main/java/org/apache/sysml/hops/AggUnaryOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/AggUnaryOp.java 
b/src/main/java/org/apache/sysml/hops/AggUnaryOp.java
index 99aef40..15308fc 100644
--- a/src/main/java/org/apache/sysml/hops/AggUnaryOp.java
+++ b/src/main/java/org/apache/sysml/hops/AggUnaryOp.java
@@ -146,7 +146,7 @@ public class AggUnaryOp extends Hop implements 
MultiThreadedHop
                                        int k = 
OptimizerUtils.getConstrainedNumThreads(_maxNumThreads);
                                        if(DMLScript.USE_ACCELERATOR && 
(DMLScript.FORCE_ACCELERATOR || getMemEstimate() < 
OptimizerUtils.GPU_MEMORY_BUDGET) && (_op == AggOp.SUM)) {
                                                // Only implemented methods for 
GPU
-                                               if (_op == AggOp.SUM && 
(_direction == Direction.RowCol || _direction == Direction.Row)){
+                                               if (_op == AggOp.SUM && 
(_direction == Direction.RowCol || _direction == Direction.Row || _direction == 
Direction.Col)){
                                                        et = ExecType.GPU;
                                                        k = 1;
                                                }

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/7a30925e/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java 
b/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java
index c1d884e..bc9b93e 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java
@@ -76,6 +76,8 @@ public class GPUInstructionParser  extends InstructionParser
                String2GPUInstructionType.put( "uak+"    , 
GPUINSTRUCTION_TYPE.AggregateUnary);
                String2GPUInstructionType.put( "uar+"    , 
GPUINSTRUCTION_TYPE.AggregateUnary);
                String2GPUInstructionType.put( "uark+"   , 
GPUINSTRUCTION_TYPE.AggregateUnary);
+               String2GPUInstructionType.put( "uac+"    , 
GPUINSTRUCTION_TYPE.AggregateUnary);
+               String2GPUInstructionType.put( "uack+"   , 
GPUINSTRUCTION_TYPE.AggregateUnary);
 
        }
        

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/7a30925e/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java
index 7da2891..00e3d87 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java
@@ -1029,12 +1029,16 @@ public class LibMatrixCUDA {
                                                ec.setScalarOutput(output, new 
DoubleObject(result));
                                                break;
                                        }
-                                       case REDUCTION_COL : {
+                                       case REDUCTION_COL : {  // The names 
are a bit misleading, REDUCTION_COL refers to the direction (reduce all 
elements in a column)
                                                reduceRow(in, out, rlen, clen);
                                                break;
                                        }
+                                       case REDUCTION_ROW : {
+                                               reduceCol(in, out, rlen, clen);
+                                               break;
+                                       }
+
                                        case REDUCTION_DIAG :
-                                       case REDUCTION_ROW :
                                                throw new 
DMLRuntimeException("Internal Error - Row, Column and Diag summation not 
implemented yet");
                                }
                                break;
@@ -1172,6 +1176,14 @@ public class LibMatrixCUDA {
                cudaDeviceSynchronize();
        }
 
+       private static void reduceCol(Pointer in, Pointer out, int rows, int 
cols) throws DMLRuntimeException {
+               int[] tmp = getKernelParamsForReduceByCol(rows, cols);
+               int blocks = tmp[0], threads = tmp[1], sharedMem = tmp[2];
+               kernels.launchKernel("reduce_col", new ExecutionConfig(blocks, 
threads, sharedMem),
+                                               in, out, rows, cols);
+               cudaDeviceSynchronize();
+       }
+
        /**
         * Get threads, blocks and shared memory for a reduce all operation
         * @param n size of input array
@@ -1207,6 +1219,18 @@ public class LibMatrixCUDA {
                return new int[] {blocks, threads, sharedMemSize};
        }
 
+       private static int[] getKernelParamsForReduceByCol(int rows, int cols) {
+               int threads = Math.min(cols, MAX_THREADS);
+               int blocks = cols/1024;
+               if (cols % 1024 != 0) blocks++;
+               int sharedMemSize = threads * Sizeof.DOUBLE;
+               if (threads <= 32){
+                       sharedMemSize *=2;
+               }
+               return new int[] {blocks, threads, sharedMemSize};
+       }
+
+
        private static int nextPow2(int x)
        {
                --x;

Reply via email to