You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by sx...@apache.org on 2019/07/19 05:44:56 UTC

[incubator-mxnet] branch master updated: Group Normalization (#14959)

This is an automated email from the ASF dual-hosted git repository.

sxjscience pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new eec0fb4  Group Normalization (#14959)
eec0fb4 is described below

commit eec0fb4eda40f4fb8222a8d93d8face454aead09
Author: Hao Jin <hj...@gmail.com>
AuthorDate: Thu Jul 18 22:44:34 2019 -0700

    Group Normalization (#14959)
    
    * GroupNorm
    
    * add to amp list
    
    * re-write forward
---
 python/mxnet/contrib/amp/lists/symbol.py |   1 +
 python/mxnet/gluon/nn/basic_layers.py    |  91 +++++++-
 src/operator/nn/group_norm-inl.h         | 347 +++++++++++++++++++++++++++++++
 src/operator/nn/group_norm.cc            | 131 ++++++++++++
 src/operator/nn/group_norm.cu            |  37 ++++
 tests/python/unittest/test_gluon.py      |   9 +
 tests/python/unittest/test_operator.py   |  91 ++++++++
 7 files changed, 706 insertions(+), 1 deletion(-)

diff --git a/python/mxnet/contrib/amp/lists/symbol.py b/python/mxnet/contrib/amp/lists/symbol.py
index 9a587df..c6cc3d1 100644
--- a/python/mxnet/contrib/amp/lists/symbol.py
+++ b/python/mxnet/contrib/amp/lists/symbol.py
@@ -471,6 +471,7 @@ FP32_FUNCS = [
     'log_softmax',
     'InstanceNorm',
     'LayerNorm',
+    'GroupNorm',
     'L2Normalization',
     'LRN',
     'SoftmaxActivation',
diff --git a/python/mxnet/gluon/nn/basic_layers.py b/python/mxnet/gluon/nn/basic_layers.py
index 3d6976c..b1482ce 100644
--- a/python/mxnet/gluon/nn/basic_layers.py
+++ b/python/mxnet/gluon/nn/basic_layers.py
@@ -19,7 +19,8 @@
 # pylint: disable= arguments-differ
 """Basic neural network layers."""
 __all__ = ['Sequential', 'HybridSequential', 'Dense', 'Dropout', 'Embedding',
-           'BatchNorm', 'InstanceNorm', 'LayerNorm', 'Flatten', 'Lambda', 'HybridLambda']
+           'BatchNorm', 'InstanceNorm', 'LayerNorm', 'GroupNorm',
+           'Flatten', 'Lambda', 'HybridLambda']
 import warnings
 import numpy as np
 
@@ -616,6 +617,94 @@ class LayerNorm(HybridBlock):
                                            for k, v in self._kwargs.items()]))
 
 
