You are viewing a plain text version of this content. The canonical link for it is here.
Posted to issues@mxnet.apache.org by GitBox <gi...@apache.org> on 2020/12/01 15:40:49 UTC

[GitHub] [incubator-mxnet] feevos opened a new issue #19609: mxnet 2.0 - GroupNorm does not work with as_nd_ndarray() upon activation of mx.npx.set_np()

feevos opened a new issue #19609:
URL: https://github.com/apache/incubator-mxnet/issues/19609


   ## Description
   The operator ```gluon.nn.GroupNorm``` does not work properly (I believe) when we use ```mx.npx.set_np()```. That is something in its definition breaks and it cannot accept ```mx.np.ndarray.as_nd_ndarray```
   
   ### Error Message
   in short (but see also below for the complete error message): 
   
   ```python
   ValueError: Deferred initialization failed because shape cannot be inferred. Operator `GroupNorm` registered in backend is known as `GroupNorm` in Python. This is a legacy operator which can only accept legacy ndarrays, while received an MXNet numpy ndarray. Please call `as_nd_ndarray()` upon the numpy ndarray to convert it to a legacy ndarray, and then feed the converted array to this operator.
   ```
   ## To Reproduce
   A. Working example: 
   ```python
   import mxnet as mx
   net = mx.gluon.nn.GroupNorm(num_groups=4)
   net.initialize()
   xx = mx.nd.random.uniform(shape=[3,32,512,512])
   net.summary(xx) # works as expected
   
   --------------------------------------------------------------------------------
           Layer (type)                                Output Shape         Param #
   ================================================================================
                  Input                           (3, 32, 128, 128)               0
            GroupNorm-1                           (3, 32, 128, 128)              64
   ================================================================================
   Parameters in forward computation graph, duplicate included
      Total params: 64
      Trainable params: 64
      Non-trainable params: 0
   Shared params in forward computation graph: 0
   Unique parameters in model: 64
   --------------------------------------------------------------------------------
   ```
   B. When things break:
   
   ```python
   import mxnet as mx
   mx.npx.set_np()
   
   net = gluon.nn.GroupNorm(num_groups=4)
   net.initialize()
   xx = mx.np.random.rand(3,32,512,512)
   yy = xx.as_nd_ndarray()
   net.summary(yy) # This now does not work, despite getting fed an mx.nd.ndarray 
   ```
   ## Error message: 
   This is the complete error message. 
   ```python
   ---------------------------------------------------------------------------
   DeferredInitializationError               Traceback (most recent call last)
   ~/.local/lib/python3.8/site-packages/mxnet/gluon/block.py in forward(self, x, *args)
      1455                 try:
   -> 1456                     params = {k: v.data(ctx) for k, v in self._reg_params.items()}
      1457                 except DeferredInitializationError:
   
   ~/.local/lib/python3.8/site-packages/mxnet/gluon/block.py in <dictcomp>(.0)
      1455                 try:
   -> 1456                     params = {k: v.data(ctx) for k, v in self._reg_params.items()}
      1457                 except DeferredInitializationError:
   
   ~/.local/lib/python3.8/site-packages/mxnet/gluon/parameter.py in data(self, ctx)
       584                                "instead." % (self.name, str(ctx), self._stype))
   --> 585         data = self._check_and_get(self._data, ctx)
       586         dc.set_variable(data, self.var())
   
   ~/.local/lib/python3.8/site-packages/mxnet/gluon/parameter.py in _check_and_get(self, arr_list, ctx)
       239         if self._deferred_init:
   --> 240             raise DeferredInitializationError(
       241                 "Parameter '%s' has not been initialized yet because initialization was " \
   
   DeferredInitializationError: Parameter 'gamma' has not been initialized yet because initialization was deferred. Actual initialization happens during the first forward pass. Please pass one batch of data through the network before accessing Parameters. You can also avoid deferred initialization by specifying in_units, num_features, etc., for network layers.
   
   During handling of the above exception, another exception occurred:
   
   TypeError                                 Traceback (most recent call last)
   ~/.local/lib/python3.8/site-packages/mxnet/gluon/block.py in _deferred_infer_shape(self, *args)
      1089         try:
   -> 1090             self.infer_shape(*args)
      1091         except Exception as e:
   
   ~/.local/lib/python3.8/site-packages/mxnet/gluon/block.py in infer_shape(self, *args)
      1282             # Gluon 1 based on F:  hybrid_forward is defined by user
   -> 1283             self._infer_attrs('infer_shape', 'shape', *args)
      1284         else:
   
   ~/.local/lib/python3.8/site-packages/mxnet/gluon/block.py in _infer_attrs(self, infer_fn, attr, *args)
      1264         """Generic infer attributes."""
   -> 1265         inputs, out = self._get_graph(*args)
      1266         args, _ = _flatten(args, "input")
   
   ~/.local/lib/python3.8/site-packages/mxnet/gluon/block.py in _get_graph(self, *args)
       988             if inspect.unwrap(self.hybrid_forward.__func__) is not HybridBlock.hybrid_forward:
   --> 989                 return self._get_graph_v1(*args)
       990             else:  # Gluon 2 based on deferred compute mode
   
   ~/.local/lib/python3.8/site-packages/mxnet/gluon/block.py in _get_graph_v1(self, *args)
       952                 params = {i: j.var() for i, j in self._reg_params.items()}
   --> 953                 out = self.hybrid_forward(symbol, *grouped_inputs, **params)  # pylint: disable=no-value-for-parameter
       954             out, self._out_format = _flatten(out, "output")
   
   ~/.local/lib/python3.8/site-packages/mxnet/gluon/nn/basic_layers.py in hybrid_forward(self, F, data, gamma, beta)
       842     def hybrid_forward(self, F, data, gamma, beta):
   --> 843         norm_data = F.GroupNorm(data, gamma=gamma, beta=beta, num_groups=self._num_groups, eps=self._epsilon)
       844         return norm_data
   
   ~/.local/lib/python3.8/site-packages/mxnet/symbol/register.py in GroupNorm(data, gamma, beta, num_groups, eps, output_mean_var, name, attr, out, **kwargs)
   
   ~/.local/lib/python3.8/site-packages/mxnet/symbol/register.py in _verify_legacy_symbol(op_name, func_name, sym)
        73     if isinstance(sym, np_symbol):
   ---> 74         raise TypeError('Operator `{}` registered in backend is known as `{}` in Python. '
        75                         'This is a legacy operator which can only accept '
   
   TypeError: Operator `GroupNorm` registered in backend is known as `GroupNorm` in Python. This is a legacy operator which can only accept legacy ndarrays, while received an MXNet numpy ndarray. Please call `as_nd_ndarray()` upon the numpy ndarray to convert it to a legacy ndarray, and then feed the converted array to this operator.
   
   During handling of the above exception, another exception occurred:
   
   ValueError                                Traceback (most recent call last)
   <ipython-input-7-d686fa49b9c0> in <module>
   ----> 1 net.summary(yy)
   
   ~/.local/lib/python3.8/site-packages/mxnet/gluon/block.py in summary(self, *inputs)
       829         try:
       830             self.apply(_register_summary_hook)
   --> 831             self(*inputs)
       832 
       833             line_format = '{:>20}  {:>42} {:>15}'
   
   ~/.local/lib/python3.8/site-packages/mxnet/gluon/block.py in __call__(self, x, *args)
      1405         if inspect.unwrap(self.hybrid_forward.__func__) is not HybridBlock.hybrid_forward:
      1406             # Gluon 1 based on F:  hybrid_forward is defined by user
   -> 1407             return super().__call__(x, *args)
      1408         else:  # Gluon 2 based on deferred compute mode
      1409             assert self.forward is not HybridBlock.forward, (
   
   ~/.local/lib/python3.8/site-packages/mxnet/gluon/block.py in __call__(self, *args)
       709             hook(self, args)
       710 
   --> 711         out = self.forward(*args)
       712 
       713         for hook in self._forward_hooks.values():
   
   ~/.local/lib/python3.8/site-packages/mxnet/gluon/block.py in forward(self, x, *args)
      1456                     params = {k: v.data(ctx) for k, v in self._reg_params.items()}
      1457                 except DeferredInitializationError:
   -> 1458                     self._deferred_infer_shape(x, *args)
      1459                     for _, v in self.params.items():
      1460                         v._finish_deferred_init()
   
   ~/.local/lib/python3.8/site-packages/mxnet/gluon/block.py in _deferred_infer_shape(self, *args)
      1092             error_msg = "Deferred initialization failed because shape"\
      1093                         " cannot be inferred. {}".format(e)
   -> 1094             raise ValueError(error_msg)
      1095 
      1096     def _call_cached_op(self, *args):
   
   ValueError: Deferred initialization failed because shape cannot be inferred. Operator `GroupNorm` registered in backend is known as `GroupNorm` in Python. This is a legacy operator which can only accept legacy ndarrays, while received an MXNet numpy ndarray. Please call `as_nd_ndarray()` upon the numpy ndarray to convert it to a legacy ndarray, and then feed the converted array to this operator.
   
   ```
   
   ### Steps to reproduce
   provided above
   
   ## What have you tried to solve it?
   I cannot solve it. 
   ## Environment
   
   ***We recommend using our script for collecting the diagnostic information with the following command***
   `curl --retry 10 -s https://raw.githubusercontent.com/apache/incubator-mxnet/master/tools/diagnose.py | python3`
   
   <details>
   <summary>Environment Information</summary>
   
   ```
   ----------Python Info----------
   Version      : 3.8.5
   Compiler     : GCC 9.3.0
   Build        : ('default', 'Jul 28 2020 12:59:40')
   Arch         : ('64bit', 'ELF')
   ------------Pip Info-----------
   Version      : 20.0.2
   Directory    : /usr/lib/python3/dist-packages/pip
   ----------MXNet Info-----------
   Version      : 2.0.0
   Directory    : /home/foivos/.local/lib/python3.8/site-packages/mxnet
   Commit hash file "/home/foivos/.local/lib/python3.8/site-packages/mxnet/COMMIT_HASH" not found. Not installed from pre-built package or built from source.
   Library      : ['/home/foivos/.local/lib/python3.8/site-packages/mxnet/libmxnet.so']
   Build features:
   ✔ CUDA
   ✔ CUDNN
   ✖ NCCL
   ✖ TENSORRT
   ✔ CPU_SSE
   ✔ CPU_SSE2
   ✔ CPU_SSE3
   ✖ CPU_SSE4_1
   ✖ CPU_SSE4_2
   ✖ CPU_SSE4A
   ✖ CPU_AVX
   ✖ CPU_AVX2
   ✔ OPENMP
   ✖ SSE
   ✖ F16C
   ✖ JEMALLOC
   ✔ BLAS_OPEN
   ✖ BLAS_ATLAS
   ✖ BLAS_MKL
   ✖ BLAS_APPLE
   ✔ LAPACK
   ✔ MKLDNN
   ✔ OPENCV
   ✔ DIST_KVSTORE
   ✖ INT64_TENSOR_SIZE
   ✔ SIGNAL_HANDLER
   ✖ DEBUG
   ✖ TVM_OP
   ----------System Info----------
   Platform     : Linux-5.4.0-7642-generic-x86_64-with-glibc2.29
   system       : Linux
   node         : dep59910
   release      : 5.4.0-7642-generic
   version      : #46~1598628707~20.04~040157c-Ubuntu SMP Fri Aug 28 18:02:16 UTC 
   ----------Hardware Info----------
   machine      : x86_64
   processor    : x86_64
   Architecture:                    x86_64
   CPU op-mode(s):                  32-bit, 64-bit
   Byte Order:                      Little Endian
   Address sizes:                   39 bits physical, 48 bits virtual
   CPU(s):                          16
   On-line CPU(s) list:             0-15
   Thread(s) per core:              2
   Core(s) per socket:              8
   Socket(s):                       1
   NUMA node(s):                    1
   Vendor ID:                       GenuineIntel
   CPU family:                      6
   Model:                           165
   Model name:                      Intel(R) Core(TM) i7-10875H CPU @ 2.30GHz
   Stepping:                        2
   CPU MHz:                         3484.141
   CPU max MHz:                     5100.0000
   CPU min MHz:                     800.0000
   BogoMIPS:                        4599.93
   Virtualisation:                  VT-x
   L1d cache:                       256 KiB
   L1i cache:                       256 KiB
   L2 cache:                        2 MiB
   L3 cache:                        16 MiB
   NUMA node0 CPU(s):               0-15
   Vulnerability Itlb multihit:     KVM: Mitigation: Split huge pages
   Vulnerability L1tf:              Not affected
   Vulnerability Mds:               Not affected
   Vulnerability Meltdown:          Not affected
   Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and s
                                    eccomp
   Vulnerability Spectre v1:        Mitigation; usercopy/swapgs barriers and __user pointer sanit
                                    ization
   Vulnerability Spectre v2:        Mitigation; Enhanced IBRS, IBPB conditional, RSB filling
   Vulnerability Srbds:             Not affected
   Vulnerability Tsx async abort:   Not affected
   Flags:                           fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov
                                     pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe sy
                                    scall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs
                                     bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf pni
                                     pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg f
                                    ma cx16 xtpr pdcm pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_
                                    deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowpre
                                    fetch cpuid_fault invpcid_single ssbd ibrs ibpb stibp ibrs_en
                                    hanced tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase 
                                    tsc_adjust bmi1 avx2 smep bmi2 erms invpcid mpx rdseed adx sm
                                    ap clflushopt intel_pt xsaveopt xsavec xgetbv1 xsaves dtherm 
                                    ida arat pln pts hwp hwp_notify hwp_act_window hwp_epp pku os
                                    pke md_clear flush_l1d arch_capabilities
   ----------Network Test----------
   Setting timeout: 10
   Timing for MXNet: https://github.com/apache/incubator-mxnet, DNS: 0.0186 sec, LOAD: 0.8468 sec.
   Timing for Gluon Tutorial(en): http://gluon.mxnet.io, DNS: 0.4067 sec, LOAD: 0.6158 sec.
   Error open Gluon Tutorial(cn): https://zh.gluon.ai, <urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1123)>, DNS finished in 0.714545488357544 sec.
   Timing for FashionMNIST: https://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/gluon/dataset/fashion-mnist/train-labels-idx1-ubyte.gz, DNS: 0.1729 sec, LOAD: 1.0142 sec.
   Timing for PYPI: https://pypi.python.org/pypi/pip, DNS: 0.0186 sec, LOAD: 1.7548 sec.
   Error open Conda: https://repo.continuum.io/pkgs/free/, HTTP Error 403: Forbidden, DNS finished in 0.13265538215637207 sec.
   ----------Environment----------
   KMP_DUPLICATE_LIB_OK="True"
   KMP_INIT_AT_FORK="FALSE"
   
   
   ```
   
   </details>
   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscribe@mxnet.apache.org
