You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2018/08/22 01:48:45 UTC
[incubator-mxnet] branch master updated: [MXNET-795] Fix a bug that
CutSubgraph works only when each subgraph has its distinct name (#12106)
This is an automated email from the ASF dual-hosted git repository.
zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new c692ffd [MXNET-795] Fix a bug that CutSubgraph works only when each subgraph has its distinct name (#12106)
c692ffd is described below
commit c692ffde3d7fede821043f9ec50ccdd062fbdc6c
Author: Junru Shao <ju...@gmail.com>
AuthorDate: Wed Aug 22 09:48:33 2018 +0800
[MXNET-795] Fix a bug that CutSubgraph works only when each subgraph has its distinct name (#12106)
* Copy only when necessary
* Fix typo
* Add unittest
---
python/mxnet/attribute.py | 2 ++
python/mxnet/symbol/contrib.py | 11 ++++++
tests/python/unittest/test_contrib_control_flow.py | 41 ++++++++++++++++++++++
3 files changed, 54 insertions(+)
diff --git a/python/mxnet/attribute.py b/python/mxnet/attribute.py
index 17044dd..1a7bd44 100644
--- a/python/mxnet/attribute.py
+++ b/python/mxnet/attribute.py
@@ -20,6 +20,7 @@
from __future__ import absolute_import
import threading
import warnings
+from collections import defaultdict
from .base import string_types, classproperty, with_metaclass, _MXClassPropertyMetaClass
@@ -34,6 +35,7 @@ class AttrScope(with_metaclass(_MXClassPropertyMetaClass, object)):
The attributes to set for all symbol creations in the scope.
"""
_current = threading.local()
+ _subgraph_names = defaultdict(int)
def __init__(self, **kwargs):
self._old_scope = None
diff --git a/python/mxnet/symbol/contrib.py b/python/mxnet/symbol/contrib.py
index 38195bd..f40a372 100644
--- a/python/mxnet/symbol/contrib.py
+++ b/python/mxnet/symbol/contrib.py
@@ -124,6 +124,14 @@ def _cut_subgraph(subg):
syms.append(s)
return syms
+def _get_unique_subgraph_name(subgraph_name):
+ attrs = AttrScope._current.value._attr
+ if attrs.get("__subgraph_name__", "") != "":
+ subgraph_name = "".join([attrs["__subgraph_name__"], "$", subgraph_name])
+ AttrScope._subgraph_names[subgraph_name] += 1
+ subgraph_name = subgraph_name + str(AttrScope._subgraph_names[subgraph_name] - 1)
+ return subgraph_name
+
# This construct a subgraph for given output nodes.
# If an output node is one of the input nodes, we call identity to make sure
# that outputs nodes are different from input nodes.
@@ -232,6 +240,7 @@ def foreach(body, data, init_states, name="foreach"):
# the python function, we need to prune the computation graph constructed from
# the function. One way of doing it is to mark the nodes in the computation graph
# with AttrScope and prune the nodes without the special attribute.
+ name = _get_unique_subgraph_name(name)
with AttrScope(__subgraph_name__=name):
if isinstance(data, list):
in_eles = [symbol.var(sym.name) for sym in data]
@@ -456,6 +465,7 @@ def while_loop(cond, func, loop_vars, max_iterations=None, name="while_loop"):
return list(step_output), list(new_loop_vars)
def _create_subgraph(graph_vars, graph_func, subgraph_name):
+ subgraph_name = _get_unique_subgraph_name(subgraph_name)
with AttrScope(__subgraph_name__=subgraph_name):
# create new variables with the same name,
# them feed them to the given func
@@ -619,6 +629,7 @@ def cond(pred, then_func, else_func, name="cond"):
return inputs
def _create_subgraph(graph_vars, graph_func, subgraph_name):
+ subgraph_name = _get_unique_subgraph_name(subgraph_name)
with AttrScope(__subgraph_name__=subgraph_name):
# create new variables with the same name,
# them feed them to the given func
diff --git a/tests/python/unittest/test_contrib_control_flow.py b/tests/python/unittest/test_contrib_control_flow.py
index 76d0218..54f22a8 100644
--- a/tests/python/unittest/test_contrib_control_flow.py
+++ b/tests/python/unittest/test_contrib_control_flow.py
@@ -20,8 +20,10 @@ import numpy as np
import mxnet as mx
from mxnet import gluon
from numpy.testing import assert_allclose, assert_array_equal
+from collections import defaultdict
from mxnet.test_utils import *
from mxnet.base import _as_list
+from mxnet.attribute import AttrScope
from common import with_seed
@@ -1765,6 +1767,45 @@ def test_cut_subgraph_cond():
assert_almost_equal(res1.asnumpy(), res2.asnumpy(), rtol=1e-3, atol=1e-3)
+def test_scope():
+ class TestBlock1(gluon.HybridBlock):
+ def __init__(self, prefix=None, params=None):
+ super(TestBlock1, self).__init__(prefix=prefix, params=params)
+ def hybrid_forward(self, F, data):
+ (new_data, ) = F.contrib.cond(
+ data > 0.5,
+ then_func=lambda: data * 2,
+ else_func=lambda: data * 3,
+ name="my_cond",
+ )
+ return new_data
+ class TestBlock2(gluon.HybridBlock):
+ def __init__(self, prefix=None, params=None):
+ super(TestBlock2, self).__init__(prefix=prefix, params=params)
+ def hybrid_forward(self, F, data):
+ (new_data, ) = F.contrib.cond(
+ data > 0.5,
+ then_func=lambda: data * 2,
+ else_func=lambda: data * 3,
+ name="my_cond",
+ )
+ return new_data
+ AttrScope._subgraph_names = defaultdict(int)
+ data = mx.nd.normal(loc=0, scale=1, shape=(1, ))
+ block1 = TestBlock1()
+ block1.initialize(ctx=default_context())
+ block1.hybridize()
+ _ = block1(data)
+ block2 = TestBlock2()
+ block2.initialize(ctx=default_context())
+ block2.hybridize()
+ _ = block2(data)
+ assert len(AttrScope._subgraph_names) == 3
+ assert AttrScope._subgraph_names['my_cond_else'] == 2
+ assert AttrScope._subgraph_names['my_cond_pred'] == 2
+ assert AttrScope._subgraph_names['my_cond_then'] == 2
+
+
if __name__ == '__main__':
import nose
nose.runmodule()