[SYSTEMML-445] Removed batch_norm builtin functions

- Removed batch_norm builtin functions to exploit codegen in CP.
- Added rewrites for compiling efficient CuDNN operators.
- Added rewrites for SGD update operations.
- To simplify adding new GPU rewrites, added HopDagPatternMatcher that allows 
for pattern matching at the HOP-level. This can be extended for other rewrites 
as well.
- Added GPU tests to validate the rewrites.
- Updated the DML language documentation.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/0f36780a
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/0f36780a
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/0f36780a

Branch: refs/heads/master
Commit: 0f36780a8244c6e728d37c32a79e00ed181211ad
Parents: 81419ae
Author: Niketan Pansare <npan...@us.ibm.com>
Authored: Thu Aug 30 15:40:44 2018 -0700
Committer: Niketan Pansare <npan...@us.ibm.com>
Committed: Thu Aug 30 15:40:44 2018 -0700

----------------------------------------------------------------------
 docs/dml-language-reference.md                  |    2 -
 scripts/nn/layers/batch_norm2d.dml              |   60 +-
 scripts/nn/layers/batch_norm2d_old.dml          |  200 ----
 src/main/cpp/kernels/SystemML.cu                |   56 +-
 src/main/cpp/kernels/SystemML.ptx               |  321 +++++-
 src/main/java/org/apache/sysml/hops/DnnOp.java  |   56 +-
 .../java/org/apache/sysml/hops/FunctionOp.java  |   30 +-
 src/main/java/org/apache/sysml/hops/Hop.java    |    8 +-
 .../hops/rewrite/HopDagPatternMatcher.java      |  378 +++++++
 .../sysml/hops/rewrite/HopPatternRewriter.java  |   72 ++
 .../HopRewriteRuleWithPatternMatcher.java       |   98 ++
 .../sysml/hops/rewrite/HopRewriteUtils.java     |   20 +
 .../hops/rewrite/RewriteGPUSpecificOps.java     | 1027 +++++-------------
 .../org/apache/sysml/lops/DnnTransform.java     |   53 +-
 .../sysml/parser/BuiltinFunctionExpression.java |   61 +-
 .../org/apache/sysml/parser/DMLTranslator.java  |    2 -
 .../org/apache/sysml/parser/Expression.java     |    2 +-
 .../instructions/GPUInstructionParser.java      |   10 +-
 .../instructions/gpu/DnnGPUInstruction.java     |  526 +++++----
 .../gpu/GPUDenseInputPointerFetcher.java        |  111 ++
 .../gpu/context/GPUMemoryManager.java           |    2 +-
 .../runtime/matrix/data/LibMatrixCUDA.java      |  110 +-
 .../runtime/matrix/data/LibMatrixCuDNN.java     |   37 +-
 .../apache/sysml/test/gpu/BatchNormTest.java    |   47 +-
 24 files changed, 1818 insertions(+), 1471 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/0f36780a/docs/dml-language-reference.md
