You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2021/03/23 08:57:33 UTC

[tvm] branch main updated: [TORCH] Implement avg_pool1d (#7694)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new f09f02e  [TORCH] Implement avg_pool1d (#7694)
f09f02e is described below

commit f09f02e575b2bd1d9187a4ff2eb178d49fd3dd22
Author: Christoph Gerum <ch...@uni-tuebingen.de>
AuthorDate: Tue Mar 23 09:57:15 2021 +0100

    [TORCH] Implement avg_pool1d (#7694)
    
    * [TORCH] Implement avg_pool1d
    
    * [TORCH] Unify creation of avg_pooling operations
    
    * [TORCH] Add tests for avg pooling with padding
    
    * [TORCH] Make format checks happy with unified avg_pool
---
 python/tvm/relay/frontend/pytorch.py          | 84 +++++++++++++++------------
 tests/python/frontend/pytorch/test_forward.py | 28 ++++++++-
 2 files changed, 72 insertions(+), 40 deletions(-)

diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py
index 8ae1e86..cb9ea6a 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -1353,47 +1353,54 @@ class PyTorchOpConverter:
         beta = _expr.const(float(inputs[1]), dtype=dtype)
         return _op.log(_op.exp(inputs[0] * beta) + _expr.const(1.0, dtype=dtype)) / beta
 
-    def avg_pool2d(self, inputs, input_types):
-        data = inputs[0]
-
-        pool_size = self.convert_const_list(inputs[1])
-        strides = self.convert_const_list(inputs[2] if inputs[2] else pool_size)
-        padding = inputs[3]
-        ceil_mode = int(inputs[4])
-        count_include_pad = int(inputs[5])
-
-        def func(x):
-            return _op.nn.avg_pool2d(
-                x,
-                pool_size=pool_size,
-                strides=strides,
-                padding=padding,
-                ceil_mode=ceil_mode,
-                count_include_pad=count_include_pad,
-            )
+    def make_avg_pool(self, dim):
+        def avg_pool(inputs, input_types):
+            data = inputs[0]
 
-        if self.is_quantized_tensor(data):
-            return qnn_torch.apply_with_upcast(data, func)
+            pool_size = self.convert_const_list(inputs[1])
+            strides = self.convert_const_list(inputs[2] if inputs[2] else pool_size)
+            padding = inputs[3]
+            ceil_mode = int(inputs[4])
+            count_include_pad = int(inputs[5])
 
-        return func(data)
+            def func(x):
+                if dim == 1:
+                    return _op.nn.avg_pool1d(
+                        x,
+                        pool_size=pool_size,
+                        strides=strides,
+                        padding=padding,
+                        ceil_mode=ceil_mode,
+                        count_include_pad=count_include_pad,
+                    )
+                elif dim == 2:
+                    return _op.nn.avg_pool2d(
+                        x,
+                        pool_size=pool_size,
+                        strides=strides,
+                        padding=padding,
+                        ceil_mode=ceil_mode,
+                        count_include_pad=count_include_pad,
+                    )
+                elif dim == 3:
+                    return _op.nn.avg_pool3d(
+                        x,
+                        pool_size=pool_size,
+                        strides=strides,
+                        padding=padding,
+                        ceil_mode=ceil_mode,
+                        count_include_pad=count_include_pad,
+                    )
+                else:
+                    msg = "Average Pooling dimension should be between 1 and 3"
+                    raise RuntimeError(msg)
 
-    def avg_pool3d(self, inputs, input_types):
-        data = inputs[0]
+            if self.is_quantized_tensor(data):
+                return qnn_torch.apply_with_upcast(data, func)
 
-        pool_size = inputs[1]
-        strides = inputs[2] if inputs[2] else pool_size
-        padding = inputs[3]
-        ceil_mode = int(inputs[4])
-        count_include_pad = int(inputs[5])
+            return func(data)
 
-        return _op.nn.avg_pool3d(
-            data,
-            pool_size=pool_size,
-            strides=strides,
-            padding=padding,
-            ceil_mode=ceil_mode,
-            count_include_pad=count_include_pad,
-        )
+        return avg_pool
 
     def linear(self, inputs, input_types):
         # https://pytorch.org/docs/stable/nn.functional.html#linear
@@ -2350,8 +2357,9 @@ class PyTorchOpConverter:
             "aten::log_softmax": self.log_softmax,
             "aten::sigmoid": self.sigmoid,
             "aten::softplus": self.softplus,
-            "aten::avg_pool2d": self.avg_pool2d,
-            "aten::avg_pool3d": self.avg_pool3d,
+            "aten::avg_pool1d": self.make_avg_pool(1),
+            "aten::avg_pool2d": self.make_avg_pool(2),
+            "aten::avg_pool3d": self.make_avg_pool(3),
             "aten::linear": self.linear,
             "aten::dropout": self.dropout,
             "aten::dropout_": self.dropout,
diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py
index d0edfd9..572aa47 100644
--- a/tests/python/frontend/pytorch/test_forward.py
+++ b/tests/python/frontend/pytorch/test_forward.py
@@ -809,7 +809,24 @@ def test_forward_split():
 
 
 @tvm.testing.uses_gpu
-def test_forward_avgpool():
+def test_forward_avgpool1d():
+    torch.set_grad_enabled(False)
+    input_shape = [1, 3, 10]
+
+    class AvgPool1D2(Module):
+        def forward(self, *args):
+            return torch.nn.functional.avg_pool1d(args[0], kernel_size=[10])
+
+    input_data = torch.rand(input_shape).float()
+    verify_model(torch.nn.AvgPool1d(kernel_size=[10]).eval(), input_data=input_data)
+    verify_model(AvgPool1D2().float().eval(), input_data=input_data)
+    verify_model(
+        torch.nn.AvgPool1d(kernel_size=[5], stride=2, padding=2).eval(), input_data=input_data
+    )
+
+
+@tvm.testing.uses_gpu
+def test_forward_avgpool2d():
     torch.set_grad_enabled(False)
     input_shape = [1, 3, 10, 10]
 
@@ -820,6 +837,9 @@ def test_forward_avgpool():
     input_data = torch.rand(input_shape).float()
     verify_model(torch.nn.AvgPool2d(kernel_size=[10, 10]).eval(), input_data=input_data)
     verify_model(AvgPool2D2().float().eval(), input_data=input_data)
+    verify_model(
+        torch.nn.AvgPool2d(kernel_size=5, stride=2, padding=2).eval(), input_data=input_data
+    )
 
 
 @tvm.testing.uses_gpu
@@ -834,6 +854,9 @@ def test_forward_avgpool3d():
     input_data = torch.rand(input_shape).float()
     verify_model(torch.nn.AvgPool3d(kernel_size=[10, 10, 10]).eval(), input_data=input_data)
     verify_model(AvgPool3D1().float().eval(), input_data=input_data)
+    verify_model(
+        torch.nn.AvgPool3d(kernel_size=5, stride=2, padding=2).eval(), input_data=input_data
+    )
 
 
 @tvm.testing.uses_gpu
@@ -3838,7 +3861,8 @@ if __name__ == "__main__":
     test_forward_logsoftmax()
     test_forward_sigmoid()
     test_forward_dense()
-    test_forward_avgpool()
+    test_forward_avgpool1d()
+    test_forward_avgpool2d()
     test_forward_avgpool3d()
     test_forward_dropout()
     test_forward_slice()