You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by sy...@apache.org on 2021/12/28 04:01:08 UTC

[tvm] branch main updated: [Topi] fix get_pad_tuple3d bug, the conv3d kernel layout should be DHW. (#9788)

This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 7448eab  [Topi] fix get_pad_tuple3d bug, the conv3d kernel layout should be DHW. (#9788)
7448eab is described below

commit 7448eab300c0710a5649083bec53a631b0fa2ebd
Author: fredjon <jf...@163.com>
AuthorDate: Tue Dec 28 11:59:53 2021 +0800

    [Topi] fix get_pad_tuple3d bug, the conv3d kernel layout should be DHW. (#9788)
---
 python/tvm/topi/nn/utils.py                        |  6 +-
 tests/python/topi/python/test_topi_conv3d_ncdhw.py | 66 +++++++++++++++++++---
 .../python/test_topi_conv3d_transpose_ncdhw.py     | 12 ++++
 3 files changed, 73 insertions(+), 11 deletions(-)

diff --git a/python/tvm/topi/nn/utils.py b/python/tvm/topi/nn/utils.py
index ff00441..369b62c 100644
--- a/python/tvm/topi/nn/utils.py
+++ b/python/tvm/topi/nn/utils.py
@@ -212,9 +212,9 @@ def get_pad_tuple3d(padding, kernel):
         pad_w = 0
         pad_d = 0
     elif padding == "SAME":
-        pad_h = kernel[0] - 1
-        pad_w = kernel[1] - 1
-        pad_d = kernel[2] - 1
+        pad_d = kernel[0] - 1
+        pad_h = kernel[1] - 1
+        pad_w = kernel[2] - 1
     else:
         raise ValueError("Unknown padding option %s" % padding)
     pad_top = (pad_h + 1) // 2
diff --git a/tests/python/topi/python/test_topi_conv3d_ncdhw.py b/tests/python/topi/python/test_topi_conv3d_ncdhw.py
index c45aaa1..ea94a7d 100644
--- a/tests/python/topi/python/test_topi_conv3d_ncdhw.py
+++ b/tests/python/topi/python/test_topi_conv3d_ncdhw.py
@@ -46,19 +46,41 @@ def verify_conv3d_ncdhw(
     add_bias=False,
     add_relu=False,
 ):
+    if isinstance(kernel, (tuple, list)):
+        if len(kernel) == 3:
+            kernel_d = kernel[0]
+            kernel_h = kernel[1]
+            kernel_w = kernel[2]
+        else:
+            raise ValueError("Size of kernel can only be 3")
+    elif isinstance(kernel, int):
+        kernel_d = kernel_h = kernel_w = kernel
+    else:
+        raise ValueError("Unknown kernel option %s" % kernel)
     pad_front, pad_top, pad_left, pad_back, pad_bottom, pad_right = get_pad_tuple3d(
-        padding, (kernel, kernel, kernel)
+        padding, (kernel_d, kernel_h, kernel_w)
     )
     padding_sum = pad_front + pad_back + pad_top + pad_left + pad_bottom + pad_right
     print(
-        "Workload: (%d, %d, %d, %d, %d, %d, %d, %d)"
-        % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation)
+        "Workload: (%d, %d, %d, %d, %d, %d, %d, %d, %d, %d)"
+        % (
+            batch,
+            in_channel,
+            in_size,
+            num_filter,
+            kernel_d,
+            kernel_h,
+            kernel_w,
+            stride,
+            padding_sum,
+            dilation,
+        )
     )
 
     in_depth = in_height = in_width = in_size
 
     A = te.placeholder((batch, in_channel, in_depth, in_height, in_width), name="A")
-    W = te.placeholder((num_filter, in_channel, kernel, kernel, kernel), name="W")
+    W = te.placeholder((num_filter, in_channel, kernel_d, kernel_h, kernel_w), name="W")
     bias = te.placeholder((num_filter, 1, 1, 1), name="bias")
 
     a_shape = get_const_tuple(A.shape)
@@ -103,8 +125,19 @@ def verify_conv3d_ncdhw(
                 s,
                 [A, W, bias, C],
                 target,
-                name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
-                % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation),
+                name="relu_%d_%d_%d_%d_%d_%d_%d_%d_%d_%d"
+                % (
+                    batch,
+                    in_channel,
+                    in_size,
+                    num_filter,
+                    kernel_d,
+                    kernel_h,
+                    kernel_w,
+                    stride,
+                    padding_sum,
+                    dilation,
+                ),
             )
             func(a, w, b, c)
         else:
@@ -112,8 +145,19 @@ def verify_conv3d_ncdhw(
                 s,
                 [A, W, C],
                 target,
-                name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
-                % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation),
+                name="relu_%d_%d_%d_%d_%d_%d_%d_%d_%d_%d"
+                % (
+                    batch,
+                    in_channel,
+                    in_size,
+                    num_filter,
+                    kernel_d,
+                    kernel_h,
+                    kernel_w,
+                    stride,
+                    padding_sum,
+                    dilation,
+                ),
             )
             func(a, w, c)
         tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-4, atol=1e-6)
@@ -155,6 +199,12 @@ def test_conv3d_ncdhw():
     verify_conv3d_ncdhw(1, 32, 32, 1, 3, 1, "VALID")
     verify_conv3d_ncdhw(1, 32, 32, 5, 1, 1, "VALID")
 
+    # DHW kernel layout
+    verify_conv3d_ncdhw(1, 32, 56, 16, (3, 5, 7), 2, (1, 2, 3))
+    verify_conv3d_ncdhw(1, 3, 56, 16, (3, 7, 7), 2, (1, 2, 3, 0, 3, 2))
+    verify_conv3d_ncdhw(1, 3, 56, 16, (3, 3, 7), 2, (1, 2, 3))
+    verify_conv3d_ncdhw(1, 3, 56, 16, (3, 7, 3), 2, (1, 3, 1))
+
 
 if __name__ == "__main__":
     test_conv3d_ncdhw()
diff --git a/tests/python/topi/python/test_topi_conv3d_transpose_ncdhw.py b/tests/python/topi/python/test_topi_conv3d_transpose_ncdhw.py
index 87b8d1f..01ec2ba 100644
--- a/tests/python/topi/python/test_topi_conv3d_transpose_ncdhw.py
+++ b/tests/python/topi/python/test_topi_conv3d_transpose_ncdhw.py
@@ -132,6 +132,18 @@ def test_conv3d_transpose_ncdhw():
     verify_conv3d_transpose_ncdhw(
         1, 8, (32, 32, 32), 64, (5, 5, 5), (2, 2, 2), (1, 1, 1, 1, 1, 1), (1, 1, 1)
     )
+    verify_conv3d_transpose_ncdhw(
+        1, 8, (32, 32, 32), 64, (3, 5, 7), (2, 2, 2), (1, 1, 1, 1, 1, 1), (0, 0, 0)
+    )
+    verify_conv3d_transpose_ncdhw(
+        1, 8, (32, 32, 32), 64, (3, 5, 5), (2, 2, 2), (1, 1, 1, 1, 1, 1), (0, 0, 0)
+    )
+    verify_conv3d_transpose_ncdhw(
+        1, 8, (32, 32, 32), 64, (3, 3, 7), (2, 2, 2), (1, 1, 1, 1, 1, 1), (0, 0, 0)
+    )
+    verify_conv3d_transpose_ncdhw(
+        1, 8, (32, 32, 32), 64, (3, 5, 3), (2, 2, 2), (1, 1, 1, 1, 1, 1), (0, 0, 0)
+    )
 
 
 if __name__ == "__main__":