You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2020/06/18 08:13:21 UTC
[GitHub] [incubator-tvm] siju-samuel commented on a change in pull request #5834: Improve type handling in PyTorch frontend
siju-samuel commented on a change in pull request #5834:
URL: https://github.com/apache/incubator-tvm/pull/5834#discussion_r442032927
##########
File path: python/tvm/relay/frontend/pytorch.py
##########
@@ -1733,12 +1780,19 @@ def _convert_dtype_value(val):
0:"torch.unit8",
None:"torch.int64"} # Default is torch.int64
if val in convert_torch_dtype_map:
- return convert_torch_dtype_map[val]
+ return _convert_data_type(convert_torch_dtype_map[val])
else:
msg = "Torch data type value %d is not handled yet." % (val)
raise NotImplementedError(msg)
-def _convert_data_type(input_type):
+def _convert_data_type(input_type, default_dtype=None):
+ """converts the PyTorch scalar type input_type to a TVM dtype.
+ optionally, default_dtype can be a TVM dtype that is used
+ if input_type is None (but not when it is unknown)"""
+ if input_type is None and default_dtype is not None:
+ return default_dtype
+
+ input_type = input_type.lower()
if input_type in ["double", "torch.float64"]:
return "float64"
elif input_type in ["float", "torch.float32"]:
Review comment:
Add "float32" here
##########
File path: python/tvm/relay/frontend/pytorch.py
##########
@@ -1755,12 +1809,21 @@ def _convert_data_type(input_type):
return "int8"
Review comment:
Add "int64" here
##########
File path: tests/python/frontend/pytorch/test_forward.py
##########
@@ -2363,6 +2366,23 @@ def forward(self, *args):
t2 = torch.rand([1, 3]).float()
verify_model(Addcmul2().float().eval(), input_data=[input_data, t1, t2])
+def test_forward_traced_function():
+ def fn(t1, t2):
+ return t1 + t2
+
+ tensor1 = torch.randn(3, 4)
+ tensor2 = torch.randn(3, 4)
+ verify_model(fn, input_data=[tensor1, tensor2])
+
+def test_forward_dtypes():
+ def fn(t1, t2):
+ return 2.5 * t1 + t2
+
+ for dt in [torch.int32, torch.int64, torch.double]:
+ tensor1 = torch.randn(3, 4).to(dtype=dt)
+ tensor2 = torch.randn(3, 4).to(dtype=dt)
+ verify_model(fn, input_data=[tensor1, tensor2])
+
Review comment:
add `test_forward_traced_function` and `test_forward_dtypes` to [main](https://github.com/apache/incubator-tvm/blob/f305b31d6343f207b913eb1aafc8d07782445e33/tests/python/frontend/pytorch/test_forward.py#L2528)
##########
File path: python/tvm/relay/frontend/pytorch.py
##########
@@ -115,64 +115,70 @@ def inplace_add_to_add(op_name):
return False
+
# operator implementation
def _elemwise(name):
def _impl(inputs, input_types):
- # TODO: Figure out a better way to get typing to work for tensor + scalar
- type0 = input_types[0]
- if isinstance(inputs[1], _expr.Expr):
- type0 = input_types[1]
-
- type1 = input_types[1]
- if isinstance(inputs[0], _expr.Expr):
- type1 = input_types[0]
-
- data0 = _convert_elemwise_input(inputs[0], type0)
- data1 = _convert_elemwise_input(inputs[1], type1)
-
+ data0, data1 = _pytorch_promote_types(inputs[:2], input_types[:2])
Review comment:
i was trying to run gp2 model and i was getting some error with the new modifications.
Code you can download from
https://gist.github.com/siju-samuel/34e63e0719e06679b5c3688bce7a0515
Error is
```
Traceback (most recent call last):
File "gp2.py", line 26, in <module>
mod, params = relay.frontend.from_pytorch(scripted_model, input_shapes)
File "/home/siju/workspace/tvm/python/tvm/relay/frontend/pytorch.py", line 2645, in from_pytorch
ret = convert_operators(_get_operator_nodes(graph.nodes()),
File "/home/siju/workspace/tvm/python/tvm/relay/frontend/pytorch.py", line 2555, in convert_operators
relay_out = relay_op(inputs, _get_input_types(op_node, default_dtype=default_dtype))
File "/home/siju/workspace/tvm/python/tvm/relay/frontend/pytorch.py", line 1694, in _impl
return _elemwise("add")(inputs, input_types)
File "/home/siju/workspace/tvm/python/tvm/relay/frontend/pytorch.py", line 151, in _impl
return get_relay_op(name)(data0, data1)
File "/home/siju/workspace/tvm/python/tvm/relay/op/tensor.py", line 513, in add
return _make.add(lhs, rhs)
File "/home/siju/workspace/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 225, in __call__
raise get_last_ffi_error()
tvm._ffi.base.TVMError: Traceback (most recent call last):
[bt] (4) /home/siju/workspace/tvm/build/libtvm.so(TVMFuncCall+0x69) [0x7f983831cc09]
[bt] (3) /home/siju/workspace/tvm/build/libtvm.so(+0xa5be6b) [0x7f9837f31e6b]
[bt] (2) /home/siju/workspace/tvm/build/libtvm.so(tvm::runtime::TVMMovableArgValue_::operator tvm::RelayExpr<tvm::RelayExpr, void>() const+0x63) [0x7f9837e0ecf3]
[bt] (1) /home/siju/workspace/tvm/build/libtvm.so(tvm::RelayExpr tvm::runtime::TVMPODValue_::AsObjectRef<tvm::RelayExpr>() const+0x1a6) [0x7f9837ab21b6]
[bt] (0) /home/siju/workspace/tvm/build/libtvm.so(+0x5cc26b) [0x7f9837aa226b]
File "/home/siju/workspace/tvm/include/tvm/runtime/packed_func.h", line 1423
TVMError: Check failed: type_code_ == kTVMObjectHandle (0 vs. 8) : expected Object but get int
```
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
users@infra.apache.org