You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/02/24 22:52:18 UTC
[GitHub] [tvm] apivovarov edited a comment on pull request #7513: [Torch] Add copy_ operator
apivovarov edited a comment on pull request #7513:
URL: https://github.com/apache/tvm/pull/7513#issuecomment-785436558
@masahi I tried to add the following to `def _run_jit_passes(graph)`
```
torch._C._jit_pass_onnx_prepare_inplace_ops_for_onnx(graph)
torch._C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph)
```
Yes it replaces `copy_` with `expand_as` + `index_put`. But `index_put` operator gets empty indices array in that case
Test Module:
```
import torch
import numpy as np
class MyCopy(torch.nn.Module):
def __init__(self, shape):
super(MyCopy, self).__init__()
self.shape = shape
def forward(self, values):
A = torch.zeros(self.shape)
B = A.copy_(values)
return B
MP = MyCopy((2,4))
a = torch.tensor([0, 1, 2, 6])
MP(a)
traced_MP = torch.jit.trace(MP, (a))
```
```
traced_MP.graph
graph(%self : __torch__.MyCopy,
%values : Long(4, strides=[1], requires_grad=0, device=cpu)):
%4 : int = prim::Constant[value=2]() # <stdin>:7:0
%5 : int = prim::Constant[value=4]() # <stdin>:7:0
%6 : int[] = prim::ListConstruct(%4, %5)
%7 : int = prim::Constant[value=6]() # <stdin>:7:0
%8 : None = prim::Constant()
%9 : Device = prim::Constant[value="cpu"]() # <stdin>:7:0
%10 : bool = prim::Constant[value=0]() # <stdin>:7:0
%A : Float(2, 4, strides=[4, 1], requires_grad=0, device=cpu) = aten::zeros(%6, %7, %8, %9, %10) # <stdin>:7:0
%12 : bool = prim::Constant[value=0]()
%13 : Float(2, 4, strides=[4, 1], requires_grad=0, device=cpu) = aten::copy_(%A, %values, %12) # <stdin>:8:0
return (%13)
```
After jit_passes
```
graph = traced_MP.graph.copy()
torch._C._jit_pass_onnx_function_substitution(graph)
torch._C._jit_pass_onnx_prepare_inplace_ops_for_onnx(graph)
torch._C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph)
graph
graph(%self : __torch__.MyCopy,
%values.1 : Long(4, strides=[1], requires_grad=0, device=cpu)):
%2 : int = prim::Constant[value=2]() # <stdin>:7:0
%3 : int = prim::Constant[value=4]() # <stdin>:7:0
%4 : int[] = prim::ListConstruct(%2, %3)
%5 : int = prim::Constant[value=6]() # <stdin>:7:0
%6 : None = prim::Constant()
%7 : Device = prim::Constant[value="cpu"]() # <stdin>:7:0
%8 : bool = prim::Constant[value=0]() # <stdin>:7:0
%A : Float(2, 4, strides=[4, 1], requires_grad=0, device=cpu) = aten::zeros(%4, %5, %6, %7, %8) # <stdin>:7:0
%10 : bool = prim::Constant[value=0]()
%values : Long(4, strides=[1], requires_grad=0, device=cpu) = aten::expand_as(%values.1, %A) # <stdin>:8:0
%15 : Tensor?[] = prim::ListConstruct()
%16 : Float(2, 4, strides=[4, 1], requires_grad=0, device=cpu) = aten::index_put(%A, %15, %values, %10)
return (%16)
```
As you can see %15 is empty list.
As a result we get TVM error because index_put indices array is empty:
```
import tvm
from tvm import relay
ctx = tvm.cpu(0)
target = 'llvm'
shape_list = [("input0", [4,]),]
mod, params = relay.frontend.from_pytorch(traced_MP, shape_list)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/Users/pivovaa/workspace/tvm/python/tvm/relay/frontend/pytorch.py", line 3186, in from_pytorch
ret = converter.convert_operators(_get_operator_nodes(graph.nodes()), outputs, ret_name)[0]
File "/Users/pivovaa/workspace/tvm/python/tvm/relay/frontend/pytorch.py", line 2605, in convert_operators
relay_out = relay_op(
File "/Users/pivovaa/workspace/tvm/python/tvm/relay/frontend/pytorch.py", line 2059, in index_put
index_tensor = _op.stack(indices, axis=0)
File "/Users/pivovaa/workspace/tvm/python/tvm/relay/op/tensor.py", line 1124, in stack
raise ValueError("relay.stack requires data to be non-empty.")
ValueError: relay.stack requires data to be non-empty.
```
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
users@infra.apache.org