You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2022/09/26 09:14:40 UTC
[tvm] branch main updated: [frontend][pytorch]Support aten::Tensor_split operator (#12871)
This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 87085b0e0d [frontend][pytorch]Support aten::Tensor_split operator (#12871)
87085b0e0d is described below
commit 87085b0e0dad2a422993472e35431d4f22fd69d8
Author: chengven027-intellif <da...@hotmail.com>
AuthorDate: Mon Sep 26 17:14:33 2022 +0800
[frontend][pytorch]Support aten::Tensor_split operator (#12871)
Support aten::Tensor_split operator
---
python/tvm/relay/frontend/pytorch.py | 54 +++++++++++++++++++++++++++
tests/python/frontend/pytorch/test_forward.py | 22 +++++++++++
2 files changed, 76 insertions(+)
diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py
index c1bf69502b..1b86b120df 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -559,6 +559,59 @@ class PyTorchOpConverter:
return _op.split(data, indices, dim)
+ def tensor_split(self, inputs, input_types):
+ # Reference: https://pytorch.org/docs/stable/generated/torch.tensor_split.html
+ import torch
+
+ if not isinstance(inputs[1], (int, list, tuple, torch.Tensor)):
+ msg = "indices_or_sections type %s could not be parsed in tensor_split op" % (
+ type(inputs[1])
+ )
+ raise AssertionError(msg)
+
+ if isinstance(inputs[1], torch.Tensor) and not (
+ list(inputs[1].shape) == [] or list(inputs[1].shape) == 1
+ ):
+ msg = "indices_or_sections must be a zero-dimensional or one-dimensional long tensor"
+ raise AssertionError(msg)
+
+ if isinstance(inputs[1], int) or (
+ isinstance(inputs[1], torch.Tensor) and list(inputs[1].shape) == []
+ ):
+ data = inputs[0]
+ n = int(inputs[1])
+ dim = int(inputs[2])
+
+ split_size = int(self.infer_shape(data)[dim] / n)
+ split_rest = int(self.infer_shape(data)[dim] % n)
+
+ indices = []
+ split_index = split_size
+ if split_rest == 0:
+ for i in range(n - 1):
+ indices.append(split_index)
+ split_index += split_size
+ else:
+ for i in range(split_rest):
+ indices.append(split_index + 1)
+ split_index = (i + 1) * (split_index + 1)
+ for i in range(n - split_rest - 1):
+ split_index += split_size
+ indices.append(split_index)
+
+ return _op.split(data, indices, dim)
+ else:
+ data = inputs[0]
+ sections = inputs[1]
+ dim = int(inputs[2])
+
+ if isinstance(sections, tuple):
+ sections = list(sections)
+ elif isinstance(sections, torch.Tensor):
+ sections = sections.cpu().numpy().tolist()
+
+ return _op.split(data, sections, dim)
+
def select(self, inputs, input_types):
data = inputs[0]
dim = int(inputs[1])
@@ -3484,6 +3537,7 @@ class PyTorchOpConverter:
"aten::slice": self.slice,
"aten::narrow": self.narrow,
"aten::split": self.split,
+ "aten::tensor_split": self.tensor_split,
"aten::split_with_sizes": self.split_with_sizes,
"aten::select": self.select,
"aten::take": self.take,
diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py
index 33c70a4d74..3c8bd5efd8 100755
--- a/tests/python/frontend/pytorch/test_forward.py
+++ b/tests/python/frontend/pytorch/test_forward.py
@@ -959,6 +959,28 @@ def test_forward_split():
verify_model(Split([2, 3, 5], 1).float().eval(), input_data=input_data)
+@tvm.testing.uses_gpu
+def test_forward_tensor_split():
+ """test_forward_tensor_split"""
+ torch.set_grad_enabled(False)
+ input_shape = [4, 10]
+
+ class Tensor_Split(Module):
+ def __init__(self, split_size_or_sections, dim):
+ super().__init__()
+ self.split_size_or_sections = split_size_or_sections
+ self.dim = dim
+
+ def forward(self, *args):
+ return torch.tensor_split(args[0], self.split_size_or_sections, self.dim)
+
+ input_data = torch.rand(input_shape).float()
+ verify_model(Tensor_Split(2, 0).float().eval(), input_data=input_data)
+ verify_model(Tensor_Split(torch.tensor(3), 1).float().eval(), input_data=input_data)
+ verify_model(Tensor_Split([2, 3, 5], 1).float().eval(), input_data=input_data)
+ verify_model(Tensor_Split((2, 3, 5), 1).float().eval(), input_data=input_data)
+
+
@tvm.testing.uses_gpu
def test_forward_avgpool1d():
"""test_forward_avgpool1d"""