You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2018/08/22 01:48:45 UTC

[incubator-mxnet] branch master updated: [MXNET-795] Fix a bug that CutSubgraph works only when each subgraph has its distinct name (#12106)

This is an automated email from the ASF dual-hosted git repository.

zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new c692ffd  [MXNET-795] Fix a bug that CutSubgraph works only when each subgraph has its distinct name (#12106)
c692ffd is described below

commit c692ffde3d7fede821043f9ec50ccdd062fbdc6c
Author: Junru Shao <ju...@gmail.com>
AuthorDate: Wed Aug 22 09:48:33 2018 +0800

    [MXNET-795] Fix a bug that CutSubgraph works only when each subgraph has its distinct name (#12106)
    
    * Copy only when necessary
    
    * Fix typo
    
    * Add unittest
---
 python/mxnet/attribute.py                          |  2 ++
 python/mxnet/symbol/contrib.py                     | 11 ++++++
 tests/python/unittest/test_contrib_control_flow.py | 41 ++++++++++++++++++++++
 3 files changed, 54 insertions(+)

diff --git a/python/mxnet/attribute.py b/python/mxnet/attribute.py
index 17044dd..1a7bd44 100644
--- a/python/mxnet/attribute.py
+++ b/python/mxnet/attribute.py
@@ -20,6 +20,7 @@
 from __future__ import absolute_import
 import threading
 import warnings
+from collections import defaultdict
 
 from .base import string_types, classproperty, with_metaclass, _MXClassPropertyMetaClass
 
@@ -34,6 +35,7 @@ class AttrScope(with_metaclass(_MXClassPropertyMetaClass, object)):
         The attributes to set for all symbol creations in the scope.
     """
     _current = threading.local()
+    _subgraph_names = defaultdict(int)
 
     def __init__(self, **kwargs):
         self._old_scope = None
diff --git a/python/mxnet/symbol/contrib.py b/python/mxnet/symbol/contrib.py
index 38195bd..f40a372 100644
--- a/python/mxnet/symbol/contrib.py
+++ b/python/mxnet/symbol/contrib.py
@@ -124,6 +124,14 @@ def _cut_subgraph(subg):
         syms.append(s)
     return syms
 
+def _get_unique_subgraph_name(subgraph_name):
+    attrs = AttrScope._current.value._attr
+    if attrs.get("__subgraph_name__", "") != "":
+        subgraph_name = "".join([attrs["__subgraph_name__"], "$", subgraph_name])
+    AttrScope._subgraph_names[subgraph_name] += 1
+    subgraph_name = subgraph_name + str(AttrScope._subgraph_names[subgraph_name] - 1)
+    return subgraph_name
+
 # This construct a subgraph for given output nodes.
 # If an output node is one of the input nodes, we call identity to make sure
 # that outputs nodes are different from input nodes.
@@ -232,6 +240,7 @@ def foreach(body, data, init_states, name="foreach"):
     # the python function, we need to prune the computation graph constructed from
     # the function. One way of doing it is to mark the nodes in the computation graph
     # with AttrScope and prune the nodes without the special attribute.
+    name = _get_unique_subgraph_name(name)
     with AttrScope(__subgraph_name__=name):
         if isinstance(data, list):
             in_eles = [symbol.var(sym.name) for sym in data]
@@ -456,6 +465,7 @@ def while_loop(cond, func, loop_vars, max_iterations=None, name="while_loop"):
         return list(step_output), list(new_loop_vars)
 
     def _create_subgraph(graph_vars, graph_func, subgraph_name):
+        subgraph_name = _get_unique_subgraph_name(subgraph_name)
         with AttrScope(__subgraph_name__=subgraph_name):
             # create new variables with the same name,
             # them feed them to the given func
@@ -619,6 +629,7 @@ def cond(pred, then_func, else_func, name="cond"):
         return inputs
 
     def _create_subgraph(graph_vars, graph_func, subgraph_name):
+        subgraph_name = _get_unique_subgraph_name(subgraph_name)
         with AttrScope(__subgraph_name__=subgraph_name):
             # create new variables with the same name,
             # them feed them to the given func
diff --git a/tests/python/unittest/test_contrib_control_flow.py b/tests/python/unittest/test_contrib_control_flow.py
index 76d0218..54f22a8 100644
--- a/tests/python/unittest/test_contrib_control_flow.py
+++ b/tests/python/unittest/test_contrib_control_flow.py
@@ -20,8 +20,10 @@ import numpy as np
 import mxnet as mx
 from mxnet import gluon
 from numpy.testing import assert_allclose, assert_array_equal
+from collections import defaultdict
 from mxnet.test_utils import *
 from mxnet.base import _as_list
+from mxnet.attribute import AttrScope
 from common import with_seed
 
 
@@ -1765,6 +1767,45 @@ def test_cut_subgraph_cond():
     assert_almost_equal(res1.asnumpy(), res2.asnumpy(), rtol=1e-3, atol=1e-3)
 
 
+def test_scope():
+    class TestBlock1(gluon.HybridBlock):
+        def __init__(self, prefix=None, params=None):
+            super(TestBlock1, self).__init__(prefix=prefix, params=params)
+        def hybrid_forward(self, F, data):
+            (new_data, ) = F.contrib.cond(
+                data > 0.5,
+                then_func=lambda: data * 2,
+                else_func=lambda: data * 3,
+                name="my_cond",
+            )
+            return new_data
+    class TestBlock2(gluon.HybridBlock):
+        def __init__(self, prefix=None, params=None):
+            super(TestBlock2, self).__init__(prefix=prefix, params=params)
+        def hybrid_forward(self, F, data):
+            (new_data, ) = F.contrib.cond(
+                data > 0.5,
+                then_func=lambda: data * 2,
+                else_func=lambda: data * 3,
+                name="my_cond",
+            )
+            return new_data
+    AttrScope._subgraph_names = defaultdict(int)
+    data = mx.nd.normal(loc=0, scale=1, shape=(1, ))
+    block1 = TestBlock1()
+    block1.initialize(ctx=default_context())
+    block1.hybridize()
+    _ = block1(data)
+    block2 = TestBlock2()
+    block2.initialize(ctx=default_context())
+    block2.hybridize()
+    _ = block2(data)
+    assert len(AttrScope._subgraph_names) == 3
+    assert AttrScope._subgraph_names['my_cond_else'] == 2
+    assert AttrScope._subgraph_names['my_cond_pred'] == 2
+    assert AttrScope._subgraph_names['my_cond_then'] == 2
+
+
 if __name__ == '__main__':
     import nose
     nose.runmodule()