----------------------------------------------------------------------
diff --git a/docs/dml-language-reference.md b/docs/dml-language-reference.md
index 924336a..cdcc529 100644
--- a/docs/dml-language-reference.md
+++ b/docs/dml-language-reference.md
@@ -1522,8 +1522,6 @@ Hence, the images are internally represented as a matrix 
with dimension (N, C *
 | bias_add                                    | input, bias              | 
[batch_size X num_channels* height_image* width_image]    | [num_channels X 1]  
                                      | [batch_size X num_channels* 
height_image* width_image]                                      |               
                                                                                
                                                                                
                | Adds the bias (row vector of size num_channels) to input with 
the given num_channels                                                          
    |
 | bias_multiply                               | input, bias              | 
[batch_size X num_channels* height_image* width_image]    | [num_channels X 1]  
                                      | [batch_size X num_channels* 
height_image* width_image]                                      |               
                                                                                
                                                                                
                | Multiplies the bias (row vector of size num_channels) to 
input with the given num_channels                                               
         |
 | lstm                                        | X,  W, bias, out0, c0    | 
[batch_size X seq_length*num_features]                    | 
[num_features+hidden_size X 4*hidden_size]                | [batch_size X 
seq_length*hidden_size] if return_sequences else  [batch_size X hidden_size]  | 
return_sequences                                                                
                                                                                
                              | Perform computation for single-layer 
unidirectional LSTM (outputs: out, carryOut)                                    
                             |
-| batch_norm2d                                | input                    | 
[batch_size X num_channels* height_image* width_image]    |                     
                                      | [batch_size X num_channels* 
height_image* width_image]                                      | scale, shift, 
exponentialMovingAverage_Mean, exponentialMovingAverage_Variance, mode, 
epsilon, momentum                                                               
                        | Performs batch normalization operation  (outputs: 
updated exponential moving average mean and variance, cache of the batch mean 
and variance)     |
-| batch_norm2d_backward                       | input, dout              | 
[batch_size X num_channels* height_image* width_image]    | [batch_size X 
num_channels* height_image* width_image]    | [batch_size X num_channels* 
height_image* width_image]                                      | scale, 
epsilon, cache_mean (from forward), cache_inv_var (from forward)                
                                                                                
                       | Computed backpropagation error for batch normalization 
operation                                                                       
           |
 
 Note: the builtin functions `batch_norm2d` and `batch_norm2d_backward` are 
deprecated and will be removed in the next release. The `lstm` builtin function 
is in experimental phase and is only supported for the GPU backend. 
 

http://git-wip-us.apache.org/repos/asf/systemml/blob/0f36780a/scripts/nn/layers/batch_norm2d.dml
----------------------------------------------------------------------
diff --git a/scripts/nn/layers/batch_norm2d.dml 
b/scripts/nn/layers/batch_norm2d.dml
index 2a98857..c68f23d 100644
--- a/scripts/nn/layers/batch_norm2d.dml
+++ b/scripts/nn/layers/batch_norm2d.dml
@@ -83,8 +83,41 @@ forward = function(matrix[double] X, matrix[double] gamma, 
matrix[double] beta,
    *  - cache_inv_var: Cache of the inverse variance, of shape (C, 1).
    *      Note: This is used for performance during training.
    */
-  out = X; ema_mean_upd = ema_mean; ema_var_upd = ema_var;  cache_mean = 
ema_mean;  cache_inv_var = ema_var
-  [out, ema_mean_upd, ema_var_upd, cache_mean, cache_inv_var] = 
batch_norm2d(X, gamma, beta, ema_mean, ema_var, mode, epsilon, mu)
+  N = nrow(X)
+
+  if (mode == 'train') {
+    # Compute channel-wise mean and variance
+    # Since we don't have tensors, we will compute the means and variances in 
a piece-wise fashion.
+    #  - mean of total group is mean of subgroup means
+    #  - variance is the mean of the subgroup variances + the variance of the 
subgroup means
+    subgrp_means = matrix(colMeans(X), rows=C, cols=Hin*Win)
+    subgrp_vars = matrix(colVars(X) * ((N-1)/N), rows=C, cols=Hin*Win)  # 
uncorrected variances
+    mean = rowMeans(subgrp_means)  # shape (C, 1)
+    var = rowMeans(subgrp_vars) + 
rowVars(subgrp_means)*(((Hin*Win)-1)/(Hin*Win))  # shape (C, 1)
+    # Update moving averages
+    ema_mean_upd = mu*ema_mean + (1-mu)*mean
+    ema_var_upd = mu*ema_var + (1-mu)*var
+  }
+  else {
+    # Use moving averages of mean and variance during testing
+    mean = ema_mean
+    var = ema_var
+    ema_mean_upd = ema_mean
+    ema_var_upd = ema_var
+  }
+
+  # Save variable for backward pass
+  cache_mean = mean
+  cache_inv_var = 1/sqrt(var+epsilon)
+  
+  # Normalize, shift, and scale
+  # norm = (X-mean)*(var+epsilon)^(-1/2)
+  #      = (X-mean) / sqrt(var+epsilon)
+  centered = bias_add(X, -mean)  # shape (N, C*Hin*Win)
+  norm = bias_multiply(centered, cache_inv_var)  # shape (N, C*Hin*Win)
+  # out = norm*gamma + beta
+  scaled = bias_multiply(norm, gamma)  # shape (N, C*Hin*Win)
+  out = bias_add(scaled, beta)  # shape (N, C*Hin*Win)
 }
 
 backward = function(matrix[double] dout, 
@@ -119,9 +152,27 @@ backward = function(matrix[double] dout,
    *  - dbeta: Gradient wrt `b`, of shape (C, 1).
    *
    */
+  N = nrow(X)
+  oneByN = 1/N
+  oneByHW = 1/(Hin*Win)
+  
+  mean = cache_mean
+  centered = bias_add(X, -mean)  # shape (N, C*Hin*Win)
+  norm = bias_multiply(centered, cache_inv_var)  # shape (N, C*Hin*Win)
   # Compute gradients during training
-  dX = X; dgamma = gamma; dbeta = gamma;
-  [dX, dgamma, dbeta] = batch_norm2d_backward(X, dout, gamma, epsilon, 
cache_mean, cache_inv_var)
+  dgamma = util::channel_sums(dout*norm, C, Hin, Win)  # shape (C, 1)
+  dbeta = util::channel_sums(dout, C, Hin, Win)  # shape (C, 1)
+  dnorm = bias_multiply(dout, gamma)  # shape (N, C*Hin*Win)
+  dvar = util::channel_sums((-1/2) * bias_multiply(centered, cache_inv_var^3) 
* dnorm,
+                          C, Hin, Win)  # shape (C, 1)
+  dmean_norm_branch = util::channel_sums(bias_multiply(dnorm, -cache_inv_var), 
C, Hin, Win)
+  dmean_var_branch =  util::channel_sums((-2*oneByN*oneByHW) * centered, C, 
Hin, Win)
+  dmean_var_branch = dmean_var_branch * dvar  # we can't use a function within 
an expression yet
+  dmean = dmean_norm_branch + dmean_var_branch  # shape (C, 1)
+  dX_norm_branch = bias_multiply(dnorm, cache_inv_var)
+  dX_mean_branch = (oneByN*oneByHW) * bias_add(matrix(0, rows=1, 
cols=C*Hin*Win), dmean)
+  dX_var_branch = (2*oneByN*oneByHW) * bias_multiply(centered, dvar)
+  dX = dX_norm_branch + dX_mean_branch + dX_var_branch  # shape (N, C*Hin*Win)
 }
 
 init = function(int C)
@@ -149,3 +200,4 @@ init = function(int C)
    ema_mean = matrix(0, rows=C, cols=1)
    ema_var = matrix(1, rows=C, cols=1)
 }
+

http://git-wip-us.apache.org/repos/asf/systemml/blob/0f36780a/scripts/nn/layers/batch_norm2d_old.dml
----------------------------------------------------------------------
diff --git a/scripts/nn/layers/batch_norm2d_old.dml 
b/scripts/nn/layers/batch_norm2d_old.dml
deleted file mode 100644
index 2aba2e6..0000000
--- a/scripts/nn/layers/batch_norm2d_old.dml
+++ /dev/null
@@ -1,200 +0,0 @@
-#-------------------------------------------------------------
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-#
-#-------------------------------------------------------------
-
-/*
- * 2D (Spatial) Batch Normalization layer.
- */
-source("nn/util.dml") as util
-
-forward = function(matrix[double] X, matrix[double] gamma, matrix[double] beta,
-                   int C, int Hin, int Win, string mode,
-                   matrix[double] ema_mean, matrix[double] ema_var,
-                   double mu, double epsilon)
-    return (matrix[double] out, matrix[double] ema_mean_upd, matrix[double] 
ema_var_upd,
-            matrix[double] cache_mean, matrix[double] cache_inv_var) {
-  /*
-   * Computes the forward pass for a 2D (spatial) batch normalization
-   * layer.  The input data has N examples, each represented as a 3D
-   * volume unrolled into a single vector.
-   *
-   * A spatial batch normalization layer uses the per-channel sample
-   * mean and per-channel uncorrected sample variance during training
-   * to normalize each channel of the input data.  Additionally, it
-   * introduces learnable parameters (gamma, beta) to control the
-   * amount of normalization.
-   *
-   *   `y = ((x-mean) / sqrt(var+eps)) * gamma + beta`
-   *
-   * This implementation maintains exponential moving averages of the
-   * mean and variance during training for use during testing.
-   *
-   * Reference:
-   *  - Batch Normalization: Accelerating Deep Network Training by
-   *    Reducing Internal Covariate Shift, S. Ioffe & C. Szegedy, 2015
-   *    - https://arxiv.org/abs/1502.03167
-   *
-   * Inputs:
-   *  - X: Inputs, of shape (N, C*Hin*Win).
-   *  - gamma: Scale parameters, of shape (C, 1).
-   *  - beta: Shift parameters, of shape (C, 1).
-   *  - C: Number of input channels (dimensionality of input depth).
-   *  - Hin: Input height.
-   *  - Win: Input width.
-   *  - mode: 'train' or 'test' to indicate if the model is currently
-   *      being trained or tested.  During training, the current batch
-   *      mean and variance will be used to normalize the inputs, while
-   *      during testing, the exponential average of the mean and
-   *      variance over all previous batches will be used.
-   *  - ema_mean: Exponential moving average of the mean, of
-   *      shape (C, 1).
-   *  - ema_var: Exponential moving average of the variance, of
-   *      shape (C, 1).
-   *  - mu: Momentum value for moving averages.
-   *      Typical values are in the range of [0.9, 0.999].
-   *  - epsilon: Smoothing term to avoid divide by zero errors.
-   *      Typical values are in the range of [1e-5, 1e-3].
-   *
-   * Outputs:
-   *  - out: Outputs, of shape (N, C*Hin*Win).
-   *  - ema_mean_upd: Updated exponential moving average of the mean,
-   *      of shape (C, 1).
-   *  - ema_var_upd: Updated exponential moving average of the variance,
-   *      of shape (C, 1).
-   *  - cache_mean: Cache of the batch mean, of shape (C, 1).
-   *      Note: This is used for performance during training.
-   *  - cache_inv_var: Cache of the inverse variance, of shape (C, 1).
-   *      Note: This is used for performance during training.
-   */
-  N = nrow(X)
-
-  if (mode == 'train') {
-    # Compute channel-wise mean and variance
-    # Since we don't have tensors, we will compute the means and variances in 
a piece-wise fashion.
-    #  - mean of total group is mean of subgroup means
-    #  - variance is the mean of the subgroup variances + the variance of the 
subgroup means
-    subgrp_means = matrix(colMeans(X), rows=C, cols=Hin*Win)
-    subgrp_vars = matrix(colVars(X) * ((N-1)/N), rows=C, cols=Hin*Win)  # 
uncorrected variances
-    mean = rowMeans(subgrp_means)  # shape (C, 1)
-    var = rowMeans(subgrp_vars) + 
rowVars(subgrp_means)*(((Hin*Win)-1)/(Hin*Win))  # shape (C, 1)
-    # Update moving averages
-    ema_mean_upd = mu*ema_mean + (1-mu)*mean
-    ema_var_upd = mu*ema_var + (1-mu)*var
-  }
-  else {
-    # Use moving averages of mean and variance during testing
-    mean = ema_mean
-    var = ema_var
-    ema_mean_upd = ema_mean
-    ema_var_upd = ema_var
-  }
-
-  # Save variable for backward pass
-  cache_mean = mean
-  cache_inv_var = 1/sqrt(var+epsilon)
-  
-  # Normalize, shift, and scale
-  # norm = (X-mean)*(var+epsilon)^(-1/2)
-  #      = (X-mean) / sqrt(var+epsilon)
-  centered = bias_add(X, -mean)  # shape (N, C*Hin*Win)
-  norm = bias_multiply(centered, cache_inv_var)  # shape (N, C*Hin*Win)
-  # out = norm*gamma + beta
-  scaled = bias_multiply(norm, gamma)  # shape (N, C*Hin*Win)
-  out = bias_add(scaled, beta)  # shape (N, C*Hin*Win)
-}
-
-backward = function(matrix[double] dout, 
-                    matrix[double] cache_mean, matrix[double] cache_inv_var,
-                    matrix[double] X, matrix[double] gamma, 
-                    int C, int Hin, int Win, double epsilon)
-      return (matrix[double] dX, matrix[double] dgamma, matrix[double] dbeta) {
-  /*
-   * Computes the backward pass for a 2D (spatial) batch normalization
-   * layer.
-   *
-   * Inputs:
-   *  - dout: Gradient wrt `out` from upstream, of shape (N, C*Hin*Win).
-   *  - cache_mean: Cache of the batch mean from the forward pass, of
-   *      shape (C, 1).  Note: This is used for performance during
-   *      training.
-   *  - cache_inv_var: Cache of the inverse variance from the forward pass,
-   *      of shape (C, 1).  Note: This is used for performance during
-   *      training.
-   *  - X: Input data matrix to the forward pass, of
-   *      shape (N, C*Hin*Win).
-   *  - gamma: Scale parameters, of shape (C, 1).
-   *  - C: Number of input channels (dimensionality of input depth).
-   *  - Hin: Input height.
-   *  - Win: Input width.
-   *  - epsilon: Smoothing term to avoid divide by zero errors.
-   *      Typical values are in the range of [1e-5, 1e-3].
-   *
-   * Outputs:
-   *  - dX: Gradient wrt `X`, of shape (N, C*Hin*Win).
-   *  - dgamma: Gradient wrt `W`, of shape (C, 1).
-   *  - dbeta: Gradient wrt `b`, of shape (C, 1).
-   *
-   */
-  N = nrow(X)
-  mean = cache_mean
-  centered = bias_add(X, -mean)  # shape (N, C*Hin*Win)
-  norm = bias_multiply(centered, cache_inv_var)  # shape (N, C*Hin*Win)
-  # Compute gradients during training
-  dgamma = util::channel_sums(dout*norm, C, Hin, Win)  # shape (C, 1)
-  dbeta = util::channel_sums(dout, C, Hin, Win)  # shape (C, 1)
-  dnorm = bias_multiply(dout, gamma)  # shape (N, C*Hin*Win)
-  dvar = util::channel_sums((-1/2) * bias_multiply(centered, cache_inv_var^3) 
* dnorm,
-                          C, Hin, Win)  # shape (C, 1)
-  dmean_norm_branch = util::channel_sums(bias_multiply(dnorm, -cache_inv_var), 
C, Hin, Win)
-  dmean_var_branch =  util::channel_sums((-2/(N*Hin*Win)) * centered, C, Hin, 
Win)
-  dmean_var_branch = dmean_var_branch * dvar  # we can't use a function within 
an expression yet
-  dmean = dmean_norm_branch + dmean_var_branch  # shape (C, 1)
-  dX_norm_branch = bias_multiply(dnorm, cache_inv_var)
-  dX_mean_branch = (1/(N*Hin*Win)) * bias_add(matrix(0, rows=1, 
cols=C*Hin*Win), dmean)
-  dX_var_branch = (2/(N*Hin*Win)) * bias_multiply(centered, dvar)
-  dX = dX_norm_branch + dX_mean_branch + dX_var_branch  # shape (N, C*Hin*Win)
-}
-
-init = function(int C)
-    return (matrix[double] gamma, matrix[double] beta,
-            matrix[double] ema_mean, matrix[double] ema_var) {
-  /*
-   * Initialize the parameters of this layer.
-   *
-   * Note: This is just a convenience function, and parameters
-   * may be initialized manually if needed.
-   *
-   * Inputs:
-   *  - C: Number of input channels (dimensionality of input depth).
-   *
-   * Outputs:
-   *  - gamma: Scale parameters, of shape (C, 1).
-   *  - beta: Shift parameters, of shape (C, 1).
-   *  - ema_mean: Exponential moving average of the mean, of
-   *      shape (C, 1).
-   *  - ema_var: Exponential moving average of the variance, of
-   *      shape (C, 1).
-   */
-   gamma = matrix(1, rows=C, cols=1)
-   beta = matrix(0, rows=C, cols=1)
-   ema_mean = matrix(0, rows=C, cols=1)
-   ema_var = matrix(1, rows=C, cols=1)
-}
-

http://git-wip-us.apache.org/repos/asf/systemml/blob/0f36780a/src/main/cpp/kernels/SystemML.cu
----------------------------------------------------------------------
diff --git a/src/main/cpp/kernels/SystemML.cu b/src/main/cpp/kernels/SystemML.cu
index 9ddaaff..b874cdd 100644
--- a/src/main/cpp/kernels/SystemML.cu
+++ b/src/main/cpp/kernels/SystemML.cu
@@ -2289,4 +2289,58 @@ extern "C" __global__ void update_nesterov_x_d(double 
*X, double *v, double *v_p
 
 extern "C" __global__ void update_nesterov_x_f(float *X, float *v, float 
*v_prev, double mu, float *out, unsigned int size) {
   update_nesterov_x(X, v, v_prev, mu, out, size);
-}
\ No newline at end of file
+}
+
+// Performs the operation: C = a*X + b*C
+template <typename T>
+__device__ void aXplusbC(T *X, T *C, double a, double b, unsigned int size) {
+  int index = blockIdx.x * blockDim.x + threadIdx.x;
+  if (index < size) {
+       C[index] = a*X[index] + b*C[index];
+  }
+}
+
+extern "C" __global__ void aXplusbC_d(double *X, double *C, double a, double 
b, unsigned int size) {
+  aXplusbC(X, C, a, b,size);
+}
+
+extern "C" __global__ void aXplusbC_f(float *X, float *C, double a, double b, 
unsigned int size) {
+  aXplusbC(X, C, a, b,size);;
+}
+
+
+// Performs the operation: C = a*X + b*Y
+template <typename T>
+__device__ void aXplusbY(T *X, T* Y, T *C, double a, double b, unsigned int 
size) {
+  int index = blockIdx.x * blockDim.x + threadIdx.x;
+  if (index < size) {
+       C[index] = a*X[index] + b*Y[index];
+  }
+}
+
+extern "C" __global__ void aXplusbY_d(double *X, double* Y, double *C, double 
a, double b, unsigned int size) {
+  aXplusbY(X, Y, C, a, b, size);
+}
+
+extern "C" __global__ void aXplusbY_f(float *X, float* Y, float *C, double a, 
double b, unsigned int size) {
+  aXplusbY(X, Y, C, a, b, size);
+}
+
+
+// Performs the operation: C = 1 / sqrt(X + eps)
+template <typename T>
+__device__ void invVar(T *X, T *C, double eps, unsigned int size) {
+  int index = blockIdx.x * blockDim.x + threadIdx.x;
+  if (index < size) {
+       C[index] = 1.0 / sqrt(X[index] + eps);
+  }
+}
+
+extern "C" __global__ void invVar_d(double *X, double *C, double eps, unsigned 
int size) {
+  invVar(X, C, eps, size);
+}
+
+extern "C" __global__ void invVar_f(float *X, float *C, double eps, unsigned 
int size) {
+  invVar(X, C, eps, size);
+}
+

http://git-wip-us.apache.org/repos/asf/systemml/blob/0f36780a/src/main/cpp/kernels/SystemML.ptx
----------------------------------------------------------------------
diff --git a/src/main/cpp/kernels/SystemML.ptx 
b/src/main/cpp/kernels/SystemML.ptx
index 8a14876..1ab32f5 100644
--- a/src/main/cpp/kernels/SystemML.ptx
+++ b/src/main/cpp/kernels/SystemML.ptx
@@ -13084,12 +13084,279 @@ BB115_2:
        ret;
 }
 
