You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2018/03/10 01:37:02 UTC

[incubator-mxnet] branch master updated: [MXNET-58]Layer Normalization in C++ (#10029)

This is an automated email from the ASF dual-hosted git repository.

zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 279ccb1  [MXNET-58]Layer Normalization in C++ (#10029)
279ccb1 is described below

commit 279ccb1c77d7fdfb2652a4f6574466ea1ecb3a09
Author: Xingjian Shi <xs...@ust.hk>
AuthorDate: Fri Mar 9 17:36:58 2018 -0800

    [MXNET-58]Layer Normalization in C++ (#10029)
    
    * add layer_norm + fix batch_norm doc
    
    * add test
    
    * add layer normaliation in Gluon
    
    * update
    
    * fix __repr__ + lint
    
    * fix doc
    
    * fix threshold
    
    * fix doc
    
    * fix bug
    
    * enable inplace + fix test
    
    * try to fix test
    
    * fix doc
---
 docs/api/python/gluon/nn.md            |   1 +
 docs/api/python/ndarray/ndarray.md     |   1 +
 docs/api/python/symbol/symbol.md       |   1 +
 python/mxnet/gluon/nn/basic_layers.py  |  99 ++++++++++-
 src/operator/nn/batch_norm-inl.h       |   2 +-
 src/operator/nn/batch_norm.cc          |   3 +-
 src/operator/nn/layer_norm-inl.h       | 296 +++++++++++++++++++++++++++++++++
 src/operator/nn/layer_norm.cc          | 148 +++++++++++++++++
 src/operator/nn/layer_norm.cu          |  37 +++++
 tests/python/unittest/test_gluon.py    |   5 +
 tests/python/unittest/test_operator.py |  41 +++++
 11 files changed, 627 insertions(+), 7 deletions(-)

diff --git a/docs/api/python/gluon/nn.md b/docs/api/python/gluon/nn.md
index 91b93bd..1001f20 100644
--- a/docs/api/python/gluon/nn.md
+++ b/docs/api/python/gluon/nn.md
@@ -20,6 +20,7 @@ This document lists the neural network blocks in Gluon:
     Dropout
     BatchNorm
     InstanceNorm
+    LayerNorm
     Embedding
     Flatten
 ```
diff --git a/docs/api/python/ndarray/ndarray.md b/docs/api/python/ndarray/ndarray.md
index 59ca4a6..08acc1a 100644
--- a/docs/api/python/ndarray/ndarray.md
+++ b/docs/api/python/ndarray/ndarray.md
@@ -640,6 +640,7 @@ The `ndarray` package provides several classes:
     Embedding
     LeakyReLU
     InstanceNorm
+    LayerNorm
     L2Normalization
     LRN
     ROIPooling
diff --git a/docs/api/python/symbol/symbol.md b/docs/api/python/symbol/symbol.md
index e383597..2eceead 100644
--- a/docs/api/python/symbol/symbol.md
+++ b/docs/api/python/symbol/symbol.md
@@ -641,6 +641,7 @@ Composite multiple symbols into a new one by an operator.
     Embedding
     LeakyReLU
     InstanceNorm
+    LayerNorm
     L2Normalization
     LRN
     ROIPooling
diff --git a/python/mxnet/gluon/nn/basic_layers.py b/python/mxnet/gluon/nn/basic_layers.py
index 9dc1a24..eb33199 100644
--- a/python/mxnet/gluon/nn/basic_layers.py
+++ b/python/mxnet/gluon/nn/basic_layers.py
@@ -19,7 +19,7 @@
 # pylint: disable= arguments-differ
 """Basic neural network layers."""
 __all__ = ['Sequential', 'HybridSequential', 'Dense', 'Dropout', 'Embedding',
-           'BatchNorm', 'InstanceNorm', 'Flatten', 'Lambda', 'HybridLambda']
+           'BatchNorm', 'InstanceNorm', 'LayerNorm', 'Flatten', 'Lambda', 'HybridLambda']
 import warnings
 import numpy as np
 
@@ -419,14 +419,18 @@ class InstanceNorm(HybridBlock):
 
     .. math::
 
-      out = \frac{x - mean[data]}{ \sqrt{Var[data]} + \epsilon} * gamma + beta
+      \bar{C} = \{i \mid i \neq 0, i \neq axis\}
+
+      out = \frac{x - mean[data, \bar{C}]}{ \sqrt{Var[data, \bar{C}]} + \epsilon}
+       * gamma + beta
 
     Parameters
     ----------
     axis : int, default 1
-        The axis that should be normalized. This is typically the channels
+        The axis that will be excluded in the normalization process. This is typically the channels
         (C) axis. For instance, after a `Conv2D` layer with `layout='NCHW'`,
-        set `axis=1` in `InstanceNorm`. If `layout='NHWC'`, then set `axis=3`.
+        set `axis=1` in `InstanceNorm`. If `layout='NHWC'`, then set `axis=3`. Data will be
+        normalized along axes excluding the first axis and the axis given.
     epsilon: float, default 1e-5
         Small float added to variance to avoid dividing by zero.
     center: bool, default True
@@ -475,7 +479,7 @@ class InstanceNorm(HybridBlock):
                  beta_initializer='zeros', gamma_initializer='ones',
                  in_channels=0, **kwargs):
         super(InstanceNorm, self).__init__(**kwargs)
-        self._kwargs = {'eps': epsilon, 'axis': axis}
+        self._kwargs = {'eps': epsilon, 'axis': axis, 'center': center, 'scale': scale}
         self._axis = axis
         self._epsilon = epsilon
         self.gamma = self.params.get('gamma', grad_req='write' if scale else 'null',
@@ -502,6 +506,91 @@ class InstanceNorm(HybridBlock):
                         content=', '.join(['='.join([k, v.__repr__()])
                                            for k, v in self._kwargs.items()]))
 
+
+class LayerNorm(HybridBlock):
+    r"""
+    Applies layer normalization to the n-dimensional input array.
+    This operator takes an n-dimensional input array and normalizes
+    the input using the given axis:
+
+    .. math::
+
+      out = \frac{x - mean[data, axis]}{ \sqrt{Var[data, axis]} + \epsilon} * gamma + beta
+
+    Parameters
+    ----------
+    axis : int, default -1
+        The axis that should be normalized. This is typically the axis of the channels.
+    epsilon: float, default 1e-5
+        Small float added to variance to avoid dividing by zero.
+    center: bool, default True
+        If True, add offset of `beta` to normalized tensor.
+        If False, `beta` is ignored.
+    scale: bool, default True
+        If True, multiply by `gamma`. If False, `gamma` is not used.
+    beta_initializer: str or `Initializer`, default 'zeros'
+        Initializer for the beta weight.
+    gamma_initializer: str or `Initializer`, default 'ones'
+        Initializer for the gamma weight.
+    in_channels : int, default 0
+        Number of channels (feature maps) in input data. If not specified,
+        initialization will be deferred to the first time `forward` is called
+        and `in_channels` will be inferred from the shape of input data.
+
+
+    Inputs:
+        - **data**: input tensor with arbitrary shape.
+
+    Outputs:
+        - **out**: output tensor with the same shape as `data`.
+
+    References
+    ----------
+        `Layer Normalization
+        <https://arxiv.org/pdf/1607.06450.pdf>`_
+
+    Examples
+    --------
+    >>> # Input of shape (2, 5)
+    >>> x = mx.nd.array([[1, 2, 3, 4, 5], [1, 1, 2, 2, 2]])
+    >>> # Layer normalization is calculated with the above formula
+    >>> layer = LayerNorm()
+    >>> layer.initialize(ctx=mx.cpu(0))
+    >>> layer(x)
+    [[-1.41421    -0.707105    0.          0.707105    1.41421   ]
+     [-1.2247195  -1.2247195   0.81647956  0.81647956  0.81647956]]
+    <NDArray 2x5 @cpu(0)>
+    """
+    def __init__(self, axis=-1, epsilon=1e-5, center=True, scale=True,
+                 beta_initializer='zeros', gamma_initializer='ones',
+                 in_channels=0, prefix=None, params=None):
+        super(LayerNorm, self).__init__(prefix=prefix, params=params)
+        self._kwargs = {'eps': epsilon, 'axis': axis, 'center': center, 'scale': scale}
+        self._axis = axis
+        self._epsilon = epsilon
+        self._center = center
+        self._scale = scale
+        self.gamma = self.params.get('gamma', grad_req='write' if scale else 'null',
+                                     shape=(in_channels,), init=gamma_initializer,
+                                     allow_deferred_init=True)
+        self.beta = self.params.get('beta', grad_req='write' if center else 'null',
+                                    shape=(in_channels,), init=beta_initializer,
+                                    allow_deferred_init=True)
+
+    def hybrid_forward(self, F, data, gamma, beta):
+        norm_data = F.LayerNorm(data, gamma=gamma, beta=beta, axis=self._axis, eps=self._epsilon)
+        return norm_data
+
+    def __repr__(self):
+        s = '{name}({content}'
+        in_channels = self.gamma.shape[0]
+        s += ', in_channels={0}'.format(in_channels)
+        s += ')'
+        return s.format(name=self.__class__.__name__,
+                        content=', '.join(['='.join([k, v.__repr__()])
+                                           for k, v in self._kwargs.items()]))
+
+
 class Lambda(Block):
     r"""Wraps an operator or an expression as a Block object.
 
diff --git a/src/operator/nn/batch_norm-inl.h b/src/operator/nn/batch_norm-inl.h
index 8418ec3..48638de 100644
--- a/src/operator/nn/batch_norm-inl.h
+++ b/src/operator/nn/batch_norm-inl.h
@@ -79,7 +79,7 @@ struct BatchNormParam : public dmlc::Parameter<BatchNormParam> {
     .describe("Whether use global moving statistics instead of local batch-norm. "
               "This will force change batch-norm into a scale shift operator.");
     DMLC_DECLARE_FIELD(output_mean_var).set_default(false)
-    .describe("Output All,normal mean and var");
+    .describe("Output the mean and inverse std ");
     DMLC_DECLARE_FIELD(axis).set_default(mxnet::op::batchnorm::DEFAULT_AXIS)
       .describe("Specify which shape axis the channel is specified");
     DMLC_DECLARE_FIELD(cudnn_off).set_default(false)
diff --git a/src/operator/nn/batch_norm.cc b/src/operator/nn/batch_norm.cc
index e577125..c8b5d58 100644
--- a/src/operator/nn/batch_norm.cc
+++ b/src/operator/nn/batch_norm.cc
@@ -510,7 +510,8 @@ Both *mean* and *var* returns a scalar by treating the input as a vector.
 
 Assume the input has size *k* on axis 1, then both ``gamma`` and ``beta``
 have shape *(k,)*. If ``output_mean_var`` is set to be true, then outputs both ``data_mean`` and
-``data_var`` as well, which are needed for the backward pass.
+the inverse of ``data_var``, which are needed for the backward pass. Note that gradient of these 
+two outputs are blocked.
 
 Besides the inputs and the outputs, this operator accepts two auxiliary
 states, ``moving_mean`` and ``moving_var``, which are *k*-length
diff --git a/src/operator/nn/layer_norm-inl.h b/src/operator/nn/layer_norm-inl.h
new file mode 100644
index 0000000..ff429df
--- /dev/null
+++ b/src/operator/nn/layer_norm-inl.h
@@ -0,0 +1,296 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2015 by Contributors
+ * \file layer_norm-inl.h
+ * \brief Implements Ba et. al, Layer Normalization (https://arxiv.org/abs/1607.06450).
+*/
+#ifndef MXNET_OPERATOR_NN_LAYER_NORM_INL_H_
+#define MXNET_OPERATOR_NN_LAYER_NORM_INL_H_
+
+#include <dmlc/logging.h>
+#include <dmlc/parameter.h>
+#include <mxnet/operator.h>
+#include <mshadow/base.h>
+#include <map>
+#include <algorithm>
+#include <vector>
+#include <string>
+#include <utility>
+#include "../mshadow_op.h"
+#include "../operator_common.h"
+#include "../mxnet_op.h"
+#include "../tensor/broadcast_reduce_op.h"
+
+namespace mxnet {
+namespace op {
+
+namespace layernorm {
+enum LayerNormOpInputs {kData, kGamma, kBeta};  // kGamma: scaling parameters, kBeta: shift biases
+enum LayerNormOpOutputs {kOut, kMean, kStd};  // req, out_data
+}  // namespace layernorm
+
+struct LayerNormParam : public dmlc::Parameter<LayerNormParam> {
+  int axis;
+  float eps;
+  bool output_mean_var;
+  DMLC_DECLARE_PARAMETER(LayerNormParam) {
+    DMLC_DECLARE_FIELD(axis).set_default(-1)
+      .describe("The axis to perform layer normalization. "
+                "Usually, this should be be axis of the channel dimension. "
+                "Negative values means indexing from right to left.");
+    DMLC_DECLARE_FIELD(eps).set_default(1e-5f)
+      .describe("An `epsilon` parameter to prevent division by 0.");
+    DMLC_DECLARE_FIELD(output_mean_var).set_default(false)
+      .describe("Output the mean and std calculated along the given axis.");
+  }
+};
+
+
+template<typename xpu>
+void LayerNormCompute(const nnvm::NodeAttrs& attrs,
+                      const OpContext& ctx, const std::vector<TBlob>& inputs,
+                      const std::vector<OpReqType>& req,
+                      const std::vector<TBlob>& outputs) {
+  using namespace mshadow;
+  using namespace mshadow::expr;
+  const LayerNormParam& param = nnvm::get<LayerNormParam>(attrs.parsed);
+  if (req[0] == kNullOp) return;
+  CHECK_NE(req[0], kAddTo);
+  int axis = param.axis;
+  if (axis < 0) {
+    axis += static_cast<int>(inputs[0].ndim());
+  }
+  CHECK(axis >= 0 && axis < inputs[0].ndim()) << "Channel axis out of range: " << param.axis;
+  CHECK_EQ(inputs.size(), 3U);
+  Stream<xpu> *s = ctx.get_stream<xpu>();
+  // Reshape gamma and beta to be broadcastable
+  TShape new_param_shape(inputs[0].shape_.begin(), inputs[0].shape_.end());
+  for (int i = 0; i < inputs[0].ndim(); i++) {
+    if (i != axis) {
+      new_param_shape[i] = 1;
+    }
+  }
+  const TBlob gamma = inputs[1].reshape(new_param_shape);
+  const TBlob beta = inputs[2].reshape(new_param_shape);
+  // Compute necessary data for the reduce operation.
+  TShape red_src_shape, red_dst_shape;
+  BroadcastReduceShapeCompact(inputs[0].shape_, outputs[layernorm::kMean].shape_,
+                              &red_src_shape, &red_dst_shape);
+  const TBlob in_data = inputs[0].reshape(red_src_shape);
+  const TBlob mean_data = outputs[layernorm::kMean].reshape(red_dst_shape);
+  const TBlob std_data = outputs[layernorm::kStd].reshape(red_dst_shape);
+  int channel_size = red_src_shape.Size() / red_dst_shape.Size();
+  // Initialize the workspace
+  Tensor<xpu, 1, char> workspace;
+  size_t workspace_size = 0;
+  MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
+    BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, {
+      workspace_size = broadcast::ReduceWorkspaceSize<NDim, DType>(s, mean_data, req[0], in_data);
+    });
+  });
+  workspace = ctx.requested[0].get_space_typed<xpu, 1, char>(Shape1(workspace_size), s);
+  // Calculate mean
+  MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
+    BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, {
+      broadcast::Reduce<red::sum, NDim, DType, op::mshadow_op::identity>(
+        s, mean_data, req[0], workspace, in_data);
+      Tensor<xpu, 1, DType> mean_data_tensor = mean_data.FlatTo1D<xpu, DType>(s);
+      mean_data_tensor /= scalar<DType>(channel_size);
+    });
+  });
+  // Calculate data = data - mean
+  BinaryBroadcastCompute<xpu, op::mshadow_op::minus>(attrs, ctx,
+                                                     {inputs[0], outputs[layernorm::kMean]},
+                                                     {kWriteTo}, {outputs[0]});
+  // Calculate std
+  const TBlob centered_out = outputs[0].reshape(red_src_shape);
+  MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
+    BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, {
+      broadcast::Reduce<red::sum, NDim, DType, op::mshadow_op::square>(
+        s, std_data, req[0], workspace, centered_out);
+      Tensor<xpu, 1, DType> std_data_tensor = std_data.FlatTo1D<xpu, DType>(s);
+      std_data_tensor = F<mshadow_op::square_root>(std_data_tensor / scalar<DType>(channel_size)
+                        + scalar<DType>(param.eps));
+    });
+  });
+  // Calculate data = data / std
+  BinaryBroadcastCompute<xpu, op::mshadow_op::div>(attrs, ctx,
+                                                   {outputs[0], outputs[layernorm::kStd]},
+                                                   {kWriteTo}, {outputs[0]});
+  // Calculate data = data * gamma
+  BinaryBroadcastCompute<xpu, op::mshadow_op::mul>(attrs, ctx,
+                                                   {outputs[0], gamma},
+                                                   {kWriteTo}, {outputs[0]});
+  // Calculate data = data + beta
+  BinaryBroadcastCompute<xpu, op::mshadow_op::plus>(attrs, ctx,
+                                                   {outputs[0], beta},
+                                                   {kWriteTo}, {outputs[0]});
+}
+
+/*
+Calculate the gradient of layer normalization.
+We have the following gradient for gamma, beta and x:
+
+\bar{x} = (x - mean) / std
+w = og * r / std
+grad_gamma = sum(\bar{x} og, exclude_axis)
+grad_beta = sum(og, exclude_axis)
+grad_x = w - mean(w, axis) - \bar{x} * mean(w * \bar{x}, axis)
+*/
+template<typename xpu>
+void LayerNormGradCompute(const nnvm::NodeAttrs& attrs,
+                          const OpContext& ctx, const std::vector<TBlob>& inputs,
+                          const std::vector<OpReqType>& req,
+                          const std::vector<TBlob>& outputs) {
+  using namespace mshadow;
+  using namespace mshadow::expr;
+  CHECK_EQ(inputs.size(), 5U);
+  const LayerNormParam& param = nnvm::get<LayerNormParam>(attrs.parsed);
+  int axis = param.axis;
+  if (axis < 0) {
+    axis += static_cast<int>(inputs[0].ndim());
+  }
+  CHECK(axis >= 0 && axis < inputs[0].ndim()) << "Channel axis out of range: " << param.axis;
+  Stream<xpu> *s = ctx.get_stream<xpu>();
+  // Reshape gamma to be broadcastable
+  TShape new_param_shape(inputs[0].shape_.begin(), inputs[0].shape_.end());
+  for (int i = 0; i < inputs[0].ndim(); i++) {
+    if (i != axis) {
+      new_param_shape[i] = 1;
+    }
+  }
+  const TBlob ograd = inputs[0];
+  const TBlob data = inputs[1];
+  const TBlob gamma = inputs[2].reshape(new_param_shape);
+  const TBlob mean = inputs[3];
+  const TBlob std = inputs[4];
+  // Prepare the necessary shapes for reduction
+  TShape red_src_shape, red_dst_shape, red_exclude_src_shape, red_exclude_dst_shape;
+  BroadcastReduceShapeCompact(ograd.shape_, mean.shape_, &red_src_shape, &red_dst_shape);
+  BroadcastReduceShapeCompact(ograd.shape_, gamma.shape_,
+                              &red_exclude_src_shape, &red_exclude_dst_shape);
+  int channel_size = red_src_shape.Size() / red_dst_shape.Size();
+  // Initialize the workspace + Construct the temporary TBlobs
+  Tensor<xpu, 1, char> workspace;
+  size_t reduce_workspace_size = 0;
+  size_t data_size = 0;
+  size_t red_out_size = 0;
+  MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
+    data_size = sizeof(DType) * data.Size();
+    red_out_size = sizeof(DType) * mean.Size();
+    // There are two types of reduction workloads: reduce over axis and reduce exclude axis
+    // We take the maximum of the workspace sizes required by these workloads.
+    // Also, we explicitly set the req_type=kAddto in case we want to use it.
+    BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, {
+      reduce_workspace_size =
+        std::max(reduce_workspace_size,
+                 broadcast::ReduceWorkspaceSize<NDim, DType>(
+                   s, ograd.reshape(red_src_shape), kAddTo,
+                   mean.reshape(red_dst_shape)));
+    });
+    BROADCAST_NDIM_SWITCH(red_exclude_dst_shape.ndim(), NDim, {
+      reduce_workspace_size =
+        std::max(reduce_workspace_size,
+                 broadcast::ReduceWorkspaceSize<NDim, DType>(
+                   s, ograd.reshape(red_exclude_src_shape), kAddTo,
+                   gamma.reshape(red_exclude_dst_shape)));
+    });
+  });
+  workspace = ctx.requested[0].get_space_typed<xpu, 1, char>(
+    Shape1(reduce_workspace_size + data_size * 2 + red_out_size), s);
+  const TBlob normalized_data = TBlob(workspace.dptr_ + reduce_workspace_size,
+                                      data.shape_, data.dev_mask(), data.type_flag_, data.dev_id());
+  const TBlob ograd_mult = TBlob(workspace.dptr_ + reduce_workspace_size + data_size,
+                                 ograd.shape_, ograd.dev_mask(), ograd.type_flag_, ograd.dev_id());
+  const TBlob red_out = TBlob(workspace.dptr_ + reduce_workspace_size + data_size * 2,
+                              mean.shape_, mean.dev_mask(), mean.type_flag_, mean.dev_id());
+  // Compute normalized_data = (data - mean) / std
+  BinaryBroadcastCompute<xpu, op::mshadow_op::minus>(attrs, ctx,
+                                                    {data, mean},
+                                                    {kWriteTo}, {normalized_data});
+  BinaryBroadcastCompute<xpu, op::mshadow_op::div>(attrs, ctx,
+                                                   {normalized_data, std},
+                                                   {kWriteTo}, {normalized_data});
+  // Calculate grad_beta
+  if (req[2] != kNullOp) {
+    MSHADOW_REAL_TYPE_SWITCH(outputs[2].type_flag_, DType, {
+      BROADCAST_NDIM_SWITCH(red_exclude_dst_shape.ndim(), NDim, {
+        broadcast::Reduce<red::sum, NDim, DType, op::mshadow_op::identity>(
+          s, outputs[2].reshape(red_exclude_dst_shape), req[2], workspace,
+          ograd.reshape(red_exclude_src_shape));
+      });
+    });
+  }
+  // Calculate grad_gamma, it will be sum(ograd * normalized_data, exclude_axis)
+  ElemwiseBinaryOp::Compute<xpu, op::mshadow_op::mul>(attrs, ctx, {normalized_data, ograd},
+                                                      {kWriteTo}, {ograd_mult});
+  if (req[1] != kNullOp) {
+    MSHADOW_REAL_TYPE_SWITCH(outputs[1].type_flag_, DType, {
+      BROADCAST_NDIM_SWITCH(red_exclude_dst_shape.ndim(), NDim, {
+        broadcast::Reduce<red::sum, NDim, DType, op::mshadow_op::identity>(
+          s, outputs[1].reshape(red_exclude_dst_shape), req[1], workspace,
+          ograd_mult.reshape(red_exclude_src_shape));
+      });
+    });
+  }
+  // Calculate grad_data:
+  //   ograd_mult = ograd * gamma / std
+  //   grad_data = ograd_mult - mean(ograd_mult, axis)
+  //               + normalized_data * (-mean(normalized_data * ograd_mult, axis))
+  if (req[0] != kNullOp) {
+    BinaryBroadcastCompute<xpu, op::mshadow_op::mul>(attrs, ctx,
+                                                    {ograd, gamma},
+                                                    {kWriteTo}, {ograd_mult});
+    BinaryBroadcastCompute<xpu, op::mshadow_op::div>(attrs, ctx,
+                                                    {ograd_mult, std},
+                                                    {kWriteTo}, {ograd_mult});
+    MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
+      BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, {
+        broadcast::Reduce<red::sum, NDim, DType, op::mshadow_op::identity>(
+          s, red_out.reshape(red_dst_shape), kWriteTo, workspace,
+          ograd_mult.reshape(red_src_shape));
+      });
+      Tensor<xpu, 1, DType> red_out_tensor = red_out.FlatTo1D<xpu, DType>(s);
+      red_out_tensor /= scalar<DType>(channel_size);
+    });
+    BinaryBroadcastCompute<xpu, op::mshadow_op::minus>(attrs, ctx,
+                                                      {ograd_mult, red_out},
+                                                      {req[0]}, {outputs[0]});
+    ElemwiseBinaryOp::Compute<xpu, op::mshadow_op::mul>(attrs, ctx, {ograd_mult, normalized_data},
+                                                        {kWriteTo}, {ograd_mult});
+    MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
+      BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, {
+        broadcast::Reduce<red::sum, NDim, DType, op::mshadow_op::identity>(
+          s, red_out.reshape(red_dst_shape), kWriteTo, workspace,
+          ograd_mult.reshape(red_src_shape));
+      });
+      Tensor<xpu, 1, DType> red_out_tensor = red_out.FlatTo1D<xpu, DType>(s);
+      red_out_tensor /=  scalar<DType>(- channel_size);
+    });
+    BinaryBroadcastCompute<xpu, op::mshadow_op::mul>(attrs, ctx,
+                                                     {normalized_data, red_out},
+                                                     {kAddTo}, {outputs[0]});
+  }
+}
+
+}  // namespace op
+}  // namespace mxnet
+#endif  // MXNET_OPERATOR_NN_LAYER_NORM_INL_H_
diff --git a/src/operator/nn/layer_norm.cc b/src/operator/nn/layer_norm.cc
new file mode 100644
index 0000000..3a24242
--- /dev/null
+++ b/src/operator/nn/layer_norm.cc
@@ -0,0 +1,148 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2015 by Contributors
+ * \file layer_norm.cc
+ * \brief Implements Ba et. al, Layer Normalization (https://arxiv.org/abs/1607.06450).
+*/
+
+#include "layer_norm-inl.h"
+#include <nnvm/op_attr_types.h>
+#include "../elemwise_op_common.h"
+
+namespace mxnet {
+namespace op {
+
+DMLC_REGISTER_PARAMETER(LayerNormParam);
+
+static bool LayerNormShape(const nnvm::NodeAttrs& attrs,
+                           std::vector<TShape> *in_shape,
+                           std::vector<TShape> *out_shape) {
+  const LayerNormParam& param = nnvm::get<LayerNormParam>(attrs.parsed);
+  using namespace mshadow;
+  CHECK_EQ(in_shape->size(), 3U) << "Input:[data, gamma, beta]";
+  const TShape &dshape = in_shape->at(layernorm::kData);
+  int axis = param.axis;
+  if (axis < 0) {
+    axis += static_cast<int>(dshape.ndim());
+  }
+  CHECK(axis >= 0 && axis < static_cast<int>(dshape.ndim()))
+    << "Channel axis out of range: axis=" << param.axis;
+
+  const int channelCount = dshape[axis];
+
+  if (dshape.ndim() == 0) {
+    return false;
+  }
+
+  in_shape->at(layernorm::kGamma) = TShape(Shape1(channelCount));
+  in_shape->at(layernorm::kBeta) = TShape(Shape1(channelCount));
+
+  out_shape->clear();
+  out_shape->push_back(dshape);                // kOut
+  TShape moments_shape(dshape.begin(), dshape.end());
+  moments_shape[axis] = 1;
+  out_shape->push_back(moments_shape);  // kMean
+  out_shape->push_back(moments_shape);  // kInvstd
+  return true;
+}
+
+
+NNVM_REGISTER_OP(LayerNorm)
+.describe(R"code(Layer normalization.
+
+Normalizes the channels of the input tensor by mean and variance, and applies a scale ``gamma`` as
+well as offset ``beta``.
+
+Assume the input has more than one dimension and we normalize along axis 1.
+We first compute the mean and variance along this axis and then 
+compute the normalized output, which has the same shape as input, as following:
+
+.. math::
+
+  out = \frac{data - mean(data, axis)}{\sqrt{var(data, axis) + \epsilon}} * gamma + beta
+
+Both ``gamma`` and ``beta`` are learnable parameters.
+
+Unlike BatchNorm and InstanceNorm,  the *mean* and *var* are computed along the channel dimension.
+
+Assume the input has size *k* on axis 1, then both ``gamma`` and ``beta``
+have shape *(k,)*. If ``output_mean_var`` is set to be true, then outputs both ``data_mean`` and
+``data_std``. Note that no gradient will be passed through these two outputs.
+
+The parameter ``axis`` specifies which axis of the input shape denotes
+the 'channel' (separately normalized groups).  The default is -1, which sets the channel
+axis to be the last item in the input shape.
+
+)code" ADD_FILELINE)
+.set_num_inputs(3)
+.set_num_outputs(3)
+.set_attr_parser(ParamParser<LayerNormParam>)
+.set_attr<nnvm::FListInputNames>("FListInputNames",
+    [](const NodeAttrs& attrs) {
+  return std::vector<std::string>{"data", "gamma", "beta"};
+})
+.set_attr<nnvm::FListOutputNames>("FListOutputNames",
+    [](const NodeAttrs& attrs) {
+  return std::vector<std::string>{"output", "mean", "std"};
+})
+.set_attr<nnvm::FNumVisibleOutputs>("FNumVisibleOutputs",
+    [](const NodeAttrs& attrs) {
+  const LayerNormParam& param = nnvm::get<LayerNormParam>(attrs.parsed);
+  return param.output_mean_var ? 3 : 1;
+})
+.set_attr<nnvm::FInferShape>("FInferShape", LayerNormShape)
+.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<3, 3>)
+.set_attr<FCompute>("FCompute<cpu>", LayerNormCompute<cpu>)
+.set_attr<nnvm::FGradient>("FGradient", [](const nnvm::NodePtr& n,
+                                           const std::vector<nnvm::NodeEntry>& ograds) {
+  std::vector<nnvm::NodeEntry> heads;
+  heads.push_back(ograds[0]);  // ograd
+  heads.push_back(n->inputs[0]);  // data
+  heads.push_back(n->inputs[1]);  // gamma
+  heads.emplace_back(nnvm::NodeEntry{n, 1, 0});  // mean
+  heads.emplace_back(nnvm::NodeEntry{ n, 2, 0 });  // std
+  return MakeGradNode("_backward_LayerNorm", n, heads, n->attrs.dict);
+})
+.set_attr<nnvm::FInplaceOption>("FInplaceOption",
+  [](const NodeAttrs& attrs) {
+  return std::vector<std::pair<int, int> >{{0, 0}};
+})
+.set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& n) {
+  return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
+})
+.add_argument("data", "NDArray-or-Symbol", "Input data to layer normalization")
+.add_argument("gamma", "NDArray-or-Symbol", "gamma array")
+.add_argument("beta", "NDArray-or-Symbol", "beta array")
+.add_arguments(LayerNormParam::__FIELDS__());
+
+
+NNVM_REGISTER_OP(_backward_LayerNorm)
+.set_num_inputs(5)
+.set_num_outputs(3)
+.set_attr<nnvm::TIsBackward>("TIsBackward", true)
+.set_attr_parser(ParamParser<LayerNormParam>)
+.set_attr<FCompute>("FCompute<cpu>", LayerNormGradCompute<cpu>)
+.set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& n) {
+  return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
+});
+
+}  // namespace op
+}  // namespace mxnet
diff --git a/src/operator/nn/layer_norm.cu b/src/operator/nn/layer_norm.cu
new file mode 100644
index 0000000..a146131
--- /dev/null
+++ b/src/operator/nn/layer_norm.cu
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2015 by Contributors
+ * \file layer_norm.cu
+ * \brief Implements Ba et. al, Layer Normalization (https://arxiv.org/abs/1607.06450).
+*/
+#include "./layer_norm-inl.h"
+
+namespace mxnet {
+namespace op {
+
+NNVM_REGISTER_OP(LayerNorm)
+.set_attr<FCompute>("FCompute<gpu>", LayerNormCompute<gpu>);
+
+NNVM_REGISTER_OP(_backward_LayerNorm)
+.set_attr<FCompute>("FCompute<gpu>", LayerNormGradCompute<gpu>);
+
+}  // namespace op
+}  // namespace mxnet
diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py
index 889d210..ba2e7ab 100644
--- a/tests/python/unittest/test_gluon.py
+++ b/tests/python/unittest/test_gluon.py
@@ -381,6 +381,11 @@ def test_instancenorm():
     layer = nn.InstanceNorm(in_channels=10)
     check_layer_forward(layer, (2, 10, 10, 10))
 
+@with_seed()
+def test_layernorm():
+    layer = nn.LayerNorm(in_channels=10)
+    check_layer_forward(layer, (2, 10, 10, 10))
+
 
 @with_seed()
 def test_reflectionpad():
diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py
index 2208a33..87dfda5 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -2413,6 +2413,47 @@ def test_l2_normalization():
                         check_l2_normalization((nbatch, nchannel, height, width), mode)
 
 
+def check_layer_normalization(in_shape, axis, eps, dtype=np.float32):
+    def npy_layer_norm(data, gamma, beta, axis=1, eps=1E-5):
+        if axis < 0:
+            axis += data.ndim
+        broadcast_shape = [1 for _ in range(data.ndim)]
+        broadcast_shape[axis] = data.shape[axis]
+        mean = data.mean(axis=axis, keepdims=True)
+        var = data.var(axis=axis, keepdims=True)
+        std = np.sqrt(var + eps)
+        out = np.reshape(gamma, broadcast_shape) * (data - mean) / std + \
+              np.reshape(beta, broadcast_shape)
+        return out
+
+    ctx = default_context()
+    data = np.random.normal(0, 1, in_shape).astype(dtype)
+    gamma = np.random.normal(0, 1, (in_shape[axis],)).astype(dtype)
+    beta = np.random.normal(0, 1, (in_shape[axis],)).astype(dtype)
+    data_s = mx.symbol.Variable('data')
+    gamma_s = mx.symbol.Variable('gamma')
+    beta_s = mx.symbol.Variable('beta')
+    out_s = mx.symbol.LayerNorm(data=data_s, gamma=gamma_s, beta=beta_s, axis=axis, eps=eps)
+    exe = out_s.simple_bind(ctx, data=in_shape)
+    exe.arg_dict['data'][:] = data
+    exe.arg_dict['gamma'][:] = gamma
+    exe.arg_dict['beta'][:] = beta
+    out_nd = exe.forward()[0]
+    out = npy_layer_norm(data, gamma, beta, axis, eps)
+    assert_allclose(out, out_nd.asnumpy(), 1E-4, 1E-4)
+    for req in ['write', 'add']:
+        check_numeric_gradient(out_s, {'data': data, 'gamma': gamma, 'beta': beta},
+                               grad_nodes={'data': req, 'gamma': req, 'beta': req},
+                               numeric_eps=1e-2, rtol=1e-2, atol=1e-3)
+
+def test_layer_norm():
+    for dtype in [np.float16, np.float32, np.float64]:
+        for in_shape in [(10, 6, 5), (5, 5)]:
+            for axis in range(-len(in_shape), len(in_shape)):
+                for eps in [1E-3, 1E-4]:
+                    check_layer_normalization(in_shape, axis, eps)
+
+
 # Numpy Implementation of Sequence Ops
 def sequence_last_numpy(array, lengths, axis):
     # create new array of dims [batch, seqlen, ...]

-- 
To stop receiving notification emails like this one, please contact
zhasheng@apache.org.