You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2019/09/26 01:01:13 UTC

[GitHub] [incubator-mxnet] sxjscience commented on a change in pull request #16280: [Gluon] Support None argument in HybridBlock

sxjscience commented on a change in pull request #16280: [Gluon] Support None argument in HybridBlock
URL: https://github.com/apache/incubator-mxnet/pull/16280#discussion_r328398016
 
 

 ##########
 File path: tests/python/unittest/test_gluon.py
 ##########
 @@ -441,6 +441,63 @@ def test_sparse_hybrid_block():
     # an exception is expected when forwarding a HybridBlock w/ sparse param
     y = net(x)
 
+@with_seed()
+def test_hybrid_block_none_args():
+    class Foo(HybridBlock):
+        def hybrid_forward(self, F, a, b):
+            if a is None and b is not None:
+                return b
+            elif b is None and a is not None:
+                return a
+            elif a is not None and b is not None:
+                return a + b
+            else:
+                raise NotImplementedError
+
+    class FooNested(HybridBlock):
+        def __init__(self, prefix=None, params=None):
+            super(FooNested, self).__init__(prefix=prefix, params=params)
+            self.f1 = Foo(prefix='foo1')
+            self.f2 = Foo(prefix='foo2')
+            self.f3 = Foo(prefix='foo3')
+
+        def hybrid_forward(self, F, a, b):
+            data = self.f1(a, b)
+            data = self.f2(a, data)
+            data = self.f3(data, b)
+            return data
+
+    for arg_inputs in [(None, mx.nd.ones((10,))),
+                       (mx.nd.ones((10,)), mx.nd.ones((10,))),
+                       (mx.nd.ones((10,)), None)]:
+        foo1 = FooNested(prefix='foo_nested_hybridized')
+        foo1.hybridize()
+        foo2 = FooNested(prefix='foo_nested_nohybrid')
+        for _ in range(2):
 
 Review comment:
   Yes, it triggers the part of the code when the operators are cached.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services