+       // .globl       aXplusbC_d
+.visible .entry aXplusbC_d(
+       .param .u64 aXplusbC_d_param_0,
+       .param .u64 aXplusbC_d_param_1,
+       .param .f64 aXplusbC_d_param_2,
+       .param .f64 aXplusbC_d_param_3,
+       .param .u32 aXplusbC_d_param_4
+)
+{
+       .reg .pred      %p<2>;
+       .reg .b32       %r<6>;
+       .reg .f64       %fd<7>;
+       .reg .b64       %rd<8>;
+
+
+       ld.param.u64    %rd1, [aXplusbC_d_param_0];
+       ld.param.u64    %rd2, [aXplusbC_d_param_1];
+       ld.param.f64    %fd1, [aXplusbC_d_param_2];
+       ld.param.f64    %fd2, [aXplusbC_d_param_3];
+       ld.param.u32    %r2, [aXplusbC_d_param_4];
+       mov.u32         %r3, %ctaid.x;
+       mov.u32         %r4, %ntid.x;
+       mov.u32         %r5, %tid.x;
+       mad.lo.s32      %r1, %r4, %r3, %r5;
+       setp.ge.u32     %p1, %r1, %r2;
+       @%p1 bra        BB116_2;
+
+       cvta.to.global.u64      %rd3, %rd2;
+       cvta.to.global.u64      %rd4, %rd1;
+       mul.wide.s32    %rd5, %r1, 8;
+       add.s64         %rd6, %rd4, %rd5;
+       ld.global.f64   %fd3, [%rd6];
+       add.s64         %rd7, %rd3, %rd5;
+       ld.global.f64   %fd4, [%rd7];
+       mul.f64         %fd5, %fd4, %fd2;
+       fma.rn.f64      %fd6, %fd3, %fd1, %fd5;
+       st.global.f64   [%rd7], %fd6;
+
+BB116_2:
+       ret;
+}
+
+       // .globl       aXplusbC_f
+.visible .entry aXplusbC_f(
+       .param .u64 aXplusbC_f_param_0,
+       .param .u64 aXplusbC_f_param_1,
+       .param .f64 aXplusbC_f_param_2,
+       .param .f64 aXplusbC_f_param_3,
+       .param .u32 aXplusbC_f_param_4
+)
+{
+       .reg .pred      %p<2>;
+       .reg .f32       %f<4>;
+       .reg .b32       %r<6>;
+       .reg .f64       %fd<7>;
+       .reg .b64       %rd<8>;
+
+
+       ld.param.u64    %rd1, [aXplusbC_f_param_0];
+       ld.param.u64    %rd2, [aXplusbC_f_param_1];
+       ld.param.f64    %fd1, [aXplusbC_f_param_2];
+       ld.param.f64    %fd2, [aXplusbC_f_param_3];
+       ld.param.u32    %r2, [aXplusbC_f_param_4];
+       mov.u32         %r3, %ctaid.x;
+       mov.u32         %r4, %ntid.x;
+       mov.u32         %r5, %tid.x;
+       mad.lo.s32      %r1, %r4, %r3, %r5;
+       setp.ge.u32     %p1, %r1, %r2;
+       @%p1 bra        BB117_2;
+
+       cvta.to.global.u64      %rd3, %rd2;
+       cvta.to.global.u64      %rd4, %rd1;
+       mul.wide.s32    %rd5, %r1, 4;
+       add.s64         %rd6, %rd4, %rd5;
+       ld.global.f32   %f1, [%rd6];
+       cvt.f64.f32     %fd3, %f1;
+       add.s64         %rd7, %rd3, %rd5;
+       ld.global.f32   %f2, [%rd7];
+       cvt.f64.f32     %fd4, %f2;
+       mul.f64         %fd5, %fd4, %fd2;
+       fma.rn.f64      %fd6, %fd3, %fd1, %fd5;
+       cvt.rn.f32.f64  %f3, %fd6;
+       st.global.f32   [%rd7], %f3;
+
+BB117_2:
+       ret;
+}
+
+       // .globl       aXplusbY_d
+.visible .entry aXplusbY_d(
+       .param .u64 aXplusbY_d_param_0,
+       .param .u64 aXplusbY_d_param_1,
+       .param .u64 aXplusbY_d_param_2,
+       .param .f64 aXplusbY_d_param_3,
+       .param .f64 aXplusbY_d_param_4,
+       .param .u32 aXplusbY_d_param_5
+)
+{
+       .reg .pred      %p<2>;
+       .reg .b32       %r<6>;
+       .reg .f64       %fd<7>;
+       .reg .b64       %rd<11>;
+
+
+       ld.param.u64    %rd1, [aXplusbY_d_param_0];
+       ld.param.u64    %rd2, [aXplusbY_d_param_1];
+       ld.param.u64    %rd3, [aXplusbY_d_param_2];
+       ld.param.f64    %fd1, [aXplusbY_d_param_3];
+       ld.param.f64    %fd2, [aXplusbY_d_param_4];
+       ld.param.u32    %r2, [aXplusbY_d_param_5];
+       mov.u32         %r3, %ctaid.x;
+       mov.u32         %r4, %ntid.x;
+       mov.u32         %r5, %tid.x;
+       mad.lo.s32      %r1, %r4, %r3, %r5;
+       setp.ge.u32     %p1, %r1, %r2;
+       @%p1 bra        BB118_2;
+
+       cvta.to.global.u64      %rd4, %rd1;
+       mul.wide.s32    %rd5, %r1, 8;
+       add.s64         %rd6, %rd4, %rd5;
+       ld.global.f64   %fd3, [%rd6];
+       cvta.to.global.u64      %rd7, %rd2;
+       add.s64         %rd8, %rd7, %rd5;
+       ld.global.f64   %fd4, [%rd8];
+       mul.f64         %fd5, %fd4, %fd2;
+       fma.rn.f64      %fd6, %fd3, %fd1, %fd5;
+       cvta.to.global.u64      %rd9, %rd3;
+       add.s64         %rd10, %rd9, %rd5;
+       st.global.f64   [%rd10], %fd6;
+
+BB118_2:
+       ret;
+}
+
+       // .globl       aXplusbY_f
+.visible .entry aXplusbY_f(
+       .param .u64 aXplusbY_f_param_0,
+       .param .u64 aXplusbY_f_param_1,
+       .param .u64 aXplusbY_f_param_2,
+       .param .f64 aXplusbY_f_param_3,
+       .param .f64 aXplusbY_f_param_4,
+       .param .u32 aXplusbY_f_param_5
+)
+{
+       .reg .pred      %p<2>;
+       .reg .f32       %f<4>;
+       .reg .b32       %r<6>;
+       .reg .f64       %fd<7>;
+       .reg .b64       %rd<11>;
+
+
+       ld.param.u64    %rd1, [aXplusbY_f_param_0];
+       ld.param.u64    %rd2, [aXplusbY_f_param_1];
+       ld.param.u64    %rd3, [aXplusbY_f_param_2];
+       ld.param.f64    %fd1, [aXplusbY_f_param_3];
+       ld.param.f64    %fd2, [aXplusbY_f_param_4];
+       ld.param.u32    %r2, [aXplusbY_f_param_5];
+       mov.u32         %r3, %ctaid.x;
+       mov.u32         %r4, %ntid.x;
+       mov.u32         %r5, %tid.x;
+       mad.lo.s32      %r1, %r4, %r3, %r5;
+       setp.ge.u32     %p1, %r1, %r2;
+       @%p1 bra        BB119_2;
+
+       cvta.to.global.u64      %rd4, %rd1;
+       mul.wide.s32    %rd5, %r1, 4;
+       add.s64         %rd6, %rd4, %rd5;
+       ld.global.f32   %f1, [%rd6];
+       cvt.f64.f32     %fd3, %f1;
+       cvta.to.global.u64      %rd7, %rd2;
+       add.s64         %rd8, %rd7, %rd5;
+       ld.global.f32   %f2, [%rd8];
+       cvt.f64.f32     %fd4, %f2;
+       mul.f64         %fd5, %fd4, %fd2;
+       fma.rn.f64      %fd6, %fd3, %fd1, %fd5;
+       cvt.rn.f32.f64  %f3, %fd6;
+       cvta.to.global.u64      %rd9, %rd3;
+       add.s64         %rd10, %rd9, %rd5;
+       st.global.f32   [%rd10], %f3;
+
+BB119_2:
+       ret;
+}
+
+       // .globl       invVar_d
+.visible .entry invVar_d(
+       .param .u64 invVar_d_param_0,
+       .param .u64 invVar_d_param_1,
+       .param .f64 invVar_d_param_2,
+       .param .u32 invVar_d_param_3
+)
+{
+       .reg .pred      %p<2>;
+       .reg .b32       %r<6>;
+       .reg .f64       %fd<6>;
+       .reg .b64       %rd<8>;
+
+
+       ld.param.u64    %rd1, [invVar_d_param_0];
+       ld.param.u64    %rd2, [invVar_d_param_1];
+       ld.param.f64    %fd1, [invVar_d_param_2];
+       ld.param.u32    %r2, [invVar_d_param_3];
+       mov.u32         %r3, %ctaid.x;
+       mov.u32         %r4, %ntid.x;
+       mov.u32         %r5, %tid.x;
+       mad.lo.s32      %r1, %r4, %r3, %r5;
+       setp.ge.u32     %p1, %r1, %r2;
+       @%p1 bra        BB120_2;
+
+       cvta.to.global.u64      %rd3, %rd1;
+       mul.wide.s32    %rd4, %r1, 8;
+       add.s64         %rd5, %rd3, %rd4;
+       ld.global.f64   %fd2, [%rd5];
+       add.f64         %fd3, %fd2, %fd1;
+       sqrt.rn.f64     %fd4, %fd3;
+       rcp.rn.f64      %fd5, %fd4;
+       cvta.to.global.u64      %rd6, %rd2;
+       add.s64         %rd7, %rd6, %rd4;
+       st.global.f64   [%rd7], %fd5;
+
+BB120_2:
+       ret;
+}
+
+       // .globl       invVar_f
+.visible .entry invVar_f(
+       .param .u64 invVar_f_param_0,
+       .param .u64 invVar_f_param_1,
+       .param .f64 invVar_f_param_2,
+       .param .u32 invVar_f_param_3
+)
+{
+       .reg .pred      %p<2>;
+       .reg .f32       %f<3>;
+       .reg .b32       %r<6>;
+       .reg .f64       %fd<6>;
+       .reg .b64       %rd<8>;
+
+
+       ld.param.u64    %rd1, [invVar_f_param_0];
+       ld.param.u64    %rd2, [invVar_f_param_1];
+       ld.param.f64    %fd1, [invVar_f_param_2];
+       ld.param.u32    %r2, [invVar_f_param_3];
+       mov.u32         %r3, %ctaid.x;
+       mov.u32         %r4, %ntid.x;
+       mov.u32         %r5, %tid.x;
+       mad.lo.s32      %r1, %r4, %r3, %r5;
+       setp.ge.u32     %p1, %r1, %r2;
+       @%p1 bra        BB121_2;
+
+       cvta.to.global.u64      %rd3, %rd1;
+       mul.wide.s32    %rd4, %r1, 4;
+       add.s64         %rd5, %rd3, %rd4;
+       ld.global.f32   %f1, [%rd5];
+       cvt.f64.f32     %fd2, %f1;
+       add.f64         %fd3, %fd2, %fd1;
+       sqrt.rn.f64     %fd4, %fd3;
+       rcp.rn.f64      %fd5, %fd4;
+       cvt.rn.f32.f64  %f2, %fd5;
+       cvta.to.global.u64      %rd6, %rd2;
+       add.s64         %rd7, %rd6, %rd4;
+       st.global.f32   [%rd7], %f2;
+
+BB121_2:
+       ret;
+}
+
 .func  (.param .b64 func_retval0) __internal_trig_reduction_slowpathd(
        .param .b64 __internal_trig_reduction_slowpathd_param_0,
        .param .b64 __internal_trig_reduction_slowpathd_param_1
 )
 {
-       .local .align 8 .b8     __local_depot116[40];
+       .local .align 8 .b8     __local_depot122[40];
        .reg .b64       %SP;
        .reg .b64       %SPL;
        .reg .pred      %p<9>;
@@ -13098,7 +13365,7 @@ BB115_2:
        .reg .b64       %rd<102>;
 
 
-       mov.u64         %rd101, __local_depot116;
+       mov.u64         %rd101, __local_depot122;
        cvta.local.u64  %SP, %rd101;
        ld.param.f64    %fd4, [__internal_trig_reduction_slowpathd_param_0];
        ld.param.u64    %rd37, [__internal_trig_reduction_slowpathd_param_1];
@@ -13112,7 +13379,7 @@ BB115_2:
        shr.u32         %r3, %r1, 20;
        bfe.u32         %r4, %r1, 20, 11;
        setp.eq.s32     %p1, %r4, 2047;
-       @%p1 bra        BB116_13;
+       @%p1 bra        BB122_13;
 
        add.s32         %r15, %r4, -1024;
        shr.u32         %r16, %r15, 6;
@@ -13125,7 +13392,7 @@ BB115_2:
        mov.u64         %rd94, 0;
        setp.ge.s32     %p2, %r5, %r6;
        mov.u64         %rd93, %rd1;
-       @%p2 bra        BB116_4;
+       @%p2 bra        BB122_4;
 
        mov.b64          %rd41, %fd4;
        shl.b64         %rd42, %rd41, 11;
@@ -13142,7 +13409,7 @@ BB115_2:
        mov.u64         %rd91, %rd1;
        mov.u32         %r39, %r5;
 
-BB116_3:
+BB122_3:
        .pragma "nounroll";
        ld.const.u64    %rd47, [%rd89];
        // inline asm
@@ -13172,15 +13439,15 @@ BB116_3:
        add.s64         %rd93, %rd93, 8;
        add.s64         %rd89, %rd89, 8;
        setp.lt.s32     %p3, %r39, %r6;
-       @%p3 bra        BB116_3;
+       @%p3 bra        BB122_3;
 
-BB116_4:
+BB122_4:
        st.local.u64    [%rd93], %rd94;
        ld.local.u64    %rd95, [%rd1+16];
        ld.local.u64    %rd96, [%rd1+24];
        and.b32         %r9, %r3, 63;
        setp.eq.s32     %p4, %r9, 0;
-       @%p4 bra        BB116_6;
+       @%p4 bra        BB122_6;
 
        mov.u32         %r27, 64;
        sub.s32         %r28, %r27, %r9;
@@ -13192,7 +13459,7 @@ BB116_4:
        shr.u64         %rd55, %rd54, %r28;
        or.b64          %rd95, %rd55, %rd53;
 
-BB116_6:
+BB122_6:
        cvta.to.local.u64       %rd56, %rd37;
        shr.u64         %rd57, %rd96, 62;
        cvt.u32.u64     %r29, %rd57;
@@ -13209,7 +13476,7 @@ BB116_6:
        selp.b32        %r34, %r32, %r33, %p5;
        st.local.u32    [%rd56], %r34;
        setp.eq.s32     %p6, %r31, 0;
-       @%p6 bra        BB116_8;
+       @%p6 bra        BB122_8;
 
        mov.u64         %rd64, 0;
        // inline asm
@@ -13229,10 +13496,10 @@ BB116_6:
        // inline asm
        xor.b32         %r40, %r40, -2147483648;
 
-BB116_8:
+BB122_8:
        clz.b64         %r41, %rd98;
        setp.eq.s32     %p7, %r41, 0;
-       @%p7 bra        BB116_10;
+       @%p7 bra        BB122_10;
 
        shl.b64         %rd67, %rd98, %r41;
        mov.u32         %r35, 64;
@@ -13240,7 +13507,7 @@ BB116_8:
        shr.u64         %rd68, %rd97, %r36;
        or.b64          %rd98, %rd68, %rd67;
 
-BB116_10:
+BB122_10:
        mov.u64         %rd72, -3958705157555305931;
        // inline asm
        {
@@ -13261,7 +13528,7 @@ BB116_10:
        }
        // inline asm
        setp.lt.s64     %p8, %rd100, 1;
-       @%p8 bra        BB116_12;
+       @%p8 bra        BB122_12;
 
        // inline asm
        {
@@ -13280,7 +13547,7 @@ BB116_10:
        // inline asm
        add.s32         %r41, %r41, 1;
 
-BB116_12:
+BB122_12:
        cvt.u64.u32     %rd79, %r40;
        shl.b64         %rd80, %rd79, 32;
        mov.u32         %r37, 1022;
@@ -13295,7 +13562,7 @@ BB116_12:
        or.b64          %rd88, %rd87, %rd80;
        mov.b64          %fd4, %rd88;
 
-BB116_13:
+BB122_13:
        st.param.f64    [func_retval0+0], %fd4;
        ret;
 }
@@ -13323,7 +13590,7 @@ BB116_13:
        }
        shr.u32         %r51, %r50, 20;
        setp.ne.s32     %p1, %r51, 0;
