You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2020/04/05 22:51:46 UTC

[GitHub] [incubator-tvm] kevinthesun opened a new pull request #5243: [Frontend][TensorFlow]Improve TensorFlow Static Shape Tensor Array

kevinthesun opened a new pull request #5243: [Frontend][TensorFlow]Improve TensorFlow Static Shape Tensor Array
URL: https://github.com/apache/incubator-tvm/pull/5243
 
 
   Improve TensorFlow frontend to deal with static shape tensor array. After this PR, most tensor array operators will have static input/output shapes.
   
   @wweic @zhiics @yongwww @masahi 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] masahi commented on issue #5243: [Frontend][TensorFlow]Improve TensorFlow Static Shape Tensor Array

Posted by GitBox <gi...@apache.org>.
masahi commented on issue #5243: [Frontend][TensorFlow]Improve TensorFlow Static Shape Tensor Array
URL: https://github.com/apache/incubator-tvm/pull/5243#issuecomment-609552445
 
 
   Hi @kevinthesun, I started experimenting with how to integrate static tensor array in Torch frontend. My use case is to support Python tensor list append and stack. I got two problems below:
   
   1. When I cons the tensor to tensor array, I can do infer shape on input tensor to get the fixed shape static tensor array expects. But after I've done some cons and try to stack the static tensor array, I don't have a way to tell what fixed shape the input tensor array to stack expects. See
   https://github.com/masahi/tvm/blob/support-more-rnn/python/tvm/relay/frontend/pytorch.py#L990
   Since the shape is fixed, I think there should be an easy way to query the shape associated with a static array. I see you have such function `check_tensor_array_shape` in this PR (by parsing op name). Is this the recommended way?
   
   2. The output type of stack is currently `static_tensor_float32_?_2_4_t[]` in my test. Is there a way to easily unwrap static tensor type wrapper and get  relay `Tensor`? @wweic had such unwrapper in https://github.com/apache/incubator-tvm/pull/4325 for generic arrays. We should have something equivalent for static arrays.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] masahi commented on issue #5243: [Frontend][TensorFlow]Improve TensorFlow Static Shape Tensor Array

Posted by GitBox <gi...@apache.org>.
masahi commented on issue #5243: [Frontend][TensorFlow]Improve TensorFlow Static Shape Tensor Array
URL: https://github.com/apache/incubator-tvm/pull/5243#issuecomment-610016636
 
 
   > @masahi You can use [tensor_get_data](https://github.com/apache/incubator-tvm/blob/master/python/tvm/relay/prelude.py#L515) to achieve this
   
   Ah thanks. I tried to use it on the output of stack, but since the first axis is 'Any', I don't know how to pass `shape` param in `get_var_static('tensor_get_data', "float32", shape)`. How do I do it?
   
   A better question might be, why do we need to pass `shape` all over the place? I'd imagine, intuitively stack and other ops that operate on already existing tensor array should be able to figure out the shape, no?

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] kevinthesun commented on issue #5243: [Frontend][TensorFlow]Improve TensorFlow Static Shape Tensor Array

Posted by GitBox <gi...@apache.org>.
kevinthesun commented on issue #5243: [Frontend][TensorFlow]Improve TensorFlow Static Shape Tensor Array
URL: https://github.com/apache/incubator-tvm/pull/5243#issuecomment-609924079
 
 
   > 2\. The output type of stack is currently `static_tensor_float32_?_2_4_t[]` in my test. Is there a way to easily unwrap static tensor type wrapper and get  relay `Tensor`? @wweic had such unwrapper in #4325 for generic arrays. We should have something equivalent for static arrays.
   
   @masahi You can use [tensor_get_data](https://github.com/apache/incubator-tvm/blob/master/python/tvm/relay/prelude.py#L515) to achieve this.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] masahi commented on issue #5243: [Frontend][TensorFlow]Improve TensorFlow Static Shape Tensor Array

Posted by GitBox <gi...@apache.org>.
masahi commented on issue #5243: [Frontend][TensorFlow]Improve TensorFlow Static Shape Tensor Array
URL: https://github.com/apache/incubator-tvm/pull/5243#issuecomment-610032209
 
 
   hmm I tried this:
   
   ```Py
   def _tensor_array_stack(prelude):
       def _impl(inputs, input_types):
           # print(prelude.mod)
           # TODO: how to get the fixed shape of static_tensor_array inputs[0]?
           shape = (2, 4)
           stack = prelude.get_var_static('tensor_array_stack', "float32", shape)
           stacked = stack(inputs[0])
   
           stacked_shape = (Any(), 2, 4)
           static_tensor_array_ops = StaticTensorArrayOps(prelude, "float32", stacked_shape)
           static_tensor_array_ops.register()
           get_tensor = prelude.get_var_static('tensor_get_data', "float32", stacked_shape)
           return get_tensor(stacked)
       return _impl
   ```
   But I'm still getting `AttributeError: 'Prelude' object has no attribute 'tensor_get_data_float32_?_2_4'`.  Do I need a new prelude to register a new shape?

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] kevinthesun commented on issue #5243: [Frontend][TensorFlow]Improve TensorFlow Static Shape Tensor Array

