You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2022/06/29 23:53:42 UTC
[tvm] branch main updated: [Relay][Pytorch] Add aten::new_ones, aten::new_full, aten::fill_, aten::pad, aten::reshape_as and atem::empty_like (#11896)
This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 41c94b27ef [Relay][Pytorch] Add aten::new_ones, aten::new_full, aten::fill_, aten::pad, aten::reshape_as and atem::empty_like (#11896)
41c94b27ef is described below
commit 41c94b27ef5f10ad70af211dd25c4837dad53f64
Author: Yuanjing Shi <yu...@octoml.ai>
AuthorDate: Wed Jun 29 16:53:36 2022 -0700
[Relay][Pytorch] Add aten::new_ones, aten::new_full, aten::fill_, aten::pad, aten::reshape_as and atem::empty_like (#11896)
* add new ops
* fix pad
* fix pad
* remove pad
* fix CI
* remove doc
* fix fill_
* add tests
---
python/tvm/relay/frontend/pytorch.py | 55 ++++++++++++++++++++
tests/python/frontend/pytorch/test_forward.py | 75 +++++++++++++++++++++++++++
2 files changed, 130 insertions(+)
diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py
index 9558ad1b6e..6fe8c89e3c 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -701,6 +701,21 @@ class PyTorchOpConverter:
return out
+ def new_ones(self, inputs, input_types):
+ size = inputs[1]
+
+ import torch
+
+ if not isinstance(size, (_expr.Expr, list, tuple, torch.Size, np.ndarray)):
+ msg = "Data type %s could not be parsed in ones op" % (type(size))
+ raise AssertionError(msg)
+
+ if inputs[2] is not None:
+ dtype = _convert_dtype_value(inputs[2])
+ else:
+ dtype = input_types[0]
+ return self.full_impl(size, 1, dtype)
+
def zeros(self, inputs, input_types):
data = inputs[0]
@@ -765,6 +780,28 @@ class PyTorchOpConverter:
return out
+ def new_full(self, inputs, input_types):
+ data = inputs[1]
+ fill_value = inputs[2]
+ import torch
+
+ if not isinstance(data, (_expr.Expr, list, tuple, torch.Size)):
+ msg = "Data type %s could not be parsed in full op" % (type(data))
+ raise AssertionError(msg)
+
+ if inputs[3] is not None: # dtype given
+ dtype = _convert_dtype_value(inputs[3])
+ else:
+ # if dtype is None, use the dtype of the input tensor
+ dtype = self.infer_type(input[0])
+
+ return self.full_impl(data, fill_value, dtype)
+
+ def fill_(self, inputs, input_types):
+ data = inputs[0]
+ fill_value = inputs[1]
+ return self.full_impl(self.infer_shape(data), fill_value, input_types[0])
+
def linspace(self, inputs, input_types):
start = inputs[0]
stop = inputs[1]
@@ -1425,6 +1462,11 @@ class PyTorchOpConverter:
new_shape = tmp_shape
return _op.transform.reshape(data, new_shape)
+ def reshape_as(self, inputs, input_types):
+ data = inputs[0]
+ new_shape = self.infer_shape(inputs[1])
+ return _op.transform.reshape(data, new_shape)
+
def pixel_shuffle(self, inputs, input_types):
data = inputs[0]
upscale_factor = inputs[1]
@@ -2400,6 +2442,14 @@ class PyTorchOpConverter:
shape = inputs[0]
return _op.zeros(shape, _convert_dtype_value(inputs[1]))
+ def empty_like(self, inputs, input_types):
+ shape = self.infer_shape(inputs[0])
+ if inputs[1] is not None:
+ dtype = _convert_dtype_value(inputs[1])
+ else:
+ dtype = input_types[0]
+ return _op.zeros(shape, dtype)
+
def bincount(self, inputs, input_types):
data = inputs[0]
weights = inputs[1]
@@ -3119,8 +3169,11 @@ class PyTorchOpConverter:
"aten::ones_like": self.ones_like,
"aten::zeros": self.zeros,
"aten::zeros_like": self.zeros_like,
+ "aten::new_ones": self.new_ones,
"aten::full": self.full,
"aten::full_like": self.full_like,
+ "aten::new_full": self.new_full,
+ "aten::fill_": self.fill_,
"aten::linspace": self.linspace,
"aten::reciprocal": self.reciprocal,
"aten::repeat": self.repeat,
@@ -3186,6 +3239,7 @@ class PyTorchOpConverter:
"aten::size": self.size,
"aten::view": self.view,
"aten::reshape": self.reshape,
+ "aten::reshape_as": self.reshape_as,
"aten::clone": self.clone,
"aten::log_softmax": self.log_softmax,
"aten::sigmoid": self.sigmoid,
@@ -3305,6 +3359,7 @@ class PyTorchOpConverter:
"aten::tensor": self.identity, # used for example in tensor(1.0)
"aten::numel": self.numel,
"aten::empty": self.empty,
+ "aten::empty_like": self.empty_like,
"aten::bincount": self.bincount,
"aten::scatter_add": self.scatter_add,
"aten::__not__": self.logical_not,
diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py
index 1bb4517f01..f039a00f5d 100644
--- a/tests/python/frontend/pytorch/test_forward.py
+++ b/tests/python/frontend/pytorch/test_forward.py
@@ -199,6 +199,28 @@ def verify_model(
torch.cuda.empty_cache()
+def verify_model_with_input(test_func, input_data, input_dict={}):
+ baseline_outputs = test_func(*input_data)
+ trace = torch.jit.trace(test_func, [input.clone() for input in input_data])
+ input_names = ["input{}".format(idx) for idx, inp in enumerate(input_data)]
+ input_shapes = list(zip(input_names, [inp.shape for inp in input_data]))
+ mod, params = relay.frontend.from_pytorch(trace, input_shapes, {})
+ with tvm.transform.PassContext(opt_level=3):
+ for target in ["llvm", "cuda"]:
+ if not tvm.runtime.enabled(target):
+ continue
+ dev = tvm.device(target, 0)
+ lib = relay.build(mod, target=target, params=params)
+ relay_model = graph_executor.GraphModule(lib["default"](dev))
+ for name, value in input_dict.items():
+ relay_model.set_input(name, value)
+ relay_model.run()
+
+ compiled_output = relay_model.get_output(0).numpy()
+ assert_shapes_match(baseline_outputs, compiled_output)
+ tvm.testing.assert_allclose(baseline_outputs, compiled_output, rtol=1e-5, atol=1e-5)
+
+
# Single operator tests
@tvm.testing.uses_gpu
def test_forward_pixel_shuffle():
@@ -1275,6 +1297,16 @@ def test_forward_reshape():
verify_model(Reshape3(), input_data=torch.randn(2, 3, 4))
+@tvm.testing.uses_gpu
+def test_forward_reshape_as():
+ def test_func(input_tensor, other_tensor):
+ return input_tensor.reshape_as(other_tensor)
+
+ input_data = [torch.rand([2, 1, 10, 1, 10]), torch.rand([2, 1, 10, 10])]
+
+ verify_model_with_input(test_func, input_data, {"input0": input_data[0]})
+
+
@tvm.testing.uses_gpu
def test_flatten():
def _test_flatten(start_dim, end_dim):
@@ -2961,6 +2993,17 @@ def test_forward_ones_like():
verify_model(OnesLike3().float().eval(), input_data=input_data)
+@tvm.testing.uses_gpu
+def test_forward_new_ones():
+ torch.set_grad_enabled(False)
+ input_shape = [1, 3, 10, 10]
+
+ def test_func(input_tensor):
+ return input_tensor.new_ones([3, 10, 10])
+
+ verify_model_with_input(test_func, [torch.rand(input_shape).float()])
+
+
@tvm.testing.uses_gpu
def test_forward_zeros():
torch.set_grad_enabled(False)
@@ -3034,6 +3077,24 @@ def test_forward_full_like():
verify_model(FullLike3().float().eval(), input_data=input_data)
+@tvm.testing.uses_gpu
+def test_forward_new_full():
+ torch.set_grad_enabled(False)
+ input_shape = [1, 3, 10, 10]
+
+ def test_func(input_tensor):
+ return input_tensor.new_full([2, 3], 1)
+
+ verify_model_with_input(test_func, [torch.rand(input_shape).float()])
+
+
+def test_forward_fill_():
+ def test_func(x):
+ return x.fill_(3)
+
+ verify_model_with_input(test_func, [torch.rand([1, 3, 10, 10]).float()])
+
+
@tvm.testing.uses_gpu
def test_forward_linspace():
torch.set_grad_enabled(False)
@@ -3752,6 +3813,20 @@ def test_numel():
verify_script_model(Numel(), [(3, 5, 8)], targets)
+def test_empty():
+ def test_func():
+ return torch.empty([1, 3, 10, 10])
+
+ verify_model_with_input(test_func, [])
+
+
+def test_empty_like():
+ def test_func(data):
+ return torch.empty_like(data)
+
+ verify_model_with_input(test_func, [torch.rand([1, 3, 10, 10]).float()])
+
+
def test_forward_pretrained_bert_base_uncased():
######################################################################
# This is an example how to run BERT models using TVM