You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2020/06/18 08:13:21 UTC

[GitHub] [incubator-tvm] siju-samuel commented on a change in pull request #5834: Improve type handling in PyTorch frontend

siju-samuel commented on a change in pull request #5834:
URL: https://github.com/apache/incubator-tvm/pull/5834#discussion_r442032927



##########
File path: python/tvm/relay/frontend/pytorch.py
##########
@@ -1733,12 +1780,19 @@ def _convert_dtype_value(val):
                                0:"torch.unit8",
                                None:"torch.int64"} # Default is torch.int64
     if val in convert_torch_dtype_map:
-        return convert_torch_dtype_map[val]
+        return _convert_data_type(convert_torch_dtype_map[val])
     else:
         msg = "Torch data type value %d is not handled yet." % (val)
         raise NotImplementedError(msg)
 
-def _convert_data_type(input_type):
+def _convert_data_type(input_type, default_dtype=None):
+    """converts the PyTorch scalar type input_type to a TVM dtype.
+       optionally, default_dtype can be a TVM dtype that is used
+       if input_type is None (but not when it is unknown)"""
+    if input_type is None and default_dtype is not None:
+        return default_dtype
+
+    input_type = input_type.lower()
     if input_type in ["double", "torch.float64"]:
         return "float64"
     elif input_type in ["float", "torch.float32"]:

Review comment:
       Add "float32" here

##########
File path: python/tvm/relay/frontend/pytorch.py
##########
@@ -1755,12 +1809,21 @@ def _convert_data_type(input_type):
         return "int8"

Review comment:
       Add "int64" here

##########
File path: tests/python/frontend/pytorch/test_forward.py
##########
@@ -2363,6 +2366,23 @@ def forward(self, *args):
     t2 = torch.rand([1, 3]).float()
     verify_model(Addcmul2().float().eval(), input_data=[input_data, t1, t2])
 
+def test_forward_traced_function():
+    def fn(t1, t2):
+        return t1 + t2
+
+    tensor1 = torch.randn(3, 4)
+    tensor2 = torch.randn(3, 4)
+    verify_model(fn, input_data=[tensor1, tensor2])
+
+def test_forward_dtypes():
+    def fn(t1, t2):
+        return 2.5 * t1 + t2
+
+    for dt in [torch.int32, torch.int64, torch.double]:
+        tensor1 = torch.randn(3, 4).to(dtype=dt)
+        tensor2 = torch.randn(3, 4).to(dtype=dt)
+        verify_model(fn, input_data=[tensor1, tensor2])
+

Review comment:
       add `test_forward_traced_function` and `test_forward_dtypes` to [main](https://github.com/apache/incubator-tvm/blob/f305b31d6343f207b913eb1aafc8d07782445e33/tests/python/frontend/pytorch/test_forward.py#L2528)

##########
File path: python/tvm/relay/frontend/pytorch.py
##########
@@ -115,64 +115,70 @@ def inplace_add_to_add(op_name):
     return False
 
 
+
 # operator implementation
 def _elemwise(name):
     def _impl(inputs, input_types):
-        # TODO: Figure out a better way to get typing to work for tensor + scalar
-        type0 = input_types[0]
-        if isinstance(inputs[1], _expr.Expr):
-            type0 = input_types[1]
-
-        type1 = input_types[1]
-        if isinstance(inputs[0], _expr.Expr):
-            type1 = input_types[0]
-
-        data0 = _convert_elemwise_input(inputs[0], type0)
-        data1 = _convert_elemwise_input(inputs[1], type1)
-
+        data0, data1 = _pytorch_promote_types(inputs[:2], input_types[:2])

Review comment:
       i was trying to run gp2 model and i was getting some error with the new modifications. 
   Code you can download from 
   https://gist.github.com/siju-samuel/34e63e0719e06679b5c3688bce7a0515
   
   Error is
   
   ```
   Traceback (most recent call last):
     File "gp2.py", line 26, in <module>
       mod, params = relay.frontend.from_pytorch(scripted_model, input_shapes)
     File "/home/siju/workspace/tvm/python/tvm/relay/frontend/pytorch.py", line 2645, in from_pytorch
       ret = convert_operators(_get_operator_nodes(graph.nodes()),
     File "/home/siju/workspace/tvm/python/tvm/relay/frontend/pytorch.py", line 2555, in convert_operators
       relay_out = relay_op(inputs, _get_input_types(op_node, default_dtype=default_dtype))
     File "/home/siju/workspace/tvm/python/tvm/relay/frontend/pytorch.py", line 1694, in _impl
       return _elemwise("add")(inputs, input_types)
     File "/home/siju/workspace/tvm/python/tvm/relay/frontend/pytorch.py", line 151, in _impl
       return get_relay_op(name)(data0, data1)
     File "/home/siju/workspace/tvm/python/tvm/relay/op/tensor.py", line 513, in add
       return _make.add(lhs, rhs)
     File "/home/siju/workspace/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 225, in __call__
       raise get_last_ffi_error()
   tvm._ffi.base.TVMError: Traceback (most recent call last):
     [bt] (4) /home/siju/workspace/tvm/build/libtvm.so(TVMFuncCall+0x69) [0x7f983831cc09]
     [bt] (3) /home/siju/workspace/tvm/build/libtvm.so(+0xa5be6b) [0x7f9837f31e6b]
     [bt] (2) /home/siju/workspace/tvm/build/libtvm.so(tvm::runtime::TVMMovableArgValue_::operator tvm::RelayExpr<tvm::RelayExpr, void>() const+0x63) [0x7f9837e0ecf3]
     [bt] (1) /home/siju/workspace/tvm/build/libtvm.so(tvm::RelayExpr tvm::runtime::TVMPODValue_::AsObjectRef<tvm::RelayExpr>() const+0x1a6) [0x7f9837ab21b6]
     [bt] (0) /home/siju/workspace/tvm/build/libtvm.so(+0x5cc26b) [0x7f9837aa226b]
     File "/home/siju/workspace/tvm/include/tvm/runtime/packed_func.h", line 1423
   TVMError: Check failed: type_code_ == kTVMObjectHandle (0 vs. 8) : expected Object but get int
   ```




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org