You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2022/09/26 09:14:40 UTC

[tvm] branch main updated: [frontend][pytorch]Support aten::Tensor_split operator (#12871)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 87085b0e0d [frontend][pytorch]Support aten::Tensor_split operator (#12871)
87085b0e0d is described below

commit 87085b0e0dad2a422993472e35431d4f22fd69d8
Author: chengven027-intellif <da...@hotmail.com>
AuthorDate: Mon Sep 26 17:14:33 2022 +0800

    [frontend][pytorch]Support aten::Tensor_split operator (#12871)
    
    Support aten::Tensor_split operator
---
 python/tvm/relay/frontend/pytorch.py          | 54 +++++++++++++++++++++++++++
 tests/python/frontend/pytorch/test_forward.py | 22 +++++++++++
 2 files changed, 76 insertions(+)

diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py
index c1bf69502b..1b86b120df 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -559,6 +559,59 @@ class PyTorchOpConverter:
 
         return _op.split(data, indices, dim)
 
+    def tensor_split(self, inputs, input_types):
+        # Reference: https://pytorch.org/docs/stable/generated/torch.tensor_split.html
+        import torch
+
+        if not isinstance(inputs[1], (int, list, tuple, torch.Tensor)):
+            msg = "indices_or_sections type %s could not be parsed in tensor_split op" % (
+                type(inputs[1])
+            )
+            raise AssertionError(msg)
+
+        if isinstance(inputs[1], torch.Tensor) and not (
+            list(inputs[1].shape) == [] or list(inputs[1].shape) == 1
+        ):
+            msg = "indices_or_sections must be a zero-dimensional or one-dimensional long tensor"
+            raise AssertionError(msg)
+
+        if isinstance(inputs[1], int) or (
+            isinstance(inputs[1], torch.Tensor) and list(inputs[1].shape) == []
+        ):
+            data = inputs[0]
+            n = int(inputs[1])
+            dim = int(inputs[2])
+
+            split_size = int(self.infer_shape(data)[dim] / n)
+            split_rest = int(self.infer_shape(data)[dim] % n)
+
+            indices = []
+            split_index = split_size
+            if split_rest == 0:
+                for i in range(n - 1):
+                    indices.append(split_index)
+                    split_index += split_size
+            else:
+                for i in range(split_rest):
+                    indices.append(split_index + 1)
+                    split_index = (i + 1) * (split_index + 1)
+                for i in range(n - split_rest - 1):
+                    split_index += split_size
+                    indices.append(split_index)
+
+            return _op.split(data, indices, dim)
+        else:
+            data = inputs[0]
+            sections = inputs[1]
+            dim = int(inputs[2])
+
+            if isinstance(sections, tuple):
+                sections = list(sections)
+            elif isinstance(sections, torch.Tensor):
+                sections = sections.cpu().numpy().tolist()
+
+            return _op.split(data, sections, dim)
+
     def select(self, inputs, input_types):
         data = inputs[0]
         dim = int(inputs[1])
@@ -3484,6 +3537,7 @@ class PyTorchOpConverter:
             "aten::slice": self.slice,
             "aten::narrow": self.narrow,
             "aten::split": self.split,
+            "aten::tensor_split": self.tensor_split,
             "aten::split_with_sizes": self.split_with_sizes,
             "aten::select": self.select,
             "aten::take": self.take,
diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py
index 33c70a4d74..3c8bd5efd8 100755
--- a/tests/python/frontend/pytorch/test_forward.py
+++ b/tests/python/frontend/pytorch/test_forward.py
@@ -959,6 +959,28 @@ def test_forward_split():
     verify_model(Split([2, 3, 5], 1).float().eval(), input_data=input_data)
 
 
+@tvm.testing.uses_gpu
+def test_forward_tensor_split():
+    """test_forward_tensor_split"""
+    torch.set_grad_enabled(False)
+    input_shape = [4, 10]
+
+    class Tensor_Split(Module):
+        def __init__(self, split_size_or_sections, dim):
+            super().__init__()
+            self.split_size_or_sections = split_size_or_sections
+            self.dim = dim
+
+        def forward(self, *args):
+            return torch.tensor_split(args[0], self.split_size_or_sections, self.dim)
+
+    input_data = torch.rand(input_shape).float()
+    verify_model(Tensor_Split(2, 0).float().eval(), input_data=input_data)
+    verify_model(Tensor_Split(torch.tensor(3), 1).float().eval(), input_data=input_data)
+    verify_model(Tensor_Split([2, 3, 5], 1).float().eval(), input_data=input_data)
+    verify_model(Tensor_Split((2, 3, 5), 1).float().eval(), input_data=input_data)
+
+
 @tvm.testing.uses_gpu
 def test_forward_avgpool1d():
     """test_forward_avgpool1d"""