You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by jx...@apache.org on 2017/08/01 03:40:04 UTC

[incubator-mxnet] branch master updated: Gluon RNN fixes for seqlen 1 (#7260)

This is an automated email from the ASF dual-hosted git repository.

jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 5393002  Gluon RNN fixes for seqlen 1 (#7260)
5393002 is described below

commit 5393002a2299ea79ad857cc015d703d63bc641ec
Author: Leonard Lausen <le...@lausen.nl>
AuthorDate: Tue Aug 1 12:40:01 2017 +0900

    Gluon RNN fixes for seqlen 1 (#7260)
    
    * Gluon RNN fixes for seqlen 1
    
    * Use _as_list from base_module
    
    * Move _as_list to base and allow tuples
---
 python/mxnet/base.py               | 18 ++++++++++++++++++
 python/mxnet/gluon/rnn/rnn_cell.py |  7 ++++---
 python/mxnet/module/base_module.py | 18 +-----------------
 3 files changed, 23 insertions(+), 20 deletions(-)

diff --git a/python/mxnet/base.py b/python/mxnet/base.py
index f714924..6d53752 100644
--- a/python/mxnet/base.py
+++ b/python/mxnet/base.py
@@ -300,3 +300,21 @@ def add_fileline_to_docstring(module, incursive=True):
             _add_fileline(obj.__func__)
         if inspect.isclass(obj) and incursive:
             add_fileline_to_docstring(obj, False)
+
+def _as_list(obj):
+    """A utility function that converts the argument to a list if it is not already.
+
+    Parameters
+    ----------
+    obj : object
+
+    Returns
+    -------
+    If `obj` is a list or tuple, return it. Otherwise, return `[obj]` as a
+    single-element list.
+
+    """
+    if isinstance(obj, (list, tuple)):
+        return obj
+    else:
+        return [obj]
diff --git a/python/mxnet/gluon/rnn/rnn_cell.py b/python/mxnet/gluon/rnn/rnn_cell.py
index e6ce65b..87c656c 100644
--- a/python/mxnet/gluon/rnn/rnn_cell.py
+++ b/python/mxnet/gluon/rnn/rnn_cell.py
@@ -6,7 +6,7 @@
 from __future__ import print_function
 
 from ... import symbol, ndarray
-from ...base import string_types, numeric_types
+from ...base import string_types, numeric_types, _as_list
 from ..block import Block, HybridBlock
 from ..utils import _indent
 from .. import tensor_types
@@ -50,8 +50,9 @@ def _format_sequence(length, inputs, layout, merge, in_layout=None):
         batch_size = inputs.shape[batch_axis]
         if merge is False:
             assert length is None or length == inputs.shape[in_axis]
-            inputs = ndarray.split(inputs, axis=in_axis, num_outputs=inputs.shape[in_axis],
-                                   squeeze_axis=1)
+            inputs = _as_list(ndarray.split(inputs, axis=in_axis,
+                                            num_outputs=inputs.shape[in_axis],
+                                            squeeze_axis=1))
     else:
         assert length is None or len(inputs) == length
         if isinstance(inputs[0], symbol.Symbol):
diff --git a/python/mxnet/module/base_module.py b/python/mxnet/module/base_module.py
index cb6cfcc..cacce25 100644
--- a/python/mxnet/module/base_module.py
+++ b/python/mxnet/module/base_module.py
@@ -12,23 +12,7 @@ from ..context import cpu
 from ..model import BatchEndParam
 from ..initializer import Uniform
 from ..io import DataDesc
-
-
-def _as_list(obj):
-    """A utility function that treat the argument as a list.
-
-    Parameters
-    ----------
-    obj : object
-
-    Returns
-    -------
-    If `obj` is a list, return it. Otherwise, return `[obj]` as a single-element list.
-    """
-    if isinstance(obj, list):
-        return obj
-    else:
-        return [obj]
+from ..base import _as_list
 
 
 def _check_input_names(symbol, names, typename, throw):

-- 
To stop receiving notification emails like this one, please contact
['"commits@mxnet.apache.org" <co...@mxnet.apache.org>'].