You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2021/03/23 08:57:33 UTC
[tvm] branch main updated: [TORCH] Implement avg_pool1d (#7694)
This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new f09f02e [TORCH] Implement avg_pool1d (#7694)
f09f02e is described below
commit f09f02e575b2bd1d9187a4ff2eb178d49fd3dd22
Author: Christoph Gerum <ch...@uni-tuebingen.de>
AuthorDate: Tue Mar 23 09:57:15 2021 +0100
[TORCH] Implement avg_pool1d (#7694)
* [TORCH] Implement avg_pool1d
* [TORCH] Unify creation of avg_pooling operations
* [TORCH] Add tests for avg pooling with padding
* [TORCH] Make format checks happy with unified avg_pool
---
python/tvm/relay/frontend/pytorch.py | 84 +++++++++++++++------------
tests/python/frontend/pytorch/test_forward.py | 28 ++++++++-
2 files changed, 72 insertions(+), 40 deletions(-)
diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py
index 8ae1e86..cb9ea6a 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -1353,47 +1353,54 @@ class PyTorchOpConverter:
beta = _expr.const(float(inputs[1]), dtype=dtype)
return _op.log(_op.exp(inputs[0] * beta) + _expr.const(1.0, dtype=dtype)) / beta
- def avg_pool2d(self, inputs, input_types):
- data = inputs[0]
-
- pool_size = self.convert_const_list(inputs[1])
- strides = self.convert_const_list(inputs[2] if inputs[2] else pool_size)
- padding = inputs[3]
- ceil_mode = int(inputs[4])
- count_include_pad = int(inputs[5])
-
- def func(x):
- return _op.nn.avg_pool2d(
- x,
- pool_size=pool_size,
- strides=strides,
- padding=padding,
- ceil_mode=ceil_mode,
- count_include_pad=count_include_pad,
- )
+ def make_avg_pool(self, dim):
+ def avg_pool(inputs, input_types):
+ data = inputs[0]
- if self.is_quantized_tensor(data):
- return qnn_torch.apply_with_upcast(data, func)
+ pool_size = self.convert_const_list(inputs[1])
+ strides = self.convert_const_list(inputs[2] if inputs[2] else pool_size)
+ padding = inputs[3]
+ ceil_mode = int(inputs[4])
+ count_include_pad = int(inputs[5])
- return func(data)
+ def func(x):
+ if dim == 1:
+ return _op.nn.avg_pool1d(
+ x,
+ pool_size=pool_size,
+ strides=strides,
+ padding=padding,
+ ceil_mode=ceil_mode,
+ count_include_pad=count_include_pad,
+ )
+ elif dim == 2:
+ return _op.nn.avg_pool2d(
+ x,
+ pool_size=pool_size,
+ strides=strides,
+ padding=padding,
+ ceil_mode=ceil_mode,
+ count_include_pad=count_include_pad,
+ )
+ elif dim == 3:
+ return _op.nn.avg_pool3d(
+ x,
+ pool_size=pool_size,
+ strides=strides,
+ padding=padding,
+ ceil_mode=ceil_mode,
+ count_include_pad=count_include_pad,
+ )
+ else:
+ msg = "Average Pooling dimension should be between 1 and 3"
+ raise RuntimeError(msg)
- def avg_pool3d(self, inputs, input_types):
- data = inputs[0]
+ if self.is_quantized_tensor(data):
+ return qnn_torch.apply_with_upcast(data, func)
- pool_size = inputs[1]
- strides = inputs[2] if inputs[2] else pool_size
- padding = inputs[3]
- ceil_mode = int(inputs[4])
- count_include_pad = int(inputs[5])
+ return func(data)
- return _op.nn.avg_pool3d(
- data,
- pool_size=pool_size,
- strides=strides,
- padding=padding,
- ceil_mode=ceil_mode,
- count_include_pad=count_include_pad,
- )
+ return avg_pool
def linear(self, inputs, input_types):
# https://pytorch.org/docs/stable/nn.functional.html#linear
@@ -2350,8 +2357,9 @@ class PyTorchOpConverter:
"aten::log_softmax": self.log_softmax,
"aten::sigmoid": self.sigmoid,
"aten::softplus": self.softplus,
- "aten::avg_pool2d": self.avg_pool2d,
- "aten::avg_pool3d": self.avg_pool3d,
+ "aten::avg_pool1d": self.make_avg_pool(1),
+ "aten::avg_pool2d": self.make_avg_pool(2),
+ "aten::avg_pool3d": self.make_avg_pool(3),
"aten::linear": self.linear,
"aten::dropout": self.dropout,
"aten::dropout_": self.dropout,
diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py
index d0edfd9..572aa47 100644
--- a/tests/python/frontend/pytorch/test_forward.py
+++ b/tests/python/frontend/pytorch/test_forward.py
@@ -809,7 +809,24 @@ def test_forward_split():
@tvm.testing.uses_gpu
-def test_forward_avgpool():
+def test_forward_avgpool1d():
+ torch.set_grad_enabled(False)
+ input_shape = [1, 3, 10]
+
+ class AvgPool1D2(Module):
+ def forward(self, *args):
+ return torch.nn.functional.avg_pool1d(args[0], kernel_size=[10])
+
+ input_data = torch.rand(input_shape).float()
+ verify_model(torch.nn.AvgPool1d(kernel_size=[10]).eval(), input_data=input_data)
+ verify_model(AvgPool1D2().float().eval(), input_data=input_data)
+ verify_model(
+ torch.nn.AvgPool1d(kernel_size=[5], stride=2, padding=2).eval(), input_data=input_data
+ )
+
+
+@tvm.testing.uses_gpu
+def test_forward_avgpool2d():
torch.set_grad_enabled(False)
input_shape = [1, 3, 10, 10]
@@ -820,6 +837,9 @@ def test_forward_avgpool():
input_data = torch.rand(input_shape).float()
verify_model(torch.nn.AvgPool2d(kernel_size=[10, 10]).eval(), input_data=input_data)
verify_model(AvgPool2D2().float().eval(), input_data=input_data)
+ verify_model(
+ torch.nn.AvgPool2d(kernel_size=5, stride=2, padding=2).eval(), input_data=input_data
+ )
@tvm.testing.uses_gpu
@@ -834,6 +854,9 @@ def test_forward_avgpool3d():
input_data = torch.rand(input_shape).float()
verify_model(torch.nn.AvgPool3d(kernel_size=[10, 10, 10]).eval(), input_data=input_data)
verify_model(AvgPool3D1().float().eval(), input_data=input_data)
+ verify_model(
+ torch.nn.AvgPool3d(kernel_size=5, stride=2, padding=2).eval(), input_data=input_data
+ )
@tvm.testing.uses_gpu
@@ -3838,7 +3861,8 @@ if __name__ == "__main__":
test_forward_logsoftmax()
test_forward_sigmoid()
test_forward_dense()
- test_forward_avgpool()
+ test_forward_avgpool1d()
+ test_forward_avgpool2d()
test_forward_avgpool3d()
test_forward_dropout()
test_forward_slice()