You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2020/07/08 11:52:40 UTC

[GitHub] [incubator-tvm] lixiaoquan commented on a change in pull request #5963: [TF]Refine LSTMBlockCell to support dynamic rnn

lixiaoquan commented on a change in pull request #5963:
URL: https://github.com/apache/incubator-tvm/pull/5963#discussion_r451483442



##########
File path: python/tvm/relay/frontend/tensorflow.py
##########
@@ -1990,6 +1990,66 @@ def _impl(inputs, attr, params, mod):
         return  _res
     return _impl
 
+def _LSTMBlockCell():
+    def _impl(inputs, attr, params, mod):
+        """LSTM Block cell.
+        Calculations and return values are described in:
+        https://github.com/tensorflow/tensorflow/blob/
+        r1.8/tensorflow/contrib/rnn/python/ops/lstm_ops.py#L41-L114
+
+        Parameters
+        ----------
+        inputs : relay.Expr
+            Input data
+        in_state_c: list of relay.Expr
+            Cell state input values for all the layers
+        in_state_h: list of relay.Expr
+            Hidden state input values for all the layers
+        attrs : dict
+            Dict of operator attributes
+        params : dict
+            List of pretrained weights and bias
+
+        Returns
+        -------
+        relay.Expr.TupleWapper
+            [dummy, cs, dummy, dummy, dummy, dummy, h]
+            Only cs and h which are useful are returned
+        """
+        in_data = inputs[0]
+        in_state_c = inputs[1]
+        in_state_h = inputs[2]
+        in_weight = inputs[3]
+        in_bias = inputs[7]
+        forget_bias = attr.pop('forget_bias')
+        input_shape = _infer_shape(inputs[0], mod)
+        weight_shape = _infer_shape(inputs[3], mod)
+        batch_size, input_size = input_shape[0], input_shape[1]
+        num_hidden_layers = weight_shape[1]
+
+        in_data = _op.reshape(in_data,
+                              newshape=(batch_size, input_size))
+        ixh = _op.concatenate([in_data, in_state_h], axis=1)
+        in_weight = _op.transpose(in_weight, axes=None)
+        gates = _op.nn.dense(ixh, in_weight,
+                             units=num_hidden_layers)
+        gates_bias = _op.add(gates, in_bias)
+        gate_list = _op.split(gates_bias, indices_or_sections=4, axis=1)
+        in_gate = _op.sigmoid(gate_list[0])
+        in_transform = _op.tanh(gate_list[1])
+        forget_gate = _op.add(gate_list[2], tvm.relay.const(forget_bias, attr['T'].name))
+        forget_gate = _op.sigmoid(forget_gate)
+        out_gate = _op.sigmoid(gate_list[3])
+        next_c = _op.add(_op.multiply(forget_gate, in_state_c),
+                         _op.multiply(in_gate, in_transform))
+        next_h = out_gate * _op.tanh(next_c)
+        # Return dummy for those unused values
+        dummy = tvm.relay.const(0)
+        return tvm.relay.TupleWrapper(
+            tvm.relay.Tuple([dummy, next_c, dummy, dummy, dummy, dummy, next_h]), 7)

Review comment:
       I added all return objects, but I just can't test them, because tensorflow.contrib.rnn.LSTMBlockCell.call() only return h and new_states
   
   I think if users uses API to contruct graph, other values won't be used.
   
       (cs_prev, h_prev) = state
       (_, cs, _, _, _, _, h) = _lstm_block_cell(
           inputs,
           cs_prev,
           h_prev,
           self._kernel,
           self._bias,
           wci=wci,
           wcf=wcf,
           wco=wco,
           forget_bias=self._forget_bias,
           cell_clip=self._cell_clip,
           use_peephole=self._use_peephole)
   
       new_state = rnn_cell_impl.LSTMStateTuple(cs, h)
       return h, new_state
   




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org