You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ha...@apache.org on 2019/05/30 21:01:00 UTC

[incubator-mxnet] branch master updated: fix the if condition for LayerNorm (#15094)

This is an automated email from the ASF dual-hosted git repository.

haibin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 6b8e107  fix the if condition for LayerNorm (#15094)
6b8e107 is described below

commit 6b8e107f19d994dc44e408b809e97e79ea5b44e3
Author: Tao Lv <ta...@intel.com>
AuthorDate: Fri May 31 05:00:14 2019 +0800

    fix the if condition for LayerNorm (#15094)
---
 src/operator/nn/layer_norm.cc | 5 +++--
 1 file changed, 3 insertions(+), 2 deletions(-)

diff --git a/src/operator/nn/layer_norm.cc b/src/operator/nn/layer_norm.cc
index 7404e04..e95f472 100644
--- a/src/operator/nn/layer_norm.cc
+++ b/src/operator/nn/layer_norm.cc
@@ -86,8 +86,9 @@ void LayerNormComputeMKL(const nnvm::NodeAttrs& attrs,
   CHECK_EQ(inputs.size(), 3U);
   int axis = GetRealAxis(param.axis, inputs[0].ndim());
 
-  if (axis == (inputs[layernorm::kData].ndim() - 1) ||
-      (inputs[0].type_flag_ != kFloat32 && inputs[0].type_flag_ != kFloat64)) {
+  // This optimization only applys for LayerNorm on the last dimension with dtype FP32 or FP64.
+  if (axis == (inputs[layernorm::kData].ndim() - 1) &&
+      (inputs[0].type_flag_ == kFloat32 || inputs[0].type_flag_ == kFloat64)) {
     // Compute necessary data for the reduce operation.
     mxnet::TShape red_src_shape, red_dst_shape;
     BroadcastReduceShapeCompact(inputs[layernorm::kData].shape_, outputs[layernorm::kMean].shape_,