Repository: systemml
Updated Branches:
  refs/heads/master 5ca8706e9 -> 04bc667f3


[SYSTEMML-445] Added SGD Nesterov update operator via rewrite for the GPU 
backend

- This leads to 10-15% speedup for ResNet200 with batch size of 32.
- Also, added GPU tests for this operator.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/04bc667f
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/04bc667f
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/04bc667f

Branch: refs/heads/master
Commit: 04bc667f3650d57c0bc9de20e46e7624205cc1e6
Parents: 5ca8706
Author: Niketan Pansare <[email protected]>
Authored: Thu Aug 9 21:00:21 2018 -0700
Committer: Niketan Pansare <[email protected]>
Committed: Thu Aug 9 21:00:21 2018 -0700

----------------------------------------------------------------------
 src/main/cpp/kernels/SystemML.cu                |  23 ++-
 src/main/cpp/kernels/SystemML.ptx               | 161 +++++++++++++++----
 src/main/java/org/apache/sysml/hops/DnnOp.java  |  15 +-
 src/main/java/org/apache/sysml/hops/Hop.java    |   4 +-
 .../hops/rewrite/RewriteGPUSpecificOps.java     |  61 +++++++
 .../org/apache/sysml/lops/DnnTransform.java     |  33 +++-
 .../instructions/GPUInstructionParser.java      |   1 +
 .../instructions/gpu/DnnGPUInstruction.java     |  56 +++++++
 .../gpu/context/GPUMemoryManager.java           |   5 +-
 .../org/apache/sysml/test/gpu/GPUTests.java     |  10 ++
 .../org/apache/sysml/test/gpu/SGDUpdate.java    |  91 +++++++++++
 11 files changed, 419 insertions(+), 41 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/04bc667f/src/main/cpp/kernels/SystemML.cu
----------------------------------------------------------------------
diff --git a/src/main/cpp/kernels/SystemML.cu b/src/main/cpp/kernels/SystemML.cu
index 485b7e2..9ddaaff 100644
--- a/src/main/cpp/kernels/SystemML.cu
+++ b/src/main/cpp/kernels/SystemML.cu
@@ -2248,12 +2248,6 @@ extern "C" __global__ void prepare_lstm_dinput_f(float* 
smlInput, float* cudnnIn
 }
 
 
-/**
- * Do an log over all the elements of a matrix
- * @param A the input matrix (of length = size)
- * @param C the pre-allocated output matrix (of length = size)
- * @param size the length of the input and output matrices
- */
 template <typename T>
 __device__ void colwise_reshape(T *A, T *C, unsigned int size, 
        unsigned int inRows, unsigned int inCols,
@@ -2278,4 +2272,21 @@ extern "C" __global__ void colwise_reshape_f(float *A, 
float *C, unsigned int si
        unsigned int inRows, unsigned int inCols,
        unsigned int outRows, unsigned int outCols) {
   colwise_reshape(A, C, size, inRows, inCols, outRows, outCols);
+}
+
+// Performs the operation: out = X - mu*v_prev + (1+mu)*v
+template <typename T>
+__device__ void update_nesterov_x(T *X, T *v, T *v_prev, double mu, T *out, 
unsigned int size) {
+  int index = blockIdx.x * blockDim.x + threadIdx.x;
+  if (index < size) {
+       out[index] = X[index] - mu*v_prev[index] + (1+mu)*v[index];
+  }
+}
+
+extern "C" __global__ void update_nesterov_x_d(double *X, double *v, double 
*v_prev, double mu, double *out, unsigned int size) {
+  update_nesterov_x(X, v, v_prev, mu, out, size);
+}
+
+extern "C" __global__ void update_nesterov_x_f(float *X, float *v, float 
*v_prev, double mu, float *out, unsigned int size) {
+  update_nesterov_x(X, v, v_prev, mu, out, size);
 }
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/systemml/blob/04bc667f/src/main/cpp/kernels/SystemML.ptx
----------------------------------------------------------------------
diff --git a/src/main/cpp/kernels/SystemML.ptx 
b/src/main/cpp/kernels/SystemML.ptx
index 93e5e35..8a14876 100644
--- a/src/main/cpp/kernels/SystemML.ptx
+++ b/src/main/cpp/kernels/SystemML.ptx
@@ -12977,12 +12977,119 @@ BB113_2:
        ret;
 }
 