For additional commands, e-mail: issues-help@mxnet.apache.org


[GitHub] [incubator-mxnet] feevos commented on issue #19609: mxnet 2.0 - GroupNorm does not work with as_nd_ndarray() upon activation of mx.npx.set_np()

Posted by GitBox <gi...@apache.org>.
feevos commented on issue #19609:
URL: https://github.com/apache/incubator-mxnet/issues/19609#issuecomment-774105316


   Dear all, this is the final solution that is working for me, using legacy operator mx.nd.GroupNorm
   
   ```Python
   import mxnet as mx
   from mxnet.gluon import HybridBlock
   from mxnet.gluon.parameter import Parameter
   
   @mx.use_np
   class GroupNorm(HybridBlock):
       r"""
       Applies group normalization to the n-dimensional input array.
       This operator takes an n-dimensional input array where the leftmost 2 axis are
       `batch` and `channel` respectively:
       .. math::
         x = x.reshape((N, num_groups, C // num_groups, ...))
         axis = (2, ...)
         out = \frac{x - mean[x, axis]}{ \sqrt{Var[x, axis] + \epsilon}} * gamma + beta
       Parameters
       ----------
       num_groups: int, default 1
           Number of groups to separate the channel axis into.
       epsilon: float, default 1e-5
           Small float added to variance to avoid dividing by zero.
       center: bool, default True
           If True, add offset of `beta` to normalized tensor.
           If False, `beta` is ignored.
       scale: bool, default True
           If True, multiply by `gamma`. If False, `gamma` is not used.
       beta_initializer: str or `Initializer`, default 'zeros'
           Initializer for the beta weight.
       gamma_initializer: str or `Initializer`, default 'ones'
           Initializer for the gamma weight.
       Inputs:
           - **data**: input tensor with shape (N, C, ...).
       Outputs:
           - **out**: output tensor with the same shape as `data`.
       References
       ----------
           `Group Normalization
           <https://arxiv.org/pdf/1803.08494.pdf>`_
       Examples
       --------
       # Input of shape (2, 3, 4)
       x = mx.nd.array([[[ 0,  1,  2,  3],
                             [ 4,  5,  6,  7],
                             [ 8,  9, 10, 11]],
                            [[12, 13, 14, 15],
                             [16, 17, 18, 19],
                             [20, 21, 22, 23]]])
       # Group normalization is calculated with the above formula
       layer = GroupNorm()
       layer.initialize(ctx=mx.cpu(0))
       layer(x)
       [[[-1.5932543 -1.3035717 -1.0138891 -0.7242065]
         [-0.4345239 -0.1448413  0.1448413  0.4345239]
         [ 0.7242065  1.0138891  1.3035717  1.5932543]]
        [[-1.5932543 -1.3035717 -1.0138891 -0.7242065]
         [-0.4345239 -0.1448413  0.1448413  0.4345239]
         [ 0.7242065  1.0138891  1.3035717  1.5932543]]]
       <NDArray 2x3x4 @cpu(0)>
       """
       def __init__(self, num_groups=1, epsilon=1e-5, center=True, scale=True,
                    beta_initializer='zeros', gamma_initializer='ones',
                    in_channels=0):
           super(GroupNorm, self).__init__()
           self._kwargs = {'eps': epsilon, 'num_groups': num_groups, 'center': center, 'scale': scale}
           self._num_groups = num_groups
           self._epsilon = epsilon
           self._center = center
           self._scale = scale
           self.gamma = Parameter('gamma', grad_req='write' if scale else 'null',
                                  shape=(in_channels,), init=gamma_initializer,
                                  allow_deferred_init=True)
           self.beta = Parameter('beta', grad_req='write' if center else 'null',
                                 shape=(in_channels,), init=beta_initializer,
                                 allow_deferred_init=True)
   
       def infer_shape(self,in_shape):
           # Necessary for mxnet 2.0 
           tshape = in_shape.shape
           self.gamma.shape = tshape[1],
           self.beta.shape = tshape[1],
   
       def forward(self, x):
   
           gamma = self.gamma.data().as_nd_ndarray()
           beta = self.beta.data().as_nd_ndarray()
           x = mx.nd.GroupNorm(data=x.as_nd_ndarray(), gamma=gamma, beta=beta, num_groups=self._num_groups, eps=self._epsilon)
   
           x = x.as_np_ndarray()
           return x
   
   
       def __repr__(self):
           s = '{name}({content}'
           in_channels = self.gamma.shape[0]
           s += ', in_channels={0}'.format(in_channels)
           s += ')'
           return s.format(name=self.__class__.__name__,
                           content=', '.join(['='.join([k, v.__repr__()])
                                              for k, v in self._kwargs.items()]))
   
   
   ```


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscribe@mxnet.apache.org
For additional commands, e-mail: issues-help@mxnet.apache.org


[GitHub] [incubator-mxnet] feevos commented on issue #19609: mxnet 2.0 - GroupNorm does not work with as_nd_ndarray() upon activation of mx.npx.set_np()

Posted by GitBox <gi...@apache.org>.
feevos commented on issue #19609:
URL: https://github.com/apache/incubator-mxnet/issues/19609#issuecomment-738546600


   Workaround that allows operations: 
   
   ```python
   
   net = gluon.nn.GroupNorm(num_groups=4) 
   net.initialize() 
   xx = mx.np.random.rand(3,32,512,512) 
   yy = xx.as_nd_ndarray() 
   mx.npx.reset_np() # <====== This fixes the situation by resetting numpy behaviour (I think!)
   net.summary(yy) 
   mx.npx.set_np()  # <======== This restores numpy behaviour. 
   ```
   
   Output: 
   
   ```python
   --------------------------------------------------------------------------------
           Layer (type)                                Output Shape         Param #
   ================================================================================
                  Input                           (3, 32, 512, 512)               0
            GroupNorm-1                           (3, 32, 512, 512)              64
   ================================================================================
   Parameters in forward computation graph, duplicate included
      Total params: 64
      Trainable params: 64
      Non-trainable params: 0
   Shared params in forward computation graph: 0
   Unique parameters in model: 64
   --------------------------------------------------------------------------------
   ```
   
   This is a small replacement hack for GroupNorm: 
   
   ```python
   import mxnet as mx
   from mxnet  import gluon
   
   class GroupNormHack(gluon.nn.HybridBlock):
       """
       This is a partial fix for issue #19609
       see https://github.com/apache/incubator-mxnet/issues/19609
       """
       def __init__(self, num_groups, **kwards):
           super().__init__(**kwards)
   
           self.norm = gluon.nn.GroupNorm(num_groups=num_groups,**kwards)
   
   
       def forward(self, input):
           tinput = input.as_nd_ndarray() if mx.npx.is_np_array() else input
           mx.npx.reset_np()
           out = self.norm(tinput)
           mx.npx.set_np()
           out = out.as_np_ndarray()
           return out
   ````
   toy run: 
   ```python
   
   In [11]: net = GroupNormHack(num_groups=4) 
       ...: net.initialize() 
       ...: xx = mx.np.random.rand(3,32,512,512) 
       ...: net.summary(xx)                                                                      
   --------------------------------------------------------------------------------
           Layer (type)                                Output Shape         Param #
   ================================================================================
                  Input                           (3, 32, 512, 512)               0
            GroupNorm-1                           (3, 32, 512, 512)              64
        GroupNormHack-2                           (3, 32, 512, 512)               0
   ================================================================================
   Parameters in forward computation graph, duplicate included
      Total params: 64
      Trainable params: 64
      Non-trainable params: 0
   Shared params in forward computation graph: 0
   Unique parameters in model: 64
   --------------------------------------------------------------------------------
   
   In [12]: yy = xx.as_nd_ndarray()                                                              
   
   In [13]: net.summary(yy)                                                                      
   --------------------------------------------------------------------------------
           Layer (type)                                Output Shape         Param #
   ================================================================================
                  Input                           (3, 32, 512, 512)               0
            GroupNorm-1                           (3, 32, 512, 512)              64
        GroupNormHack-2                           (3, 32, 512, 512)               0
   ================================================================================
   Parameters in forward computation graph, duplicate included
      Total params: 64
      Trainable params: 64
      Non-trainable params: 0
   Shared params in forward computation graph: 0
   Unique parameters in model: 64
   --------------------------------------------------------------------------------
   
   In [14]: net.hybridize()                                                                      
   
   In [15]: out = net(xx)                                                                        
   
   In [16]: out = net(yy)                                                                        
   
   In [17]:  
   
   ```


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscribe@mxnet.apache.org
For additional commands, e-mail: issues-help@mxnet.apache.org


[GitHub] [incubator-mxnet] feevos commented on issue #19609: mxnet 2.0 - GroupNorm does not work with as_nd_ndarray() upon activation of mx.npx.set_np()

Posted by GitBox <gi...@apache.org>.
feevos commented on issue #19609:
URL: https://github.com/apache/incubator-mxnet/issues/19609#issuecomment-773565727


   The workaround presented above does not allow me to load saved weights - it gives the same error.  


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscribe@mxnet.apache.org
For additional commands, e-mail: issues-help@mxnet.apache.org


[GitHub] [incubator-mxnet] feevos commented on issue #19609: mxnet 2.0 - GroupNorm does not work with as_nd_ndarray() upon activation of mx.npx.set_np()

Posted by GitBox <gi...@apache.org>.
feevos commented on issue #19609:
URL: https://github.com/apache/incubator-mxnet/issues/19609#issuecomment-773565727


   The workaround presented above does not allow me to load saved weights - it gives the same error.  


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscribe@mxnet.apache.org
For additional commands, e-mail: issues-help@mxnet.apache.org