Posted by GitBox <gi...@apache.org>.
kevinthesun commented on issue #5243: [Frontend][TensorFlow]Improve TensorFlow Static Shape Tensor Array
URL: https://github.com/apache/incubator-tvm/pull/5243#issuecomment-612176243
 
 
   > @kevinthesun @wweic Is it reasonable to add "axis" parameter to tensor array concat? I encountered a need to concat along the -1 axis.
   
   @masahi To support different axis we need to change both ```define_tensor_concatenate``` and ```define_tensor_array_concat``` to support axis arguments. The main thing we need to take care is the output shape.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] masahi commented on issue #5243: [Frontend][TensorFlow]Improve TensorFlow Static Shape Tensor Array

Posted by GitBox <gi...@apache.org>.
masahi commented on issue #5243: [Frontend][TensorFlow]Improve TensorFlow Static Shape Tensor Array
URL: https://github.com/apache/incubator-tvm/pull/5243#issuecomment-610052206
 
 
   Great I got the following working. Also confirmed `get_tensor_array_shape` worked. Happy now :) Thank you very much! Unwrapping should enable supporting the "stacked" LSTM in https://github.com/pytorch/pytorch/blob/master/benchmarks/fastrnns/custom_lstms.py#L252, where the output from `LSTMLayer` is pipelined `num_layers` times to get a bigger network. 
   
   ```Py
   def _tensor_array_stack(prelude):
       def _impl(inputs, input_types):
           shape = get_tensor_array_shape(inputs[0], "float32", prelude)
           stack = prelude.get_var_static('tensor_array_stack', "float32", shape)
           stacked = stack(inputs[0])
   
           stacked_shape = (Any(),) + shape
           static_tensor_array_ops = StaticTensorArrayOps(prelude, "float32", shape)
           static_tensor_array_ops.define_tensor_get_data(stacked_shape)
           # passing stacked_shape below gives "'Prelude' object has no attribute" error
           get_tensor = prelude.get_var_static('tensor_get_data', "float32", shape)
           return get_tensor(stacked)
       return _impl
   ```

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] kevinthesun commented on issue #5243: [Frontend][TensorFlow]Improve TensorFlow Static Shape Tensor Array

Posted by GitBox <gi...@apache.org>.
kevinthesun commented on issue #5243: [Frontend][TensorFlow]Improve TensorFlow Static Shape Tensor Array
URL: https://github.com/apache/incubator-tvm/pull/5243#issuecomment-612251085
 
 
   @masahi @wweic PTAL

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] kevinthesun commented on issue #5243: [Frontend][TensorFlow]Improve TensorFlow Static Shape Tensor Array

Posted by GitBox <gi...@apache.org>.
kevinthesun commented on issue #5243: [Frontend][TensorFlow]Improve TensorFlow Static Shape Tensor Array
URL: https://github.com/apache/incubator-tvm/pull/5243#issuecomment-610029111
 
 
   @masahi The shape passed to ```get_var_static``` is for identification. For tensor_get_data, it is just for picking up corresponding global var from prelude mod. You just need to pass the shape with which you created StaticTensorArrayOps. For example, if you create a tensor array with shape (1, 2, 3), you just need to pass (1, 2, 3) to get_var_static. However, to define_tensor_get_data, you want to pass (Any(), 1, 2, 3), since this is the actual output shape.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] masahi commented on issue #5243: [Frontend][TensorFlow]Improve TensorFlow Static Shape Tensor Array

