You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2022/06/29 23:53:42 UTC

[tvm] branch main updated: [Relay][Pytorch] Add aten::new_ones, aten::new_full, aten::fill_, aten::pad, aten::reshape_as and atem::empty_like (#11896)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 41c94b27ef [Relay][Pytorch] Add aten::new_ones, aten::new_full, aten::fill_, aten::pad, aten::reshape_as and atem::empty_like (#11896)
41c94b27ef is described below

commit 41c94b27ef5f10ad70af211dd25c4837dad53f64
Author: Yuanjing Shi <yu...@octoml.ai>
AuthorDate: Wed Jun 29 16:53:36 2022 -0700

    [Relay][Pytorch] Add aten::new_ones, aten::new_full, aten::fill_, aten::pad, aten::reshape_as and atem::empty_like (#11896)
    
    * add new ops
    
    * fix pad
    
    * fix pad
    
    * remove pad
    
    * fix CI
    
    * remove doc
    
    * fix fill_
    
    * add tests
---
 python/tvm/relay/frontend/pytorch.py          | 55 ++++++++++++++++++++
 tests/python/frontend/pytorch/test_forward.py | 75 +++++++++++++++++++++++++++
 2 files changed, 130 insertions(+)

diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py
index 9558ad1b6e..6fe8c89e3c 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -701,6 +701,21 @@ class PyTorchOpConverter:
 
         return out
 
+    def new_ones(self, inputs, input_types):
+        size = inputs[1]
+
+        import torch
+
+        if not isinstance(size, (_expr.Expr, list, tuple, torch.Size, np.ndarray)):
+            msg = "Data type %s could not be parsed in ones op" % (type(size))
+            raise AssertionError(msg)
+
+        if inputs[2] is not None:
+            dtype = _convert_dtype_value(inputs[2])
+        else:
+            dtype = input_types[0]
+        return self.full_impl(size, 1, dtype)
+
     def zeros(self, inputs, input_types):
         data = inputs[0]
 
@@ -765,6 +780,28 @@ class PyTorchOpConverter:
 
         return out
 
+    def new_full(self, inputs, input_types):
+        data = inputs[1]
+        fill_value = inputs[2]
+        import torch
+
+        if not isinstance(data, (_expr.Expr, list, tuple, torch.Size)):
+            msg = "Data type %s could not be parsed in full op" % (type(data))
+            raise AssertionError(msg)
+
+        if inputs[3] is not None:  # dtype given
+            dtype = _convert_dtype_value(inputs[3])
+        else:
+            # if dtype is None, use the dtype of the input tensor
+            dtype = self.infer_type(input[0])
+
+        return self.full_impl(data, fill_value, dtype)
+
+    def fill_(self, inputs, input_types):
+        data = inputs[0]
+        fill_value = inputs[1]
+        return self.full_impl(self.infer_shape(data), fill_value, input_types[0])
+
     def linspace(self, inputs, input_types):
         start = inputs[0]
         stop = inputs[1]
@@ -1425,6 +1462,11 @@ class PyTorchOpConverter:
             new_shape = tmp_shape
         return _op.transform.reshape(data, new_shape)
 
+    def reshape_as(self, inputs, input_types):
+        data = inputs[0]
+        new_shape = self.infer_shape(inputs[1])
+        return _op.transform.reshape(data, new_shape)
+
     def pixel_shuffle(self, inputs, input_types):
         data = inputs[0]
         upscale_factor = inputs[1]
@@ -2400,6 +2442,14 @@ class PyTorchOpConverter:
         shape = inputs[0]
         return _op.zeros(shape, _convert_dtype_value(inputs[1]))
 
+    def empty_like(self, inputs, input_types):
+        shape = self.infer_shape(inputs[0])
+        if inputs[1] is not None:
+            dtype = _convert_dtype_value(inputs[1])
+        else:
+            dtype = input_types[0]
+        return _op.zeros(shape, dtype)
+
     def bincount(self, inputs, input_types):
         data = inputs[0]
         weights = inputs[1]
@@ -3119,8 +3169,11 @@ class PyTorchOpConverter:
             "aten::ones_like": self.ones_like,
             "aten::zeros": self.zeros,
             "aten::zeros_like": self.zeros_like,
+            "aten::new_ones": self.new_ones,
             "aten::full": self.full,
             "aten::full_like": self.full_like,
+            "aten::new_full": self.new_full,
+            "aten::fill_": self.fill_,
             "aten::linspace": self.linspace,
             "aten::reciprocal": self.reciprocal,
             "aten::repeat": self.repeat,
@@ -3186,6 +3239,7 @@ class PyTorchOpConverter:
             "aten::size": self.size,
             "aten::view": self.view,
             "aten::reshape": self.reshape,
+            "aten::reshape_as": self.reshape_as,
             "aten::clone": self.clone,
             "aten::log_softmax": self.log_softmax,
             "aten::sigmoid": self.sigmoid,
@@ -3305,6 +3359,7 @@ class PyTorchOpConverter:
             "aten::tensor": self.identity,  # used for example in tensor(1.0)
             "aten::numel": self.numel,
             "aten::empty": self.empty,
+            "aten::empty_like": self.empty_like,
             "aten::bincount": self.bincount,
             "aten::scatter_add": self.scatter_add,
             "aten::__not__": self.logical_not,
diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py
index 1bb4517f01..f039a00f5d 100644
--- a/tests/python/frontend/pytorch/test_forward.py
+++ b/tests/python/frontend/pytorch/test_forward.py
@@ -199,6 +199,28 @@ def verify_model(
         torch.cuda.empty_cache()
 
 
+def verify_model_with_input(test_func, input_data, input_dict={}):
+    baseline_outputs = test_func(*input_data)
+    trace = torch.jit.trace(test_func, [input.clone() for input in input_data])
+    input_names = ["input{}".format(idx) for idx, inp in enumerate(input_data)]
+    input_shapes = list(zip(input_names, [inp.shape for inp in input_data]))
+    mod, params = relay.frontend.from_pytorch(trace, input_shapes, {})
+    with tvm.transform.PassContext(opt_level=3):
+        for target in ["llvm", "cuda"]:
+            if not tvm.runtime.enabled(target):
+                continue
+            dev = tvm.device(target, 0)
+            lib = relay.build(mod, target=target, params=params)
+            relay_model = graph_executor.GraphModule(lib["default"](dev))
+            for name, value in input_dict.items():
+                relay_model.set_input(name, value)
+            relay_model.run()
+
+            compiled_output = relay_model.get_output(0).numpy()
+            assert_shapes_match(baseline_outputs, compiled_output)
+            tvm.testing.assert_allclose(baseline_outputs, compiled_output, rtol=1e-5, atol=1e-5)
+
+
 # Single operator tests
 @tvm.testing.uses_gpu
 def test_forward_pixel_shuffle():
@@ -1275,6 +1297,16 @@ def test_forward_reshape():
     verify_model(Reshape3(), input_data=torch.randn(2, 3, 4))
 
 
+@tvm.testing.uses_gpu
+def test_forward_reshape_as():
+    def test_func(input_tensor, other_tensor):
+        return input_tensor.reshape_as(other_tensor)
+
+    input_data = [torch.rand([2, 1, 10, 1, 10]), torch.rand([2, 1, 10, 10])]
+
+    verify_model_with_input(test_func, input_data, {"input0": input_data[0]})
+
+
 @tvm.testing.uses_gpu
 def test_flatten():
     def _test_flatten(start_dim, end_dim):
@@ -2961,6 +2993,17 @@ def test_forward_ones_like():
     verify_model(OnesLike3().float().eval(), input_data=input_data)
 
 
+@tvm.testing.uses_gpu
+def test_forward_new_ones():
+    torch.set_grad_enabled(False)
+    input_shape = [1, 3, 10, 10]
+
+    def test_func(input_tensor):
+        return input_tensor.new_ones([3, 10, 10])
+
+    verify_model_with_input(test_func, [torch.rand(input_shape).float()])
+
+
 @tvm.testing.uses_gpu
 def test_forward_zeros():
     torch.set_grad_enabled(False)
@@ -3034,6 +3077,24 @@ def test_forward_full_like():
     verify_model(FullLike3().float().eval(), input_data=input_data)
 
 
+@tvm.testing.uses_gpu
+def test_forward_new_full():
+    torch.set_grad_enabled(False)
+    input_shape = [1, 3, 10, 10]
+
+    def test_func(input_tensor):
+        return input_tensor.new_full([2, 3], 1)
+
+    verify_model_with_input(test_func, [torch.rand(input_shape).float()])
+
+
+def test_forward_fill_():
+    def test_func(x):
+        return x.fill_(3)
+
+    verify_model_with_input(test_func, [torch.rand([1, 3, 10, 10]).float()])
+
+
 @tvm.testing.uses_gpu
 def test_forward_linspace():
     torch.set_grad_enabled(False)
@@ -3752,6 +3813,20 @@ def test_numel():
     verify_script_model(Numel(), [(3, 5, 8)], targets)
 
 
+def test_empty():
+    def test_func():
+        return torch.empty([1, 3, 10, 10])
+
+    verify_model_with_input(test_func, [])
+
+
+def test_empty_like():
+    def test_func(data):
+        return torch.empty_like(data)
+
+    verify_model_with_input(test_func, [torch.rand([1, 3, 10, 10]).float()])
+
+
 def test_forward_pretrained_bert_base_uncased():
     ######################################################################
     # This is an example how to run BERT models using TVM