Repository: systemml Updated Branches: refs/heads/master 5ca8706e9 -> 04bc667f3
[SYSTEMML-445] Added SGD Nesterov update operator via rewrite for the GPU backend - This leads to 10-15% speedup for ResNet200 with batch size of 32. - Also, added GPU tests for this operator. Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/04bc667f Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/04bc667f Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/04bc667f Branch: refs/heads/master Commit: 04bc667f3650d57c0bc9de20e46e7624205cc1e6 Parents: 5ca8706 Author: Niketan Pansare <[email protected]> Authored: Thu Aug 9 21:00:21 2018 -0700 Committer: Niketan Pansare <[email protected]> Committed: Thu Aug 9 21:00:21 2018 -0700 ---------------------------------------------------------------------- src/main/cpp/kernels/SystemML.cu | 23 ++- src/main/cpp/kernels/SystemML.ptx | 161 +++++++++++++++---- src/main/java/org/apache/sysml/hops/DnnOp.java | 15 +- src/main/java/org/apache/sysml/hops/Hop.java | 4 +- .../hops/rewrite/RewriteGPUSpecificOps.java | 61 +++++++ .../org/apache/sysml/lops/DnnTransform.java | 33 +++- .../instructions/GPUInstructionParser.java | 1 + .../instructions/gpu/DnnGPUInstruction.java | 56 +++++++ .../gpu/context/GPUMemoryManager.java | 5 +- .../org/apache/sysml/test/gpu/GPUTests.java | 10 ++ .../org/apache/sysml/test/gpu/SGDUpdate.java | 91 +++++++++++ 11 files changed, 419 insertions(+), 41 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/04bc667f/src/main/cpp/kernels/SystemML.cu ---------------------------------------------------------------------- diff --git a/src/main/cpp/kernels/SystemML.cu b/src/main/cpp/kernels/SystemML.cu index 485b7e2..9ddaaff 100644 --- a/src/main/cpp/kernels/SystemML.cu +++ b/src/main/cpp/kernels/SystemML.cu @@ -2248,12 +2248,6 @@ extern "C" __global__ void prepare_lstm_dinput_f(float* smlInput, float* cudnnIn } -/** - * Do an log over all the elements of a matrix - * @param A the input matrix (of length = size) - * @param C the pre-allocated output matrix (of length = size) - * @param size the length of the input and output matrices - */ template <typename T> __device__ void colwise_reshape(T *A, T *C, unsigned int size, unsigned int inRows, unsigned int inCols, @@ -2278,4 +2272,21 @@ extern "C" __global__ void colwise_reshape_f(float *A, float *C, unsigned int si unsigned int inRows, unsigned int inCols, unsigned int outRows, unsigned int outCols) { colwise_reshape(A, C, size, inRows, inCols, outRows, outCols); +} + +// Performs the operation: out = X - mu*v_prev + (1+mu)*v +template <typename T> +__device__ void update_nesterov_x(T *X, T *v, T *v_prev, double mu, T *out, unsigned int size) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + if (index < size) { + out[index] = X[index] - mu*v_prev[index] + (1+mu)*v[index]; + } +} + +extern "C" __global__ void update_nesterov_x_d(double *X, double *v, double *v_prev, double mu, double *out, unsigned int size) { + update_nesterov_x(X, v, v_prev, mu, out, size); +} + +extern "C" __global__ void update_nesterov_x_f(float *X, float *v, float *v_prev, double mu, float *out, unsigned int size) { + update_nesterov_x(X, v, v_prev, mu, out, size); } \ No newline at end of file http://git-wip-us.apache.org/repos/asf/systemml/blob/04bc667f/src/main/cpp/kernels/SystemML.ptx ---------------------------------------------------------------------- diff --git a/src/main/cpp/kernels/SystemML.ptx b/src/main/cpp/kernels/SystemML.ptx index 93e5e35..8a14876 100644 --- a/src/main/cpp/kernels/SystemML.ptx +++ b/src/main/cpp/kernels/SystemML.ptx @@ -12977,12 +12977,119 @@ BB113_2: ret; } + // .globl update_nesterov_x_d +.visible .entry update_nesterov_x_d( + .param .u64 update_nesterov_x_d_param_0, + .param .u64 update_nesterov_x_d_param_1, + .param .u64 update_nesterov_x_d_param_2, + .param .f64 update_nesterov_x_d_param_3, + .param .u64 update_nesterov_x_d_param_4, + .param .u32 update_nesterov_x_d_param_5 +) +{ + .reg .pred %p<2>; + .reg .b32 %r<6>; + .reg .f64 %fd<9>; + .reg .b64 %rd<14>; + + + ld.param.u64 %rd1, [update_nesterov_x_d_param_0]; + ld.param.u64 %rd2, [update_nesterov_x_d_param_1]; + ld.param.u64 %rd3, [update_nesterov_x_d_param_2]; + ld.param.f64 %fd1, [update_nesterov_x_d_param_3]; + ld.param.u64 %rd4, [update_nesterov_x_d_param_4]; + ld.param.u32 %r2, [update_nesterov_x_d_param_5]; + mov.u32 %r3, %ctaid.x; + mov.u32 %r4, %ntid.x; + mov.u32 %r5, %tid.x; + mad.lo.s32 %r1, %r4, %r3, %r5; + setp.ge.u32 %p1, %r1, %r2; + @%p1 bra BB114_2; + + cvta.to.global.u64 %rd5, %rd1; + mul.wide.s32 %rd6, %r1, 8; + add.s64 %rd7, %rd5, %rd6; + cvta.to.global.u64 %rd8, %rd3; + add.s64 %rd9, %rd8, %rd6; + ld.global.f64 %fd2, [%rd9]; + mul.f64 %fd3, %fd2, %fd1; + ld.global.f64 %fd4, [%rd7]; + sub.f64 %fd5, %fd4, %fd3; + cvta.to.global.u64 %rd10, %rd2; + add.s64 %rd11, %rd10, %rd6; + ld.global.f64 %fd6, [%rd11]; + add.f64 %fd7, %fd1, 0d3FF0000000000000; + fma.rn.f64 %fd8, %fd7, %fd6, %fd5; + cvta.to.global.u64 %rd12, %rd4; + add.s64 %rd13, %rd12, %rd6; + st.global.f64 [%rd13], %fd8; + +BB114_2: + ret; +} + + // .globl update_nesterov_x_f +.visible .entry update_nesterov_x_f( + .param .u64 update_nesterov_x_f_param_0, + .param .u64 update_nesterov_x_f_param_1, + .param .u64 update_nesterov_x_f_param_2, + .param .f64 update_nesterov_x_f_param_3, + .param .u64 update_nesterov_x_f_param_4, + .param .u32 update_nesterov_x_f_param_5 +) +{ + .reg .pred %p<2>; + .reg .f32 %f<5>; + .reg .b32 %r<6>; + .reg .f64 %fd<9>; + .reg .b64 %rd<14>; + + + ld.param.u64 %rd1, [update_nesterov_x_f_param_0]; + ld.param.u64 %rd2, [update_nesterov_x_f_param_1]; + ld.param.u64 %rd3, [update_nesterov_x_f_param_2]; + ld.param.f64 %fd1, [update_nesterov_x_f_param_3]; + ld.param.u64 %rd4, [update_nesterov_x_f_param_4]; + ld.param.u32 %r2, [update_nesterov_x_f_param_5]; + mov.u32 %r3, %ctaid.x; + mov.u32 %r4, %ntid.x; + mov.u32 %r5, %tid.x; + mad.lo.s32 %r1, %r4, %r3, %r5; + setp.ge.u32 %p1, %r1, %r2; + @%p1 bra BB115_2; + + cvta.to.global.u64 %rd5, %rd1; + mul.wide.s32 %rd6, %r1, 4; + add.s64 %rd7, %rd5, %rd6; + ld.global.f32 %f1, [%rd7]; + cvt.f64.f32 %fd2, %f1; + cvta.to.global.u64 %rd8, %rd3; + add.s64 %rd9, %rd8, %rd6; + ld.global.f32 %f2, [%rd9]; + cvt.f64.f32 %fd3, %f2; + mul.f64 %fd4, %fd3, %fd1; + sub.f64 %fd5, %fd2, %fd4; + cvta.to.global.u64 %rd10, %rd2; + add.s64 %rd11, %rd10, %rd6; + ld.global.f32 %f3, [%rd11]; + cvt.f64.f32 %fd6, %f3; + add.f64 %fd7, %fd1, 0d3FF0000000000000; + fma.rn.f64 %fd8, %fd7, %fd6, %fd5; + cvt.rn.f32.f64 %f4, %fd8; + cvta.to.global.u64 %rd12, %rd4; + add.s64 %rd13, %rd12, %rd6; + st.global.f32 [%rd13], %f4; + +BB115_2: + ret; +} + .func (.param .b64 func_retval0) __internal_trig_reduction_slowpathd( .param .b64 __internal_trig_reduction_slowpathd_param_0, .param .b64 __internal_trig_reduction_slowpathd_param_1 ) { - .local .align 8 .b8 __local_depot114[40]; + .local .align 8 .b8 __local_depot116[40]; .reg .b64 %SP; .reg .b64 %SPL; .reg .pred %p<9>; @@ -12991,7 +13098,7 @@ BB113_2: .reg .b64 %rd<102>; - mov.u64 %rd101, __local_depot114; + mov.u64 %rd101, __local_depot116; cvta.local.u64 %SP, %rd101; ld.param.f64 %fd4, [__internal_trig_reduction_slowpathd_param_0]; ld.param.u64 %rd37, [__internal_trig_reduction_slowpathd_param_1]; @@ -13005,7 +13112,7 @@ BB113_2: shr.u32 %r3, %r1, 20; bfe.u32 %r4, %r1, 20, 11; setp.eq.s32 %p1, %r4, 2047; - @%p1 bra BB114_13; + @%p1 bra BB116_13; add.s32 %r15, %r4, -1024; shr.u32 %r16, %r15, 6; @@ -13018,7 +13125,7 @@ BB113_2: mov.u64 %rd94, 0; setp.ge.s32 %p2, %r5, %r6; mov.u64 %rd93, %rd1; - @%p2 bra BB114_4; + @%p2 bra BB116_4; mov.b64 %rd41, %fd4; shl.b64 %rd42, %rd41, 11; @@ -13035,7 +13142,7 @@ BB113_2: mov.u64 %rd91, %rd1; mov.u32 %r39, %r5; -BB114_3: +BB116_3: .pragma "nounroll"; ld.const.u64 %rd47, [%rd89]; // inline asm @@ -13065,15 +13172,15 @@ BB114_3: add.s64 %rd93, %rd93, 8; add.s64 %rd89, %rd89, 8; setp.lt.s32 %p3, %r39, %r6; - @%p3 bra BB114_3; + @%p3 bra BB116_3; -BB114_4: +BB116_4: st.local.u64 [%rd93], %rd94; ld.local.u64 %rd95, [%rd1+16]; ld.local.u64 %rd96, [%rd1+24]; and.b32 %r9, %r3, 63; setp.eq.s32 %p4, %r9, 0; - @%p4 bra BB114_6; + @%p4 bra BB116_6; mov.u32 %r27, 64; sub.s32 %r28, %r27, %r9; @@ -13085,7 +13192,7 @@ BB114_4: shr.u64 %rd55, %rd54, %r28; or.b64 %rd95, %rd55, %rd53; -BB114_6: +BB116_6: cvta.to.local.u64 %rd56, %rd37; shr.u64 %rd57, %rd96, 62; cvt.u32.u64 %r29, %rd57; @@ -13102,7 +13209,7 @@ BB114_6: selp.b32 %r34, %r32, %r33, %p5; st.local.u32 [%rd56], %r34; setp.eq.s32 %p6, %r31, 0; - @%p6 bra BB114_8; + @%p6 bra BB116_8; mov.u64 %rd64, 0; // inline asm @@ -13122,10 +13229,10 @@ BB114_6: // inline asm xor.b32 %r40, %r40, -2147483648; -BB114_8: +BB116_8: clz.b64 %r41, %rd98; setp.eq.s32 %p7, %r41, 0; - @%p7 bra BB114_10; + @%p7 bra BB116_10; shl.b64 %rd67, %rd98, %r41; mov.u32 %r35, 64; @@ -13133,7 +13240,7 @@ BB114_8: shr.u64 %rd68, %rd97, %r36; or.b64 %rd98, %rd68, %rd67; -BB114_10: +BB116_10: mov.u64 %rd72, -3958705157555305931; // inline asm { @@ -13154,7 +13261,7 @@ BB114_10: } // inline asm setp.lt.s64 %p8, %rd100, 1; - @%p8 bra BB114_12; + @%p8 bra BB116_12; // inline asm { @@ -13173,7 +13280,7 @@ BB114_10: // inline asm add.s32 %r41, %r41, 1; -BB114_12: +BB116_12: cvt.u64.u32 %rd79, %r40; shl.b64 %rd80, %rd79, 32; mov.u32 %r37, 1022; @@ -13188,7 +13295,7 @@ BB114_12: or.b64 %rd88, %rd87, %rd80; mov.b64 %fd4, %rd88; -BB114_13: +BB116_13: st.param.f64 [func_retval0+0], %fd4; ret; } @@ -13216,7 +13323,7 @@ BB114_13: } shr.u32 %r51, %r50, 20; setp.ne.s32 %p1, %r51, 0; - @%p1 bra BB115_2; + @%p1 bra BB117_2; mul.f64 %fd14, %fd12, 0d4350000000000000; { @@ -13230,13 +13337,13 @@ BB114_13: shr.u32 %r16, %r50, 20; add.s32 %r51, %r16, -54; -BB115_2: +BB117_2: add.s32 %r52, %r51, -1023; and.b32 %r17, %r50, -2146435073; or.b32 %r18, %r17, 1072693248; mov.b64 %fd135, {%r49, %r18}; setp.lt.u32 %p2, %r18, 1073127583; - @%p2 bra BB115_4; + @%p2 bra BB117_4; { .reg .b32 %temp; @@ -13250,7 +13357,7 @@ BB115_2: mov.b64 %fd135, {%r19, %r21}; add.s32 %r52, %r51, -1022; -BB115_4: +BB117_4: add.f64 %fd15, %fd135, 0d3FF0000000000000; rcp.approx.ftz.f64 %fd16, %fd15; neg.f64 %fd17, %fd15; @@ -13413,13 +13520,13 @@ BB115_4: mov.b32 %f2, %r35; abs.f32 %f1, %f2; setp.lt.f32 %p4, %f1, 0f4086232B; - @%p4 bra BB115_7; + @%p4 bra BB117_7; setp.lt.f64 %p5, %fd4, 0d0000000000000000; add.f64 %fd129, %fd4, 0d7FF0000000000000; selp.f64 %fd136, 0d0000000000000000, %fd129, %p5; setp.geu.f32 %p6, %f1, 0f40874800; - @%p6 bra BB115_7; + @%p6 bra BB117_7; mov.f64 %fd134, 0d4338000000000000; mov.f64 %fd133, 0d3FF71547652B82FE; @@ -13441,26 +13548,26 @@ BB115_4: mov.b64 %fd131, {%r44, %r43}; mul.f64 %fd136, %fd130, %fd131; -BB115_7: +BB117_7: { .reg .b32 %temp; mov.b64 {%temp, %r45}, %fd136; } and.b32 %r46, %r45, 2147483647; setp.ne.s32 %p7, %r46, 2146435072; - @%p7 bra BB115_9; + @%p7 bra BB117_9; { .reg .b32 %temp; mov.b64 {%r47, %temp}, %fd136; } setp.eq.s32 %p8, %r47, 0; - @%p8 bra BB115_10; + @%p8 bra BB117_10; -BB115_9: +BB117_9: fma.rn.f64 %fd136, %fd136, %fd5, %fd136; -BB115_10: +BB117_10: st.param.f64 [func_retval0+0], %fd136; ret; } http://git-wip-us.apache.org/repos/asf/systemml/blob/04bc667f/src/main/java/org/apache/sysml/hops/DnnOp.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/DnnOp.java b/src/main/java/org/apache/sysml/hops/DnnOp.java index 4e22f59..4ca90f8 100644 --- a/src/main/java/org/apache/sysml/hops/DnnOp.java +++ b/src/main/java/org/apache/sysml/hops/DnnOp.java @@ -136,6 +136,7 @@ public class DnnOp extends MultiThreadedHop } case BATCH_NORM2D_TEST: case CHANNEL_SUMS: + case UPDATE_NESTEROV_X: { if(et == ExecType.GPU) { setLops(constructDnnLops(et, inputs)); @@ -175,6 +176,8 @@ public class DnnOp extends MultiThreadedHop return 6; case CHANNEL_SUMS: return 3; + case UPDATE_NESTEROV_X: + return 4; default: return 13; } @@ -528,7 +531,9 @@ public class DnnOp extends MultiThreadedHop // [numRows, numCols, NNZ] long[] ret = new long[3]; - if(op == OpOpDnn.BIASADD || op == OpOpDnn.BIASMULT || op == OpOpDnn.BATCH_NORM2D_TEST) { + if(op == OpOpDnn.BIASADD || op == OpOpDnn.BIASMULT || op == OpOpDnn.BATCH_NORM2D_TEST || + op == OpOpDnn.UPDATE_NESTEROV_X) { + // Same dimension as the first input MatrixCharacteristics[] mc = memo.getAllInputStats(getInput()); ret[0] = mc[0].rowsKnown() ? mc[0].getRows() : -1; ret[1] = mc[0].colsKnown() ? mc[0].getCols() : -1; @@ -734,7 +739,8 @@ public class DnnOp extends MultiThreadedHop @Override public void refreshSizeInformation() { - if(op == OpOpDnn.BIASADD || op == OpOpDnn.BIASMULT || op == OpOpDnn.BATCH_NORM2D_TEST) { + if(op == OpOpDnn.BIASADD || op == OpOpDnn.BIASMULT || op == OpOpDnn.BATCH_NORM2D_TEST || op == OpOpDnn.UPDATE_NESTEROV_X) { + // Same dimension as the first input Hop input1 = getInput().get(0); setDim1(input1.getDim1()); setDim2(input1.getDim2()); @@ -840,8 +846,9 @@ public class DnnOp extends MultiThreadedHop * @return either -1 or value associated with the dimString */ private long getDim(String dimString) { - if(op == OpOpDnn.BIASADD || op == OpOpDnn.BIASMULT || op == OpOpDnn.BATCH_NORM2D_TEST || op == OpOpDnn.CHANNEL_SUMS) { - throw new RuntimeException("getDim method should not be invoked for batch_norm_test, channel_sums, bias_add and bias_multiply"); + if(op == OpOpDnn.BIASADD || op == OpOpDnn.BIASMULT || op == OpOpDnn.BATCH_NORM2D_TEST || op == OpOpDnn.CHANNEL_SUMS || + op == OpOpDnn.UPDATE_NESTEROV_X) { + throw new RuntimeException("getDim method should not be invoked for " + op.name()); } try { parseInput(); http://git-wip-us.apache.org/repos/asf/systemml/blob/04bc667f/src/main/java/org/apache/sysml/hops/Hop.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/Hop.java b/src/main/java/org/apache/sysml/hops/Hop.java index 6466575..73a58e3 100644 --- a/src/main/java/org/apache/sysml/hops/Hop.java +++ b/src/main/java/org/apache/sysml/hops/Hop.java @@ -1099,7 +1099,8 @@ public abstract class Hop implements ParseInfo public enum OpOpDnn { MAX_POOL, MAX_POOL_BACKWARD, AVG_POOL, AVG_POOL_BACKWARD, CONV2D, CONV2D_BACKWARD_FILTER, CONV2D_BACKWARD_DATA, - BIASADD, BIASMULT, BATCH_NORM2D_TEST, CHANNEL_SUMS + BIASADD, BIASMULT, BATCH_NORM2D_TEST, CHANNEL_SUMS, + UPDATE_NESTEROV_X } public enum DataGenMethod { @@ -1174,6 +1175,7 @@ public abstract class Hop implements ParseInfo HopsConv2Lops.put(OpOpDnn.CONV2D_BACKWARD_DATA, org.apache.sysml.lops.DnnTransform.OperationTypes.CONV2D_BACKWARD_DATA); HopsConv2Lops.put(OpOpDnn.BATCH_NORM2D_TEST, org.apache.sysml.lops.DnnTransform.OperationTypes.BATCH_NORM2D_TEST); HopsConv2Lops.put(OpOpDnn.CHANNEL_SUMS, org.apache.sysml.lops.DnnTransform.OperationTypes.CHANNEL_SUMS); + HopsConv2Lops.put(OpOpDnn.UPDATE_NESTEROV_X, org.apache.sysml.lops.DnnTransform.OperationTypes.UPDATE_NESTEROV_X); } protected static final HashMap<Hop.Direction, org.apache.sysml.lops.PartialAggregate.DirectionTypes> HopsDirection2Lops; http://git-wip-us.apache.org/repos/asf/systemml/blob/04bc667f/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java index 2a1699d..b603aa7 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java @@ -124,6 +124,7 @@ public class RewriteGPUSpecificOps extends HopRewriteRule { } hi = batchNormTest(hop, hi, i); hi = channelSums(hop, hi, i); + hi = updateNesterovX(hop, hi, i); if( !descendFirst ) rule_GPUKernels(roots, hi, descendFirst); @@ -281,6 +282,11 @@ public class RewriteGPUSpecificOps extends HopRewriteRule { && getFirstInput(h).getDataType() == DataType.MATRIX && getSecondInput(h).getDataType() == DataType.MATRIX; } + private static boolean isBinaryMMMinus(Hop h) { + return h instanceof BinaryOp && ((BinaryOp)h).getOp() == OpOp2.MINUS + && getFirstInput(h).getDataType() == DataType.MATRIX && getSecondInput(h).getDataType() == DataType.MATRIX; + } + private static boolean isBinaryMSMult(Hop h, double expectedValue) { return h instanceof BinaryOp && ((BinaryOp)h).getOp() == OpOp2.MULT && getFirstInput(h).getDataType() == DataType.MATRIX && getSecondInput(h).getDataType() == DataType.SCALAR @@ -323,6 +329,16 @@ public class RewriteGPUSpecificOps extends HopRewriteRule { && getSecondInput(h).getDataType() == DataType.MATRIX && getFirstInput(h).getDataType() == DataType.SCALAR; } + private static boolean isBinarySMMult(Hop h, double expectedVal) { + return h instanceof BinaryOp && ((BinaryOp)h).getOp() == OpOp2.MULT + && getSecondInput(h).getDataType() == DataType.MATRIX && getFirstInput(h).getDataType() == DataType.SCALAR + && getValue(getFirstInput(h)) == expectedVal; + } + + private static double getValue(Hop h) { + return OptimizerUtils.rEvalSimpleDoubleExpression(h, new HashMap<>()); + } + /** * Checks if the "mean" hop is a moving average of mean in batch normalization layer. * @@ -704,6 +720,51 @@ public class RewriteGPUSpecificOps extends HopRewriteRule { // ------------------------------------------------------------ /** + * Checks for the nesterov_update_x pattern (X = X - mu*v_prev + (1+mu)*v) + * and returns a new DnnOp if matched + * + * @param parent parent of the input + * @param hi input to be matched + * @param pos position + * @return a new DnnOp or hi + */ + private static Hop updateNesterovX(Hop parent, Hop hi, int pos) { + if(fitsOnGPU(hi, 4) && isBinaryMMAdd(hi) && isBinaryMMMinus(getFirstInput(hi)) + && isBinarySMMult(getSecondInput(getFirstInput(hi))) + && isBinarySMMult(getSecondInput(hi))) { + Hop onePlusMu = getFirstInput(getSecondInput(hi)); + Hop tmp = getSecondInput(getFirstInput(hi)); + Hop mu = getFirstInput(tmp); + if(isOnePlusMu(onePlusMu, mu)) { + Hop v_prev = getSecondInput(tmp); + Hop v = getSecondInput(getSecondInput(hi)); + Hop X = getFirstInput(getFirstInput(hi)); + if(hasSameDimensions(X, v) && hasSameDimensions(X, v_prev)) { + ArrayList<Hop> inHops = new ArrayList<Hop>(); + inHops.add(X); + inHops.add(v); + inHops.add(v_prev); + inHops.add(mu); + LOG.debug("Applied updateNesterovX rewrite."); + Hop newHop = new DnnOp(hi.getName(), hi.getDataType(), hi.getValueType(), + OpOpDnn.UPDATE_NESTEROV_X, inHops); + return HopRewriteUtils.rewireAllParentChildReferences(hi, newHop); + } + } + } + return hi; + } + + private static boolean hasSameDimensions(Hop x, Hop y) { + return x.dimsKnown() && y.dimsKnown() && (x.getDim1() == y.getDim1()) && (x.getDim2() == y.getDim2()); + } + + private static boolean isOnePlusMu(Hop onePlusMu, Hop mu) { + return (isBinarySMMult(onePlusMu, 1.0) && getSecondInput(onePlusMu) == mu) || + getValue(onePlusMu) == getValue(mu) + 1; + } + + /** * Checks for the batch norm (mode="test") pattern using the helper isBatchNormTrainMean and isBatchNormTrainVar * and returns a new DnnOp if matched * http://git-wip-us.apache.org/repos/asf/systemml/blob/04bc667f/src/main/java/org/apache/sysml/lops/DnnTransform.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/lops/DnnTransform.java b/src/main/java/org/apache/sysml/lops/DnnTransform.java index 6c61d4a..3183b5f 100644 --- a/src/main/java/org/apache/sysml/lops/DnnTransform.java +++ b/src/main/java/org/apache/sysml/lops/DnnTransform.java @@ -31,7 +31,8 @@ public class DnnTransform extends Lop MAX_POOL, MAX_POOL_BACKWARD, AVG_POOL, AVG_POOL_BACKWARD, RELU_MAX_POOLING, RELU_MAX_POOLING_BACKWARD, RELU_BACKWARD, CONV2D, CONV2D_BACKWARD_FILTER, CONV2D_BACKWARD_DATA, - BIAS_ADD, CONV2D_BIAS_ADD, BIAS_MULTIPLY, CHANNEL_SUMS, BATCH_NORM2D_TEST + BIAS_ADD, CONV2D_BIAS_ADD, BIAS_MULTIPLY, CHANNEL_SUMS, BATCH_NORM2D_TEST, + UPDATE_NESTEROV_X } private OperationTypes operation; @@ -165,6 +166,9 @@ public class DnnTransform extends Lop case CHANNEL_SUMS: return "channel_sums"; + + case UPDATE_NESTEROV_X: + return "update_nesterov_x"; case BATCH_NORM2D_TEST: return "batch_norm2d_test"; @@ -232,6 +236,33 @@ public class DnnTransform extends Lop } @Override + public String getInstructions(String input1, String input2, String input3, String input4, String output) { + if(operation == OperationTypes.UPDATE_NESTEROV_X) { + StringBuilder sb = new StringBuilder(); + sb.append( getExecType() ); + + sb.append( OPERAND_DELIMITOR ); + sb.append( getOpcode() ); + sb.append( OPERAND_DELIMITOR ); + sb.append( getInputs().get(0).prepInputOperand(input1)); + sb.append( OPERAND_DELIMITOR ); + sb.append( getInputs().get(1).prepInputOperand(input2)); + sb.append( OPERAND_DELIMITOR ); + sb.append( getInputs().get(2).prepInputOperand(input3)); + sb.append( OPERAND_DELIMITOR ); + sb.append( getInputs().get(3).prepInputOperand(input4)); + //output + sb.append( OPERAND_DELIMITOR ); + sb.append( this.prepOutputOperand(output)); + + return sb.toString(); + } + else { + throw new LopsException("The operation is not supported with three operands:" + operation.name()); + } + } + + @Override public String getInstructions(String[] inputs, String output) { StringBuilder sb = new StringBuilder(); appendOpcode(sb); http://git-wip-us.apache.org/repos/asf/systemml/blob/04bc667f/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java b/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java index c90f9f9..f4122d9 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java @@ -63,6 +63,7 @@ public class GPUInstructionParser extends InstructionParser String2GPUInstructionType.put( "batch_norm2d_backward", GPUINSTRUCTION_TYPE.Dnn); String2GPUInstructionType.put( "batch_norm2d_test", GPUINSTRUCTION_TYPE.Dnn); String2GPUInstructionType.put( "batch_norm2d_train", GPUINSTRUCTION_TYPE.Dnn); + String2GPUInstructionType.put( "update_nesterov_x", GPUINSTRUCTION_TYPE.Dnn); // Matrix Multiply Operators String2GPUInstructionType.put( "ba+*", GPUINSTRUCTION_TYPE.AggregateBinary); http://git-wip-us.apache.org/repos/asf/systemml/blob/04bc667f/src/main/java/org/apache/sysml/runtime/instructions/gpu/DnnGPUInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/gpu/DnnGPUInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/gpu/DnnGPUInstruction.java index a36d0fc..8d89032 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/gpu/DnnGPUInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/gpu/DnnGPUInstruction.java @@ -124,6 +124,21 @@ public class DnnGPUInstruction extends GPUInstruction { _intermediateMemoryBudget = intermediateMemoryBudget; } + public DnnGPUInstruction(CPOperand in1, CPOperand in2, CPOperand in3, CPOperand in4, CPOperand out, String opcode, String istr, + double intermediateMemoryBudget) throws DMLRuntimeException { + super(new ReorgOperator(SwapIndex.getSwapIndexFnObject()), opcode, istr); + if( !opcode.equals("update_nesterov_x") ) { + throw new DMLRuntimeException("Incorrect opcode: " + opcode); + } + _input1 = in1; + _input2 = in2; + _input3 = in3; + _input4 = in4; + _gputype = GPUINSTRUCTION_TYPE.Dnn; + _output = out; + _intermediateMemoryBudget = intermediateMemoryBudget; + } + public DnnGPUInstruction(CPOperand in1, CPOperand in2, CPOperand in3, CPOperand out, String opcode, String istr, ArrayList<CPOperand> stride, ArrayList<CPOperand> padding, ArrayList<CPOperand> input_shape, @@ -298,6 +313,15 @@ public class DnnGPUInstruction extends GPUInstruction { CPOperand out = new CPOperand(parts[4]); return new DnnGPUInstruction(in, in2, in3, out, opcode, str, 0); } + else if (opcode.equalsIgnoreCase("update_nesterov_x")) { + InstructionUtils.checkNumFields(parts, 5); + CPOperand in = new CPOperand(parts[1]); + CPOperand in2 = new CPOperand(parts[2]); + CPOperand in3 = new CPOperand(parts[3]); + CPOperand in4 = new CPOperand(parts[4]); + CPOperand out = new CPOperand(parts[5]); + return new DnnGPUInstruction(in, in2, in3, in4, out, opcode, str, 0); + } else if (opcode.equalsIgnoreCase("lstm")) { InstructionUtils.checkNumFields(parts, 8); CPOperand in1 = new CPOperand(parts[1]); @@ -552,6 +576,34 @@ public class DnnGPUInstruction extends GPUInstruction { ec.releaseMatrixOutputForGPUInstruction(_output.getName()); } + private void processNesterovUpdateInstruction(ExecutionContext ec) { + GPUStatistics.incrementNoOfExecutedGPUInst();; + MatrixObject input = getMatrixInputForGPUInstruction(ec, _input1.getName()); + MatrixObject v = getMatrixInputForGPUInstruction(ec, _input2.getName()); + MatrixObject v_prev = getMatrixInputForGPUInstruction(ec, _input3.getName()); + double mu = (int) ec.getScalarInput(_input4.getName(), _input4.getValueType(), _input4.isLiteral()).getDoubleValue(); + int rows = LibMatrixCUDA.toInt(input.getNumRows()); + int cols = LibMatrixCUDA.toInt(input.getNumColumns()); + MatrixObject out = getDenseMatrixOutputForGPUInstruction(ec, _output.getName(), rows, cols); + + GPUContext gCtx = ec.getGPUContext(0); + String instName = getExtendedOpcode(); + LibMatrixCUDA.getCudaKernels(gCtx).launchKernel("update_nesterov_x", + ExecutionConfig.getConfigForSimpleVectorOperations(LibMatrixCUDA.toInt(rows*cols)), + LibMatrixCUDA.getDensePointer(gCtx, input, instName), + LibMatrixCUDA.getDensePointer(gCtx, v, instName), + LibMatrixCUDA.getDensePointer(gCtx, v_prev, instName), + mu, + LibMatrixCUDA.getDensePointer(gCtx, out, instName), + rows*cols); + + // release inputs/outputs + ec.releaseMatrixInputForGPUInstruction(_input1.getName()); + ec.releaseMatrixInputForGPUInstruction(_input2.getName()); + ec.releaseMatrixInputForGPUInstruction(_input3.getName()); + ec.releaseMatrixOutputForGPUInstruction(_output.getName()); + } + private static int toInt(long num) throws DMLRuntimeException { if(num >= Integer.MAX_VALUE || num <= Integer.MIN_VALUE) { throw new DMLRuntimeException("GPU : Exceeded supported size " + num); @@ -697,6 +749,10 @@ public class DnnGPUInstruction extends GPUInstruction { processChannelSumsInstruction(ec); return; } + else if (instOpcode.equalsIgnoreCase("update_nesterov_x")) { + processNesterovUpdateInstruction(ec); + return; + } else if (instOpcode.equalsIgnoreCase("lstm")) { processLstmInstruction(ec); return; http://git-wip-us.apache.org/repos/asf/systemml/blob/04bc667f/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUMemoryManager.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUMemoryManager.java b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUMemoryManager.java index 3d7ab2c..5d0e4bc 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUMemoryManager.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUMemoryManager.java @@ -405,7 +405,8 @@ public class GPUMemoryManager { allPointers.remove(toFree); lazyCudaFreeMemoryManager.removeIfPresent(size, toFree); allocator.free(toFree); - // JCuda.cudaDeviceSynchronize(); // Force a device synchronize after free-ing the pointer for debugging + if(DMLScript.SYNCHRONIZE_GPU) + jcuda.runtime.JCuda.cudaDeviceSynchronize(); // Force a device synchronize after free-ing the pointer for debugging } else { throw new RuntimeException("Attempting to free an unaccounted pointer:" + toFree); @@ -447,7 +448,7 @@ public class GPUMemoryManager { public void removeGPUObject(GPUObject gpuObj) { if(LOG.isDebugEnabled()) LOG.debug("Removing the GPU object: " + gpuObj); - matrixMemoryManager.gpuObjects.removeIf(a -> a.equals(gpuObj)); + matrixMemoryManager.gpuObjects.remove(gpuObj); } http://git-wip-us.apache.org/repos/asf/systemml/blob/04bc667f/src/test/java/org/apache/sysml/test/gpu/GPUTests.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/gpu/GPUTests.java b/src/test/java/org/apache/sysml/test/gpu/GPUTests.java index cae2e33..e1ae1ae 100644 --- a/src/test/java/org/apache/sysml/test/gpu/GPUTests.java +++ b/src/test/java/org/apache/sysml/test/gpu/GPUTests.java @@ -311,6 +311,16 @@ public abstract class GPUTests extends AutomatedTestBase { Set<String> heavyHitterOpCodes = Statistics.getCPHeavyHitterOpCodes(); Assert.assertTrue(heavyHitterOpCodes.contains(heavyHitterOpCode)); } + + /** + * asserts that the expected op was executed + * + * @param heavyHitterOpCode opcode of the heavy hitter for the unary op + */ + protected void assertHeavyHitterNotPresent(String heavyHitterOpCode) { + Set<String> heavyHitterOpCodes = Statistics.getCPHeavyHitterOpCodes(); + Assert.assertTrue(!heavyHitterOpCodes.contains(heavyHitterOpCode)); + } /** * Runs a program on the CPU http://git-wip-us.apache.org/repos/asf/systemml/blob/04bc667f/src/test/java/org/apache/sysml/test/gpu/SGDUpdate.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/gpu/SGDUpdate.java b/src/test/java/org/apache/sysml/test/gpu/SGDUpdate.java new file mode 100644 index 0000000..c98a74d --- /dev/null +++ b/src/test/java/org/apache/sysml/test/gpu/SGDUpdate.java @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.test.gpu; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import org.apache.sysml.test.utils.TestUtils; +import org.junit.Test; + +/** + * Tests update rewrites for SGD + */ +public class SGDUpdate extends GPUTests { + + private final static String TEST_NAME = "SGDUpdateTests"; + private final int seed = 42; + + @Override + public void setUp() { + super.setUp(); + TestUtils.clearAssertionInformation(); + addTestConfiguration(TEST_DIR, TEST_NAME); + getAndLoadTestConfiguration(TEST_NAME); + } + + @Test + public void testNesterovRewrite() { + String scriptStr = "mu=0.99; output = x - mu*v_prev + (1+mu)*v;" ; + int inRows = 10; + int inCols = 30; + HashMap<String, Object> inputs = new HashMap<>(); + inputs.put("x", generateInputMatrix(spark, inRows, inCols, 0, 10, 0.9, seed)); + inputs.put("v_prev", generateInputMatrix(spark, inRows, inCols, 0, 10, 0.9, seed)); + inputs.put("v", generateInputMatrix(spark, inRows, inCols, 0, 10, 0.9, seed)); + List<String> outputs = Arrays.asList("output"); + List<Object> outCPU = runOnCPU(spark, scriptStr, inputs, outputs); + List<Object> outGPU = runOnGPU(spark, scriptStr, inputs, outputs); + assertHeavyHitterPresent("gpu_update_nesterov_x"); + assertEqualObjects(outCPU.get(0), outGPU.get(0)); + } + + @Test + public void testNoNesterovRewrite1() { + String scriptStr = "mu=0.99; output = x - mu*v_prev + (1+mu)*v;" ; + int inRows = 10; + int inCols = 30; + HashMap<String, Object> inputs = new HashMap<>(); + inputs.put("x", generateInputMatrix(spark, inRows, inCols, 0, 10, 0.9, seed)); + inputs.put("v_prev", generateInputMatrix(spark, inRows, inCols, 0, 10, 0.9, seed)); + inputs.put("v", generateInputMatrix(spark, inRows, 1, 0, 10, 0.9, seed)); + List<String> outputs = Arrays.asList("output"); + List<Object> outCPU = runOnCPU(spark, scriptStr, inputs, outputs); + List<Object> outGPU = runOnGPU(spark, scriptStr, inputs, outputs); + assertHeavyHitterNotPresent("gpu_update_nesterov_x"); + assertEqualObjects(outCPU.get(0), outGPU.get(0)); + } + + @Test + public void testNoNesterovRewrite2() { + String scriptStr = "mu=0.99; output = x - mu*v_prev + mu*v;" ; + int inRows = 10; + int inCols = 30; + HashMap<String, Object> inputs = new HashMap<>(); + inputs.put("x", generateInputMatrix(spark, inRows, inCols, 0, 10, 0.9, seed)); + inputs.put("v_prev", generateInputMatrix(spark, inRows, inCols, 0, 10, 0.9, seed)); + inputs.put("v", generateInputMatrix(spark, inRows, inCols, 0, 10, 0.9, seed)); + List<String> outputs = Arrays.asList("output"); + List<Object> outCPU = runOnCPU(spark, scriptStr, inputs, outputs); + List<Object> outGPU = runOnGPU(spark, scriptStr, inputs, outputs); + assertHeavyHitterNotPresent("gpu_update_nesterov_x"); + assertEqualObjects(outCPU.get(0), outGPU.get(0)); + } +}