Posted by GitBox <gi...@apache.org>.
masahi commented on issue #5243: [Frontend][TensorFlow]Improve TensorFlow Static Shape Tensor Array
URL: https://github.com/apache/incubator-tvm/pull/5243#issuecomment-612293655
 
 
   Thanks @kevinthesun 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] kevinthesun commented on a change in pull request #5243: [Frontend][TensorFlow]Improve TensorFlow Static Shape Tensor Array

Posted by GitBox <gi...@apache.org>.
kevinthesun commented on a change in pull request #5243: [Frontend][TensorFlow]Improve TensorFlow Static Shape Tensor Array
URL: https://github.com/apache/incubator-tvm/pull/5243#discussion_r404255135
 
 

 ##########
 File path: python/tvm/relay/frontend/common.py
 ##########
 @@ -548,6 +558,28 @@ def new_var(name_hint,
     return _expr.var(name_hint, type_annotation, shape, dtype)
 
 
+def check_tensor_array_shape(expr, dtype, mod):
 
 Review comment:
   Yes. I'll move it to prelude.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] masahi commented on issue #5243: [Frontend][TensorFlow]Improve TensorFlow Static Shape Tensor Array

Posted by GitBox <gi...@apache.org>.
masahi commented on issue #5243: [Frontend][TensorFlow]Improve TensorFlow Static Shape Tensor Array
URL: https://github.com/apache/incubator-tvm/pull/5243#issuecomment-612088720
 
 
   @kevinthesun @wweic Is it reasonable to add "axis" parameter to tensor array concat? I encountered a need to concat along the -1 axis.
   

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] masahi edited a comment on issue #5243: [Frontend][TensorFlow]Improve TensorFlow Static Shape Tensor Array

Posted by GitBox <gi...@apache.org>.
masahi edited a comment on issue #5243: [Frontend][TensorFlow]Improve TensorFlow Static Shape Tensor Array
URL: https://github.com/apache/incubator-tvm/pull/5243#issuecomment-612270589
 
 
   > @masahi To support different axis we need to change both `define_tensor_concatenate` and `define_tensor_array_concat` to support axis arguments. The main thing we need to take care is the output shape.
   
   Ok for now I went an easy route of just defining concat_last op. It seems to work, but I'm getting the following typing error:
   
   ```
   ...
     %101 = @map(tensor_constructor_float32_?_2_4(Tensor[(?, 2, 4), float32]), %100);
     %102 = @tensor_array_concat_last_float32_?_2_?(%101) unable to unify: `static_tensor_float32_?_2_4_t` and `static_tensor_float32_?_2_?_t`; ;
     %103 = @tensor_get_data_float32_?_2_?(%102);
   ...
   ```
   
   The first axis is already Any by tensor array stack. Now I'm trying to concat (?, 2, 4) tensors along -1 axis to get (?, 2, ?) tensor. Is this possible? They typing error suggests no.
   
   UPDATE: Solved by mapping tensor_constructor with concat-ed shape (?, 2, ?):
   ```
     %101 = @map(tensor_constructor_float32_?_2_?(Tensor[(?, 2, ?), float32]), %100) /* ty=List[static_tensor_float32_?_2_?_t[]] */;
     %102 = @tensor_array_concat_last_float32_?_2_?(%101) /* ty=static_tensor_float32_?_2_?_t[] */;
     %103 = @tensor_get_data_float32_?_2_?(%102) /* ty=Tensor[(?, 2, ?), float32] */;
   ```

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] masahi commented on a change in pull request #5243: [Frontend][TensorFlow]Improve TensorFlow Static Shape Tensor Array

Posted by GitBox <gi...@apache.org>.
masahi commented on a change in pull request #5243: [Frontend][TensorFlow]Improve TensorFlow Static Shape Tensor Array
URL: https://github.com/apache/incubator-tvm/pull/5243#discussion_r406498520
 
 

 ##########
 File path: python/tvm/relay/frontend/common.py
 ##########
 @@ -483,6 +481,17 @@ def infer_channels(inputs, transpose=False):
     return channels
 
 
+def infer_shape(inputs, mod=None):
+    """A method to get the output type of an intermediate node in the graph."""
+    out_type = infer_type(inputs, mod=mod)
+    checked_type = out_type.checked_type
+    if hasattr(checked_type, 'shape'):
+        # Regular operator that outputs tensors
+        return get_const_tuple(out_type.checked_type.shape)
+    # The return type is not a tensor, for example List
 
 Review comment:
   `out_type.checked_type -> checked_type`

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] masahi merged pull request #5243: [Frontend][TensorFlow]Improve TensorFlow Static Shape Tensor Array

