Repository: systemml Updated Branches: refs/heads/master 1e5984cca -> e4220e3bc
[SYSTEMML-445] Added builtin functions for efficient computation of lstm_backward function - The current implementation treats lstm and lstm_backward as stateless function for simplicity. We can revisit this after performance testing. - Removed reserve parameter from lstm builtin function. - Updated the language reference and lstm_staging.dml file. - Added necessary kernels for transforming input to the format required by lstm_backward function. Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/e4220e3b Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/e4220e3b Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/e4220e3b Branch: refs/heads/master Commit: e4220e3bc2559056f94ef08cd1c9006315cf229a Parents: 1e5984c Author: Niketan Pansare <[email protected]> Authored: Fri Jun 15 13:07:59 2018 -0700 Committer: Niketan Pansare <[email protected]> Committed: Fri Jun 15 13:07:59 2018 -0700 ---------------------------------------------------------------------- docs/dml-language-reference.md | 2 +- scripts/nn/layers/lstm_staging.dml | 85 +-- src/main/cpp/kernels/SystemML.cu | 121 ++- src/main/cpp/kernels/SystemML.ptx | 760 ++++++++++++++----- .../java/org/apache/sysml/hops/FunctionOp.java | 9 +- .../sysml/parser/BuiltinFunctionExpression.java | 60 +- .../org/apache/sysml/parser/DMLTranslator.java | 1 + .../org/apache/sysml/parser/Expression.java | 2 +- .../instructions/GPUInstructionParser.java | 1 + .../gpu/ConvolutionGPUInstruction.java | 79 +- .../instructions/gpu/context/GPUObject.java | 4 +- .../runtime/matrix/data/LibMatrixCuDNN.java | 98 ++- .../matrix/data/LibMatrixCuDNNRnnAlgorithm.java | 48 +- 13 files changed, 962 insertions(+), 308 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/e4220e3b/docs/dml-language-reference.md ---------------------------------------------------------------------- diff --git a/docs/dml-language-reference.md b/docs/dml-language-reference.md index 3212806..5bf9099 100644 --- a/docs/dml-language-reference.md +++ b/docs/dml-language-reference.md @@ -1520,7 +1520,7 @@ Hence, the images are internally represented as a matrix with dimension (N, C * | max_pool_backward, avg_pool_backward | input, dout | [batch_size X num_channels* height_image* width_image] | [batch_size X num_channels* height_out* width_out] | [batch_size X num_channels* height_image* width_image] | stride=[stride_h, stride_w], padding=[pad_h, pad_w], input_shape=[batch_size, num_channels, height_image, width_image], pool_size=[height_pool, width_pool] | Computes the gradients wrt input of 2D max pooling, average pooling | | bias_add | input, bias | [batch_size X num_channels* height_image* width_image] | [num_channels X 1] | [batch_size X num_channels* height_image* width_image] | | Adds the bias (row vector of size num_channels) to input with the given num_channels | | bias_multiply | input, bias | [batch_size X num_channels* height_image* width_image] | [num_channels X 1] | [batch_size X num_channels* height_image* width_image] | | Multiplies the bias (row vector of size num_channels) to input with the given num_channels | -| lstm | X, W, bias, out0, c0 | [batch_size X seq_length*num_features] | [num_features+hidden_size X 4*hidden_size] | [batch_size X seq_length*hidden_size] if return_sequences else [batch_size X hidden_size] | return_sequences | Perform computation for single-layer unidirectional LSTM (outputs: out, carryOut, reserveSpace) | +| lstm | X, W, bias, out0, c0 | [batch_size X seq_length*num_features] | [num_features+hidden_size X 4*hidden_size] | [batch_size X seq_length*hidden_size] if return_sequences else [batch_size X hidden_size] | return_sequences | Perform computation for single-layer unidirectional LSTM (outputs: out, carryOut) | | batch_norm2d | input | [batch_size X num_channels* height_image* width_image] | | [batch_size X num_channels* height_image* width_image] | scale, shift, exponentialMovingAverage_Mean, exponentialMovingAverage_Variance, mode, epsilon, momentum | Performs batch normalization operation (outputs: updated exponential moving average mean and variance, cache of the batch mean and variance) | | batch_norm2d_backward | input, dout | [batch_size X num_channels* height_image* width_image] | [batch_size X num_channels* height_image* width_image] | [batch_size X num_channels* height_image* width_image] | scale, epsilon, cache_mean (from forward), cache_inv_var (from forward) | Computed backpropagation error for batch normalization operation | http://git-wip-us.apache.org/repos/asf/systemml/blob/e4220e3b/scripts/nn/layers/lstm_staging.dml ---------------------------------------------------------------------- diff --git a/scripts/nn/layers/lstm_staging.dml b/scripts/nn/layers/lstm_staging.dml index d0949d9..886b88c 100644 --- a/scripts/nn/layers/lstm_staging.dml +++ b/scripts/nn/layers/lstm_staging.dml @@ -57,17 +57,15 @@ forward = function(matrix[double] X, matrix[double] W, matrix[double] b, * - out: If `return_sequences` is True, outputs for all timesteps, * of shape (N, T*M). Else, outputs for the final timestep, of * shape (N, M). - * - c: Cell state for final timestep, of shape (N, M). - * - reserveSpace: reserveSpace to be passed to output (row-vector whose size is determined at runtime). + * - c: Cell state for final timestep, of shape (N, M). */ - [out, c, reserveSpace] = lstm(X, W, b, out0, c0, return_sequences) + out = 0; c = c0; + [out, c] = lstm(X, W, b, out0, c0, return_sequences) } -# TODO: backward = function(matrix[double] dout, matrix[double] dc, - matrix[double] X, matrix[double] W, matrix[double] b, int T, int D, - boolean given_sequences, matrix[double] out0, matrix[double] c0, - matrix[double] cache_out, matrix[double] cache_c, matrix[double] cache_ifog) + matrix[double] X, matrix[double] W, matrix[double] b, + boolean given_sequences, matrix[double] out0, matrix[double] c0) return (matrix[double] dX, matrix[double] dW, matrix[double] db, matrix[double] dout0, matrix[double] dc0) { /* @@ -87,8 +85,6 @@ backward = function(matrix[double] dout, matrix[double] dc, * - X: Inputs, of shape (N, T*D). * - W: Weights, of shape (D+M, 4M). * - b: Biases, of shape (1, 4M). - * - T: Length of example sequences (number of timesteps). - * - D: Dimensionality of the input features. * - given_sequences: Whether `dout` is for all timesteps, * or just for the final timestep. This is based on whether * `return_sequences` was true in the forward pass. @@ -110,75 +106,8 @@ backward = function(matrix[double] dout, matrix[double] dc, * - dout0: Gradient wrt `out0`, of shape (N, M). * - dc0: Gradient wrt `c0`, of shape (N, M). */ - N = nrow(X) - M = as.integer(ncol(W)/4) - N1 = nrow(out0) - if(N != N1) { - # Allow for smaller out0 for last batch - # out0 = out0[1:N,] - # c0 = c0[1:N,] - stop("Unsupported operation: The batch size of previous iteration " + N1 + " is different than the batch size of current iteration " + N) - } - dX = matrix(0, rows=N, cols=T*D) - dW = matrix(0, rows=D+M, cols=4*M) - db = matrix(0, rows=1, cols=4*M) - dout0 = matrix(0, rows=N, cols=M) - dc0 = matrix(0, rows=N, cols=M) - dct = dc - if (!given_sequences) { - # only given dout for output at final timestep, so prepend empty douts for all other timesteps - dout = cbind(matrix(0, rows=N, cols=(T-1)*M), dout) # shape (N, T*M) - } - - t = T - for (iter in 1:T) { # each timestep in reverse order - X_t = X[,(t-1)*D+1:t*D] # shape (N, D) - dout_t = dout[,(t-1)*M+1:t*M] # shape (N, M) - out_t = matrix(cache_out[t,], rows=N, cols=M) # shape (N, M) - ct = matrix(cache_c[t,], rows=N, cols=M) # shape (N, M) - if (t == 1) { - out_prev = out0 # shape (N, M) - c_prev = c0 # shape (N, M) - } - else { - out_prev = matrix(cache_out[t-1,], rows=N, cols=M) # shape (N, M) - c_prev = matrix(cache_c[t-1,], rows=N, cols=M) # shape (N, M) - } - input = cbind(X_t, out_prev) # shape (N, D+M) - ifog = matrix(cache_ifog[t,], rows=N, cols=4*M) - i = ifog[,1:M] # input gate, shape (N, M) - f = ifog[,M+1:2*M] # forget gate, shape (N, M) - o = ifog[,2*M+1:3*M] # output gate, shape (N, M) - g = ifog[,3*M+1:4*M] # g gate, shape (N, M) - - dct = dct + o*tanh::backward(dout_t, ct) # shape (N, M) - do = tanh::forward(ct) * dout_t # output gate, shape (N, M) - df = c_prev * dct # forget gate, shape (N, M) - dc_prev = f * dct # shape (N, M) - di = g * dct # input gate, shape (N, M) - dg = i * dct # g gate, shape (N, M) - - di_raw = i * (1-i) * di - df_raw = f * (1-f) * df - do_raw = o * (1-o) * do - dg_raw = (1-g^2) * dg - difog_raw = cbind(di_raw, df_raw, do_raw, dg_raw) # shape (N, 4M) - - dW = dW + t(input) %*% difog_raw # shape (D+M, 4M) - db = db + colSums(difog_raw) # shape (1, 4M) - dinput = difog_raw %*% t(W) # shape (N, D+M) - dX[,(t-1)*D+1:t*D] = dinput[,1:D] - dout_prev = dinput[,D+1:D+M] # shape (N, M) - if (t == 1) { - dout0 = dout_prev # shape (N, M) - dc0 = dc_prev # shape (N, M) - } - else { - dout[,(t-2)*M+1:(t-1)*M] = dout[,(t-2)*M+1:(t-1)*M] + dout_prev # shape (N, M) - dct = dc_prev # shape (N, M) - } - t = t - 1 - } + dX = X; dW = W; db = b; dout0 = out0; dc0 = c0 + [dX, dW, db, dout0, dc0] = lstm_backward(X, W, b, out0, c0, given_sequences, dout, dc) } init = function(int N, int D, int M) http://git-wip-us.apache.org/repos/asf/systemml/blob/e4220e3b/src/main/cpp/kernels/SystemML.cu ---------------------------------------------------------------------- diff --git a/src/main/cpp/kernels/SystemML.cu b/src/main/cpp/kernels/SystemML.cu index cc2d531..082daac 100644 --- a/src/main/cpp/kernels/SystemML.cu +++ b/src/main/cpp/kernels/SystemML.cu @@ -1987,11 +1987,7 @@ __device__ int swap_co(int offset) { return (offset < 2) ? offset : (offset == 2 ? 3 : 2); } -template <typename T> -__device__ void prepare_lstm_weight(T* smlWeight, T* smlBias, T* cudnnWeight, int D, int M) { - int DM = D*M; int MM = M*M; int DM4 = DM*4; - int M4 = M*4; - int index = blockIdx.x * blockDim.x + threadIdx.x; +__device__ void compute_lstm_weight_indexes(int index, int D, int M, int* ret) { // input: cbind(X_t, out_prev) => [N, D+M], weight: [D+M, 4M] // https://docs.nvidia.com/deeplearning/sdk/cudnn-developer-guide/index.html#cudnnGetRNNLinLayerMatrixParams states that // Elements in each weight matrix are arranged in the row-major order, but the column-major format works !! @@ -1999,20 +1995,16 @@ __device__ void prepare_lstm_weight(T* smlWeight, T* smlBias, T* cudnnWeight, in // CuDNN weight order: w_i, w_f, w_c, w_o, r_i, r_f, r_c, r_o // SystemML weight order: i, f, o, c; TF weight order: i, c, f, o // SystemML performs (X_t %*% W + out_prev %*% R) => [N, 4*M] - - // bias layout: bi bf bc bo 0 0 0 0 - // where W: [DxM], R: [MxM] and b: [1x1] - - // Maximum (D+M+2)*M4 threads - int srcIndex = -1; int destIndex; + int DM = D*M; int MM = M*M; int DM4 = DM*4; + int M4 = M*4; if(index < DM4) { // Fill w_i, w_f, w_c and w_o int localIndex = index%DM; int smlRowIndex = localIndex/M; int smlColIndex = swap_co(index/(DM))*M + localIndex%M; // Convert index to column-major where index = (index/(DM))*DM + (localIndex/M)*M + localIndex%M - destIndex = (index/(DM))*DM + (localIndex%M)*D + localIndex/M; - srcIndex = smlRowIndex*M4+smlColIndex; + ret[1] = (index/(DM))*DM + (localIndex%M)*D + localIndex/M; + ret[0] = smlRowIndex*M4+smlColIndex; } else if(index < (D+M)*M4) { // Fill r_i, r_f, r_c and r_o @@ -2021,18 +2013,29 @@ __device__ void prepare_lstm_weight(T* smlWeight, T* smlBias, T* cudnnWeight, in int smlRowIndex = D + (localIndex / M); int smlColIndex = swap_co(tmpIndex/(MM))*M + localIndex%M; // Convert index to column-major where index = DM4 + (tmpIndex/(MM))*MM + (localIndex/M)*M + localIndex%M - destIndex = DM4 + (tmpIndex/(MM))*MM + (localIndex%M)*M + localIndex/M; - srcIndex = smlRowIndex*M4+smlColIndex; + ret[1] = DM4 + (tmpIndex/(MM))*MM + (localIndex%M)*M + localIndex/M; + ret[0] = smlRowIndex*M4+smlColIndex; + } +} + +template <typename T> +__device__ void prepare_lstm_weight(T* smlWeight, T* smlBias, T* cudnnWeight, int D, int M) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + // Maximum (D+M+2)*M4 threads + int M4 = M*4; + if(index < (D+M)*M4) { + int indexes[2]; + compute_lstm_weight_indexes(index, D, M, indexes); + cudnnWeight[indexes[1]] = smlWeight[indexes[0]]; } else if(index < (D+M+1)*M4) { // Fill bias + // bias layout: bi bf bc bo 0 0 0 0 + // where W: [DxM], R: [MxM] and b: [1x1] int tmpIndex = index - (D+M)*M4; int smlColIndex = swap_co(tmpIndex/(M))*M + tmpIndex%M; cudnnWeight[index] = smlBias[smlColIndex]; } - // __syncthreads(); - if(srcIndex != -1) - cudnnWeight[destIndex] = smlWeight[srcIndex]; } extern "C" __global__ void prepare_lstm_weight_d(double* smlWeight, double* smlBias, double* cudnnWeight, int D, int M) { @@ -2162,3 +2165,85 @@ extern "C" __global__ void prepare_lstm_output_d(double* smlInput, double* cudnn extern "C" __global__ void prepare_lstm_output_f(float* smlInput, float* cudnnInput, int N, int T, int M, int size) { prepare_lstm_output(smlInput, cudnnInput, N, T, M, size); } + +template <typename T> +__device__ void prepare_lstm_backward_gradients(T* smlDout, T* cudnnDy, int N, int T1, int M, int size, int return_sequences) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + if(index < size && return_sequences != 0) { + // smlDout = [N, T, M] + int TM = T1*M; + int NT = T1*N; + int n = index / TM; + int tm = index % TM; + int t = tm / M; + int m = tm % M; + T val = smlDout[index]; + cudnnDy[t*N*M + n*M + m] = val; + } + else if(index < size) { + // smlDout = [N, T, M] + int n = index / M; + int m = index % M; + T val = smlDout[index]; + cudnnDy[(T1-1)*N*M + n*M + m] = val; + } +} + + +extern "C" __global__ void prepare_lstm_backward_gradients_d(double* smlInput, double* cudnnDy, int N, int T, int M, int size, int return_sequences) { + prepare_lstm_backward_gradients(smlInput, cudnnDy, N, T, M, size, return_sequences); +} + +extern "C" __global__ void prepare_lstm_backward_gradients_f(float* smlInput, float* cudnnDy, int N, int T, int M, int size, int return_sequences) { + prepare_lstm_backward_gradients(smlInput, cudnnDy, N, T, M, size, return_sequences); +} + +template <typename T> +__device__ void prepare_lstm_dweight(T* smldWeight, T* smldBias, T* cudnndWeight, int D, int M) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + // Maximum (D+M+2)*M4 threads + int M4 = M*4; + if(index < (D+M)*M4) { + int indexes[2]; + compute_lstm_weight_indexes(index, D, M, indexes); + smldWeight[indexes[0]] = cudnndWeight[indexes[1]]; + } + else if(index < (D+M+1)*M4) { + // Fill bias + // bias layout: bi bf bc bo 0 0 0 0 + // where W: [DxM], R: [MxM] and b: [1x1] + int tmpIndex = index - (D+M)*M4; + int smlColIndex = swap_co(tmpIndex/(M))*M + tmpIndex%M; + smldBias[smlColIndex] = cudnndWeight[index]; + } +} + +extern "C" __global__ void prepare_lstm_dweight_d(double* smldWeight, double* smldBias, double* cudnndWeight, int D, int M) { + prepare_lstm_dweight(smldWeight, smldBias, cudnndWeight, D, M); +} + +extern "C" __global__ void prepare_lstm_dweight_f(float* smldWeight, float* smldBias, float* cudnndWeight, int D, int M) { + prepare_lstm_dweight(smldWeight, smldBias, cudnndWeight, D, M); +} + +template <typename T> +__device__ void prepare_lstm_dinput(T* smlInput, T* cudnnInput, int N, int D, int TD, int size) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + if(index < size) { + int n = index / TD; + int td = index % TD; + int t = td / D; + int d = td % D; + smlInput[index] = cudnnInput[t*N*D + n*D + d]; + } +} + + +extern "C" __global__ void prepare_lstm_dinput_d(double* smlInput, double* cudnnInput, int N, int D, int TD, int size) { + prepare_lstm_dinput(smlInput, cudnnInput, N, D, TD, size); +} + +extern "C" __global__ void prepare_lstm_dinput_f(float* smlInput, float* cudnnInput, int N, int D, int TD, int size) { + prepare_lstm_dinput(smlInput, cudnnInput, N, D, TD, size); +} + http://git-wip-us.apache.org/repos/asf/systemml/blob/e4220e3b/src/main/cpp/kernels/SystemML.ptx ---------------------------------------------------------------------- diff --git a/src/main/cpp/kernels/SystemML.ptx b/src/main/cpp/kernels/SystemML.ptx index ed1a100..9d8b178 100644 --- a/src/main/cpp/kernels/SystemML.ptx +++ b/src/main/cpp/kernels/SystemML.ptx @@ -11794,8 +11794,8 @@ BB99_2: .param .u32 prepare_lstm_weight_d_param_4 ) { - .reg .pred %p<11>; - .reg .b32 %r<53>; + .reg .pred %p<8>; + .reg .b32 %r<48>; .reg .f64 %fd<3>; .reg .b64 %rd<15>; @@ -11803,98 +11803,87 @@ BB99_2: ld.param.u64 %rd2, [prepare_lstm_weight_d_param_0]; ld.param.u64 %rd3, [prepare_lstm_weight_d_param_1]; ld.param.u64 %rd4, [prepare_lstm_weight_d_param_2]; - ld.param.u32 %r13, [prepare_lstm_weight_d_param_3]; - ld.param.u32 %r14, [prepare_lstm_weight_d_param_4]; + ld.param.u32 %r45, [prepare_lstm_weight_d_param_3]; + ld.param.u32 %r21, [prepare_lstm_weight_d_param_4]; cvta.to.global.u64 %rd1, %rd4; - mul.lo.s32 %r1, %r14, %r13; - shl.b32 %r2, %r1, 2; - shl.b32 %r3, %r14, 2; - mov.u32 %r15, %ntid.x; - mov.u32 %r16, %ctaid.x; - mov.u32 %r17, %tid.x; - mad.lo.s32 %r4, %r15, %r16, %r17; - setp.lt.s32 %p1, %r4, %r2; - @%p1 bra BB100_5; + mov.u32 %r22, %ntid.x; + mov.u32 %r23, %ctaid.x; + mov.u32 %r24, %tid.x; + mad.lo.s32 %r1, %r22, %r23, %r24; + add.s32 %r2, %r21, %r45; + shl.b32 %r3, %r21, 2; + mul.lo.s32 %r4, %r2, %r3; + setp.lt.s32 %p1, %r1, %r4; + @%p1 bra BB100_3; bra.uni BB100_1; +BB100_3: + mul.lo.s32 %r5, %r21, %r45; + mul.lo.s32 %r47, %r21, %r21; + shl.b32 %r7, %r5, 2; + setp.lt.s32 %p5, %r1, %r7; + @%p5 bra BB100_5; + bra.uni BB100_4; + BB100_5: - rem.s32 %r42, %r4, %r1; - div.s32 %r43, %r42, %r14; - div.s32 %r44, %r4, %r1; - setp.lt.s32 %p8, %r44, 2; - setp.eq.s32 %p9, %r44, 2; - selp.b32 %r45, 3, 2, %p9; - selp.b32 %r46, %r44, %r45, %p8; - rem.s32 %r47, %r42, %r14; - sub.s32 %r48, %r4, %r42; - add.s32 %r49, %r48, %r43; - mad.lo.s32 %r52, %r47, %r13, %r49; - mad.lo.s32 %r50, %r43, %r3, %r47; - mad.lo.s32 %r51, %r46, %r14, %r50; + rem.s32 %r44, %r1, %r5; + div.s32 %r42, %r44, %r21; + mov.u32 %r43, %r42; + mov.u32 %r46, %r1; + mov.u32 %r47, %r5; bra.uni BB100_6; BB100_1: - add.s32 %r5, %r14, %r13; - mul.lo.s32 %r6, %r5, %r3; - setp.lt.s32 %p2, %r4, %r6; - @%p2 bra BB100_4; - bra.uni BB100_2; - -BB100_4: - mul.lo.s32 %r30, %r14, %r14; - sub.s32 %r31, %r4, %r2; - rem.s32 %r32, %r31, %r30; - div.s32 %r33, %r32, %r14; - add.s32 %r34, %r33, %r13; - div.s32 %r35, %r31, %r30; - setp.lt.s32 %p6, %r35, 2; - setp.eq.s32 %p7, %r35, 2; - selp.b32 %r36, 3, 2, %p7; - selp.b32 %r37, %r35, %r36, %p6; - rem.s32 %r38, %r32, %r14; - sub.s32 %r39, %r4, %r32; - add.s32 %r40, %r39, %r33; - mad.lo.s32 %r52, %r38, %r14, %r40; - mad.lo.s32 %r41, %r34, %r3, %r38; - mad.lo.s32 %r51, %r37, %r14, %r41; - bra.uni BB100_6; - -BB100_2: - add.s32 %r20, %r5, 1; - mul.lo.s32 %r21, %r20, %r3; - mov.u32 %r51, -1; - setp.ge.s32 %p3, %r4, %r21; - @%p3 bra BB100_6; + add.s32 %r25, %r2, 1; + mul.lo.s32 %r26, %r25, %r3; + setp.ge.s32 %p2, %r1, %r26; + @%p2 bra BB100_7; cvta.to.global.u64 %rd5, %rd3; - sub.s32 %r24, %r4, %r6; - div.s32 %r25, %r24, %r14; - setp.lt.s32 %p4, %r25, 2; - setp.eq.s32 %p5, %r25, 2; - selp.b32 %r26, 3, 2, %p5; - selp.b32 %r27, %r25, %r26, %p4; - rem.s32 %r28, %r24, %r14; - mad.lo.s32 %r29, %r27, %r14, %r28; - mul.wide.s32 %rd6, %r29, 8; + sub.s32 %r27, %r1, %r4; + div.s32 %r28, %r27, %r21; + setp.lt.s32 %p3, %r28, 2; + setp.eq.s32 %p4, %r28, 2; + selp.b32 %r29, 3, 2, %p4; + selp.b32 %r30, %r28, %r29, %p3; + rem.s32 %r31, %r27, %r21; + mad.lo.s32 %r32, %r30, %r21, %r31; + mul.wide.s32 %rd6, %r32, 8; add.s64 %rd7, %rd5, %rd6; ld.global.f64 %fd1, [%rd7]; - mul.wide.s32 %rd8, %r4, 8; + mul.wide.s32 %rd8, %r1, 8; add.s64 %rd9, %rd1, %rd8; st.global.f64 [%rd9], %fd1; + bra.uni BB100_7; -BB100_6: - setp.eq.s32 %p10, %r51, -1; - @%p10 bra BB100_8; +BB100_4: + sub.s32 %r46, %r1, %r7; + rem.s32 %r44, %r46, %r47; + div.s32 %r43, %r44, %r21; + add.s32 %r42, %r43, %r45; + mov.u32 %r45, %r21; +BB100_6: cvta.to.global.u64 %rd10, %rd2; - mul.wide.s32 %rd11, %r51, 8; + div.s32 %r33, %r46, %r47; + setp.eq.s32 %p6, %r33, 2; + selp.b32 %r34, 3, 2, %p6; + setp.lt.s32 %p7, %r33, 2; + selp.b32 %r35, %r33, %r34, %p7; + rem.s32 %r36, %r44, %r21; + sub.s32 %r37, %r1, %r44; + add.s32 %r38, %r37, %r43; + mad.lo.s32 %r39, %r36, %r45, %r38; + mad.lo.s32 %r40, %r42, %r3, %r36; + mad.lo.s32 %r41, %r35, %r21, %r40; + mul.wide.s32 %rd11, %r41, 8; add.s64 %rd12, %rd10, %rd11; ld.global.f64 %fd2, [%rd12]; - mul.wide.s32 %rd13, %r52, 8; + mul.wide.s32 %rd13, %r39, 8; add.s64 %rd14, %rd1, %rd13; st.global.f64 [%rd14], %fd2; -BB100_8: +BB100_7: ret; } @@ -11907,107 +11896,96 @@ BB100_8: .param .u32 prepare_lstm_weight_f_param_4 ) { - .reg .pred %p<11>; + .reg .pred %p<8>; .reg .f32 %f<3>; - .reg .b32 %r<53>; + .reg .b32 %r<48>; .reg .b64 %rd<15>; ld.param.u64 %rd2, [prepare_lstm_weight_f_param_0]; ld.param.u64 %rd3, [prepare_lstm_weight_f_param_1]; ld.param.u64 %rd4, [prepare_lstm_weight_f_param_2]; - ld.param.u32 %r13, [prepare_lstm_weight_f_param_3]; - ld.param.u32 %r14, [prepare_lstm_weight_f_param_4]; + ld.param.u32 %r45, [prepare_lstm_weight_f_param_3]; + ld.param.u32 %r21, [prepare_lstm_weight_f_param_4]; cvta.to.global.u64 %rd1, %rd4; - mul.lo.s32 %r1, %r14, %r13; - shl.b32 %r2, %r1, 2; - shl.b32 %r3, %r14, 2; - mov.u32 %r15, %ntid.x; - mov.u32 %r16, %ctaid.x; - mov.u32 %r17, %tid.x; - mad.lo.s32 %r4, %r15, %r16, %r17; - setp.lt.s32 %p1, %r4, %r2; - @%p1 bra BB101_5; + mov.u32 %r22, %ntid.x; + mov.u32 %r23, %ctaid.x; + mov.u32 %r24, %tid.x; + mad.lo.s32 %r1, %r22, %r23, %r24; + add.s32 %r2, %r21, %r45; + shl.b32 %r3, %r21, 2; + mul.lo.s32 %r4, %r2, %r3; + setp.lt.s32 %p1, %r1, %r4; + @%p1 bra BB101_3; bra.uni BB101_1; +BB101_3: + mul.lo.s32 %r5, %r21, %r45; + mul.lo.s32 %r47, %r21, %r21; + shl.b32 %r7, %r5, 2; + setp.lt.s32 %p5, %r1, %r7; + @%p5 bra BB101_5; + bra.uni BB101_4; + BB101_5: - rem.s32 %r42, %r4, %r1; - div.s32 %r43, %r42, %r14; - div.s32 %r44, %r4, %r1; - setp.lt.s32 %p8, %r44, 2; - setp.eq.s32 %p9, %r44, 2; - selp.b32 %r45, 3, 2, %p9; - selp.b32 %r46, %r44, %r45, %p8; - rem.s32 %r47, %r42, %r14; - sub.s32 %r48, %r4, %r42; - add.s32 %r49, %r48, %r43; - mad.lo.s32 %r52, %r47, %r13, %r49; - mad.lo.s32 %r50, %r43, %r3, %r47; - mad.lo.s32 %r51, %r46, %r14, %r50; + rem.s32 %r44, %r1, %r5; + div.s32 %r42, %r44, %r21; + mov.u32 %r43, %r42; + mov.u32 %r46, %r1; + mov.u32 %r47, %r5; bra.uni BB101_6; BB101_1: - add.s32 %r5, %r14, %r13; - mul.lo.s32 %r6, %r5, %r3; - setp.lt.s32 %p2, %r4, %r6; - @%p2 bra BB101_4; - bra.uni BB101_2; - -BB101_4: - mul.lo.s32 %r30, %r14, %r14; - sub.s32 %r31, %r4, %r2; - rem.s32 %r32, %r31, %r30; - div.s32 %r33, %r32, %r14; - add.s32 %r34, %r33, %r13; - div.s32 %r35, %r31, %r30; - setp.lt.s32 %p6, %r35, 2; - setp.eq.s32 %p7, %r35, 2; - selp.b32 %r36, 3, 2, %p7; - selp.b32 %r37, %r35, %r36, %p6; - rem.s32 %r38, %r32, %r14; - sub.s32 %r39, %r4, %r32; - add.s32 %r40, %r39, %r33; - mad.lo.s32 %r52, %r38, %r14, %r40; - mad.lo.s32 %r41, %r34, %r3, %r38; - mad.lo.s32 %r51, %r37, %r14, %r41; - bra.uni BB101_6; - -BB101_2: - add.s32 %r20, %r5, 1; - mul.lo.s32 %r21, %r20, %r3; - mov.u32 %r51, -1; - setp.ge.s32 %p3, %r4, %r21; - @%p3 bra BB101_6; + add.s32 %r25, %r2, 1; + mul.lo.s32 %r26, %r25, %r3; + setp.ge.s32 %p2, %r1, %r26; + @%p2 bra BB101_7; cvta.to.global.u64 %rd5, %rd3; - sub.s32 %r24, %r4, %r6; - div.s32 %r25, %r24, %r14; - setp.lt.s32 %p4, %r25, 2; - setp.eq.s32 %p5, %r25, 2; - selp.b32 %r26, 3, 2, %p5; - selp.b32 %r27, %r25, %r26, %p4; - rem.s32 %r28, %r24, %r14; - mad.lo.s32 %r29, %r27, %r14, %r28; - mul.wide.s32 %rd6, %r29, 4; + sub.s32 %r27, %r1, %r4; + div.s32 %r28, %r27, %r21; + setp.lt.s32 %p3, %r28, 2; + setp.eq.s32 %p4, %r28, 2; + selp.b32 %r29, 3, 2, %p4; + selp.b32 %r30, %r28, %r29, %p3; + rem.s32 %r31, %r27, %r21; + mad.lo.s32 %r32, %r30, %r21, %r31; + mul.wide.s32 %rd6, %r32, 4; add.s64 %rd7, %rd5, %rd6; ld.global.f32 %f1, [%rd7]; - mul.wide.s32 %rd8, %r4, 4; + mul.wide.s32 %rd8, %r1, 4; add.s64 %rd9, %rd1, %rd8; st.global.f32 [%rd9], %f1; + bra.uni BB101_7; -BB101_6: - setp.eq.s32 %p10, %r51, -1; - @%p10 bra BB101_8; +BB101_4: + sub.s32 %r46, %r1, %r7; + rem.s32 %r44, %r46, %r47; + div.s32 %r43, %r44, %r21; + add.s32 %r42, %r43, %r45; + mov.u32 %r45, %r21; +BB101_6: cvta.to.global.u64 %rd10, %rd2; - mul.wide.s32 %rd11, %r51, 4; + div.s32 %r33, %r46, %r47; + setp.eq.s32 %p6, %r33, 2; + selp.b32 %r34, 3, 2, %p6; + setp.lt.s32 %p7, %r33, 2; + selp.b32 %r35, %r33, %r34, %p7; + rem.s32 %r36, %r44, %r21; + sub.s32 %r37, %r1, %r44; + add.s32 %r38, %r37, %r43; + mad.lo.s32 %r39, %r36, %r45, %r38; + mad.lo.s32 %r40, %r42, %r3, %r36; + mad.lo.s32 %r41, %r35, %r21, %r40; + mul.wide.s32 %rd11, %r41, 4; add.s64 %rd12, %rd10, %rd11; ld.global.f32 %f2, [%rd12]; - mul.wide.s32 %rd13, %r52, 4; + mul.wide.s32 %rd13, %r39, 4; add.s64 %rd14, %rd1, %rd13; st.global.f32 [%rd14], %f2; -BB101_8: +BB101_7: ret; } @@ -12463,12 +12441,448 @@ BB105_2: ret; } + // .globl prepare_lstm_backward_gradients_d +.visible .entry prepare_lstm_backward_gradients_d( + .param .u64 prepare_lstm_backward_gradients_d_param_0, + .param .u64 prepare_lstm_backward_gradients_d_param_1, + .param .u32 prepare_lstm_backward_gradients_d_param_2, + .param .u32 prepare_lstm_backward_gradients_d_param_3, + .param .u32 prepare_lstm_backward_gradients_d_param_4, + .param .u32 prepare_lstm_backward_gradients_d_param_5, + .param .u32 prepare_lstm_backward_gradients_d_param_6 +) +{ + .reg .pred %p<5>; + .reg .b32 %r<20>; + .reg .f64 %fd<3>; + .reg .b64 %rd<11>; + + + ld.param.u64 %rd3, [prepare_lstm_backward_gradients_d_param_0]; + ld.param.u64 %rd4, [prepare_lstm_backward_gradients_d_param_1]; + ld.param.u32 %r2, [prepare_lstm_backward_gradients_d_param_2]; + ld.param.u32 %r3, [prepare_lstm_backward_gradients_d_param_3]; + ld.param.u32 %r4, [prepare_lstm_backward_gradients_d_param_4]; + ld.param.u32 %r5, [prepare_lstm_backward_gradients_d_param_5]; + ld.param.u32 %r6, [prepare_lstm_backward_gradients_d_param_6]; + cvta.to.global.u64 %rd1, %rd4; + mov.u32 %r7, %ntid.x; + mov.u32 %r8, %ctaid.x; + mov.u32 %r9, %tid.x; + mad.lo.s32 %r1, %r7, %r8, %r9; + setp.lt.s32 %p1, %r1, %r5; + setp.ne.s32 %p2, %r6, 0; + and.pred %p3, %p1, %p2; + cvta.to.global.u64 %rd5, %rd3; + mul.wide.s32 %rd6, %r1, 8; + add.s64 %rd2, %rd5, %rd6; + @%p3 bra BB106_3; + bra.uni BB106_1; + +BB106_3: + mul.lo.s32 %r13, %r4, %r3; + div.s32 %r14, %r1, %r13; + rem.s32 %r15, %r1, %r13; + div.s32 %r16, %r15, %r4; + rem.s32 %r17, %r15, %r4; + ld.global.f64 %fd2, [%rd2]; + mad.lo.s32 %r18, %r16, %r2, %r14; + mad.lo.s32 %r19, %r18, %r4, %r17; + mul.wide.s32 %rd9, %r19, 8; + add.s64 %rd10, %rd1, %rd9; + st.global.f64 [%rd10], %fd2; + bra.uni BB106_4; + +BB106_1: + setp.ge.s32 %p4, %r1, %r5; + @%p4 bra BB106_4; + + ld.global.f64 %fd1, [%rd2]; + add.s32 %r10, %r3, -1; + mul.lo.s32 %r11, %r10, %r2; + mad.lo.s32 %r12, %r11, %r4, %r1; + mul.wide.s32 %rd7, %r12, 8; + add.s64 %rd8, %rd1, %rd7; + st.global.f64 [%rd8], %fd1; + +BB106_4: + ret; +} + + // .globl prepare_lstm_backward_gradients_f +.visible .entry prepare_lstm_backward_gradients_f( + .param .u64 prepare_lstm_backward_gradients_f_param_0, + .param .u64 prepare_lstm_backward_gradients_f_param_1, + .param .u32 prepare_lstm_backward_gradients_f_param_2, + .param .u32 prepare_lstm_backward_gradients_f_param_3, + .param .u32 prepare_lstm_backward_gradients_f_param_4, + .param .u32 prepare_lstm_backward_gradients_f_param_5, + .param .u32 prepare_lstm_backward_gradients_f_param_6 +) +{ + .reg .pred %p<5>; + .reg .f32 %f<3>; + .reg .b32 %r<20>; + .reg .b64 %rd<11>; + + + ld.param.u64 %rd3, [prepare_lstm_backward_gradients_f_param_0]; + ld.param.u64 %rd4, [prepare_lstm_backward_gradients_f_param_1]; + ld.param.u32 %r2, [prepare_lstm_backward_gradients_f_param_2]; + ld.param.u32 %r3, [prepare_lstm_backward_gradients_f_param_3]; + ld.param.u32 %r4, [prepare_lstm_backward_gradients_f_param_4]; + ld.param.u32 %r5, [prepare_lstm_backward_gradients_f_param_5]; + ld.param.u32 %r6, [prepare_lstm_backward_gradients_f_param_6]; + cvta.to.global.u64 %rd1, %rd4; + mov.u32 %r7, %ntid.x; + mov.u32 %r8, %ctaid.x; + mov.u32 %r9, %tid.x; + mad.lo.s32 %r1, %r7, %r8, %r9; + setp.lt.s32 %p1, %r1, %r5; + setp.ne.s32 %p2, %r6, 0; + and.pred %p3, %p1, %p2; + cvta.to.global.u64 %rd5, %rd3; + mul.wide.s32 %rd6, %r1, 4; + add.s64 %rd2, %rd5, %rd6; + @%p3 bra BB107_3; + bra.uni BB107_1; + +BB107_3: + mul.lo.s32 %r13, %r4, %r3; + div.s32 %r14, %r1, %r13; + rem.s32 %r15, %r1, %r13; + div.s32 %r16, %r15, %r4; + rem.s32 %r17, %r15, %r4; + ld.global.f32 %f2, [%rd2]; + mad.lo.s32 %r18, %r16, %r2, %r14; + mad.lo.s32 %r19, %r18, %r4, %r17; + mul.wide.s32 %rd9, %r19, 4; + add.s64 %rd10, %rd1, %rd9; + st.global.f32 [%rd10], %f2; + bra.uni BB107_4; + +BB107_1: + setp.ge.s32 %p4, %r1, %r5; + @%p4 bra BB107_4; + + ld.global.f32 %f1, [%rd2]; + add.s32 %r10, %r3, -1; + mul.lo.s32 %r11, %r10, %r2; + mad.lo.s32 %r12, %r11, %r4, %r1; + mul.wide.s32 %rd7, %r12, 4; + add.s64 %rd8, %rd1, %rd7; + st.global.f32 [%rd8], %f1; + +BB107_4: + ret; +} + + // .globl prepare_lstm_dweight_d +.visible .entry prepare_lstm_dweight_d( + .param .u64 prepare_lstm_dweight_d_param_0, + .param .u64 prepare_lstm_dweight_d_param_1, + .param .u64 prepare_lstm_dweight_d_param_2, + .param .u32 prepare_lstm_dweight_d_param_3, + .param .u32 prepare_lstm_dweight_d_param_4 +) +{ + .reg .pred %p<8>; + .reg .b32 %r<48>; + .reg .f64 %fd<3>; + .reg .b64 %rd<15>; + + + ld.param.u64 %rd2, [prepare_lstm_dweight_d_param_0]; + ld.param.u64 %rd3, [prepare_lstm_dweight_d_param_1]; + ld.param.u64 %rd4, [prepare_lstm_dweight_d_param_2]; + ld.param.u32 %r45, [prepare_lstm_dweight_d_param_3]; + ld.param.u32 %r21, [prepare_lstm_dweight_d_param_4]; + cvta.to.global.u64 %rd1, %rd4; + mov.u32 %r22, %ntid.x; + mov.u32 %r23, %ctaid.x; + mov.u32 %r24, %tid.x; + mad.lo.s32 %r1, %r22, %r23, %r24; + add.s32 %r2, %r21, %r45; + shl.b32 %r3, %r21, 2; + mul.lo.s32 %r4, %r2, %r3; + setp.lt.s32 %p1, %r1, %r4; + @%p1 bra BB108_3; + bra.uni BB108_1; + +BB108_3: + mul.lo.s32 %r5, %r21, %r45; + mul.lo.s32 %r47, %r21, %r21; + shl.b32 %r7, %r5, 2; + setp.lt.s32 %p5, %r1, %r7; + @%p5 bra BB108_5; + bra.uni BB108_4; + +BB108_5: + rem.s32 %r44, %r1, %r5; + div.s32 %r42, %r44, %r21; + mov.u32 %r43, %r42; + mov.u32 %r46, %r1; + mov.u32 %r47, %r5; + bra.uni BB108_6; + +BB108_1: + add.s32 %r25, %r2, 1; + mul.lo.s32 %r26, %r25, %r3; + setp.ge.s32 %p2, %r1, %r26; + @%p2 bra BB108_7; + + cvta.to.global.u64 %rd5, %rd3; + sub.s32 %r27, %r1, %r4; + div.s32 %r28, %r27, %r21; + setp.lt.s32 %p3, %r28, 2; + setp.eq.s32 %p4, %r28, 2; + selp.b32 %r29, 3, 2, %p4; + selp.b32 %r30, %r28, %r29, %p3; + rem.s32 %r31, %r27, %r21; + mad.lo.s32 %r32, %r30, %r21, %r31; + mul.wide.s32 %rd6, %r1, 8; + add.s64 %rd7, %rd1, %rd6; + ld.global.f64 %fd1, [%rd7]; + mul.wide.s32 %rd8, %r32, 8; + add.s64 %rd9, %rd5, %rd8; + st.global.f64 [%rd9], %fd1; + bra.uni BB108_7; + +BB108_4: + sub.s32 %r46, %r1, %r7; + rem.s32 %r44, %r46, %r47; + div.s32 %r43, %r44, %r21; + add.s32 %r42, %r43, %r45; + mov.u32 %r45, %r21; + +BB108_6: + cvta.to.global.u64 %rd10, %rd2; + div.s32 %r33, %r46, %r47; + setp.eq.s32 %p6, %r33, 2; + selp.b32 %r34, 3, 2, %p6; + setp.lt.s32 %p7, %r33, 2; + selp.b32 %r35, %r33, %r34, %p7; + rem.s32 %r36, %r44, %r21; + sub.s32 %r37, %r1, %r44; + add.s32 %r38, %r37, %r43; + mad.lo.s32 %r39, %r36, %r45, %r38; + mad.lo.s32 %r40, %r42, %r3, %r36; + mad.lo.s32 %r41, %r35, %r21, %r40; + mul.wide.s32 %rd11, %r39, 8; + add.s64 %rd12, %rd1, %rd11; + ld.global.f64 %fd2, [%rd12]; + mul.wide.s32 %rd13, %r41, 8; + add.s64 %rd14, %rd10, %rd13; + st.global.f64 [%rd14], %fd2; + +BB108_7: + ret; +} + + // .globl prepare_lstm_dweight_f +.visible .entry prepare_lstm_dweight_f( + .param .u64 prepare_lstm_dweight_f_param_0, + .param .u64 prepare_lstm_dweight_f_param_1, + .param .u64 prepare_lstm_dweight_f_param_2, + .param .u32 prepare_lstm_dweight_f_param_3, + .param .u32 prepare_lstm_dweight_f_param_4 +) +{ + .reg .pred %p<8>; + .reg .f32 %f<3>; + .reg .b32 %r<48>; + .reg .b64 %rd<15>; + + + ld.param.u64 %rd2, [prepare_lstm_dweight_f_param_0]; + ld.param.u64 %rd3, [prepare_lstm_dweight_f_param_1]; + ld.param.u64 %rd4, [prepare_lstm_dweight_f_param_2]; + ld.param.u32 %r45, [prepare_lstm_dweight_f_param_3]; + ld.param.u32 %r21, [prepare_lstm_dweight_f_param_4]; + cvta.to.global.u64 %rd1, %rd4; + mov.u32 %r22, %ntid.x; + mov.u32 %r23, %ctaid.x; + mov.u32 %r24, %tid.x; + mad.lo.s32 %r1, %r22, %r23, %r24; + add.s32 %r2, %r21, %r45; + shl.b32 %r3, %r21, 2; + mul.lo.s32 %r4, %r2, %r3; + setp.lt.s32 %p1, %r1, %r4; + @%p1 bra BB109_3; + bra.uni BB109_1; + +BB109_3: + mul.lo.s32 %r5, %r21, %r45; + mul.lo.s32 %r47, %r21, %r21; + shl.b32 %r7, %r5, 2; + setp.lt.s32 %p5, %r1, %r7; + @%p5 bra BB109_5; + bra.uni BB109_4; + +BB109_5: + rem.s32 %r44, %r1, %r5; + div.s32 %r42, %r44, %r21; + mov.u32 %r43, %r42; + mov.u32 %r46, %r1; + mov.u32 %r47, %r5; + bra.uni BB109_6; + +BB109_1: + add.s32 %r25, %r2, 1; + mul.lo.s32 %r26, %r25, %r3; + setp.ge.s32 %p2, %r1, %r26; + @%p2 bra BB109_7; + + cvta.to.global.u64 %rd5, %rd3; + sub.s32 %r27, %r1, %r4; + div.s32 %r28, %r27, %r21; + setp.lt.s32 %p3, %r28, 2; + setp.eq.s32 %p4, %r28, 2; + selp.b32 %r29, 3, 2, %p4; + selp.b32 %r30, %r28, %r29, %p3; + rem.s32 %r31, %r27, %r21; + mad.lo.s32 %r32, %r30, %r21, %r31; + mul.wide.s32 %rd6, %r1, 4; + add.s64 %rd7, %rd1, %rd6; + ld.global.f32 %f1, [%rd7]; + mul.wide.s32 %rd8, %r32, 4; + add.s64 %rd9, %rd5, %rd8; + st.global.f32 [%rd9], %f1; + bra.uni BB109_7; + +BB109_4: + sub.s32 %r46, %r1, %r7; + rem.s32 %r44, %r46, %r47; + div.s32 %r43, %r44, %r21; + add.s32 %r42, %r43, %r45; + mov.u32 %r45, %r21; + +BB109_6: + cvta.to.global.u64 %rd10, %rd2; + div.s32 %r33, %r46, %r47; + setp.eq.s32 %p6, %r33, 2; + selp.b32 %r34, 3, 2, %p6; + setp.lt.s32 %p7, %r33, 2; + selp.b32 %r35, %r33, %r34, %p7; + rem.s32 %r36, %r44, %r21; + sub.s32 %r37, %r1, %r44; + add.s32 %r38, %r37, %r43; + mad.lo.s32 %r39, %r36, %r45, %r38; + mad.lo.s32 %r40, %r42, %r3, %r36; + mad.lo.s32 %r41, %r35, %r21, %r40; + mul.wide.s32 %rd11, %r39, 4; + add.s64 %rd12, %rd1, %rd11; + ld.global.f32 %f2, [%rd12]; + mul.wide.s32 %rd13, %r41, 4; + add.s64 %rd14, %rd10, %rd13; + st.global.f32 [%rd14], %f2; + +BB109_7: + ret; +} + + // .globl prepare_lstm_dinput_d +.visible .entry prepare_lstm_dinput_d( + .param .u64 prepare_lstm_dinput_d_param_0, + .param .u64 prepare_lstm_dinput_d_param_1, + .param .u32 prepare_lstm_dinput_d_param_2, + .param .u32 prepare_lstm_dinput_d_param_3, + .param .u32 prepare_lstm_dinput_d_param_4, + .param .u32 prepare_lstm_dinput_d_param_5 +) +{ + .reg .pred %p<2>; + .reg .b32 %r<15>; + .reg .f64 %fd<2>; + .reg .b64 %rd<9>; + + + ld.param.u64 %rd1, [prepare_lstm_dinput_d_param_0]; + ld.param.u64 %rd2, [prepare_lstm_dinput_d_param_1]; + ld.param.u32 %r2, [prepare_lstm_dinput_d_param_2]; + ld.param.u32 %r3, [prepare_lstm_dinput_d_param_3]; + ld.param.u32 %r4, [prepare_lstm_dinput_d_param_4]; + ld.param.u32 %r5, [prepare_lstm_dinput_d_param_5]; + mov.u32 %r6, %ctaid.x; + mov.u32 %r7, %ntid.x; + mov.u32 %r8, %tid.x; + mad.lo.s32 %r1, %r7, %r6, %r8; + setp.ge.s32 %p1, %r1, %r5; + @%p1 bra BB110_2; + + cvta.to.global.u64 %rd3, %rd2; + rem.s32 %r9, %r1, %r4; + div.s32 %r10, %r9, %r3; + rem.s32 %r11, %r9, %r3; + div.s32 %r12, %r1, %r4; + mad.lo.s32 %r13, %r10, %r2, %r12; + mad.lo.s32 %r14, %r13, %r3, %r11; + mul.wide.s32 %rd4, %r14, 8; + add.s64 %rd5, %rd3, %rd4; + ld.global.f64 %fd1, [%rd5]; + cvta.to.global.u64 %rd6, %rd1; + mul.wide.s32 %rd7, %r1, 8; + add.s64 %rd8, %rd6, %rd7; + st.global.f64 [%rd8], %fd1; + +BB110_2: + ret; +} + + // .globl prepare_lstm_dinput_f +.visible .entry prepare_lstm_dinput_f( + .param .u64 prepare_lstm_dinput_f_param_0, + .param .u64 prepare_lstm_dinput_f_param_1, + .param .u32 prepare_lstm_dinput_f_param_2, + .param .u32 prepare_lstm_dinput_f_param_3, + .param .u32 prepare_lstm_dinput_f_param_4, + .param .u32 prepare_lstm_dinput_f_param_5 +) +{ + .reg .pred %p<2>; + .reg .f32 %f<2>; + .reg .b32 %r<15>; + .reg .b64 %rd<9>; + + + ld.param.u64 %rd1, [prepare_lstm_dinput_f_param_0]; + ld.param.u64 %rd2, [prepare_lstm_dinput_f_param_1]; + ld.param.u32 %r2, [prepare_lstm_dinput_f_param_2]; + ld.param.u32 %r3, [prepare_lstm_dinput_f_param_3]; + ld.param.u32 %r4, [prepare_lstm_dinput_f_param_4]; + ld.param.u32 %r5, [prepare_lstm_dinput_f_param_5]; + mov.u32 %r6, %ctaid.x; + mov.u32 %r7, %ntid.x; + mov.u32 %r8, %tid.x; + mad.lo.s32 %r1, %r7, %r6, %r8; + setp.ge.s32 %p1, %r1, %r5; + @%p1 bra BB111_2; + + cvta.to.global.u64 %rd3, %rd2; + rem.s32 %r9, %r1, %r4; + div.s32 %r10, %r9, %r3; + rem.s32 %r11, %r9, %r3; + div.s32 %r12, %r1, %r4; + mad.lo.s32 %r13, %r10, %r2, %r12; + mad.lo.s32 %r14, %r13, %r3, %r11; + mul.wide.s32 %rd4, %r14, 4; + add.s64 %rd5, %rd3, %rd4; + ld.global.f32 %f1, [%rd5]; + cvta.to.global.u64 %rd6, %rd1; + mul.wide.s32 %rd7, %r1, 4; + add.s64 %rd8, %rd6, %rd7; + st.global.f32 [%rd8], %f1; + +BB111_2: + ret; +} + .func (.param .b64 func_retval0) __internal_trig_reduction_slowpathd( .param .b64 __internal_trig_reduction_slowpathd_param_0, .param .b64 __internal_trig_reduction_slowpathd_param_1 ) { - .local .align 8 .b8 __local_depot106[40]; + .local .align 8 .b8 __local_depot112[40]; .reg .b64 %SP; .reg .b64 %SPL; .reg .pred %p<9>; @@ -12477,7 +12891,7 @@ BB105_2: .reg .b64 %rd<102>; - mov.u64 %rd101, __local_depot106; + mov.u64 %rd101, __local_depot112; cvta.local.u64 %SP, %rd101; ld.param.f64 %fd4, [__internal_trig_reduction_slowpathd_param_0]; ld.param.u64 %rd37, [__internal_trig_reduction_slowpathd_param_1]; @@ -12491,7 +12905,7 @@ BB105_2: shr.u32 %r3, %r1, 20; bfe.u32 %r4, %r1, 20, 11; setp.eq.s32 %p1, %r4, 2047; - @%p1 bra BB106_13; + @%p1 bra BB112_13; add.s32 %r15, %r4, -1024; shr.u32 %r16, %r15, 6; @@ -12504,7 +12918,7 @@ BB105_2: mov.u64 %rd94, 0; setp.ge.s32 %p2, %r5, %r6; mov.u64 %rd93, %rd1; - @%p2 bra BB106_4; + @%p2 bra BB112_4; mov.b64 %rd41, %fd4; shl.b64 %rd42, %rd41, 11; @@ -12521,7 +12935,7 @@ BB105_2: mov.u64 %rd91, %rd1; mov.u32 %r39, %r5; -BB106_3: +BB112_3: .pragma "nounroll"; ld.const.u64 %rd47, [%rd89]; // inline asm @@ -12551,15 +12965,15 @@ BB106_3: add.s64 %rd93, %rd93, 8; add.s64 %rd89, %rd89, 8; setp.lt.s32 %p3, %r39, %r6; - @%p3 bra BB106_3; + @%p3 bra BB112_3; -BB106_4: +BB112_4: st.local.u64 [%rd93], %rd94; ld.local.u64 %rd95, [%rd1+16]; ld.local.u64 %rd96, [%rd1+24]; and.b32 %r9, %r3, 63; setp.eq.s32 %p4, %r9, 0; - @%p4 bra BB106_6; + @%p4 bra BB112_6; mov.u32 %r27, 64; sub.s32 %r28, %r27, %r9; @@ -12571,7 +12985,7 @@ BB106_4: shr.u64 %rd55, %rd54, %r28; or.b64 %rd95, %rd55, %rd53; -BB106_6: +BB112_6: cvta.to.local.u64 %rd56, %rd37; shr.u64 %rd57, %rd96, 62; cvt.u32.u64 %r29, %rd57; @@ -12588,7 +13002,7 @@ BB106_6: selp.b32 %r34, %r32, %r33, %p5; st.local.u32 [%rd56], %r34; setp.eq.s32 %p6, %r31, 0; - @%p6 bra BB106_8; + @%p6 bra BB112_8; mov.u64 %rd64, 0; // inline asm @@ -12608,10 +13022,10 @@ BB106_6: // inline asm xor.b32 %r40, %r40, -2147483648; -BB106_8: +BB112_8: clz.b64 %r41, %rd98; setp.eq.s32 %p7, %r41, 0; - @%p7 bra BB106_10; + @%p7 bra BB112_10; shl.b64 %rd67, %rd98, %r41; mov.u32 %r35, 64; @@ -12619,7 +13033,7 @@ BB106_8: shr.u64 %rd68, %rd97, %r36; or.b64 %rd98, %rd68, %rd67; -BB106_10: +BB112_10: mov.u64 %rd72, -3958705157555305931; // inline asm { @@ -12640,7 +13054,7 @@ BB106_10: } // inline asm setp.lt.s64 %p8, %rd100, 1; - @%p8 bra BB106_12; + @%p8 bra BB112_12; // inline asm { @@ -12659,7 +13073,7 @@ BB106_10: // inline asm add.s32 %r41, %r41, 1; -BB106_12: +BB112_12: cvt.u64.u32 %rd79, %r40; shl.b64 %rd80, %rd79, 32; mov.u32 %r37, 1022; @@ -12674,7 +13088,7 @@ BB106_12: or.b64 %rd88, %rd87, %rd80; mov.b64 %fd4, %rd88; -BB106_13: +BB112_13: st.param.f64 [func_retval0+0], %fd4; ret; } @@ -12702,7 +13116,7 @@ BB106_13: } shr.u32 %r51, %r50, 20; setp.ne.s32 %p1, %r51, 0; - @%p1 bra BB107_2; + @%p1 bra BB113_2; mul.f64 %fd14, %fd12, 0d4350000000000000; { @@ -12716,13 +13130,13 @@ BB106_13: shr.u32 %r16, %r50, 20; add.s32 %r51, %r16, -54; -BB107_2: +BB113_2: add.s32 %r52, %r51, -1023; and.b32 %r17, %r50, -2146435073; or.b32 %r18, %r17, 1072693248; mov.b64 %fd135, {%r49, %r18}; setp.lt.u32 %p2, %r18, 1073127583; - @%p2 bra BB107_4; + @%p2 bra BB113_4; { .reg .b32 %temp; @@ -12736,7 +13150,7 @@ BB107_2: mov.b64 %fd135, {%r19, %r21}; add.s32 %r52, %r51, -1022; -BB107_4: +BB113_4: add.f64 %fd15, %fd135, 0d3FF0000000000000; rcp.approx.ftz.f64 %fd16, %fd15; neg.f64 %fd17, %fd15; @@ -12899,13 +13313,13 @@ BB107_4: mov.b32 %f2, %r35; abs.f32 %f1, %f2; setp.lt.f32 %p4, %f1, 0f4086232B; - @%p4 bra BB107_7; + @%p4 bra BB113_7; setp.lt.f64 %p5, %fd4, 0d0000000000000000; add.f64 %fd129, %fd4, 0d7FF0000000000000; selp.f64 %fd136, 0d0000000000000000, %fd129, %p5; setp.geu.f32 %p6, %f1, 0f40874800; - @%p6 bra BB107_7; + @%p6 bra BB113_7; mov.f64 %fd134, 0d4338000000000000; mov.f64 %fd133, 0d3FF71547652B82FE; @@ -12927,26 +13341,26 @@ BB107_4: mov.b64 %fd131, {%r44, %r43}; mul.f64 %fd136, %fd130, %fd131; -BB107_7: +BB113_7: { .reg .b32 %temp; mov.b64 {%temp, %r45}, %fd136; } and.b32 %r46, %r45, 2147483647; setp.ne.s32 %p7, %r46, 2146435072; - @%p7 bra BB107_9; + @%p7 bra BB113_9; { .reg .b32 %temp; mov.b64 {%r47, %temp}, %fd136; } setp.eq.s32 %p8, %r47, 0; - @%p8 bra BB107_10; + @%p8 bra BB113_10; -BB107_9: +BB113_9: fma.rn.f64 %fd136, %fd136, %fd5, %fd136; -BB107_10: +BB113_10: st.param.f64 [func_retval0+0], %fd136; ret; } http://git-wip-us.apache.org/repos/asf/systemml/blob/e4220e3b/src/main/java/org/apache/sysml/hops/FunctionOp.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/FunctionOp.java b/src/main/java/org/apache/sysml/hops/FunctionOp.java index 428f357..64963d9 100644 --- a/src/main/java/org/apache/sysml/hops/FunctionOp.java +++ b/src/main/java/org/apache/sysml/hops/FunctionOp.java @@ -169,7 +169,7 @@ public class FunctionOp extends Hop long outputValues = OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(1).getDim1(), 1, 1.0); return outputVectors+outputValues; } - else if ( getFunctionName().equalsIgnoreCase("lstm") ) { + else if ( getFunctionName().equalsIgnoreCase("lstm") || getFunctionName().equalsIgnoreCase("lstm_backward") ) { // TODO: To allow for initial version to always run on the GPU return 0; } @@ -218,7 +218,7 @@ public class FunctionOp extends Hop else if (getFunctionName().equalsIgnoreCase("batch_norm2d") || getFunctionName().equalsIgnoreCase("batch_norm2d_backward")) { return 0; } - else if ( getFunctionName().equalsIgnoreCase("lstm") ) { + else if ( getFunctionName().equalsIgnoreCase("lstm") || getFunctionName().equalsIgnoreCase("lstm_backward") ) { // TODO: To allow for initial version to always run on the GPU return 0; } @@ -239,7 +239,8 @@ public class FunctionOp extends Hop @Override public boolean isGPUEnabled() { - if(getFunctionName().equalsIgnoreCase("lstm") || getFunctionName().equalsIgnoreCase("batch_norm2d") || getFunctionName().equalsIgnoreCase("batch_norm2d_backward")) + if(getFunctionName().equalsIgnoreCase("lstm") || getFunctionName().equalsIgnoreCase("lstm_backward") || + getFunctionName().equalsIgnoreCase("batch_norm2d") || getFunctionName().equalsIgnoreCase("batch_norm2d_backward")) return true; else return false; @@ -288,7 +289,7 @@ public class FunctionOp extends Hop || (getMemEstimate() >= OptimizerUtils.getLocalMemBudget() && OptimizerUtils.isSparkExecutionMode())) ? ExecType.SPARK : ExecType.CP); } - else if( getFunctionName().equalsIgnoreCase("lstm")) { + else if(getFunctionName().equalsIgnoreCase("lstm") || getFunctionName().equalsIgnoreCase("lstm_backward")) { if(!DMLScript.USE_ACCELERATOR) throw new RuntimeException("The function " + getFunctionName() + " is only supported on GPU."); _etype = ExecType.GPU; http://git-wip-us.apache.org/repos/asf/systemml/blob/e4220e3b/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java b/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java index 3514e6b..3ca8e1d 100644 --- a/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java +++ b/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java @@ -102,6 +102,15 @@ public class BuiltinFunctionExpression extends DataIdentifier public Expression getSixthExpr() { return (_args.length >= 6 ? _args[5] : null); } + + public Expression getSeventhExpr() { + return (_args.length >= 7 ? _args[6] : null); + } + + public Expression getEighthExpr() { + return (_args.length >= 8 ? _args[7] : null); + } + public Expression[] getAllExpr(){ return _args; @@ -210,13 +219,12 @@ public class BuiltinFunctionExpression extends DataIdentifier checkMatrixParam(getFifthExpr()); // setup output properties - if(getOutputs() == null || getOutputs().length != 3) { + if(getOutputs() == null || getOutputs().length != 2) { int numOutputs = getOutputs() == null ? 0 : getOutputs().length; - raiseValidateError("The builtin function lstm has three outputs, but instead found: " + numOutputs, conditional); + raiseValidateError("The builtin function lstm has two outputs, but instead found: " + numOutputs, conditional); } DataIdentifier out = (DataIdentifier) getOutputs()[0]; DataIdentifier cy = (DataIdentifier) getOutputs()[1]; - DataIdentifier reserveSpace = (DataIdentifier) getOutputs()[2]; // Output1 - out: If `return_sequences` is True, outputs for all timesteps, else outputs for the final timestep. out.setDataType(DataType.MATRIX); @@ -230,12 +238,36 @@ public class BuiltinFunctionExpression extends DataIdentifier cy.setDimensions(getExpr(4).getOutput().getDim1(), getExpr(4).getOutput().getDim2()); cy.setBlockDimensions(getExpr(4).getOutput().getRowsInBlock(), getExpr(4).getOutput().getColumnsInBlock()); - // Output3 - reserve space. - reserveSpace.setDataType(DataType.MATRIX); - reserveSpace.setValueType(ValueType.DOUBLE); - reserveSpace.setDimensions(-1, -1); - reserveSpace.setBlockDimensions(getFirstExpr().getOutput().getRowsInBlock(), getFirstExpr().getOutput().getColumnsInBlock()); + break; + } + case LSTM_BACKWARD: + { + // Input: X, W, b, out0, c0, return_sequences, dout, cy + checkNumParameters(8); + checkMatrixParam(getFirstExpr()); + checkMatrixParam(getSecondExpr()); + checkMatrixParam(getThirdExpr()); + checkMatrixParam(getFourthExpr()); + checkMatrixParam(getFifthExpr()); + checkMatrixParam(getSeventhExpr()); + checkMatrixParam(getEighthExpr()); + // Output: dx, dw, db, dout0, dc0 + // setup output properties + if(getOutputs().length != 5) + raiseValidateError("lstm_backward has 5 outputs", false); + + DataIdentifier dx = (DataIdentifier) getOutputs()[0]; + DataIdentifier dw = (DataIdentifier) getOutputs()[1]; + DataIdentifier db = (DataIdentifier) getOutputs()[2]; + DataIdentifier dout0 = (DataIdentifier) getOutputs()[3]; + DataIdentifier dc0 = (DataIdentifier) getOutputs()[4]; + + setDimensions(dx, getFirstExpr()); + setDimensions(dw, getSecondExpr()); + setDimensions(db, getThirdExpr()); + setDimensions(dout0, getFourthExpr()); + setDimensions(dc0, getFifthExpr()); break; } case BATCH_NORM2D: @@ -1414,7 +1446,8 @@ public class BuiltinFunctionExpression extends DataIdentifier // always unconditional (because unsupported operation) BuiltinFunctionOp op = getOpCode(); if( op==BuiltinFunctionOp.EIGEN || op==BuiltinFunctionOp.LU || op==BuiltinFunctionOp.QR || op==BuiltinFunctionOp.SVD - || op==BuiltinFunctionOp.LSTM || op==BuiltinFunctionOp.BATCH_NORM2D || op==BuiltinFunctionOp.BATCH_NORM2D_BACKWARD) + || op==BuiltinFunctionOp.LSTM || op==BuiltinFunctionOp.LSTM_BACKWARD + || op==BuiltinFunctionOp.BATCH_NORM2D || op==BuiltinFunctionOp.BATCH_NORM2D_BACKWARD) raiseValidateError("Function "+op+" needs to be called with multi-return assignment.", false, LanguageErrorCodes.INVALID_PARAMETERS); else raiseValidateError("Unsupported function "+op, false, LanguageErrorCodes.INVALID_PARAMETERS); @@ -1496,6 +1529,7 @@ public class BuiltinFunctionExpression extends DataIdentifier case LU: case EIGEN: case LSTM: + case LSTM_BACKWARD: case BATCH_NORM2D: case BATCH_NORM2D_BACKWARD: case SVD: @@ -1624,7 +1658,8 @@ public class BuiltinFunctionExpression extends DataIdentifier raiseValidateError("Missing argument for function " + this.getOpCode(), false, LanguageErrorCodes.INVALID_PARAMETERS); } - + + // Not sure the rationale for the first two if loops, but will keep them for backward compatibility if (((count == 1) && (getSecondExpr() != null || getThirdExpr() != null)) || ((count == 2) && (getThirdExpr() != null))) { raiseValidateError("Invalid number of arguments for function " + this.getOpCode().toString().toLowerCase() @@ -1633,6 +1668,9 @@ public class BuiltinFunctionExpression extends DataIdentifier || ((count == 3) && (getSecondExpr() == null || getThirdExpr() == null))) { raiseValidateError("Missing argument for function " + this.getOpCode(), false, LanguageErrorCodes.INVALID_PARAMETERS); + } else if(count > 0 && (_args == null || _args.length < count)) { + raiseValidateError("Missing argument for function " + this.getOpCode(), false, + LanguageErrorCodes.INVALID_PARAMETERS); } } @@ -1909,6 +1947,8 @@ public class BuiltinFunctionExpression extends DataIdentifier bifop = Expression.BuiltinFunctionOp.EIGEN; else if (functionName.equals("lstm")) bifop = Expression.BuiltinFunctionOp.LSTM; + else if (functionName.equals("lstm_backward")) + bifop = Expression.BuiltinFunctionOp.LSTM_BACKWARD; else if (functionName.equals("batch_norm2d")) bifop = Expression.BuiltinFunctionOp.BATCH_NORM2D; else if (functionName.equals("batch_norm2d_backward")) http://git-wip-us.apache.org/repos/asf/systemml/blob/e4220e3b/src/main/java/org/apache/sysml/parser/DMLTranslator.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/parser/DMLTranslator.java b/src/main/java/org/apache/sysml/parser/DMLTranslator.java index 102ca54..d29a8f4 100644 --- a/src/main/java/org/apache/sysml/parser/DMLTranslator.java +++ b/src/main/java/org/apache/sysml/parser/DMLTranslator.java @@ -2225,6 +2225,7 @@ public class DMLTranslator case LU: case EIGEN: case LSTM: + case LSTM_BACKWARD: case BATCH_NORM2D: case BATCH_NORM2D_BACKWARD: case SVD: http://git-wip-us.apache.org/repos/asf/systemml/blob/e4220e3b/src/main/java/org/apache/sysml/parser/Expression.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/parser/Expression.java b/src/main/java/org/apache/sysml/parser/Expression.java index 2808609..1708ed3 100644 --- a/src/main/java/org/apache/sysml/parser/Expression.java +++ b/src/main/java/org/apache/sysml/parser/Expression.java @@ -92,7 +92,7 @@ public abstract class Expression implements ParseInfo EXISTS, CONV2D, CONV2D_BACKWARD_FILTER, CONV2D_BACKWARD_DATA, BIAS_ADD, BIAS_MULTIPLY, MAX_POOL, AVG_POOL, MAX_POOL_BACKWARD, AVG_POOL_BACKWARD, - LSTM, BATCH_NORM2D, BATCH_NORM2D_BACKWARD, + LSTM, LSTM_BACKWARD, BATCH_NORM2D, BATCH_NORM2D_BACKWARD, EXP, FLOOR, IFELSE, http://git-wip-us.apache.org/repos/asf/systemml/blob/e4220e3b/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java b/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java index 049577b..8e9bb47 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java @@ -57,6 +57,7 @@ public class GPUInstructionParser extends InstructionParser String2GPUInstructionType.put( "bias_multiply", GPUINSTRUCTION_TYPE.Convolution); String2GPUInstructionType.put( "channel_sums", GPUINSTRUCTION_TYPE.Convolution); String2GPUInstructionType.put( "lstm", GPUINSTRUCTION_TYPE.Convolution); + String2GPUInstructionType.put( "lstm_backward", GPUINSTRUCTION_TYPE.Convolution); String2GPUInstructionType.put( "batch_norm2d", GPUINSTRUCTION_TYPE.Convolution); String2GPUInstructionType.put( "batch_norm2d_backward", GPUINSTRUCTION_TYPE.Convolution); http://git-wip-us.apache.org/repos/asf/systemml/blob/e4220e3b/src/main/java/org/apache/sysml/runtime/instructions/gpu/ConvolutionGPUInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/gpu/ConvolutionGPUInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/gpu/ConvolutionGPUInstruction.java index e523a45..7d8a0fc 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/gpu/ConvolutionGPUInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/gpu/ConvolutionGPUInstruction.java @@ -73,7 +73,7 @@ public class ConvolutionGPUInstruction extends GPUInstruction { _intermediateMemoryBudget = intermediateMemoryBudget; } public ConvolutionGPUInstruction(CPOperand in1, CPOperand in2, CPOperand in3, CPOperand in4, CPOperand in5, CPOperand in6, - CPOperand out, CPOperand out2, CPOperand out3, String opcode, String istr, + CPOperand out, CPOperand out2, String opcode, String istr, double intermediateMemoryBudget) throws DMLRuntimeException { super(new ReorgOperator(SwapIndex.getSwapIndexFnObject()), opcode, istr); _input1 = in1; @@ -85,7 +85,6 @@ public class ConvolutionGPUInstruction extends GPUInstruction { _gputype = GPUINSTRUCTION_TYPE.Convolution; _output = out; _output2 = out2; - _output3 = out3; _intermediateMemoryBudget = intermediateMemoryBudget; } @@ -283,7 +282,7 @@ public class ConvolutionGPUInstruction extends GPUInstruction { return new ConvolutionGPUInstruction(in, in2, in3, out, opcode, str, 0); } else if (opcode.equalsIgnoreCase("lstm")) { - InstructionUtils.checkNumFields(parts, 9); + InstructionUtils.checkNumFields(parts, 8); CPOperand in1 = new CPOperand(parts[1]); CPOperand in2 = new CPOperand(parts[2]); CPOperand in3 = new CPOperand(parts[3]); @@ -292,10 +291,9 @@ public class ConvolutionGPUInstruction extends GPUInstruction { CPOperand in6 = new CPOperand(parts[6]); CPOperand out = new CPOperand(parts[7]); CPOperand out2 = new CPOperand(parts[8]); - CPOperand out3 = new CPOperand(parts[9]); - return new ConvolutionGPUInstruction(in1, in2, in3, in4, in5, in6, out, out2, out3, opcode, str, 0); + return new ConvolutionGPUInstruction(in1, in2, in3, in4, in5, in6, out, out2, opcode, str, 0); } - else if (opcode.equalsIgnoreCase("batch_norm2d")) { + else if (opcode.equalsIgnoreCase("batch_norm2d") || opcode.equalsIgnoreCase("lstm_backward")) { InstructionUtils.checkNumFields(parts, 13); CPOperand in1 = new CPOperand(parts[1]); // image CPOperand in2 = new CPOperand(parts[2]); // scale @@ -471,6 +469,66 @@ public class ConvolutionGPUInstruction extends GPUInstruction { // return tX; // } + private void processLstmBackwardInstruction(ExecutionContext ec) throws DMLRuntimeException { + GPUStatistics.incrementNoOfExecutedGPUInst(); + GPUContext gCtx = ec.getGPUContext(0); + String instructionName = getExtendedOpcode(); + + MatrixObject out0 = getMatrixInputForGPUInstruction(ec, _input4.getName()); + int M = toInt(out0.getNumColumns()); // hiddenSize .. since out0: (N, M) + Pointer out0Pointer = LibMatrixCUDA.getDensePointer(gCtx, out0, instructionName); + + MatrixObject W = getMatrixInputForGPUInstruction(ec, _input2.getName()); + MatrixObject bias = getMatrixInputForGPUInstruction(ec, _input3.getName()); + long numRowsW = W.getNumRows(); + int D = toInt(numRowsW) - M; // since W:(D+M, 4M) ... numFeatures + Pointer sysmlWPointer = LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, W, instructionName, D+M, 4*M); + Pointer sysmlBiasPointer = LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, bias, instructionName, 1, 4*M); + Pointer cudnnWPointer = gCtx.allocate(instructionName, (D+M+2)*(4*M)*LibMatrixCUDA.sizeOfDataType); + LibMatrixCUDA.getCudaKernels(gCtx).launchKernel("prepare_lstm_weight", + ExecutionConfig.getConfigForSimpleVectorOperations((D+M+2)*(4*M)), + sysmlWPointer, sysmlBiasPointer, cudnnWPointer, D, M); + ec.releaseMatrixInputForGPUInstruction(_input2.getName()); + ec.releaseMatrixInputForGPUInstruction(_input3.getName()); + + + MatrixObject X = getMatrixInputForGPUInstruction(ec, _input1.getName()); + Pointer xPointer = LibMatrixCUDA.getDensePointer(gCtx, X, instructionName); + int N = toInt(X.getNumRows()); // batchSize .. since X:(N, T*D) + long numColsX = X.getNumColumns(); + int T = toInt(numColsX/ D); // since X:(N, T*D) ... seqLength + Pointer cudnnInput = gCtx.allocate(instructionName, (N*T*D)*LibMatrixCUDA.sizeOfDataType); + LibMatrixCUDA.getCudaKernels(gCtx).launchKernel("prepare_lstm_input", + ExecutionConfig.getConfigForSimpleVectorOperations(N*T*D), + xPointer, cudnnInput, N, D, T*D, N*T*D); + ec.releaseMatrixInputForGPUInstruction(_input1.getName()); + + Pointer c0Pointer = LibMatrixCUDA.getDensePointer(gCtx, getMatrixInputForGPUInstruction(ec, _input5.getName()), instructionName); + boolean return_sequences = ec.getScalarInput(_input6.getName(), _input6.getValueType(), _input6.isLiteral()).getBooleanValue(); + + // LibMatrixCuDNN.lstm(ec, gCtx, instructionName, + // cudnnInput, cudnnWPointer, out0Pointer, c0Pointer, return_sequences, _output.getName(), _output2.getName(), N, M, D, T); + // String xName, Pointer hx, Pointer cx, Pointer wPointer, String doutName, String dcyName, // input + // String dxName, String dwName, String dbName, String dhxName, String dcxName, // output + String dxName = _output.getName(); + String dwName = _output2.getName(); + String dbName = _output3.getName(); + String dhxName = _output4.getName(); + String dcxName = _output5.getName(); + String doutName = _input7.getName(); + String dcyName = _input8.getName(); + LibMatrixCuDNN.lstmBackward(ec, gCtx, instructionName, + cudnnInput, out0Pointer, c0Pointer, cudnnWPointer, doutName, dcyName, // input + dxName, dwName, dbName, dhxName, dcxName, // output + return_sequences, N, M, D, T); + gCtx.cudaFreeHelper(instructionName, cudnnWPointer, DMLScript.EAGER_CUDA_FREE); + gCtx.cudaFreeHelper(instructionName, cudnnInput, DMLScript.EAGER_CUDA_FREE); + + // release inputs/outputs + ec.releaseMatrixInputForGPUInstruction(_input4.getName()); + ec.releaseMatrixInputForGPUInstruction(_input5.getName()); + } + private void processLstmInstruction(ExecutionContext ec) throws DMLRuntimeException { // batchSize=N, seqLength=T, numFeatures=D and hiddenSize=M // input X:(N, T*D), ==> (T, D, N) @@ -496,6 +554,7 @@ public class ConvolutionGPUInstruction extends GPUInstruction { ExecutionConfig.getConfigForSimpleVectorOperations((D+M+2)*(4*M)), sysmlWPointer, sysmlBiasPointer, cudnnWPointer, D, M); ec.releaseMatrixInputForGPUInstruction(_input2.getName()); + ec.releaseMatrixInputForGPUInstruction(_input3.getName()); boolean return_sequences = ec.getScalarInput(_input6.getName(), _input6.getValueType(), _input6.isLiteral()).getBooleanValue(); @@ -513,17 +572,15 @@ public class ConvolutionGPUInstruction extends GPUInstruction { Pointer c0Pointer = LibMatrixCUDA.getDensePointer(gCtx, getMatrixInputForGPUInstruction(ec, _input5.getName()), instructionName); - LibMatrixCuDNN.lstm(ec, gCtx, instructionName, cudnnInput, cudnnWPointer, out0Pointer, c0Pointer, return_sequences, _output.getName(), _output2.getName(), _output3.getName(), N, M, D, T); + LibMatrixCuDNN.lstm(ec, gCtx, instructionName, cudnnInput, cudnnWPointer, out0Pointer, c0Pointer, return_sequences, _output.getName(), _output2.getName(), N, M, D, T); gCtx.cudaFreeHelper(instructionName, cudnnWPointer, DMLScript.EAGER_CUDA_FREE); gCtx.cudaFreeHelper(instructionName, cudnnInput, DMLScript.EAGER_CUDA_FREE); // release inputs/outputs - ec.releaseMatrixInputForGPUInstruction(_input3.getName()); ec.releaseMatrixInputForGPUInstruction(_input4.getName()); ec.releaseMatrixInputForGPUInstruction(_input5.getName()); ec.releaseMatrixOutputForGPUInstruction(_output2.getName()); ec.releaseMatrixOutputForGPUInstruction(_output.getName()); - ec.releaseMatrixOutputForGPUInstruction(_output3.getName()); } @Override @@ -544,6 +601,10 @@ public class ConvolutionGPUInstruction extends GPUInstruction { processLstmInstruction(ec); return; } + else if (instOpcode.equalsIgnoreCase("lstm_backward")) { + processLstmBackwardInstruction(ec); + return; + } else if (instOpcode.equalsIgnoreCase("batch_norm2d")) { processBatchNorm2dInstruction(ec); return; http://git-wip-us.apache.org/repos/asf/systemml/blob/e4220e3b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUObject.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUObject.java b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUObject.java index 576584b..328d1d4 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUObject.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUObject.java @@ -739,9 +739,9 @@ public class GPUObject { long rows = mat.getNumRows(); long cols = mat.getNumColumns(); if(rows <= 0) - throw new DMLRuntimeException("Internal error - invalid number of rows when allocating dense matrix"); + throw new DMLRuntimeException("Internal error - invalid number of rows when allocating dense matrix:" + rows); if(cols <= 0) - throw new DMLRuntimeException("Internal error - invalid number of columns when allocating dense matrix;"); + throw new DMLRuntimeException("Internal error - invalid number of columns when allocating dense matrix:" + cols); long size = getDatatypeSizeOf(rows * cols); Pointer tmp = allocate(size); setDensePointer(tmp); http://git-wip-us.apache.org/repos/asf/systemml/blob/e4220e3b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNN.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNN.java b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNN.java index e84dce7..a692739 100644 --- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNN.java +++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNN.java @@ -846,6 +846,12 @@ public class LibMatrixCuDNN extends LibMatrixCUDA { } } + static Pointer getDenseInputPointer(ExecutionContext ec, GPUContext gCtx, String instName, String inputName, + long numRows, long numCols) throws DMLRuntimeException { + MatrixObject output = ec.getMatrixInputForGPUInstruction(inputName, instName); + return LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, output, instName, toInt(numRows), toInt(numCols)); + } + static Pointer getDenseOutputPointer(ExecutionContext ec, GPUContext gCtx, String instName, String outputName, long numRows, long numCols) throws DMLRuntimeException { MatrixObject output = ec.getMatrixObject(outputName); @@ -867,7 +873,6 @@ public class LibMatrixCuDNN extends LibMatrixCUDA { * @param return_sequences Whether to return `out` at all timesteps, or just for the final timestep. * @param outputName name of the out variable. If `return_sequences` is True, outputs for all timesteps. * @param cyName name of the output cell state. Cell state for final timestep. - * @param reserveSpaceName name of reserve space. * @param N minibatch size * @param M hidden size * @param D number of features @@ -876,13 +881,13 @@ public class LibMatrixCuDNN extends LibMatrixCUDA { */ public static void lstm(ExecutionContext ec, GPUContext gCtx, String instName, Pointer X, Pointer wPointer, Pointer out0, Pointer c0, boolean return_sequences, - String outputName, String cyName, String reserveSpaceName, int N, int M, int D, int T) throws DMLRuntimeException { - singleLayerUnidirectionalRNNForward(ec, gCtx, instName, X, out0, c0, wPointer, outputName, cyName, reserveSpaceName, "lstm", return_sequences, N, M, D, T); + String outputName, String cyName, int N, int M, int D, int T) throws DMLRuntimeException { + singleLayerUnidirectionalRNNForward(ec, gCtx, instName, X, out0, c0, wPointer, outputName, cyName, "lstm", return_sequences, N, M, D, T); } private static void singleLayerUnidirectionalRNNForward(ExecutionContext ec, GPUContext gCtx, String instName, Pointer x, Pointer hx, Pointer cx, Pointer wPointer, // input - String outputName, String cyName, String reserveSpaceName, // output + String outputName, String cyName, // output String rnnMode, boolean return_sequences, int N, int M, int D, int T) throws DMLRuntimeException { boolean hasCarry = rnnMode.equalsIgnoreCase("lstm"); // Get output pointers @@ -891,8 +896,7 @@ public class LibMatrixCuDNN extends LibMatrixCUDA { Pointer cyPointer = hasCarry ? getDenseOutputPointer(ec, gCtx, instName, cyName, N, M) : new Pointer(); // Pointer wPointer = getDensePointerForCuDNN(gCtx, w, instName, D+M+2, 4*M); - try(LibMatrixCuDNNRnnAlgorithm algo = new LibMatrixCuDNNRnnAlgorithm(ec, gCtx, instName, rnnMode, N, T, M, D, true, wPointer, reserveSpaceName)) { - jcuda.runtime.JCuda.cudaDeviceSynchronize(); + try(LibMatrixCuDNNRnnAlgorithm algo = new LibMatrixCuDNNRnnAlgorithm(ec, gCtx, instName, rnnMode, N, T, M, D, true, wPointer)) { JCudnn.cudnnRNNForwardTraining(gCtx.getCudnnHandle(), algo.rnnDesc, T, algo.xDesc, x, algo.hxDesc, hx, @@ -915,6 +919,88 @@ public class LibMatrixCuDNN extends LibMatrixCUDA { gCtx.cudaFreeHelper(instName, cudnnYPointer, DMLScript.EAGER_CUDA_FREE); } + public static void lstmBackward(ExecutionContext ec, GPUContext gCtx, String instName, + Pointer x, Pointer hx, Pointer cx, Pointer wPointer, String doutName, String dcyName, // input + String dxName, String dwName, String dbName, String dhxName, String dcxName, // output + boolean return_sequences, int N, int M, int D, int T) throws DMLRuntimeException { + // Transform the input dout and prepare them for cudnnRNNBackwardData + Pointer dy = gCtx.allocate(instName, N*T*M*sizeOfDataType); + int size = return_sequences ? N*T*M : N*M; + LibMatrixCUDA.getCudaKernels(gCtx).launchKernel("prepare_lstm_backward_gradients", + ExecutionConfig.getConfigForSimpleVectorOperations(size), + getDenseInputPointer(ec, gCtx, instName, doutName, N, return_sequences ? T*M : M), + dy, N, T, M, size, return_sequences ? 1 : 0); + ec.releaseMatrixInputForGPUInstruction(doutName); + + // Allocate intermediate pointers computed by forward + Pointer yPointer = gCtx.allocate(instName, N*T*M*sizeOfDataType); + try(LibMatrixCuDNNRnnAlgorithm algo = new LibMatrixCuDNNRnnAlgorithm(ec, gCtx, instName, "lstm", N, T, M, D, true, wPointer)) { + JCudnn.cudnnRNNForwardTraining(gCtx.getCudnnHandle(), algo.rnnDesc, T, + algo.xDesc, x, + algo.hxDesc, hx, + algo.cxDesc, cx, + algo.wDesc, wPointer, + algo.yDesc, yPointer, + algo.hyDesc, new Pointer(), + algo.cyDesc, new Pointer(), + algo.workSpace, algo.sizeInBytes, + algo.reserveSpace, algo.reserveSpaceSizeInBytes); + + Pointer cudnnDx = gCtx.allocate(instName, N*T*D*LibMatrixCUDA.sizeOfDataType); + JCudnn.cudnnRNNBackwardData(gCtx.getCudnnHandle(), algo.rnnDesc, T, + algo.yDesc, yPointer, + // ---------------------- + // Additional inputs: + algo.dyDesc, dy, + algo.dhyDesc, new Pointer(), + algo.dcyDesc, getDenseInputPointer(ec, gCtx, instName, dcyName, N, M), + // ---------------------- + algo.wDesc, wPointer, + algo.hxDesc, hx, + algo.cxDesc, cx, + // ---------------------- + // Output: + algo.dxDesc, cudnnDx, + algo.dhxDesc, getDenseOutputPointer(ec, gCtx, instName, dhxName, N, M), + algo.dcxDesc, getDenseOutputPointer(ec, gCtx, instName, dcxName, N, M), + // ---------------------- + algo.workSpace, algo.sizeInBytes, + algo.reserveSpace, algo.reserveSpaceSizeInBytes); + gCtx.cudaFreeHelper(instName, dy, DMLScript.EAGER_CUDA_FREE); + ec.releaseMatrixInputForGPUInstruction(dcyName); + ec.releaseMatrixOutputForGPUInstruction(dhxName); + ec.releaseMatrixOutputForGPUInstruction(dcxName); + + Pointer smlDx = getDenseOutputPointer(ec, gCtx, instName, dxName, N, T*D); + LibMatrixCUDA.getCudaKernels(gCtx).launchKernel("prepare_lstm_dinput", + ExecutionConfig.getConfigForSimpleVectorOperations(N*T*D), + smlDx, cudnnDx, N, D, T*D, N*T*D); + ec.releaseMatrixOutputForGPUInstruction(dxName); + + // ------------------------------------------------------------------------------------------- + Pointer cudnnDwPointer = gCtx.allocate(instName, (D+M+2)*(4*M)*LibMatrixCUDA.sizeOfDataType); + JCudnn.cudnnRNNBackwardWeights(gCtx.getCudnnHandle(), algo.rnnDesc, T, + algo.xDesc, x, + algo.hxDesc, hx, + algo.yDesc, yPointer, + algo.workSpace, algo.sizeInBytes, + algo.dwDesc, cudnnDwPointer, + algo.reserveSpace, algo.reserveSpaceSizeInBytes); + LibMatrixCUDA.getCudaKernels(gCtx).launchKernel("prepare_lstm_dweight", + ExecutionConfig.getConfigForSimpleVectorOperations((D+M+2)*(4*M)), + getDenseOutputPointer(ec, gCtx, instName, dwName, D+M, 4*M), + getDenseOutputPointer(ec, gCtx, instName, dbName, 1, 4*M), cudnnDwPointer, D, M); + gCtx.cudaFreeHelper(instName, cudnnDwPointer, DMLScript.EAGER_CUDA_FREE); + ec.releaseMatrixOutputForGPUInstruction(dwName); + ec.releaseMatrixOutputForGPUInstruction(dbName); + // ------------------------------------------------------------------------------------------- + + gCtx.cudaFreeHelper(instName, yPointer, DMLScript.EAGER_CUDA_FREE); + } + } + + + /** * Performs the forward BatchNormalization layer computation for training * @param gCtx a valid {@link GPUContext} http://git-wip-us.apache.org/repos/asf/systemml/blob/e4220e3b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNNRnnAlgorithm.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNNRnnAlgorithm.java b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNNRnnAlgorithm.java index d772a55..68d308e 100644 --- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNNRnnAlgorithm.java +++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNNRnnAlgorithm.java @@ -49,27 +49,36 @@ public class LibMatrixCuDNNRnnAlgorithm implements java.lang.AutoCloseable { String instName; cudnnDropoutDescriptor dropoutDesc; cudnnRNNDescriptor rnnDesc; - cudnnTensorDescriptor[] xDesc, yDesc; // of length T - cudnnTensorDescriptor hxDesc, cxDesc, hyDesc, cyDesc; + cudnnTensorDescriptor[] xDesc, dxDesc, yDesc, dyDesc; // of length T + cudnnTensorDescriptor hxDesc, cxDesc, hyDesc, cyDesc, dhxDesc, dcxDesc, dhyDesc, dcyDesc; cudnnFilterDescriptor wDesc; + cudnnFilterDescriptor dwDesc; long sizeInBytes; Pointer workSpace; long reserveSpaceSizeInBytes; Pointer reserveSpace; public LibMatrixCuDNNRnnAlgorithm(ExecutionContext ec, GPUContext gCtx, String instName, - String rnnMode, int N, int T, int M, int D, boolean isTraining, Pointer w, String reserveSpaceName) throws DMLRuntimeException { + String rnnMode, int N, int T, int M, int D, boolean isTraining, Pointer w) throws DMLRuntimeException { this.gCtx = gCtx; this.instName = instName; // Allocate input/output descriptors xDesc = new cudnnTensorDescriptor[T]; + dxDesc = new cudnnTensorDescriptor[T]; yDesc = new cudnnTensorDescriptor[T]; + dyDesc = new cudnnTensorDescriptor[T]; for(int t = 0; t < T; t++) { xDesc[t] = allocateTensorDescriptorWithStride(N, D, 1); + dxDesc[t] = allocateTensorDescriptorWithStride(N, D, 1); yDesc[t] = allocateTensorDescriptorWithStride(N, M, 1); + dyDesc[t] = allocateTensorDescriptorWithStride(N, M, 1); } hxDesc = allocateTensorDescriptorWithStride(1, N, M); + dhxDesc = allocateTensorDescriptorWithStride(1, N, M); cxDesc = allocateTensorDescriptorWithStride(1, N, M); + dcxDesc = allocateTensorDescriptorWithStride(1, N, M); hyDesc = allocateTensorDescriptorWithStride(1, N, M); + dhyDesc = allocateTensorDescriptorWithStride(1, N, M); cyDesc = allocateTensorDescriptorWithStride(1, N, M); + dcyDesc = allocateTensorDescriptorWithStride(1, N, M); // Initial dropout descriptor dropoutDesc = new cudnnDropoutDescriptor(); @@ -94,6 +103,7 @@ public class LibMatrixCuDNNRnnAlgorithm implements java.lang.AutoCloseable { throw new DMLRuntimeException("Incorrect number of RNN parameters " + (D+M+2)*4*M + " != " + expectedNumWeights + ", where numFeatures=" + D + ", hiddenSize=" + M); } wDesc = allocateFilterDescriptor(expectedNumWeights); + dwDesc = allocateFilterDescriptor(expectedNumWeights); // Setup workspace workSpace = new Pointer(); reserveSpace = new Pointer(); @@ -104,12 +114,11 @@ public class LibMatrixCuDNNRnnAlgorithm implements java.lang.AutoCloseable { if(isTraining) { reserveSpaceSizeInBytes = getReservespaceSize(T); if (reserveSpaceSizeInBytes != 0) { - int numCols = (int) Math.ceil(((double)reserveSpaceSizeInBytes) / LibMatrixCUDA.sizeOfDataType); - reserveSpace = LibMatrixCuDNN.getDenseOutputPointer(ec, gCtx, instName, reserveSpaceName, 1, numCols); + reserveSpace = gCtx.allocate(reserveSpaceSizeInBytes); } } if (reserveSpaceSizeInBytes == 0) { - reserveSpace = LibMatrixCuDNN.getDenseOutputPointer(ec, gCtx, instName, reserveSpaceName, 1, 1); + reserveSpace = gCtx.allocate(reserveSpaceSizeInBytes); } /* @@ -241,30 +250,57 @@ public class LibMatrixCuDNNRnnAlgorithm implements java.lang.AutoCloseable { if(hxDesc != null) cudnnDestroyTensorDescriptor(hxDesc); hxDesc = null; + if(dhxDesc != null) + cudnnDestroyTensorDescriptor(dhxDesc); + dhxDesc = null; if(hyDesc != null) cudnnDestroyTensorDescriptor(hyDesc); hyDesc = null; + if(dhyDesc != null) + cudnnDestroyTensorDescriptor(dhyDesc); + dhyDesc = null; if(cxDesc != null) cudnnDestroyTensorDescriptor(cxDesc); cxDesc = null; + if(dcxDesc != null) + cudnnDestroyTensorDescriptor(dcxDesc); + dcxDesc = null; if(cyDesc != null) cudnnDestroyTensorDescriptor(cyDesc); cyDesc = null; + if(dcyDesc != null) + cudnnDestroyTensorDescriptor(dcyDesc); + dcyDesc = null; if(wDesc != null) cudnnDestroyFilterDescriptor(wDesc); wDesc = null; + if(dwDesc != null) + cudnnDestroyFilterDescriptor(dwDesc); + dwDesc = null; if(xDesc != null) { for(cudnnTensorDescriptor dsc : xDesc) { cudnnDestroyTensorDescriptor(dsc); } xDesc = null; } + if(dxDesc != null) { + for(cudnnTensorDescriptor dsc : dxDesc) { + cudnnDestroyTensorDescriptor(dsc); + } + dxDesc = null; + } if(yDesc != null) { for(cudnnTensorDescriptor dsc : yDesc) { cudnnDestroyTensorDescriptor(dsc); } yDesc = null; } + if(dyDesc != null) { + for(cudnnTensorDescriptor dsc : dyDesc) { + cudnnDestroyTensorDescriptor(dsc); + } + dyDesc = null; + } if(sizeInBytes != 0) { try { gCtx.cudaFreeHelper(instName, workSpace, DMLScript.EAGER_CUDA_FREE);
