You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2022/08/12 07:02:53 UTC

[tvm] branch main updated: [PyTorch] Fix pad_common for float pad_value (#12134)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 22dcf4490d [PyTorch] Fix pad_common for float pad_value (#12134)
22dcf4490d is described below

commit 22dcf4490dacc7813f5ef3d700ab0b64171c7662
Author: Yuanjing Shi <yu...@octoml.ai>
AuthorDate: Thu Aug 11 21:02:48 2022 -1000

    [PyTorch] Fix pad_common for float pad_value (#12134)
    
    * fix pad
    
    * fix constant padding and handle float infinity
    
    * revert change to pad_width
    
    * fix constant pad value
---
 python/tvm/relay/frontend/pytorch.py          | 11 ++++-----
 tests/python/frontend/pytorch/test_forward.py | 32 +++++++++++++++++++++++++--
 2 files changed, 36 insertions(+), 7 deletions(-)

diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py
index 0fe8d57464..ffe4b313c5 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -1905,7 +1905,7 @@ class PyTorchOpConverter:
 
         # initialize paddings based on input len
         pad_len = len(self.infer_shape(data)) * 2
-        paddings = [pad_value] * pad_len
+        paddings = [0] * pad_len
 
         if len(pad_list) >= 2:
             paddings[-1] = pad_list[1]
@@ -1925,8 +1925,10 @@ class PyTorchOpConverter:
         for pad in paddings:
             const_paddings.append([])
             for p in pad:
-                if not isinstance(p, int):
+                if isinstance(p, _expr.Expr):
                     p = int(_infer_value(p, {}).numpy())
+                elif not isinstance(p, int):
+                    raise NotImplementedError("pad width should be int/expr")
                 const_paddings[-1].append(p)
                 if p != 0:
                     non_zero_found = True
@@ -1934,12 +1936,11 @@ class PyTorchOpConverter:
         if not non_zero_found:
             return data
         elif mode == "constant":
-            return _op.nn.pad(data, const_paddings, pad_value=inputs[2], pad_mode=mode)
+            return _op.nn.pad(data, const_paddings, pad_value=pad_value, pad_mode=mode)
         else:
             return _op.nn.pad(data, const_paddings, pad_mode=mode)
 
     def pad(self, inputs, input_types):
-
         # mode: Optional default "constant"
         if len(inputs) > 2 and inputs[2] is not None:
             mode = inputs[2]
@@ -1960,7 +1961,7 @@ class PyTorchOpConverter:
         return self.pad_common(mode, pad_value, inputs, input_types)
 
     def constant_pad_nd(self, inputs, input_types):
-        return self.pad_common("constant", 0, inputs, input_types)
+        return self.pad_common("constant", _expr.const(inputs[2]), inputs, input_types)
 
     def reflection_pad1d(self, inputs, input_types):
         return self.pad_common("reflect", 0, inputs, input_types)
diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py
index bc848f90b3..6b1eb30a56 100755
--- a/tests/python/frontend/pytorch/test_forward.py
+++ b/tests/python/frontend/pytorch/test_forward.py
@@ -2010,6 +2010,34 @@ def test_forward_functional_pad():
     pad = (0, 1, 2, 1, 3, 3)
     verify_model(Pad1().float().eval(), input_data=input_data)
 
+    class Pad2(Module):
+        def forward(self, *args):
+            return torch.nn.functional.pad(args[0], pad, "constant", 1)
+
+    input_data = torch.rand((3, 3, 4, 2))
+    pad = (1, 1)
+    verify_model(Pad2().float().eval(), input_data=input_data)
+
+    pad = (1, 1, 2, 2)
+    verify_model(Pad2().float().eval(), input_data=input_data)
+
+    pad = (0, 1, 2, 1, 3, 3)
+    verify_model(Pad2().float().eval(), input_data=input_data)
+
+    class Pad3(Module):
+        def forward(self, *args):
+            return torch.nn.functional.pad(args[0], pad, "constant", 1.0)
+
+    input_data = torch.rand((3, 3, 4, 2))
+    pad = (1, 1)
+    verify_model(Pad3().float().eval(), input_data=input_data)
+
+    pad = (1, 1, 2, 2)
+    verify_model(Pad3().float().eval(), input_data=input_data)
+
+    pad = (0, 1, 2, 1, 3, 3)
+    verify_model(Pad3().float().eval(), input_data=input_data)
+
 
 @tvm.testing.uses_gpu
 def test_forward_zero_pad2d():
@@ -2021,10 +2049,10 @@ def test_forward_zero_pad2d():
 @tvm.testing.uses_gpu
 def test_forward_constant_pad1d():
     inp = torch.rand((1, 2, 4))
-    verify_model(torch.nn.ConstantPad2d(2, 3.5).eval(), inp)
+    verify_model(torch.nn.ConstantPad1d(2, 3.5).eval(), inp)
 
     inp = torch.rand((1, 2, 3))
-    verify_model(torch.nn.ConstantPad2d((3, 1), 3.5).eval(), inp)
+    verify_model(torch.nn.ConstantPad1d((3, 1), 3.5).eval(), inp)
 
 
 @tvm.testing.uses_gpu