Posted by GitBox <gi...@apache.org>.
masahi merged pull request #5243: [Frontend][TensorFlow]Improve TensorFlow Static Shape Tensor Array
URL: https://github.com/apache/incubator-tvm/pull/5243
 
 
   

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] masahi edited a comment on issue #5243: [Frontend][TensorFlow]Improve TensorFlow Static Shape Tensor Array

Posted by GitBox <gi...@apache.org>.
masahi edited a comment on issue #5243: [Frontend][TensorFlow]Improve TensorFlow Static Shape Tensor Array
URL: https://github.com/apache/incubator-tvm/pull/5243#issuecomment-609552445
 
 
   Hi @kevinthesun, I started experimenting with how to integrate static tensor array in Torch frontend. My use case is to support Python tensor list append and stack. I got two problems below:
   
   1. When I append the tensor to tensor array (by concat), I can do infer shape on the input tensor to get the fixed shape static tensor array expects. But after I've done some appends and try to stack the static tensor array, I don't have a way to tell what fixed shape the input tensor array to stack expects. See
   https://github.com/masahi/tvm/blob/support-more-rnn/python/tvm/relay/frontend/pytorch.py#L989-L990
   Since the shape is fixed, I think there should be an easy way to query the shape associated with a static array. I see you have such function `check_tensor_array_shape` in this PR (by parsing op name). Is this the recommended way?
   
   2. The output type of stack is currently `static_tensor_float32_?_2_4_t[]` in my test. Is there a way to easily unwrap static tensor type wrapper and get  relay `Tensor`? @wweic had such unwrapper in https://github.com/apache/incubator-tvm/pull/4325 for generic arrays. We should have something equivalent for static arrays.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] masahi commented on issue #5243: [Frontend][TensorFlow]Improve TensorFlow Static Shape Tensor Array

Posted by GitBox <gi...@apache.org>.
masahi commented on issue #5243: [Frontend][TensorFlow]Improve TensorFlow Static Shape Tensor Array
URL: https://github.com/apache/incubator-tvm/pull/5243#issuecomment-612270589
 
 
   > @masahi To support different axis we need to change both `define_tensor_concatenate` and `define_tensor_array_concat` to support axis arguments. The main thing we need to take care is the output shape.
   
   Ok for now I went an easy route of just defining concat_last op. It seems to work, but I'm getting the following typing error:
   
   ```
   ...
     %101 = @map(tensor_constructor_float32_?_2_4(Tensor[(?, 2, 4), float32]), %100);
     %102 = @tensor_array_concat_last_float32_?_2_?(%101) unable to unify: `static_tensor_float32_?_2_4_t` and `static_tensor_float32_?_2_?_t`; ;
     %103 = @tensor_get_data_float32_?_2_?(%102);
   ...
   ```
   
   The first axis is already Any by tensor array stack. Now I'm trying to concat (?, 2, 4) tensors along -1 axis to get (?, 2, ?) tensor. Is this possible? They typing error suggests no.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] masahi edited a comment on issue #5243: [Frontend][TensorFlow]Improve TensorFlow Static Shape Tensor Array