+       // .globl       update_nesterov_x_d
+.visible .entry update_nesterov_x_d(
+       .param .u64 update_nesterov_x_d_param_0,
+       .param .u64 update_nesterov_x_d_param_1,
+       .param .u64 update_nesterov_x_d_param_2,
+       .param .f64 update_nesterov_x_d_param_3,
+       .param .u64 update_nesterov_x_d_param_4,
+       .param .u32 update_nesterov_x_d_param_5
+)
+{
+       .reg .pred      %p<2>;
+       .reg .b32       %r<6>;
+       .reg .f64       %fd<9>;
+       .reg .b64       %rd<14>;
+
+
+       ld.param.u64    %rd1, [update_nesterov_x_d_param_0];
+       ld.param.u64    %rd2, [update_nesterov_x_d_param_1];
+       ld.param.u64    %rd3, [update_nesterov_x_d_param_2];
+       ld.param.f64    %fd1, [update_nesterov_x_d_param_3];
+       ld.param.u64    %rd4, [update_nesterov_x_d_param_4];
+       ld.param.u32    %r2, [update_nesterov_x_d_param_5];
+       mov.u32         %r3, %ctaid.x;
+       mov.u32         %r4, %ntid.x;
+       mov.u32         %r5, %tid.x;
+       mad.lo.s32      %r1, %r4, %r3, %r5;
+       setp.ge.u32     %p1, %r1, %r2;
+       @%p1 bra        BB114_2;
+
+       cvta.to.global.u64      %rd5, %rd1;
+       mul.wide.s32    %rd6, %r1, 8;
+       add.s64         %rd7, %rd5, %rd6;
+       cvta.to.global.u64      %rd8, %rd3;
+       add.s64         %rd9, %rd8, %rd6;
+       ld.global.f64   %fd2, [%rd9];
+       mul.f64         %fd3, %fd2, %fd1;
+       ld.global.f64   %fd4, [%rd7];
+       sub.f64         %fd5, %fd4, %fd3;
+       cvta.to.global.u64      %rd10, %rd2;
+       add.s64         %rd11, %rd10, %rd6;
+       ld.global.f64   %fd6, [%rd11];
+       add.f64         %fd7, %fd1, 0d3FF0000000000000;
+       fma.rn.f64      %fd8, %fd7, %fd6, %fd5;
+       cvta.to.global.u64      %rd12, %rd4;
+       add.s64         %rd13, %rd12, %rd6;
+       st.global.f64   [%rd13], %fd8;
+
+BB114_2:
+       ret;
+}
+
+       // .globl       update_nesterov_x_f
+.visible .entry update_nesterov_x_f(
+       .param .u64 update_nesterov_x_f_param_0,
+       .param .u64 update_nesterov_x_f_param_1,
+       .param .u64 update_nesterov_x_f_param_2,
+       .param .f64 update_nesterov_x_f_param_3,
+       .param .u64 update_nesterov_x_f_param_4,
+       .param .u32 update_nesterov_x_f_param_5
+)
+{
+       .reg .pred      %p<2>;
+       .reg .f32       %f<5>;
+       .reg .b32       %r<6>;
+       .reg .f64       %fd<9>;
+       .reg .b64       %rd<14>;
+
+
+       ld.param.u64    %rd1, [update_nesterov_x_f_param_0];
+       ld.param.u64    %rd2, [update_nesterov_x_f_param_1];
+       ld.param.u64    %rd3, [update_nesterov_x_f_param_2];
+       ld.param.f64    %fd1, [update_nesterov_x_f_param_3];
+       ld.param.u64    %rd4, [update_nesterov_x_f_param_4];
+       ld.param.u32    %r2, [update_nesterov_x_f_param_5];
+       mov.u32         %r3, %ctaid.x;
+       mov.u32         %r4, %ntid.x;
+       mov.u32         %r5, %tid.x;
+       mad.lo.s32      %r1, %r4, %r3, %r5;
+       setp.ge.u32     %p1, %r1, %r2;
+       @%p1 bra        BB115_2;
+
+       cvta.to.global.u64      %rd5, %rd1;
+       mul.wide.s32    %rd6, %r1, 4;
+       add.s64         %rd7, %rd5, %rd6;
+       ld.global.f32   %f1, [%rd7];
+       cvt.f64.f32     %fd2, %f1;
+       cvta.to.global.u64      %rd8, %rd3;
+       add.s64         %rd9, %rd8, %rd6;
+       ld.global.f32   %f2, [%rd9];
+       cvt.f64.f32     %fd3, %f2;
+       mul.f64         %fd4, %fd3, %fd1;
+       sub.f64         %fd5, %fd2, %fd4;
+       cvta.to.global.u64      %rd10, %rd2;
+       add.s64         %rd11, %rd10, %rd6;
+       ld.global.f32   %f3, [%rd11];
+       cvt.f64.f32     %fd6, %f3;
+       add.f64         %fd7, %fd1, 0d3FF0000000000000;
+       fma.rn.f64      %fd8, %fd7, %fd6, %fd5;
+       cvt.rn.f32.f64  %f4, %fd8;
+       cvta.to.global.u64      %rd12, %rd4;
+       add.s64         %rd13, %rd12, %rd6;
+       st.global.f32   [%rd13], %f4;
+
+BB115_2:
+       ret;
+}
+
 .func  (.param .b64 func_retval0) __internal_trig_reduction_slowpathd(
        .param .b64 __internal_trig_reduction_slowpathd_param_0,
        .param .b64 __internal_trig_reduction_slowpathd_param_1
 )
 {
-       .local .align 8 .b8     __local_depot114[40];
+       .local .align 8 .b8     __local_depot116[40];
        .reg .b64       %SP;
        .reg .b64       %SPL;
        .reg .pred      %p<9>;
@@ -12991,7 +13098,7 @@ BB113_2:
        .reg .b64       %rd<102>;
 
 
-       mov.u64         %rd101, __local_depot114;
+       mov.u64         %rd101, __local_depot116;
        cvta.local.u64  %SP, %rd101;
        ld.param.f64    %fd4, [__internal_trig_reduction_slowpathd_param_0];
        ld.param.u64    %rd37, [__internal_trig_reduction_slowpathd_param_1];
@@ -13005,7 +13112,7 @@ BB113_2:
        shr.u32         %r3, %r1, 20;
        bfe.u32         %r4, %r1, 20, 11;
        setp.eq.s32     %p1, %r4, 2047;
-       @%p1 bra        BB114_13;
+       @%p1 bra        BB116_13;
 
        add.s32         %r15, %r4, -1024;
        shr.u32         %r16, %r15, 6;
@@ -13018,7 +13125,7 @@ BB113_2:
        mov.u64         %rd94, 0;
        setp.ge.s32     %p2, %r5, %r6;
        mov.u64         %rd93, %rd1;
-       @%p2 bra        BB114_4;
+       @%p2 bra        BB116_4;
 
        mov.b64          %rd41, %fd4;
        shl.b64         %rd42, %rd41, 11;
@@ -13035,7 +13142,7 @@ BB113_2:
        mov.u64         %rd91, %rd1;
        mov.u32         %r39, %r5;
 
-BB114_3:
+BB116_3:
        .pragma "nounroll";
        ld.const.u64    %rd47, [%rd89];
        // inline asm
@@ -13065,15 +13172,15 @@ BB114_3:
        add.s64         %rd93, %rd93, 8;
        add.s64         %rd89, %rd89, 8;
        setp.lt.s32     %p3, %r39, %r6;
-       @%p3 bra        BB114_3;
+       @%p3 bra        BB116_3;
 
-BB114_4:
+BB116_4:
        st.local.u64    [%rd93], %rd94;
        ld.local.u64    %rd95, [%rd1+16];
        ld.local.u64    %rd96, [%rd1+24];
        and.b32         %r9, %r3, 63;
        setp.eq.s32     %p4, %r9, 0;
-       @%p4 bra        BB114_6;
+       @%p4 bra        BB116_6;
 
        mov.u32         %r27, 64;
        sub.s32         %r28, %r27, %r9;
@@ -13085,7 +13192,7 @@ BB114_4:
        shr.u64         %rd55, %rd54, %r28;
        or.b64          %rd95, %rd55, %rd53;
 
-BB114_6:
+BB116_6:
        cvta.to.local.u64       %rd56, %rd37;
        shr.u64         %rd57, %rd96, 62;
        cvt.u32.u64     %r29, %rd57;
@@ -13102,7 +13209,7 @@ BB114_6:
        selp.b32        %r34, %r32, %r33, %p5;
        st.local.u32    [%rd56], %r34;
        setp.eq.s32     %p6, %r31, 0;
-       @%p6 bra        BB114_8;
+       @%p6 bra        BB116_8;
 
        mov.u64         %rd64, 0;
        // inline asm
@@ -13122,10 +13229,10 @@ BB114_6:
        // inline asm
        xor.b32         %r40, %r40, -2147483648;
 
-BB114_8:
+BB116_8:
        clz.b64         %r41, %rd98;
        setp.eq.s32     %p7, %r41, 0;
-       @%p7 bra        BB114_10;
+       @%p7 bra        BB116_10;
 
        shl.b64         %rd67, %rd98, %r41;
        mov.u32         %r35, 64;
@@ -13133,7 +13240,7 @@ BB114_8:
        shr.u64         %rd68, %rd97, %r36;
        or.b64          %rd98, %rd68, %rd67;
 
-BB114_10:
+BB116_10:
        mov.u64         %rd72, -3958705157555305931;
        // inline asm
        {
@@ -13154,7 +13261,7 @@ BB114_10:
        }
        // inline asm
        setp.lt.s64     %p8, %rd100, 1;
-       @%p8 bra        BB114_12;
+       @%p8 bra        BB116_12;
 
        // inline asm
        {
@@ -13173,7 +13280,7 @@ BB114_10:
        // inline asm
        add.s32         %r41, %r41, 1;
 
-BB114_12:
+BB116_12:
        cvt.u64.u32     %rd79, %r40;
        shl.b64         %rd80, %rd79, 32;
        mov.u32         %r37, 1022;
@@ -13188,7 +13295,7 @@ BB114_12:
        or.b64          %rd88, %rd87, %rd80;
        mov.b64          %fd4, %rd88;
 
-BB114_13:
+BB116_13:
        st.param.f64    [func_retval0+0], %fd4;
        ret;
 }
@@ -13216,7 +13323,7 @@ BB114_13:
        }
        shr.u32         %r51, %r50, 20;
        setp.ne.s32     %p1, %r51, 0;
