This is an automated email from the ASF dual-hosted git repository.

haibin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 6b8e107  fix the if condition for LayerNorm (#15094)
6b8e107 is described below

commit 6b8e107f19d994dc44e408b809e97e79ea5b44e3
Author: Tao Lv <[email protected]>
AuthorDate: Fri May 31 05:00:14 2019 +0800

    fix the if condition for LayerNorm (#15094)
---
 src/operator/nn/layer_norm.cc | 5 +++--
 1 file changed, 3 insertions(+), 2 deletions(-)

diff --git a/src/operator/nn/layer_norm.cc b/src/operator/nn/layer_norm.cc
index 7404e04..e95f472 100644
--- a/src/operator/nn/layer_norm.cc
+++ b/src/operator/nn/layer_norm.cc
@@ -86,8 +86,9 @@ void LayerNormComputeMKL(const nnvm::NodeAttrs& attrs,
   CHECK_EQ(inputs.size(), 3U);
   int axis = GetRealAxis(param.axis, inputs[0].ndim());
 
-  if (axis == (inputs[layernorm::kData].ndim() - 1) ||
-      (inputs[0].type_flag_ != kFloat32 && inputs[0].type_flag_ != kFloat64)) {
+  // This optimization only applys for LayerNorm on the last dimension with 
dtype FP32 or FP64.
+  if (axis == (inputs[layernorm::kData].ndim() - 1) &&
+      (inputs[0].type_flag_ == kFloat32 || inputs[0].type_flag_ == kFloat64)) {
     // Compute necessary data for the reduce operation.
     mxnet::TShape red_src_shape, red_dst_shape;
     BroadcastReduceShapeCompact(inputs[layernorm::kData].shape_, 
outputs[layernorm::kMean].shape_,

Reply via email to