Posted by GitBox <gi...@apache.org>.
masahi edited a comment on issue #5243: [Frontend][TensorFlow]Improve TensorFlow Static Shape Tensor Array
URL: https://github.com/apache/incubator-tvm/pull/5243#issuecomment-609602675
 
 
   Update: With the new static tensor array, I got the following PyTorch LSTM model, originally from the fastrnn benchmark in PyTorch repo here https://github.com/pytorch/pytorch/blob/master/benchmarks/fastrnns/custom_lstms.py#L187, converted correctly to Relay and got the identical result as torch! It was not possible with generic tensor array. @kevinthesun @wweic 
   
   ```Py
   class LSTMCell(jit.ScriptModule):
       def __init__(self, input_size, hidden_size):
           super().__init__()
           self.input_size = input_size
           self.hidden_size = hidden_size
           self.weight_ih = Parameter(torch.randn(4 * hidden_size, input_size))
           self.weight_hh = Parameter(torch.randn(4 * hidden_size, hidden_size))
           self.bias_ih = Parameter(torch.randn(4 * hidden_size))
           self.bias_hh = Parameter(torch.randn(4 * hidden_size))
   
       @jit.script_method
       def forward(self, input, state):
           # type: (Tensor, Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]
           hx, cx = state
           gates = (torch.mm(input, self.weight_ih.t()) + self.bias_ih +
                    torch.mm(hx, self.weight_hh.t()) + self.bias_hh)
           ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
   
           ingate = torch.sigmoid(ingate)
           forgetgate = torch.sigmoid(forgetgate)
           cellgate = torch.tanh(cellgate)
           outgate = torch.sigmoid(outgate)
   
           cy = (forgetgate * cx) + (ingate * cellgate)
           hy = outgate * torch.tanh(cy)
   
           return hy, (hy, cy)
   
   
   class LSTMLayer(jit.ScriptModule):
       def __init__(self, cell, *cell_args):
           super().__init__()
           self.cell = cell(*cell_args)
   
       @jit.script_method
       def forward(self, input, state):
           # type: (Tensor, Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]
           outputs = []
           for i in range(input.size(0)):
               out, state = self.cell(input[i], state)
               outputs += [out]
           return torch.stack(outputs), state
   ```
   
   Here is the converted Relay IR:
   ```
   fn (%input: Tensor[(5, 2, 3), float32], %v25: Tensor[(16, 3), float32], %v28: Tensor[(16), float32], %v30: Tensor[(16, 4), float32], %v34: Tensor[(16), float32], %states: (Tensor[(2, 4), float32], Tensor[(2, 4), float32])) -> (static_tensor_float32_?_2_4_t[], (Tensor[(2, 4), float32], Tensor[(2, 4), float32])) {
     %0 = Nil /* ty=List[static_tensor_float32_2_4_t[]] */;
     %36 = (
       let %while_loop: fn (int32, List[static_tensor_float32_2_4_t[]], (Tensor[(2, 4), float32], Tensor[(2, 4), float32])) -> (int32, List[static_tensor_float32_2_4_t[]], (Tensor[(2, 4), float32], Tensor[(2, 4), float32])) = fn (%i.1: int32, %outputs.6: List[static_tensor_float32_2_4_t[]], %state.6: (Tensor[(2, 4), float32], Tensor[(2, 4), float32])) -> (int32, List[static_tensor_float32_2_4_t[]], (Tensor[(2, 4), float32], Tensor[(2, 4), float32])) {
         %1 = less(%i.1, 5 /* ty=int32 */) /* ty=bool */;
         if (%1) {
           %2 = add(%i.1, 1 /* ty=int32 */) /* ty=int32 */;
           %3 = take(%input, %i.1, axis=0) /* ty=Tensor[(2, 3), float32] */;
           %4 = transpose(%v25, axes=[1, 0]) /* ty=Tensor[(3, 16), float32] */;
           %5 = transpose(%4, axes=[1, 0]) /* ty=Tensor[(16, 3), float32] */;
           %6 = nn.dense(%3, %5, units=None) /* ty=Tensor[(2, 16), float32] */;
           %7 = add(%6, %v28) /* ty=Tensor[(2, 16), float32] */;
           %8 = %state.6.0;
           %9 = transpose(%v30, axes=[1, 0]) /* ty=Tensor[(4, 16), float32] */;
           %10 = transpose(%9, axes=[1, 0]) /* ty=Tensor[(16, 4), float32] */;
           %11 = nn.dense(%8, %10, units=None) /* ty=Tensor[(2, 16), float32] */;
           %12 = add(%7, %11) /* ty=Tensor[(2, 16), float32] */;
           %13 = add(%12, %v34) /* ty=Tensor[(2, 16), float32] */;
           %14 = strided_slice(%13, begin=[0, 12], end=[2, 16], strides=[1, 1]) /* ty=Tensor[(2, 4), float32] */;
           %15 = sigmoid(%14) /* ty=Tensor[(2, 4), float32] */;
           %16 = strided_slice(%13, begin=[0, 4], end=[2, 8], strides=[1, 1]) /* ty=Tensor[(2, 4), float32] */;
           %17 = sigmoid(%16) /* ty=Tensor[(2, 4), float32] */;
           %18 = %state.6.1;
           %19 = multiply(%17, %18) /* ty=Tensor[(2, 4), float32] */;
           %20 = strided_slice(%13, begin=[0, 0], end=[2, 4], strides=[1, 1]) /* ty=Tensor[(2, 4), float32] */;
           %21 = sigmoid(%20) /* ty=Tensor[(2, 4), float32] */;
           %22 = strided_slice(%13, begin=[0, 8], end=[2, 12], strides=[1, 1]) /* ty=Tensor[(2, 4), float32] */;
           %23 = tanh(%22) /* ty=Tensor[(2, 4), float32] */;
           %24 = multiply(%21, %23) /* ty=Tensor[(2, 4), float32] */;
           %25 = add(%19, %24) /* ty=Tensor[(2, 4), float32] */;
           %26 = tanh(%25) /* ty=Tensor[(2, 4), float32] */;
           %27 = multiply(%15, %26) /* ty=Tensor[(2, 4), float32] */;
           %28 = (%27, %25);
           %29 = (%27, %28);
           %30 = %29.0;
           %31 = tensor_constructor_float32_2_4(%30) /* ty=static_tensor_float32_2_4_t[] */;
           %32 = Nil /* ty=List[static_tensor_float32_2_4_t[]] */;
           %33 = Cons(%31, %32) /* ty=List[static_tensor_float32_2_4_t[]] */;
           %34 = @concat(%outputs.6, %33) /* ty=List[static_tensor_float32_2_4_t[]] */;
           %35 = %29.1;
           %while_loop(%2, %34, %35) /* ty=(int32, List[static_tensor_float32_2_4_t[]], (Tensor[(2, 4), float32], Tensor[(2, 4), float32])) */
         } else {
           (%i.1, %outputs.6, %state.6)
         }
       };
       %while_loop
     );
     %37 = %36(0 /* ty=int32 */, %0, %states) /* ty=(int32, List[static_tensor_float32_2_4_t[]], (Tensor[(2, 4), float32], Tensor[(2, 4), float32])) */;
     %38 = %37.1;
     %39 = @tensor_array_stack_float32_2_4(%38) /* ty=static_tensor_float32_?_2_4_t[] */;
     %40 = %37.2;
     (%39, %40)
   }
   
   ```

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] kevinthesun commented on issue #5243: [Frontend][TensorFlow]Improve TensorFlow Static Shape Tensor Array