-       @%p1 bra        BB115_2;
+       @%p1 bra        BB117_2;
 
        mul.f64         %fd14, %fd12, 0d4350000000000000;
        {
@@ -13230,13 +13337,13 @@ BB114_13:
        shr.u32         %r16, %r50, 20;
        add.s32         %r51, %r16, -54;
 
-BB115_2:
+BB117_2:
        add.s32         %r52, %r51, -1023;
        and.b32         %r17, %r50, -2146435073;
        or.b32          %r18, %r17, 1072693248;
        mov.b64         %fd135, {%r49, %r18};
        setp.lt.u32     %p2, %r18, 1073127583;
-       @%p2 bra        BB115_4;
+       @%p2 bra        BB117_4;
 
        {
        .reg .b32 %temp; 
@@ -13250,7 +13357,7 @@ BB115_2:
        mov.b64         %fd135, {%r19, %r21};
        add.s32         %r52, %r51, -1022;
 
-BB115_4:
+BB117_4:
        add.f64         %fd15, %fd135, 0d3FF0000000000000;
        rcp.approx.ftz.f64      %fd16, %fd15;
        neg.f64         %fd17, %fd15;
@@ -13413,13 +13520,13 @@ BB115_4:
        mov.b32          %f2, %r35;
        abs.f32         %f1, %f2;
        setp.lt.f32     %p4, %f1, 0f4086232B;
-       @%p4 bra        BB115_7;
+       @%p4 bra        BB117_7;
 
        setp.lt.f64     %p5, %fd4, 0d0000000000000000;
        add.f64         %fd129, %fd4, 0d7FF0000000000000;
        selp.f64        %fd136, 0d0000000000000000, %fd129, %p5;
        setp.geu.f32    %p6, %f1, 0f40874800;
-       @%p6 bra        BB115_7;
+       @%p6 bra        BB117_7;
 
        mov.f64         %fd134, 0d4338000000000000;
        mov.f64         %fd133, 0d3FF71547652B82FE;
@@ -13441,26 +13548,26 @@ BB115_4:
        mov.b64         %fd131, {%r44, %r43};
        mul.f64         %fd136, %fd130, %fd131;
 
-BB115_7:
+BB117_7:
        {
        .reg .b32 %temp; 
        mov.b64         {%temp, %r45}, %fd136;
        }
        and.b32         %r46, %r45, 2147483647;
        setp.ne.s32     %p7, %r46, 2146435072;
-       @%p7 bra        BB115_9;
+       @%p7 bra        BB117_9;
 
        {
        .reg .b32 %temp; 
        mov.b64         {%r47, %temp}, %fd136;
        }
        setp.eq.s32     %p8, %r47, 0;
-       @%p8 bra        BB115_10;
+       @%p8 bra        BB117_10;
 
-BB115_9:
+BB117_9:
        fma.rn.f64      %fd136, %fd136, %fd5, %fd136;
 
-BB115_10:
+BB117_10:
        st.param.f64    [func_retval0+0], %fd136;
        ret;
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/04bc667f/src/main/java/org/apache/sysml/hops/DnnOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/DnnOp.java 
b/src/main/java/org/apache/sysml/hops/DnnOp.java
index 4e22f59..4ca90f8 100644
--- a/src/main/java/org/apache/sysml/hops/DnnOp.java
+++ b/src/main/java/org/apache/sysml/hops/DnnOp.java
@@ -136,6 +136,7 @@ public class DnnOp extends MultiThreadedHop
                        }
                        case BATCH_NORM2D_TEST:
                        case CHANNEL_SUMS:
+                       case UPDATE_NESTEROV_X:
                        {       
                                if(et == ExecType.GPU) {
                                        setLops(constructDnnLops(et, inputs));
@@ -175,6 +176,8 @@ public class DnnOp extends MultiThreadedHop
                                return 6;
                        case CHANNEL_SUMS:
                                return 3;
+                       case UPDATE_NESTEROV_X:
+                               return 4;
                        default:
                                return 13;
                }
@@ -528,7 +531,9 @@ public class DnnOp extends MultiThreadedHop
                // [numRows, numCols, NNZ] 
                long[] ret = new long[3];
                
-               if(op == OpOpDnn.BIASADD || op == OpOpDnn.BIASMULT || op == 
OpOpDnn.BATCH_NORM2D_TEST) {
+               if(op == OpOpDnn.BIASADD || op == OpOpDnn.BIASMULT || op == 
OpOpDnn.BATCH_NORM2D_TEST ||
+                       op == OpOpDnn.UPDATE_NESTEROV_X) {
+                       // Same dimension as the first input
                        MatrixCharacteristics[] mc = 
memo.getAllInputStats(getInput());
                        ret[0] = mc[0].rowsKnown() ? mc[0].getRows() : -1;
                        ret[1] = mc[0].colsKnown() ? mc[0].getCols() : -1;
@@ -734,7 +739,8 @@ public class DnnOp extends MultiThreadedHop
        @Override
        public void refreshSizeInformation()
        {
-               if(op == OpOpDnn.BIASADD || op == OpOpDnn.BIASMULT || op == 
OpOpDnn.BATCH_NORM2D_TEST) {
+               if(op == OpOpDnn.BIASADD || op == OpOpDnn.BIASMULT || op == 
OpOpDnn.BATCH_NORM2D_TEST || op == OpOpDnn.UPDATE_NESTEROV_X) {
+                       // Same dimension as the first input
                        Hop input1 = getInput().get(0);
                        setDim1(input1.getDim1());
                        setDim2(input1.getDim2());
@@ -840,8 +846,9 @@ public class DnnOp extends MultiThreadedHop
         * @return either -1 or value associated with the dimString
         */
        private long getDim(String dimString) {
-               if(op == OpOpDnn.BIASADD || op == OpOpDnn.BIASMULT || op == 
OpOpDnn.BATCH_NORM2D_TEST || op == OpOpDnn.CHANNEL_SUMS) {
-                       throw new RuntimeException("getDim method should not be 
invoked for batch_norm_test, channel_sums, bias_add and bias_multiply");
+               if(op == OpOpDnn.BIASADD || op == OpOpDnn.BIASMULT || op == 
OpOpDnn.BATCH_NORM2D_TEST || op == OpOpDnn.CHANNEL_SUMS ||
+                       op == OpOpDnn.UPDATE_NESTEROV_X) {
+                       throw new RuntimeException("getDim method should not be 
invoked for " + op.name());
                }
                try {
                        parseInput();

http://git-wip-us.apache.org/repos/asf/systemml/blob/04bc667f/src/main/java/org/apache/sysml/hops/Hop.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/Hop.java 
b/src/main/java/org/apache/sysml/hops/Hop.java
index 6466575..73a58e3 100644
--- a/src/main/java/org/apache/sysml/hops/Hop.java
+++ b/src/main/java/org/apache/sysml/hops/Hop.java
@@ -1099,7 +1099,8 @@ public abstract class Hop implements ParseInfo
        public enum OpOpDnn {
                MAX_POOL, MAX_POOL_BACKWARD, AVG_POOL, AVG_POOL_BACKWARD,
                CONV2D, CONV2D_BACKWARD_FILTER, CONV2D_BACKWARD_DATA,
-               BIASADD, BIASMULT, BATCH_NORM2D_TEST, CHANNEL_SUMS
+               BIASADD, BIASMULT, BATCH_NORM2D_TEST, CHANNEL_SUMS,
+               UPDATE_NESTEROV_X
        }
        
        public enum DataGenMethod {
@@ -1174,6 +1175,7 @@ public abstract class Hop implements ParseInfo
                HopsConv2Lops.put(OpOpDnn.CONV2D_BACKWARD_DATA, 
org.apache.sysml.lops.DnnTransform.OperationTypes.CONV2D_BACKWARD_DATA);
                HopsConv2Lops.put(OpOpDnn.BATCH_NORM2D_TEST, 
org.apache.sysml.lops.DnnTransform.OperationTypes.BATCH_NORM2D_TEST);
                HopsConv2Lops.put(OpOpDnn.CHANNEL_SUMS, 
org.apache.sysml.lops.DnnTransform.OperationTypes.CHANNEL_SUMS);
+               HopsConv2Lops.put(OpOpDnn.UPDATE_NESTEROV_X, 
org.apache.sysml.lops.DnnTransform.OperationTypes.UPDATE_NESTEROV_X);
        }
 
        protected static final HashMap<Hop.Direction, 
org.apache.sysml.lops.PartialAggregate.DirectionTypes> HopsDirection2Lops;

http://git-wip-us.apache.org/repos/asf/systemml/blob/04bc667f/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java
index 2a1699d..b603aa7 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java
@@ -124,6 +124,7 @@ public class RewriteGPUSpecificOps extends HopRewriteRule {
                        }
                        hi = batchNormTest(hop, hi, i); 
                        hi = channelSums(hop, hi, i); 
+                       hi = updateNesterovX(hop, hi, i);
        
                        if( !descendFirst )
                                rule_GPUKernels(roots, hi, descendFirst);
@@ -281,6 +282,11 @@ public class RewriteGPUSpecificOps extends HopRewriteRule {
                                && getFirstInput(h).getDataType() == 
DataType.MATRIX && getSecondInput(h).getDataType() == DataType.MATRIX;
        }
        
+       private static boolean isBinaryMMMinus(Hop h) {
+               return h instanceof BinaryOp && ((BinaryOp)h).getOp() == 
OpOp2.MINUS 
+                               && getFirstInput(h).getDataType() == 
DataType.MATRIX && getSecondInput(h).getDataType() == DataType.MATRIX;
+       }
+       
        private static boolean isBinaryMSMult(Hop h, double expectedValue) {
                return h instanceof BinaryOp && ((BinaryOp)h).getOp() == 
OpOp2.MULT 
                                && getFirstInput(h).getDataType() == 
DataType.MATRIX && getSecondInput(h).getDataType() == DataType.SCALAR
@@ -323,6 +329,16 @@ public class RewriteGPUSpecificOps extends HopRewriteRule {
                                && getSecondInput(h).getDataType() == 
DataType.MATRIX && getFirstInput(h).getDataType() == DataType.SCALAR;
        }
        
+       private static boolean isBinarySMMult(Hop h, double expectedVal) {
+               return h instanceof BinaryOp && ((BinaryOp)h).getOp() == 
OpOp2.MULT 
+                               && getSecondInput(h).getDataType() == 
DataType.MATRIX && getFirstInput(h).getDataType() == DataType.SCALAR
+                               && getValue(getFirstInput(h)) == expectedVal;
+       }
+       
+       private static double getValue(Hop h) {
+               return OptimizerUtils.rEvalSimpleDoubleExpression(h, new 
HashMap<>());
+       }
+       
        /**
         * Checks if the "mean" hop is a moving average of mean in batch 
normalization layer.
         *  
@@ -704,6 +720,51 @@ public class RewriteGPUSpecificOps extends HopRewriteRule {
        // ------------------------------------------------------------
        
        /**
+        * Checks for the nesterov_update_x pattern (X = X - mu*v_prev + 
(1+mu)*v)
+        * and returns a new DnnOp if matched
+        * 
+        * @param parent parent of the input
+        * @param hi input to be matched
+        * @param pos position
+        * @return a new DnnOp or hi
+        */
+       private static Hop updateNesterovX(Hop parent, Hop hi, int pos) {
+               if(fitsOnGPU(hi, 4) && isBinaryMMAdd(hi) && 
isBinaryMMMinus(getFirstInput(hi))
+                       && isBinarySMMult(getSecondInput(getFirstInput(hi))) 
+                       && isBinarySMMult(getSecondInput(hi))) {
+                       Hop onePlusMu = getFirstInput(getSecondInput(hi));
+                       Hop tmp = getSecondInput(getFirstInput(hi));
+                       Hop mu = getFirstInput(tmp);
+                       if(isOnePlusMu(onePlusMu, mu)) {
+                               Hop v_prev = getSecondInput(tmp);
+                               Hop v = getSecondInput(getSecondInput(hi));
+                               Hop X = getFirstInput(getFirstInput(hi));
+                               if(hasSameDimensions(X, v) && 
hasSameDimensions(X, v_prev)) {
+                                       ArrayList<Hop> inHops = new 
ArrayList<Hop>();
+                                       inHops.add(X);
+                                       inHops.add(v);
+                                       inHops.add(v_prev);
+                                       inHops.add(mu);
+                                       LOG.debug("Applied updateNesterovX 
rewrite.");
+                                       Hop newHop = new DnnOp(hi.getName(), 
hi.getDataType(), hi.getValueType(),
+                                                       
OpOpDnn.UPDATE_NESTEROV_X, inHops);
+                                       return 
HopRewriteUtils.rewireAllParentChildReferences(hi, newHop);
+                               }
+                       }
+               }
+               return hi;
+       }
+       
+       private static boolean hasSameDimensions(Hop x, Hop y) {
+               return x.dimsKnown() && y.dimsKnown() && (x.getDim1() == 
y.getDim1()) && (x.getDim2() == y.getDim2());
+       }
+       
+       private static boolean isOnePlusMu(Hop onePlusMu, Hop mu) {
+               return (isBinarySMMult(onePlusMu, 1.0) && 
getSecondInput(onePlusMu) == mu) ||
+                               getValue(onePlusMu) == getValue(mu) + 1;
+       }
+       
+       /**
         * Checks for the batch norm (mode="test") pattern using the helper 
isBatchNormTrainMean and isBatchNormTrainVar
         * and returns a new DnnOp if matched
         * 

http://git-wip-us.apache.org/repos/asf/systemml/blob/04bc667f/src/main/java/org/apache/sysml/lops/DnnTransform.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/lops/DnnTransform.java 
b/src/main/java/org/apache/sysml/lops/DnnTransform.java
index 6c61d4a..3183b5f 100644
--- a/src/main/java/org/apache/sysml/lops/DnnTransform.java
+++ b/src/main/java/org/apache/sysml/lops/DnnTransform.java
@@ -31,7 +31,8 @@ public class DnnTransform extends Lop
                MAX_POOL, MAX_POOL_BACKWARD, AVG_POOL, AVG_POOL_BACKWARD,
                RELU_MAX_POOLING, RELU_MAX_POOLING_BACKWARD, RELU_BACKWARD,
                CONV2D, CONV2D_BACKWARD_FILTER, CONV2D_BACKWARD_DATA,
-               BIAS_ADD, CONV2D_BIAS_ADD, BIAS_MULTIPLY, CHANNEL_SUMS, 
BATCH_NORM2D_TEST
+               BIAS_ADD, CONV2D_BIAS_ADD, BIAS_MULTIPLY, CHANNEL_SUMS, 
BATCH_NORM2D_TEST, 
+               UPDATE_NESTEROV_X
        }
        
        private OperationTypes operation;
@@ -165,6 +166,9 @@ public class DnnTransform extends Lop
                        
                case CHANNEL_SUMS:
                        return "channel_sums";
+               
+               case UPDATE_NESTEROV_X:
+                       return "update_nesterov_x";
                        
                case BATCH_NORM2D_TEST:
                        return "batch_norm2d_test";
@@ -232,6 +236,33 @@ public class DnnTransform extends Lop
        }
        
        @Override
+       public String getInstructions(String input1, String input2, String 
input3, String input4, String output) {
+               if(operation == OperationTypes.UPDATE_NESTEROV_X) {
+                       StringBuilder sb = new StringBuilder();
+                       sb.append( getExecType() );
+                       
+                       sb.append( OPERAND_DELIMITOR );
+                       sb.append( getOpcode() );
+                       sb.append( OPERAND_DELIMITOR );
+                       sb.append( getInputs().get(0).prepInputOperand(input1));
+                       sb.append( OPERAND_DELIMITOR );
+                       sb.append( getInputs().get(1).prepInputOperand(input2));
+                       sb.append( OPERAND_DELIMITOR );
+                       sb.append( getInputs().get(2).prepInputOperand(input3));
+                       sb.append( OPERAND_DELIMITOR );
+                       sb.append( getInputs().get(3).prepInputOperand(input4));
+                       //output
+                       sb.append( OPERAND_DELIMITOR );
+                       sb.append( this.prepOutputOperand(output));
+                       
+                       return sb.toString();
+               }
+               else {
+                       throw new LopsException("The operation is not supported 
with three operands:" + operation.name());
+               }
+       }
+       
+       @Override
        public String getInstructions(String[] inputs, String output) {
                StringBuilder sb = new StringBuilder();
                appendOpcode(sb);

http://git-wip-us.apache.org/repos/asf/systemml/blob/04bc667f/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java 
b/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java
index c90f9f9..f4122d9 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java
@@ -63,6 +63,7 @@ public class GPUInstructionParser  extends InstructionParser
                String2GPUInstructionType.put( "batch_norm2d_backward",  
GPUINSTRUCTION_TYPE.Dnn);
                String2GPUInstructionType.put( "batch_norm2d_test",      
GPUINSTRUCTION_TYPE.Dnn);
                String2GPUInstructionType.put( "batch_norm2d_train",      
GPUINSTRUCTION_TYPE.Dnn);
+               String2GPUInstructionType.put( "update_nesterov_x",      
GPUINSTRUCTION_TYPE.Dnn);
                
                // Matrix Multiply Operators
                String2GPUInstructionType.put( "ba+*",  
GPUINSTRUCTION_TYPE.AggregateBinary);

http://git-wip-us.apache.org/repos/asf/systemml/blob/04bc667f/src/main/java/org/apache/sysml/runtime/instructions/gpu/DnnGPUInstruction.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/gpu/DnnGPUInstruction.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/gpu/DnnGPUInstruction.java
index a36d0fc..8d89032 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/gpu/DnnGPUInstruction.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/gpu/DnnGPUInstruction.java
@@ -124,6 +124,21 @@ public class DnnGPUInstruction extends GPUInstruction {
                _intermediateMemoryBudget = intermediateMemoryBudget;
        }
        
+       public DnnGPUInstruction(CPOperand in1, CPOperand in2, CPOperand in3, 
CPOperand in4, CPOperand out, String opcode, String istr, 
+                       double intermediateMemoryBudget) throws 
DMLRuntimeException {
+               super(new ReorgOperator(SwapIndex.getSwapIndexFnObject()), 
opcode, istr);
+               if( !opcode.equals("update_nesterov_x") ) {
+                       throw new DMLRuntimeException("Incorrect opcode: " + 
opcode);
+               }
+               _input1 = in1;
+               _input2 = in2;
+               _input3 = in3;
+               _input4 = in4;
+               _gputype = GPUINSTRUCTION_TYPE.Dnn;
+               _output = out;
+               _intermediateMemoryBudget = intermediateMemoryBudget;
+       }
+       
        public DnnGPUInstruction(CPOperand in1, CPOperand in2, CPOperand in3, 
CPOperand out, String opcode,
                        String istr, ArrayList<CPOperand> stride,
                        ArrayList<CPOperand> padding, ArrayList<CPOperand> 
input_shape,
@@ -298,6 +313,15 @@ public class DnnGPUInstruction extends GPUInstruction {
                        CPOperand out = new CPOperand(parts[4]);
                        return new DnnGPUInstruction(in, in2, in3, out, opcode, 
str, 0);
                }
+               else if (opcode.equalsIgnoreCase("update_nesterov_x")) {
+                       InstructionUtils.checkNumFields(parts, 5);
+                       CPOperand in = new CPOperand(parts[1]);
+                       CPOperand in2 = new CPOperand(parts[2]);
+                       CPOperand in3 = new CPOperand(parts[3]);
+                       CPOperand in4 = new CPOperand(parts[4]);
+                       CPOperand out = new CPOperand(parts[5]);
+                       return new DnnGPUInstruction(in, in2, in3, in4, out, 
opcode, str, 0);
+               }
                else if (opcode.equalsIgnoreCase("lstm")) {
                        InstructionUtils.checkNumFields(parts, 8);
                        CPOperand in1 = new CPOperand(parts[1]);
@@ -552,6 +576,34 @@ public class DnnGPUInstruction extends GPUInstruction {
                ec.releaseMatrixOutputForGPUInstruction(_output.getName());
        }
        
+       private void processNesterovUpdateInstruction(ExecutionContext ec) {
+               GPUStatistics.incrementNoOfExecutedGPUInst();;
+               MatrixObject input = getMatrixInputForGPUInstruction(ec, 
_input1.getName());
+               MatrixObject v = getMatrixInputForGPUInstruction(ec, 
_input2.getName());
+               MatrixObject v_prev = getMatrixInputForGPUInstruction(ec, 
_input3.getName());
+               double mu = (int) ec.getScalarInput(_input4.getName(), 
_input4.getValueType(), _input4.isLiteral()).getDoubleValue();
+               int rows = LibMatrixCUDA.toInt(input.getNumRows());
+               int cols = LibMatrixCUDA.toInt(input.getNumColumns());
+               MatrixObject out = getDenseMatrixOutputForGPUInstruction(ec, 
_output.getName(), rows, cols);
+               
+               GPUContext gCtx = ec.getGPUContext(0);
+               String instName = getExtendedOpcode();
+               
LibMatrixCUDA.getCudaKernels(gCtx).launchKernel("update_nesterov_x", 
+                               
ExecutionConfig.getConfigForSimpleVectorOperations(LibMatrixCUDA.toInt(rows*cols)),
+                               LibMatrixCUDA.getDensePointer(gCtx, input, 
instName), 
+                               LibMatrixCUDA.getDensePointer(gCtx, v, 
instName),
+                               LibMatrixCUDA.getDensePointer(gCtx, v_prev, 
instName),
+                               mu, 
+                               LibMatrixCUDA.getDensePointer(gCtx, out, 
instName),
+                               rows*cols);
+               
+               // release inputs/outputs
+               ec.releaseMatrixInputForGPUInstruction(_input1.getName());
+               ec.releaseMatrixInputForGPUInstruction(_input2.getName());
+               ec.releaseMatrixInputForGPUInstruction(_input3.getName());
+               ec.releaseMatrixOutputForGPUInstruction(_output.getName());
+       }
+       
        private static int toInt(long num) throws DMLRuntimeException {
                if(num >= Integer.MAX_VALUE || num <= Integer.MIN_VALUE) {
                        throw new DMLRuntimeException("GPU : Exceeded supported 
size " + num);
@@ -697,6 +749,10 @@ public class DnnGPUInstruction extends GPUInstruction {
                        processChannelSumsInstruction(ec);
                        return;
                }
+               else if (instOpcode.equalsIgnoreCase("update_nesterov_x")) {
+                       processNesterovUpdateInstruction(ec);
+                       return;
+               }
                else if (instOpcode.equalsIgnoreCase("lstm")) {
                        processLstmInstruction(ec);
                        return;

http://git-wip-us.apache.org/repos/asf/systemml/blob/04bc667f/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUMemoryManager.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUMemoryManager.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUMemoryManager.java
index 3d7ab2c..5d0e4bc 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUMemoryManager.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUMemoryManager.java
@@ -405,7 +405,8 @@ public class GPUMemoryManager {
                        allPointers.remove(toFree);
                        lazyCudaFreeMemoryManager.removeIfPresent(size, toFree);
                        allocator.free(toFree);
-                       // JCuda.cudaDeviceSynchronize(); // Force a device 
synchronize after free-ing the pointer for debugging
+                       if(DMLScript.SYNCHRONIZE_GPU)
+                               jcuda.runtime.JCuda.cudaDeviceSynchronize(); // 
Force a device synchronize after free-ing the pointer for debugging
                }
                else {
                        throw new RuntimeException("Attempting to free an 
unaccounted pointer:" + toFree);
@@ -447,7 +448,7 @@ public class GPUMemoryManager {
        public void removeGPUObject(GPUObject gpuObj) {
                if(LOG.isDebugEnabled())
                        LOG.debug("Removing the GPU object: " + gpuObj);
-               matrixMemoryManager.gpuObjects.removeIf(a -> a.equals(gpuObj));
+               matrixMemoryManager.gpuObjects.remove(gpuObj);
        }
 
        

http://git-wip-us.apache.org/repos/asf/systemml/blob/04bc667f/src/test/java/org/apache/sysml/test/gpu/GPUTests.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/gpu/GPUTests.java 
b/src/test/java/org/apache/sysml/test/gpu/GPUTests.java
index cae2e33..e1ae1ae 100644
--- a/src/test/java/org/apache/sysml/test/gpu/GPUTests.java
+++ b/src/test/java/org/apache/sysml/test/gpu/GPUTests.java
@@ -311,6 +311,16 @@ public abstract class GPUTests extends AutomatedTestBase {
                Set<String> heavyHitterOpCodes = 
Statistics.getCPHeavyHitterOpCodes();
                
Assert.assertTrue(heavyHitterOpCodes.contains(heavyHitterOpCode));
        }
+       
+       /**
+        * asserts that the expected op was executed
+        *
+        * @param heavyHitterOpCode opcode of the heavy hitter for the unary op
+        */
+       protected void assertHeavyHitterNotPresent(String heavyHitterOpCode) {
+               Set<String> heavyHitterOpCodes = 
Statistics.getCPHeavyHitterOpCodes();
+               
Assert.assertTrue(!heavyHitterOpCodes.contains(heavyHitterOpCode));
+       }
 
        /**
         * Runs a program on the CPU

http://git-wip-us.apache.org/repos/asf/systemml/blob/04bc667f/src/test/java/org/apache/sysml/test/gpu/SGDUpdate.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/gpu/SGDUpdate.java 
b/src/test/java/org/apache/sysml/test/gpu/SGDUpdate.java
new file mode 100644
index 0000000..c98a74d
--- /dev/null
+++ b/src/test/java/org/apache/sysml/test/gpu/SGDUpdate.java
@@ -0,0 +1,91 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.test.gpu;
+
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import org.apache.sysml.test.utils.TestUtils;
+import org.junit.Test;
+
+/**
+ * Tests update rewrites for SGD
+ */
+public class SGDUpdate extends GPUTests {
+
+       private final static String TEST_NAME = "SGDUpdateTests";
+       private final int seed = 42;
+
+       @Override
+       public void setUp() {
+               super.setUp();
+               TestUtils.clearAssertionInformation();
+               addTestConfiguration(TEST_DIR, TEST_NAME);
+               getAndLoadTestConfiguration(TEST_NAME);
+       }
+
+       @Test
+       public void testNesterovRewrite() {
+               String scriptStr = "mu=0.99; output = x - mu*v_prev + 
(1+mu)*v;" ;
+               int inRows = 10;
+               int inCols = 30;
+               HashMap<String, Object> inputs = new HashMap<>();
+               inputs.put("x", generateInputMatrix(spark, inRows, inCols, 0, 
10, 0.9, seed));
+               inputs.put("v_prev", generateInputMatrix(spark, inRows, inCols, 
0, 10, 0.9, seed));
+               inputs.put("v", generateInputMatrix(spark, inRows, inCols, 0, 
10, 0.9, seed));
+               List<String> outputs = Arrays.asList("output");
+               List<Object> outCPU = runOnCPU(spark, scriptStr, inputs, 
outputs);
+               List<Object> outGPU = runOnGPU(spark, scriptStr, inputs, 
outputs);
+               assertHeavyHitterPresent("gpu_update_nesterov_x");
+               assertEqualObjects(outCPU.get(0), outGPU.get(0));
+       }
+       
+       @Test
+       public void testNoNesterovRewrite1() {
+               String scriptStr = "mu=0.99; output = x - mu*v_prev + 
(1+mu)*v;" ;
+               int inRows = 10;
+               int inCols = 30;
+               HashMap<String, Object> inputs = new HashMap<>();
+               inputs.put("x", generateInputMatrix(spark, inRows, inCols, 0, 
10, 0.9, seed));
+               inputs.put("v_prev", generateInputMatrix(spark, inRows, inCols, 
0, 10, 0.9, seed));
+               inputs.put("v", generateInputMatrix(spark, inRows, 1, 0, 10, 
0.9, seed));
+               List<String> outputs = Arrays.asList("output");
+               List<Object> outCPU = runOnCPU(spark, scriptStr, inputs, 
outputs);
+               List<Object> outGPU = runOnGPU(spark, scriptStr, inputs, 
outputs);
+               assertHeavyHitterNotPresent("gpu_update_nesterov_x");
+               assertEqualObjects(outCPU.get(0), outGPU.get(0));
+       }
+       
+       @Test
+       public void testNoNesterovRewrite2() {
+               String scriptStr = "mu=0.99; output = x - mu*v_prev + mu*v;" ;
+               int inRows = 10;
+               int inCols = 30;
+               HashMap<String, Object> inputs = new HashMap<>();
+               inputs.put("x", generateInputMatrix(spark, inRows, inCols, 0, 
10, 0.9, seed));
+               inputs.put("v_prev", generateInputMatrix(spark, inRows, inCols, 
0, 10, 0.9, seed));
+               inputs.put("v", generateInputMatrix(spark, inRows, inCols, 0, 
10, 0.9, seed));
+               List<String> outputs = Arrays.asList("output");
+               List<Object> outCPU = runOnCPU(spark, scriptStr, inputs, 
outputs);
+               List<Object> outGPU = runOnGPU(spark, scriptStr, inputs, 
outputs);
+               assertHeavyHitterNotPresent("gpu_update_nesterov_x");
+               assertEqualObjects(outCPU.get(0), outGPU.get(0));
+       }
+}

Reply via email to