[SYSTEMML-445] Removed batch_norm builtin functions - Removed batch_norm builtin functions to exploit codegen in CP. - Added rewrites for compiling efficient CuDNN operators. - Added rewrites for SGD update operations. - To simplify adding new GPU rewrites, added HopDagPatternMatcher that allows for pattern matching at the HOP-level. This can be extended for other rewrites as well. - Added GPU tests to validate the rewrites. - Updated the DML language documentation.
Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/0f36780a Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/0f36780a Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/0f36780a Branch: refs/heads/master Commit: 0f36780a8244c6e728d37c32a79e00ed181211ad Parents: 81419ae Author: Niketan Pansare <npan...@us.ibm.com> Authored: Thu Aug 30 15:40:44 2018 -0700 Committer: Niketan Pansare <npan...@us.ibm.com> Committed: Thu Aug 30 15:40:44 2018 -0700 ---------------------------------------------------------------------- docs/dml-language-reference.md | 2 - scripts/nn/layers/batch_norm2d.dml | 60 +- scripts/nn/layers/batch_norm2d_old.dml | 200 ---- src/main/cpp/kernels/SystemML.cu | 56 +- src/main/cpp/kernels/SystemML.ptx | 321 +++++- src/main/java/org/apache/sysml/hops/DnnOp.java | 56 +- .../java/org/apache/sysml/hops/FunctionOp.java | 30 +- src/main/java/org/apache/sysml/hops/Hop.java | 8 +- .../hops/rewrite/HopDagPatternMatcher.java | 378 +++++++ .../sysml/hops/rewrite/HopPatternRewriter.java | 72 ++ .../HopRewriteRuleWithPatternMatcher.java | 98 ++ .../sysml/hops/rewrite/HopRewriteUtils.java | 20 + .../hops/rewrite/RewriteGPUSpecificOps.java | 1027 +++++------------- .../org/apache/sysml/lops/DnnTransform.java | 53 +- .../sysml/parser/BuiltinFunctionExpression.java | 61 +- .../org/apache/sysml/parser/DMLTranslator.java | 2 - .../org/apache/sysml/parser/Expression.java | 2 +- .../instructions/GPUInstructionParser.java | 10 +- .../instructions/gpu/DnnGPUInstruction.java | 526 +++++---- .../gpu/GPUDenseInputPointerFetcher.java | 111 ++ .../gpu/context/GPUMemoryManager.java | 2 +- .../runtime/matrix/data/LibMatrixCUDA.java | 110 +- .../runtime/matrix/data/LibMatrixCuDNN.java | 37 +- .../apache/sysml/test/gpu/BatchNormTest.java | 47 +- 24 files changed, 1818 insertions(+), 1471 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/0f36780a/docs/dml-language-reference.md ---------------------------------------------------------------------- diff --git a/docs/dml-language-reference.md b/docs/dml-language-reference.md index 924336a..cdcc529 100644 --- a/docs/dml-language-reference.md +++ b/docs/dml-language-reference.md @@ -1522,8 +1522,6 @@ Hence, the images are internally represented as a matrix with dimension (N, C * | bias_add | input, bias | [batch_size X num_channels* height_image* width_image] | [num_channels X 1] | [batch_size X num_channels* height_image* width_image] | | Adds the bias (row vector of size num_channels) to input with the given num_channels | | bias_multiply | input, bias | [batch_size X num_channels* height_image* width_image] | [num_channels X 1] | [batch_size X num_channels* height_image* width_image] | | Multiplies the bias (row vector of size num_channels) to input with the given num_channels | | lstm | X, W, bias, out0, c0 | [batch_size X seq_length*num_features] | [num_features+hidden_size X 4*hidden_size] | [batch_size X seq_length*hidden_size] if return_sequences else [batch_size X hidden_size] | return_sequences | Perform computation for single-layer unidirectional LSTM (outputs: out, carryOut) | -| batch_norm2d | input | [batch_size X num_channels* height_image* width_image] | | [batch_size X num_channels* height_image* width_image] | scale, shift, exponentialMovingAverage_Mean, exponentialMovingAverage_Variance, mode, epsilon, momentum | Performs batch normalization operation (outputs: updated exponential moving average mean and variance, cache of the batch mean and variance) | -| batch_norm2d_backward | input, dout | [batch_size X num_channels* height_image* width_image] | [batch_size X num_channels* height_image* width_image] | [batch_size X num_channels* height_image* width_image] | scale, epsilon, cache_mean (from forward), cache_inv_var (from forward) | Computed backpropagation error for batch normalization operation | Note: the builtin functions `batch_norm2d` and `batch_norm2d_backward` are deprecated and will be removed in the next release. The `lstm` builtin function is in experimental phase and is only supported for the GPU backend. http://git-wip-us.apache.org/repos/asf/systemml/blob/0f36780a/scripts/nn/layers/batch_norm2d.dml ---------------------------------------------------------------------- diff --git a/scripts/nn/layers/batch_norm2d.dml b/scripts/nn/layers/batch_norm2d.dml index 2a98857..c68f23d 100644 --- a/scripts/nn/layers/batch_norm2d.dml +++ b/scripts/nn/layers/batch_norm2d.dml @@ -83,8 +83,41 @@ forward = function(matrix[double] X, matrix[double] gamma, matrix[double] beta, * - cache_inv_var: Cache of the inverse variance, of shape (C, 1). * Note: This is used for performance during training. */ - out = X; ema_mean_upd = ema_mean; ema_var_upd = ema_var; cache_mean = ema_mean; cache_inv_var = ema_var - [out, ema_mean_upd, ema_var_upd, cache_mean, cache_inv_var] = batch_norm2d(X, gamma, beta, ema_mean, ema_var, mode, epsilon, mu) + N = nrow(X) + + if (mode == 'train') { + # Compute channel-wise mean and variance + # Since we don't have tensors, we will compute the means and variances in a piece-wise fashion. + # - mean of total group is mean of subgroup means + # - variance is the mean of the subgroup variances + the variance of the subgroup means + subgrp_means = matrix(colMeans(X), rows=C, cols=Hin*Win) + subgrp_vars = matrix(colVars(X) * ((N-1)/N), rows=C, cols=Hin*Win) # uncorrected variances + mean = rowMeans(subgrp_means) # shape (C, 1) + var = rowMeans(subgrp_vars) + rowVars(subgrp_means)*(((Hin*Win)-1)/(Hin*Win)) # shape (C, 1) + # Update moving averages + ema_mean_upd = mu*ema_mean + (1-mu)*mean + ema_var_upd = mu*ema_var + (1-mu)*var + } + else { + # Use moving averages of mean and variance during testing + mean = ema_mean + var = ema_var + ema_mean_upd = ema_mean + ema_var_upd = ema_var + } + + # Save variable for backward pass + cache_mean = mean + cache_inv_var = 1/sqrt(var+epsilon) + + # Normalize, shift, and scale + # norm = (X-mean)*(var+epsilon)^(-1/2) + # = (X-mean) / sqrt(var+epsilon) + centered = bias_add(X, -mean) # shape (N, C*Hin*Win) + norm = bias_multiply(centered, cache_inv_var) # shape (N, C*Hin*Win) + # out = norm*gamma + beta + scaled = bias_multiply(norm, gamma) # shape (N, C*Hin*Win) + out = bias_add(scaled, beta) # shape (N, C*Hin*Win) } backward = function(matrix[double] dout, @@ -119,9 +152,27 @@ backward = function(matrix[double] dout, * - dbeta: Gradient wrt `b`, of shape (C, 1). * */ + N = nrow(X) + oneByN = 1/N + oneByHW = 1/(Hin*Win) + + mean = cache_mean + centered = bias_add(X, -mean) # shape (N, C*Hin*Win) + norm = bias_multiply(centered, cache_inv_var) # shape (N, C*Hin*Win) # Compute gradients during training - dX = X; dgamma = gamma; dbeta = gamma; - [dX, dgamma, dbeta] = batch_norm2d_backward(X, dout, gamma, epsilon, cache_mean, cache_inv_var) + dgamma = util::channel_sums(dout*norm, C, Hin, Win) # shape (C, 1) + dbeta = util::channel_sums(dout, C, Hin, Win) # shape (C, 1) + dnorm = bias_multiply(dout, gamma) # shape (N, C*Hin*Win) + dvar = util::channel_sums((-1/2) * bias_multiply(centered, cache_inv_var^3) * dnorm, + C, Hin, Win) # shape (C, 1) + dmean_norm_branch = util::channel_sums(bias_multiply(dnorm, -cache_inv_var), C, Hin, Win) + dmean_var_branch = util::channel_sums((-2*oneByN*oneByHW) * centered, C, Hin, Win) + dmean_var_branch = dmean_var_branch * dvar # we can't use a function within an expression yet + dmean = dmean_norm_branch + dmean_var_branch # shape (C, 1) + dX_norm_branch = bias_multiply(dnorm, cache_inv_var) + dX_mean_branch = (oneByN*oneByHW) * bias_add(matrix(0, rows=1, cols=C*Hin*Win), dmean) + dX_var_branch = (2*oneByN*oneByHW) * bias_multiply(centered, dvar) + dX = dX_norm_branch + dX_mean_branch + dX_var_branch # shape (N, C*Hin*Win) } init = function(int C) @@ -149,3 +200,4 @@ init = function(int C) ema_mean = matrix(0, rows=C, cols=1) ema_var = matrix(1, rows=C, cols=1) } + http://git-wip-us.apache.org/repos/asf/systemml/blob/0f36780a/scripts/nn/layers/batch_norm2d_old.dml ---------------------------------------------------------------------- diff --git a/scripts/nn/layers/batch_norm2d_old.dml b/scripts/nn/layers/batch_norm2d_old.dml deleted file mode 100644 index 2aba2e6..0000000 --- a/scripts/nn/layers/batch_norm2d_old.dml +++ /dev/null @@ -1,200 +0,0 @@ -#------------------------------------------------------------- -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -#------------------------------------------------------------- - -/* - * 2D (Spatial) Batch Normalization layer. - */ -source("nn/util.dml") as util - -forward = function(matrix[double] X, matrix[double] gamma, matrix[double] beta, - int C, int Hin, int Win, string mode, - matrix[double] ema_mean, matrix[double] ema_var, - double mu, double epsilon) - return (matrix[double] out, matrix[double] ema_mean_upd, matrix[double] ema_var_upd, - matrix[double] cache_mean, matrix[double] cache_inv_var) { - /* - * Computes the forward pass for a 2D (spatial) batch normalization - * layer. The input data has N examples, each represented as a 3D - * volume unrolled into a single vector. - * - * A spatial batch normalization layer uses the per-channel sample - * mean and per-channel uncorrected sample variance during training - * to normalize each channel of the input data. Additionally, it - * introduces learnable parameters (gamma, beta) to control the - * amount of normalization. - * - * `y = ((x-mean) / sqrt(var+eps)) * gamma + beta` - * - * This implementation maintains exponential moving averages of the - * mean and variance during training for use during testing. - * - * Reference: - * - Batch Normalization: Accelerating Deep Network Training by - * Reducing Internal Covariate Shift, S. Ioffe & C. Szegedy, 2015 - * - https://arxiv.org/abs/1502.03167 - * - * Inputs: - * - X: Inputs, of shape (N, C*Hin*Win). - * - gamma: Scale parameters, of shape (C, 1). - * - beta: Shift parameters, of shape (C, 1). - * - C: Number of input channels (dimensionality of input depth). - * - Hin: Input height. - * - Win: Input width. - * - mode: 'train' or 'test' to indicate if the model is currently - * being trained or tested. During training, the current batch - * mean and variance will be used to normalize the inputs, while - * during testing, the exponential average of the mean and - * variance over all previous batches will be used. - * - ema_mean: Exponential moving average of the mean, of - * shape (C, 1). - * - ema_var: Exponential moving average of the variance, of - * shape (C, 1). - * - mu: Momentum value for moving averages. - * Typical values are in the range of [0.9, 0.999]. - * - epsilon: Smoothing term to avoid divide by zero errors. - * Typical values are in the range of [1e-5, 1e-3]. - * - * Outputs: - * - out: Outputs, of shape (N, C*Hin*Win). - * - ema_mean_upd: Updated exponential moving average of the mean, - * of shape (C, 1). - * - ema_var_upd: Updated exponential moving average of the variance, - * of shape (C, 1). - * - cache_mean: Cache of the batch mean, of shape (C, 1). - * Note: This is used for performance during training. - * - cache_inv_var: Cache of the inverse variance, of shape (C, 1). - * Note: This is used for performance during training. - */ - N = nrow(X) - - if (mode == 'train') { - # Compute channel-wise mean and variance - # Since we don't have tensors, we will compute the means and variances in a piece-wise fashion. - # - mean of total group is mean of subgroup means - # - variance is the mean of the subgroup variances + the variance of the subgroup means - subgrp_means = matrix(colMeans(X), rows=C, cols=Hin*Win) - subgrp_vars = matrix(colVars(X) * ((N-1)/N), rows=C, cols=Hin*Win) # uncorrected variances - mean = rowMeans(subgrp_means) # shape (C, 1) - var = rowMeans(subgrp_vars) + rowVars(subgrp_means)*(((Hin*Win)-1)/(Hin*Win)) # shape (C, 1) - # Update moving averages - ema_mean_upd = mu*ema_mean + (1-mu)*mean - ema_var_upd = mu*ema_var + (1-mu)*var - } - else { - # Use moving averages of mean and variance during testing - mean = ema_mean - var = ema_var - ema_mean_upd = ema_mean - ema_var_upd = ema_var - } - - # Save variable for backward pass - cache_mean = mean - cache_inv_var = 1/sqrt(var+epsilon) - - # Normalize, shift, and scale - # norm = (X-mean)*(var+epsilon)^(-1/2) - # = (X-mean) / sqrt(var+epsilon) - centered = bias_add(X, -mean) # shape (N, C*Hin*Win) - norm = bias_multiply(centered, cache_inv_var) # shape (N, C*Hin*Win) - # out = norm*gamma + beta - scaled = bias_multiply(norm, gamma) # shape (N, C*Hin*Win) - out = bias_add(scaled, beta) # shape (N, C*Hin*Win) -} - -backward = function(matrix[double] dout, - matrix[double] cache_mean, matrix[double] cache_inv_var, - matrix[double] X, matrix[double] gamma, - int C, int Hin, int Win, double epsilon) - return (matrix[double] dX, matrix[double] dgamma, matrix[double] dbeta) { - /* - * Computes the backward pass for a 2D (spatial) batch normalization - * layer. - * - * Inputs: - * - dout: Gradient wrt `out` from upstream, of shape (N, C*Hin*Win). - * - cache_mean: Cache of the batch mean from the forward pass, of - * shape (C, 1). Note: This is used for performance during - * training. - * - cache_inv_var: Cache of the inverse variance from the forward pass, - * of shape (C, 1). Note: This is used for performance during - * training. - * - X: Input data matrix to the forward pass, of - * shape (N, C*Hin*Win). - * - gamma: Scale parameters, of shape (C, 1). - * - C: Number of input channels (dimensionality of input depth). - * - Hin: Input height. - * - Win: Input width. - * - epsilon: Smoothing term to avoid divide by zero errors. - * Typical values are in the range of [1e-5, 1e-3]. - * - * Outputs: - * - dX: Gradient wrt `X`, of shape (N, C*Hin*Win). - * - dgamma: Gradient wrt `W`, of shape (C, 1). - * - dbeta: Gradient wrt `b`, of shape (C, 1). - * - */ - N = nrow(X) - mean = cache_mean - centered = bias_add(X, -mean) # shape (N, C*Hin*Win) - norm = bias_multiply(centered, cache_inv_var) # shape (N, C*Hin*Win) - # Compute gradients during training - dgamma = util::channel_sums(dout*norm, C, Hin, Win) # shape (C, 1) - dbeta = util::channel_sums(dout, C, Hin, Win) # shape (C, 1) - dnorm = bias_multiply(dout, gamma) # shape (N, C*Hin*Win) - dvar = util::channel_sums((-1/2) * bias_multiply(centered, cache_inv_var^3) * dnorm, - C, Hin, Win) # shape (C, 1) - dmean_norm_branch = util::channel_sums(bias_multiply(dnorm, -cache_inv_var), C, Hin, Win) - dmean_var_branch = util::channel_sums((-2/(N*Hin*Win)) * centered, C, Hin, Win) - dmean_var_branch = dmean_var_branch * dvar # we can't use a function within an expression yet - dmean = dmean_norm_branch + dmean_var_branch # shape (C, 1) - dX_norm_branch = bias_multiply(dnorm, cache_inv_var) - dX_mean_branch = (1/(N*Hin*Win)) * bias_add(matrix(0, rows=1, cols=C*Hin*Win), dmean) - dX_var_branch = (2/(N*Hin*Win)) * bias_multiply(centered, dvar) - dX = dX_norm_branch + dX_mean_branch + dX_var_branch # shape (N, C*Hin*Win) -} - -init = function(int C) - return (matrix[double] gamma, matrix[double] beta, - matrix[double] ema_mean, matrix[double] ema_var) { - /* - * Initialize the parameters of this layer. - * - * Note: This is just a convenience function, and parameters - * may be initialized manually if needed. - * - * Inputs: - * - C: Number of input channels (dimensionality of input depth). - * - * Outputs: - * - gamma: Scale parameters, of shape (C, 1). - * - beta: Shift parameters, of shape (C, 1). - * - ema_mean: Exponential moving average of the mean, of - * shape (C, 1). - * - ema_var: Exponential moving average of the variance, of - * shape (C, 1). - */ - gamma = matrix(1, rows=C, cols=1) - beta = matrix(0, rows=C, cols=1) - ema_mean = matrix(0, rows=C, cols=1) - ema_var = matrix(1, rows=C, cols=1) -} - http://git-wip-us.apache.org/repos/asf/systemml/blob/0f36780a/src/main/cpp/kernels/SystemML.cu ---------------------------------------------------------------------- diff --git a/src/main/cpp/kernels/SystemML.cu b/src/main/cpp/kernels/SystemML.cu index 9ddaaff..b874cdd 100644 --- a/src/main/cpp/kernels/SystemML.cu +++ b/src/main/cpp/kernels/SystemML.cu @@ -2289,4 +2289,58 @@ extern "C" __global__ void update_nesterov_x_d(double *X, double *v, double *v_p extern "C" __global__ void update_nesterov_x_f(float *X, float *v, float *v_prev, double mu, float *out, unsigned int size) { update_nesterov_x(X, v, v_prev, mu, out, size); -} \ No newline at end of file +} + +// Performs the operation: C = a*X + b*C +template <typename T> +__device__ void aXplusbC(T *X, T *C, double a, double b, unsigned int size) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + if (index < size) { + C[index] = a*X[index] + b*C[index]; + } +} + +extern "C" __global__ void aXplusbC_d(double *X, double *C, double a, double b, unsigned int size) { + aXplusbC(X, C, a, b,size); +} + +extern "C" __global__ void aXplusbC_f(float *X, float *C, double a, double b, unsigned int size) { + aXplusbC(X, C, a, b,size);; +} + + +// Performs the operation: C = a*X + b*Y +template <typename T> +__device__ void aXplusbY(T *X, T* Y, T *C, double a, double b, unsigned int size) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + if (index < size) { + C[index] = a*X[index] + b*Y[index]; + } +} + +extern "C" __global__ void aXplusbY_d(double *X, double* Y, double *C, double a, double b, unsigned int size) { + aXplusbY(X, Y, C, a, b, size); +} + +extern "C" __global__ void aXplusbY_f(float *X, float* Y, float *C, double a, double b, unsigned int size) { + aXplusbY(X, Y, C, a, b, size); +} + + +// Performs the operation: C = 1 / sqrt(X + eps) +template <typename T> +__device__ void invVar(T *X, T *C, double eps, unsigned int size) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + if (index < size) { + C[index] = 1.0 / sqrt(X[index] + eps); + } +} + +extern "C" __global__ void invVar_d(double *X, double *C, double eps, unsigned int size) { + invVar(X, C, eps, size); +} + +extern "C" __global__ void invVar_f(float *X, float *C, double eps, unsigned int size) { + invVar(X, C, eps, size); +} + http://git-wip-us.apache.org/repos/asf/systemml/blob/0f36780a/src/main/cpp/kernels/SystemML.ptx ---------------------------------------------------------------------- diff --git a/src/main/cpp/kernels/SystemML.ptx b/src/main/cpp/kernels/SystemML.ptx index 8a14876..1ab32f5 100644 --- a/src/main/cpp/kernels/SystemML.ptx +++ b/src/main/cpp/kernels/SystemML.ptx @@ -13084,12 +13084,279 @@ BB115_2: ret; } + // .globl aXplusbC_d +.visible .entry aXplusbC_d( + .param .u64 aXplusbC_d_param_0, + .param .u64 aXplusbC_d_param_1, + .param .f64 aXplusbC_d_param_2, + .param .f64 aXplusbC_d_param_3, + .param .u32 aXplusbC_d_param_4 +) +{ + .reg .pred %p<2>; + .reg .b32 %r<6>; + .reg .f64 %fd<7>; + .reg .b64 %rd<8>; + + + ld.param.u64 %rd1, [aXplusbC_d_param_0]; + ld.param.u64 %rd2, [aXplusbC_d_param_1]; + ld.param.f64 %fd1, [aXplusbC_d_param_2]; + ld.param.f64 %fd2, [aXplusbC_d_param_3]; + ld.param.u32 %r2, [aXplusbC_d_param_4]; + mov.u32 %r3, %ctaid.x; + mov.u32 %r4, %ntid.x; + mov.u32 %r5, %tid.x; + mad.lo.s32 %r1, %r4, %r3, %r5; + setp.ge.u32 %p1, %r1, %r2; + @%p1 bra BB116_2; + + cvta.to.global.u64 %rd3, %rd2; + cvta.to.global.u64 %rd4, %rd1; + mul.wide.s32 %rd5, %r1, 8; + add.s64 %rd6, %rd4, %rd5; + ld.global.f64 %fd3, [%rd6]; + add.s64 %rd7, %rd3, %rd5; + ld.global.f64 %fd4, [%rd7]; + mul.f64 %fd5, %fd4, %fd2; + fma.rn.f64 %fd6, %fd3, %fd1, %fd5; + st.global.f64 [%rd7], %fd6; + +BB116_2: + ret; +} + + // .globl aXplusbC_f +.visible .entry aXplusbC_f( + .param .u64 aXplusbC_f_param_0, + .param .u64 aXplusbC_f_param_1, + .param .f64 aXplusbC_f_param_2, + .param .f64 aXplusbC_f_param_3, + .param .u32 aXplusbC_f_param_4 +) +{ + .reg .pred %p<2>; + .reg .f32 %f<4>; + .reg .b32 %r<6>; + .reg .f64 %fd<7>; + .reg .b64 %rd<8>; + + + ld.param.u64 %rd1, [aXplusbC_f_param_0]; + ld.param.u64 %rd2, [aXplusbC_f_param_1]; + ld.param.f64 %fd1, [aXplusbC_f_param_2]; + ld.param.f64 %fd2, [aXplusbC_f_param_3]; + ld.param.u32 %r2, [aXplusbC_f_param_4]; + mov.u32 %r3, %ctaid.x; + mov.u32 %r4, %ntid.x; + mov.u32 %r5, %tid.x; + mad.lo.s32 %r1, %r4, %r3, %r5; + setp.ge.u32 %p1, %r1, %r2; + @%p1 bra BB117_2; + + cvta.to.global.u64 %rd3, %rd2; + cvta.to.global.u64 %rd4, %rd1; + mul.wide.s32 %rd5, %r1, 4; + add.s64 %rd6, %rd4, %rd5; + ld.global.f32 %f1, [%rd6]; + cvt.f64.f32 %fd3, %f1; + add.s64 %rd7, %rd3, %rd5; + ld.global.f32 %f2, [%rd7]; + cvt.f64.f32 %fd4, %f2; + mul.f64 %fd5, %fd4, %fd2; + fma.rn.f64 %fd6, %fd3, %fd1, %fd5; + cvt.rn.f32.f64 %f3, %fd6; + st.global.f32 [%rd7], %f3; + +BB117_2: + ret; +} + + // .globl aXplusbY_d +.visible .entry aXplusbY_d( + .param .u64 aXplusbY_d_param_0, + .param .u64 aXplusbY_d_param_1, + .param .u64 aXplusbY_d_param_2, + .param .f64 aXplusbY_d_param_3, + .param .f64 aXplusbY_d_param_4, + .param .u32 aXplusbY_d_param_5 +) +{ + .reg .pred %p<2>; + .reg .b32 %r<6>; + .reg .f64 %fd<7>; + .reg .b64 %rd<11>; + + + ld.param.u64 %rd1, [aXplusbY_d_param_0]; + ld.param.u64 %rd2, [aXplusbY_d_param_1]; + ld.param.u64 %rd3, [aXplusbY_d_param_2]; + ld.param.f64 %fd1, [aXplusbY_d_param_3]; + ld.param.f64 %fd2, [aXplusbY_d_param_4]; + ld.param.u32 %r2, [aXplusbY_d_param_5]; + mov.u32 %r3, %ctaid.x; + mov.u32 %r4, %ntid.x; + mov.u32 %r5, %tid.x; + mad.lo.s32 %r1, %r4, %r3, %r5; + setp.ge.u32 %p1, %r1, %r2; + @%p1 bra BB118_2; + + cvta.to.global.u64 %rd4, %rd1; + mul.wide.s32 %rd5, %r1, 8; + add.s64 %rd6, %rd4, %rd5; + ld.global.f64 %fd3, [%rd6]; + cvta.to.global.u64 %rd7, %rd2; + add.s64 %rd8, %rd7, %rd5; + ld.global.f64 %fd4, [%rd8]; + mul.f64 %fd5, %fd4, %fd2; + fma.rn.f64 %fd6, %fd3, %fd1, %fd5; + cvta.to.global.u64 %rd9, %rd3; + add.s64 %rd10, %rd9, %rd5; + st.global.f64 [%rd10], %fd6; + +BB118_2: + ret; +} + + // .globl aXplusbY_f +.visible .entry aXplusbY_f( + .param .u64 aXplusbY_f_param_0, + .param .u64 aXplusbY_f_param_1, + .param .u64 aXplusbY_f_param_2, + .param .f64 aXplusbY_f_param_3, + .param .f64 aXplusbY_f_param_4, + .param .u32 aXplusbY_f_param_5 +) +{ + .reg .pred %p<2>; + .reg .f32 %f<4>; + .reg .b32 %r<6>; + .reg .f64 %fd<7>; + .reg .b64 %rd<11>; + + + ld.param.u64 %rd1, [aXplusbY_f_param_0]; + ld.param.u64 %rd2, [aXplusbY_f_param_1]; + ld.param.u64 %rd3, [aXplusbY_f_param_2]; + ld.param.f64 %fd1, [aXplusbY_f_param_3]; + ld.param.f64 %fd2, [aXplusbY_f_param_4]; + ld.param.u32 %r2, [aXplusbY_f_param_5]; + mov.u32 %r3, %ctaid.x; + mov.u32 %r4, %ntid.x; + mov.u32 %r5, %tid.x; + mad.lo.s32 %r1, %r4, %r3, %r5; + setp.ge.u32 %p1, %r1, %r2; + @%p1 bra BB119_2; + + cvta.to.global.u64 %rd4, %rd1; + mul.wide.s32 %rd5, %r1, 4; + add.s64 %rd6, %rd4, %rd5; + ld.global.f32 %f1, [%rd6]; + cvt.f64.f32 %fd3, %f1; + cvta.to.global.u64 %rd7, %rd2; + add.s64 %rd8, %rd7, %rd5; + ld.global.f32 %f2, [%rd8]; + cvt.f64.f32 %fd4, %f2; + mul.f64 %fd5, %fd4, %fd2; + fma.rn.f64 %fd6, %fd3, %fd1, %fd5; + cvt.rn.f32.f64 %f3, %fd6; + cvta.to.global.u64 %rd9, %rd3; + add.s64 %rd10, %rd9, %rd5; + st.global.f32 [%rd10], %f3; + +BB119_2: + ret; +} + + // .globl invVar_d +.visible .entry invVar_d( + .param .u64 invVar_d_param_0, + .param .u64 invVar_d_param_1, + .param .f64 invVar_d_param_2, + .param .u32 invVar_d_param_3 +) +{ + .reg .pred %p<2>; + .reg .b32 %r<6>; + .reg .f64 %fd<6>; + .reg .b64 %rd<8>; + + + ld.param.u64 %rd1, [invVar_d_param_0]; + ld.param.u64 %rd2, [invVar_d_param_1]; + ld.param.f64 %fd1, [invVar_d_param_2]; + ld.param.u32 %r2, [invVar_d_param_3]; + mov.u32 %r3, %ctaid.x; + mov.u32 %r4, %ntid.x; + mov.u32 %r5, %tid.x; + mad.lo.s32 %r1, %r4, %r3, %r5; + setp.ge.u32 %p1, %r1, %r2; + @%p1 bra BB120_2; + + cvta.to.global.u64 %rd3, %rd1; + mul.wide.s32 %rd4, %r1, 8; + add.s64 %rd5, %rd3, %rd4; + ld.global.f64 %fd2, [%rd5]; + add.f64 %fd3, %fd2, %fd1; + sqrt.rn.f64 %fd4, %fd3; + rcp.rn.f64 %fd5, %fd4; + cvta.to.global.u64 %rd6, %rd2; + add.s64 %rd7, %rd6, %rd4; + st.global.f64 [%rd7], %fd5; + +BB120_2: + ret; +} + + // .globl invVar_f +.visible .entry invVar_f( + .param .u64 invVar_f_param_0, + .param .u64 invVar_f_param_1, + .param .f64 invVar_f_param_2, + .param .u32 invVar_f_param_3 +) +{ + .reg .pred %p<2>; + .reg .f32 %f<3>; + .reg .b32 %r<6>; + .reg .f64 %fd<6>; + .reg .b64 %rd<8>; + + + ld.param.u64 %rd1, [invVar_f_param_0]; + ld.param.u64 %rd2, [invVar_f_param_1]; + ld.param.f64 %fd1, [invVar_f_param_2]; + ld.param.u32 %r2, [invVar_f_param_3]; + mov.u32 %r3, %ctaid.x; + mov.u32 %r4, %ntid.x; + mov.u32 %r5, %tid.x; + mad.lo.s32 %r1, %r4, %r3, %r5; + setp.ge.u32 %p1, %r1, %r2; + @%p1 bra BB121_2; + + cvta.to.global.u64 %rd3, %rd1; + mul.wide.s32 %rd4, %r1, 4; + add.s64 %rd5, %rd3, %rd4; + ld.global.f32 %f1, [%rd5]; + cvt.f64.f32 %fd2, %f1; + add.f64 %fd3, %fd2, %fd1; + sqrt.rn.f64 %fd4, %fd3; + rcp.rn.f64 %fd5, %fd4; + cvt.rn.f32.f64 %f2, %fd5; + cvta.to.global.u64 %rd6, %rd2; + add.s64 %rd7, %rd6, %rd4; + st.global.f32 [%rd7], %f2; + +BB121_2: + ret; +} + .func (.param .b64 func_retval0) __internal_trig_reduction_slowpathd( .param .b64 __internal_trig_reduction_slowpathd_param_0, .param .b64 __internal_trig_reduction_slowpathd_param_1 ) { - .local .align 8 .b8 __local_depot116[40]; + .local .align 8 .b8 __local_depot122[40]; .reg .b64 %SP; .reg .b64 %SPL; .reg .pred %p<9>; @@ -13098,7 +13365,7 @@ BB115_2: .reg .b64 %rd<102>; - mov.u64 %rd101, __local_depot116; + mov.u64 %rd101, __local_depot122; cvta.local.u64 %SP, %rd101; ld.param.f64 %fd4, [__internal_trig_reduction_slowpathd_param_0]; ld.param.u64 %rd37, [__internal_trig_reduction_slowpathd_param_1]; @@ -13112,7 +13379,7 @@ BB115_2: shr.u32 %r3, %r1, 20; bfe.u32 %r4, %r1, 20, 11; setp.eq.s32 %p1, %r4, 2047; - @%p1 bra BB116_13; + @%p1 bra BB122_13; add.s32 %r15, %r4, -1024; shr.u32 %r16, %r15, 6; @@ -13125,7 +13392,7 @@ BB115_2: mov.u64 %rd94, 0; setp.ge.s32 %p2, %r5, %r6; mov.u64 %rd93, %rd1; - @%p2 bra BB116_4; + @%p2 bra BB122_4; mov.b64 %rd41, %fd4; shl.b64 %rd42, %rd41, 11; @@ -13142,7 +13409,7 @@ BB115_2: mov.u64 %rd91, %rd1; mov.u32 %r39, %r5; -BB116_3: +BB122_3: .pragma "nounroll"; ld.const.u64 %rd47, [%rd89]; // inline asm @@ -13172,15 +13439,15 @@ BB116_3: add.s64 %rd93, %rd93, 8; add.s64 %rd89, %rd89, 8; setp.lt.s32 %p3, %r39, %r6; - @%p3 bra BB116_3; + @%p3 bra BB122_3; -BB116_4: +BB122_4: st.local.u64 [%rd93], %rd94; ld.local.u64 %rd95, [%rd1+16]; ld.local.u64 %rd96, [%rd1+24]; and.b32 %r9, %r3, 63; setp.eq.s32 %p4, %r9, 0; - @%p4 bra BB116_6; + @%p4 bra BB122_6; mov.u32 %r27, 64; sub.s32 %r28, %r27, %r9; @@ -13192,7 +13459,7 @@ BB116_4: shr.u64 %rd55, %rd54, %r28; or.b64 %rd95, %rd55, %rd53; -BB116_6: +BB122_6: cvta.to.local.u64 %rd56, %rd37; shr.u64 %rd57, %rd96, 62; cvt.u32.u64 %r29, %rd57; @@ -13209,7 +13476,7 @@ BB116_6: selp.b32 %r34, %r32, %r33, %p5; st.local.u32 [%rd56], %r34; setp.eq.s32 %p6, %r31, 0; - @%p6 bra BB116_8; + @%p6 bra BB122_8; mov.u64 %rd64, 0; // inline asm @@ -13229,10 +13496,10 @@ BB116_6: // inline asm xor.b32 %r40, %r40, -2147483648; -BB116_8: +BB122_8: clz.b64 %r41, %rd98; setp.eq.s32 %p7, %r41, 0; - @%p7 bra BB116_10; + @%p7 bra BB122_10; shl.b64 %rd67, %rd98, %r41; mov.u32 %r35, 64; @@ -13240,7 +13507,7 @@ BB116_8: shr.u64 %rd68, %rd97, %r36; or.b64 %rd98, %rd68, %rd67; -BB116_10: +BB122_10: mov.u64 %rd72, -3958705157555305931; // inline asm { @@ -13261,7 +13528,7 @@ BB116_10: } // inline asm setp.lt.s64 %p8, %rd100, 1; - @%p8 bra BB116_12; + @%p8 bra BB122_12; // inline asm { @@ -13280,7 +13547,7 @@ BB116_10: // inline asm add.s32 %r41, %r41, 1; -BB116_12: +BB122_12: cvt.u64.u32 %rd79, %r40; shl.b64 %rd80, %rd79, 32; mov.u32 %r37, 1022; @@ -13295,7 +13562,7 @@ BB116_12: or.b64 %rd88, %rd87, %rd80; mov.b64 %fd4, %rd88; -BB116_13: +BB122_13: st.param.f64 [func_retval0+0], %fd4; ret; } @@ -13323,7 +13590,7 @@ BB116_13: } shr.u32 %r51, %r50, 20; setp.ne.s32 %p1, %r51, 0; - @%p1 bra BB117_2; + @%p1 bra BB123_2; mul.f64 %fd14, %fd12, 0d4350000000000000; { @@ -13337,13 +13604,13 @@ BB116_13: shr.u32 %r16, %r50, 20; add.s32 %r51, %r16, -54; -BB117_2: +BB123_2: add.s32 %r52, %r51, -1023; and.b32 %r17, %r50, -2146435073; or.b32 %r18, %r17, 1072693248; mov.b64 %fd135, {%r49, %r18}; setp.lt.u32 %p2, %r18, 1073127583; - @%p2 bra BB117_4; + @%p2 bra BB123_4; { .reg .b32 %temp; @@ -13357,7 +13624,7 @@ BB117_2: mov.b64 %fd135, {%r19, %r21}; add.s32 %r52, %r51, -1022; -BB117_4: +BB123_4: add.f64 %fd15, %fd135, 0d3FF0000000000000; rcp.approx.ftz.f64 %fd16, %fd15; neg.f64 %fd17, %fd15; @@ -13520,13 +13787,13 @@ BB117_4: mov.b32 %f2, %r35; abs.f32 %f1, %f2; setp.lt.f32 %p4, %f1, 0f4086232B; - @%p4 bra BB117_7; + @%p4 bra BB123_7; setp.lt.f64 %p5, %fd4, 0d0000000000000000; add.f64 %fd129, %fd4, 0d7FF0000000000000; selp.f64 %fd136, 0d0000000000000000, %fd129, %p5; setp.geu.f32 %p6, %f1, 0f40874800; - @%p6 bra BB117_7; + @%p6 bra BB123_7; mov.f64 %fd134, 0d4338000000000000; mov.f64 %fd133, 0d3FF71547652B82FE; @@ -13548,26 +13815,26 @@ BB117_4: mov.b64 %fd131, {%r44, %r43}; mul.f64 %fd136, %fd130, %fd131; -BB117_7: +BB123_7: { .reg .b32 %temp; mov.b64 {%temp, %r45}, %fd136; } and.b32 %r46, %r45, 2147483647; setp.ne.s32 %p7, %r46, 2146435072; - @%p7 bra BB117_9; + @%p7 bra BB123_9; { .reg .b32 %temp; mov.b64 {%r47, %temp}, %fd136; } setp.eq.s32 %p8, %r47, 0; - @%p8 bra BB117_10; + @%p8 bra BB123_10; -BB117_9: +BB123_9: fma.rn.f64 %fd136, %fd136, %fd5, %fd136; -BB117_10: +BB123_10: st.param.f64 [func_retval0+0], %fd136; ret; } http://git-wip-us.apache.org/repos/asf/systemml/blob/0f36780a/src/main/java/org/apache/sysml/hops/DnnOp.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/DnnOp.java b/src/main/java/org/apache/sysml/hops/DnnOp.java index a7d37dc..c4ce466 100644 --- a/src/main/java/org/apache/sysml/hops/DnnOp.java +++ b/src/main/java/org/apache/sysml/hops/DnnOp.java @@ -110,8 +110,6 @@ public class DnnOp extends MultiThreadedHop if( getLops() != null ) return getLops(); - ExecType et = optFindExecType(); - ArrayList<Hop> inputs = getInput(); switch( op ) { @@ -125,6 +123,7 @@ public class DnnOp extends MultiThreadedHop case BIASADD: case BIASMULT: { + ExecType et = optFindExecType(); if(et == ExecType.CP || et == ExecType.GPU) { setLops(constructDnnLops(et, inputs)); break; @@ -137,15 +136,15 @@ public class DnnOp extends MultiThreadedHop case BATCH_NORM2D_TEST: case CHANNEL_SUMS: case UPDATE_NESTEROV_X: + case UPDATE_EMA_VAR: + case RESHAPE_COLMEANS: + case UPDATE_EMA: + case INV_VAR: + case BATCH_NORM2D_BACKWARD_DX: { - if(et == ExecType.GPU) { - setLops(constructDnnLops(et, inputs)); - break; - } - else { - throw new HopsException("Unimplemented DnnOp for execution type: " + et.name()); - } - // break; + // GPU-specific operators + setLops(constructDnnLops(ExecType.GPU, inputs)); + break; } default: throw new HopsException("Unsupported lops construction for operation type '"+op+"'."); @@ -171,10 +170,16 @@ public class DnnOp extends MultiThreadedHop return 14; case BIASADD: case BIASMULT: + case INV_VAR: return 2; case BATCH_NORM2D_TEST: return 6; + case UPDATE_EMA_VAR: + case BATCH_NORM2D_BACKWARD_DX: + return 5; + case RESHAPE_COLMEANS: case CHANNEL_SUMS: + case UPDATE_EMA: return 3; case UPDATE_NESTEROV_X: return 4; @@ -532,7 +537,8 @@ public class DnnOp extends MultiThreadedHop long[] ret = new long[3]; if(op == OpOpDnn.BIASADD || op == OpOpDnn.BIASMULT || op == OpOpDnn.BATCH_NORM2D_TEST || - op == OpOpDnn.UPDATE_NESTEROV_X) { + op == OpOpDnn.UPDATE_NESTEROV_X || op == OpOpDnn.UPDATE_EMA || op == OpOpDnn.INV_VAR || + op == OpOpDnn.BATCH_NORM2D_BACKWARD_DX) { // Same dimension as the first input MatrixCharacteristics[] mc = memo.getAllInputStats(getInput()); ret[0] = mc[0].rowsKnown() ? mc[0].getRows() : -1; @@ -540,13 +546,21 @@ public class DnnOp extends MultiThreadedHop ret[2] = -1; return (ret[0]>=0 && ret[1]>=0) ? ret : null; } - else if(op == OpOpDnn.CHANNEL_SUMS) { + else if(op == OpOpDnn.CHANNEL_SUMS || op == OpOpDnn.UPDATE_EMA_VAR) { long numChannels = Hop.computeSizeInformation(getInput().get(1)); ret[0] = numChannels; ret[1] = 1; ret[2] = -1; return ret; } + else if(op == OpOpDnn.RESHAPE_COLMEANS) { + long numChannels = Hop.computeSizeInformation(getInput().get(1)); + long HW = Hop.computeSizeInformation(getInput().get(2)); + ret[0] = numChannels; + ret[1] = HW; + ret[2] = -1; + return ret; + } refreshSizeInformation(); ret[0] = _dim1; ret[1] = _dim2; ret[2] = _nnz; @@ -739,7 +753,9 @@ public class DnnOp extends MultiThreadedHop @Override public void refreshSizeInformation() { - if(op == OpOpDnn.BIASADD || op == OpOpDnn.BIASMULT || op == OpOpDnn.BATCH_NORM2D_TEST || op == OpOpDnn.UPDATE_NESTEROV_X) { + if(op == OpOpDnn.BIASADD || op == OpOpDnn.BIASMULT || op == OpOpDnn.BATCH_NORM2D_TEST || + op == OpOpDnn.UPDATE_NESTEROV_X || op == OpOpDnn.UPDATE_EMA || op == OpOpDnn.INV_VAR || + op == OpOpDnn.BATCH_NORM2D_BACKWARD_DX) { // Same dimension as the first input Hop input1 = getInput().get(0); setDim1(input1.getDim1()); @@ -747,13 +763,21 @@ public class DnnOp extends MultiThreadedHop _nnz = -1; // cannot infer stats return; } - else if(op == OpOpDnn.CHANNEL_SUMS) { + else if(op == OpOpDnn.CHANNEL_SUMS || op == OpOpDnn.UPDATE_EMA_VAR) { long numChannels = Hop.computeSizeInformation(getInput().get(1)); setDim1(numChannels); setDim2(1); _nnz = -1; // cannot infer stats return; } + else if(op == OpOpDnn.RESHAPE_COLMEANS) { + long numChannels = Hop.computeSizeInformation(getInput().get(1)); + long HW = Hop.computeSizeInformation(getInput().get(2)); + setDim1(numChannels); + setDim2(HW); + _nnz = -1; // cannot infer stats + return; + } // Reset the _cachedParams to avoid incorrect sizes _cachedParams = new DnnParameters(-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, _maxNumThreads); @@ -847,7 +871,9 @@ public class DnnOp extends MultiThreadedHop */ private long getDim(String dimString) { if(op == OpOpDnn.BIASADD || op == OpOpDnn.BIASMULT || op == OpOpDnn.BATCH_NORM2D_TEST || op == OpOpDnn.CHANNEL_SUMS || - op == OpOpDnn.UPDATE_NESTEROV_X) { + op == OpOpDnn.UPDATE_NESTEROV_X || op == OpOpDnn.RESHAPE_COLMEANS || + op == OpOpDnn.UPDATE_EMA_VAR || op == OpOpDnn.UPDATE_EMA || op == OpOpDnn.INV_VAR || + op == OpOpDnn.BATCH_NORM2D_BACKWARD_DX) { throw new RuntimeException("getDim method should not be invoked for " + op.name()); } try { http://git-wip-us.apache.org/repos/asf/systemml/blob/0f36780a/src/main/java/org/apache/sysml/hops/FunctionOp.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/FunctionOp.java b/src/main/java/org/apache/sysml/hops/FunctionOp.java index ea397db..5f177bd 100644 --- a/src/main/java/org/apache/sysml/hops/FunctionOp.java +++ b/src/main/java/org/apache/sysml/hops/FunctionOp.java @@ -181,21 +181,6 @@ public class FunctionOp extends Hop // TODO: To allow for initial version to always run on the GPU return 0; } - else if ( getFunctionName().equalsIgnoreCase("batch_norm2d") || getFunctionName().equalsIgnoreCase("batch_norm2d_train")) { - return OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(0).getDim1(), getOutputs().get(0).getDim2(), 1.0) + - OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(1).getDim1(), getOutputs().get(1).getDim2(), 1.0) + - OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(2).getDim1(), getOutputs().get(2).getDim2(), 1.0) + - OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(3).getDim1(), getOutputs().get(3).getDim2(), 1.0) + - OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(4).getDim1(), getOutputs().get(4).getDim2(), 1.0); - } - else if ( getFunctionName().equalsIgnoreCase("batch_norm2d_test") ) { - return OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(0).getDim1(), getOutputs().get(0).getDim2(), 1.0); - } - else if ( getFunctionName().equalsIgnoreCase("batch_norm2d_backward") ) { - return OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(0).getDim1(), getOutputs().get(0).getDim2(), 1.0) + - OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(1).getDim1(), getOutputs().get(1).getDim2(), 1.0) + - OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(2).getDim1(), getOutputs().get(2).getDim2(), 1.0); - } else if ( getFunctionName().equalsIgnoreCase("svd") ) { long outputU = OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(0).getDim1(), getOutputs().get(0).getDim2(), 1.0); long outputSigma = OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(1).getDim1(), getOutputs().get(1).getDim2(), 1.0); @@ -226,10 +211,6 @@ public class FunctionOp extends Hop return OptimizerUtils.estimateSizeExactSparsity(getInput().get(0).getDim1(), getInput().get(0).getDim2(), 1.0) + 3*OptimizerUtils.estimateSizeExactSparsity(getInput().get(0).getDim1(), 1, 1.0); } - else if (getFunctionName().equalsIgnoreCase("batch_norm2d") || getFunctionName().equalsIgnoreCase("batch_norm2d_backward") || - getFunctionName().equalsIgnoreCase("batch_norm2d_train") || getFunctionName().equalsIgnoreCase("batch_norm2d_test")) { - return 0; - } else if ( getFunctionName().equalsIgnoreCase("lstm") || getFunctionName().equalsIgnoreCase("lstm_backward") ) { // TODO: To allow for initial version to always run on the GPU return 0; @@ -251,9 +232,7 @@ public class FunctionOp extends Hop @Override public boolean isGPUEnabled() { - if(getFunctionName().equalsIgnoreCase("lstm") || getFunctionName().equalsIgnoreCase("lstm_backward") || - getFunctionName().equalsIgnoreCase("batch_norm2d") || getFunctionName().equalsIgnoreCase("batch_norm2d_backward") || - getFunctionName().equalsIgnoreCase("batch_norm2d_train") || getFunctionName().equalsIgnoreCase("batch_norm2d_test")) + if(getFunctionName().equalsIgnoreCase("lstm") || getFunctionName().equalsIgnoreCase("lstm_backward")) return true; else return false; @@ -308,13 +287,6 @@ public class FunctionOp extends Hop throw new RuntimeException("The function " + getFunctionName() + " is only supported on GPU."); _etype = ExecType.GPU; } - else if(isBuiltinFunction && (getFunctionName().equalsIgnoreCase("batch_norm2d") || getFunctionName().equalsIgnoreCase("batch_norm2d_backward"))) { - _etype = ConfigurationManager.isGPU() ? ExecType.GPU : ExecType.CP; - } - else if(isBuiltinFunction && getFunctionName().equalsIgnoreCase("batch_norm2d_train")) { - // Only GPU implementation is supported - _etype = ExecType.GPU; - } else { // Since the memory estimate is only conservative, do not throw // exception if the estimated memory is larger than the budget http://git-wip-us.apache.org/repos/asf/systemml/blob/0f36780a/src/main/java/org/apache/sysml/hops/Hop.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/Hop.java b/src/main/java/org/apache/sysml/hops/Hop.java index 3b461a1..c8356e0 100644 --- a/src/main/java/org/apache/sysml/hops/Hop.java +++ b/src/main/java/org/apache/sysml/hops/Hop.java @@ -1100,7 +1100,8 @@ public abstract class Hop implements ParseInfo MAX_POOL, MAX_POOL_BACKWARD, AVG_POOL, AVG_POOL_BACKWARD, CONV2D, CONV2D_BACKWARD_FILTER, CONV2D_BACKWARD_DATA, BIASADD, BIASMULT, BATCH_NORM2D_TEST, CHANNEL_SUMS, - UPDATE_NESTEROV_X + UPDATE_NESTEROV_X, RESHAPE_COLMEANS, UPDATE_EMA_VAR, UPDATE_EMA, INV_VAR, + BATCH_NORM2D_BACKWARD_DX } public enum DataGenMethod { @@ -1174,8 +1175,13 @@ public abstract class Hop implements ParseInfo HopsConv2Lops.put(OpOpDnn.CONV2D_BACKWARD_FILTER, org.apache.sysml.lops.DnnTransform.OperationTypes.CONV2D_BACKWARD_FILTER); HopsConv2Lops.put(OpOpDnn.CONV2D_BACKWARD_DATA, org.apache.sysml.lops.DnnTransform.OperationTypes.CONV2D_BACKWARD_DATA); HopsConv2Lops.put(OpOpDnn.BATCH_NORM2D_TEST, org.apache.sysml.lops.DnnTransform.OperationTypes.BATCH_NORM2D_TEST); + HopsConv2Lops.put(OpOpDnn.UPDATE_EMA_VAR, org.apache.sysml.lops.DnnTransform.OperationTypes.UPDATE_EMA_VAR); HopsConv2Lops.put(OpOpDnn.CHANNEL_SUMS, org.apache.sysml.lops.DnnTransform.OperationTypes.CHANNEL_SUMS); HopsConv2Lops.put(OpOpDnn.UPDATE_NESTEROV_X, org.apache.sysml.lops.DnnTransform.OperationTypes.UPDATE_NESTEROV_X); + HopsConv2Lops.put(OpOpDnn.RESHAPE_COLMEANS, org.apache.sysml.lops.DnnTransform.OperationTypes.RESHAPE_COLMEANS); + HopsConv2Lops.put(OpOpDnn.UPDATE_EMA, org.apache.sysml.lops.DnnTransform.OperationTypes.UPDATE_EMA); + HopsConv2Lops.put(OpOpDnn.INV_VAR, org.apache.sysml.lops.DnnTransform.OperationTypes.INV_VAR); + HopsConv2Lops.put(OpOpDnn.BATCH_NORM2D_BACKWARD_DX, org.apache.sysml.lops.DnnTransform.OperationTypes.BATCH_NORM2D_BACKWARD_DX); } protected static final HashMap<Hop.Direction, org.apache.sysml.lops.PartialAggregate.DirectionTypes> HopsDirection2Lops; http://git-wip-us.apache.org/repos/asf/systemml/blob/0f36780a/src/main/java/org/apache/sysml/hops/rewrite/HopDagPatternMatcher.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/HopDagPatternMatcher.java b/src/main/java/org/apache/sysml/hops/rewrite/HopDagPatternMatcher.java new file mode 100644 index 0000000..7c70b7b --- /dev/null +++ b/src/main/java/org/apache/sysml/hops/rewrite/HopDagPatternMatcher.java @@ -0,0 +1,378 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sysml.hops.rewrite; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.function.Function; +import java.util.function.Predicate; +import org.apache.sysml.api.DMLScript; +import org.apache.sysml.conf.ConfigurationManager; +import org.apache.sysml.hops.AggUnaryOp; +import org.apache.sysml.hops.Hop; +import org.apache.sysml.hops.OptimizerUtils; +import org.apache.sysml.hops.Hop.AggOp; +import org.apache.sysml.hops.Hop.Direction; +import org.apache.sysml.hops.Hop.OpOp1; +import org.apache.sysml.hops.Hop.OpOp2; +import org.apache.sysml.hops.Hop.OpOpDnn; +import org.apache.sysml.hops.Hop.ReOrgOp; +import org.apache.sysml.parser.Expression.DataType; +import org.apache.sysml.runtime.DMLRuntimeException; +import org.apache.sysml.runtime.instructions.gpu.context.GPUContextPool; +import org.apache.sysml.utils.Explain; + +/** + * Please see org.apache.sysml.hops.rewrite.RewriteGPUSpecificOps class for usage and design documentation. + */ +public class HopDagPatternMatcher { + static final HashSet<String> DEBUG_PATTERNS; + static { + // DEBUG_PATTERNS = new HashSet<>(); + // DEBUG_PATTERNS.add("batchNormdX"); + DEBUG_PATTERNS = null; + } + + // Predicates for the current HOP + List<HopPredicate> _predicates = new ArrayList<>(); + // Child matchers + List<HopDagPatternMatcher> _children = new ArrayList<>(); + private boolean _isLeaf = false; + + static boolean DEBUG_REWRITES = false; // This is set by HopPatternRewriter. Please use DEBUG_PATTERNS instead. + + // Simple utility for debugging the rewrites + public static class HopPredicate implements Predicate<Hop> { + final String _name; + final Function<Hop, Boolean> _pred; + public HopPredicate(String name, Function<Hop, Boolean> pred) { + _name = name; + _pred = pred; + } + @Override + public boolean test(Hop h) { + return _pred.apply(h); + } + @Override + public String toString() { + return _name; + } + } + + /** + * Adds a predicate to the pattern matcher + * + * @param name name of the pattern for debugging + * @param pred higher order function that takes as an input a hop and returns true if the pattern matches else false + * @return this + */ + public HopDagPatternMatcher addPredicate(String name, Function<Hop, Boolean> pred) { + _predicates.add(new HopPredicate(name, pred)); + return this; + } + + /** + * Add child pattern matcher + * @param children list of childer + * @return this + */ + public HopDagPatternMatcher addChildMatcher(HopDagPatternMatcher... children) { + for(int i = 0; i < children.length; i++) { + _children.add(children[i]); + } + return this; + } + + /** + * Get the matched HOP DAGs + * @param varName variable names + * @return matched HOP + */ + public Hop getMatchedHop(String varName) { + + if(matchedHops == null || !matchedHops.containsKey(varName)) { + throw new RuntimeException("Incorrect usage: the variable " + varName + " is not registered as input."); + } + return matchedHops.get(varName); + } + + /** + * Return the value + * + * @param varName variable name + * @return the value of the LiteralOp + */ + public double getLiteralValue(String varName) { + return OptimizerUtils.rEvalSimpleDoubleExpression(getMatchedHop(varName), new HashMap<>()); + } + + @Override + public String toString() { + return _predicates.size() >= 1 ? _predicates.get(0).toString() : ""; + } + + /** + * Match the given HOP DAG + * + * @param h root node of the HOP DAG + * @return true if HOP DAG matches + */ + public boolean matches(Hop h) { + visited.clear(); + matchedHops.clear(); + return matchHelper(this, h); + } + + private HashMap<String, Hop> matchedHops = new HashMap<>(); + private String variableName; + private HashMap<HopDagPatternMatcher, Hop> visited = new HashMap<>(); // Map of matched hops + private boolean matchHelper(HopDagPatternMatcher root, Hop h) { + if(h == null) { + return false; + } + else if(_children.size() > 0 && h.getInput().size() < _children.size()) { + if(DEBUG_REWRITES) { + System.out.println("The expected number of children (" + _children.size() + ") didnot match the number of inputs (" + h.getInput().size() + ") " + this); + } + return false; + } + if(root.visited.containsKey(this)) { + Hop h1 = root.visited.get(this); + if(h == h1) { + if(DEBUG_REWRITES) + System.out.println("MATCHED: Early exit as the given HOP has been already matched by the matcher." + this); + return true; // Early exit as the given HOP has been already matched by the matcher + } + else if(_isLeaf) { + if(h.getDataType() == h1.getDataType() && h.getDataType() == DataType.SCALAR) { + return OptimizerUtils.rEvalSimpleDoubleExpression(h, new HashMap<>()) == OptimizerUtils.rEvalSimpleDoubleExpression(h1, new HashMap<>()); + } + return false; // Mismatched or unknown datatypes or matched with different hops + } + } + + for(HopPredicate p : _predicates) { + if(!p.test(h)) { + if(DEBUG_REWRITES) { + System.out.println("The predicate " + p.toString() + " failed."); + } + return false; + } + } + int index = 0; + for(HopDagPatternMatcher child : _children) { + if(!child.matchHelper(root, h.getInput().get(index))) { + return false; + } + index++; + } + if(_isLeaf) { + root.matchedHops.put(variableName, h); + } + + root.visited.put(this, h); + if(DEBUG_REWRITES) + System.out.println("MATCHED: " + this + " to " + Explain.explain(h)); + return true; + } + + + // Simple helper utilities for adding predicates + private HopDagPatternMatcher isScalar() { + return this.addPredicate("isScalar", h -> h.getDataType() == DataType.SCALAR); + } + private HopDagPatternMatcher isMatrix() { + return this.addPredicate("isMatrix", h -> h.getDataType() == DataType.MATRIX); + } + public HopDagPatternMatcher fitsOnGPU(double constant) { + return this.addPredicate("fitsOnGPU", h -> _fitsOnGPU(h, constant)); + } + + // Factory methods: + public static HopDagPatternMatcher dummy = new HopDagPatternMatcher(); + public static HopDagPatternMatcher rowMeans(HopDagPatternMatcher child1) { + return new HopDagPatternMatcher().addPredicate("rowMeans", h -> + h instanceof AggUnaryOp && ((AggUnaryOp)h).getOp() == AggOp.MEAN && ((AggUnaryOp)h).getDirection() == Direction.Row) + .addChildMatcher(child1); + } + public static HopDagPatternMatcher rowVars(HopDagPatternMatcher child1) { + return new HopDagPatternMatcher().addPredicate("rowVars", h -> + h instanceof AggUnaryOp && ((AggUnaryOp)h).getOp() == AggOp.VAR && ((AggUnaryOp)h).getDirection() == Direction.Row) + .addChildMatcher(child1); + } + public static HopDagPatternMatcher colVars(HopDagPatternMatcher child1) { + return new HopDagPatternMatcher().addPredicate("colVars", h -> + h instanceof AggUnaryOp && ((AggUnaryOp)h).getOp() == AggOp.VAR && ((AggUnaryOp)h).getDirection() == Direction.Col) + .addChildMatcher(child1); + } + public static HopDagPatternMatcher leaf(String _variableName, DataType dt) { + HopDagPatternMatcher ret = new HopDagPatternMatcher(); + ret._isLeaf = true; + ret.variableName = _variableName; + if(dt == DataType.MATRIX) { + return ret.isMatrix(); + } + else if(dt == DataType.SCALAR) { + return ret.isScalar(); + } + else if(dt == DataType.UNKNOWN) { + return ret; + } + else { + throw new DMLRuntimeException("Unsupported datatype in pattern matcher:" + dt.name()); + } + } + public static HopDagPatternMatcher rowSums(HopDagPatternMatcher child1) { + return new HopDagPatternMatcher().addPredicate("rowSums", h -> + h instanceof AggUnaryOp && ((AggUnaryOp)h).getOp() == AggOp.SUM && ((AggUnaryOp)h).getDirection() == Direction.Row) + .addChildMatcher(child1); + } + public static HopDagPatternMatcher colSums(HopDagPatternMatcher child1) { + return new HopDagPatternMatcher().addPredicate("colSums", h -> + h instanceof AggUnaryOp && ((AggUnaryOp)h).getOp() == AggOp.SUM && ((AggUnaryOp)h).getDirection() == Direction.Col) + .addChildMatcher(child1); + } + public static HopDagPatternMatcher colMeans(HopDagPatternMatcher child1) { + return new HopDagPatternMatcher().addPredicate("colSums", h -> + h instanceof AggUnaryOp && ((AggUnaryOp)h).getOp() == AggOp.MEAN && ((AggUnaryOp)h).getDirection() == Direction.Col) + .addChildMatcher(child1); + } + public static HopDagPatternMatcher matrix(HopDagPatternMatcher X, HopDagPatternMatcher rows, HopDagPatternMatcher cols) { + return new HopDagPatternMatcher().addPredicate("matrix_reshape", h -> HopRewriteUtils.isReorg(h, ReOrgOp.RESHAPE)) + .addChildMatcher(X, rows, cols); + } + public static HopDagPatternMatcher matrix(double X, HopDagPatternMatcher rows, HopDagPatternMatcher cols) { + return new HopDagPatternMatcher().addPredicate("matrix_datagen", h -> HopRewriteUtils.isDataGenOpWithConstantValue(h, X)) + .addChildMatcher(rows, cols); + } + public static HopDagPatternMatcher matrix(double X, HopDagPatternMatcher rows, long cols) { + return new HopDagPatternMatcher().addPredicate("matrix_datagen", h -> HopRewriteUtils.isDataGenOpWithConstantValue(h, X) && + h.getDim2() == cols) + .addChildMatcher(rows, dummy); + } + public static HopDagPatternMatcher matrix(double X, long rows, HopDagPatternMatcher cols) { + return new HopDagPatternMatcher().addPredicate("matrix_datagen", h -> HopRewriteUtils.isDataGenOpWithConstantValue(h, X) && + h.getDim1() == rows) + .addChildMatcher(dummy, cols); + } + public static HopDagPatternMatcher matrix(double X, long rows, long cols) { + return new HopDagPatternMatcher().addPredicate("matrix_datagen", h -> HopRewriteUtils.isDataGenOpWithConstantValue(h, X) && + h.getDim1() == rows && h.getDim2() == cols) + .addChildMatcher(dummy, dummy); + } + public static HopDagPatternMatcher bias_add(HopDagPatternMatcher child1, HopDagPatternMatcher child2) { + return new HopDagPatternMatcher().addPredicate("bias_add", h -> HopRewriteUtils.isDnn(h, OpOpDnn.BIASADD)) + .addChildMatcher(child1, child2); + } + public static HopDagPatternMatcher bias_multiply(HopDagPatternMatcher child1, HopDagPatternMatcher child2) { + return new HopDagPatternMatcher().addPredicate("bias_multiply", h -> HopRewriteUtils.isDnn(h, OpOpDnn.BIASMULT)) + .addChildMatcher(child1, child2); + } + public static HopDagPatternMatcher unaryMinus(HopDagPatternMatcher child) { + return new HopDagPatternMatcher().addPredicate("unaryMinus", h -> HopRewriteUtils.isBinary(h, OpOp2.MINUS) + && HopRewriteUtils.isLiteralOfValue(h.getInput().get(0), 0)) + .addChildMatcher(dummy, child); + } + public static HopDagPatternMatcher sqrt(HopDagPatternMatcher child) { + return new HopDagPatternMatcher().addPredicate("sqrt", h -> HopRewriteUtils.isUnary(h, OpOp1.SQRT)) + .addChildMatcher(child); + } + public static HopDagPatternMatcher div(HopDagPatternMatcher child1, HopDagPatternMatcher child2) { + return new HopDagPatternMatcher().addPredicate("div", h -> HopRewriteUtils.isBinary(h, OpOp2.DIV)) + .addChildMatcher(child1, child2); + } + public static HopDagPatternMatcher div(double child1, HopDagPatternMatcher child2) { + return new HopDagPatternMatcher().addPredicate("div", h -> HopRewriteUtils.isBinary(h, OpOp2.DIV) && + HopRewriteUtils.isLiteralOfValue(h.getInput().get(0), child1)) + .addChildMatcher(dummy, child2); + } + public static HopDagPatternMatcher div(HopDagPatternMatcher child1, double child2) { + return new HopDagPatternMatcher().addPredicate("div", h -> HopRewriteUtils.isBinary(h, OpOp2.DIV) && + HopRewriteUtils.isLiteralOfValue(h.getInput().get(1), child2)) + .addChildMatcher(child1, dummy); + } + + public static HopDagPatternMatcher pow(HopDagPatternMatcher child1, HopDagPatternMatcher child2) { + return new HopDagPatternMatcher().addPredicate("pow", h -> HopRewriteUtils.isBinary(h, OpOp2.POW)) + .addChildMatcher(child1, child2); + } + public static HopDagPatternMatcher pow(HopDagPatternMatcher child1, double child2) { + return new HopDagPatternMatcher().addPredicate("pow", h -> HopRewriteUtils.isBinary(h, OpOp2.POW) && + HopRewriteUtils.isLiteralOfValue(h.getInput().get(1), child2)) + .addChildMatcher(child1, dummy); + } + private static boolean matchDimensions(Hop h1, Hop h2) { + return h1.getDim1() == h2.getDim1() && h1.getDim2() == h2.getDim2(); + } + // This is used to differentiate between matrix-matrix and matrix-vector operations. + public static HopDagPatternMatcher mm_plus(HopDagPatternMatcher child1, HopDagPatternMatcher child2) { + return new HopDagPatternMatcher().addPredicate("plus", h -> HopRewriteUtils.isBinary(h, OpOp2.PLUS) + && matchDimensions(h.getInput().get(0), h.getInput().get(1))) + .addChildMatcher(child1, child2); + } + public static HopDagPatternMatcher plus(HopDagPatternMatcher child1, HopDagPatternMatcher child2) { + return new HopDagPatternMatcher().addPredicate("plus", h -> HopRewriteUtils.isBinary(h, OpOp2.PLUS)) + .addChildMatcher(child1, child2); + } + public static HopDagPatternMatcher plus(double child1, HopDagPatternMatcher child2) { + return new HopDagPatternMatcher().addPredicate("plus", h -> HopRewriteUtils.isBinary(h, OpOp2.PLUS) && + HopRewriteUtils.isLiteralOfValue(h.getInput().get(0), child1)) + .addChildMatcher(dummy, child2); + } + public static HopDagPatternMatcher plus(HopDagPatternMatcher child1, double child2) { + return new HopDagPatternMatcher().addPredicate("plus", h -> HopRewriteUtils.isBinary(h, OpOp2.PLUS) && + HopRewriteUtils.isLiteralOfValue(h.getInput().get(1), child2)) + .addChildMatcher(child1, dummy); + } + public static HopDagPatternMatcher minus(HopDagPatternMatcher child1, HopDagPatternMatcher child2) { + return new HopDagPatternMatcher().addPredicate("minus", h -> HopRewriteUtils.isBinary(h, OpOp2.MINUS)) + .addChildMatcher(child1, child2); + } + public static HopDagPatternMatcher minus(double child1, HopDagPatternMatcher child2) { + return new HopDagPatternMatcher().addPredicate("minus", h -> HopRewriteUtils.isBinary(h, OpOp2.MINUS) && + HopRewriteUtils.isLiteralOfValue(h.getInput().get(0), child1)) + .addChildMatcher(dummy, child2); + } + public static HopDagPatternMatcher minus(HopDagPatternMatcher child1, double child2) { + return new HopDagPatternMatcher().addPredicate("minus", h -> HopRewriteUtils.isBinary(h, OpOp2.MINUS) && + HopRewriteUtils.isLiteralOfValue(h.getInput().get(1), child2)) + .addChildMatcher(child1, dummy); + } + public static HopDagPatternMatcher mult(HopDagPatternMatcher child1, HopDagPatternMatcher child2) { + return new HopDagPatternMatcher().addPredicate("mult", h -> HopRewriteUtils.isBinary(h, OpOp2.MULT)) + .addChildMatcher(child1, child2); + } + public static HopDagPatternMatcher mult(double child1, HopDagPatternMatcher child2) { + return new HopDagPatternMatcher().addPredicate("mult", h -> HopRewriteUtils.isBinary(h, OpOp2.MULT) && + HopRewriteUtils.isLiteralOfValue(h.getInput().get(0), child1)) + .addChildMatcher(dummy, child2); + } + public static HopDagPatternMatcher mult(HopDagPatternMatcher child1, double child2) { + return new HopDagPatternMatcher().addPredicate("mult", h -> HopRewriteUtils.isBinary(h, OpOp2.MULT) && + HopRewriteUtils.isLiteralOfValue(h.getInput().get(1), child2)) + .addChildMatcher(child1, dummy); + } + private static boolean _fitsOnGPU(Hop h, double multiplier) { + double memEst = multiplier*h.getMemEstimate(); + return ConfigurationManager.isGPU() && h.dimsKnown() && OptimizerUtils.isMemoryBasedOptLevel() && + memEst < OptimizerUtils.getLocalMemBudget() && memEst < GPUContextPool.initialGPUMemBudget(); + } +} http://git-wip-us.apache.org/repos/asf/systemml/blob/0f36780a/src/main/java/org/apache/sysml/hops/rewrite/HopPatternRewriter.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/HopPatternRewriter.java b/src/main/java/org/apache/sysml/hops/rewrite/HopPatternRewriter.java new file mode 100644 index 0000000..02472ed --- /dev/null +++ b/src/main/java/org/apache/sysml/hops/rewrite/HopPatternRewriter.java @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sysml.hops.rewrite; + +import java.util.function.Function; + +import org.apache.sysml.hops.Hop; + +/** + * This class is used with HopRewriteRuleWithPatternMatcher to implement the following pattern matching logic: + * ArrayList<HopPatternRewriter> patternRewriters = getPatternRewriter(); + * for(HopPatternRewriter patternRewriter : patternRewriters) { + * hi = patternRewriter.rewrite(hi); + * } + * + * Please see org.apache.sysml.hops.rewrite.RewriteGPUSpecificOps class for usage and design documentation. + */ +public class HopPatternRewriter { + private final HopDagPatternMatcher _matcher; + private final Function<Hop, Hop> _replacer; + private final String _name; + public HopPatternRewriter(String name, HopDagPatternMatcher matcher, Function<Hop, Hop> replacer) { + _name = name; + _matcher = matcher; + _replacer = replacer; + } + + public Hop rewrite(Hop root) { + boolean printMessage = HopDagPatternMatcher.DEBUG_PATTERNS != null && HopDagPatternMatcher.DEBUG_PATTERNS.contains(_name); + if(printMessage) { + HopDagPatternMatcher.DEBUG_REWRITES = true; + System.out.println("-----------------------------------"); + System.out.println(org.apache.sysml.utils.Explain.explain(root)); + } + if(_matcher.matches(root)) { + Hop newHop = _replacer.apply(root); + if(printMessage) { + if(newHop == root) + System.out.println("Initial pattern match for " + _name + " succeeded but replacer returned the same HOP."); + else + System.out.println("Pattern match for " + _name + " succeeded."); + HopDagPatternMatcher.DEBUG_REWRITES = false; + System.out.println("-----------------------------------"); + } + return newHop; + } + else { + if(printMessage) { + System.out.println("Pattern match for " + _name + " failed."); + HopDagPatternMatcher.DEBUG_REWRITES = false; + System.out.println("-----------------------------------"); + } + return root; + } + } +} http://git-wip-us.apache.org/repos/asf/systemml/blob/0f36780a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteRuleWithPatternMatcher.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteRuleWithPatternMatcher.java b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteRuleWithPatternMatcher.java new file mode 100644 index 0000000..854eca3 --- /dev/null +++ b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteRuleWithPatternMatcher.java @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sysml.hops.rewrite; + +import java.util.ArrayList; + +import org.apache.sysml.hops.Hop; + +/** + * Simple utility class that implements generic structure for HopRewriteRule. + * Please see org.apache.sysml.hops.rewrite.RewriteGPUSpecificOps class for usage and design documentation. + */ +public abstract class HopRewriteRuleWithPatternMatcher extends HopRewriteRule { + + public abstract ArrayList<HopPatternRewriter> getPatternRewriter(); + + @Override + public ArrayList<Hop> rewriteHopDAGs(ArrayList<Hop> roots, ProgramRewriteStatus state) { + if( roots == null ) + return roots; + + //one pass rewrite-descend (rewrite created pattern) + for( int i = 0; i < roots.size(); i++ ) + applyRules(roots, roots.get(i), false ); + Hop.resetVisitStatus(roots, true); + + //one pass descend-rewrite (for rollup) + for( int i = 0; i < roots.size(); i++ ) + applyRules(roots, roots.get(i), true ); + Hop.resetVisitStatus(roots, true); + + return roots; + } + + @Override + public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus state) { + if( root == null ) + return root; + + //one pass rewrite-descend (rewrite created pattern) + applyRules(null, root, false ); + + root.resetVisitStatus(); + + //one pass descend-rewrite (for rollup) + applyRules(null, root, true ); + + return root; + } + + /** + * Apply rules + * + * @param roots root operators + * @param hop high-level operator + * @param descendFirst true if recursively process children first + */ + private void applyRules(ArrayList<Hop> roots, Hop hop, boolean descendFirst) + { + if(hop.isVisited()) + return; + + //recursively process children + for( int i=0; i<hop.getInput().size(); i++) { + Hop hi = hop.getInput().get(i); + + //process childs recursively first (to allow roll-up) + if( descendFirst ) + applyRules(roots, hi, descendFirst); //see below + + ArrayList<HopPatternRewriter> patternRewriters = getPatternRewriter(); + for(HopPatternRewriter patternRewriter : patternRewriters) { + hi = patternRewriter.rewrite(hi); + } + + if( !descendFirst ) + applyRules(roots, hi, descendFirst); + } + + hop.setVisited(); + } +} http://git-wip-us.apache.org/repos/asf/systemml/blob/0f36780a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java index 271142d..2351f5f 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java @@ -719,6 +719,26 @@ public class HopRewriteUtils return ternOp; } + public static DnnOp createDnnOp(OpOpDnn op, Hop... hops) { + ArrayList<Hop> inHops = new ArrayList<Hop>(); + for(Hop h : hops) { + inHops.add(h); + } + return new DnnOp("tmp", DataType.MATRIX, ValueType.DOUBLE, + op, inHops); + } + + public static DnnOp createDnnOp(HopDagPatternMatcher matcher, OpOpDnn op, String... varNames) { + ArrayList<Hop> inHops = new ArrayList<Hop>(); + for(String v : varNames) { + inHops.add(matcher.getMatchedHop(v)); + } + return new DnnOp("tmp", DataType.MATRIX, ValueType.DOUBLE, + op, inHops); + } + + + public static void setOutputParameters( Hop hop, long rlen, long clen, int brlen, int bclen, long nnz ) { hop.setDim1( rlen ); hop.setDim2( clen );