You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/02/24 22:52:18 UTC

[GitHub] [tvm] apivovarov edited a comment on pull request #7513: [Torch] Add copy_ operator

apivovarov edited a comment on pull request #7513:
URL: https://github.com/apache/tvm/pull/7513#issuecomment-785436558


   @masahi I tried to add the following to `def _run_jit_passes(graph)`
   ```
   torch._C._jit_pass_onnx_prepare_inplace_ops_for_onnx(graph)
   torch._C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph)
   ```
   
   Yes it replaces `copy_` with `expand_as` + `index_put`. But `index_put` operator gets empty indices array in that case
   
   Test Module:
   ```
   import torch
   import numpy as np
   
   class MyCopy(torch.nn.Module):
       def __init__(self, shape):
           super(MyCopy, self).__init__()
           self.shape = shape
           
       def forward(self, values):
           A = torch.zeros(self.shape)
           B = A.copy_(values)
           return B
   
   
   MP = MyCopy((2,4))
   a = torch.tensor([0, 1, 2, 6])
   MP(a)
   
   traced_MP = torch.jit.trace(MP, (a))
   ```
   
   ```
   traced_MP.graph
   graph(%self : __torch__.MyCopy,
         %values : Long(4, strides=[1], requires_grad=0, device=cpu)):
     %4 : int = prim::Constant[value=2]() # <stdin>:7:0
     %5 : int = prim::Constant[value=4]() # <stdin>:7:0
     %6 : int[] = prim::ListConstruct(%4, %5)
     %7 : int = prim::Constant[value=6]() # <stdin>:7:0
     %8 : None = prim::Constant()
     %9 : Device = prim::Constant[value="cpu"]() # <stdin>:7:0
     %10 : bool = prim::Constant[value=0]() # <stdin>:7:0
     %A : Float(2, 4, strides=[4, 1], requires_grad=0, device=cpu) = aten::zeros(%6, %7, %8, %9, %10) # <stdin>:7:0
     %12 : bool = prim::Constant[value=0]()
     %13 : Float(2, 4, strides=[4, 1], requires_grad=0, device=cpu) = aten::copy_(%A, %values, %12) # <stdin>:8:0
     return (%13)
   ```
   
   After jit_passes
   ```
   graph = traced_MP.graph.copy()
   torch._C._jit_pass_onnx_function_substitution(graph)
   torch._C._jit_pass_onnx_prepare_inplace_ops_for_onnx(graph)
   torch._C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph)
   graph
   graph(%self : __torch__.MyCopy,
         %values.1 : Long(4, strides=[1], requires_grad=0, device=cpu)):
     %2 : int = prim::Constant[value=2]() # <stdin>:7:0
     %3 : int = prim::Constant[value=4]() # <stdin>:7:0
     %4 : int[] = prim::ListConstruct(%2, %3)
     %5 : int = prim::Constant[value=6]() # <stdin>:7:0
     %6 : None = prim::Constant()
     %7 : Device = prim::Constant[value="cpu"]() # <stdin>:7:0
     %8 : bool = prim::Constant[value=0]() # <stdin>:7:0
     %A : Float(2, 4, strides=[4, 1], requires_grad=0, device=cpu) = aten::zeros(%4, %5, %6, %7, %8) # <stdin>:7:0
     %10 : bool = prim::Constant[value=0]()
     %values : Long(4, strides=[1], requires_grad=0, device=cpu) = aten::expand_as(%values.1, %A) # <stdin>:8:0
     %15 : Tensor?[] = prim::ListConstruct()
     %16 : Float(2, 4, strides=[4, 1], requires_grad=0, device=cpu) = aten::index_put(%A, %15, %values, %10)
     return (%16)
   ```
   
   As you can see %15 is empty list.
   
   As a result we get TVM error because index_put indices array is empty:
   ```
   import tvm
   from tvm import relay
   ctx = tvm.cpu(0)
   target = 'llvm'
   
   shape_list = [("input0", [4,]),]
   mod, params = relay.frontend.from_pytorch(traced_MP, shape_list)
   
   Traceback (most recent call last):
     File "<stdin>", line 1, in <module>
     File "/Users/pivovaa/workspace/tvm/python/tvm/relay/frontend/pytorch.py", line 3186, in from_pytorch
       ret = converter.convert_operators(_get_operator_nodes(graph.nodes()), outputs, ret_name)[0]
     File "/Users/pivovaa/workspace/tvm/python/tvm/relay/frontend/pytorch.py", line 2605, in convert_operators
       relay_out = relay_op(
     File "/Users/pivovaa/workspace/tvm/python/tvm/relay/frontend/pytorch.py", line 2059, in index_put
       index_tensor = _op.stack(indices, axis=0)
     File "/Users/pivovaa/workspace/tvm/python/tvm/relay/op/tensor.py", line 1124, in stack
       raise ValueError("relay.stack requires data to be non-empty.")
   ValueError: relay.stack requires data to be non-empty.
   ```
   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org