You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by sx...@apache.org on 2020/04/30 17:04:36 UTC
[incubator-mxnet] branch master updated: [Bug Fix] Fix GroupNorm
Implementation (#18199)
This is an automated email from the ASF dual-hosted git repository.
sxjscience pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 1496c91 [Bug Fix] Fix GroupNorm Implementation (#18199)
1496c91 is described below
commit 1496c91871b9d81d6a18785bdc8a1c3450bedbca
Author: Huang, Guangtai <gu...@amazon.com>
AuthorDate: Fri May 1 01:03:41 2020 +0800
[Bug Fix] Fix GroupNorm Implementation (#18199)
* init
* add in_channels
---
python/mxnet/gluon/nn/basic_layers.py | 11 +++++++----
src/operator/nn/group_norm-inl.h | 25 +++++++++++++------------
src/operator/nn/group_norm.cc | 4 ++--
tests/python/unittest/test_operator.py | 16 ++++++++--------
4 files changed, 30 insertions(+), 26 deletions(-)
diff --git a/python/mxnet/gluon/nn/basic_layers.py b/python/mxnet/gluon/nn/basic_layers.py
index 70b0a71..797392a 100644
--- a/python/mxnet/gluon/nn/basic_layers.py
+++ b/python/mxnet/gluon/nn/basic_layers.py
@@ -820,7 +820,7 @@ class GroupNorm(HybridBlock):
"""
def __init__(self, num_groups=1, epsilon=1e-5, center=True, scale=True,
beta_initializer='zeros', gamma_initializer='ones',
- prefix=None, params=None):
+ in_channels=0, prefix=None, params=None):
super(GroupNorm, self).__init__(prefix=prefix, params=params)
self._kwargs = {'eps': epsilon, 'num_groups': num_groups, 'center': center, 'scale': scale}
self._num_groups = num_groups
@@ -828,10 +828,10 @@ class GroupNorm(HybridBlock):
self._center = center
self._scale = scale
self.gamma = self.params.get('gamma', grad_req='write' if scale else 'null',
- shape=(num_groups,), init=gamma_initializer,
+ shape=(in_channels,), init=gamma_initializer,
allow_deferred_init=True)
self.beta = self.params.get('beta', grad_req='write' if center else 'null',
- shape=(num_groups,), init=beta_initializer,
+ shape=(in_channels,), init=beta_initializer,
allow_deferred_init=True)
def hybrid_forward(self, F, data, gamma, beta):
@@ -839,7 +839,10 @@ class GroupNorm(HybridBlock):
return norm_data
def __repr__(self):
- s = '{name}({content})'
+ s = '{name}({content}'
+ in_channels = self.gamma.shape[0]
+ s += ', in_channels={0}'.format(in_channels)
+ s += ')'
return s.format(name=self.__class__.__name__,
content=', '.join(['='.join([k, v.__repr__()])
for k, v in self._kwargs.items()]))
diff --git a/src/operator/nn/group_norm-inl.h b/src/operator/nn/group_norm-inl.h
index 69d5a30..143e216 100644
--- a/src/operator/nn/group_norm-inl.h
+++ b/src/operator/nn/group_norm-inl.h
@@ -136,16 +136,16 @@ void GroupNormCompute(const nnvm::NodeAttrs& attrs,
TBlob data_grp = data.reshape(temp_data_shape);
const TBlob& mean_grp = mean.reshape(moments_shape);
const TBlob& std_grp = std.reshape(moments_shape);
- const TBlob& output = outputs[groupnorm::kOut].reshape(temp_data_shape);
+ const TBlob& output_grp = outputs[groupnorm::kOut].reshape(temp_data_shape);
// Calculate data = data - mean
BinaryBroadcastCompute<xpu, op::mshadow_op::minus>(attrs, ctx,
{data_grp, mean_grp},
- {kWriteTo}, {output});
+ {kWriteTo}, {output_grp});
// Calculate std
const TBlob centered_out = outputs[groupnorm::kOut].reshape(red_src_shape);
- MSHADOW_REAL_TYPE_SWITCH(output.type_flag_, DType, {
+ MSHADOW_REAL_TYPE_SWITCH(output_grp.type_flag_, DType, {
BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, {
broadcast::Reduce<mshadow_op::sum, NDim, DType, mshadow_op::square, true>(
s, std_, req[0], workspace, centered_out);
@@ -157,11 +157,12 @@ void GroupNormCompute(const nnvm::NodeAttrs& attrs,
// Calculate data = data / std
BinaryBroadcastCompute<xpu, mshadow_op::div>(attrs, ctx,
- {output, std_grp},
- {kWriteTo}, {output});
+ {output_grp, std_grp},
+ {kWriteTo}, {output_grp});
- mxnet::TShape new_param_shape(data_shape.ndim() + 1, 1);
- new_param_shape[1] = num_groups;
+ const TBlob& output = outputs[groupnorm::kOut];
+ mxnet::TShape new_param_shape(data_shape.ndim(), 1);
+ new_param_shape[1] = data_shape[1];
const TBlob& gamma = inputs[groupnorm::kGamma].reshape(new_param_shape);
const TBlob& beta = inputs[groupnorm::kBeta].reshape(new_param_shape);
@@ -215,8 +216,8 @@ void GroupNormGradCompute(const nnvm::NodeAttrs& attrs,
Stream<xpu> *s = ctx.get_stream<xpu>();
// Reshape gamma to be broadcastable
- mxnet::TShape new_param_shape(dshape.ndim() + 1, 1);
- new_param_shape[1] = num_groups;
+ mxnet::TShape new_param_shape(dshape.ndim(), 1);
+ new_param_shape[1] = dshape[1];
const TBlob& gamma = inputs[2].reshape(new_param_shape);
@@ -233,7 +234,7 @@ void GroupNormGradCompute(const nnvm::NodeAttrs& attrs,
// Prepare the necessary shapes for reduction
mxnet::TShape red_src_shape, red_dst_shape, red_exclude_src_shape, red_exclude_dst_shape;
BroadcastReduceShapeCompact(temp_dshape, mean_.shape_, &red_src_shape, &red_dst_shape);
- BroadcastReduceShapeCompact(temp_dshape, gamma.shape_,
+ BroadcastReduceShapeCompact(dshape, gamma.shape_,
&red_exclude_src_shape, &red_exclude_dst_shape);
int N = red_src_shape.Size() / red_dst_shape.Size();
@@ -308,8 +309,8 @@ void GroupNormGradCompute(const nnvm::NodeAttrs& attrs,
if (req[0] != kNullOp) {
const TBlob output_ = outputs[0].reshape(data_.shape_);
BinaryBroadcastCompute<xpu, op::mshadow_op::mul>(attrs, ctx,
- {ograd, gamma},
- {kWriteTo}, {ograd_mult});
+ {inputs[0], gamma},
+ {kWriteTo}, {ograd_mult.reshape(data.shape_)});
BinaryBroadcastCompute<xpu, op::mshadow_op::div>(attrs, ctx,
{ograd_mult, std_},
{kWriteTo}, {ograd_mult});
diff --git a/src/operator/nn/group_norm.cc b/src/operator/nn/group_norm.cc
index 6b8fe9b..c939b44 100644
--- a/src/operator/nn/group_norm.cc
+++ b/src/operator/nn/group_norm.cc
@@ -47,8 +47,8 @@ static bool GroupNormShape(const nnvm::NodeAttrs& attrs,
return false;
}
- in_shape->at(groupnorm::kGamma) = mxnet::TShape(Shape1(num_groups));
- in_shape->at(groupnorm::kBeta) = mxnet::TShape(Shape1(num_groups));
+ in_shape->at(groupnorm::kGamma) = mxnet::TShape(Shape1(dshape[1]));
+ in_shape->at(groupnorm::kBeta) = mxnet::TShape(Shape1(dshape[1]));
out_shape->clear();
out_shape->push_back(dshape);
diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py
index 32812b1..0baa941 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -1960,10 +1960,10 @@ def test_groupnorm():
return x_hat, mean, std
def np_groupnorm(data, gamma, beta, num_groups, eps):
- new_param_shape = (1, num_groups, 1, 1, 1)
+ new_param_shape = (1, dshape[1], 1, 1)
x_hat, mean, std = x_hat_helper(data, num_groups, eps)
- out = x_hat * gamma.reshape(new_param_shape) + beta.reshape(new_param_shape)
- return out.reshape(dshape), mean, std
+ out = x_hat.reshape(dshape) * gamma.reshape(new_param_shape) + beta.reshape(new_param_shape)
+ return out, mean, std
def np_groupnorm_grad(ograd, data, gamma, beta, mean, std, num_groups, eps):
x_hat, mean, std = x_hat_helper(data, num_groups, eps)
@@ -1971,7 +1971,7 @@ def test_groupnorm():
dshape = data.shape
dtype = data.dtype
new_moments_shape = (new_shape[0], num_groups, 1, 1, 1)
- new_param_shape = (1, num_groups, 1, 1, 1)
+ new_param_shape = (1, dshape[1], 1, 1)
acc_type = acc_types[str(dtype)]
ograd = ograd.reshape(new_shape)
data = data.reshape(new_shape)
@@ -1979,9 +1979,9 @@ def test_groupnorm():
beta = beta.reshape(new_param_shape)
mean = mean.reshape(new_moments_shape)
std = std.reshape(new_moments_shape)
- beta_grad = np.sum(ograd, axis=(0, 2, 3, 4), dtype=acc_type, keepdims=False).astype(dtype)
- gamma_grad = np.sum(x_hat * ograd, axis=(0, 2, 3, 4), dtype=acc_type, keepdims=False).astype(dtype)
- x_hat_grad = ograd * gamma
+ beta_grad = np.sum(ograd, axis=(0, 3, 4), dtype=acc_type, keepdims=False).astype(dtype).flatten()
+ gamma_grad = np.sum(x_hat * ograd, axis=(0, 3, 4), dtype=acc_type, keepdims=False).astype(dtype).flatten()
+ x_hat_grad = ograd * gamma.reshape(1, num_groups, dshape[1] // num_groups, 1, 1)
ograd_mult = x_hat_grad / std
red_out = np.mean(ograd_mult, axis=(2, 3, 4), dtype=acc_type, keepdims=True).astype(dtype)
data_grad = ograd_mult - red_out
@@ -1996,7 +1996,7 @@ def test_groupnorm():
height = random.randint(1, 5)
width = random.randint(1, 5)
dshape = (batch_size, num_channels, height, width)
- param_shape = (num_groups,)
+ param_shape = (num_channels,)
temp_shape = (batch_size, num_groups, int(num_channels / num_groups), height, width)
np_data = np.random.uniform(0.2, 1.0, dshape)
np_gamma = np.random.uniform(-1.0, 1.0, param_shape)