You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2020/07/08 07:14:51 UTC

[GitHub] [incubator-tvm] yongwww commented on a change in pull request #5963: [TF]Refine LSTMBlockCell to support dynamic rnn

yongwww commented on a change in pull request #5963:
URL: https://github.com/apache/incubator-tvm/pull/5963#discussion_r451330806



##########
File path: python/tvm/relay/frontend/tensorflow.py
##########
@@ -1990,6 +1990,66 @@ def _impl(inputs, attr, params, mod):
         return  _res
     return _impl
 
+def _LSTMBlockCell():
+    def _impl(inputs, attr, params, mod):
+        """LSTM Block cell.
+        Calculations and return values are described in:
+        https://github.com/tensorflow/tensorflow/blob/
+        r1.8/tensorflow/contrib/rnn/python/ops/lstm_ops.py#L41-L114
+
+        Parameters
+        ----------
+        inputs : relay.Expr
+            Input data
+        in_state_c: list of relay.Expr
+            Cell state input values for all the layers
+        in_state_h: list of relay.Expr
+            Hidden state input values for all the layers
+        attrs : dict
+            Dict of operator attributes
+        params : dict
+            List of pretrained weights and bias
+
+        Returns
+        -------
+        relay.Expr.TupleWapper
+            [dummy, cs, dummy, dummy, dummy, dummy, h]
+            Only cs and h which are useful are returned
+        """
+        in_data = inputs[0]
+        in_state_c = inputs[1]
+        in_state_h = inputs[2]
+        in_weight = inputs[3]
+        in_bias = inputs[7]
+        forget_bias = attr.pop('forget_bias')
+        input_shape = _infer_shape(inputs[0], mod)
+        weight_shape = _infer_shape(inputs[3], mod)
+        batch_size, input_size = input_shape[0], input_shape[1]
+        num_hidden_layers = weight_shape[1]
+
+        in_data = _op.reshape(in_data,
+                              newshape=(batch_size, input_size))
+        ixh = _op.concatenate([in_data, in_state_h], axis=1)
+        in_weight = _op.transpose(in_weight, axes=None)
+        gates = _op.nn.dense(ixh, in_weight,
+                             units=num_hidden_layers)
+        gates_bias = _op.add(gates, in_bias)
+        gate_list = _op.split(gates_bias, indices_or_sections=4, axis=1)
+        in_gate = _op.sigmoid(gate_list[0])
+        in_transform = _op.tanh(gate_list[1])
+        forget_gate = _op.add(gate_list[2], tvm.relay.const(forget_bias, attr['T'].name))
+        forget_gate = _op.sigmoid(forget_gate)
+        out_gate = _op.sigmoid(gate_list[3])
+        next_c = _op.add(_op.multiply(forget_gate, in_state_c),
+                         _op.multiply(in_gate, in_transform))
+        next_h = out_gate * _op.tanh(next_c)
+        # Return dummy for those unused values
+        dummy = tvm.relay.const(0)
+        return tvm.relay.TupleWrapper(
+            tvm.relay.Tuple([dummy, next_c, dummy, dummy, dummy, dummy, next_h]), 7)

Review comment:
       I am not sure if the dummy node will be used somehow in some cases. It would be good to generate the real node as TF does for the dummy nodes.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org