Posted by GitBox <gi...@apache.org>.
kevinthesun commented on issue #5243: [Frontend][TensorFlow]Improve TensorFlow Static Shape Tensor Array
URL: https://github.com/apache/incubator-tvm/pull/5243#issuecomment-610039127
 
 
   https://github.com/apache/incubator-tvm/pull/5243/files#diff-eae8ecf976e0031823eeae454466f964R903 Take tensor_array_gather as an example, you create a new static tensor array ops object with your input tensor array shape, and register all ops except tensor_get_data. After this, https://github.com/apache/incubator-tvm/pull/5243/files#diff-eae8ecf976e0031823eeae454466f964R924 you need to manually register tensor_get_data. It won't be automatically registered since input shape and output shape might not match.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] masahi commented on issue #5243: [Frontend][TensorFlow]Improve TensorFlow Static Shape Tensor Array

Posted by GitBox <gi...@apache.org>.
masahi commented on issue #5243: [Frontend][TensorFlow]Improve TensorFlow Static Shape Tensor Array
URL: https://github.com/apache/incubator-tvm/pull/5243#issuecomment-612256307
 
 
   @kevinthesun I'm not entirely familiar with TF let alone its tensor array support. If that is fine I can review.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] kevinthesun commented on issue #5243: [Frontend][TensorFlow]Improve TensorFlow Static Shape Tensor Array

Posted by GitBox <gi...@apache.org>.
kevinthesun commented on issue #5243: [Frontend][TensorFlow]Improve TensorFlow Static Shape Tensor Array
URL: https://github.com/apache/incubator-tvm/pull/5243#issuecomment-609926012
 
 
   > 1. When I append the tensor to tensor array (by concat), I can do infer shape on the input tensor to get the fixed shape static tensor array expects. But after I've done some appends and try to stack the static tensor array, I don't have a way to tell what fixed shape the input tensor array to stack expects. See
   >    https://github.com/masahi/tvm/blob/support-more-rnn/python/tvm/relay/frontend/pytorch.py#L989-L990
   >    Since the shape is fixed, I think there should be an easy way to query the shape associated with a static array. I see you have such function `check_tensor_array_shape` in this PR (by parsing op name). Is this the recommended way?
   
   Yes you can use ```check_tensor_array_shape```. I'll change the name to ```get_tensor_array_shape```.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] masahi commented on issue #5243: [Frontend][TensorFlow]Improve TensorFlow Static Shape Tensor Array

