You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2022/07/06 12:15:08 UTC

[GitHub] [tvm] yuanfz98 opened a new pull request, #12017: [Pytorch] add aten::rnn_tanh, aten::rnn_relu

yuanfz98 opened a new pull request, #12017:
URL: https://github.com/apache/tvm/pull/12017

   Hello,
   
   This PR supports aten::rnn_tanh, aten::rnn_relu. The idea is from the previous implementation of GRU and LSTM in relay.
   
   Links to issue #11827
   
   ```
   def test_RNN_torch(num_layers: int,
                                       bidirectional: bool,
                                       use_bias: bool,
                                       hidden_size: int,
                                       input_size: int,
                                       seq_len: int,
                                       batch_first: bool,
                                       batch_size: int):
       r''' 
       Args:
           num_layers (int): num_layers to be passed to torch.nn.RNN
           bidirectional (bool): whether to build bidirectional RNN or not
           use_bias (bool): whether to use bias or not
           hidden_size (int): hidden_size of RNN cells
           input_size (int): Input features
           seq_len (int): Timesteps in input data
           batch_first (bool): Whether batch dimension is first or second dimension in input tensor
           batch_size (int): Batch size of input. If 0, unbatched input will be fed to network
       '''
   
       if batch_first:
           input_shape = (batch_size, seq_len, input_size)
       else:
           input_shape = (seq_len, batch_size, input_size)
       pytorch_net = torch.nn.Sequential(
           torch.nn.RNN(input_size,
                        hidden_size,
                        batch_first=batch_first,
                        num_layers=num_layers,
                        bidirectional=bidirectional,
                        bias=use_bias)
       )
   
       scripted_model = torch.jit.trace(pytorch_net.eval(),
                                        torch.randn(input_shape))
   
       mod, params = relay.frontend.from_pytorch(scripted_model,
                                                 [('input', input_shape)])
       mod = relay.transform.InferType()(mod)
       print(mod.astext())
   
   if __name__ == "__main__":
   
       test_RNN_torch(1,
                      False,
                      True,
                      5,
                      5,
                      15,
                      True,
                      32)
   
   ```
   
   Out:
   
   ```
   #[version = "0.0.5"]
   type List[A] {
     Cons(A, List[A]),
     Nil,
   }
   
   type Option[A] {
     Some(A),
     None,
   }
   
   type Tree[A] {
     Rose(A, List[Tree[A]]),
   }
   
   type tensor_float16_t {
     tensor_nil_float16,
     tensor0_float16(float16),
     tensor1_float16(Tensor[(?), float16]),
     tensor2_float16(Tensor[(?, ?), float16]),
     tensor3_float16(Tensor[(?, ?, ?), float16]),
     tensor4_float16(Tensor[(?, ?, ?, ?), float16]),
     tensor5_float16(Tensor[(?, ?, ?, ?, ?), float16]),
     tensor6_float16(Tensor[(?, ?, ?, ?, ?, ?), float16]),
   }
   
   type tensor_float32_t {
     tensor_nil_float32,
     tensor0_float32(float32),
     tensor1_float32(Tensor[(?), float32]),
   ...
   ```


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] masahi merged pull request #12017: [Pytorch] add aten::rnn_tanh, aten::rnn_relu

Posted by GitBox <gi...@apache.org>.
masahi merged PR #12017:
URL: https://github.com/apache/tvm/pull/12017


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org