You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/08/04 17:13:33 UTC

[GitHub] eric-haibin-lin commented on a change in pull request #11948: [MXNET-766] add unroll RNN for HybridBlock

eric-haibin-lin commented on a change in pull request #11948: [MXNET-766] add unroll RNN for HybridBlock
URL: https://github.com/apache/incubator-mxnet/pull/11948#discussion_r207714212
 
 

 ##########
 File path: python/mxnet/gluon/contrib/rnn/rnn_cell.py
 ##########
 @@ -315,3 +317,138 @@ def hybrid_forward(self, F, inputs, states, i2h_weight,
 
         return next_r, [next_r, next_c]
     # pylint: enable= arguments-differ
+
+
+def _contrib_format_sequence(inputs, layout, in_layout=None):
+    assert inputs is not None, \
+        "unroll(inputs=None) has been deprecated. " \
+        "Please create input variables outside unroll."
+
+    axis = layout.find('T')
+    batch_axis = layout.find('N')
+    batch_size = 0
+    in_axis = in_layout.find('T') if in_layout is not None else axis
+    assert isinstance(inputs, tensor_types)
+    if isinstance(inputs, symbol.Symbol):
+        F = symbol
+    else:
+        F = ndarray
+        batch_size = inputs.shape[batch_axis]
+
+    if axis != in_axis:
+        inputs = F.swapaxes(inputs, dim1=axis, dim2=in_axis)
+
+    return inputs, axis, F, batch_size
+
+
+def unroll(cell, inputs, begin_state, drop_inputs=0, drop_outputs=0,
+           layout='TNC', valid_length=None):
+    """Unrolls an RNN cell across time steps.
+
+    Currently, 'TNC' is a preferred layout. unroll on the input of this layout
+    runs much faster.
+
+    Parameters
+    ----------
+    cell : an object whose base class is RNNCell.
+        The RNN cell to run on the input sequence.
+    inputs : Symbol
+        It should have shape (batch_size, length, ...) if `layout` is 'NTC',
+        or (length, batch_size, ...) if `layout` is 'TNC'.
+    begin_state : nested list of Symbol
+        The initial states of the RNN sequence.
+    drop_inputs : float, default 0.
+        The dropout rate for inputs. Won't apply dropout if it equals 0.
+    drop_outputs : float, default 0.
+        The dropout rate for outputs. Won't apply dropout if it equals 0.
+    layout : str, optional
+        `layout` of input symbol. Only used if inputs
+        is a single Symbol.
+    valid_length : Symbol, NDArray or None
+        `valid_length` specifies the length of the sequences in the batch without padding.
+        This option is especially useful for building sequence-to-sequence models where
+        the input and output sequences would potentially be padded.
+        If `valid_length` is None, all sequences are assumed to have the same length.
+        If `valid_length` is a Symbol or NDArray, it should have shape (batch_size,).
+        The ith element will be the length of the ith sequence in the batch.
+        The last valid state will be return and the padded outputs will be masked with 0.
+        Note that `valid_length` must be smaller or equal to `length`.
+
+    Returns
+    -------
+    outputs : Symbol
+        the output of the RNN from this unrolling.
+
+    states : list of Symbol
+        The new state of this RNN after this unrolling.
+        The type of this symbol is same as the output of `begin_state`.
+
+    Examples
+    --------
+    >>> seq_len = 3
+    >>> batch_size = 2
+    >>> input_size = 5
+    >>> cell = mx.gluon.rnn.LSTMCell(input_size, prefix='rnn_')
+    >>> cell.initialize(ctx=mx.cpu())
+    >>> rnn_data = mx.nd.normal(loc=0, scale=1, shape=(seq_len, batch_size, input_size))
+    >>> state_shape = (batch_size, input_size)
+    >>> states = [mx.nd.normal(loc=0, scale=1, shape=state_shape) for i in range(2)]
+    >>> valid_length = mx.nd.array([2, 3])
+    >>> output, states = mx.gluon.contrib.rnn.rnn_cell.unroll(cell, rnn_data, states,
+                                                              valid_length=valid_length,
+                                                              layout='TNC')
+    >>> print(output)
+    [[[ 0.00767238  0.00023103  0.03973929 -0.00925503 -0.05660512]
+      [ 0.00881535  0.05428379 -0.02493718 -0.01834097  0.02189514]]
+     [[-0.00676967  0.01447039  0.01287002 -0.00574152 -0.05734247]
+      [ 0.01568508  0.02650866 -0.04270559 -0.04328435  0.00904011]]
+     [[ 0.          0.          0.          0.          0.        ]
+      [ 0.01055336  0.02734251 -0.03153727 -0.03742751 -0.01378113]]]
+     <NDArray 3x2x5 @cpu(0)>
+    """
+
+    inputs, axis, F, _ = _contrib_format_sequence(inputs, layout)
+    if axis != 0:
+        axes = list(range(len(layout)))
+        tmp = axes[0]
 
 Review comment:
   nit: you can swap element by `foo[i], foo[j] = foo[j], foo[i]` in python

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services