Posted by GitBox <gi...@apache.org>.
masahi commented on issue #5243: [Frontend][TensorFlow]Improve TensorFlow Static Shape Tensor Array
URL: https://github.com/apache/incubator-tvm/pull/5243#issuecomment-609602675
 
 
   Update: With the new static tensor array, I got the following PyTorch LSTM model, originally from the fastrnn benchmark in PyTorch repo here https://github.com/pytorch/pytorch/blob/master/benchmarks/fastrnns/custom_lstms.py#L187, converted correctly to Relay and got the identical result as torch! It was not possible with generic tensor array. @kevinthesun @wweic 
   
   ```Py
   class LSTMCell(jit.ScriptModule):
       def __init__(self, input_size, hidden_size):
           super().__init__()
           self.input_size = input_size
           self.hidden_size = hidden_size
           self.weight_ih = Parameter(torch.randn(4 * hidden_size, input_size))
           self.weight_hh = Parameter(torch.randn(4 * hidden_size, hidden_size))
           self.bias_ih = Parameter(torch.randn(4 * hidden_size))
           self.bias_hh = Parameter(torch.randn(4 * hidden_size))
   
       @jit.script_method
       def forward(self, input, state):
           # type: (Tensor, Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]
           hx, cx = state
           gates = (torch.mm(input, self.weight_ih.t()) + self.bias_ih +
                    torch.mm(hx, self.weight_hh.t()) + self.bias_hh)
           ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
   
           ingate = torch.sigmoid(ingate)
           forgetgate = torch.sigmoid(forgetgate)
           cellgate = torch.tanh(cellgate)
           outgate = torch.sigmoid(outgate)
   
           cy = (forgetgate * cx) + (ingate * cellgate)
           hy = outgate * torch.tanh(cy)
   
           return hy, (hy, cy)
   
   
   class LSTMLayer(jit.ScriptModule):
       def __init__(self, cell, *cell_args):
           super().__init__()
           self.cell = cell(*cell_args)
   
       @jit.script_method
       def forward(self, input, state):
           # type: (Tensor, Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]
           outputs = []
           for i in range(input.size(0)):
               out, state = self.cell(input[i], state)
               outputs += [out]
           return torch.stack(outputs), state
   ```
   
   Here is the converted Relay IR:
   ```
   fn (%input: Tensor[(5, 2, 3), float32], %v25: Tensor[(16, 3), float32], %v28: Tensor[(16), float32], %v30: Tensor[(16, 4), float32], %v34: Tensor[(16), float32], %states: (Tensor[(2, 4), float32], Tensor[(2, 4), float32])) -> (static_tensor_float32_?_2_4_t[], (Tensor[(2, 4), float32], Tensor[(2, 4), float32])) {
     %0 = Nil /* ty=List[static_tensor_float32_2_4_t[]] */;
     %34 = (
       let %while_loop: fn (int32, List[static_tensor_float32_2_4_t[]], (Tensor[(2, 4), float32], Tensor[(2, 4), float32])) -> (int32, List[static_tensor_float32_2_4_t[]], (Tensor[(2, 4), float32], Tensor[(2, 4), float32])) = fn (%i.1: int32, %outputs.6: List[static_tensor_float32_2_4_t[]], %state.6: (Tensor[(2, 4), float32], Tensor[(2, 4), float32])) -> (int32, List[static_tensor_float32_2_4_t[]], (Tensor[(2, 4), float32], Tensor[(2, 4), float32])) {
         %1 = less(%i.1, 5 /* ty=int32 */) /* ty=bool */;
         if (%1) {
           %2 = add(%i.1, 1 /* ty=int32 */) /* ty=int32 */;
           %3 = take(%input, %i.1, axis=0) /* ty=Tensor[(2, 3), float32] */;
           %4 = transpose(%v25, axes=[1, 0]) /* ty=Tensor[(3, 16), float32] */;
           %5 = transpose(%4, axes=[1, 0]) /* ty=Tensor[(16, 3), float32] */;
           %6 = nn.dense(%3, %5, units=None) /* ty=Tensor[(2, 16), float32] */;
           %7 = add(%6, %v28) /* ty=Tensor[(2, 16), float32] */;
           %8 = %state.6.0;
           %9 = transpose(%v30, axes=[1, 0]) /* ty=Tensor[(4, 16), float32] */;
           %10 = transpose(%9, axes=[1, 0]) /* ty=Tensor[(16, 4), float32] */;
           %11 = nn.dense(%8, %10, units=None) /* ty=Tensor[(2, 16), float32] */;
           %12 = add(%7, %11) /* ty=Tensor[(2, 16), float32] */;
           %13 = add(%12, %v34) /* ty=Tensor[(2, 16), float32] */;
           %14 = strided_slice(%13, begin=[0, 12], end=[2, 16], strides=[1, 1]) /* ty=Tensor[(2, 4), float32] */;
           %15 = sigmoid(%14) /* ty=Tensor[(2, 4), float32] */;
           %16 = strided_slice(%13, begin=[0, 4], end=[2, 8], strides=[1, 1]) /* ty=Tensor[(2, 4), float32] */;
           %17 = sigmoid(%16) /* ty=Tensor[(2, 4), float32] */;
           %18 = %state.6.1;
           %19 = multiply(%17, %18) /* ty=Tensor[(2, 4), float32] */;
           %20 = strided_slice(%13, begin=[0, 0], end=[2, 4], strides=[1, 1]) /* ty=Tensor[(2, 4), float32] */;
           %21 = sigmoid(%20) /* ty=Tensor[(2, 4), float32] */;
           %22 = strided_slice(%13, begin=[0, 8], end=[2, 12], strides=[1, 1]) /* ty=Tensor[(2, 4), float32] */;
           %23 = tanh(%22) /* ty=Tensor[(2, 4), float32] */;
           %24 = multiply(%21, %23) /* ty=Tensor[(2, 4), float32] */;
           %25 = add(%19, %24) /* ty=Tensor[(2, 4), float32] */;
           %26 = tanh(%25) /* ty=Tensor[(2, 4), float32] */;
           %27 = multiply(%15, %26) /* ty=Tensor[(2, 4), float32] */;
           %28 = (%27, %25);
           %29 = (%27, %28);
           %30 = %29.0;
           %31 = tensor_constructor_float32_2_4(%30) /* ty=static_tensor_float32_2_4_t[] */;
           %32 = Cons(%31, %outputs.6) /* ty=List[static_tensor_float32_2_4_t[]] */;
           %33 = %29.1;
           %while_loop(%2, %32, %33) /* ty=(int32, List[static_tensor_float32_2_4_t[]], (Tensor[(2, 4), float32], Tensor[(2, 4), float32])) */
         } else {
           (%i.1, %outputs.6, %state.6)
         }
       };
       %while_loop
     );
     %35 = %34(0 /* ty=int32 */, %0, %states) /* ty=(int32, List[static_tensor_float32_2_4_t[]], (Tensor[(2, 4), float32], Tensor[(2, 4), float32])) */;
     %36 = %35.1;
     %37 = @rev(%36) /* ty=List[static_tensor_float32_2_4_t[]] */;
     %38 = @tensor_array_stack_float32_2_4(%37) /* ty=static_tensor_float32_?_2_4_t[] */;
     %39 = %35.2;
     (%38, %39)
   }
   ```

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] kevinthesun commented on issue #5243: [Frontend][TensorFlow]Improve TensorFlow Static Shape Tensor Array

Posted by GitBox <gi...@apache.org>.
kevinthesun commented on issue #5243: [Frontend][TensorFlow]Improve TensorFlow Static Shape Tensor Array
URL: https://github.com/apache/incubator-tvm/pull/5243#issuecomment-612260354
 
 
   @masahi Sure. Please go ahead and review. I think a lot of logics can be reused in pytorch.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] masahi commented on a change in pull request #5243: [Frontend][TensorFlow]Improve TensorFlow Static Shape Tensor Array

Posted by GitBox <gi...@apache.org>.
masahi commented on a change in pull request #5243: [Frontend][TensorFlow]Improve TensorFlow Static Shape Tensor Array
URL: https://github.com/apache/incubator-tvm/pull/5243#discussion_r403818296
 
 

 ##########
 File path: python/tvm/relay/frontend/common.py
 ##########
 @@ -548,6 +558,28 @@ def new_var(name_hint,
     return _expr.var(name_hint, type_annotation, shape, dtype)
 
 
+def check_tensor_array_shape(expr, dtype, mod):
 
 Review comment:
   I think querying the fixed shape associated with a static tensor array would be very common thing to do. Does it make sense to have such utility function in prelude, rather than in frontend code?

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services