You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by jx...@apache.org on 2017/08/01 03:40:04 UTC
[incubator-mxnet] branch master updated: Gluon RNN fixes for seqlen
1 (#7260)
This is an automated email from the ASF dual-hosted git repository.
jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 5393002 Gluon RNN fixes for seqlen 1 (#7260)
5393002 is described below
commit 5393002a2299ea79ad857cc015d703d63bc641ec
Author: Leonard Lausen <le...@lausen.nl>
AuthorDate: Tue Aug 1 12:40:01 2017 +0900
Gluon RNN fixes for seqlen 1 (#7260)
* Gluon RNN fixes for seqlen 1
* Use _as_list from base_module
* Move _as_list to base and allow tuples
---
python/mxnet/base.py | 18 ++++++++++++++++++
python/mxnet/gluon/rnn/rnn_cell.py | 7 ++++---
python/mxnet/module/base_module.py | 18 +-----------------
3 files changed, 23 insertions(+), 20 deletions(-)
diff --git a/python/mxnet/base.py b/python/mxnet/base.py
index f714924..6d53752 100644
--- a/python/mxnet/base.py
+++ b/python/mxnet/base.py
@@ -300,3 +300,21 @@ def add_fileline_to_docstring(module, incursive=True):
_add_fileline(obj.__func__)
if inspect.isclass(obj) and incursive:
add_fileline_to_docstring(obj, False)
+
+def _as_list(obj):
+ """A utility function that converts the argument to a list if it is not already.
+
+ Parameters
+ ----------
+ obj : object
+
+ Returns
+ -------
+ If `obj` is a list or tuple, return it. Otherwise, return `[obj]` as a
+ single-element list.
+
+ """
+ if isinstance(obj, (list, tuple)):
+ return obj
+ else:
+ return [obj]
diff --git a/python/mxnet/gluon/rnn/rnn_cell.py b/python/mxnet/gluon/rnn/rnn_cell.py
index e6ce65b..87c656c 100644
--- a/python/mxnet/gluon/rnn/rnn_cell.py
+++ b/python/mxnet/gluon/rnn/rnn_cell.py
@@ -6,7 +6,7 @@
from __future__ import print_function
from ... import symbol, ndarray
-from ...base import string_types, numeric_types
+from ...base import string_types, numeric_types, _as_list
from ..block import Block, HybridBlock
from ..utils import _indent
from .. import tensor_types
@@ -50,8 +50,9 @@ def _format_sequence(length, inputs, layout, merge, in_layout=None):
batch_size = inputs.shape[batch_axis]
if merge is False:
assert length is None or length == inputs.shape[in_axis]
- inputs = ndarray.split(inputs, axis=in_axis, num_outputs=inputs.shape[in_axis],
- squeeze_axis=1)
+ inputs = _as_list(ndarray.split(inputs, axis=in_axis,
+ num_outputs=inputs.shape[in_axis],
+ squeeze_axis=1))
else:
assert length is None or len(inputs) == length
if isinstance(inputs[0], symbol.Symbol):
diff --git a/python/mxnet/module/base_module.py b/python/mxnet/module/base_module.py
index cb6cfcc..cacce25 100644
--- a/python/mxnet/module/base_module.py
+++ b/python/mxnet/module/base_module.py
@@ -12,23 +12,7 @@ from ..context import cpu
from ..model import BatchEndParam
from ..initializer import Uniform
from ..io import DataDesc
-
-
-def _as_list(obj):
- """A utility function that treat the argument as a list.
-
- Parameters
- ----------
- obj : object
-
- Returns
- -------
- If `obj` is a list, return it. Otherwise, return `[obj]` as a single-element list.
- """
- if isinstance(obj, list):
- return obj
- else:
- return [obj]
+from ..base import _as_list
def _check_input_names(symbol, names, typename, throw):
--
To stop receiving notification emails like this one, please contact
['"commits@mxnet.apache.org" <co...@mxnet.apache.org>'].