You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by la...@apache.org on 2020/09/15 05:15:45 UTC
[incubator-mxnet] branch master updated: Fix legacy codepath
detection feature for decorated HybridBlocks (#19143)
This is an automated email from the ASF dual-hosted git repository.
lausen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 179262b Fix legacy codepath detection feature for decorated HybridBlocks (#19143)
179262b is described below
commit 179262b1e357f3ed7bd2825a5a16259b45a2c40f
Author: Leonard Lausen <la...@amazon.com>
AuthorDate: Tue Sep 15 05:14:21 2020 +0000
Fix legacy codepath detection feature for decorated HybridBlocks (#19143)
if self.hybrid_forward.__func__ is not HybridBlock.hybrid_forward check for detecting legacy Blocks yields false positives on Gluon 2 Blocks if they are wrapped with a class decorator. This leads to hybridization to silently fail on Gluon 2 Blocks that make use of class decorator such as @use_np.
---
python/mxnet/gluon/block.py | 7 +++--
tests/python/unittest/test_deferred_compute.py | 41 +++++++++++++++-----------
2 files changed, 28 insertions(+), 20 deletions(-)
diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py
index bac14ca..8fd7dd3 100644
--- a/python/mxnet/gluon/block.py
+++ b/python/mxnet/gluon/block.py
@@ -21,6 +21,7 @@
__all__ = ['Block', 'HybridBlock', 'SymbolBlock']
import copy
+import inspect
import warnings
import weakref
from collections import OrderedDict, defaultdict
@@ -984,7 +985,7 @@ class HybridBlock(Block):
def _get_graph(self, *args):
if not self._cached_graph:
- if self.hybrid_forward.__func__ is not HybridBlock.hybrid_forward: # Gluon 1
+ if inspect.unwrap(self.hybrid_forward.__func__) is not HybridBlock.hybrid_forward:
return self._get_graph_v1(*args)
else: # Gluon 2 based on deferred compute mode
return self._get_graph_v2(*args)
@@ -1277,7 +1278,7 @@ class HybridBlock(Block):
def infer_shape(self, *args):
"""Infers shape of Parameters from inputs."""
- if self.hybrid_forward.__func__ is not HybridBlock.hybrid_forward:
+ if inspect.unwrap(self.hybrid_forward.__func__) is not HybridBlock.hybrid_forward:
# Gluon 1 based on F: hybrid_forward is defined by user
self._infer_attrs('infer_shape', 'shape', *args)
else:
@@ -1388,7 +1389,7 @@ class HybridBlock(Block):
cld()._monitor_all = monitor_all
def __call__(self, x, *args):
- if self.hybrid_forward.__func__ is not HybridBlock.hybrid_forward:
+ if inspect.unwrap(self.hybrid_forward.__func__) is not HybridBlock.hybrid_forward:
# Gluon 1 based on F: hybrid_forward is defined by user
return super().__call__(x, *args)
else: # Gluon 2 based on deferred compute mode
diff --git a/tests/python/unittest/test_deferred_compute.py b/tests/python/unittest/test_deferred_compute.py
index 1565fb1..ea6f2b4 100644
--- a/tests/python/unittest/test_deferred_compute.py
+++ b/tests/python/unittest/test_deferred_compute.py
@@ -17,6 +17,7 @@
import functools
import operator
+import tempfile
import numpy as np
@@ -306,8 +307,7 @@ def test_dc_dynamic_shape():
def f(a, *, nd):
return [mx.nd.np.flatnonzero(a)]
- # Skip GraphExecutor test due to https://github.com/apache/incubator-mxnet/issues/17810
- for mode in ('imperative', 'imperativewithnondccompute'):
+ for mode in ('imperative', 'imperativewithnondccompute', 'symbolic', 'all'):
_assert_dc(_dc_simple_setup, f, mode=mode, numpy=True)
@@ -338,10 +338,6 @@ def test_dc_tuple_indexing():
def test_dc_simple_boolean_indexing():
- if mx.test_utils.default_context() == mx.gpu(0) and mx.runtime.Features().is_enabled("TVM_OP"):
- # Skip due to https://github.com/apache/incubator-mxnet/issues/17886
- return
-
def setup(*, nd):
assert nd is mx.np
x = mx.np.array([[0, 1], [1, 1], [2, 2]])
@@ -351,10 +347,6 @@ def test_dc_simple_boolean_indexing():
assert nd is mx.np
return [a[idx].reshape((2, 2))]
- # Skip GraphExecutor test due to https://github.com/apache/incubator-mxnet/issues/17810
- for mode in ('imperative', 'imperativewithnondccompute'):
- _assert_dc(setup, f, mode=mode)
-
def test_dc_list_indexing_error():
def f(a, *, nd):
@@ -428,6 +420,8 @@ def _assert_dc_gluon(setup, net, setup_is_deterministic=True, numpy=True, autogr
_all_same(ys_np, ys_hybrid_np)
+ with tempfile.TemporaryDirectory() as root:
+ net.export(root)
def _dc_gluon_simple_setup(shape=(8, 10), *, nd):
return [nd.ones(shape=shape, ctx=mx.context.current_context())]
@@ -452,12 +446,29 @@ def test_dc_hybridblock():
net = MyBlock()
net.initialize(ctx=contexts)
_assert_dc_gluon(_dc_gluon_simple_setup, net, numpy=False, ctx=ctx)
- with mx.util.np_array(True):
+ with mx.util.np_shape(True), mx.util.np_array(True):
net = MyBlock()
net.initialize(ctx=contexts)
_assert_dc_gluon(_dc_gluon_simple_setup, net, numpy=True, ctx=ctx)
+def test_dc_hybridblock_wrapped():
+ @mx.util.use_np
+ class MyBlock(mx.gluon.HybridBlock):
+ def __init__(self):
+ super().__init__()
+ self.dense = mx.gluon.nn.Dense(units=10, in_units=10)
+ self.weight = mx.gluon.Parameter('weight', shape=(10, ))
+
+ def forward(self, x):
+ assert x.shape[1] == 10 # due to in_units=10 above
+ return self.dense(x) + self.weight.data(x.context)
+
+ net = MyBlock()
+ net.initialize()
+ _assert_dc_gluon(_dc_gluon_simple_setup, net, numpy=True)
+
+
def test_dc_hybridblock_deferred_init_no_infer_shape_error():
class MyBlock(mx.gluon.HybridBlock):
def __init__(self):
@@ -491,17 +502,13 @@ def test_dc_hybridblock_deferred_init():
net = MyBlock()
net.initialize()
_assert_dc_gluon(_dc_gluon_simple_setup, net, numpy=False)
- with mx.util.np_array(True):
+ with mx.util.np_shape(True), mx.util.np_array(True):
net = MyBlock()
net.initialize()
_assert_dc_gluon(_dc_gluon_simple_setup, net, numpy=True)
def test_dc_hybridblock_dynamic_shape():
- if mx.test_utils.default_context() == mx.gpu(0) and mx.runtime.Features().is_enabled("TVM_OP"):
- # Skip due to https://github.com/apache/incubator-mxnet/issues/17886
- return
-
class MyBlock(mx.gluon.HybridBlock):
def __init__(self):
super().__init__()
@@ -515,7 +522,7 @@ def test_dc_hybridblock_dynamic_shape():
x = mx.np.array([[0, 1], [1, 1], [2, 2]])
return [x, x < 2]
- with mx.util.np_array(True):
+ with mx.util.np_shape(True), mx.util.np_array(True):
net = MyBlock()
net.initialize()
_assert_dc_gluon(setup, net, numpy=True)