You are viewing a plain text version of this content. The canonical link for it is here.
Posted to issues@mxnet.apache.org by GitBox <gi...@apache.org> on 2020/10/02 01:10:45 UTC
[GitHub] [incubator-mxnet] wkcn commented on issue #19264: Same Network can hybridize on CPU but can not hybridize on GPU.
wkcn commented on issue #19264:
URL: https://github.com/apache/incubator-mxnet/issues/19264#issuecomment-702475163
It is a bug.
The temporary solution is to pass multiple NDArrays rather than a list of NDArray.
For example:
```python
import os
os.environ["DMLC_LOG_STACK_TRACE_DEPTH"]="10"
import mxnet as mx
print(mx.__version__)
import mxnet.gluon as gluon
class nn(object):
@staticmethod
def Sequential(*args):
bl = gluon.nn.HybridSequential()
for a in args:
bl.add(a)
return bl
@staticmethod
def Upsample(scale_factor, mode):
# return BilinearResize2D(scale_factor=scale_factor)
return mx.gluon.nn.HybridLambda(lambda F, x: F.contrib.BilinearResize2D(x, scale_width=scale_factor,
scale_height=scale_factor, name="fwd"))
class HighResolutionModule(gluon.nn.HybridBlock):
def __init__(self):
super(HighResolutionModule, self).__init__()
self.relu = mx.gluon.nn.Activation("relu")
self.fff = nn.Sequential(
mx.gluon.nn.Conv2D(in_channels=64, channels=32, kernel_size=3, padding=1),
nn.Upsample(scale_factor=2, mode="nearest")
)
self.fff1 = nn.Sequential(
mx.gluon.nn.Conv2D(in_channels=32, channels=64, kernel_size=3, padding=1, strides=2),
mx.gluon.nn.BatchNorm(axis=1, momentum=.9, in_channels=32)
)
def hybrid_forward(self, F, *x, **kwargs):
y0 = self.relu(x[0] + self.fff(x[1]))
y1 = self.relu(self.fff1(x[0]) + x[1])
return [y0, y1]
class HighResolutionNet(gluon.nn.HybridBlock):
def __init__(self):
super(HighResolutionNet, self).__init__()
self.stage2 = self._make_stage()
def _make_stage(self):
modules = []
for i in range(2):
modules.append(
HighResolutionModule()
)
return nn.Sequential(*modules)
def hybrid_forward(self, F, *x_list):
y_list = self.stage2(*x_list)
return y_list
def get_cls_net():
model = HighResolutionNet()
return model
if __name__ == '__main__':
import easydict
ctx = mx.cpu()
args = easydict.EasyDict()
model = get_cls_net()
model.initialize()
model.reset_ctx(ctx)
model.hybridize()
y_hat = model(mx.nd.random.randn(1, 32, 56, 56, ctx=ctx), mx.nd.random.randn(1, 64, 28, 28, ctx=ctx))
```
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
users@infra.apache.org
---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscribe@mxnet.apache.org
For additional commands, e-mail: issues-help@mxnet.apache.org