You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by sx...@apache.org on 2020/04/30 17:04:36 UTC

[incubator-mxnet] branch master updated: [Bug Fix] Fix GroupNorm Implementation (#18199)

This is an automated email from the ASF dual-hosted git repository.

sxjscience pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 1496c91  [Bug Fix] Fix GroupNorm Implementation (#18199)
1496c91 is described below

commit 1496c91871b9d81d6a18785bdc8a1c3450bedbca
Author: Huang, Guangtai <gu...@amazon.com>
AuthorDate: Fri May 1 01:03:41 2020 +0800

    [Bug Fix] Fix GroupNorm Implementation (#18199)
    
    * init
    
    * add in_channels
---
 python/mxnet/gluon/nn/basic_layers.py  | 11 +++++++----
 src/operator/nn/group_norm-inl.h       | 25 +++++++++++++------------
 src/operator/nn/group_norm.cc          |  4 ++--
 tests/python/unittest/test_operator.py | 16 ++++++++--------
 4 files changed, 30 insertions(+), 26 deletions(-)

diff --git a/python/mxnet/gluon/nn/basic_layers.py b/python/mxnet/gluon/nn/basic_layers.py
index 70b0a71..797392a 100644
--- a/python/mxnet/gluon/nn/basic_layers.py
+++ b/python/mxnet/gluon/nn/basic_layers.py
@@ -820,7 +820,7 @@ class GroupNorm(HybridBlock):
     """
     def __init__(self, num_groups=1, epsilon=1e-5, center=True, scale=True,
                  beta_initializer='zeros', gamma_initializer='ones',
-                 prefix=None, params=None):
+                 in_channels=0, prefix=None, params=None):
         super(GroupNorm, self).__init__(prefix=prefix, params=params)
         self._kwargs = {'eps': epsilon, 'num_groups': num_groups, 'center': center, 'scale': scale}
         self._num_groups = num_groups
@@ -828,10 +828,10 @@ class GroupNorm(HybridBlock):
         self._center = center
         self._scale = scale
         self.gamma = self.params.get('gamma', grad_req='write' if scale else 'null',
-                                     shape=(num_groups,), init=gamma_initializer,
+                                     shape=(in_channels,), init=gamma_initializer,
                                      allow_deferred_init=True)
         self.beta = self.params.get('beta', grad_req='write' if center else 'null',
-                                    shape=(num_groups,), init=beta_initializer,
+                                    shape=(in_channels,), init=beta_initializer,
                                     allow_deferred_init=True)
 
     def hybrid_forward(self, F, data, gamma, beta):
@@ -839,7 +839,10 @@ class GroupNorm(HybridBlock):
         return norm_data
 
     def __repr__(self):
-        s = '{name}({content})'
+        s = '{name}({content}'
+        in_channels = self.gamma.shape[0]
+        s += ', in_channels={0}'.format(in_channels)
+        s += ')'
         return s.format(name=self.__class__.__name__,
                         content=', '.join(['='.join([k, v.__repr__()])
                                            for k, v in self._kwargs.items()]))
diff --git a/src/operator/nn/group_norm-inl.h b/src/operator/nn/group_norm-inl.h
index 69d5a30..143e216 100644
--- a/src/operator/nn/group_norm-inl.h
+++ b/src/operator/nn/group_norm-inl.h
@@ -136,16 +136,16 @@ void GroupNormCompute(const nnvm::NodeAttrs& attrs,
   TBlob data_grp = data.reshape(temp_data_shape);
   const TBlob& mean_grp = mean.reshape(moments_shape);
   const TBlob& std_grp = std.reshape(moments_shape);
-  const TBlob& output = outputs[groupnorm::kOut].reshape(temp_data_shape);
+  const TBlob& output_grp = outputs[groupnorm::kOut].reshape(temp_data_shape);
 
   // Calculate data = data - mean
   BinaryBroadcastCompute<xpu, op::mshadow_op::minus>(attrs, ctx,
                                                      {data_grp, mean_grp},
-                                                     {kWriteTo}, {output});
+                                                     {kWriteTo}, {output_grp});
 
   // Calculate std
   const TBlob centered_out = outputs[groupnorm::kOut].reshape(red_src_shape);
-  MSHADOW_REAL_TYPE_SWITCH(output.type_flag_, DType, {
+  MSHADOW_REAL_TYPE_SWITCH(output_grp.type_flag_, DType, {
     BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, {
       broadcast::Reduce<mshadow_op::sum, NDim, DType, mshadow_op::square, true>(
         s, std_, req[0], workspace, centered_out);
@@ -157,11 +157,12 @@ void GroupNormCompute(const nnvm::NodeAttrs& attrs,
 
   // Calculate data = data / std
   BinaryBroadcastCompute<xpu, mshadow_op::div>(attrs, ctx,
-                                               {output, std_grp},
-                                               {kWriteTo}, {output});
+                                               {output_grp, std_grp},
+                                               {kWriteTo}, {output_grp});
 
-  mxnet::TShape new_param_shape(data_shape.ndim() + 1, 1);
-  new_param_shape[1] = num_groups;
+  const TBlob& output = outputs[groupnorm::kOut];
+  mxnet::TShape new_param_shape(data_shape.ndim(), 1);
+  new_param_shape[1] = data_shape[1];
 
   const TBlob& gamma = inputs[groupnorm::kGamma].reshape(new_param_shape);
   const TBlob& beta = inputs[groupnorm::kBeta].reshape(new_param_shape);
@@ -215,8 +216,8 @@ void GroupNormGradCompute(const nnvm::NodeAttrs& attrs,
 
   Stream<xpu> *s = ctx.get_stream<xpu>();
   // Reshape gamma to be broadcastable
-  mxnet::TShape new_param_shape(dshape.ndim() + 1, 1);
-  new_param_shape[1] = num_groups;
+  mxnet::TShape new_param_shape(dshape.ndim(), 1);
+  new_param_shape[1] = dshape[1];
 
   const TBlob& gamma = inputs[2].reshape(new_param_shape);
 
@@ -233,7 +234,7 @@ void GroupNormGradCompute(const nnvm::NodeAttrs& attrs,
   // Prepare the necessary shapes for reduction
   mxnet::TShape red_src_shape, red_dst_shape, red_exclude_src_shape, red_exclude_dst_shape;
   BroadcastReduceShapeCompact(temp_dshape, mean_.shape_, &red_src_shape, &red_dst_shape);
-  BroadcastReduceShapeCompact(temp_dshape, gamma.shape_,
+  BroadcastReduceShapeCompact(dshape, gamma.shape_,
                               &red_exclude_src_shape, &red_exclude_dst_shape);
 
   int N = red_src_shape.Size() / red_dst_shape.Size();
@@ -308,8 +309,8 @@ void GroupNormGradCompute(const nnvm::NodeAttrs& attrs,
   if (req[0] != kNullOp) {
     const TBlob output_ = outputs[0].reshape(data_.shape_);
     BinaryBroadcastCompute<xpu, op::mshadow_op::mul>(attrs, ctx,
-                                                    {ograd, gamma},
-                                                    {kWriteTo}, {ograd_mult});
+                                                    {inputs[0], gamma},
+                                                    {kWriteTo}, {ograd_mult.reshape(data.shape_)});
     BinaryBroadcastCompute<xpu, op::mshadow_op::div>(attrs, ctx,
                                                     {ograd_mult, std_},
                                                     {kWriteTo}, {ograd_mult});
diff --git a/src/operator/nn/group_norm.cc b/src/operator/nn/group_norm.cc
index 6b8fe9b..c939b44 100644
--- a/src/operator/nn/group_norm.cc
+++ b/src/operator/nn/group_norm.cc
@@ -47,8 +47,8 @@ static bool GroupNormShape(const nnvm::NodeAttrs& attrs,
     return false;
   }
 
-  in_shape->at(groupnorm::kGamma) = mxnet::TShape(Shape1(num_groups));
-  in_shape->at(groupnorm::kBeta) = mxnet::TShape(Shape1(num_groups));
+  in_shape->at(groupnorm::kGamma) = mxnet::TShape(Shape1(dshape[1]));
+  in_shape->at(groupnorm::kBeta) = mxnet::TShape(Shape1(dshape[1]));
 
   out_shape->clear();
   out_shape->push_back(dshape);
diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py
index 32812b1..0baa941 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -1960,10 +1960,10 @@ def test_groupnorm():
         return x_hat, mean, std
 
     def np_groupnorm(data, gamma, beta, num_groups, eps):
-        new_param_shape = (1, num_groups, 1, 1, 1)
+        new_param_shape = (1, dshape[1], 1, 1)
         x_hat, mean, std = x_hat_helper(data, num_groups, eps)
-        out = x_hat * gamma.reshape(new_param_shape) + beta.reshape(new_param_shape)
-        return out.reshape(dshape), mean, std
+        out = x_hat.reshape(dshape) * gamma.reshape(new_param_shape) + beta.reshape(new_param_shape)
+        return out, mean, std
 
     def np_groupnorm_grad(ograd, data, gamma, beta, mean, std, num_groups, eps):
         x_hat, mean, std = x_hat_helper(data, num_groups, eps)
@@ -1971,7 +1971,7 @@ def test_groupnorm():
         dshape = data.shape
         dtype = data.dtype
         new_moments_shape = (new_shape[0], num_groups, 1, 1, 1)
-        new_param_shape = (1, num_groups, 1, 1, 1)
+        new_param_shape = (1, dshape[1], 1, 1)
         acc_type = acc_types[str(dtype)]
         ograd = ograd.reshape(new_shape)
         data = data.reshape(new_shape)
@@ -1979,9 +1979,9 @@ def test_groupnorm():
         beta = beta.reshape(new_param_shape)
         mean = mean.reshape(new_moments_shape)
         std = std.reshape(new_moments_shape)
-        beta_grad = np.sum(ograd, axis=(0, 2, 3, 4), dtype=acc_type, keepdims=False).astype(dtype)
-        gamma_grad = np.sum(x_hat * ograd, axis=(0, 2, 3, 4), dtype=acc_type, keepdims=False).astype(dtype)
-        x_hat_grad = ograd * gamma
+        beta_grad = np.sum(ograd, axis=(0, 3, 4), dtype=acc_type, keepdims=False).astype(dtype).flatten()
+        gamma_grad = np.sum(x_hat * ograd, axis=(0, 3, 4), dtype=acc_type, keepdims=False).astype(dtype).flatten()
+        x_hat_grad = ograd * gamma.reshape(1, num_groups, dshape[1] // num_groups, 1, 1)
         ograd_mult = x_hat_grad / std
         red_out = np.mean(ograd_mult, axis=(2, 3, 4), dtype=acc_type, keepdims=True).astype(dtype)
         data_grad = ograd_mult - red_out
@@ -1996,7 +1996,7 @@ def test_groupnorm():
     height = random.randint(1, 5)
     width = random.randint(1, 5)
     dshape = (batch_size, num_channels, height, width)
-    param_shape = (num_groups,)
+    param_shape = (num_channels,)
     temp_shape = (batch_size, num_groups, int(num_channels / num_groups), height, width)
     np_data = np.random.uniform(0.2, 1.0, dshape)
     np_gamma = np.random.uniform(-1.0, 1.0, param_shape)