You are viewing a plain text version of this content. The canonical link for it is here.
Posted to issues@mxnet.apache.org by GitBox <gi...@apache.org> on 2020/12/01 15:40:49 UTC
[GitHub] [incubator-mxnet] feevos opened a new issue #19609: mxnet 2.0 - GroupNorm does not work with as_nd_ndarray() upon activation of mx.npx.set_np()
feevos opened a new issue #19609:
URL: https://github.com/apache/incubator-mxnet/issues/19609
## Description
The operator ```gluon.nn.GroupNorm``` does not work properly (I believe) when we use ```mx.npx.set_np()```. That is something in its definition breaks and it cannot accept ```mx.np.ndarray.as_nd_ndarray```
### Error Message
in short (but see also below for the complete error message):
```python
ValueError: Deferred initialization failed because shape cannot be inferred. Operator `GroupNorm` registered in backend is known as `GroupNorm` in Python. This is a legacy operator which can only accept legacy ndarrays, while received an MXNet numpy ndarray. Please call `as_nd_ndarray()` upon the numpy ndarray to convert it to a legacy ndarray, and then feed the converted array to this operator.
```
## To Reproduce
A. Working example:
```python
import mxnet as mx
net = mx.gluon.nn.GroupNorm(num_groups=4)
net.initialize()
xx = mx.nd.random.uniform(shape=[3,32,512,512])
net.summary(xx) # works as expected
--------------------------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================================
Input (3, 32, 128, 128) 0
GroupNorm-1 (3, 32, 128, 128) 64
================================================================================
Parameters in forward computation graph, duplicate included
Total params: 64
Trainable params: 64
Non-trainable params: 0
Shared params in forward computation graph: 0
Unique parameters in model: 64
--------------------------------------------------------------------------------
```
B. When things break:
```python
import mxnet as mx
mx.npx.set_np()
net = gluon.nn.GroupNorm(num_groups=4)
net.initialize()
xx = mx.np.random.rand(3,32,512,512)
yy = xx.as_nd_ndarray()
net.summary(yy) # This now does not work, despite getting fed an mx.nd.ndarray
```
## Error message:
This is the complete error message.
```python
---------------------------------------------------------------------------
DeferredInitializationError Traceback (most recent call last)
~/.local/lib/python3.8/site-packages/mxnet/gluon/block.py in forward(self, x, *args)
1455 try:
-> 1456 params = {k: v.data(ctx) for k, v in self._reg_params.items()}
1457 except DeferredInitializationError:
~/.local/lib/python3.8/site-packages/mxnet/gluon/block.py in <dictcomp>(.0)
1455 try:
-> 1456 params = {k: v.data(ctx) for k, v in self._reg_params.items()}
1457 except DeferredInitializationError:
~/.local/lib/python3.8/site-packages/mxnet/gluon/parameter.py in data(self, ctx)
584 "instead." % (self.name, str(ctx), self._stype))
--> 585 data = self._check_and_get(self._data, ctx)
586 dc.set_variable(data, self.var())
~/.local/lib/python3.8/site-packages/mxnet/gluon/parameter.py in _check_and_get(self, arr_list, ctx)
239 if self._deferred_init:
--> 240 raise DeferredInitializationError(
241 "Parameter '%s' has not been initialized yet because initialization was " \
DeferredInitializationError: Parameter 'gamma' has not been initialized yet because initialization was deferred. Actual initialization happens during the first forward pass. Please pass one batch of data through the network before accessing Parameters. You can also avoid deferred initialization by specifying in_units, num_features, etc., for network layers.
During handling of the above exception, another exception occurred:
TypeError Traceback (most recent call last)
~/.local/lib/python3.8/site-packages/mxnet/gluon/block.py in _deferred_infer_shape(self, *args)
1089 try:
-> 1090 self.infer_shape(*args)
1091 except Exception as e:
~/.local/lib/python3.8/site-packages/mxnet/gluon/block.py in infer_shape(self, *args)
1282 # Gluon 1 based on F: hybrid_forward is defined by user
-> 1283 self._infer_attrs('infer_shape', 'shape', *args)
1284 else:
~/.local/lib/python3.8/site-packages/mxnet/gluon/block.py in _infer_attrs(self, infer_fn, attr, *args)
1264 """Generic infer attributes."""
-> 1265 inputs, out = self._get_graph(*args)
1266 args, _ = _flatten(args, "input")
~/.local/lib/python3.8/site-packages/mxnet/gluon/block.py in _get_graph(self, *args)
988 if inspect.unwrap(self.hybrid_forward.__func__) is not HybridBlock.hybrid_forward:
--> 989 return self._get_graph_v1(*args)
990 else: # Gluon 2 based on deferred compute mode
~/.local/lib/python3.8/site-packages/mxnet/gluon/block.py in _get_graph_v1(self, *args)
952 params = {i: j.var() for i, j in self._reg_params.items()}
--> 953 out = self.hybrid_forward(symbol, *grouped_inputs, **params) # pylint: disable=no-value-for-parameter
954 out, self._out_format = _flatten(out, "output")
~/.local/lib/python3.8/site-packages/mxnet/gluon/nn/basic_layers.py in hybrid_forward(self, F, data, gamma, beta)
842 def hybrid_forward(self, F, data, gamma, beta):
--> 843 norm_data = F.GroupNorm(data, gamma=gamma, beta=beta, num_groups=self._num_groups, eps=self._epsilon)
844 return norm_data
~/.local/lib/python3.8/site-packages/mxnet/symbol/register.py in GroupNorm(data, gamma, beta, num_groups, eps, output_mean_var, name, attr, out, **kwargs)
~/.local/lib/python3.8/site-packages/mxnet/symbol/register.py in _verify_legacy_symbol(op_name, func_name, sym)
73 if isinstance(sym, np_symbol):
---> 74 raise TypeError('Operator `{}` registered in backend is known as `{}` in Python. '
75 'This is a legacy operator which can only accept '
TypeError: Operator `GroupNorm` registered in backend is known as `GroupNorm` in Python. This is a legacy operator which can only accept legacy ndarrays, while received an MXNet numpy ndarray. Please call `as_nd_ndarray()` upon the numpy ndarray to convert it to a legacy ndarray, and then feed the converted array to this operator.
During handling of the above exception, another exception occurred:
ValueError Traceback (most recent call last)
<ipython-input-7-d686fa49b9c0> in <module>
----> 1 net.summary(yy)
~/.local/lib/python3.8/site-packages/mxnet/gluon/block.py in summary(self, *inputs)
829 try:
830 self.apply(_register_summary_hook)
--> 831 self(*inputs)
832
833 line_format = '{:>20} {:>42} {:>15}'
~/.local/lib/python3.8/site-packages/mxnet/gluon/block.py in __call__(self, x, *args)
1405 if inspect.unwrap(self.hybrid_forward.__func__) is not HybridBlock.hybrid_forward:
1406 # Gluon 1 based on F: hybrid_forward is defined by user
-> 1407 return super().__call__(x, *args)
1408 else: # Gluon 2 based on deferred compute mode
1409 assert self.forward is not HybridBlock.forward, (
~/.local/lib/python3.8/site-packages/mxnet/gluon/block.py in __call__(self, *args)
709 hook(self, args)
710
--> 711 out = self.forward(*args)
712
713 for hook in self._forward_hooks.values():
~/.local/lib/python3.8/site-packages/mxnet/gluon/block.py in forward(self, x, *args)
1456 params = {k: v.data(ctx) for k, v in self._reg_params.items()}
1457 except DeferredInitializationError:
-> 1458 self._deferred_infer_shape(x, *args)
1459 for _, v in self.params.items():
1460 v._finish_deferred_init()
~/.local/lib/python3.8/site-packages/mxnet/gluon/block.py in _deferred_infer_shape(self, *args)
1092 error_msg = "Deferred initialization failed because shape"\
1093 " cannot be inferred. {}".format(e)
-> 1094 raise ValueError(error_msg)
1095
1096 def _call_cached_op(self, *args):
ValueError: Deferred initialization failed because shape cannot be inferred. Operator `GroupNorm` registered in backend is known as `GroupNorm` in Python. This is a legacy operator which can only accept legacy ndarrays, while received an MXNet numpy ndarray. Please call `as_nd_ndarray()` upon the numpy ndarray to convert it to a legacy ndarray, and then feed the converted array to this operator.
```
### Steps to reproduce
provided above
## What have you tried to solve it?
I cannot solve it.
## Environment
***We recommend using our script for collecting the diagnostic information with the following command***
`curl --retry 10 -s https://raw.githubusercontent.com/apache/incubator-mxnet/master/tools/diagnose.py | python3`
<details>
<summary>Environment Information</summary>
```
----------Python Info----------
Version : 3.8.5
Compiler : GCC 9.3.0
Build : ('default', 'Jul 28 2020 12:59:40')
Arch : ('64bit', 'ELF')
------------Pip Info-----------
Version : 20.0.2
Directory : /usr/lib/python3/dist-packages/pip
----------MXNet Info-----------
Version : 2.0.0
Directory : /home/foivos/.local/lib/python3.8/site-packages/mxnet
Commit hash file "/home/foivos/.local/lib/python3.8/site-packages/mxnet/COMMIT_HASH" not found. Not installed from pre-built package or built from source.
Library : ['/home/foivos/.local/lib/python3.8/site-packages/mxnet/libmxnet.so']
Build features:
✔ CUDA
✔ CUDNN
✖ NCCL
✖ TENSORRT
✔ CPU_SSE
✔ CPU_SSE2
✔ CPU_SSE3
✖ CPU_SSE4_1
✖ CPU_SSE4_2
✖ CPU_SSE4A
✖ CPU_AVX
✖ CPU_AVX2
✔ OPENMP
✖ SSE
✖ F16C
✖ JEMALLOC
✔ BLAS_OPEN
✖ BLAS_ATLAS
✖ BLAS_MKL
✖ BLAS_APPLE
✔ LAPACK
✔ MKLDNN
✔ OPENCV
✔ DIST_KVSTORE
✖ INT64_TENSOR_SIZE
✔ SIGNAL_HANDLER
✖ DEBUG
✖ TVM_OP
----------System Info----------
Platform : Linux-5.4.0-7642-generic-x86_64-with-glibc2.29
system : Linux
node : dep59910
release : 5.4.0-7642-generic
version : #46~1598628707~20.04~040157c-Ubuntu SMP Fri Aug 28 18:02:16 UTC
----------Hardware Info----------
machine : x86_64
processor : x86_64
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Byte Order: Little Endian
Address sizes: 39 bits physical, 48 bits virtual
CPU(s): 16
On-line CPU(s) list: 0-15
Thread(s) per core: 2
Core(s) per socket: 8
Socket(s): 1
NUMA node(s): 1
Vendor ID: GenuineIntel
CPU family: 6
Model: 165
Model name: Intel(R) Core(TM) i7-10875H CPU @ 2.30GHz
Stepping: 2
CPU MHz: 3484.141
CPU max MHz: 5100.0000
CPU min MHz: 800.0000
BogoMIPS: 4599.93
Virtualisation: VT-x
L1d cache: 256 KiB
L1i cache: 256 KiB
L2 cache: 2 MiB
L3 cache: 16 MiB
NUMA node0 CPU(s): 0-15
Vulnerability Itlb multihit: KVM: Mitigation: Split huge pages
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and s
eccomp
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanit
ization
Vulnerability Spectre v2: Mitigation; Enhanced IBRS, IBPB conditional, RSB filling
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov
pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe sy
scall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs
bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf pni
pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg f
ma cx16 xtpr pdcm pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_
deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowpre
fetch cpuid_fault invpcid_single ssbd ibrs ibpb stibp ibrs_en
hanced tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase
tsc_adjust bmi1 avx2 smep bmi2 erms invpcid mpx rdseed adx sm
ap clflushopt intel_pt xsaveopt xsavec xgetbv1 xsaves dtherm
ida arat pln pts hwp hwp_notify hwp_act_window hwp_epp pku os
pke md_clear flush_l1d arch_capabilities
----------Network Test----------
Setting timeout: 10
Timing for MXNet: https://github.com/apache/incubator-mxnet, DNS: 0.0186 sec, LOAD: 0.8468 sec.
Timing for Gluon Tutorial(en): http://gluon.mxnet.io, DNS: 0.4067 sec, LOAD: 0.6158 sec.
Error open Gluon Tutorial(cn): https://zh.gluon.ai, <urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1123)>, DNS finished in 0.714545488357544 sec.
Timing for FashionMNIST: https://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/gluon/dataset/fashion-mnist/train-labels-idx1-ubyte.gz, DNS: 0.1729 sec, LOAD: 1.0142 sec.
Timing for PYPI: https://pypi.python.org/pypi/pip, DNS: 0.0186 sec, LOAD: 1.7548 sec.
Error open Conda: https://repo.continuum.io/pkgs/free/, HTTP Error 403: Forbidden, DNS finished in 0.13265538215637207 sec.
----------Environment----------
KMP_DUPLICATE_LIB_OK="True"
KMP_INIT_AT_FORK="FALSE"
```
</details>
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
users@infra.apache.org
---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscribe@mxnet.apache.org
For additional commands, e-mail: issues-help@mxnet.apache.org
[GitHub] [incubator-mxnet] feevos commented on issue #19609: mxnet 2.0 - GroupNorm does not work with as_nd_ndarray() upon activation of mx.npx.set_np()
Posted by GitBox <gi...@apache.org>.
feevos commented on issue #19609:
URL: https://github.com/apache/incubator-mxnet/issues/19609#issuecomment-774105316
Dear all, this is the final solution that is working for me, using legacy operator mx.nd.GroupNorm
```Python
import mxnet as mx
from mxnet.gluon import HybridBlock
from mxnet.gluon.parameter import Parameter
@mx.use_np
class GroupNorm(HybridBlock):
r"""
Applies group normalization to the n-dimensional input array.
This operator takes an n-dimensional input array where the leftmost 2 axis are
`batch` and `channel` respectively:
.. math::
x = x.reshape((N, num_groups, C // num_groups, ...))
axis = (2, ...)
out = \frac{x - mean[x, axis]}{ \sqrt{Var[x, axis] + \epsilon}} * gamma + beta
Parameters
----------
num_groups: int, default 1
Number of groups to separate the channel axis into.
epsilon: float, default 1e-5
Small float added to variance to avoid dividing by zero.
center: bool, default True
If True, add offset of `beta` to normalized tensor.
If False, `beta` is ignored.
scale: bool, default True
If True, multiply by `gamma`. If False, `gamma` is not used.
beta_initializer: str or `Initializer`, default 'zeros'
Initializer for the beta weight.
gamma_initializer: str or `Initializer`, default 'ones'
Initializer for the gamma weight.
Inputs:
- **data**: input tensor with shape (N, C, ...).
Outputs:
- **out**: output tensor with the same shape as `data`.
References
----------
`Group Normalization
<https://arxiv.org/pdf/1803.08494.pdf>`_
Examples
--------
# Input of shape (2, 3, 4)
x = mx.nd.array([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]],
[[12, 13, 14, 15],
[16, 17, 18, 19],
[20, 21, 22, 23]]])
# Group normalization is calculated with the above formula
layer = GroupNorm()
layer.initialize(ctx=mx.cpu(0))
layer(x)
[[[-1.5932543 -1.3035717 -1.0138891 -0.7242065]
[-0.4345239 -0.1448413 0.1448413 0.4345239]
[ 0.7242065 1.0138891 1.3035717 1.5932543]]
[[-1.5932543 -1.3035717 -1.0138891 -0.7242065]
[-0.4345239 -0.1448413 0.1448413 0.4345239]
[ 0.7242065 1.0138891 1.3035717 1.5932543]]]
<NDArray 2x3x4 @cpu(0)>
"""
def __init__(self, num_groups=1, epsilon=1e-5, center=True, scale=True,
beta_initializer='zeros', gamma_initializer='ones',
in_channels=0):
super(GroupNorm, self).__init__()
self._kwargs = {'eps': epsilon, 'num_groups': num_groups, 'center': center, 'scale': scale}
self._num_groups = num_groups
self._epsilon = epsilon
self._center = center
self._scale = scale
self.gamma = Parameter('gamma', grad_req='write' if scale else 'null',
shape=(in_channels,), init=gamma_initializer,
allow_deferred_init=True)
self.beta = Parameter('beta', grad_req='write' if center else 'null',
shape=(in_channels,), init=beta_initializer,
allow_deferred_init=True)
def infer_shape(self,in_shape):
# Necessary for mxnet 2.0
tshape = in_shape.shape
self.gamma.shape = tshape[1],
self.beta.shape = tshape[1],
def forward(self, x):
gamma = self.gamma.data().as_nd_ndarray()
beta = self.beta.data().as_nd_ndarray()
x = mx.nd.GroupNorm(data=x.as_nd_ndarray(), gamma=gamma, beta=beta, num_groups=self._num_groups, eps=self._epsilon)
x = x.as_np_ndarray()
return x
def __repr__(self):
s = '{name}({content}'
in_channels = self.gamma.shape[0]
s += ', in_channels={0}'.format(in_channels)
s += ')'
return s.format(name=self.__class__.__name__,
content=', '.join(['='.join([k, v.__repr__()])
for k, v in self._kwargs.items()]))
```
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
users@infra.apache.org
---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscribe@mxnet.apache.org
For additional commands, e-mail: issues-help@mxnet.apache.org
[GitHub] [incubator-mxnet] feevos commented on issue #19609: mxnet 2.0 - GroupNorm does not work with as_nd_ndarray() upon activation of mx.npx.set_np()
Posted by GitBox <gi...@apache.org>.
feevos commented on issue #19609:
URL: https://github.com/apache/incubator-mxnet/issues/19609#issuecomment-738546600
Workaround that allows operations:
```python
net = gluon.nn.GroupNorm(num_groups=4)
net.initialize()
xx = mx.np.random.rand(3,32,512,512)
yy = xx.as_nd_ndarray()
mx.npx.reset_np() # <====== This fixes the situation by resetting numpy behaviour (I think!)
net.summary(yy)
mx.npx.set_np() # <======== This restores numpy behaviour.
```
Output:
```python
--------------------------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================================
Input (3, 32, 512, 512) 0
GroupNorm-1 (3, 32, 512, 512) 64
================================================================================
Parameters in forward computation graph, duplicate included
Total params: 64
Trainable params: 64
Non-trainable params: 0
Shared params in forward computation graph: 0
Unique parameters in model: 64
--------------------------------------------------------------------------------
```
This is a small replacement hack for GroupNorm:
```python
import mxnet as mx
from mxnet import gluon
class GroupNormHack(gluon.nn.HybridBlock):
"""
This is a partial fix for issue #19609
see https://github.com/apache/incubator-mxnet/issues/19609
"""
def __init__(self, num_groups, **kwards):
super().__init__(**kwards)
self.norm = gluon.nn.GroupNorm(num_groups=num_groups,**kwards)
def forward(self, input):
tinput = input.as_nd_ndarray() if mx.npx.is_np_array() else input
mx.npx.reset_np()
out = self.norm(tinput)
mx.npx.set_np()
out = out.as_np_ndarray()
return out
````
toy run:
```python
In [11]: net = GroupNormHack(num_groups=4)
...: net.initialize()
...: xx = mx.np.random.rand(3,32,512,512)
...: net.summary(xx)
--------------------------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================================
Input (3, 32, 512, 512) 0
GroupNorm-1 (3, 32, 512, 512) 64
GroupNormHack-2 (3, 32, 512, 512) 0
================================================================================
Parameters in forward computation graph, duplicate included
Total params: 64
Trainable params: 64
Non-trainable params: 0
Shared params in forward computation graph: 0
Unique parameters in model: 64
--------------------------------------------------------------------------------
In [12]: yy = xx.as_nd_ndarray()
In [13]: net.summary(yy)
--------------------------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================================
Input (3, 32, 512, 512) 0
GroupNorm-1 (3, 32, 512, 512) 64
GroupNormHack-2 (3, 32, 512, 512) 0
================================================================================
Parameters in forward computation graph, duplicate included
Total params: 64
Trainable params: 64
Non-trainable params: 0
Shared params in forward computation graph: 0
Unique parameters in model: 64
--------------------------------------------------------------------------------
In [14]: net.hybridize()
In [15]: out = net(xx)
In [16]: out = net(yy)
In [17]:
```
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
users@infra.apache.org
---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscribe@mxnet.apache.org
For additional commands, e-mail: issues-help@mxnet.apache.org
[GitHub] [incubator-mxnet] feevos commented on issue #19609: mxnet 2.0 - GroupNorm does not work with as_nd_ndarray() upon activation of mx.npx.set_np()
Posted by GitBox <gi...@apache.org>.
feevos commented on issue #19609:
URL: https://github.com/apache/incubator-mxnet/issues/19609#issuecomment-773565727
The workaround presented above does not allow me to load saved weights - it gives the same error.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
users@infra.apache.org
---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscribe@mxnet.apache.org
For additional commands, e-mail: issues-help@mxnet.apache.org
[GitHub] [incubator-mxnet] feevos commented on issue #19609: mxnet 2.0 - GroupNorm does not work with as_nd_ndarray() upon activation of mx.npx.set_np()
Posted by GitBox <gi...@apache.org>.
feevos commented on issue #19609:
URL: https://github.com/apache/incubator-mxnet/issues/19609#issuecomment-773565727
The workaround presented above does not allow me to load saved weights - it gives the same error.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
users@infra.apache.org
---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscribe@mxnet.apache.org
For additional commands, e-mail: issues-help@mxnet.apache.org