+class GroupNorm(HybridBlock):
+    r"""
+    Applies group normalization to the n-dimensional input array.
+    This operator takes an n-dimensional input array where the leftmost 2 axis are
+    `batch` and `channel` respectively:
+
+    .. math::
+
+      x = x.reshape((N, num_groups, C // num_groups, ...))
+      axis = (2, ...)
+      out = \frac{x - mean[x, axis]}{ \sqrt{Var[x, axis] + \epsilon}} * gamma + beta
+
+    Parameters
+    ----------
+    num_groups: int, default 1
+        Number of groups to separate the channel axis into.
+    epsilon: float, default 1e-5
+        Small float added to variance to avoid dividing by zero.
+    center: bool, default True
+        If True, add offset of `beta` to normalized tensor.
+        If False, `beta` is ignored.
+    scale: bool, default True
+        If True, multiply by `gamma`. If False, `gamma` is not used.
+    beta_initializer: str or `Initializer`, default 'zeros'
+        Initializer for the beta weight.
+    gamma_initializer: str or `Initializer`, default 'ones'
+        Initializer for the gamma weight.
+
+
+    Inputs:
+        - **data**: input tensor with shape (N, C, ...).
+
+    Outputs:
+        - **out**: output tensor with the same shape as `data`.
+
+    References
+    ----------
+        `Group Normalization
+        <https://arxiv.org/pdf/1803.08494.pdf>`_
+
+    Examples
+    --------
+    >>> # Input of shape (2, 3, 4)
+    >>> x = mx.nd.array([[[ 0,  1,  2,  3],
+                          [ 4,  5,  6,  7],
+                          [ 8,  9, 10, 11]],
+                         [[12, 13, 14, 15],
+                          [16, 17, 18, 19],
+                          [20, 21, 22, 23]]])
+    >>> # Group normalization is calculated with the above formula
+    >>> layer = GroupNorm()
+    >>> layer.initialize(ctx=mx.cpu(0))
+    >>> layer(x)
+    [[[-1.5932543 -1.3035717 -1.0138891 -0.7242065]
+      [-0.4345239 -0.1448413  0.1448413  0.4345239]
+      [ 0.7242065  1.0138891  1.3035717  1.5932543]]
+     [[-1.5932543 -1.3035717 -1.0138891 -0.7242065]
+      [-0.4345239 -0.1448413  0.1448413  0.4345239]
+      [ 0.7242065  1.0138891  1.3035717  1.5932543]]]
+    <NDArray 2x3x4 @cpu(0)>
+    """
+    def __init__(self, num_groups=1, epsilon=1e-5, center=True, scale=True,
+                 beta_initializer='zeros', gamma_initializer='ones',
+                 prefix=None, params=None):
+        super(GroupNorm, self).__init__(prefix=prefix, params=params)
+        self._kwargs = {'eps': epsilon, 'num_groups': num_groups, 'center': center, 'scale': scale}
+        self._num_groups = num_groups
+        self._epsilon = epsilon
+        self._center = center
+        self._scale = scale
+        self.gamma = self.params.get('gamma', grad_req='write' if scale else 'null',
+                                     shape=(num_groups,), init=gamma_initializer,
+                                     allow_deferred_init=True)
+        self.beta = self.params.get('beta', grad_req='write' if center else 'null',
+                                    shape=(num_groups,), init=beta_initializer,
+                                    allow_deferred_init=True)
+
+    def hybrid_forward(self, F, data, gamma, beta):
+        norm_data = F.GroupNorm(data, gamma=gamma, beta=beta, num_groups=self._num_groups, eps=self._epsilon)
+        return norm_data
+
+    def __repr__(self):
+        s = '{name}({content})'
+        return s.format(name=self.__class__.__name__,
+                        content=', '.join(['='.join([k, v.__repr__()])
+                                           for k, v in self._kwargs.items()]))
+
+
 class Lambda(Block):
     r"""Wraps an operator or an expression as a Block object.
 
diff --git a/src/operator/nn/group_norm-inl.h b/src/operator/nn/group_norm-inl.h
new file mode 100644
index 0000000..69d5a30
--- /dev/null
+++ b/src/operator/nn/group_norm-inl.h
@@ -0,0 +1,347 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2019 by Contributors
+ * \file group_norm-inl.h
+ * \brief Implements Group Normalization (https://arxiv.org/abs/1803.08494).
+ * \author Hao Jin
+*/
+
+#ifndef MXNET_OPERATOR_NN_GROUP_NORM_INL_H_
+#define MXNET_OPERATOR_NN_GROUP_NORM_INL_H_
+
+#include <dmlc/logging.h>
+#include <dmlc/parameter.h>
+#include <mxnet/operator.h>
+#include <mshadow/base.h>
+#include <map>
+#include <algorithm>
+#include <vector>
+#include <string>
+#include <utility>
+#include "./moments-inl.h"
+#include "../mshadow_op.h"
+#include "../operator_common.h"
+#include "../mxnet_op.h"
+#include "../tensor/broadcast_reduce_op.h"
+
+namespace mxnet {
+namespace op {
+
+namespace groupnorm {
+enum GroupNormOpInputs {kData, kGamma, kBeta};  // kGamma: scaling parameters, kBeta: shift biases
+enum GroupNormOpOutputs {kOut, kMean, kStd};  // req, out_data
+}  // namespace groupnorm
+
+struct GroupNormParam : public dmlc::Parameter<GroupNormParam> {
+  int num_groups;
+  float eps;
+  bool output_mean_var;
+  DMLC_DECLARE_PARAMETER(GroupNormParam) {
+    DMLC_DECLARE_FIELD(num_groups).set_default(1)
+      .describe("Total number of groups.");
+    DMLC_DECLARE_FIELD(eps).set_default(1e-5f)
+      .describe("An `epsilon` parameter to prevent division by 0.");
+    DMLC_DECLARE_FIELD(output_mean_var).set_default(false)
+      .describe("Output the mean and std calculated along the given axis.");
+  }
+};
+
+
+template<typename xpu>
+void GroupNormCompute(const nnvm::NodeAttrs& attrs,
+                      const OpContext& ctx,
+                      const std::vector<TBlob>& inputs,
+                      const std::vector<OpReqType>& req,
+                      const std::vector<TBlob>& outputs) {
+  using namespace mshadow;
+  using namespace mshadow::expr;
+  using namespace mxnet_op;
+  const GroupNormParam& param = nnvm::get<GroupNormParam>(attrs.parsed);
+  const int num_groups = param.num_groups;
+  if (req[0] == kNullOp) return;
+  CHECK_NE(req[0], kAddTo);
+
+  Stream<xpu> *s = ctx.get_stream<xpu>();
+  const TBlob& data = inputs[groupnorm::kData];
+  const TBlob& mean = outputs[groupnorm::kMean];
+  const TBlob& std = outputs[groupnorm::kStd];
+  const mxnet::TShape& data_shape = data.shape_;
+  CHECK_GE(data_shape.ndim(), 3U)
+    << "input should have at least 3 dims and "
+    << "the first 2 dims should be batch and channel respectively";
+  CHECK_EQ(data_shape[1] % num_groups, 0)
+    << "number of channel should be divisible by num_groups.";
+
+  mxnet::TShape temp_data_shape(data_shape.ndim() + 1, 1);
+  temp_data_shape[0] = data_shape[0];
+  temp_data_shape[1] = num_groups;
+  temp_data_shape[2] = data_shape[1] / num_groups;
+  for (int i = 2; i < data_shape.ndim(); ++i) {
+    temp_data_shape[i+1] = data_shape[i];
+  }
+
+  mxnet::TShape moments_shape(temp_data_shape.ndim(), 1);
+  for (int i = 0; i < data.shape_.ndim(); ++i) {
+    moments_shape[i] = (i < mean.shape_.ndim()) ? mean.shape_[i] : 1;
+  }
+
+  mxnet::TShape red_src_shape, red_dst_shape;
+  BroadcastReduceShapeCompact(temp_data_shape, moments_shape, &red_src_shape, &red_dst_shape);
+  int channel_size = red_src_shape.Size() / red_dst_shape.Size();
+
+  TBlob data_ = data.reshape(red_src_shape);
+  const TBlob& mean_ = mean.reshape(red_dst_shape);
+  const TBlob& std_ = std.reshape(red_dst_shape);
+
+  Tensor<xpu, 1, char> workspace;
+
+  size_t workspace_size = 0;
+  MSHADOW_REAL_TYPE_SWITCH(data.type_flag_, DType, {
+    BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, {
+      workspace_size =
+        broadcast::ReduceWorkspaceSize<NDim, DType>(s, red_dst_shape, req[0], red_src_shape);
+    });
+  });
+
+  workspace = ctx.requested[0].get_space_typed<xpu, 1, char>(Shape1(workspace_size), s);
+
+  // Calculate mean
+  MSHADOW_REAL_TYPE_SWITCH(data.type_flag_, DType, {
+    BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, {
+      broadcast::Reduce<mshadow_op::sum, NDim, DType, mshadow_op::identity, true>(
+        s, mean_, req[0], workspace, data_);
+      Tensor<xpu, 1, DType> mean_data_tensor = mean_.FlatTo1D<xpu, DType>(s);
+      mean_data_tensor /= scalar<DType>(channel_size);
+    });
+  });
+
+  TBlob data_grp = data.reshape(temp_data_shape);
+  const TBlob& mean_grp = mean.reshape(moments_shape);
+  const TBlob& std_grp = std.reshape(moments_shape);
+  const TBlob& output = outputs[groupnorm::kOut].reshape(temp_data_shape);
+
+  // Calculate data = data - mean
+  BinaryBroadcastCompute<xpu, op::mshadow_op::minus>(attrs, ctx,
+                                                     {data_grp, mean_grp},
+                                                     {kWriteTo}, {output});
+
+  // Calculate std
+  const TBlob centered_out = outputs[groupnorm::kOut].reshape(red_src_shape);
+  MSHADOW_REAL_TYPE_SWITCH(output.type_flag_, DType, {
+    BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, {
+      broadcast::Reduce<mshadow_op::sum, NDim, DType, mshadow_op::square, true>(
+        s, std_, req[0], workspace, centered_out);
+      Tensor<xpu, 1, DType> std_data_tensor = std_.FlatTo1D<xpu, DType>(s);
+      std_data_tensor = F<mshadow_op::square_root>(std_data_tensor / scalar<DType>(channel_size)
+                        + scalar<DType>(param.eps));
+    });
+  });
+
+  // Calculate data = data / std
+  BinaryBroadcastCompute<xpu, mshadow_op::div>(attrs, ctx,
+                                               {output, std_grp},
+                                               {kWriteTo}, {output});
+
+  mxnet::TShape new_param_shape(data_shape.ndim() + 1, 1);
+  new_param_shape[1] = num_groups;
+
+  const TBlob& gamma = inputs[groupnorm::kGamma].reshape(new_param_shape);
+  const TBlob& beta = inputs[groupnorm::kBeta].reshape(new_param_shape);
+
+  // Calculate data = data * gamma
+  BinaryBroadcastCompute<xpu, op::mshadow_op::mul>(attrs, ctx,
+                                                   {output, gamma},
+                                                   {kWriteTo}, {output});
+  // Calculate data = data + beta
+  BinaryBroadcastCompute<xpu, op::mshadow_op::plus>(attrs, ctx,
+                                                   {output, beta},
+                                                   {kWriteTo}, {output});
+}
+
+/*
+Calculate the gradient of group normalization.
+We have the following gradient for gamma, beta and x:
+
+\bar{x} = (x - mean) / std
+w = og * r / std
+grad_gamma = sum(\bar{x} og, exclude_axis)
+grad_beta = sum(og, exclude_axis)
+grad_x = w - mean(w, axis) - \bar{x} * mean(w * \bar{x}, axis)
+*/
+template<typename xpu>
+void GroupNormGradCompute(const nnvm::NodeAttrs& attrs,
+                          const OpContext& ctx,
+                          const std::vector<TBlob>& inputs,
+                          const std::vector<OpReqType>& req,
+                          const std::vector<TBlob>& outputs) {
+  using namespace mshadow;
+  using namespace mshadow::expr;
+  using namespace mxnet_op;
+  CHECK_EQ(inputs.size(), 5U);
+  CHECK_EQ(outputs.size(), 3U);
+  const GroupNormParam& param = nnvm::get<GroupNormParam>(attrs.parsed);
+  const int num_groups = param.num_groups;
+
+  const TBlob& data = inputs[1];
+  const mxnet::TShape& dshape = data.shape_;
+
+  mxnet::TShape temp_dshape(dshape.ndim() + 1, 1);
+  temp_dshape[0] = dshape[0];
+  temp_dshape[1] = num_groups;
+  temp_dshape[2] = dshape[1] / num_groups;
+  for (int i = 2; i < dshape.ndim(); ++i) {
+    temp_dshape[i+1] = dshape[i];
+  }
+  const TBlob& data_ = data.reshape(temp_dshape);
+  const TBlob& ograd = inputs[0].reshape(temp_dshape);
+
+  Stream<xpu> *s = ctx.get_stream<xpu>();
+  // Reshape gamma to be broadcastable
+  mxnet::TShape new_param_shape(dshape.ndim() + 1, 1);
+  new_param_shape[1] = num_groups;
+
+  const TBlob& gamma = inputs[2].reshape(new_param_shape);
+
+  const TBlob& mean = inputs[3];
+  const TBlob& std = inputs[4];
+
+  mxnet::TShape moments_shape(temp_dshape.ndim(), 1);
+  for (int i = 0; i < dshape.ndim(); ++i) {
+    moments_shape[i] = (i < mean.shape_.ndim()) ? mean.shape_[i] : 1;
+  }
+  const TBlob& mean_ = mean.reshape(moments_shape);
+  const TBlob& std_ = std.reshape(moments_shape);
+
+  // Prepare the necessary shapes for reduction
+  mxnet::TShape red_src_shape, red_dst_shape, red_exclude_src_shape, red_exclude_dst_shape;
+  BroadcastReduceShapeCompact(temp_dshape, mean_.shape_, &red_src_shape, &red_dst_shape);
+  BroadcastReduceShapeCompact(temp_dshape, gamma.shape_,
+                              &red_exclude_src_shape, &red_exclude_dst_shape);
+
+  int N = red_src_shape.Size() / red_dst_shape.Size();
+
+  // Initialize the workspace + Construct the temporary TBlobs
+  Tensor<xpu, 1, char> workspace;
+  size_t reduce_workspace_size = 0;
+  size_t data_size = 0;
+  size_t red_out_size = 0;
+  MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
+    data_size = sizeof(DType) * data.Size();
+    red_out_size = sizeof(DType) * mean.Size();
+    // There are two types of reduction workloads: reduce over axis and reduce exclude axis
+    // We take the maximum of the workspace sizes required by these workloads.
+    // Also, we explicitly set the req_type=kAddto in case we want to use it.
+    BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, {
+      reduce_workspace_size =
+        std::max(reduce_workspace_size,
+                 broadcast::ReduceWorkspaceSize<NDim, DType>(s, red_dst_shape,
+                                                             kAddTo, red_src_shape));
+    });
+    BROADCAST_NDIM_SWITCH(red_exclude_dst_shape.ndim(), NDim, {
+      reduce_workspace_size =
+        std::max(reduce_workspace_size,
+                 broadcast::ReduceWorkspaceSize<NDim, DType>(s, red_exclude_dst_shape, kAddTo,
+                                                             red_exclude_src_shape));
+    });
+  });
+  workspace = ctx.requested[0].get_space_typed<xpu, 1, char>(
+    Shape1(reduce_workspace_size + data_size * 2 + red_out_size), s);
+  const TBlob normalized_data =
+    TBlob(workspace.dptr_ + reduce_workspace_size,
+          data_.shape_, data.dev_mask(), data.type_flag_, data.dev_id());
+  const TBlob ograd_mult = TBlob(workspace.dptr_ + reduce_workspace_size + data_size,
+                                 data_.shape_, ograd.dev_mask(), ograd.type_flag_, ograd.dev_id());
+  const TBlob red_out = TBlob(workspace.dptr_ + reduce_workspace_size + data_size * 2,
+                              mean_.shape_, mean.dev_mask(), mean.type_flag_, mean.dev_id());
+  // Compute normalized_data = (data - mean) / std
+  BinaryBroadcastCompute<xpu, op::mshadow_op::minus>(attrs, ctx,
+                                                    {data_, mean_},
+                                                    {kWriteTo}, {normalized_data});
+  BinaryBroadcastCompute<xpu, op::mshadow_op::div>(attrs, ctx,
+                                                   {normalized_data, std_},
+                                                   {kWriteTo}, {normalized_data});
+  // Calculate grad_beta
+  if (req[2] != kNullOp) {
+    MSHADOW_REAL_TYPE_SWITCH(outputs[2].type_flag_, DType, {
+      BROADCAST_NDIM_SWITCH(red_exclude_dst_shape.ndim(), NDim, {
+        broadcast::Reduce<red::sum, NDim, DType, op::mshadow_op::identity, true>(
+          s, outputs[2].reshape(red_exclude_dst_shape), req[2], workspace,
+          ograd.reshape(red_exclude_src_shape));
+      });
+    });
+  }
+  // Calculate grad_gamma, it will be sum(ograd * normalized_data, exclude_axis)
+  ElemwiseBinaryOp::Compute<xpu, op::mshadow_op::mul>(attrs, ctx, {normalized_data, ograd},
+                                                      {kWriteTo}, {ograd_mult});
+  if (req[1] != kNullOp) {
+    MSHADOW_REAL_TYPE_SWITCH(outputs[1].type_flag_, DType, {
+      BROADCAST_NDIM_SWITCH(red_exclude_dst_shape.ndim(), NDim, {
+        broadcast::Reduce<mshadow_op::sum, NDim, DType, op::mshadow_op::identity, true>(
+          s, outputs[1].reshape(red_exclude_dst_shape), req[1], workspace,
+          ograd_mult.reshape(red_exclude_src_shape));
+      });
+    });
+  }
+
+  // Calculate grad_data:
+  //   ograd_mult = ograd * gamma / std
+  //   grad_data = ograd_mult - mean(ograd_mult, axis)
+  //               + normalized_data * (-mean(normalized_data * ograd_mult, axis))
+  if (req[0] != kNullOp) {
+    const TBlob output_ = outputs[0].reshape(data_.shape_);
+    BinaryBroadcastCompute<xpu, op::mshadow_op::mul>(attrs, ctx,
+                                                    {ograd, gamma},
+                                                    {kWriteTo}, {ograd_mult});
+    BinaryBroadcastCompute<xpu, op::mshadow_op::div>(attrs, ctx,
+                                                    {ograd_mult, std_},
+                                                    {kWriteTo}, {ograd_mult});
+    MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
+      BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, {
+        broadcast::Reduce<mshadow_op::sum, NDim, DType, op::mshadow_op::identity, true>(
+          s, red_out.reshape(red_dst_shape), kWriteTo, workspace,
+          ograd_mult.reshape(red_src_shape));
+      });
+      Tensor<xpu, 1, DType> red_out_tensor = red_out.FlatTo1D<xpu, DType>(s);
+      red_out_tensor /= scalar<DType>(N);
+    });
+    BinaryBroadcastCompute<xpu, op::mshadow_op::minus>(attrs, ctx,
+                                                      {ograd_mult, red_out},
+                                                      {req[0]}, {output_});
+    ElemwiseBinaryOp::Compute<xpu, op::mshadow_op::mul>(attrs, ctx, {ograd_mult, normalized_data},
+                                                        {kWriteTo}, {ograd_mult});
+    MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
+      BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, {
+        broadcast::Reduce<mshadow_op::sum, NDim, DType, op::mshadow_op::identity, true>(
+          s, red_out.reshape(red_dst_shape), kWriteTo, workspace,
+          ograd_mult.reshape(red_src_shape));
+      });
+      Tensor<xpu, 1, DType> red_out_tensor = red_out.FlatTo1D<xpu, DType>(s);
+      red_out_tensor /= scalar<DType>(-N);
+    });
+    BinaryBroadcastCompute<xpu, op::mshadow_op::mul>(attrs, ctx,
+                                                     {normalized_data, red_out},
+                                                     {kAddTo}, {output_});
+  }
+}
+
+}  // namespace op
+}  // namespace mxnet
+#endif  // MXNET_OPERATOR_NN_GROUP_NORM_INL_H_
diff --git a/src/operator/nn/group_norm.cc b/src/operator/nn/group_norm.cc
new file mode 100644
index 0000000..b4698ab
--- /dev/null
+++ b/src/operator/nn/group_norm.cc
@@ -0,0 +1,131 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2019 by Contributors
+ * \file group_norm.cc
+ * \brief Implements Group Normalization (https://arxiv.org/abs/1803.08494).
+*/
+
+#include "group_norm-inl.h"
+#include <nnvm/op_attr_types.h>
+#include "../elemwise_op_common.h"
+
+namespace mxnet {
+namespace op {
+
+DMLC_REGISTER_PARAMETER(GroupNormParam);
+
+static bool GroupNormShape(const nnvm::NodeAttrs& attrs,
+                           mxnet::ShapeVector *in_shape,
+                           mxnet::ShapeVector *out_shape) {
+  const GroupNormParam& param = nnvm::get<GroupNormParam>(attrs.parsed);
+  using namespace mshadow;
+  CHECK_EQ(in_shape->size(), 3U) << "Input:[data, gamma, beta]";
+  const mxnet::TShape &dshape = in_shape->at(groupnorm::kData);
+  CHECK_GE(dshape.ndim(), 3U);
+  const int num_groups = param.num_groups;
+  CHECK_EQ(dshape[1] % num_groups, 0) << "# of channels must be divisible by # of groups";
+
+  if (!mxnet::ndim_is_known(dshape)) {
+    return false;
+  }
+
+  in_shape->at(groupnorm::kGamma) = mxnet::TShape(Shape1(num_groups));
+  in_shape->at(groupnorm::kBeta) = mxnet::TShape(Shape1(num_groups));
+
+  out_shape->clear();
+  out_shape->push_back(dshape);
+
+  mxnet::TShape moments_shape(2, 1);
+  moments_shape[0] = dshape[0];
+  moments_shape[1] = num_groups;
+  out_shape->push_back(moments_shape);
+  out_shape->push_back(moments_shape);
+  return true;
+}
+
+NNVM_REGISTER_OP(GroupNorm)
+.describe(R"code(Group normalization.
+
+The input channels are separated into ``num_groups`` groups, each containing ``num_channels / num_groups`` channels.
+The mean and standard-deviation are calculated separately over the each group.
+
+.. math::
+
+  data = data.reshape((N, num_groups, C // num_groups, ...))
+  out = \frac{data - mean(data, axis)}{\sqrt{var(data, axis) + \epsilon}} * gamma + beta
+
+Both ``gamma`` and ``beta`` are learnable parameters.
+
+)code" ADD_FILELINE)
+.set_num_inputs(3)
+.set_num_outputs(3)
+.set_attr_parser(ParamParser<GroupNormParam>)
+.set_attr<nnvm::FListInputNames>("FListInputNames",
+    [](const NodeAttrs& attrs) {
+  return std::vector<std::string>{"data", "gamma", "beta"};
+})
+.set_attr<nnvm::FListOutputNames>("FListOutputNames",
+    [](const NodeAttrs& attrs) {
+  return std::vector<std::string>{"output", "mean", "std"};
+})
+.set_attr<nnvm::FNumVisibleOutputs>("FNumVisibleOutputs",
+    [](const NodeAttrs& attrs) {
+  const GroupNormParam& param = nnvm::get<GroupNormParam>(attrs.parsed);
+  return param.output_mean_var ? 3 : 1;
+})
+.set_attr<mxnet::FInferShape>("FInferShape", GroupNormShape)
+.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<3, 3>)
+.set_attr<FCompute>("FCompute<cpu>", GroupNormCompute<cpu>)
+.set_attr<nnvm::FGradient>("FGradient", [](const nnvm::NodePtr& n,
+                                           const std::vector<nnvm::NodeEntry>& ograds) {
+  std::vector<nnvm::NodeEntry> heads;
+  heads.push_back(ograds[0]);  // ograd
+  heads.push_back(n->inputs[0]);  // data
+  heads.push_back(n->inputs[1]);  // gamma
+  heads.emplace_back(nnvm::NodeEntry{n, 1, 0});  // mean
+  heads.emplace_back(nnvm::NodeEntry{ n, 2, 0 });  // std
+  return MakeGradNode("_backward_GroupNorm", n, heads, n->attrs.dict);
+})
+.set_attr<nnvm::FInplaceOption>("FInplaceOption",
+  [](const NodeAttrs& attrs) {
+  return std::vector<std::pair<int, int> >{{0, 0}};
+})
+.set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& n) {
+  return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
+})
+.add_argument("data", "NDArray-or-Symbol", "Input data")
+.add_argument("gamma", "NDArray-or-Symbol", "gamma array")
+.add_argument("beta", "NDArray-or-Symbol", "beta array")
+.add_arguments(GroupNormParam::__FIELDS__());
+
+
+NNVM_REGISTER_OP(_backward_GroupNorm)
+.set_num_inputs(5)
+.set_num_outputs(3)
+.set_attr<nnvm::TIsBackward>("TIsBackward", true)
+.set_attr_parser(ParamParser<GroupNormParam>)
+.set_attr<FCompute>("FCompute<cpu>", GroupNormGradCompute<cpu>)
+.set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& n) {
+  return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
+});
+
+}  // namespace op
+}  // namespace mxnet
diff --git a/src/operator/nn/group_norm.cu b/src/operator/nn/group_norm.cu
new file mode 100644
index 0000000..136c333
--- /dev/null
+++ b/src/operator/nn/group_norm.cu
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2019 by Contributors
+ * \file group_norm.cu
+ * \brief Implements Group Normalization (https://arxiv.org/abs/1803.08494).
+*/
+#include "./group_norm-inl.h"
+
+namespace mxnet {
+namespace op {
+
+NNVM_REGISTER_OP(GroupNorm)
+.set_attr<FCompute>("FCompute<gpu>", GroupNormCompute<gpu>);
+
+NNVM_REGISTER_OP(_backward_GroupNorm)
+.set_attr<FCompute>("FCompute<gpu>", GroupNormGradCompute<gpu>);
+
+}  // namespace op
+}  // namespace mxnet
diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py
index d52e7f8..b59ce2d 100644
--- a/tests/python/unittest/test_gluon.py
+++ b/tests/python/unittest/test_gluon.py
@@ -744,6 +744,15 @@ def test_layernorm():
 
 
 @with_seed()
+def test_groupnorm():
+    layer = nn.GroupNorm()
+    check_layer_forward(layer, (2, 10, 10, 10))
+    layer = nn.GroupNorm(num_groups=2)
+    check_layer_forward(layer, (2, 10, 10, 10))
+    layer = nn.GroupNorm(num_groups=5)
+    check_layer_forward(layer, (2, 10, 10, 10))
+
+@with_seed()
 def test_reflectionpad():
     layer = nn.ReflectionPad2D(3)
     check_layer_forward(layer, (2, 3, 24, 24))
diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py
index aeddc7a..749f0f2 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -1831,6 +1831,97 @@ def test_batchnorm():
 
 
 @with_seed()
+def test_groupnorm():
+    acc_types = {'float16': 'float32', 'float32': 'float64', 'float64': 'float64'}
+    def x_hat_helper(x, num_groups, eps):
+        dtype = x.dtype
+        dshape = x.shape
+        assert len(dshape) == 4
+        acc_type = acc_types[str(dtype)]
+        new_shape = (dshape[0], num_groups, int(dshape[1] / num_groups), dshape[2], dshape[3])
+        new_moments_shape = (dshape[0], num_groups, 1, 1, 1)
+        data = x.reshape(new_shape)
+        mean = np.mean(data, axis=(2, 3, 4), keepdims=False, dtype=acc_type).astype(dtype)
+        std = np.sqrt(np.var(data, axis=(2, 3, 4), dtype=acc_type, keepdims=False).astype(dtype) + eps)
+        x_hat = (data - mean.reshape(new_moments_shape)) / std.reshape(new_moments_shape)
+        return x_hat, mean, std
+
+    def np_groupnorm(data, gamma, beta, num_groups, eps):
+        new_param_shape = (1, num_groups, 1, 1, 1)
+        x_hat, mean, std = x_hat_helper(data, num_groups, eps)
+        out = x_hat * gamma.reshape(new_param_shape) + beta.reshape(new_param_shape)
+        return out.reshape(dshape), mean, std
+
+    def np_groupnorm_grad(ograd, data, gamma, beta, mean, std, num_groups, eps):
+        x_hat, mean, std = x_hat_helper(data, num_groups, eps)
+        new_shape = x_hat.shape
+        dshape = data.shape
+        dtype = data.dtype
+        new_moments_shape = (new_shape[0], num_groups, 1, 1, 1)
+        new_param_shape = (1, num_groups, 1, 1, 1)
+        acc_type = acc_types[str(dtype)]
+        ograd = ograd.reshape(new_shape)
+        data = data.reshape(new_shape)
+        gamma = gamma.reshape(new_param_shape)
+        beta = beta.reshape(new_param_shape)
+        mean = mean.reshape(new_moments_shape)
+        std = std.reshape(new_moments_shape)
+        beta_grad = np.sum(ograd, axis=(0, 2, 3, 4), dtype=acc_type, keepdims=False).astype(dtype)
+        gamma_grad = np.sum(x_hat * ograd, axis=(0, 2, 3, 4), dtype=acc_type, keepdims=False).astype(dtype)
+        x_hat_grad = ograd * gamma
+        ograd_mult = x_hat_grad / std
+        red_out = np.mean(ograd_mult, axis=(2, 3, 4), dtype=acc_type, keepdims=True).astype(dtype)
+        data_grad = ograd_mult - red_out
+        red_out = np.mean(ograd_mult * x_hat, axis=(2, 3, 4), dtype=acc_type, keepdims=True).astype(dtype)
+        data_grad = data_grad - x_hat * red_out
+        return data_grad.reshape(dshape), gamma_grad, beta_grad
+
+
+    batch_size = random.randint(1, 8)
+    num_groups = random.randint(2, 3)
+    num_channels = random.randint(2, 3) * num_groups
+    height = random.randint(1, 5)
+    width = random.randint(1, 5)
+    dshape = (batch_size, num_channels, height, width)
+    param_shape = (num_groups,)
+    temp_shape = (batch_size, num_groups, int(num_channels / num_groups), height, width)
+    np_data = np.random.uniform(0.2, 1.0, dshape)
+    np_gamma = np.random.uniform(-1.0, 1.0, param_shape)
+    np_beta = np.random.uniform(-1.0, 1.0, param_shape)
+    data_sym = mx.sym.Variable("data")
+    gamma_sym = mx.sym.Variable("gamma")
+    beta_sym = mx.sym.Variable("beta")
+    for dtype in [np.float16, np.float32, np.float64]:
+        eps = 1e-2 if dtype == np.float16 else 1e-5
+        mx_data = mx.nd.array(np_data, dtype=dtype)
+        mx_gamma = mx.nd.array(np_gamma, dtype=dtype)
+        mx_beta = mx.nd.array(np_beta, dtype=dtype)
+        np_out, np_mean, np_std = np_groupnorm(np_data.astype(dtype),
+                                               np_gamma.astype(dtype),
+                                               np_beta.astype(dtype),
+                                               num_groups=num_groups,
+                                               eps=eps)
+        mx_sym = mx.sym.GroupNorm(data=data_sym, gamma=gamma_sym, beta=beta_sym,
+                                  num_groups=num_groups, eps=eps, output_mean_var=True)
+        check_symbolic_forward(mx_sym, [mx_data, mx_gamma, mx_beta], [np_out, np_mean, np_std],
+                               rtol=1e-2 if dtype == np.float16 else 1e-3,
+                               atol=5e-3 if dtype == np.float16 else 1e-5, dtype=dtype)
+        mx_sym = mx.sym.GroupNorm(data=data_sym, gamma=gamma_sym, beta=beta_sym,
+                                  num_groups=num_groups, eps=eps, output_mean_var=False)
+        np_ograd = np.random.uniform(-1.0, 1.0, dshape).astype(dtype)
+        np_data_grad, np_gamma_grad, np_beta_grad = np_groupnorm_grad(np_ograd,
+                                                                      np_data.astype(dtype),
+                                                                      np_gamma.astype(dtype),
+                                                                      np_beta.astype(dtype),
+                                                                      np_mean, np_std,
+                                                                      num_groups, eps)
+        check_symbolic_backward(mx_sym, [mx_data, mx_gamma, mx_beta], [mx.nd.array(np_ograd)],
+                                [np_data_grad, np_gamma_grad, np_beta_grad],
+                                rtol=1e-2 if dtype == np.float16 else 1e-3,
+                                atol=5e-2 if dtype == np.float16 else 1e-5, dtype=dtype)
+
+
+@with_seed()
 def test_convolution_grouping():
     for dim in [1, 2, 3]:
         num_filter = 4