Repository: incubator-systemml Updated Branches: refs/heads/master 0daae6cf0 -> 4316efeba
Bug fixes, instruction added, async cudaFree - Fixes for GPU mem mgmt and related integration tests - Added "exp" function for GPU - Do cudaFree asynchronously Closes #404 Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/4316efeb Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/4316efeb Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/4316efeb Branch: refs/heads/master Commit: 4316efebaf065d7a3de067354275d1b991e38bb4 Parents: 0daae6c Author: Nakul Jindal <[email protected]> Authored: Fri Feb 24 11:27:44 2017 -0800 Committer: Nakul Jindal <[email protected]> Committed: Fri Feb 24 11:27:44 2017 -0800 ---------------------------------------------------------------------- src/main/cpp/kernels/SystemML.cu | 15 ++ src/main/cpp/kernels/SystemML.ptx | 136 ++++++++++++-- .../java/org/apache/sysml/hops/UnaryOp.java | 2 +- .../context/ExecutionContext.java | 8 +- .../instructions/GPUInstructionParser.java | 3 +- .../gpu/BuiltinUnaryGPUInstruction.java | 2 +- .../gpu/ConvolutionGPUInstruction.java | 16 +- .../instructions/gpu/GPUInstruction.java | 10 +- .../gpu/MatrixBuiltinGPUInstruction.java | 14 +- .../instructions/gpu/context/GPUContext.java | 17 +- .../instructions/gpu/context/GPUObject.java | 177 +++++++++++-------- .../instructions/gpu/context/JCudaContext.java | 1 + .../instructions/gpu/context/JCudaObject.java | 161 ++++++++++------- .../runtime/matrix/data/LibMatrixCUDA.java | 103 +++++++---- .../java/org/apache/sysml/utils/Statistics.java | 23 ++- .../functions/misc/RewritePushdownUaggTest.java | 15 +- .../RewriteSimplifyRowColSumMVMultTest.java | 5 +- 17 files changed, 483 insertions(+), 225 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/4316efeb/src/main/cpp/kernels/SystemML.cu ---------------------------------------------------------------------- diff --git a/src/main/cpp/kernels/SystemML.cu b/src/main/cpp/kernels/SystemML.cu index cda28ba..40a1046 100644 --- a/src/main/cpp/kernels/SystemML.cu +++ b/src/main/cpp/kernels/SystemML.cu @@ -628,3 +628,18 @@ __global__ void reduce_col_mean(double *g_idata, double *g_odata, unsigned int r MeanOp aop(rows); reduce_col<SumOp, MeanOp>(g_idata, g_odata, rows, cols, op, aop, 0.0); } + + +/** + * Do an exp over all the elements of a matrix + * @param A the input matrix (of length = size) + * @param C the pre-allocated output matrix (of length = size) + * @param siz the length of the input and output matrices + */ +extern "C" +__global__ void matrix_exp(double *A, double *C, unsigned int size) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + if (index < size){ + C[index] = exp(A[index]); + } +} http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/4316efeb/src/main/cpp/kernels/SystemML.ptx ---------------------------------------------------------------------- diff --git a/src/main/cpp/kernels/SystemML.ptx b/src/main/cpp/kernels/SystemML.ptx index 93f3879..b9efd9b 100644 --- a/src/main/cpp/kernels/SystemML.ptx +++ b/src/main/cpp/kernels/SystemML.ptx @@ -4810,6 +4810,120 @@ BB33_5: ret; } + // .globl matrix_exp +.visible .entry matrix_exp( + .param .u64 matrix_exp_param_0, + .param .u64 matrix_exp_param_1, + .param .u32 matrix_exp_param_2 +) +{ + .reg .pred %p<5>; + .reg .f32 %f<3>; + .reg .b32 %r<21>; + .reg .f64 %fd<42>; + .reg .b64 %rd<10>; + + + ld.param.u64 %rd2, [matrix_exp_param_0]; + ld.param.u64 %rd3, [matrix_exp_param_1]; + ld.param.u32 %r5, [matrix_exp_param_2]; + mov.u32 %r6, %ctaid.x; + mov.u32 %r7, %ntid.x; + mov.u32 %r8, %tid.x; + mad.lo.s32 %r1, %r7, %r6, %r8; + setp.ge.u32 %p1, %r1, %r5; + @%p1 bra BB34_5; + + cvta.to.global.u64 %rd4, %rd2; + cvt.s64.s32 %rd1, %r1; + mul.wide.s32 %rd5, %r1, 8; + add.s64 %rd6, %rd4, %rd5; + ld.global.f64 %fd1, [%rd6]; + mov.f64 %fd6, 0d3FF71547652B82FE; + mul.rn.f64 %fd7, %fd1, %fd6; + mov.f64 %fd8, 0d4338000000000000; + add.rn.f64 %fd9, %fd7, %fd8; + { + .reg .b32 %temp; + mov.b64 {%r2, %temp}, %fd9; + } + mov.f64 %fd10, 0dC338000000000000; + add.rn.f64 %fd11, %fd9, %fd10; + mov.f64 %fd12, 0dBFE62E42FEFA39EF; + fma.rn.f64 %fd13, %fd11, %fd12, %fd1; + mov.f64 %fd14, 0dBC7ABC9E3B39803F; + fma.rn.f64 %fd15, %fd11, %fd14, %fd13; + mov.f64 %fd16, 0d3E928AF3FCA213EA; + mov.f64 %fd17, 0d3E5ADE1569CE2BDF; + fma.rn.f64 %fd18, %fd17, %fd15, %fd16; + mov.f64 %fd19, 0d3EC71DEE62401315; + fma.rn.f64 %fd20, %fd18, %fd15, %fd19; + mov.f64 %fd21, 0d3EFA01997C89EB71; + fma.rn.f64 %fd22, %fd20, %fd15, %fd21; + mov.f64 %fd23, 0d3F2A01A014761F65; + fma.rn.f64 %fd24, %fd22, %fd15, %fd23; + mov.f64 %fd25, 0d3F56C16C1852B7AF; + fma.rn.f64 %fd26, %fd24, %fd15, %fd25; + mov.f64 %fd27, 0d3F81111111122322; + fma.rn.f64 %fd28, %fd26, %fd15, %fd27; + mov.f64 %fd29, 0d3FA55555555502A1; + fma.rn.f64 %fd30, %fd28, %fd15, %fd29; + mov.f64 %fd31, 0d3FC5555555555511; + fma.rn.f64 %fd32, %fd30, %fd15, %fd31; + mov.f64 %fd33, 0d3FE000000000000B; + fma.rn.f64 %fd34, %fd32, %fd15, %fd33; + mov.f64 %fd35, 0d3FF0000000000000; + fma.rn.f64 %fd36, %fd34, %fd15, %fd35; + fma.rn.f64 %fd37, %fd36, %fd15, %fd35; + { + .reg .b32 %temp; + mov.b64 {%r3, %temp}, %fd37; + } + { + .reg .b32 %temp; + mov.b64 {%temp, %r4}, %fd37; + } + shl.b32 %r9, %r2, 20; + add.s32 %r10, %r4, %r9; + mov.b64 %fd41, {%r3, %r10}; + { + .reg .b32 %temp; + mov.b64 {%temp, %r11}, %fd1; + } + mov.b32 %f2, %r11; + abs.f32 %f1, %f2; + setp.lt.f32 %p2, %f1, 0f4086232B; + @%p2 bra BB34_4; + + setp.lt.f64 %p3, %fd1, 0d0000000000000000; + add.f64 %fd38, %fd1, 0d7FF0000000000000; + selp.f64 %fd41, 0d0000000000000000, %fd38, %p3; + setp.geu.f32 %p4, %f1, 0f40874800; + @%p4 bra BB34_4; + + shr.u32 %r12, %r2, 31; + add.s32 %r13, %r2, %r12; + shr.s32 %r14, %r13, 1; + shl.b32 %r15, %r14, 20; + add.s32 %r16, %r15, %r4; + mov.b64 %fd39, {%r3, %r16}; + sub.s32 %r17, %r2, %r14; + shl.b32 %r18, %r17, 20; + add.s32 %r19, %r18, 1072693248; + mov.u32 %r20, 0; + mov.b64 %fd40, {%r20, %r19}; + mul.f64 %fd41, %fd39, %fd40; + +BB34_4: + cvta.to.global.u64 %rd7, %rd3; + shl.b64 %rd8, %rd1, 3; + add.s64 %rd9, %rd7, %rd8; + st.global.f64 [%rd9], %fd41; + +BB34_5: + ret; +} + .func (.param .b64 func_retval0) __internal_accurate_pow( .param .b64 __internal_accurate_pow_param_0, .param .b64 __internal_accurate_pow_param_1 @@ -4833,7 +4947,7 @@ BB33_5: } shr.u32 %r50, %r49, 20; setp.ne.s32 %p1, %r50, 0; - @%p1 bra BB34_2; + @%p1 bra BB35_2; mul.f64 %fd14, %fd12, 0d4350000000000000; { @@ -4847,13 +4961,13 @@ BB33_5: shr.u32 %r16, %r49, 20; add.s32 %r50, %r16, -54; -BB34_2: +BB35_2: add.s32 %r51, %r50, -1023; and.b32 %r17, %r49, -2146435073; or.b32 %r18, %r17, 1072693248; mov.b64 %fd133, {%r48, %r18}; setp.lt.u32 %p2, %r18, 1073127583; - @%p2 bra BB34_4; + @%p2 bra BB35_4; { .reg .b32 %temp; @@ -4867,7 +4981,7 @@ BB34_2: mov.b64 %fd133, {%r19, %r21}; add.s32 %r51, %r50, -1022; -BB34_4: +BB35_4: add.f64 %fd16, %fd133, 0d3FF0000000000000; // inline asm rcp.approx.ftz.f64 %fd15,%fd16; @@ -5033,13 +5147,13 @@ BB34_4: mov.b32 %f2, %r35; abs.f32 %f1, %f2; setp.lt.f32 %p4, %f1, 0f4086232B; - @%p4 bra BB34_7; + @%p4 bra BB35_7; setp.lt.f64 %p5, %fd4, 0d0000000000000000; add.f64 %fd130, %fd4, 0d7FF0000000000000; selp.f64 %fd134, 0d0000000000000000, %fd130, %p5; setp.geu.f32 %p6, %f1, 0f40874800; - @%p6 bra BB34_7; + @%p6 bra BB35_7; shr.u32 %r36, %r13, 31; add.s32 %r37, %r13, %r36; @@ -5054,26 +5168,26 @@ BB34_4: mov.b64 %fd132, {%r44, %r43}; mul.f64 %fd134, %fd131, %fd132; -BB34_7: +BB35_7: { .reg .b32 %temp; mov.b64 {%temp, %r45}, %fd134; } and.b32 %r46, %r45, 2147483647; setp.ne.s32 %p7, %r46, 2146435072; - @%p7 bra BB34_9; + @%p7 bra BB35_9; { .reg .b32 %temp; mov.b64 {%r47, %temp}, %fd134; } setp.eq.s32 %p8, %r47, 0; - @%p8 bra BB34_10; + @%p8 bra BB35_10; -BB34_9: +BB35_9: fma.rn.f64 %fd134, %fd134, %fd5, %fd134; -BB34_10: +BB35_10: st.param.f64 [func_retval0+0], %fd134; ret; } http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/4316efeb/src/main/java/org/apache/sysml/hops/UnaryOp.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/UnaryOp.java b/src/main/java/org/apache/sysml/hops/UnaryOp.java index 85c396f..dd2a634 100644 --- a/src/main/java/org/apache/sysml/hops/UnaryOp.java +++ b/src/main/java/org/apache/sysml/hops/UnaryOp.java @@ -173,7 +173,7 @@ public class UnaryOp extends Hop implements MultiThreadedHop else //default unary { int k = isCumulativeUnaryOperation() ? OptimizerUtils.getConstrainedNumThreads( _maxNumThreads ) : 1; - if(_op == OpOp1.SELP) { + if(_op == OpOp1.SELP || _op == OpOp1.EXP) { et = findGPUExecTypeByMemEstimate(et); } Unary unary1 = new Unary(input.constructLops(), HopsOpOp1LopsU.get(_op), http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/4316efeb/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java b/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java index d6f4e6c..f14123e 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java @@ -66,9 +66,7 @@ public class ExecutionContext //debugging (optional) protected DebugState _dbState = null; - - protected GPUContext _gpuCtx = null; - + protected ExecutionContext() { //protected constructor to force use of ExecutionContextFactory @@ -101,8 +99,8 @@ public class ExecutionContext } public void destroyGPUContext() throws DMLRuntimeException { - if(_gpuCtx != null) - _gpuCtx.destroy(); + if(GPUContext.isGPUContextCreated) + GPUContext.getGPUContext().destroy(); } http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/4316efeb/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java b/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java index 3dbdb1e..0be2139 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java @@ -69,8 +69,9 @@ public class GPUInstructionParser extends InstructionParser String2GPUInstructionType.put( "+*" , GPUINSTRUCTION_TYPE.ArithmeticBinary); String2GPUInstructionType.put( "-*" , GPUINSTRUCTION_TYPE.ArithmeticBinary); - + // Builtin functions String2GPUInstructionType.put( "sel+" , GPUINSTRUCTION_TYPE.BuiltinUnary); + String2GPUInstructionType.put( "exp" , GPUINSTRUCTION_TYPE.BuiltinUnary); // Aggregate Unary String2GPUInstructionType.put( "ua+" , GPUINSTRUCTION_TYPE.AggregateUnary); // Sum http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/4316efeb/src/main/java/org/apache/sysml/runtime/instructions/gpu/BuiltinUnaryGPUInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/gpu/BuiltinUnaryGPUInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/gpu/BuiltinUnaryGPUInstruction.java index 08a7923..181af4e 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/gpu/BuiltinUnaryGPUInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/gpu/BuiltinUnaryGPUInstruction.java @@ -35,7 +35,7 @@ public abstract class BuiltinUnaryGPUInstruction extends GPUInstruction { int _arity; CPOperand _input; - CPOperand _output; + CPOperand _output; public BuiltinUnaryGPUInstruction(Operator op, CPOperand in, CPOperand out, int _arity, String opcode, String istr ) { http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/4316efeb/src/main/java/org/apache/sysml/runtime/instructions/gpu/ConvolutionGPUInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/gpu/ConvolutionGPUInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/gpu/ConvolutionGPUInstruction.java index 5c49a91..56e95b7 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/gpu/ConvolutionGPUInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/gpu/ConvolutionGPUInstruction.java @@ -205,9 +205,7 @@ public class ConvolutionGPUInstruction extends GPUInstruction if (instOpcode.equalsIgnoreCase("conv2d")) { MatrixObject image = ec.getMatrixInputForGPUInstruction(_input1.getName()); MatrixObject filter = ec.getMatrixInputForGPUInstruction(_input2.getName()); - if( LibMatrixCUDA.isInSparseFormat(image) || LibMatrixCUDA.isInSparseFormat(filter) ) { - throw new DMLRuntimeException("Sparse convolution not implemented"); - } + if(image.getNumRows() != N || image.getNumColumns() != C*H*W) throw new DMLRuntimeException("Incorrect dimensions for image in conv2d"); if(filter.getNumRows() != K || filter.getNumColumns() != C*R*S) @@ -221,8 +219,7 @@ public class ConvolutionGPUInstruction extends GPUInstruction else if (instOpcode.equalsIgnoreCase("conv2d_backward_filter")) { MatrixObject image = ec.getMatrixInputForGPUInstruction(_input1.getName()); MatrixObject dout = ec.getMatrixInputForGPUInstruction(_input2.getName()); - if(LibMatrixCUDA.isInSparseFormat(image) || LibMatrixCUDA.isInSparseFormat(dout)) - throw new DMLRuntimeException("Sparse convolution_backward_filter not implemented"); + if(image.getNumRows() != N || image.getNumColumns() != C*H*W) throw new DMLRuntimeException("Incorrect dimensions for image in conv2d_backward_filter"); if(dout.getNumRows() != N || dout.getNumColumns() != K*P*Q) @@ -239,8 +236,7 @@ public class ConvolutionGPUInstruction extends GPUInstruction else if (instOpcode.equalsIgnoreCase("conv2d_backward_data")) { MatrixObject filter = ec.getMatrixInputForGPUInstruction(_input1.getName()); MatrixObject dout = ec.getMatrixInputForGPUInstruction(_input2.getName()); - if(LibMatrixCUDA.isInSparseFormat(filter) || LibMatrixCUDA.isInSparseFormat(dout)) - throw new DMLRuntimeException("Sparse convolution_backward_data not implemented"); + if(filter.getNumRows() != K || filter.getNumColumns() != C*R*S) throw new DMLRuntimeException("Incorrect dimensions for filter in convolution_backward_data"); if(dout.getNumRows() != N || dout.getNumColumns() != K*P*Q) @@ -254,8 +250,7 @@ public class ConvolutionGPUInstruction extends GPUInstruction } else if (instOpcode.equalsIgnoreCase("maxpooling")) { MatrixObject image = ec.getMatrixInputForGPUInstruction(_input1.getName()); - if(LibMatrixCUDA.isInSparseFormat(image)) - throw new DMLRuntimeException("Sparse maxpooling not implemented"); + if(image.getNumRows() != N || image.getNumColumns() != C*H*W) throw new DMLRuntimeException("Incorrect dimensions for image in maxpooling: " + image.getNumRows() + " != " + N + " || " + image.getNumColumns() + " != " + C*H*W); @@ -268,8 +263,7 @@ public class ConvolutionGPUInstruction extends GPUInstruction else if (instOpcode.equalsIgnoreCase("maxpooling_backward")) { MatrixObject image = ec.getMatrixInputForGPUInstruction(_input1.getName()); MatrixObject dout = ec.getMatrixInputForGPUInstruction(_input2.getName()); - if(LibMatrixCUDA.isInSparseFormat(image) || LibMatrixCUDA.isInSparseFormat(dout)) - throw new DMLRuntimeException("Sparse maxpooling_backward_data not implemented"); + if(dout.getNumRows() != N || dout.getNumColumns() != C*P*Q) throw new DMLRuntimeException("Incorrect dimensions for dout in maxpooling_backward"); if(image.getNumRows() != N || image.getNumColumns() != C*H*W) http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/4316efeb/src/main/java/org/apache/sysml/runtime/instructions/gpu/GPUInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/gpu/GPUInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/gpu/GPUInstruction.java index aca197e..1c91a51 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/gpu/GPUInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/gpu/GPUInstruction.java @@ -19,6 +19,7 @@ package org.apache.sysml.runtime.instructions.gpu; +import jcuda.runtime.JCuda; import org.apache.sysml.lops.runtime.RunMRJobs; import org.apache.sysml.runtime.DMLRuntimeException; import org.apache.sysml.runtime.controlprogram.context.ExecutionContext; @@ -28,7 +29,7 @@ import org.apache.sysml.runtime.matrix.operators.Operator; public abstract class GPUInstruction extends Instruction { - public enum GPUINSTRUCTION_TYPE { AggregateUnary, AggregateBinary, Convolution, MMTSJ, Reorg, ArithmeticBinary, BuiltinUnary }; + public enum GPUINSTRUCTION_TYPE { AggregateUnary, AggregateBinary, Convolution, MMTSJ, Reorg, ArithmeticBinary, BuiltinUnary, Builtin }; protected GPUINSTRUCTION_TYPE _gputype; protected Operator _optr; @@ -83,4 +84,11 @@ public abstract class GPUInstruction extends Instruction @Override public abstract void processInstruction(ExecutionContext ec) throws DMLRuntimeException; + + @Override + public void postprocessInstruction(ExecutionContext ec) + throws DMLRuntimeException + { + JCuda.cudaDeviceSynchronize(); + } } http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/4316efeb/src/main/java/org/apache/sysml/runtime/instructions/gpu/MatrixBuiltinGPUInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/gpu/MatrixBuiltinGPUInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/gpu/MatrixBuiltinGPUInstruction.java index 527ea55..a423cdd 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/gpu/MatrixBuiltinGPUInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/gpu/MatrixBuiltinGPUInstruction.java @@ -40,16 +40,20 @@ public class MatrixBuiltinGPUInstruction extends BuiltinUnaryGPUInstruction { Statistics.incrementNoOfExecutedGPUInst(); String opcode = getOpcode(); - //get input - MatrixObject mat = ec.getMatrixInputForGPUInstruction(_input.getName()); + MatrixObject mat = ec.getMatrixInputForGPUInstruction(_input.getName()); + + ec.setMetaData(_output.getName(), mat.getNumRows(), mat.getNumColumns()); + if(opcode.equals("sel+")) { - ec.setMetaData(_output.getName(), mat.getNumRows(), mat.getNumColumns()); LibMatrixCUDA.relu(ec, mat, _output.getName()); - ec.releaseMatrixInputForGPUInstruction(_input.getName()); - ec.releaseMatrixOutputForGPUInstruction(_output.getName()); + + } else if (opcode.equals("exp")) { + LibMatrixCUDA.exp(ec, mat, _output.getName()); } else { throw new DMLRuntimeException("Unsupported GPU operator:" + opcode); } + ec.releaseMatrixInputForGPUInstruction(_input.getName()); + ec.releaseMatrixOutputForGPUInstruction(_output.getName()); } } \ No newline at end of file http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/4316efeb/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUContext.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUContext.java b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUContext.java index c1b77eb..8ca00b2 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUContext.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUContext.java @@ -19,6 +19,10 @@ package org.apache.sysml.runtime.instructions.gpu.context; import java.util.ArrayList; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; import org.apache.sysml.api.DMLScript; import org.apache.sysml.hops.OptimizerUtils; @@ -28,7 +32,18 @@ import org.apache.sysml.runtime.controlprogram.caching.MatrixObject; //FIXME merge JCudaContext into GPUContext as this context is anyway CUDA specific public abstract class GPUContext { - public static ArrayList<GPUObject> allocatedPointers = new ArrayList<GPUObject>(); + public static ArrayList<GPUObject> allocatedPointers = new ArrayList<GPUObject>(); + + /** cudaFree calls are done asynchronously on a separate thread, + * this list preserve the list of currently happening cudaFree calls */ + public static ConcurrentLinkedQueue<Future> pendingDeallocates = new ConcurrentLinkedQueue<Future>(); + + /** All asynchronous cudaFree calls will be done on this executor service */ + public static ExecutorService deallocExecutorService = Executors.newSingleThreadExecutor(); + + /** Synchronization object to make sure no allocations happen when something is being evicted from memory */ + public static final Object syncObj = new Object(); + protected static GPUContext currContext; public static volatile Boolean isGPUContextCreated = false; http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/4316efeb/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUObject.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUObject.java b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUObject.java index bcffa46..9708fe8 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUObject.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUObject.java @@ -18,17 +18,19 @@ */ package org.apache.sysml.runtime.instructions.gpu.context; -import java.util.Collections; -import java.util.Comparator; -import java.util.concurrent.atomic.AtomicInteger; -import java.util.concurrent.atomic.AtomicLong; - import org.apache.sysml.runtime.DMLRuntimeException; import org.apache.sysml.runtime.controlprogram.caching.CacheException; import org.apache.sysml.runtime.controlprogram.caching.MatrixObject; import org.apache.sysml.runtime.matrix.data.MatrixBlock; import org.apache.sysml.utils.Statistics; +import java.util.Collections; +import java.util.Comparator; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Future; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicLong; + //FIXME merge JCudaObject into GPUObject to avoid unnecessary complexity public abstract class GPUObject { @@ -81,7 +83,7 @@ public abstract class GPUObject abstract void allocateDenseMatrixOnDevice() throws DMLRuntimeException; abstract void allocateSparseMatrixOnDevice() throws DMLRuntimeException; - abstract void deallocateMemoryOnDevice() throws DMLRuntimeException; + abstract void deallocateMemoryOnDevice(boolean synchronous) throws DMLRuntimeException; abstract long getSizeOnDevice() throws DMLRuntimeException; abstract void copyFromHostToDevice() throws DMLRuntimeException; @@ -99,82 +101,115 @@ public abstract class GPUObject /** * Cycles through the sorted list of allocated {@link GPUObject} instances. Sorting is based on * number of (read) locks that have been obtained on it (reverse order). It repeatedly frees up - * blocks on which there are zero locks until the required size has been freed up. + * blocks on which there are zero locks until the required size has been freed up. * // TODO: update it with hybrid policy * @param GPUSize Desired size to be freed up on the GPU * @throws DMLRuntimeException If no blocks to free up or if not enough blocks with zero locks on them. */ protected static void evict(final long GPUSize) throws DMLRuntimeException { - if(GPUContext.allocatedPointers.size() == 0) { - throw new DMLRuntimeException("There is not enough memory on device for this matrix!"); - } - - Statistics.cudaEvictionCount.addAndGet(1); - - synchronized(evictionLock) { - Collections.sort(GPUContext.allocatedPointers, new Comparator<GPUObject>() { - - @Override - public int compare(GPUObject p1, GPUObject p2) { - long p1Val = p1.numLocks.get(); - long p2Val = p2.numLocks.get(); - - if(p1Val>0 && p2Val>0) { - // Both are locked, so don't sort - return 0; - } - else if(p1Val>0 || p2Val>0) { - // Put the unlocked one to RHS - return Long.compare(p2Val, p1Val); - } - else { - // Both are unlocked - - if(evictionPolicy == EvictionPolicy.MIN_EVICT) { - long p1Size = 0; long p2Size = 0; - try { - p1Size = p1.getSizeOnDevice() - GPUSize; - p2Size = p2.getSizeOnDevice() - GPUSize; - } catch (DMLRuntimeException e) { - throw new RuntimeException(e); - } - - if(p1Size>=0 && p2Size>=0 ) { - return Long.compare(p2Size, p1Size); - } - else { - return Long.compare(p1Size, p2Size); - } - } - else if(evictionPolicy == EvictionPolicy.LRU || evictionPolicy == EvictionPolicy.LFU) { - return Long.compare(p2.timestamp.get(), p1.timestamp.get()); - } - else { - throw new RuntimeException("Unsupported eviction policy:" + evictionPolicy.name()); - } - } - } - }); - - while(GPUSize > getAvailableMemory() && GPUContext.allocatedPointers.size() > 0) { - GPUObject toBeRemoved = GPUContext.allocatedPointers.get(GPUContext.allocatedPointers.size() - 1); - if(toBeRemoved.numLocks.get() > 0) { - throw new DMLRuntimeException("There is not enough memory on device for this matrix!"); - } - if(toBeRemoved.isDeviceCopyModified) { - toBeRemoved.copyFromDeviceToHost(); - } - toBeRemoved.clearData(); - } - } + synchronized (GPUContext.syncObj) { + // Check for the completion of asynchronous cudaFree calls + try { + while (GPUSize > getAvailableMemory()) { + Future f = GPUContext.pendingDeallocates.poll(); + if (f == null) { + break; + } else if (f.isDone()) { + continue; + } else { + f.get(); + } + } + } catch (InterruptedException e) { + throw new DMLRuntimeException("There was an error with pending deallocates", e); + } catch (ExecutionException e) { + throw new DMLRuntimeException("There was an error with pending deallocates", e); + } + + if (GPUSize <= getAvailableMemory()) + return; + + if (GPUContext.allocatedPointers.size() == 0) { + throw new DMLRuntimeException("There is not enough memory on device for this matrix!"); + } + + Statistics.cudaEvictionCount.addAndGet(1); + + synchronized (evictionLock) { + Collections.sort(GPUContext.allocatedPointers, new Comparator<GPUObject>() { + + @Override + public int compare(GPUObject p1, GPUObject p2) { + long p1Val = p1.numLocks.get(); + long p2Val = p2.numLocks.get(); + + if (p1Val > 0 && p2Val > 0) { + // Both are locked, so don't sort + return 0; + } else if (p1Val > 0 || p2Val > 0) { + // Put the unlocked one to RHS + return Long.compare(p2Val, p1Val); + } else { + // Both are unlocked + + if (evictionPolicy == EvictionPolicy.MIN_EVICT) { + long p1Size = 0; + long p2Size = 0; + try { + p1Size = p1.getSizeOnDevice() - GPUSize; + p2Size = p2.getSizeOnDevice() - GPUSize; + } catch (DMLRuntimeException e) { + throw new RuntimeException(e); + } + + if (p1Size >= 0 && p2Size >= 0) { + return Long.compare(p2Size, p1Size); + } else { + return Long.compare(p1Size, p2Size); + } + } else if (evictionPolicy == EvictionPolicy.LRU || evictionPolicy == EvictionPolicy.LFU) { + return Long.compare(p2.timestamp.get(), p1.timestamp.get()); + } else { + throw new RuntimeException("Unsupported eviction policy:" + evictionPolicy.name()); + } + } + } + }); + + while (GPUSize > getAvailableMemory() && GPUContext.allocatedPointers.size() > 0) { + GPUObject toBeRemoved = GPUContext.allocatedPointers.get(GPUContext.allocatedPointers.size() - 1); + if (toBeRemoved.numLocks.get() > 0) { + throw new DMLRuntimeException("There is not enough memory on device for this matrix!"); + } + if (toBeRemoved.isDeviceCopyModified) { + toBeRemoved.copyFromDeviceToHost(); + } + + toBeRemoved.clearData(true); + } + } + } } - + + /** + * Asynchronously clears the data associated with this {@link GPUObject} instance + * @throws CacheException ? + */ public void clearData() throws CacheException { + clearData(false); + } + + /** + * Clears the data associated with this {@link GPUObject} instance + * @param synchronous whether to be done synchronously or asynchronously + * @throws CacheException ? + */ + public void clearData(boolean synchronous) throws CacheException { synchronized(evictionLock) { GPUContext.allocatedPointers.remove(this); } try { - deallocateMemoryOnDevice(); + deallocateMemoryOnDevice(synchronous); } catch (DMLRuntimeException e) { throw new CacheException(e); } http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/4316efeb/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/JCudaContext.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/JCudaContext.java b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/JCudaContext.java index 38f4e4c..d118429 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/JCudaContext.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/JCudaContext.java @@ -277,6 +277,7 @@ public class JCudaContext extends GPUContext { cudnnDestroy(LibMatrixCUDA.cudnnHandle); cublasDestroy(LibMatrixCUDA.cublasHandle); cusparseDestroy(LibMatrixCUDA.cusparseHandle); + GPUContext.deallocExecutorService.shutdown(); currContext = null; isGPUContextCreated = false; } http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/4316efeb/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/JCudaObject.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/JCudaObject.java b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/JCudaObject.java index 58bc9ec..24063b5 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/JCudaObject.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/JCudaObject.java @@ -57,6 +57,8 @@ import jcuda.jcusparse.cusparseHandle; import jcuda.jcusparse.cusparseMatDescr; import jcuda.jcusparse.cusparsePointerMode; +import java.util.concurrent.Future; + /** * Handle to a matrix block on the GPU */ @@ -215,7 +217,6 @@ public class JCudaObject extends GPUObject { Statistics.cudaFromDevTime.addAndGet(System.nanoTime()-t0); Statistics.cudaFromDevCount.addAndGet(3); } - // ============================================================================================== @@ -407,12 +408,20 @@ public class JCudaObject extends GPUObject { } /** - * Calls cudaFree on the allocated {@link Pointer} instances + * Calls cudaFree asynchronously on the allocated {@link Pointer} instances */ public void deallocate() { - cudaFree(val); - cudaFree(rowPtr); - cudaFree(colInd); + deallocate(false); + } + + /** + * Calls cudaFree asynchronously or synchronously on the allocated {@link Pointer} instances + * @param synchronous whether to do synchronous or async cudaFrees + */ + public void deallocate(boolean synchronous){ + cudaFreeHelper(val); + cudaFreeHelper(rowPtr); + cudaFreeHelper(colInd); } }; @@ -453,15 +462,17 @@ public class JCudaObject extends GPUObject { * @throws DMLRuntimeException if DMLRuntimeException occurs */ public static Pointer allocate(long size, int statsCount) throws DMLRuntimeException{ - Pointer A = new Pointer(); - ensureFreeSpace(size); - long t0 = System.nanoTime(); - cudaMalloc(A, size); - // Set all elements to 0 since newly allocated space will contain garbage - cudaMemset(A, 0, size); - Statistics.cudaAllocTime.getAndAdd(System.nanoTime() - t0); - Statistics.cudaAllocCount.getAndAdd(statsCount); - return A; + synchronized (GPUContext.syncObj) { + Pointer A = new Pointer(); + ensureFreeSpace(size); + long t0 = System.nanoTime(); + cudaMalloc(A, size); + // Set all elements to 0 since newly allocated space will contain garbage + cudaMemset(A, 0, size); + Statistics.cudaAllocTime.getAndAdd(System.nanoTime() - t0); + Statistics.cudaAllocCount.getAndAdd(statsCount); + return A; + } } /** @@ -504,7 +515,6 @@ public class JCudaObject extends GPUObject { LibMatrixCUDA.kernels.launchKernel("fill", ExecutionConfig.getConfigForSimpleVectorOperations(numElems), jcudaDenseMatrixPtr, v, numElems); } - /** * If this {@link JCudaObject} is sparse and empty * Being allocated is a prerequisite to being sparse and empty. @@ -517,47 +527,27 @@ public class JCudaObject extends GPUObject { return isEmptyAndSparseAndAllocated; } - /** - * Allocate necessary memory on the GPU for this {@link JCudaObject} instance. - * - * @param isInput if the block is input, isSparse argument is ignored - * @param isSparse if the block is sparse - * @throws DMLRuntimeException if DMLRuntimeException occurs - */ - private void prepare(boolean isInput, boolean isSparse) throws DMLRuntimeException { - if(isAllocated()) { - // Already allocated on GPU and expected to be in sync - } - else { - if(isInput) { - copyFromHostToDevice(); - } - else { - mat.setDirty(true); - // Don't copy just allocate - if (isSparse){ - allocateSparseMatrixOnDevice(); - } else { // Dense block, size = numRows * numCols - allocateDenseMatrixOnDevice(); - } - synchronized(evictionLock) { - GPUContext.allocatedPointers.add(this); - } - } - } - numLocks.addAndGet(1); - } - @Override public synchronized void acquireDeviceRead() throws DMLRuntimeException { - prepare(true, false); - if(!isAllocated()) + if(!isAllocated()) { + copyFromHostToDevice(); + } else { + numLocks.addAndGet(1); + } + if(!isAllocated()) throw new DMLRuntimeException("Expected device data to be allocated"); } @Override public synchronized void acquireDeviceModifyDense() throws DMLRuntimeException { - prepare(false, false); + if(!isAllocated()) { + mat.setDirty(true); + // Dense block, size = numRows * numCols + allocateDenseMatrixOnDevice(); + synchronized(evictionLock) { + GPUContext.allocatedPointers.add(this); + } + } isDeviceCopyModified = true; if(!isAllocated()) throw new DMLRuntimeException("Expected device data to be allocated"); @@ -566,7 +556,13 @@ public class JCudaObject extends GPUObject { @Override public synchronized void acquireDeviceModifySparse() throws DMLRuntimeException { isInSparseFormat = true; - prepare(false, true); + if(!isAllocated()) { + mat.setDirty(true); + allocateSparseMatrixOnDevice(); + synchronized(evictionLock) { + GPUContext.allocatedPointers.add(this); + } + } isDeviceCopyModified = true; if(!isAllocated()) throw new DMLRuntimeException("Expected device data to be allocated"); @@ -716,17 +712,17 @@ public class JCudaObject extends GPUObject { } @Override - void deallocateMemoryOnDevice() { + void deallocateMemoryOnDevice(boolean synchronous) { if(jcudaDenseMatrixPtr != null) { long start = System.nanoTime(); - cudaFree(jcudaDenseMatrixPtr); + cudaFreeHelper(jcudaDenseMatrixPtr, synchronous); ((JCudaContext)GPUContext.currContext).getAndAddAvailableMemory(numBytes); Statistics.cudaDeAllocTime.addAndGet(System.nanoTime()-start); Statistics.cudaDeAllocCount.addAndGet(1); } if (jcudaSparseMatrixPtr != null) { long start = System.nanoTime(); - jcudaSparseMatrixPtr.deallocate(); + jcudaSparseMatrixPtr.deallocate(synchronous); ((JCudaContext)GPUContext.currContext).getAndAddAvailableMemory(numBytes); Statistics.cudaDeAllocTime.addAndGet(System.nanoTime()-start); Statistics.cudaDeAllocCount.addAndGet(1); @@ -789,14 +785,14 @@ public class JCudaObject extends GPUObject { long t0 = System.nanoTime(); SparseBlockCOO cooBlock = (SparseBlockCOO)block; csrBlock = new SparseBlockCSR(toIntExact(mat.getNumRows()), cooBlock.rowIndexes(), cooBlock.indexes(), cooBlock.values()); - Statistics.cudaConversionTime.addAndGet(System.nanoTime() - t0); - Statistics.cudaConversionCount.incrementAndGet(); + Statistics.cudaSparseConversionTime.addAndGet(System.nanoTime() - t0); + Statistics.cudaSparseConversionCount.incrementAndGet(); } else if (block instanceof SparseBlockMCSR) { long t0 = System.nanoTime(); SparseBlockMCSR mcsrBlock = (SparseBlockMCSR)block; csrBlock = new SparseBlockCSR(mcsrBlock.getRows(), toIntExact(mcsrBlock.size())); - Statistics.cudaConversionTime.addAndGet(System.nanoTime() - t0); - Statistics.cudaConversionCount.incrementAndGet(); + Statistics.cudaSparseConversionTime.addAndGet(System.nanoTime() - t0); + Statistics.cudaSparseConversionCount.incrementAndGet(); } else { throw new DMLRuntimeException("Unsupported sparse matrix format for CUDA operations"); } @@ -956,7 +952,7 @@ public class JCudaObject extends GPUObject { this.jcudaSparseMatrixPtr = sparseMatrixPtr; this.isInSparseFormat = true; if(jcudaDenseMatrixPtr != null) { - cudaFree(jcudaDenseMatrixPtr); + cudaFreeHelper(jcudaDenseMatrixPtr); jcudaDenseMatrixPtr = null; } } @@ -982,6 +978,7 @@ public class JCudaObject extends GPUObject { * @throws DMLRuntimeException if DMLRuntimeException occurs */ public void denseToSparse() throws DMLRuntimeException { + long t0 = System.nanoTime(); cusparseHandle cusparseHandle = LibMatrixCUDA.cusparseHandle; if(cusparseHandle == null) throw new DMLRuntimeException("Expected cusparse to be initialized"); @@ -995,6 +992,8 @@ public class JCudaObject extends GPUObject { setSparseMatrixCudaPointer(columnMajorDenseToRowMajorSparse(cusparseHandle, rows, cols, jcudaDenseMatrixPtr)); // TODO: What if mat.getNnz() is -1 ? numBytes = CSRPointer.estimateSize(mat.getNnz(), rows); + Statistics.cudaDenseToSparseTime.addAndGet(System.nanoTime() - t0); + Statistics.cudaDenseToSparseCount.addAndGet(1); } /** @@ -1005,7 +1004,6 @@ public class JCudaObject extends GPUObject { * @param lda rows in input matrix * @param ldc columns in output matrix * @return transposed matrix - * @throws DMLRuntimeException if DMLRuntimeException occurs */ public static Pointer transpose(Pointer densePtr, int m, int n, int lda, int ldc) throws DMLRuntimeException { Pointer alpha = LibMatrixCUDA.pointerTo(1.0); @@ -1032,7 +1030,7 @@ public class JCudaObject extends GPUObject { } Pointer tmp = transpose(jcudaDenseMatrixPtr, m, n, lda, ldc); - cudaFree(jcudaDenseMatrixPtr); + cudaFreeHelper(jcudaDenseMatrixPtr); setDenseMatrixCudaPointer(tmp); } @@ -1046,7 +1044,7 @@ public class JCudaObject extends GPUObject { } Pointer tmp = transpose(jcudaDenseMatrixPtr, m, n, lda, ldc); - cudaFree(jcudaDenseMatrixPtr); + cudaFreeHelper(jcudaDenseMatrixPtr); setDenseMatrixCudaPointer(tmp); } @@ -1056,11 +1054,14 @@ public class JCudaObject extends GPUObject { * @throws DMLRuntimeException if DMLRuntimeException occurs */ public void sparseToDense() throws DMLRuntimeException { + long t0 = System.nanoTime(); if(jcudaSparseMatrixPtr == null || !isAllocated()) throw new DMLRuntimeException("Expected allocated sparse matrix before sparseToDense() call"); sparseToColumnMajorDense(); convertDensePtrFromColMajorToRowMajor(); + Statistics.cudaSparseToDenseTime.addAndGet(System.nanoTime() - t0); + Statistics.cudaSparseToDenseCount.addAndGet(1); } @@ -1103,8 +1104,8 @@ public class JCudaObject extends GPUObject { ensureFreeSpace(getIntSizeOf(rows + 1)); long t1 = System.nanoTime(); - cudaMalloc(nnzPerRowPtr, getIntSizeOf(rows)); - cudaMalloc(nnzTotalDevHostPtr, getIntSizeOf(1)); + nnzPerRowPtr = allocate(getIntSizeOf(rows)); + nnzTotalDevHostPtr = allocate(getIntSizeOf(1)); Statistics.cudaAllocTime.addAndGet(System.nanoTime() - t1); Statistics.cudaAllocCount.addAndGet(2); @@ -1126,12 +1127,40 @@ public class JCudaObject extends GPUObject { cusparseDdense2csr(cusparseHandle, rows, cols, matDescr, densePtr, rows, nnzPerRowPtr, C.val, C.rowPtr, C.colInd); cudaDeviceSynchronize(); - cudaFree(nnzPerRowPtr); - cudaFree(nnzTotalDevHostPtr); + cudaFreeHelper(nnzPerRowPtr); + cudaFreeHelper(nnzTotalDevHostPtr); return C; } - + + /** + * Does asynchronous cudaFree calls + * @param toFree {@link Pointer} instance to be freed + */ + public static void cudaFreeHelper(final Pointer toFree) { + cudaFreeHelper(toFree, false); + } + + /** + * Does cudaFree calls, either synchronously or asynchronously + * @param toFree {@link Pointer} instance to be freed + * @param synchronous true if to be done synchronously + */ + public static void cudaFreeHelper(final Pointer toFree, boolean synchronous) { + if (synchronous) { + cudaFree(toFree); + } else { + Future submitted = GPUContext.deallocExecutorService.submit(new Runnable() { + @Override + public void run() { + cudaFree(toFree); + } + }); + GPUContext.pendingDeallocates.offer(submitted); + } + } + + /** * Gets the double array from GPU memory onto host memory and returns string. * @param A Pointer to memory on device (GPU), assumed to point to a double array http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/4316efeb/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java index c10d0bf..31ec348 100644 --- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java +++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java @@ -51,13 +51,13 @@ import static jcuda.jcusparse.JCusparse.cusparseDcsrmv; import static jcuda.jcusparse.cusparseOperation.CUSPARSE_OPERATION_NON_TRANSPOSE; import static jcuda.jcusparse.cusparseOperation.CUSPARSE_OPERATION_TRANSPOSE; import static jcuda.runtime.JCuda.cudaDeviceSynchronize; -import static jcuda.runtime.JCuda.cudaFree; -import static jcuda.runtime.JCuda.cudaMalloc; import static jcuda.runtime.JCuda.cudaMemcpy; import static jcuda.runtime.cudaMemcpyKind.cudaMemcpyDeviceToHost; import static jcuda.runtime.cudaMemcpyKind.cudaMemcpyHostToDevice; import static jcuda.runtime.cudaMemcpyKind.cudaMemcpyDeviceToDevice; import static jcuda.jcudnn.cudnnActivationMode.CUDNN_ACTIVATION_RELU; +import static org.apache.sysml.runtime.instructions.gpu.context.JCudaObject.allocate; +import static org.apache.sysml.runtime.instructions.gpu.context.JCudaObject.cudaFreeHelper; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.sysml.runtime.DMLRuntimeException; @@ -206,7 +206,7 @@ public class LibMatrixCUDA { CONVOLUTION_PREFERENCE, sizeInBytesArray[0], algos); cudnnGetConvolutionForwardWorkspaceSize(cudnnHandle, srcTensorDesc, filterDesc, convDesc, dstTensorDesc, algo, sizeInBytesArray); if(sizeInBytesArray[0] != 0) - jcuda.runtime.JCuda.cudaMalloc(workSpace, sizeInBytesArray[0]); + workSpace = allocate(sizeInBytesArray[0]); sizeInBytes = sizeInBytesArray[0]; } else if(CONVOLUTION_PREFERENCE == cudnnConvolutionFwdPreference.CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT) { @@ -230,9 +230,9 @@ public class LibMatrixCUDA { finally { if(alpha != null) - cudaFree(alpha); + cudaFreeHelper(alpha); if(beta != null) - cudaFree(beta); + cudaFreeHelper(beta); if(srcTensorDesc != null) cudnnDestroyTensorDescriptor(srcTensorDesc); @@ -243,7 +243,7 @@ public class LibMatrixCUDA { if(convDesc != null) cudnnDestroyConvolutionDescriptor(convDesc); if(workSpace != null && sizeInBytes != 0) - cudaFree(workSpace); + cudaFreeHelper(workSpace); } } @@ -421,9 +421,9 @@ public class LibMatrixCUDA { } finally { if(alpha != null) - cudaFree(alpha); + cudaFreeHelper(alpha); if(beta != null) - cudaFree(beta); + cudaFreeHelper(beta); if(xTensorDesc != null) cudnnDestroyTensorDescriptor(xTensorDesc); if(doutTensorDesc != null) @@ -435,7 +435,7 @@ public class LibMatrixCUDA { cudnnDestroyConvolutionDescriptor(convDesc); if(workSpace != null && sizeInBytes != 0) - cudaFree(workSpace); + cudaFreeHelper(workSpace); } } @@ -483,9 +483,9 @@ public class LibMatrixCUDA { finally { if(alpha != null) - cudaFree(alpha); + cudaFreeHelper(alpha); if(beta != null) - cudaFree(beta); + cudaFreeHelper(beta); if(srcTensorDesc != null) cudnnDestroyTensorDescriptor(srcTensorDesc); @@ -678,11 +678,11 @@ public class LibMatrixCUDA { int colsA = (int)left.getNumColumns(); Pointer AT = JCudaObject.transpose(ADense, rowsA, colsA, colsA, rowsA); CSRPointer A = JCudaObject.columnMajorDenseToRowMajorSparse(cusparseHandle, rowsA, colsA, AT); - Statistics.cudaConversionTime.addAndGet(System.nanoTime() - t0); - Statistics.cudaConversionCount.addAndGet(1); + Statistics.cudaSparseConversionTime.addAndGet(System.nanoTime() - t0); + Statistics.cudaSparseConversionCount.addAndGet(1); sparseSparseMatmult(output, transA, transB, m, n, k, A, B); A.deallocate(); - cudaFree(AT); + cudaFreeHelper(AT); } else { LOG.debug(" GPU Dense-Sparse Matrix Multiplication (Converted to Dense-Dense)"); // Convert right to dense and do a cuBlas matmul @@ -696,7 +696,7 @@ public class LibMatrixCUDA { (int) right.getNumColumns(), (int) right.getNumRows(), isLeftTransposed, !isRightTransposed, ADense, BDenseTransposed); - cudaFree(BDenseTransposed); + cudaFreeHelper(BDenseTransposed); } } @@ -737,11 +737,11 @@ public class LibMatrixCUDA { int colsB = (int)right.getNumColumns(); Pointer BT = JCudaObject.transpose(BDense, rowsB, colsB, colsB, rowsB); CSRPointer B = JCudaObject.columnMajorDenseToRowMajorSparse(cusparseHandle, rowsB, colsB, BT); - Statistics.cudaConversionTime.addAndGet(System.nanoTime() - t0); - Statistics.cudaConversionCount.addAndGet(1); + Statistics.cudaSparseConversionTime.addAndGet(System.nanoTime() - t0); + Statistics.cudaSparseConversionCount.addAndGet(1); sparseSparseMatmult(output, transA, transB, m, n, k, A, B); B.deallocate(); - cudaFree(BT); + cudaFreeHelper(BT); } else { LOG.debug(" GPU Sparse-Dense Matrix Multiplication (Converted to Dense-Dense)"); // Convert left to dense and do a cuBlas matmul @@ -755,7 +755,7 @@ public class LibMatrixCUDA { (int) right.getNumRows(), (int) right.getNumColumns(), !isLeftTransposed, isRightTransposed, ADenseTransposed, BDense); - cudaFree(ADenseTransposed); + cudaFreeHelper(ADenseTransposed); } } } @@ -1142,7 +1142,7 @@ public class LibMatrixCUDA { default: throw new DMLRuntimeException("Internal Error - Unsupported reduction direction for summation squared"); } - cudaFree(tmp); + cudaFreeHelper(tmp); break; } case OP_MEAN:{ @@ -1255,7 +1255,7 @@ public class LibMatrixCUDA { ScalarOperator divideOp = new RightScalarOperator(Divide.getDivideFnObject(), clen - 1); matrixScalarOp(tmpRow, clen - 1, rlen, clen, out, divideOp); - cudaFree(tmpRow); + cudaFreeHelper(tmpRow); break; } case REDUCTION_ROW: { @@ -1272,14 +1272,14 @@ public class LibMatrixCUDA { ScalarOperator divideOp = new RightScalarOperator(Divide.getDivideFnObject(), rlen - 1); matrixScalarOp(tmpCol, rlen - 1, rlen, clen, out, divideOp); - cudaFree(tmpCol); + cudaFreeHelper(tmpCol); break; } default: throw new DMLRuntimeException("Internal Error - Unsupported reduction direction for variance"); } - cudaFree(tmp); - cudaFree(tmp2); + cudaFreeHelper(tmp); + cudaFreeHelper(tmp2); break; } case OP_MAXINDEX : { @@ -1343,7 +1343,7 @@ public class LibMatrixCUDA { } double[] result = {-1f}; cudaMemcpy(Pointer.to(result), tempOut, Sizeof.DOUBLE, cudaMemcpyDeviceToHost); - cudaFree(tempOut); + cudaFreeHelper(tempOut); return result[0]; } @@ -1527,9 +1527,9 @@ public class LibMatrixCUDA { } finally { if(alpha != null) - cudaFree(alpha); + cudaFreeHelper(alpha); if(beta != null) - cudaFree(beta); + cudaFreeHelper(beta); if(dyDesc != null) cudnnDestroyTensorDescriptor(dyDesc); if(dxDesc != null) @@ -1541,7 +1541,7 @@ public class LibMatrixCUDA { cudnnDestroyConvolutionDescriptor(convDesc); if(workSpace != null && sizeInBytes != 0) - cudaFree(workSpace); + cudaFreeHelper(workSpace); } } @@ -1598,9 +1598,9 @@ public class LibMatrixCUDA { } finally { if(alpha != null) - cudaFree(alpha); + cudaFreeHelper(alpha); if(beta != null) - cudaFree(beta); + cudaFreeHelper(beta); if(yDesc != null) cudnnDestroyTensorDescriptor(yDesc); if(xDesc != null) @@ -1661,9 +1661,8 @@ public class LibMatrixCUDA { // Calling PoolForward first, y is one of the inputs for poolBackward // TODO: Remove calling poolForward after necessary changes at language level for poolBackward - Pointer y = new Pointer(); long numBytes = N*C*P*Q*Sizeof.DOUBLE; - cudaMalloc(y, numBytes); + Pointer y = allocate(numBytes); // Allocate data Pointer x = ((JCudaObject)image.getGPUObject()).jcudaDenseMatrixPtr; @@ -1684,13 +1683,13 @@ public class LibMatrixCUDA { throw new DMLRuntimeException("Could not executed cudnnPoolingBackward: " + jcuda.jcudnn.cudnnStatus.stringFor(status)); } - cudaFree(y); + cudaFreeHelper(y); } finally { if(alpha != null) - cudaFree(alpha); + cudaFreeHelper(alpha); if(beta != null) - cudaFree(beta); + cudaFreeHelper(beta); if(yDesc != null) cudnnDestroyTensorDescriptor(yDesc); if(xDesc != null) @@ -2219,6 +2218,40 @@ public class LibMatrixCUDA { } /** + * Performs an "exp" operation on a matrix on the GPU + * @param ec execution context + * @param in1 input matrix + * @param outputName output matrix name + * @throws DMLRuntimeException if DMLRuntimeException occurs + */ + public static void exp(ExecutionContext ec, MatrixObject in1, String outputName) throws DMLRuntimeException { + JCudaObject in = ((JCudaObject)in1.getGPUObject()); + boolean isSparseAndEmpty = in.isSparseAndEmpty(); + boolean isSparse = in.isInSparseFormat(); + + if (isSparseAndEmpty) { + // e^0 = 1, create a dense block full of 1s + MatrixObject out = ec.getMatrixObject(outputName); + ec.allocateGPUMatrixObject(outputName); + ((JCudaObject)(out.getGPUObject())).allocateAndFillDense(1); + } else { + // Sparse + if (isSparse) { + // If the input is in sparse format, convert it to dense. + // The output will always be dense, because for all x, exp(x) > 0 + in.sparseToDense(); + } + // Dense + MatrixObject out = ec.getDenseMatrixOutputForGPUInstruction(outputName); + Pointer output = ((JCudaObject)out.getGPUObject()).jcudaDenseMatrixPtr; + Pointer input = in.jcudaDenseMatrixPtr; + int size = (int)(in1.getNumColumns() * in1.getNumRows()); + kernels.launchKernel("matrix_exp", ExecutionConfig.getConfigForSimpleVectorOperations(size), + input, output, size); + } + } + + /** * Convenience method for debugging matrices on the GPU. * @param in Pointer to a double array (matrix) on the GPU * @param rlen row length http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/4316efeb/src/main/java/org/apache/sysml/utils/Statistics.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/utils/Statistics.java b/src/main/java/org/apache/sysml/utils/Statistics.java index e371d9c..87cb64f 100644 --- a/src/main/java/org/apache/sysml/utils/Statistics.java +++ b/src/main/java/org/apache/sysml/utils/Statistics.java @@ -101,8 +101,12 @@ public class Statistics public static long cudaInitTime = 0; public static long cudaLibrariesInitTime = 0; - public static AtomicLong cudaConversionTime = new AtomicLong(0); // Measures time spent in converting between sparse block types - public static AtomicLong cudaConversionCount = new AtomicLong(0); + public static AtomicLong cudaSparseToDenseTime = new AtomicLong(0); // Measures time spent in converting sparse matrix block to dense + public static AtomicLong cudaSparseToDenseCount = new AtomicLong(0); + public static AtomicLong cudaDenseToSparseTime = new AtomicLong(0); // Measures time spent in converting dense matrix block to sparse + public static AtomicLong cudaDenseToSparseCount = new AtomicLong(0); + public static AtomicLong cudaSparseConversionTime = new AtomicLong(0); // Measures time spent in converting between sparse block types + public static AtomicLong cudaSparseConversionCount = new AtomicLong(0); public static AtomicLong cudaAllocTime = new AtomicLong(0); public static AtomicLong cudaDeAllocTime = new AtomicLong(0); public static AtomicLong cudaToDevTime = new AtomicLong(0); @@ -619,19 +623,26 @@ public class Statistics sb.append("CUDA/CuLibraries init time:\t" + String.format("%.3f", cudaInitTime*1e-9) + "/" + String.format("%.3f", cudaLibrariesInitTime*1e-9) + " sec.\n"); sb.append("Number of executed GPU inst:\t" + getNoOfExecutedGPUInst() + ".\n"); - sb.append("GPU mem tx time (alloc/dealloc/conversion/toDev/fromDev):\t" + sb.append("GPU mem tx time (alloc/dealloc/toDev/fromDev):\t" + String.format("%.3f", cudaAllocTime.get()*1e-9) + "/" + String.format("%.3f", cudaDeAllocTime.get()*1e-9) + "/" - + String.format("%.3f", cudaConversionTime.get()*1e-9) + "/" + String.format("%.3f", cudaToDevTime.get()*1e-9) + "/" + String.format("%.3f", cudaFromDevTime.get()*1e-9) + " sec.\n"); - sb.append("GPU mem tx count (alloc/dealloc/conversion/toDev/fromDev/evict):\t" + sb.append("GPU mem tx count (alloc/dealloc/toDev/fromDev/evict):\t" + cudaAllocCount.get() + "/" + cudaDeAllocCount.get() + "/" - + cudaConversionCount.get() + "/" + + cudaSparseConversionCount.get() + "/" + cudaToDevCount.get() + "/" + cudaFromDevCount.get() + "/" + cudaEvictionCount.get() + ".\n"); + sb.append("GPU conversion time (sparseConv/sp2dense/dense2sp):\t" + + String.format("%.3f", cudaSparseConversionTime.get()*1e-9) + "/" + + String.format("%.3f", cudaSparseToDenseTime.get()*1e-9) + "/" + + String.format("%.3f", cudaDenseToSparseTime.get()*1e-9) + " sec.\n"); + sb.append("GPU conversion count (sparseConv/sp2dense/dense2sp):\t" + + cudaSparseConversionCount.get() + "/" + + cudaSparseToDenseCount.get() + "/" + + cudaDenseToSparseCount.get() + ".\n"); } //show extended caching/compilation statistics http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/4316efeb/src/test/java/org/apache/sysml/test/integration/functions/misc/RewritePushdownUaggTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewritePushdownUaggTest.java b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewritePushdownUaggTest.java index 348cfbb..84547ba 100644 --- a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewritePushdownUaggTest.java +++ b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewritePushdownUaggTest.java @@ -100,13 +100,7 @@ public class RewritePushdownUaggTest extends AutomatedTestBase testRewritePushdownUagg( TEST_NAME4, true ); } - - /** - * - * @param condition - * @param branchRemoval - * @param IPA - */ + private void testRewritePushdownUagg( String testname, boolean rewrites ) { boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION; @@ -146,8 +140,11 @@ public class RewritePushdownUaggTest extends AutomatedTestBase check = rewrites ? "uarmin" : "uacmin"; else if( testname.equals(TEST_NAME4) ) //rowmins check = rewrites ? "uacmin" : "uarmin"; - - Assert.assertTrue( "Missing opcode: "+check, Statistics.getCPHeavyHitterOpCodes().contains(check) ); + + String gpuCheck = "gpu_" + check; + boolean containsOpcode = Statistics.getCPHeavyHitterOpCodes().contains(check) || Statistics.getCPHeavyHitterOpCodes().contains(gpuCheck); + + Assert.assertTrue( "Missing opcode: "+check, containsOpcode); } finally { http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/4316efeb/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteSimplifyRowColSumMVMultTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteSimplifyRowColSumMVMultTest.java b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteSimplifyRowColSumMVMultTest.java index 2829bab..51ad19f 100644 --- a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteSimplifyRowColSumMVMultTest.java +++ b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteSimplifyRowColSumMVMultTest.java @@ -115,7 +115,10 @@ public class RewriteSimplifyRowColSumMVMultTest extends AutomatedTestBase TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R"); //check matrix mult existence - Assert.assertTrue( Statistics.getCPHeavyHitterOpCodes().contains("ba+*") == rewrites ); + String gpuBa = "gpu_ba+*"; + String ba = "ba+*"; + boolean isMatmultPresent = Statistics.getCPHeavyHitterOpCodes().contains(ba) || Statistics.getCPHeavyHitterOpCodes().contains(gpuBa); + Assert.assertTrue( isMatmultPresent == rewrites ); } finally {
