This is an automated email from the ASF dual-hosted git repository.
haibin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 6b8e107 fix the if condition for LayerNorm (#15094)
6b8e107 is described below
commit 6b8e107f19d994dc44e408b809e97e79ea5b44e3
Author: Tao Lv <[email protected]>
AuthorDate: Fri May 31 05:00:14 2019 +0800
fix the if condition for LayerNorm (#15094)
---
src/operator/nn/layer_norm.cc | 5 +++--
1 file changed, 3 insertions(+), 2 deletions(-)
diff --git a/src/operator/nn/layer_norm.cc b/src/operator/nn/layer_norm.cc
index 7404e04..e95f472 100644
--- a/src/operator/nn/layer_norm.cc
+++ b/src/operator/nn/layer_norm.cc
@@ -86,8 +86,9 @@ void LayerNormComputeMKL(const nnvm::NodeAttrs& attrs,
CHECK_EQ(inputs.size(), 3U);
int axis = GetRealAxis(param.axis, inputs[0].ndim());
- if (axis == (inputs[layernorm::kData].ndim() - 1) ||
- (inputs[0].type_flag_ != kFloat32 && inputs[0].type_flag_ != kFloat64)) {
+ // This optimization only applys for LayerNorm on the last dimension with
dtype FP32 or FP64.
+ if (axis == (inputs[layernorm::kData].ndim() - 1) &&
+ (inputs[0].type_flag_ == kFloat32 || inputs[0].type_flag_ == kFloat64)) {
// Compute necessary data for the reduce operation.
mxnet::TShape red_src_shape, red_dst_shape;
BroadcastReduceShapeCompact(inputs[layernorm::kData].shape_,
outputs[layernorm::kMean].shape_,