You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by co...@apache.org on 2020/12/29 23:10:05 UTC

[tvm] branch main updated: Asymmetric padding and dilation in conv2d workload (#7142)

This is an automated email from the ASF dual-hosted git repository.

comaniac pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new cfdbf0ea Asymmetric padding and dilation in conv2d workload (#7142)
cfdbf0ea is described below

commit cfdbf0eaa52504cc68cad62fd3966bd88e279061
Author: Wheest <Wh...@users.noreply.github.com>
AuthorDate: Tue Dec 29 23:09:47 2020 +0000

    Asymmetric padding and dilation in conv2d workload (#7142)
    
    * added asymmetric padding to conv2d workload
    
    * fixed depthwise conv2d padding
    
    * Added fix to include dilation in workload output width calculation
    
    * Added missing dilation to arm_cpu/conv2d_int8.py workload
    
    * Fixed dilation for x86 conv2d
    
    * Improved dilation workload integration in x86
    
    * Fixed x86 conv2d_alter_op to add dilation
    
    * Local linting not always producing same output as CI, probably my fault
    
    * Fixed bug, tested locally
    
    * Abusing CI until I can figure out how to reproduce the same behaviour of running integration tests locally.
    
    * Ammeded conv2d_int8 test
    
    * Updated workload, improved unit tests
    
    * Added depthwise conv2d workload test
---
 python/tvm/topi/arm_cpu/conv2d_int8.py             |  7 ++--
 python/tvm/topi/cuda/conv2d_int8.py                |  7 ++--
 python/tvm/topi/generic/conv2d.py                  | 15 ++++----
 python/tvm/topi/nn/conv2d.py                       | 43 +++++++++++++++++-----
 python/tvm/topi/nn/depthwise_conv2d.py             | 33 ++++++++++++-----
 python/tvm/topi/testing/depthwise_conv2d_python.py |  2 +-
 python/tvm/topi/x86/conv2d.py                      | 16 +++++---
 python/tvm/topi/x86/conv2d_alter_op.py             | 30 +++++++++++++--
 python/tvm/topi/x86/conv2d_avx_1x1.py              | 11 ++++--
 python/tvm/topi/x86/conv2d_avx_common.py           | 14 ++++---
 python/tvm/topi/x86/conv2d_int8.py                 | 14 ++++---
 python/tvm/topi/x86/depthwise_conv2d.py            |  9 +++--
 tests/python/topi/python/test_topi_conv2d_int8.py  | 23 +++++++++++-
 tests/python/topi/python/test_topi_conv2d_nchw.py  | 17 +++++++++
 .../topi/python/test_topi_depthwise_conv2d.py      | 23 +++++++++++-
 15 files changed, 201 insertions(+), 63 deletions(-)

diff --git a/python/tvm/topi/arm_cpu/conv2d_int8.py b/python/tvm/topi/arm_cpu/conv2d_int8.py
index 445b9ec..fc7e403 100644
--- a/python/tvm/topi/arm_cpu/conv2d_int8.py
+++ b/python/tvm/topi/arm_cpu/conv2d_int8.py
@@ -32,12 +32,12 @@ from .conv2d_gemm import (
 from .arm_utils import get_tiling_B_interleaved_t
 
 
-def _get_default_config(cfg, data, kernel, strides, padding, out_dtype):
+def _get_default_config(cfg, data, kernel, strides, padding, dilation, out_dtype):
     """
     Get default int8 schedule config for the workload
     """
-    wkl = _get_conv2d_workload(data, kernel, strides, padding, out_dtype)
-    is_kernel_1x1 = wkl.hkernel == 1 and wkl.wkernel == 1
+    wkl = _get_conv2d_workload(data, kernel, strides, padding, dilation, out_dtype)
+    is_kernel_1x1 = wkl.kernel_h == 1 and wkl.kernel_w == 1
     if is_kernel_1x1:
         conv2d_generic.fallback_schedule_cpu_1x1_int8(cfg, wkl, int32_lanes=2, num_int8_elements=4)
     else:
@@ -65,6 +65,7 @@ def conv2d_NCHWc_int8(cfg, data, kernel, strides, padding, dilation, layout, out
             te.placeholder((num_filter, in_channel, kh, kw), dtype=kernel.dtype),
             strides,
             padding,
+            dilation,
             out_dtype,
         )
     return nn.conv2d_NCHWc_int8_compute(
diff --git a/python/tvm/topi/cuda/conv2d_int8.py b/python/tvm/topi/cuda/conv2d_int8.py
index 50a0e8b..001411d 100644
--- a/python/tvm/topi/cuda/conv2d_int8.py
+++ b/python/tvm/topi/cuda/conv2d_int8.py
@@ -142,9 +142,10 @@ def conv2d_NCHWc_int8(cfg, data, kernel, stride, padding, dilation, layout, out_
     pad_data = pad(packed_data, pad_before, pad_after, name="pad_data")
 
     # compute the output shape
-    out_height = (in_height - (kernel_h - 1) * dilation_h - 1 + pad_top + pad_down) // stride_h + 1
-    out_width = (in_width - (kernel_w - 1) * dilation_w - 1 + pad_left + pad_right) // stride_w + 1
-
+    dilated_kernel_h = (kernel_h - 1) * dilation_h + 1
+    dilated_kernel_w = (kernel_w - 1) * dilation_w + 1
+    out_height = (in_height - dilated_kernel_h + pad_top + pad_down) // stride_h + 1
+    out_width = (in_width - dilated_kernel_w + pad_left + pad_right) // stride_w + 1
     oshape = (batch, oc_chunk, out_height, out_width, oc_block)
 
     icc = te.reduce_axis((0, ic_chunk), name="ic_chunk")
diff --git a/python/tvm/topi/generic/conv2d.py b/python/tvm/topi/generic/conv2d.py
index 7dd9aed..4daa84c 100644
--- a/python/tvm/topi/generic/conv2d.py
+++ b/python/tvm/topi/generic/conv2d.py
@@ -38,9 +38,10 @@ def fallback_schedule_cpu_common_int8(cfg, wkl, int32_lanes, num_int8_elements):
         How many numbers of input int32/uint32 will be multiplied and reduced.
         This is related to input channel.
     """
-    HPAD, WPAD = wkl.hpad, wkl.wpad
-    HSTR, WSTR = wkl.hstride, wkl.wstride
-    out_width = (wkl.width + 2 * WPAD - wkl.wkernel) // WSTR + 1
+    pt, pl, pb, pr = wkl.padt, wkl.padl, wkl.padb, wkl.padr
+    HSTR, WSTR = wkl.stride_h, wkl.stride_w
+    dilated_kernel_w = (wkl.kernel_w - 1) * wkl.dilation_w + 1
+    out_width = (wkl.width + pl + pr - dilated_kernel_w) // WSTR + 1
 
     assert wkl.out_filter % int32_lanes == 0, "wkl.out_filter=%d, int32_lanes=%d" % (
         wkl.out_filter,
@@ -85,10 +86,10 @@ def fallback_schedule_cpu_1x1_int8(cfg, wkl, int32_lanes, num_int8_elements):
         How many numbers of input int32/uint32 will be multiplied and reduced.
         This is related to input channel.
     """
-    HPAD, WPAD = wkl.hpad, wkl.wpad
-    HSTR, WSTR = wkl.hstride, wkl.wstride
-    out_height = (wkl.height + 2 * HPAD - wkl.hkernel) // HSTR + 1
-    out_width = (wkl.width + 2 * WPAD - wkl.wkernel) // WSTR + 1
+    pt, pl, pb, pr = wkl.padt, wkl.padl, wkl.padb, wkl.padr
+    HSTR, WSTR = wkl.stride_h, wkl.stride_w
+    out_height = (wkl.height + pt + pb - wkl.kernel_h) // HSTR + 1
+    out_width = (wkl.width + pl + pr - wkl.kernel_w) // WSTR + 1
 
     assert wkl.out_filter % int32_lanes == 0, "wkl.out_filter=%d, int32_lanes=%d" % (
         wkl.out_filter,
diff --git a/python/tvm/topi/nn/conv2d.py b/python/tvm/topi/nn/conv2d.py
index 886470b..80f87f8 100644
--- a/python/tvm/topi/nn/conv2d.py
+++ b/python/tvm/topi/nn/conv2d.py
@@ -38,12 +38,16 @@ Workload = namedtuple(
         "in_filter",
         "groups",
         "out_filter",
-        "hkernel",
-        "wkernel",
-        "hpad",
-        "wpad",
-        "hstride",
-        "wstride",
+        "kernel_h",
+        "kernel_w",
+        "padt",
+        "padl",
+        "padb",
+        "padr",
+        "dilation_h",
+        "dilation_w",
+        "stride_h",
+        "stride_w",
     ],
 )
 
@@ -154,7 +158,7 @@ def conv2d_infer_layout(workload, cfg):
     raise ValueError("missing register for topi.nn.conv2d_infer_layout")
 
 
-def _get_workload(data, kernel, stride, padding, out_dtype, data_layout="NCHW"):
+def _get_workload(data, kernel, stride, padding, dilation, out_dtype, data_layout="NCHW"):
     """ Get the workload structure. """
     if data_layout == "NCHW":
         _, CI, IH, IW = get_const_tuple(data.shape)
@@ -170,7 +174,10 @@ def _get_workload(data, kernel, stride, padding, out_dtype, data_layout="NCHW"):
     else:
         KH, KW, CIG, CO = get_const_tuple(kernel.shape)
 
-    HPAD, WPAD, _, _ = get_pad_tuple(padding, (get_const_int(KH), get_const_int(KW)))
+    pt, pl, pb, pr = get_pad_tuple(padding, (get_const_int(KH), get_const_int(KW)))
+    dilation_h, dilation_w = (
+        dilation if isinstance(dilation, (tuple, list)) else (dilation, dilation)
+    )
     GRPS = CI // CIG
     if isinstance(stride, (tuple, list)):
         HSTR, WSTR = stride
@@ -182,7 +189,25 @@ def _get_workload(data, kernel, stride, padding, out_dtype, data_layout="NCHW"):
         '{} vs. {}".format(
         data.dtype, kernel.dtype
     )
-    return Workload(data.dtype, out_dtype, IH, IW, CI, GRPS, CO, KH, KW, HPAD, WPAD, HSTR, WSTR)
+    return Workload(
+        data.dtype,
+        out_dtype,
+        IH,
+        IW,
+        CI,
+        GRPS,
+        CO,
+        KH,
+        KW,
+        pt,
+        pl,
+        pb,
+        pr,
+        dilation_h,
+        dilation_w,
+        HSTR,
+        WSTR,
+    )
 
 
 def conv2d_nchw(Input, Filter, stride, padding, dilation, out_dtype=None):
diff --git a/python/tvm/topi/nn/depthwise_conv2d.py b/python/tvm/topi/nn/depthwise_conv2d.py
index 7235682..052ab8b 100644
--- a/python/tvm/topi/nn/depthwise_conv2d.py
+++ b/python/tvm/topi/nn/depthwise_conv2d.py
@@ -36,22 +36,28 @@ Workload = namedtuple(
         "width",
         "in_filter",
         "out_filter",
-        "hkernel",
-        "wkernel",
-        "hpad",
-        "wpad",
-        "hstride",
-        "wstride",
+        "kernel_h",
+        "kernel_w",
+        "padt",
+        "padl",
+        "padb",
+        "padr",
+        "dilation_h",
+        "dilation_w",
+        "stride_h",
+        "stride_w",
     ],
 )
 
 
-def _get_workload(data, kernel, stride, padding, out_dtype):
+def _get_workload(data, kernel, stride, padding, dilation, out_dtype):
     """ Get the workload structure. """
     _, in_channel, height, width = [x.value for x in data.shape]
     channel, channel_multiplier, kh, kw = [x.value for x in kernel.shape]
     out_channel = channel * channel_multiplier
-    HPAD, WPAD, _, _ = get_pad_tuple(padding, kernel)
+    dilation_h, dilation_w = (
+        dilation if isinstance(dilation, (tuple, list)) else (dilation, dilation)
+    )
     if isinstance(stride, (tuple, list)):
         HSTR, WSTR = stride
     else:
@@ -62,6 +68,9 @@ def _get_workload(data, kernel, stride, padding, out_dtype):
         '{} vs. {}".format(
         data.dtype, kernel.dtype
     )
+    dilated_kernel_h = (kh - 1) * dilation_h + 1
+    dilated_kernel_w = (kw - 1) * dilation_w + 1
+    pt, pl, pb, pr = get_pad_tuple(padding, (dilated_kernel_h, dilated_kernel_w))
     return Workload(
         data.dtype,
         out_dtype,
@@ -71,8 +80,12 @@ def _get_workload(data, kernel, stride, padding, out_dtype):
         out_channel,
         kh,
         kw,
-        HPAD,
-        WPAD,
+        pt,
+        pl,
+        pb,
+        pr,
+        dilation_h,
+        dilation_w,
         HSTR,
         WSTR,
     )
diff --git a/python/tvm/topi/testing/depthwise_conv2d_python.py b/python/tvm/topi/testing/depthwise_conv2d_python.py
index 06f26ab..2239c56 100644
--- a/python/tvm/topi/testing/depthwise_conv2d_python.py
+++ b/python/tvm/topi/testing/depthwise_conv2d_python.py
@@ -65,7 +65,7 @@ def depthwise_conv2d_python_nchw(input_np, filter_np, stride, padding):
                     0 : (in_height - filter_height + 1) : stride_h,
                     0 : (in_width - filter_width + 1) : stride_w,
                 ]
-    if padding == "SAME":
+    elif padding == "SAME":
         out_channel = in_channel * channel_multiplier
         out_height = np.int(np.ceil(float(in_height) / float(stride_h)))
         out_width = np.int(np.ceil(float(in_width) / float(stride_w)))
diff --git a/python/tvm/topi/x86/conv2d.py b/python/tvm/topi/x86/conv2d.py
index a3b7e47..182454a 100644
--- a/python/tvm/topi/x86/conv2d.py
+++ b/python/tvm/topi/x86/conv2d.py
@@ -35,7 +35,7 @@ logger = logging.getLogger("topi")
 
 
 def _get_default_config(
-    cfg, data, kernel, strides, padding, out_dtype, is_depthwise=False, layout="NCHW"
+    cfg, data, kernel, strides, padding, dilation, out_dtype, is_depthwise=False, layout="NCHW"
 ):
     """
     Get default schedule config for the workload
@@ -48,13 +48,13 @@ def _get_default_config(
             static_data_shape.append(dim)
     data = te.placeholder(static_data_shape, dtype=data.dtype)
     if is_depthwise:
-        wkl = _get_depthwise_conv2d_workload(data, kernel, strides, padding, out_dtype)
+        wkl = _get_depthwise_conv2d_workload(data, kernel, strides, padding, dilation, out_dtype)
         from .depthwise_conv2d import _fallback_schedule
 
         _fallback_schedule(cfg, wkl)
     else:
-        wkl = _get_conv2d_workload(data, kernel, strides, padding, out_dtype, layout)
-        is_kernel_1x1 = wkl.hkernel == 1 and wkl.wkernel == 1
+        wkl = _get_conv2d_workload(data, kernel, strides, padding, dilation, out_dtype, layout)
+        is_kernel_1x1 = wkl.kernel_h == 1 and wkl.kernel_w == 1
         if is_kernel_1x1:
             conv2d_avx_1x1._fallback_schedule(cfg, wkl)
         else:
@@ -69,8 +69,11 @@ def _conv2d_infer_layout(workload, cfg):
     idxdiv = tvm.tir.indexdiv
 
     pt, pl, pb, pr = get_pad_tuple(padding, (k_height, k_width))
-    out_height = idxdiv(in_height + pt + pb - k_height, strides[0]) + 1
-    out_width = idxdiv(in_width + pl + pr - k_width, strides[1]) + 1
+    hdilation, wdilation = dilation if isinstance(dilation, (tuple, list)) else (dilation, dilation)
+    dilated_kernel_h = (k_height - 1) * hdilation + 1
+    dilated_kernel_w = (k_width - 1) * wdilation + 1
+    out_height = idxdiv(in_height + pt + pb - dilated_kernel_h, strides[0]) + 1
+    out_width = idxdiv(in_width + pl + pr - dilated_kernel_w, strides[1]) + 1
     tile_ic, tile_oc = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1]
     in_shape = (batch_size, idxdiv(in_channel, tile_ic), in_height, in_width, tile_ic)
     in_layout = "NCHW%dc" % tile_ic
@@ -208,6 +211,7 @@ def conv2d_NCHWc(cfg, data, kernel, strides, padding, dilation, layout, out_layo
             ),
             strides,
             padding,
+            dilation,
             out_dtype,
         )
 
diff --git a/python/tvm/topi/x86/conv2d_alter_op.py b/python/tvm/topi/x86/conv2d_alter_op.py
index 979dc5a..f05bac8 100644
--- a/python/tvm/topi/x86/conv2d_alter_op.py
+++ b/python/tvm/topi/x86/conv2d_alter_op.py
@@ -97,7 +97,15 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
         if data_layout == "NCHW" and kernel_layout == "OIHW":
             if cfg.is_fallback:
                 _get_default_config(
-                    cfg, data_tensor, kernel_tensor, strides, padding, out_dtype, False, data_layout
+                    cfg,
+                    data_tensor,
+                    kernel_tensor,
+                    strides,
+                    padding,
+                    dilation,
+                    out_dtype,
+                    False,
+                    data_layout,
                 )
             batch_size, in_channel, height, width = get_const_tuple(data_tensor.shape)
             out_channel, _, kh, kw = get_const_tuple(kernel_tensor.shape)
@@ -142,7 +150,15 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
         assert data_layout == "NCHW" and kernel_layout == "OIHW"
         if cfg.is_fallback:
             _get_default_config_int8(
-                cfg, data_tensor, kernel_tensor, strides, padding, out_dtype, False, data_layout
+                cfg,
+                data_tensor,
+                kernel_tensor,
+                strides,
+                padding,
+                dilation,
+                out_dtype,
+                False,
+                data_layout,
             )
 
         batch_size, in_channel, height, width = get_const_tuple(data_tensor.shape)
@@ -198,7 +214,15 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
         if data_layout == "NCHW" and kernel_layout == "OIHW":
             if cfg.is_fallback:
                 _get_default_config(
-                    cfg, data_tensor, kernel_tensor, strides, padding, out_dtype, True, data_layout
+                    cfg,
+                    data_tensor,
+                    kernel_tensor,
+                    strides,
+                    padding,
+                    dilation,
+                    out_dtype,
+                    True,
+                    data_layout,
                 )
 
             batch_size, in_channel, height, width = get_const_tuple(data_tensor.shape)
diff --git a/python/tvm/topi/x86/conv2d_avx_1x1.py b/python/tvm/topi/x86/conv2d_avx_1x1.py
index 3e5a12b..afee03a 100644
--- a/python/tvm/topi/x86/conv2d_avx_1x1.py
+++ b/python/tvm/topi/x86/conv2d_avx_1x1.py
@@ -31,10 +31,13 @@ from .utils import get_fp32_len
 
 def _fallback_schedule(cfg, wkl):
     simd_width = get_fp32_len()
-    HPAD, WPAD = wkl.hpad, wkl.wpad
-    HSTR, WSTR = wkl.hstride, wkl.wstride
-    out_height = (wkl.height + 2 * HPAD - wkl.hkernel) // HSTR + 1
-    out_width = (wkl.width + 2 * WPAD - wkl.wkernel) // WSTR + 1
+    pt, pl, pb, pr = wkl.padt, wkl.padl, wkl.padb, wkl.padr
+    HSTR, WSTR = wkl.stride_h, wkl.stride_w
+    dilated_kernel_h = (wkl.kernel_h - 1) * wkl.dilation_h + 1
+    dilated_kernel_w = (wkl.kernel_w - 1) * wkl.dilation_w + 1
+
+    out_height = (wkl.height + pt + pb - dilated_kernel_h) // HSTR + 1
+    out_width = (wkl.width + pl + pr - dilated_kernel_w) // WSTR + 1
 
     oc_bn = 1
     for bn in range(simd_width, 0, -1):
diff --git a/python/tvm/topi/x86/conv2d_avx_common.py b/python/tvm/topi/x86/conv2d_avx_common.py
index 8d70744..5e63de3 100644
--- a/python/tvm/topi/x86/conv2d_avx_common.py
+++ b/python/tvm/topi/x86/conv2d_avx_common.py
@@ -27,9 +27,11 @@ from .utils import get_fp32_len
 
 def _fallback_schedule(cfg, wkl):
     simd_width = get_fp32_len()
-    HPAD, WPAD = wkl.hpad, wkl.wpad
-    HSTR, WSTR = wkl.hstride, wkl.wstride
-    out_width = (wkl.width + 2 * WPAD - wkl.wkernel) // WSTR + 1
+    pt, pl, pb, pr = wkl.padt, wkl.padl, wkl.padb, wkl.padr
+    HSTR, WSTR = wkl.stride_h, wkl.stride_w
+    dilated_kernel_w = (wkl.kernel_w - 1) * wkl.dilation_w + 1
+
+    out_width = (wkl.width + pl + pr - dilated_kernel_w) // WSTR + 1
 
     oc_bn = 1
     for bn in range(simd_width, 0, -1):
@@ -56,9 +58,9 @@ def _fallback_schedule(cfg, wkl):
 
 
 def _fallback_schedule_int8(cfg, wkl):
-    HPAD, WPAD = wkl.hpad, wkl.wpad
-    HSTR, WSTR = wkl.hstride, wkl.wstride
-    out_width = (wkl.width + 2 * WPAD - wkl.wkernel) // WSTR + 1
+    pt, pl, pb, pr = wkl.padt, wkl.padl, wkl.padb, wkl.padr
+    HSTR, WSTR = wkl.stride_h, wkl.stride_w
+    out_width = (wkl.width + pl + pr - wkl.kernel_w) // WSTR + 1
 
     oc_bn = 16
     assert wkl.out_filter % oc_bn == 0
diff --git a/python/tvm/topi/x86/conv2d_int8.py b/python/tvm/topi/x86/conv2d_int8.py
index 905ada6..ca0d0b8 100644
--- a/python/tvm/topi/x86/conv2d_int8.py
+++ b/python/tvm/topi/x86/conv2d_int8.py
@@ -33,7 +33,7 @@ from . import conv2d_avx_1x1, conv2d_avx_common
 
 
 def _get_default_config_int8(
-    cfg, data, kernel, strides, padding, out_dtype, is_depthwise=False, layout="NCHW"
+    cfg, data, kernel, strides, padding, dilation, out_dtype, is_depthwise=False, layout="NCHW"
 ):
     """
     Get default schedule config for the workload
@@ -45,8 +45,8 @@ def _get_default_config_int8(
 
         _fallback_schedule(cfg, wkl)
     else:
-        wkl = _get_conv2d_workload(data, kernel, strides, padding, out_dtype, layout)
-        is_kernel_1x1 = wkl.hkernel == 1 and wkl.wkernel == 1
+        wkl = _get_conv2d_workload(data, kernel, strides, padding, dilation, out_dtype, layout)
+        is_kernel_1x1 = wkl.kernel_h == 1 and wkl.kernel_w == 1
         if is_kernel_1x1:
             conv2d_generic.fallback_schedule_cpu_1x1_int8(
                 cfg, wkl, int32_lanes=16, num_int8_elements=4
@@ -138,8 +138,11 @@ def conv2d_NCHWc_int8(cfg, data, kernel, strides, padding, dilation, layout, out
     is_kernel_1x1 = kernel_height == 1 and kernel_width == 1
     pt, pl, pb, pr = get_pad_tuple(padding, (kernel_height, kernel_width))
     sh, sw = strides if isinstance(strides, (tuple, list)) else (strides, strides)
-    oh = (ih - kernel_height + pt + pb) // sh + 1
-    ow = (iw - kernel_width + pl + pr) // sw + 1
+    dh, dw = dilation if isinstance(dilation, (tuple, list)) else (dilation, dilation)
+    dilated_kernel_h = (kernel_height - 1) * dh + 1
+    dilated_kernel_w = (kernel_width - 1) * dw + 1
+    oh = (ih - dilated_kernel_h + pt + pb) // sh + 1
+    ow = (iw - dilated_kernel_w + pl + pr) // sw + 1
 
     cfg.define_split("tile_ic", in_channel, num_outputs=2, filter=lambda y: y.size[-1] % 4 == 0)
     cfg.define_split("tile_oc", num_filter, num_outputs=2, filter=lambda y: y.size[-1] % 16 == 0)
@@ -159,6 +162,7 @@ def conv2d_NCHWc_int8(cfg, data, kernel, strides, padding, dilation, layout, out
             ),
             strides,
             padding,
+            dilation,
             out_dtype,
         )
 
diff --git a/python/tvm/topi/x86/depthwise_conv2d.py b/python/tvm/topi/x86/depthwise_conv2d.py
index badba1a..a0225ef 100644
--- a/python/tvm/topi/x86/depthwise_conv2d.py
+++ b/python/tvm/topi/x86/depthwise_conv2d.py
@@ -42,9 +42,11 @@ def _fallback_schedule(cfg, wkl):
     """
     simd_width = get_fp32_len()
 
-    HPAD, WPAD = wkl.hpad, wkl.wpad
-    HSTR, WSTR = wkl.hstride, wkl.wstride
-    out_width = (wkl.width + 2 * WPAD - wkl.wkernel) // WSTR + 1
+    pt, pl, pb, pr = wkl.padt, wkl.padl, wkl.padb, wkl.padr
+    HSTR, WSTR = wkl.stride_h, wkl.stride_w
+    dilated_kernel_w = (wkl.kernel_w - 1) * wkl.dilation_w + 1
+
+    out_width = (wkl.width - dilated_kernel_w + pl + pr) // WSTR + 1
 
     oc_bn = 1
     for bn in range(simd_width, 0, -1):
@@ -165,6 +167,7 @@ def depthwise_conv2d_NCHWc(
         ),
         strides,
         (pad_top, pad_down),
+        dilation,
         out_dtype,
     )
     if cfg.is_fallback:
diff --git a/tests/python/topi/python/test_topi_conv2d_int8.py b/tests/python/topi/python/test_topi_conv2d_int8.py
index 1bf83eb..a934e3e 100644
--- a/tests/python/topi/python/test_topi_conv2d_int8.py
+++ b/tests/python/topi/python/test_topi_conv2d_int8.py
@@ -27,6 +27,8 @@ from tvm.contrib.pickle_memoize import memoize
 from tvm.topi.nn.utils import get_pad_tuple
 from tvm.topi.utils import get_const_tuple
 from tvm.topi.arm_cpu.conv2d_gemm import is_aarch64_arm
+from tvm.topi.nn.conv2d import _get_workload
+from tvm.topi.generic.conv2d import fallback_schedule_cpu_common_int8
 
 from common import Int8Fallback
 import tvm.testing
@@ -112,7 +114,7 @@ def compile_conv2d_NHWC_gemm_int8_arm(
                 s,
                 [A, W, bias, C],
                 device,
-                name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
+                name="relu_%dnnn_%d_%d_%d_%d_%d_%d_%d"
                 % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation),
             )
         else:
@@ -385,6 +387,22 @@ def verify_conv2d_nchw_int8(
 
     a_np, w_np, b_np, c_np = get_ref_data()
 
+    def verify_workload_padding():
+        _, _, out_height, out_width = get_const_tuple(c_np.shape)
+        wkl = _get_workload(A, W, (stride, stride), padding, dilation, dtype)
+
+        # for testing functionality,
+        # we choose arbitrary int32_lanes and num_int8_elements can divide the channel,
+        # regardless of the performance.
+        int32_lanes, num_int8_elements = num_filter, in_channel
+
+        # check if tile_ow candidates are the factors of the right output weight.
+        cfg = autotvm.get_config()
+        fallback_schedule_cpu_common_int8(cfg, wkl, int32_lanes, num_int8_elements)
+        ow_tile = np.prod(cfg["tile_ow"].size)
+
+        tvm.testing.assert_allclose(ow_tile, out_width)
+
     def check_device(device):
         ctx = tvm.context(device, 0)
         if not tvm.testing.device_enabled(device):
@@ -436,6 +454,8 @@ def verify_conv2d_nchw_int8(
             func(a, w, c)
         tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5)
 
+    verify_workload_padding()
+
     for device in ["cuda"]:
         check_device(device)
 
@@ -547,6 +567,7 @@ def test_conv2d_nchw():
         verify_conv2d_nchw_int8(1, 32, 149, 32, 3, 1, 0)
         verify_conv2d_nchw_int8(7, 32, 149, 32, 3, 1, 0)
         verify_conv2d_nchw_int8(1, 32, 35, 64, 7, 2, (0, 0, 1, 1))
+        verify_conv2d_nchw_int8(1, 32, 35, 64, 7, 2, (0, 0, 2, 2))
 
 
 def test_conv2d_nhwc():
diff --git a/tests/python/topi/python/test_topi_conv2d_nchw.py b/tests/python/topi/python/test_topi_conv2d_nchw.py
index 1b75752..07ad45c 100644
--- a/tests/python/topi/python/test_topi_conv2d_nchw.py
+++ b/tests/python/topi/python/test_topi_conv2d_nchw.py
@@ -25,6 +25,8 @@ import tvm.topi.testing
 from tvm.contrib.pickle_memoize import memoize
 from tvm.topi.nn.utils import get_pad_tuple
 from tvm.topi.utils import get_const_tuple
+from tvm.topi.nn.conv2d import _get_workload
+from tvm.topi.x86.conv2d_avx_common import _fallback_schedule
 
 import tvm.testing
 
@@ -76,6 +78,17 @@ def verify_conv2d_nchw(
 
     a_np, w_np, b_np, c_np = get_ref_data()
 
+    def verify_workload_padding():
+        _, _, out_height, out_width = get_const_tuple(c_np.shape)
+        wkl = _get_workload(A, W, (stride, stride), padding, dilation, dtype)
+
+        # check if tile_ow candidates are the factors of the right output weight.
+        cfg = autotvm.get_config()
+        _fallback_schedule(cfg, wkl)
+        ow_tile = np.prod(cfg["tile_ow"].size)
+
+        tvm.testing.assert_allclose(ow_tile, out_width)
+
     def check_device(device):
         ctx = tvm.context(device, 0)
         if not tvm.testing.device_enabled(device):
@@ -101,6 +114,9 @@ def verify_conv2d_nchw(
                 C = topi.nn.relu(C)
             s = fschedule([C])
 
+            if "llvm" in device:
+                verify_workload_padding()
+
         a = tvm.nd.array(a_np, ctx)
         w = tvm.nd.array(w_np, ctx)
         b = tvm.nd.array(b_np, ctx)
@@ -242,6 +258,7 @@ def test_conv2d_nchw():
     verify_conv2d_nchw(1, 64, 8, 64, 5, 2, (1, 3), add_bias=True)
     verify_conv2d_nchw(1, 64, 8, 64, 3, 1, "VALID", add_bias=True, add_relu=True)
     verify_conv2d_nchw(1, 64, 8, 64, 24, 1, "SAME", add_bias=True, add_relu=True)
+    verify_conv2d_nchw(1, 32, 35, 64, 7, 2, (0, 0, 2, 2))
 
 
 if __name__ == "__main__":
diff --git a/tests/python/topi/python/test_topi_depthwise_conv2d.py b/tests/python/topi/python/test_topi_depthwise_conv2d.py
index 55d2fe0..804c486 100644
--- a/tests/python/topi/python/test_topi_depthwise_conv2d.py
+++ b/tests/python/topi/python/test_topi_depthwise_conv2d.py
@@ -23,6 +23,8 @@ import numpy as np
 from tvm.topi.utils import get_const_tuple
 from tvm.topi.nn.utils import get_pad_tuple
 from tvm.contrib.pickle_memoize import memoize
+from tvm.topi.nn.depthwise_conv2d import _get_workload
+from tvm.topi.x86.depthwise_conv2d import _fallback_schedule
 
 import tvm.testing
 
@@ -116,8 +118,8 @@ def depthwise_conv2d_with_workload_nchw(
     if dilation == 1:
         # here we transform the padding argument from 'str' to  'tuple' ,
         # because we need this to match the "workload" tuple to the records in TopHub
-        pad_h, pad_w, _, _ = get_pad_tuple(padding, (filter_height, filter_width))
-        padding_args = (pad_h, pad_w)
+        padt, padl, padb, padr = get_pad_tuple(padding, (filter_height, filter_width))
+        padding_args = (padt, padl, padb, padr)
     else:
         padding_args = padding
 
@@ -205,6 +207,23 @@ def depthwise_conv2d_with_workload_nchw(
                 relu_scipy,
             ) = get_ref_data()
 
+            def verify_workload_padding():
+                _, _, out_height, out_width = get_const_tuple(depthwise_conv2d_scipy.shape)
+                wkl = _get_workload(
+                    Input, Filter, (stride_h, stride_w), padding_args, dilation, dtype
+                )
+
+                # check if tile_ow candidates are the factors of the right output weight.
+                with tvm.target.Target(device):
+                    cfg = autotvm.get_config()
+                    _fallback_schedule(cfg, wkl)
+                    ow_tile = np.prod(cfg["tile_ow"].size)
+
+                    tvm.testing.assert_allclose(ow_tile, out_width)
+
+            if "llvm" in device:
+                verify_workload_padding()
+
             input_tvm = tvm.nd.array(input_np, ctx)
             filter_tvm = tvm.nd.array(filter_np, ctx)
             scale_tvm = tvm.nd.array(scale_np, ctx)