-       @%p1 bra        BB117_2;
+       @%p1 bra        BB123_2;
 
        mul.f64         %fd14, %fd12, 0d4350000000000000;
        {
@@ -13337,13 +13604,13 @@ BB116_13:
        shr.u32         %r16, %r50, 20;
        add.s32         %r51, %r16, -54;
 
-BB117_2:
+BB123_2:
        add.s32         %r52, %r51, -1023;
        and.b32         %r17, %r50, -2146435073;
        or.b32          %r18, %r17, 1072693248;
        mov.b64         %fd135, {%r49, %r18};
        setp.lt.u32     %p2, %r18, 1073127583;
-       @%p2 bra        BB117_4;
+       @%p2 bra        BB123_4;
 
        {
        .reg .b32 %temp; 
@@ -13357,7 +13624,7 @@ BB117_2:
        mov.b64         %fd135, {%r19, %r21};
        add.s32         %r52, %r51, -1022;
 
-BB117_4:
+BB123_4:
        add.f64         %fd15, %fd135, 0d3FF0000000000000;
        rcp.approx.ftz.f64      %fd16, %fd15;
        neg.f64         %fd17, %fd15;
@@ -13520,13 +13787,13 @@ BB117_4:
        mov.b32          %f2, %r35;
        abs.f32         %f1, %f2;
        setp.lt.f32     %p4, %f1, 0f4086232B;
-       @%p4 bra        BB117_7;
+       @%p4 bra        BB123_7;
 
        setp.lt.f64     %p5, %fd4, 0d0000000000000000;
        add.f64         %fd129, %fd4, 0d7FF0000000000000;
        selp.f64        %fd136, 0d0000000000000000, %fd129, %p5;
        setp.geu.f32    %p6, %f1, 0f40874800;
-       @%p6 bra        BB117_7;
+       @%p6 bra        BB123_7;
 
        mov.f64         %fd134, 0d4338000000000000;
        mov.f64         %fd133, 0d3FF71547652B82FE;
@@ -13548,26 +13815,26 @@ BB117_4:
        mov.b64         %fd131, {%r44, %r43};
        mul.f64         %fd136, %fd130, %fd131;
 
-BB117_7:
+BB123_7:
        {
        .reg .b32 %temp; 
        mov.b64         {%temp, %r45}, %fd136;
        }
        and.b32         %r46, %r45, 2147483647;
        setp.ne.s32     %p7, %r46, 2146435072;
-       @%p7 bra        BB117_9;
+       @%p7 bra        BB123_9;
 
        {
        .reg .b32 %temp; 
        mov.b64         {%r47, %temp}, %fd136;
        }
        setp.eq.s32     %p8, %r47, 0;
-       @%p8 bra        BB117_10;
+       @%p8 bra        BB123_10;
 
-BB117_9:
+BB123_9:
        fma.rn.f64      %fd136, %fd136, %fd5, %fd136;
 
-BB117_10:
+BB123_10:
        st.param.f64    [func_retval0+0], %fd136;
        ret;
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/0f36780a/src/main/java/org/apache/sysml/hops/DnnOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/DnnOp.java 
b/src/main/java/org/apache/sysml/hops/DnnOp.java
index a7d37dc..c4ce466 100644
--- a/src/main/java/org/apache/sysml/hops/DnnOp.java
+++ b/src/main/java/org/apache/sysml/hops/DnnOp.java
@@ -110,8 +110,6 @@ public class DnnOp extends MultiThreadedHop
                if( getLops() != null )
                        return getLops();
                
-               ExecType et = optFindExecType();
-               
                ArrayList<Hop> inputs = getInput();
                switch( op )
                {
@@ -125,6 +123,7 @@ public class DnnOp extends MultiThreadedHop
                        case BIASADD:
                        case BIASMULT:
                        {       
+                               ExecType et = optFindExecType();
                                if(et == ExecType.CP || et == ExecType.GPU) {
                                        setLops(constructDnnLops(et, inputs));
                                        break;
@@ -137,15 +136,15 @@ public class DnnOp extends MultiThreadedHop
                        case BATCH_NORM2D_TEST:
                        case CHANNEL_SUMS:
                        case UPDATE_NESTEROV_X:
+                       case UPDATE_EMA_VAR:
+                       case RESHAPE_COLMEANS:
+                       case UPDATE_EMA:
+                       case INV_VAR:
+                       case BATCH_NORM2D_BACKWARD_DX:
                        {       
-                               if(et == ExecType.GPU) {
-                                       setLops(constructDnnLops(et, inputs));
-                                       break;
-                               }
-                               else {
-                                       throw new HopsException("Unimplemented 
DnnOp for execution type: " + et.name());
-                               }
-                               // break;
+                               // GPU-specific operators
+                               setLops(constructDnnLops(ExecType.GPU, inputs));
+                               break;
                        }
                        default: 
                                throw new HopsException("Unsupported lops 
construction for operation type '"+op+"'.");
@@ -171,10 +170,16 @@ public class DnnOp extends MultiThreadedHop
                                return 14;
                        case BIASADD:
                        case BIASMULT:
+                       case INV_VAR:
                                return 2;
                        case BATCH_NORM2D_TEST:
                                return 6;
+                       case UPDATE_EMA_VAR:
+                       case BATCH_NORM2D_BACKWARD_DX:
+                               return 5;
+                       case RESHAPE_COLMEANS:
                        case CHANNEL_SUMS:
+                       case UPDATE_EMA:
                                return 3;
                        case UPDATE_NESTEROV_X:
                                return 4;
@@ -532,7 +537,8 @@ public class DnnOp extends MultiThreadedHop
                long[] ret = new long[3];
                
                if(op == OpOpDnn.BIASADD || op == OpOpDnn.BIASMULT || op == 
OpOpDnn.BATCH_NORM2D_TEST ||
-                       op == OpOpDnn.UPDATE_NESTEROV_X) {
+                       op == OpOpDnn.UPDATE_NESTEROV_X || op == 
OpOpDnn.UPDATE_EMA || op == OpOpDnn.INV_VAR ||
+                       op == OpOpDnn.BATCH_NORM2D_BACKWARD_DX) {
                        // Same dimension as the first input
                        MatrixCharacteristics[] mc = 
memo.getAllInputStats(getInput());
                        ret[0] = mc[0].rowsKnown() ? mc[0].getRows() : -1;
@@ -540,13 +546,21 @@ public class DnnOp extends MultiThreadedHop
                        ret[2] = -1;
                        return (ret[0]>=0 && ret[1]>=0) ? ret : null;
                }
-               else if(op == OpOpDnn.CHANNEL_SUMS) {
+               else if(op == OpOpDnn.CHANNEL_SUMS || op == 
OpOpDnn.UPDATE_EMA_VAR) {
                        long numChannels = 
Hop.computeSizeInformation(getInput().get(1));
                        ret[0] = numChannels;
                        ret[1] = 1;
                        ret[2] = -1;
                        return ret;
                }
+               else if(op == OpOpDnn.RESHAPE_COLMEANS) {
+                       long numChannels = 
Hop.computeSizeInformation(getInput().get(1));
+                       long HW = Hop.computeSizeInformation(getInput().get(2));
+                       ret[0] = numChannels;
+                       ret[1] = HW;
+                       ret[2] = -1;
+                       return ret;
+               }
                
                refreshSizeInformation();
                ret[0] = _dim1; ret[1] = _dim2; ret[2] = _nnz;
@@ -739,7 +753,9 @@ public class DnnOp extends MultiThreadedHop
        @Override
        public void refreshSizeInformation()
        {
-               if(op == OpOpDnn.BIASADD || op == OpOpDnn.BIASMULT || op == 
OpOpDnn.BATCH_NORM2D_TEST || op == OpOpDnn.UPDATE_NESTEROV_X) {
+               if(op == OpOpDnn.BIASADD || op == OpOpDnn.BIASMULT || op == 
OpOpDnn.BATCH_NORM2D_TEST || 
+                       op == OpOpDnn.UPDATE_NESTEROV_X || op == 
OpOpDnn.UPDATE_EMA || op == OpOpDnn.INV_VAR ||
+                       op == OpOpDnn.BATCH_NORM2D_BACKWARD_DX) {
                        // Same dimension as the first input
                        Hop input1 = getInput().get(0);
                        setDim1(input1.getDim1());
@@ -747,13 +763,21 @@ public class DnnOp extends MultiThreadedHop
                        _nnz = -1; // cannot infer stats
                        return;
                }
-               else if(op == OpOpDnn.CHANNEL_SUMS) {
+               else if(op == OpOpDnn.CHANNEL_SUMS || op == 
OpOpDnn.UPDATE_EMA_VAR) {
                        long numChannels = 
Hop.computeSizeInformation(getInput().get(1));
                        setDim1(numChannels);
                        setDim2(1);
                        _nnz = -1; // cannot infer stats
                        return;
                }
+               else if(op == OpOpDnn.RESHAPE_COLMEANS) {
+                       long numChannels = 
Hop.computeSizeInformation(getInput().get(1));
+                       long HW = Hop.computeSizeInformation(getInput().get(2));
+                       setDim1(numChannels);
+                       setDim2(HW);
+                       _nnz = -1; // cannot infer stats
+                       return;
+               }
                
                // Reset the _cachedParams to avoid incorrect sizes
                _cachedParams = new DnnParameters(-1, -1, -1, -1, -1, -1, -1, 
-1, -1, -1, -1, _maxNumThreads);
@@ -847,7 +871,9 @@ public class DnnOp extends MultiThreadedHop
         */
        private long getDim(String dimString) {
                if(op == OpOpDnn.BIASADD || op == OpOpDnn.BIASMULT || op == 
OpOpDnn.BATCH_NORM2D_TEST || op == OpOpDnn.CHANNEL_SUMS ||
-                       op == OpOpDnn.UPDATE_NESTEROV_X) {
+                       op == OpOpDnn.UPDATE_NESTEROV_X || op == 
OpOpDnn.RESHAPE_COLMEANS ||
+                       op == OpOpDnn.UPDATE_EMA_VAR || op == 
OpOpDnn.UPDATE_EMA || op == OpOpDnn.INV_VAR ||
+                       op == OpOpDnn.BATCH_NORM2D_BACKWARD_DX) {
                        throw new RuntimeException("getDim method should not be 
invoked for " + op.name());
                }
                try {

http://git-wip-us.apache.org/repos/asf/systemml/blob/0f36780a/src/main/java/org/apache/sysml/hops/FunctionOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/FunctionOp.java 
b/src/main/java/org/apache/sysml/hops/FunctionOp.java
index ea397db..5f177bd 100644
--- a/src/main/java/org/apache/sysml/hops/FunctionOp.java
+++ b/src/main/java/org/apache/sysml/hops/FunctionOp.java
@@ -181,21 +181,6 @@ public class FunctionOp extends Hop
                                // TODO: To allow for initial version to always 
run on the GPU
                                return 0; 
                        }
-                       else if ( 
getFunctionName().equalsIgnoreCase("batch_norm2d") || 
getFunctionName().equalsIgnoreCase("batch_norm2d_train")) {
-                               return 
OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(0).getDim1(), 
getOutputs().get(0).getDim2(), 1.0) +
-                                               
OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(1).getDim1(), 
getOutputs().get(1).getDim2(), 1.0) +
-                                               
OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(2).getDim1(), 
getOutputs().get(2).getDim2(), 1.0) +
-                                               
OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(3).getDim1(), 
getOutputs().get(3).getDim2(), 1.0) + 
-                                               
OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(4).getDim1(), 
getOutputs().get(4).getDim2(), 1.0);
-                       }
-                       else if ( 
getFunctionName().equalsIgnoreCase("batch_norm2d_test") ) {
-                       return 
OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(0).getDim1(), 
getOutputs().get(0).getDim2(), 1.0);
-               }
-                       else if ( 
getFunctionName().equalsIgnoreCase("batch_norm2d_backward") ) {
-                               return 
OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(0).getDim1(), 
getOutputs().get(0).getDim2(), 1.0) +
-                                               
OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(1).getDim1(), 
getOutputs().get(1).getDim2(), 1.0) +
-                                               
OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(2).getDim1(), 
getOutputs().get(2).getDim2(), 1.0);
-                       }
                        else if ( getFunctionName().equalsIgnoreCase("svd") ) {
                                long outputU = 
OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(0).getDim1(), 
getOutputs().get(0).getDim2(), 1.0);
                                long outputSigma = 
OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(1).getDim1(), 
getOutputs().get(1).getDim2(), 1.0);
@@ -226,10 +211,6 @@ public class FunctionOp extends Hop
                                return 
OptimizerUtils.estimateSizeExactSparsity(getInput().get(0).getDim1(), 
getInput().get(0).getDim2(), 1.0) 
                                                + 
3*OptimizerUtils.estimateSizeExactSparsity(getInput().get(0).getDim1(), 1, 
1.0); 
                        }
-                       else if 
(getFunctionName().equalsIgnoreCase("batch_norm2d") || 
getFunctionName().equalsIgnoreCase("batch_norm2d_backward") ||
-                                       
getFunctionName().equalsIgnoreCase("batch_norm2d_train") || 
getFunctionName().equalsIgnoreCase("batch_norm2d_test")) {
-                               return 0; 
-                       }
                        else if ( getFunctionName().equalsIgnoreCase("lstm") || 
 getFunctionName().equalsIgnoreCase("lstm_backward") ) {
                                // TODO: To allow for initial version to always 
run on the GPU
                                return 0; 
@@ -251,9 +232,7 @@ public class FunctionOp extends Hop
        
        @Override
        public boolean isGPUEnabled() {
-               if(getFunctionName().equalsIgnoreCase("lstm") || 
getFunctionName().equalsIgnoreCase("lstm_backward") ||  
-                       getFunctionName().equalsIgnoreCase("batch_norm2d") || 
getFunctionName().equalsIgnoreCase("batch_norm2d_backward") ||
-                       
getFunctionName().equalsIgnoreCase("batch_norm2d_train") || 
getFunctionName().equalsIgnoreCase("batch_norm2d_test")) 
+               if(getFunctionName().equalsIgnoreCase("lstm") || 
getFunctionName().equalsIgnoreCase("lstm_backward")) 
                        return true;
                else
                        return false;
@@ -308,13 +287,6 @@ public class FunctionOp extends Hop
                                        throw new RuntimeException("The 
function " + getFunctionName() + " is only supported on GPU.");
                                _etype = ExecType.GPU;
                        }
-                       else if(isBuiltinFunction && 
(getFunctionName().equalsIgnoreCase("batch_norm2d") || 
getFunctionName().equalsIgnoreCase("batch_norm2d_backward"))) {
-                               _etype = ConfigurationManager.isGPU() ? 
ExecType.GPU : ExecType.CP;
-                       }
-                       else if(isBuiltinFunction && 
getFunctionName().equalsIgnoreCase("batch_norm2d_train")) {
-                               // Only GPU implementation is supported
-                               _etype = ExecType.GPU;
-                       }
                        else {
                                // Since the memory estimate is only 
conservative, do not throw
                                // exception if the estimated memory is larger 
than the budget

http://git-wip-us.apache.org/repos/asf/systemml/blob/0f36780a/src/main/java/org/apache/sysml/hops/Hop.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/Hop.java 
b/src/main/java/org/apache/sysml/hops/Hop.java
index 3b461a1..c8356e0 100644
--- a/src/main/java/org/apache/sysml/hops/Hop.java
+++ b/src/main/java/org/apache/sysml/hops/Hop.java
@@ -1100,7 +1100,8 @@ public abstract class Hop implements ParseInfo
                MAX_POOL, MAX_POOL_BACKWARD, AVG_POOL, AVG_POOL_BACKWARD,
                CONV2D, CONV2D_BACKWARD_FILTER, CONV2D_BACKWARD_DATA,
                BIASADD, BIASMULT, BATCH_NORM2D_TEST, CHANNEL_SUMS,
-               UPDATE_NESTEROV_X
+               UPDATE_NESTEROV_X, RESHAPE_COLMEANS, UPDATE_EMA_VAR, 
UPDATE_EMA, INV_VAR,
+               BATCH_NORM2D_BACKWARD_DX
        }
        
        public enum DataGenMethod {
@@ -1174,8 +1175,13 @@ public abstract class Hop implements ParseInfo
                HopsConv2Lops.put(OpOpDnn.CONV2D_BACKWARD_FILTER, 
org.apache.sysml.lops.DnnTransform.OperationTypes.CONV2D_BACKWARD_FILTER);
                HopsConv2Lops.put(OpOpDnn.CONV2D_BACKWARD_DATA, 
org.apache.sysml.lops.DnnTransform.OperationTypes.CONV2D_BACKWARD_DATA);
                HopsConv2Lops.put(OpOpDnn.BATCH_NORM2D_TEST, 
org.apache.sysml.lops.DnnTransform.OperationTypes.BATCH_NORM2D_TEST);
+               HopsConv2Lops.put(OpOpDnn.UPDATE_EMA_VAR, 
org.apache.sysml.lops.DnnTransform.OperationTypes.UPDATE_EMA_VAR);
                HopsConv2Lops.put(OpOpDnn.CHANNEL_SUMS, 
org.apache.sysml.lops.DnnTransform.OperationTypes.CHANNEL_SUMS);
                HopsConv2Lops.put(OpOpDnn.UPDATE_NESTEROV_X, 
org.apache.sysml.lops.DnnTransform.OperationTypes.UPDATE_NESTEROV_X);
+               HopsConv2Lops.put(OpOpDnn.RESHAPE_COLMEANS, 
org.apache.sysml.lops.DnnTransform.OperationTypes.RESHAPE_COLMEANS);
+               HopsConv2Lops.put(OpOpDnn.UPDATE_EMA, 
org.apache.sysml.lops.DnnTransform.OperationTypes.UPDATE_EMA);
+               HopsConv2Lops.put(OpOpDnn.INV_VAR, 
org.apache.sysml.lops.DnnTransform.OperationTypes.INV_VAR);
+               HopsConv2Lops.put(OpOpDnn.BATCH_NORM2D_BACKWARD_DX, 
org.apache.sysml.lops.DnnTransform.OperationTypes.BATCH_NORM2D_BACKWARD_DX);
        }
 
        protected static final HashMap<Hop.Direction, 
org.apache.sysml.lops.PartialAggregate.DirectionTypes> HopsDirection2Lops;

http://git-wip-us.apache.org/repos/asf/systemml/blob/0f36780a/src/main/java/org/apache/sysml/hops/rewrite/HopDagPatternMatcher.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/rewrite/HopDagPatternMatcher.java 
b/src/main/java/org/apache/sysml/hops/rewrite/HopDagPatternMatcher.java
new file mode 100644
index 0000000..7c70b7b
--- /dev/null
+++ b/src/main/java/org/apache/sysml/hops/rewrite/HopDagPatternMatcher.java
@@ -0,0 +1,378 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sysml.hops.rewrite;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.function.Function;
+import java.util.function.Predicate;
+import org.apache.sysml.api.DMLScript;
+import org.apache.sysml.conf.ConfigurationManager;
+import org.apache.sysml.hops.AggUnaryOp;
+import org.apache.sysml.hops.Hop;
+import org.apache.sysml.hops.OptimizerUtils;
+import org.apache.sysml.hops.Hop.AggOp;
+import org.apache.sysml.hops.Hop.Direction;
+import org.apache.sysml.hops.Hop.OpOp1;
+import org.apache.sysml.hops.Hop.OpOp2;
+import org.apache.sysml.hops.Hop.OpOpDnn;
+import org.apache.sysml.hops.Hop.ReOrgOp;
+import org.apache.sysml.parser.Expression.DataType;
+import org.apache.sysml.runtime.DMLRuntimeException;
+import org.apache.sysml.runtime.instructions.gpu.context.GPUContextPool;
+import org.apache.sysml.utils.Explain;
+
+/**
+ * Please see org.apache.sysml.hops.rewrite.RewriteGPUSpecificOps class for 
usage and design documentation.
+ */
+public class HopDagPatternMatcher {
+       static final HashSet<String> DEBUG_PATTERNS;
+       static {
+               // DEBUG_PATTERNS = new HashSet<>();
+               // DEBUG_PATTERNS.add("batchNormdX");
+               DEBUG_PATTERNS = null;
+       }
+       
+       // Predicates for the current HOP
+       List<HopPredicate> _predicates = new ArrayList<>();
+       // Child matchers
+       List<HopDagPatternMatcher> _children = new ArrayList<>();
+       private boolean _isLeaf = false;
+       
+       static boolean DEBUG_REWRITES = false; // This is set by 
HopPatternRewriter. Please use DEBUG_PATTERNS instead.
+       
+       // Simple utility for debugging the rewrites
+       public static class HopPredicate implements Predicate<Hop> {
+               final String _name;
+               final Function<Hop, Boolean> _pred;
+               public HopPredicate(String name, Function<Hop, Boolean> pred) {
+                       _name = name;
+                       _pred = pred;
+               }
+               @Override
+               public boolean test(Hop h) {
+                       return _pred.apply(h);
+               }
+               @Override
+           public String toString() {
+               return _name;
+           }
+       }
+       
+       /**
+        * Adds a predicate to the pattern matcher
+        * 
+        * @param name name of the pattern for debugging
+        * @param pred higher order function that takes as an input a hop and 
returns true if the pattern matches else false
+        * @return this
+        */
+       public HopDagPatternMatcher addPredicate(String name, Function<Hop, 
Boolean> pred) {
+               _predicates.add(new HopPredicate(name, pred));
+               return this;
+       }
+       
+       /**
+        * Add child pattern matcher
+        * @param children list of childer
+        * @return this
+        */
+       public HopDagPatternMatcher addChildMatcher(HopDagPatternMatcher... 
children) {
+               for(int i = 0; i < children.length; i++) {
+                       _children.add(children[i]);
+               }
+               return this;
+       }
+       
+       /**
+        * Get the matched HOP DAGs
+        * @param varName variable names
+        * @return matched HOP
+        */
+       public Hop getMatchedHop(String varName) {
+               
+               if(matchedHops == null || !matchedHops.containsKey(varName)) {
+                       throw new RuntimeException("Incorrect usage: the 
variable " + varName + " is not registered as input.");
+               }
+               return matchedHops.get(varName);
+       }
+       
+       /**
+        * Return the value 
+        * 
+        * @param varName variable name
+        * @return the value of the LiteralOp 
+        */
+       public double getLiteralValue(String varName) {
+               return 
OptimizerUtils.rEvalSimpleDoubleExpression(getMatchedHop(varName), new 
HashMap<>());
+       }
+       
+       @Override
+    public String toString() {
+        return _predicates.size() >= 1 ? _predicates.get(0).toString() : "";
+    }
+       
+       /**
+        * Match the given HOP DAG
+        * 
+        * @param h root node of the HOP DAG 
+        * @return true if HOP DAG matches
+        */
+       public boolean matches(Hop h) {
+               visited.clear();
+               matchedHops.clear();
+               return matchHelper(this, h);
+       }
+       
+       private HashMap<String, Hop> matchedHops = new HashMap<>();
+       private String variableName;
+       private HashMap<HopDagPatternMatcher, Hop> visited = new HashMap<>(); 
// Map of matched hops
+       private boolean matchHelper(HopDagPatternMatcher root, Hop h) {
+               if(h == null) {
+                       return false;
+               }
+               else if(_children.size() > 0 && h.getInput().size() < 
_children.size()) {
+                       if(DEBUG_REWRITES) {
+                               System.out.println("The expected number of 
children (" + _children.size() + ") didnot match the number of inputs (" + 
h.getInput().size() + ") " + this);
+                       }
+                       return false;
+               }
+               if(root.visited.containsKey(this)) {
+                       Hop h1 = root.visited.get(this);
+                       if(h == h1) {
+                               if(DEBUG_REWRITES)
+                                       System.out.println("MATCHED: Early exit 
as the given HOP has been already matched by the matcher." + this); 
+                               return true; // Early exit as the given HOP has 
been already matched by the matcher
+                       }
+                       else if(_isLeaf) {
+                               if(h.getDataType() == h1.getDataType() && 
h.getDataType() == DataType.SCALAR) {
+                                       return 
OptimizerUtils.rEvalSimpleDoubleExpression(h, new HashMap<>()) == 
OptimizerUtils.rEvalSimpleDoubleExpression(h1, new HashMap<>());
+                               }
+                               return false; // Mismatched or unknown 
datatypes or matched with different hops
+                       }
+               }
+               
+               for(HopPredicate p : _predicates) {
+                       if(!p.test(h)) {
+                               if(DEBUG_REWRITES) {
+                                       System.out.println("The predicate " + 
p.toString() + " failed.");
+                               }
+                               return false;
+                       }
+               }
+               int index = 0;
+               for(HopDagPatternMatcher child : _children) {
+                       if(!child.matchHelper(root, h.getInput().get(index))) {
+                               return false;
+                       }
+                       index++;
+               }
+               if(_isLeaf) {
+                       root.matchedHops.put(variableName, h);
+               }
+               
+               root.visited.put(this, h);
+               if(DEBUG_REWRITES)
+                       System.out.println("MATCHED: " + this + " to " + 
Explain.explain(h));
+               return true;
+       }
+       
+
+       // Simple helper utilities for adding predicates
+       private HopDagPatternMatcher isScalar() {
+               return this.addPredicate("isScalar", h -> h.getDataType() == 
DataType.SCALAR);
+       }
+       private HopDagPatternMatcher isMatrix() {
+               return this.addPredicate("isMatrix", h -> h.getDataType() == 
DataType.MATRIX);
+       }
+       public HopDagPatternMatcher fitsOnGPU(double constant) {
+               return this.addPredicate("fitsOnGPU", h -> _fitsOnGPU(h, 
constant));
+       }
+       
+       // Factory methods:
+       public static HopDagPatternMatcher dummy = new HopDagPatternMatcher();
+       public static HopDagPatternMatcher rowMeans(HopDagPatternMatcher 
child1) {
+               return new HopDagPatternMatcher().addPredicate("rowMeans", h -> 
+                       h instanceof AggUnaryOp && ((AggUnaryOp)h).getOp() == 
AggOp.MEAN && ((AggUnaryOp)h).getDirection() == Direction.Row)
+                       .addChildMatcher(child1);
+       }
+       public static HopDagPatternMatcher rowVars(HopDagPatternMatcher child1) 
{
+               return new HopDagPatternMatcher().addPredicate("rowVars", h -> 
+                       h instanceof AggUnaryOp && ((AggUnaryOp)h).getOp() == 
AggOp.VAR && ((AggUnaryOp)h).getDirection() == Direction.Row)
+                       .addChildMatcher(child1);
+       }
+       public static HopDagPatternMatcher colVars(HopDagPatternMatcher child1) 
{
+               return new HopDagPatternMatcher().addPredicate("colVars", h -> 
+                       h instanceof AggUnaryOp && ((AggUnaryOp)h).getOp() == 
AggOp.VAR && ((AggUnaryOp)h).getDirection() == Direction.Col)
+                       .addChildMatcher(child1);
+       }
+       public static HopDagPatternMatcher leaf(String _variableName, DataType 
dt) {
+               HopDagPatternMatcher ret = new HopDagPatternMatcher();
+               ret._isLeaf = true;
+               ret.variableName = _variableName;
+               if(dt == DataType.MATRIX) {
+                       return ret.isMatrix();
+               }
+               else if(dt == DataType.SCALAR) {
+                       return ret.isScalar();
+               }
+               else if(dt == DataType.UNKNOWN) {
+                       return ret;
+               }
+               else {
+                       throw new DMLRuntimeException("Unsupported datatype in 
pattern matcher:" + dt.name());
+               }
+       }
+       public static HopDagPatternMatcher rowSums(HopDagPatternMatcher child1) 
{
+               return new HopDagPatternMatcher().addPredicate("rowSums", h -> 
+                       h instanceof AggUnaryOp && ((AggUnaryOp)h).getOp() == 
AggOp.SUM && ((AggUnaryOp)h).getDirection() == Direction.Row)
+                       .addChildMatcher(child1);
+       }
+       public static HopDagPatternMatcher colSums(HopDagPatternMatcher child1) 
{
+               return new HopDagPatternMatcher().addPredicate("colSums", h -> 
+                       h instanceof AggUnaryOp && ((AggUnaryOp)h).getOp() == 
AggOp.SUM && ((AggUnaryOp)h).getDirection() == Direction.Col)
+                       .addChildMatcher(child1);
+       }
+       public static HopDagPatternMatcher colMeans(HopDagPatternMatcher 
child1) {
+               return new HopDagPatternMatcher().addPredicate("colSums", h -> 
+                       h instanceof AggUnaryOp && ((AggUnaryOp)h).getOp() == 
AggOp.MEAN && ((AggUnaryOp)h).getDirection() == Direction.Col)
+                       .addChildMatcher(child1);
+       }
+       public static HopDagPatternMatcher matrix(HopDagPatternMatcher X, 
HopDagPatternMatcher rows, HopDagPatternMatcher cols) {
+               return new 
HopDagPatternMatcher().addPredicate("matrix_reshape", h -> 
HopRewriteUtils.isReorg(h, ReOrgOp.RESHAPE))
+                               .addChildMatcher(X, rows, cols);
+       }
+       public static HopDagPatternMatcher matrix(double X, 
HopDagPatternMatcher rows, HopDagPatternMatcher cols) {
+               return new 
HopDagPatternMatcher().addPredicate("matrix_datagen", h -> 
HopRewriteUtils.isDataGenOpWithConstantValue(h, X))
+                               .addChildMatcher(rows, cols);
+       }
+       public static HopDagPatternMatcher matrix(double X, 
HopDagPatternMatcher rows, long cols) {
+               return new 
HopDagPatternMatcher().addPredicate("matrix_datagen", h -> 
HopRewriteUtils.isDataGenOpWithConstantValue(h, X) && 
+                               h.getDim2() == cols)
+                               .addChildMatcher(rows, dummy);
+       }
+       public static HopDagPatternMatcher matrix(double X, long rows, 
HopDagPatternMatcher cols) {
+               return new 
HopDagPatternMatcher().addPredicate("matrix_datagen", h -> 
HopRewriteUtils.isDataGenOpWithConstantValue(h, X) && 
+                               h.getDim1() == rows)
+                               .addChildMatcher(dummy, cols);
+       }
+       public static HopDagPatternMatcher matrix(double X, long rows, long 
cols) {
+               return new 
HopDagPatternMatcher().addPredicate("matrix_datagen", h -> 
HopRewriteUtils.isDataGenOpWithConstantValue(h, X) && 
+                               h.getDim1() == rows && h.getDim2() == cols)
+                               .addChildMatcher(dummy, dummy);
+       }
+       public static HopDagPatternMatcher bias_add(HopDagPatternMatcher 
child1, HopDagPatternMatcher child2) {
+               return new HopDagPatternMatcher().addPredicate("bias_add", h -> 
HopRewriteUtils.isDnn(h, OpOpDnn.BIASADD))
+                               .addChildMatcher(child1, child2);
+       }
+       public static HopDagPatternMatcher bias_multiply(HopDagPatternMatcher 
child1, HopDagPatternMatcher child2) {
+               return new HopDagPatternMatcher().addPredicate("bias_multiply", 
h -> HopRewriteUtils.isDnn(h, OpOpDnn.BIASMULT))
+                               .addChildMatcher(child1, child2);
+       }
+       public static HopDagPatternMatcher unaryMinus(HopDagPatternMatcher 
child) {
+               return new HopDagPatternMatcher().addPredicate("unaryMinus", h 
-> HopRewriteUtils.isBinary(h, OpOp2.MINUS)
+                               && 
HopRewriteUtils.isLiteralOfValue(h.getInput().get(0), 0))
+                               .addChildMatcher(dummy, child);
+       }
+       public static HopDagPatternMatcher sqrt(HopDagPatternMatcher child) {
+               return new HopDagPatternMatcher().addPredicate("sqrt", h -> 
HopRewriteUtils.isUnary(h, OpOp1.SQRT))
+                               .addChildMatcher(child);
+       }
+       public static HopDagPatternMatcher div(HopDagPatternMatcher child1, 
HopDagPatternMatcher child2) {
+               return new HopDagPatternMatcher().addPredicate("div", h -> 
HopRewriteUtils.isBinary(h, OpOp2.DIV))
+                               .addChildMatcher(child1, child2);
+       }
+       public static HopDagPatternMatcher div(double child1, 
HopDagPatternMatcher child2) {
+               return new HopDagPatternMatcher().addPredicate("div", h -> 
HopRewriteUtils.isBinary(h, OpOp2.DIV) && 
+                               
HopRewriteUtils.isLiteralOfValue(h.getInput().get(0), child1))
+                               .addChildMatcher(dummy, child2);
+       }
+       public static HopDagPatternMatcher div(HopDagPatternMatcher child1, 
double child2) {
+               return new HopDagPatternMatcher().addPredicate("div", h -> 
HopRewriteUtils.isBinary(h, OpOp2.DIV) && 
+                               
HopRewriteUtils.isLiteralOfValue(h.getInput().get(1), child2))
+                               .addChildMatcher(child1, dummy);
+       }
+       
+       public static HopDagPatternMatcher pow(HopDagPatternMatcher child1, 
HopDagPatternMatcher child2) {
+               return new HopDagPatternMatcher().addPredicate("pow", h -> 
HopRewriteUtils.isBinary(h, OpOp2.POW))
+                               .addChildMatcher(child1, child2);
+       }
+       public static HopDagPatternMatcher pow(HopDagPatternMatcher child1, 
double child2) {
+               return new HopDagPatternMatcher().addPredicate("pow", h -> 
HopRewriteUtils.isBinary(h, OpOp2.POW) && 
+                               
HopRewriteUtils.isLiteralOfValue(h.getInput().get(1), child2))
+                               .addChildMatcher(child1, dummy);
+       }
+       private static boolean matchDimensions(Hop h1, Hop h2) {
+               return h1.getDim1() == h2.getDim1() && h1.getDim2() == 
h2.getDim2();
+       }
+       // This is used to differentiate between matrix-matrix and 
matrix-vector operations.
+       public static HopDagPatternMatcher mm_plus(HopDagPatternMatcher child1, 
HopDagPatternMatcher child2) {
+               return new HopDagPatternMatcher().addPredicate("plus", h -> 
HopRewriteUtils.isBinary(h, OpOp2.PLUS)
+                               && matchDimensions(h.getInput().get(0), 
h.getInput().get(1)))
+                               .addChildMatcher(child1, child2);
+       }
+       public static HopDagPatternMatcher plus(HopDagPatternMatcher child1, 
HopDagPatternMatcher child2) {
+               return new HopDagPatternMatcher().addPredicate("plus", h -> 
HopRewriteUtils.isBinary(h, OpOp2.PLUS))
+                               .addChildMatcher(child1, child2);
+       }
+       public static HopDagPatternMatcher plus(double child1, 
HopDagPatternMatcher child2) {
+               return new HopDagPatternMatcher().addPredicate("plus", h -> 
HopRewriteUtils.isBinary(h, OpOp2.PLUS) && 
+                               
HopRewriteUtils.isLiteralOfValue(h.getInput().get(0), child1))
+                               .addChildMatcher(dummy, child2);
+       }
+       public static HopDagPatternMatcher plus(HopDagPatternMatcher child1, 
double child2) {
+               return new HopDagPatternMatcher().addPredicate("plus", h -> 
HopRewriteUtils.isBinary(h, OpOp2.PLUS) && 
+                               
HopRewriteUtils.isLiteralOfValue(h.getInput().get(1), child2))
+                               .addChildMatcher(child1, dummy);
+       }
+       public static HopDagPatternMatcher minus(HopDagPatternMatcher child1, 
HopDagPatternMatcher child2) {
+               return new HopDagPatternMatcher().addPredicate("minus", h -> 
HopRewriteUtils.isBinary(h, OpOp2.MINUS))
+                               .addChildMatcher(child1, child2);
+       }
+       public static HopDagPatternMatcher minus(double child1, 
HopDagPatternMatcher child2) {
+               return new HopDagPatternMatcher().addPredicate("minus", h -> 
HopRewriteUtils.isBinary(h, OpOp2.MINUS) && 
+                               
HopRewriteUtils.isLiteralOfValue(h.getInput().get(0), child1))
+                               .addChildMatcher(dummy, child2);
+       }
+       public static HopDagPatternMatcher minus(HopDagPatternMatcher child1, 
double child2) {
+               return new HopDagPatternMatcher().addPredicate("minus", h -> 
HopRewriteUtils.isBinary(h, OpOp2.MINUS) && 
+                               
HopRewriteUtils.isLiteralOfValue(h.getInput().get(1), child2))
+                               .addChildMatcher(child1, dummy);
+       }
+       public static HopDagPatternMatcher mult(HopDagPatternMatcher child1, 
HopDagPatternMatcher child2) {
+               return new HopDagPatternMatcher().addPredicate("mult", h -> 
HopRewriteUtils.isBinary(h, OpOp2.MULT))
+                               .addChildMatcher(child1, child2);
+       }
+       public static HopDagPatternMatcher mult(double child1, 
HopDagPatternMatcher child2) {
+               return new HopDagPatternMatcher().addPredicate("mult", h -> 
HopRewriteUtils.isBinary(h, OpOp2.MULT) && 
+                               
HopRewriteUtils.isLiteralOfValue(h.getInput().get(0), child1))
+                               .addChildMatcher(dummy, child2);
+       }
+       public static HopDagPatternMatcher mult(HopDagPatternMatcher child1, 
double child2) {
+               return new HopDagPatternMatcher().addPredicate("mult", h -> 
HopRewriteUtils.isBinary(h, OpOp2.MULT) && 
+                               
HopRewriteUtils.isLiteralOfValue(h.getInput().get(1), child2))
+                               .addChildMatcher(child1, dummy);
+       }
+       private static boolean _fitsOnGPU(Hop h, double multiplier) {
+               double memEst = multiplier*h.getMemEstimate();
+               return ConfigurationManager.isGPU() && h.dimsKnown() && 
OptimizerUtils.isMemoryBasedOptLevel() &&
+                               memEst < OptimizerUtils.getLocalMemBudget() && 
memEst < GPUContextPool.initialGPUMemBudget();
+       }
+}

http://git-wip-us.apache.org/repos/asf/systemml/blob/0f36780a/src/main/java/org/apache/sysml/hops/rewrite/HopPatternRewriter.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/rewrite/HopPatternRewriter.java 
b/src/main/java/org/apache/sysml/hops/rewrite/HopPatternRewriter.java
new file mode 100644
index 0000000..02472ed
--- /dev/null
+++ b/src/main/java/org/apache/sysml/hops/rewrite/HopPatternRewriter.java
@@ -0,0 +1,72 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sysml.hops.rewrite;
+
+import java.util.function.Function;
+
+import org.apache.sysml.hops.Hop;
+
+/**
+ * This class is used with HopRewriteRuleWithPatternMatcher to implement the 
following pattern matching logic:
+ * ArrayList<HopPatternRewriter> patternRewriters =  getPatternRewriter();
+ * for(HopPatternRewriter patternRewriter : patternRewriters) {
+ *   hi = patternRewriter.rewrite(hi);
+ * }
+ * 
+ * Please see org.apache.sysml.hops.rewrite.RewriteGPUSpecificOps class for 
usage and design documentation.
+ */
+public class HopPatternRewriter {
+       private final HopDagPatternMatcher _matcher;
+       private final Function<Hop, Hop> _replacer;
+       private final String _name;
+       public HopPatternRewriter(String name, HopDagPatternMatcher matcher, 
Function<Hop, Hop> replacer) {
+               _name = name;
+               _matcher = matcher;
+               _replacer = replacer;
+       }
+       
+       public Hop rewrite(Hop root) {
+               boolean printMessage = HopDagPatternMatcher.DEBUG_PATTERNS != 
null && HopDagPatternMatcher.DEBUG_PATTERNS.contains(_name);
+               if(printMessage) {
+                       HopDagPatternMatcher.DEBUG_REWRITES = true;
+                       
System.out.println("-----------------------------------");
+                       
System.out.println(org.apache.sysml.utils.Explain.explain(root));
+               }
+               if(_matcher.matches(root)) {
+                       Hop newHop = _replacer.apply(root);
+                       if(printMessage) {
+                               if(newHop == root)
+                                       System.out.println("Initial pattern 
match for " + _name + " succeeded but replacer returned the same HOP.");
+                               else
+                                       System.out.println("Pattern match for " 
+ _name + " succeeded.");
+                               HopDagPatternMatcher.DEBUG_REWRITES = false;
+                               
System.out.println("-----------------------------------");
+                       }
+                       return newHop;
+               }
+               else {
+                       if(printMessage) {
+                               System.out.println("Pattern match for " + _name 
+ " failed.");
+                               HopDagPatternMatcher.DEBUG_REWRITES = false;
+                               
System.out.println("-----------------------------------");
+                       }
+                       return root;
+               }
+       }
+}

http://git-wip-us.apache.org/repos/asf/systemml/blob/0f36780a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteRuleWithPatternMatcher.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteRuleWithPatternMatcher.java
 
b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteRuleWithPatternMatcher.java
new file mode 100644
index 0000000..854eca3
--- /dev/null
+++ 
b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteRuleWithPatternMatcher.java
@@ -0,0 +1,98 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sysml.hops.rewrite;
+
+import java.util.ArrayList;
+
+import org.apache.sysml.hops.Hop;
+
+/**
+ * Simple utility class that implements generic structure for HopRewriteRule.
+ * Please see org.apache.sysml.hops.rewrite.RewriteGPUSpecificOps class for 
usage and design documentation.
+ */
+public abstract class HopRewriteRuleWithPatternMatcher extends HopRewriteRule {
+       
+       public abstract ArrayList<HopPatternRewriter> getPatternRewriter();
+       
+       @Override
+       public ArrayList<Hop> rewriteHopDAGs(ArrayList<Hop> roots, 
ProgramRewriteStatus state) {
+               if( roots == null )
+                       return roots;
+
+               //one pass rewrite-descend (rewrite created pattern)
+               for( int i = 0; i < roots.size(); i++ )
+                       applyRules(roots, roots.get(i), false );
+               Hop.resetVisitStatus(roots, true);
+
+               //one pass descend-rewrite (for rollup) 
+               for( int i = 0; i < roots.size(); i++ )
+                       applyRules(roots, roots.get(i), true );
+               Hop.resetVisitStatus(roots, true);
+               
+               return roots;
+       }
+
+       @Override
+       public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus state) {
+               if( root == null )
+                       return root;
+               
+               //one pass rewrite-descend (rewrite created pattern)
+               applyRules(null, root, false );
+               
+               root.resetVisitStatus();
+               
+               //one pass descend-rewrite (for rollup) 
+               applyRules(null, root, true );
+               
+               return root;
+       }
+       
+       /**
+        * Apply rules
+        * 
+        * @param roots root operators
+        * @param hop high-level operator
+        * @param descendFirst true if recursively process children first
+        */
+       private void applyRules(ArrayList<Hop> roots, Hop hop, boolean 
descendFirst) 
+       {
+               if(hop.isVisited())
+                       return;
+               
+               //recursively process children
+               for( int i=0; i<hop.getInput().size(); i++) {
+                       Hop hi = hop.getInput().get(i);
+                       
+                       //process childs recursively first (to allow roll-up)
+                       if( descendFirst )
+                               applyRules(roots, hi, descendFirst); //see below
+                       
+                       ArrayList<HopPatternRewriter> patternRewriters =  
getPatternRewriter();
+                       for(HopPatternRewriter patternRewriter : 
patternRewriters) {
+                               hi = patternRewriter.rewrite(hi);
+                       }
+                       
+                       if( !descendFirst )
+                               applyRules(roots, hi, descendFirst);
+               }
+
+               hop.setVisited();
+       }
+}

http://git-wip-us.apache.org/repos/asf/systemml/blob/0f36780a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java 
b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
index 271142d..2351f5f 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
@@ -719,6 +719,26 @@ public class HopRewriteUtils
                return ternOp;
        }
        
+       public static DnnOp createDnnOp(OpOpDnn op, Hop... hops) {
+               ArrayList<Hop> inHops = new ArrayList<Hop>();
+               for(Hop h : hops) {
+                       inHops.add(h);
+               }
+               return  new DnnOp("tmp", DataType.MATRIX, ValueType.DOUBLE,
+                               op, inHops);
+       }
+       
+       public static DnnOp createDnnOp(HopDagPatternMatcher matcher, OpOpDnn 
op, String... varNames) {
+               ArrayList<Hop> inHops = new ArrayList<Hop>();
+               for(String v : varNames) {
+                       inHops.add(matcher.getMatchedHop(v));
+               }
+               return  new DnnOp("tmp", DataType.MATRIX, ValueType.DOUBLE,
+                               op, inHops);
+       }
+       
+       
+       
        public static void setOutputParameters( Hop hop, long rlen, long clen, 
int brlen, int bclen, long nnz ) {
                hop.setDim1( rlen );
                hop.setDim2( clen );

Reply via email to