You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2022/12/28 09:56:52 UTC

[GitHub] [tvm] ekalda opened a new pull request, #13669: [TOPI][bugfix] Fix a bug in arm_cpu int8 dotprod schedule and modernize tests

ekalda opened a new pull request, #13669:
URL: https://github.com/apache/tvm/pull/13669

   topi.arm_cpu.schedule_conv2d_NHWC_quantized_native was failing compilation in case the input channels divided by 4 was less than 4.
   
   This was because we were splitting this axis by a factor of 4 to create appropriate loop nest for tensorize, but then tensorize was assuming that the outer axis bound was divisible by 4.
   
   If the outer bound was less than 4, compilation failed, if it was greater than 4 but not divisible by 4, we were occasionally accessing data outside of tensor, which luckily was padded due to alignment (I think).
   
   So here we make sure that we explicitly pad the input axis such that the outer loop will always be divisible by 4.
   
   There are also some refactors to test_topi_conv2d_int8.py:
   - decouple the tests using pytest.parametrize
   - extend the NHWC int8 schedules test to test against arm targets and various schedules. When these schedules were initialy added, we didn't have Arm CI, so only compilation was tested, now we can also run the workloads on Arm targets.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] Mousius commented on pull request #13669: [TOPI][bugfix] Fix a bug in arm_cpu int8 dotprod schedule and modernize tests

Posted by GitBox <gi...@apache.org>.
Mousius commented on PR #13669:
URL: https://github.com/apache/tvm/pull/13669#issuecomment-1367515561

   LGTM @ekalda, thanks for making great strides improving the tests here 😸 I'll leave it open a little longer but otherwise I think this is good to go


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] ekalda commented on a diff in pull request #13669: [TOPI][bugfix] Fix a bug in arm_cpu int8 dotprod schedule and modernize tests

Posted by GitBox <gi...@apache.org>.
ekalda commented on code in PR #13669:
URL: https://github.com/apache/tvm/pull/13669#discussion_r1058903606


##########
python/tvm/topi/nn/conv2d.py:
##########
@@ -606,8 +606,8 @@ def conv2d_gemm_weight_transform(kernel, tile_rows, tile_cols):
     if N % tile_rows != 0:
         pad_N = tile_rows - (N % tile_rows)
 
-    if K % tile_cols != 0:
-        pad_K = tile_cols - (K % tile_cols)
+    if K % (tile_cols * 4) != 0:
+        pad_K = (tile_cols * 4) - (K % (tile_cols * 4))

Review Comment:
   Done



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] ekalda commented on a diff in pull request #13669: [TOPI][bugfix] Fix a bug in arm_cpu int8 dotprod schedule and modernize tests

Posted by GitBox <gi...@apache.org>.
ekalda commented on code in PR #13669:
URL: https://github.com/apache/tvm/pull/13669#discussion_r1058903893


##########
tests/python/topi/python/test_topi_conv2d_int8.py:
##########
@@ -35,261 +35,138 @@
 import platform
 
 
-def compile_conv2d_NHWC_gemm_int8_arm(
-    batch,
-    in_channel,
-    in_size,
-    num_filter,
-    kernel,
-    stride,
-    padding,
-    dilation=1,
-    add_bias=False,
-    add_relu=False,
-):
-    pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (kernel, kernel))
-    padding_sum = pad_top + pad_left + pad_bottom + pad_right
-    print(
-        "Workload: (%d, %d, %d, %d, %d, %d, %d, %d)"
-        % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation)
-    )
-
-    in_height = in_width = in_size
-    A = te.placeholder((batch, in_height, in_width, in_channel), name="A", dtype="int8")
-    W = te.placeholder((kernel, kernel, in_channel, num_filter), name="W", dtype="int8")
-    bias = te.placeholder((num_filter,), name="bias", dtype="int8")
-    dtype = "int32"
-    devices = [
-        (
-            "llvm --device arm_cpu --mtriple aarch64-linux-gnu",
-            topi.arm_cpu.compute_conv2d_NHWC_quantized_interleaved,
-            topi.arm_cpu.schedule_conv2d_NHWC_quantized_interleaved,
-        ),
-        (
-            "llvm --device arm_cpu --mtriple aarch64-linux-gnu -mattr=+v8.2a,+dotprod",
-            topi.arm_cpu.compute_conv2d_NHWC_quantized_interleaved,
-            topi.arm_cpu.schedule_conv2d_NHWC_quantized_interleaved,
-        ),
-        (
-            "llvm --device arm_cpu --mtriple aarch64-linux-gnu -mattr=+v8.2a,+dotprod",
-            topi.arm_cpu.compute_conv2d_NHWC_quantized_native,
-            topi.arm_cpu.schedule_conv2d_NHWC_quantized_native,
-        ),
-        # TODO(giuseros) Need LLVM-11 in order to compile with +i8mm extension
-        # (
-        #   "llvm --device arm_cpu --mtriple aarch64-linux-gnu -mattr=+v8.2a,+i8mm",
-        #   topi.arm_cpu.compute_conv2d_NHWC_quantized_interleaved,
-        #   topi.arm_cpu.schedule_conv2d_NHWC_quantized_interleaved,
-        # ),
-    ]
-
-    for device_tuple in devices:
-        target = device_tuple[0]
-        compute = device_tuple[1]
-        schedule = device_tuple[2]
-
-        dev = tvm.device(target, 0)
-        if not tvm.testing.device_enabled(target):
-            print("Skip because %s is not enabled" % target)
-            return
-        print("Compiling on arm AArch64 target: %s" % target)
-        with tvm.target.Target(target) as tvm_target:
-            assert tvm_target.features.is_aarch64, "AArch64 target not recognized"
+devices = [
+    (
+        "llvm",
+        topi.arm_cpu.compute_conv2d_NHWC_quantized_interleaved,
+        topi.arm_cpu.schedule_conv2d_NHWC_quantized_interleaved,
+    ),
+    (
+        "llvm --device arm_cpu --mtriple aarch64-linux-gnu",
+        topi.arm_cpu.compute_conv2d_NHWC_quantized_interleaved,
+        topi.arm_cpu.schedule_conv2d_NHWC_quantized_interleaved,
+    ),
+    (
+        "llvm --device arm_cpu --mtriple aarch64-linux-gnu -mattr=+v8.2a,+dotprod",
+        topi.arm_cpu.compute_conv2d_NHWC_quantized_interleaved,
+        topi.arm_cpu.schedule_conv2d_NHWC_quantized_interleaved,
+    ),
+    (
+        "llvm --device arm_cpu --mtriple aarch64-linux-gnu -mattr=+v8.2a,+dotprod",
+        topi.arm_cpu.compute_conv2d_NHWC_quantized_native,
+        topi.arm_cpu.schedule_conv2d_NHWC_quantized_native,
+    ),
+    # TODO(giuseros) We need LLVM-11 in order to compile with +i8mm extension
+    # (
+    # "llvm --device arm_cpu --mtriple aarch64-linux-gnu -mattr=+v8.2a,+i8mm",
+    # topi.arm_cpu.compute_conv2d_NHWC_quantized_interleaved,
+    # topi.arm_cpu.schedule_conv2d_NHWC_quantized_interleaved,
+    # ),
+]
+
+
+@tvm.testing.requires_llvm
+@pytest.mark.parametrize("device", devices)
+@pytest.mark.parametrize(
+    "params",
+    [
+        # Subset of inception v3 expanded (dilation > 1, batch > 1, 'VALID' padding)
+        (1, 3, 299, 32, 3, 2, "SAME", 1, False, False),
+        (1, 32, 149, 32, 3, 1, "SAME", 2, False, False),
+        (4, 32, 147, 64, 3, 1, "SAME", 1, False, False),
+        (1, 64, 73, 80, 1, 1, "SAME", 1, False, False),
+        (1, 80, 73, 192, 3, 1, "SAME", 1, False, False),
+        (1, 192, 35, 48, 1, 1, "SAME", 1, False, False),
+        (1, 192, 35, 64, 1, 1, "VALID", 1, False, False),
+        (1, 192, 35, 32, 1, 1, "SAME", 1, False, False),
+        (1, 48, 35, 64, 5, 1, "SAME", 1, False, False),
+        (1, 96, 35, 96, 3, 1, "SAME", 1, False, False),
+        (1, 256, 35, 48, 1, 1, "SAME", 1, False, False),
+        (1, 256, 35, 64, 1, 1, "SAME", 1, False, False),
+        (1, 288, 35, 64, 1, 1, "SAME", 1, False, False),
+        (1, 288, 35, 48, 1, 1, "SAME", 1, False, False),
+        (1, 96, 35, 96, 3, 2, "SAME", 1, False, False),
+        (1, 128, 17, 192, 7, 1, "SAME", 2, False, False),
+        (1, 160, 17, 160, 7, 1, "SAME", 1, False, False),
+        (1, 160, 17, 192, 1, 1, "VALID", 1, False, False),
+        (1, 192, 17, 192, 1, 1, "SAME", 1, False, False),
+        (1, 768, 5, 128, 1, 1, "SAME", 1, False, False),
+        (1, 192, 17, 320, 3, 2, "SAME", 1, False, False),
+        (1, 192, 17, 192, 3, 2, "SAME", 1, False, False),
+        (1, 1280, 8, 192, 1, 1, "SAME", 1, False, False),
+        (1, 1280, 8, 384, 1, 1, "SAME", 1, False, False),
+        (1, 1280, 8, 320, 1, 1, "SAME", 1, False, False),
+        (1, 1280, 8, 448, 1, 1, "SAME", 1, False, False),
+        (1, 384, 8, 384, 1, 1, "SAME", 1, False, False),
+        (1, 384, 8, 384, 3, 1, "SAME", 1, False, False),
+        (1, 448, 8, 384, 3, 1, "VALID", 1, False, False),
+        (1, 2048, 8, 320, 1, 1, "SAME", 1, False, False),
+        (1, 2048, 8, 448, 1, 1, "SAME", 1, True, True),
+        (1, 2048, 8, 192, 1, 1, "SAME", 1, True, False),
+        # A trouble case for native schedule
+        (1, 8, 1, 24, 1, 1, "SAME", 1, False, False),
+    ],
+)
+def test_conv2d_NHWC_gemm_int8(params, device):
 
-            C = compute(A, W, (stride, stride), padding, (dilation, dilation), dtype)
-            if add_bias:
-                C = topi.add(C, bias)
-            if add_relu:
-                C = topi.nn.relu(C)
-            s = schedule([C])
+    with Int8Fallback():
+        target, compute, schedule = device
 
-        if add_bias:
-            tvm.build(
-                s,
-                [A, W, bias, C],
-                target,
-                name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
-                % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation),
-            )
-            func = tvm.build(
-                s,
-                [A, W, bias, C],
-                target,
-                name="relu_%dnnn_%d_%d_%d_%d_%d_%d_%d"
-                % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation),
-            )
-        else:
-            func = tvm.build(
-                s,
-                [A, W, C],
-                target,
-                name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
-                % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation),
-            )
+        (
+            batch,
+            in_channel,
+            in_size,
+            num_filter,
+            kernel,
+            stride,
+            padding,
+            dilation,
+            add_bias,
+            add_relu,
+        ) = params
+
+        # TODO(ekalda): These combinations hang during compilation
+        failing_cases = [
+            (devices[1], (1, 128, 17, 192, 7, 1, "SAME", 2, False, False)),
+            (devices[1], (1, 160, 17, 160, 7, 1, "SAME", 1, False, False)),
+            (
+                devices[1],
+                (1, 448, 8, 384, 3, 1, "VALID", 1, False, False),
+            ),  # this one passes but is just incredibly slow
+        ]
+        if (device, params) in failing_cases:
+            return

Review Comment:
   I added `pytest.skip` for these cases so that they will show up as skipped tests in pytest log



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] ekalda commented on a diff in pull request #13669: [TOPI][bugfix] Fix a bug in arm_cpu int8 dotprod schedule and modernize tests

Posted by GitBox <gi...@apache.org>.
ekalda commented on code in PR #13669:
URL: https://github.com/apache/tvm/pull/13669#discussion_r1058904002


##########
tests/python/topi/python/test_topi_conv2d_int8.py:
##########
@@ -298,378 +175,462 @@ def get_ref_data():
 
         a_np, w_np, b_np, c_np = get_ref_data()
 
-        with tvm.target.Target(target):
-            C = compute(
-                A,
-                W,
-                (stride, stride),
-                padding,
-                (dilation, dilation),
-                "NCHW",
-                "NCHW",
-                out_dtype,
-            )
+        dev = tvm.device(target, 0)
+        with tvm.target.Target(target) as tvm_target:
+            C = compute(A, W, (stride, stride), padding, (dilation, dilation), dtype)
             if add_bias:
                 C = topi.add(C, bias)
             if add_relu:
                 C = topi.nn.relu(C)
             s = schedule([C])
 
-        a = tvm.nd.array(a_np.astype(dtype), dev)
-        w = tvm.nd.array(w_np.astype(dtype), dev)
-        b = tvm.nd.array(b_np.astype(out_dtype), dev)
-        c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), dev)
-
-        if add_bias:
-            compile_args = [A, W, bias, C]
-            run_args = [a, w, b, c]
-        else:
-            compile_args = [A, W, C]
-            run_args = [a, w, c]
-
-        func = tvm.build(
-            s,
-            compile_args,
-            target,
-            name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
-            % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation),
-        )
+            a = tvm.nd.array(a_np, dev)
+            w = tvm.nd.array(w_np, dev)
+            b = tvm.nd.array(b_np, dev)
+            c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), dev)
 
-        if build_only:
-            return
+            build_inputs = [A, W, bias, C] if add_bias else [A, W, C]
+            inference_inputs = (a, w, b, c) if add_bias else (a, w, c)
+
+            func = tvm.build(
+                s,
+                build_inputs,
+                target,
+                name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
+                % (
+                    batch,
+                    in_channel,
+                    in_size,
+                    num_filter,
+                    kernel,
+                    stride,
+                    padding_sum,
+                    dilation,
+                ),
+            )
 
-        print("Running on target: %s" % target)
+            build_only = tvm_target.features.is_aarch64 and (platform.machine() != "aarch64")
 
-        func(*run_args)
+            if not build_only:
+                print("Running on target: %s" % target)
+                func(*inference_inputs)
+                tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-5)
 
-        tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-5)
 
-    targets = [
-        (
-            "cuda",
-            lambda a, w, s, p, d, l, ol, o: topi.cuda.conv2d_NCHWc_int8(a, w, s, p, d, l, o),
-            topi.cuda.schedule_conv2d_NCHWc_int8,
-            4,
-            False,
-        ),
-        # Disable on CI since it does not support spirv int8 dot product
-        # (
-        #     "vulkan -from_device=0",
-        #     lambda a, w, s, p, d, l, ol, o: topi.cuda.conv2d_NCHWc_int8(a, w, s, p, d, l, o),
-        #     topi.cuda.schedule_conv2d_NCHWc_int8,
-        #     4,
-        #     False,
-        # ),
-    ]
-
-    build_only_aarch64 = platform.machine() != "aarch64"
-
-    targets.append(
+@pytest.mark.parametrize("in_dtype", ["int8", "uint8"])
+@pytest.mark.parametrize(
+    "params",
+    [
+        # ResNet18 workloads where channels in / out are multiple of oc_block_factor
+        (1, 64, 56, 64, 3, 1, 1, 1, False, False),
+        (1, 64, 56, 64, 1, 1, 0, 1, False, False),
+        (1, 64, 56, 128, 3, 2, 1, 1, False, False),
+        (1, 64, 56, 128, 1, 2, 0, 1, False, False),
+        (1, 128, 28, 128, 3, 1, 1, 1, False, False),
+        (1, 128, 28, 256, 3, 2, 1, 1, False, False),
+        (1, 128, 28, 256, 1, 2, 0, 1, False, False),
+        (1, 256, 14, 256, 3, 1, 1, 1, False, False),
+        (1, 256, 14, 512, 3, 2, 1, 1, False, False),
+        (1, 256, 14, 512, 1, 2, 0, 1, False, False),
+        (1, 512, 7, 512, 3, 1, 1, 1, False, False),
+        # bias, relu
+        (1, 64, 56, 64, 3, 1, 1, 1, False, True),
+        (1, 64, 56, 64, 3, 1, 1, 1, True, False),
+        (1, 64, 56, 64, 3, 1, 1, 1, True, True),
+        # dilation = 2
+        (1, 64, 56, 64, 3, 1, 1, 2, False, False),
+        # batch size
+        (4, 64, 56, 64, 3, 1, 1, 1, False, False),
+        (9, 64, 56, 64, 3, 1, 1, 1, False, False),
+        # weird workloads
+        (4, 4, 4, 8, 4, 4, 4, 1, False, False),
+        # inception v3 workloads where channels in / out are multiple of oc_block_factor
+        (1, 32, 149, 32, 3, 1, 0, 1, False, False),
+        (1, 32, 147, 64, 3, 1, 1, 1, False, False),
+        (1, 64, 73, 80, 1, 1, 0, 1, False, False),
+        (1, 80, 73, 192, 3, 1, 0, 1, False, False),
+        (1, 192, 35, 64, 1, 1, 0, 1, False, False),
+        (1, 192, 35, 48, 1, 1, 0, 1, False, False),
+        (1, 48, 35, 64, 5, 1, 2, 1, False, False),
+        (1, 64, 35, 96, 3, 1, 1, 1, False, False),
+        (1, 96, 35, 96, 3, 1, 1, 1, False, False),
+        (1, 192, 35, 32, 1, 1, 0, 1, False, False),
+        (1, 256, 35, 64, 1, 1, 0, 1, False, False),
+        (1, 256, 35, 48, 1, 1, 0, 1, False, False),
+        (1, 288, 35, 64, 1, 1, 0, 1, False, False),
+        (1, 288, 35, 48, 1, 1, 0, 1, False, False),
+        (1, 288, 35, 384, 3, 2, 0, 1, False, False),
+        (1, 96, 35, 96, 3, 2, 0, 1, False, False),
+        (1, 768, 17, 192, 1, 1, 0, 1, False, False),
+        (1, 768, 17, 128, 1, 1, 0, 1, False, False),
+        (1, 128, 17, 128, 1, 1, 0, 1, False, False),
+        (1, 128, 17, 192, 7, 1, 3, 1, False, False),
+        (1, 128, 17, 128, 7, 1, 3, 1, False, False),
+        (1, 128, 17, 192, 1, 1, 0, 1, False, False),
+        (1, 768, 17, 160, 1, 1, 0, 1, False, False),
+        (1, 160, 17, 160, 1, 1, 0, 1, False, False),
+        (1, 160, 17, 192, 7, 1, 3, 1, False, False),
+        (1, 160, 17, 160, 7, 1, 3, 1, False, False),
+        (1, 160, 17, 192, 1, 1, 0, 1, False, False),
+        (1, 192, 17, 192, 1, 1, 0, 1, False, False),
+        (1, 192, 17, 192, 7, 1, 3, 1, False, False),
+        (1, 192, 17, 320, 3, 2, 0, 1, False, False),
+        (1, 192, 17, 192, 3, 2, 0, 1, False, False),
+        (1, 1280, 8, 320, 1, 1, 0, 1, False, False),
+        (1, 1280, 8, 384, 1, 1, 0, 1, False, False),
+        (1, 384, 8, 384, 1, 1, 0, 1, False, False),
+        (1, 384, 8, 384, 3, 1, 1, 1, False, False),
+        (1, 1280, 8, 448, 1, 1, 0, 1, False, False),
+        (1, 448, 8, 384, 3, 1, 1, 1, False, False),
+        (1, 1280, 8, 192, 1, 1, 0, 1, False, False),
+        (1, 2048, 8, 320, 1, 1, 0, 1, False, False),
+        (1, 2048, 8, 384, 1, 1, 0, 1, False, False),
+        (1, 2048, 8, 448, 1, 1, 0, 1, False, False),
+        (1, 2048, 8, 192, 1, 1, 0, 1, False, False),
+        (1, 1024, 19, 88, 3, 1, 1, 1, False, False),
+        # batch > 1
+        (7, 32, 149, 32, 3, 1, 0, 1, False, False),
+        (8, 32, 149, 32, 3, 1, 0, 1, False, False),
+        (32, 32, 149, 32, 3, 1, 0, 1, False, False),
+        # Asymmetric padding
+        (1, 32, 35, 64, 7, 2, (0, 0, 1, 1), 1, False, False),
+        (1, 64, 8, 128, 3, 1, (3, 3, 2, 2), 1, False, False),
+        (1, 64, 8, 64, 1, 1, (1, 2, 2, 1), 1, False, False),
+        (1, 64, 17, 192, 1, 1, (1, 2), 1, False, False),
+        (1, 64, 8, 64, 3, 1, (3, 1), 1, False, False),
+        (1, 128, 8, 384, 3, 1, (0, 2), 1, False, False),
+        (1, 64, 8, 64, 1, 1, "VALID", 1, False, False),
+        (1, 392, 8, 64, 3, 1, "VALID", 1, False, False),
+        (1, 512, 19, 64, 1, 1, "SAME", 1, False, False),
+        (1, 64, 16, 32, 2, 1, "SAME", 1, False, False),
+        (1, 64, 8, 64, 3, 1, (1, 2, 2, 1), 1, False, True),
+        (1, 64, 8, 64, 5, 2, (1, 3), 1, True, False),
+        (1, 64, 56, 64, 3, 1, "VALID", 1, True, True),
+        (1, 64, 56, 64, 24, 1, "SAME", 1, True, True),
+    ],
+)
+def test_conv2d_NCHWc_int8(in_dtype, params):
+    with Int8Fallback():
         (
-            "llvm -device arm_cpu -mtriple aarch64-linux-gnu -mattr=+neon,+v8.2a,+dotprod",
-            topi.arm_cpu.conv2d_NCHWc_int8,
-            topi.arm_cpu.schedule_conv2d_NCHWc_int8,
-            8,
-            build_only_aarch64,
+            batch,
+            in_channel,
+            in_size,
+            num_filter,
+            kernel,
+            stride,
+            padding,
+            dilation,
+            add_bias,
+            add_relu,
+        ) = params
+        pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (kernel, kernel))
+        padding_sum = pad_top + pad_left + pad_bottom + pad_right
+        print(
+            "Workload: (%d, %d, %d, %d, %d, %d, %d, %d)"
+            % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation)
         )
-    )
-
-    if in_dtype == "int8":
-        targets += [
-            (
-                "llvm -device arm_cpu -mtriple aarch64-linux-gnu -mattr=+neon",
-                topi.arm_cpu.conv2d_NCHWc_int8,
-                topi.arm_cpu.schedule_conv2d_NCHWc_int8,
-                8,
-                build_only_aarch64,
-            ),
-            (
-                "rocm -mattr=+dotprod",
-                lambda a, w, s, p, d, l, ol, o: topi.cuda.conv2d_NCHWc_int8(a, w, s, p, d, l, o),
-                topi.cuda.schedule_conv2d_NCHWc_int8,
-                4,
-                False,
-            ),
-        ]
-
-    for target, compute, schedule, oc_block_factor, build_only in targets:
-        check_target(target, compute, schedule, oc_block_factor, build_only)
-
-
-def verify_conv2d_nchw_int8(
-    in_dtype,
-    batch,
-    in_channel,
-    in_size,
-    num_filter,
-    kernel,
-    stride,
-    padding,
-    dilation=1,
-    add_bias=False,
-    add_relu=False,
-):
-    pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (kernel, kernel))
-    padding_sum = pad_top + pad_left + pad_bottom + pad_right
-    print(
-        "Workload: (%d, %d, %d, %d, %d, %d, %d, %d)"
-        % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation)
-    )
-
-    in_height = in_width = in_size
-
-    A = te.placeholder((batch, in_channel, in_height, in_width), name="A", dtype=in_dtype)
-    W = te.placeholder((num_filter, in_channel, kernel, kernel), name="W", dtype=in_dtype)
-    bias = te.placeholder((num_filter, 1, 1), name="bias", dtype=in_dtype)
-
-    a_shape = get_const_tuple(A.shape)
-    w_shape = get_const_tuple(W.shape)
-    bias_shape = get_const_tuple(bias.shape)
-    dtype = A.dtype
-
-    @memoize("topi.tests.test_topi_conv2d_int8.verify_conv2d_nchw")
-    def get_ref_data():
-        a_np = np.random.randint(low=-128, high=127, size=a_shape).astype(dtype)
-        w_np = np.random.randint(low=-128, high=128, size=w_shape).astype(dtype)
-        b_np = np.random.uniform(size=bias_shape).astype(dtype)
-        dw_np = tvm.topi.testing.dilate_python(w_np, (1, 1, dilation, dilation))
-        c_np = tvm.topi.testing.conv2d_nchw_python(a_np, dw_np, stride, padding).astype(dtype)
-
-        if add_bias:
-            b_np = np.random.uniform(size=bias_shape).astype(dtype)
-            c_np += b_np
-        if add_relu:
-            c_np = np.maximum(c_np, 0)
-
-        return a_np, w_np, b_np, c_np
-
-    a_np, w_np, b_np, c_np = get_ref_data()
-
-    def verify_workload_padding():
-        _, _, out_height, out_width = get_const_tuple(c_np.shape)
-        wkl = _get_workload(A, W, (stride, stride), padding, dilation, dtype)
-
-        # for testing functionality,
-        # we choose arbitrary int32_lanes and num_int8_elements can divide the channel,
-        # regardless of the performance.
-        int32_lanes, num_int8_elements = num_filter, in_channel
 
-        # check if tile_ow candidates are the factors of the right output weight.
-        cfg = autotvm.get_config()
-        fallback_schedule_cpu_common_int8(cfg, wkl, int32_lanes, num_int8_elements)
-        ow_tile = np.prod(cfg["tile_ow"].size)
-
-        tvm.testing.assert_allclose(ow_tile, out_width)
+        in_height = in_width = in_size
+
+        A = te.placeholder((batch, in_channel, in_height, in_width), name="A", dtype=in_dtype)
+        W = te.placeholder((num_filter, in_channel, kernel, kernel), name="W", dtype=in_dtype)
+
+        a_shape = get_const_tuple(A.shape)
+        w_shape = get_const_tuple(W.shape)
+        dtype = A.dtype
+        out_dtype = "int32" if in_dtype == "int8" else "uint32"
+        lo = -128 if in_dtype == "int8" else 0
+        hi = 127 if in_dtype == "int8" else 255
+
+        def check_target(target, compute, schedule, oc_block_factor, build_only):
+            dev = tvm.device(target, 0)
+            if not tvm.testing.device_enabled(target):
+                print("Skip because %s is not enabled" % target)
+                return

Review Comment:
   Done



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] Mousius commented on a diff in pull request #13669: [TOPI][bugfix] Fix a bug in arm_cpu int8 dotprod schedule and modernize tests

Posted by GitBox <gi...@apache.org>.
Mousius commented on code in PR #13669:
URL: https://github.com/apache/tvm/pull/13669#discussion_r1058354955


##########
python/tvm/topi/nn/conv2d.py:
##########
@@ -606,8 +606,8 @@ def conv2d_gemm_weight_transform(kernel, tile_rows, tile_cols):
     if N % tile_rows != 0:
         pad_N = tile_rows - (N % tile_rows)
 
-    if K % tile_cols != 0:
-        pad_K = tile_cols - (K % tile_cols)
+    if K % (tile_cols * 4) != 0:
+        pad_K = (tile_cols * 4) - (K % (tile_cols * 4))

Review Comment:
   We could maybe spell this out in code to make it clearer for people coming to it later 😸 
   
   I think this is roughly what's happening here?
   ```suggestion
       tile_size = 4
       untiled_cols =  tile_cols * tile_size
       misaligned_K = K % untiled_cols
       if misaligned_K != 0:
           pad_K = untiled_cols - misaligned_K
   ```



##########
tests/python/topi/python/test_topi_conv2d_int8.py:
##########
@@ -35,261 +35,138 @@
 import platform
 
 
-def compile_conv2d_NHWC_gemm_int8_arm(
-    batch,
-    in_channel,
-    in_size,
-    num_filter,
-    kernel,
-    stride,
-    padding,
-    dilation=1,
-    add_bias=False,
-    add_relu=False,
-):
-    pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (kernel, kernel))
-    padding_sum = pad_top + pad_left + pad_bottom + pad_right
-    print(
-        "Workload: (%d, %d, %d, %d, %d, %d, %d, %d)"
-        % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation)
-    )
-
-    in_height = in_width = in_size
-    A = te.placeholder((batch, in_height, in_width, in_channel), name="A", dtype="int8")
-    W = te.placeholder((kernel, kernel, in_channel, num_filter), name="W", dtype="int8")
-    bias = te.placeholder((num_filter,), name="bias", dtype="int8")
-    dtype = "int32"
-    devices = [
-        (
-            "llvm --device arm_cpu --mtriple aarch64-linux-gnu",
-            topi.arm_cpu.compute_conv2d_NHWC_quantized_interleaved,
-            topi.arm_cpu.schedule_conv2d_NHWC_quantized_interleaved,
-        ),
-        (
-            "llvm --device arm_cpu --mtriple aarch64-linux-gnu -mattr=+v8.2a,+dotprod",
-            topi.arm_cpu.compute_conv2d_NHWC_quantized_interleaved,
-            topi.arm_cpu.schedule_conv2d_NHWC_quantized_interleaved,
-        ),
-        (
-            "llvm --device arm_cpu --mtriple aarch64-linux-gnu -mattr=+v8.2a,+dotprod",
-            topi.arm_cpu.compute_conv2d_NHWC_quantized_native,
-            topi.arm_cpu.schedule_conv2d_NHWC_quantized_native,
-        ),
-        # TODO(giuseros) Need LLVM-11 in order to compile with +i8mm extension
-        # (
-        #   "llvm --device arm_cpu --mtriple aarch64-linux-gnu -mattr=+v8.2a,+i8mm",
-        #   topi.arm_cpu.compute_conv2d_NHWC_quantized_interleaved,
-        #   topi.arm_cpu.schedule_conv2d_NHWC_quantized_interleaved,
-        # ),
-    ]
-
-    for device_tuple in devices:
-        target = device_tuple[0]
-        compute = device_tuple[1]
-        schedule = device_tuple[2]
-
-        dev = tvm.device(target, 0)
-        if not tvm.testing.device_enabled(target):
-            print("Skip because %s is not enabled" % target)
-            return
-        print("Compiling on arm AArch64 target: %s" % target)
-        with tvm.target.Target(target) as tvm_target:
-            assert tvm_target.features.is_aarch64, "AArch64 target not recognized"
+devices = [
+    (
+        "llvm",
+        topi.arm_cpu.compute_conv2d_NHWC_quantized_interleaved,
+        topi.arm_cpu.schedule_conv2d_NHWC_quantized_interleaved,
+    ),
+    (
+        "llvm --device arm_cpu --mtriple aarch64-linux-gnu",
+        topi.arm_cpu.compute_conv2d_NHWC_quantized_interleaved,
+        topi.arm_cpu.schedule_conv2d_NHWC_quantized_interleaved,
+    ),
+    (
+        "llvm --device arm_cpu --mtriple aarch64-linux-gnu -mattr=+v8.2a,+dotprod",
+        topi.arm_cpu.compute_conv2d_NHWC_quantized_interleaved,
+        topi.arm_cpu.schedule_conv2d_NHWC_quantized_interleaved,
+    ),
+    (
+        "llvm --device arm_cpu --mtriple aarch64-linux-gnu -mattr=+v8.2a,+dotprod",
+        topi.arm_cpu.compute_conv2d_NHWC_quantized_native,
+        topi.arm_cpu.schedule_conv2d_NHWC_quantized_native,
+    ),
+    # TODO(giuseros) We need LLVM-11 in order to compile with +i8mm extension
+    # (
+    # "llvm --device arm_cpu --mtriple aarch64-linux-gnu -mattr=+v8.2a,+i8mm",
+    # topi.arm_cpu.compute_conv2d_NHWC_quantized_interleaved,
+    # topi.arm_cpu.schedule_conv2d_NHWC_quantized_interleaved,
+    # ),
+]
+
+
+@tvm.testing.requires_llvm
+@pytest.mark.parametrize("device", devices)
+@pytest.mark.parametrize(
+    "params",
+    [
+        # Subset of inception v3 expanded (dilation > 1, batch > 1, 'VALID' padding)
+        (1, 3, 299, 32, 3, 2, "SAME", 1, False, False),
+        (1, 32, 149, 32, 3, 1, "SAME", 2, False, False),
+        (4, 32, 147, 64, 3, 1, "SAME", 1, False, False),
+        (1, 64, 73, 80, 1, 1, "SAME", 1, False, False),
+        (1, 80, 73, 192, 3, 1, "SAME", 1, False, False),
+        (1, 192, 35, 48, 1, 1, "SAME", 1, False, False),
+        (1, 192, 35, 64, 1, 1, "VALID", 1, False, False),
+        (1, 192, 35, 32, 1, 1, "SAME", 1, False, False),
+        (1, 48, 35, 64, 5, 1, "SAME", 1, False, False),
+        (1, 96, 35, 96, 3, 1, "SAME", 1, False, False),
+        (1, 256, 35, 48, 1, 1, "SAME", 1, False, False),
+        (1, 256, 35, 64, 1, 1, "SAME", 1, False, False),
+        (1, 288, 35, 64, 1, 1, "SAME", 1, False, False),
+        (1, 288, 35, 48, 1, 1, "SAME", 1, False, False),
+        (1, 96, 35, 96, 3, 2, "SAME", 1, False, False),
+        (1, 128, 17, 192, 7, 1, "SAME", 2, False, False),
+        (1, 160, 17, 160, 7, 1, "SAME", 1, False, False),
+        (1, 160, 17, 192, 1, 1, "VALID", 1, False, False),
+        (1, 192, 17, 192, 1, 1, "SAME", 1, False, False),
+        (1, 768, 5, 128, 1, 1, "SAME", 1, False, False),
+        (1, 192, 17, 320, 3, 2, "SAME", 1, False, False),
+        (1, 192, 17, 192, 3, 2, "SAME", 1, False, False),
+        (1, 1280, 8, 192, 1, 1, "SAME", 1, False, False),
+        (1, 1280, 8, 384, 1, 1, "SAME", 1, False, False),
+        (1, 1280, 8, 320, 1, 1, "SAME", 1, False, False),
+        (1, 1280, 8, 448, 1, 1, "SAME", 1, False, False),
+        (1, 384, 8, 384, 1, 1, "SAME", 1, False, False),
+        (1, 384, 8, 384, 3, 1, "SAME", 1, False, False),
+        (1, 448, 8, 384, 3, 1, "VALID", 1, False, False),
+        (1, 2048, 8, 320, 1, 1, "SAME", 1, False, False),
+        (1, 2048, 8, 448, 1, 1, "SAME", 1, True, True),
+        (1, 2048, 8, 192, 1, 1, "SAME", 1, True, False),
+        # A trouble case for native schedule
+        (1, 8, 1, 24, 1, 1, "SAME", 1, False, False),
+    ],
+)
+def test_conv2d_NHWC_gemm_int8(params, device):
 
-            C = compute(A, W, (stride, stride), padding, (dilation, dilation), dtype)
-            if add_bias:
-                C = topi.add(C, bias)
-            if add_relu:
-                C = topi.nn.relu(C)
-            s = schedule([C])
+    with Int8Fallback():
+        target, compute, schedule = device
 
-        if add_bias:
-            tvm.build(
-                s,
-                [A, W, bias, C],
-                target,
-                name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
-                % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation),
-            )
-            func = tvm.build(
-                s,
-                [A, W, bias, C],
-                target,
-                name="relu_%dnnn_%d_%d_%d_%d_%d_%d_%d"
-                % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation),
-            )
-        else:
-            func = tvm.build(
-                s,
-                [A, W, C],
-                target,
-                name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
-                % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation),
-            )
+        (
+            batch,
+            in_channel,
+            in_size,
+            num_filter,
+            kernel,
+            stride,
+            padding,
+            dilation,
+            add_bias,
+            add_relu,
+        ) = params
+
+        # TODO(ekalda): These combinations hang during compilation
+        failing_cases = [
+            (devices[1], (1, 128, 17, 192, 7, 1, "SAME", 2, False, False)),
+            (devices[1], (1, 160, 17, 160, 7, 1, "SAME", 1, False, False)),
+            (
+                devices[1],
+                (1, 448, 8, 384, 3, 1, "VALID", 1, False, False),
+            ),  # this one passes but is just incredibly slow
+        ]
+        if (device, params) in failing_cases:
+            return

Review Comment:
   This will make it look like the test passed rather than is skipped, we should add these as parameters and expect a failure or mark it as a slow test so that future generations can see it from the pytest output:
   
   ```
   [
   ... other cases ...,
           pytest.param(devices[1], (1, 128, 17, 192, 7, 1, "SAME", 2, False, False), mark=pytest.mark.xfail),
           pytest.param(devices[1], (1, 160, 17, 160, 7, 1, "SAME", 1, False, False), mark=pytest.mark.xfail),
           pytest.param(devices[1], (1, 448, 8, 384, 3, 1, "VALID", 1, False, False), mark=tvm.testing.slow),
   ]
   ```



##########
tests/python/topi/python/test_topi_conv2d_int8.py:
##########
@@ -298,378 +175,462 @@ def get_ref_data():
 
         a_np, w_np, b_np, c_np = get_ref_data()
 
-        with tvm.target.Target(target):
-            C = compute(
-                A,
-                W,
-                (stride, stride),
-                padding,
-                (dilation, dilation),
-                "NCHW",
-                "NCHW",
-                out_dtype,
-            )
+        dev = tvm.device(target, 0)
+        with tvm.target.Target(target) as tvm_target:
+            C = compute(A, W, (stride, stride), padding, (dilation, dilation), dtype)
             if add_bias:
                 C = topi.add(C, bias)
             if add_relu:
                 C = topi.nn.relu(C)
             s = schedule([C])
 
-        a = tvm.nd.array(a_np.astype(dtype), dev)
-        w = tvm.nd.array(w_np.astype(dtype), dev)
-        b = tvm.nd.array(b_np.astype(out_dtype), dev)
-        c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), dev)
-
-        if add_bias:
-            compile_args = [A, W, bias, C]
-            run_args = [a, w, b, c]
-        else:
-            compile_args = [A, W, C]
-            run_args = [a, w, c]
-
-        func = tvm.build(
-            s,
-            compile_args,
-            target,
-            name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
-            % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation),
-        )
+            a = tvm.nd.array(a_np, dev)
+            w = tvm.nd.array(w_np, dev)
+            b = tvm.nd.array(b_np, dev)
+            c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), dev)
 
-        if build_only:
-            return
+            build_inputs = [A, W, bias, C] if add_bias else [A, W, C]
+            inference_inputs = (a, w, b, c) if add_bias else (a, w, c)
+
+            func = tvm.build(
+                s,
+                build_inputs,
+                target,
+                name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
+                % (
+                    batch,
+                    in_channel,
+                    in_size,
+                    num_filter,
+                    kernel,
+                    stride,
+                    padding_sum,
+                    dilation,
+                ),
+            )
 
-        print("Running on target: %s" % target)
+            build_only = tvm_target.features.is_aarch64 and (platform.machine() != "aarch64")
 
-        func(*run_args)
+            if not build_only:
+                print("Running on target: %s" % target)
+                func(*inference_inputs)
+                tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-5)
 
-        tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-5)
 
-    targets = [
-        (
-            "cuda",
-            lambda a, w, s, p, d, l, ol, o: topi.cuda.conv2d_NCHWc_int8(a, w, s, p, d, l, o),
-            topi.cuda.schedule_conv2d_NCHWc_int8,
-            4,
-            False,
-        ),
-        # Disable on CI since it does not support spirv int8 dot product
-        # (
-        #     "vulkan -from_device=0",
-        #     lambda a, w, s, p, d, l, ol, o: topi.cuda.conv2d_NCHWc_int8(a, w, s, p, d, l, o),
-        #     topi.cuda.schedule_conv2d_NCHWc_int8,
-        #     4,
-        #     False,
-        # ),
-    ]
-
-    build_only_aarch64 = platform.machine() != "aarch64"
-
-    targets.append(
+@pytest.mark.parametrize("in_dtype", ["int8", "uint8"])
+@pytest.mark.parametrize(
+    "params",
+    [
+        # ResNet18 workloads where channels in / out are multiple of oc_block_factor
+        (1, 64, 56, 64, 3, 1, 1, 1, False, False),
+        (1, 64, 56, 64, 1, 1, 0, 1, False, False),
+        (1, 64, 56, 128, 3, 2, 1, 1, False, False),
+        (1, 64, 56, 128, 1, 2, 0, 1, False, False),
+        (1, 128, 28, 128, 3, 1, 1, 1, False, False),
+        (1, 128, 28, 256, 3, 2, 1, 1, False, False),
+        (1, 128, 28, 256, 1, 2, 0, 1, False, False),
+        (1, 256, 14, 256, 3, 1, 1, 1, False, False),
+        (1, 256, 14, 512, 3, 2, 1, 1, False, False),
+        (1, 256, 14, 512, 1, 2, 0, 1, False, False),
+        (1, 512, 7, 512, 3, 1, 1, 1, False, False),
+        # bias, relu
+        (1, 64, 56, 64, 3, 1, 1, 1, False, True),
+        (1, 64, 56, 64, 3, 1, 1, 1, True, False),
+        (1, 64, 56, 64, 3, 1, 1, 1, True, True),
+        # dilation = 2
+        (1, 64, 56, 64, 3, 1, 1, 2, False, False),
+        # batch size
+        (4, 64, 56, 64, 3, 1, 1, 1, False, False),
+        (9, 64, 56, 64, 3, 1, 1, 1, False, False),
+        # weird workloads
+        (4, 4, 4, 8, 4, 4, 4, 1, False, False),
+        # inception v3 workloads where channels in / out are multiple of oc_block_factor
+        (1, 32, 149, 32, 3, 1, 0, 1, False, False),
+        (1, 32, 147, 64, 3, 1, 1, 1, False, False),
+        (1, 64, 73, 80, 1, 1, 0, 1, False, False),
+        (1, 80, 73, 192, 3, 1, 0, 1, False, False),
+        (1, 192, 35, 64, 1, 1, 0, 1, False, False),
+        (1, 192, 35, 48, 1, 1, 0, 1, False, False),
+        (1, 48, 35, 64, 5, 1, 2, 1, False, False),
+        (1, 64, 35, 96, 3, 1, 1, 1, False, False),
+        (1, 96, 35, 96, 3, 1, 1, 1, False, False),
+        (1, 192, 35, 32, 1, 1, 0, 1, False, False),
+        (1, 256, 35, 64, 1, 1, 0, 1, False, False),
+        (1, 256, 35, 48, 1, 1, 0, 1, False, False),
+        (1, 288, 35, 64, 1, 1, 0, 1, False, False),
+        (1, 288, 35, 48, 1, 1, 0, 1, False, False),
+        (1, 288, 35, 384, 3, 2, 0, 1, False, False),
+        (1, 96, 35, 96, 3, 2, 0, 1, False, False),
+        (1, 768, 17, 192, 1, 1, 0, 1, False, False),
+        (1, 768, 17, 128, 1, 1, 0, 1, False, False),
+        (1, 128, 17, 128, 1, 1, 0, 1, False, False),
+        (1, 128, 17, 192, 7, 1, 3, 1, False, False),
+        (1, 128, 17, 128, 7, 1, 3, 1, False, False),
+        (1, 128, 17, 192, 1, 1, 0, 1, False, False),
+        (1, 768, 17, 160, 1, 1, 0, 1, False, False),
+        (1, 160, 17, 160, 1, 1, 0, 1, False, False),
+        (1, 160, 17, 192, 7, 1, 3, 1, False, False),
+        (1, 160, 17, 160, 7, 1, 3, 1, False, False),
+        (1, 160, 17, 192, 1, 1, 0, 1, False, False),
+        (1, 192, 17, 192, 1, 1, 0, 1, False, False),
+        (1, 192, 17, 192, 7, 1, 3, 1, False, False),
+        (1, 192, 17, 320, 3, 2, 0, 1, False, False),
+        (1, 192, 17, 192, 3, 2, 0, 1, False, False),
+        (1, 1280, 8, 320, 1, 1, 0, 1, False, False),
+        (1, 1280, 8, 384, 1, 1, 0, 1, False, False),
+        (1, 384, 8, 384, 1, 1, 0, 1, False, False),
+        (1, 384, 8, 384, 3, 1, 1, 1, False, False),
+        (1, 1280, 8, 448, 1, 1, 0, 1, False, False),
+        (1, 448, 8, 384, 3, 1, 1, 1, False, False),
+        (1, 1280, 8, 192, 1, 1, 0, 1, False, False),
+        (1, 2048, 8, 320, 1, 1, 0, 1, False, False),
+        (1, 2048, 8, 384, 1, 1, 0, 1, False, False),
+        (1, 2048, 8, 448, 1, 1, 0, 1, False, False),
+        (1, 2048, 8, 192, 1, 1, 0, 1, False, False),
+        (1, 1024, 19, 88, 3, 1, 1, 1, False, False),
+        # batch > 1
+        (7, 32, 149, 32, 3, 1, 0, 1, False, False),
+        (8, 32, 149, 32, 3, 1, 0, 1, False, False),
+        (32, 32, 149, 32, 3, 1, 0, 1, False, False),
+        # Asymmetric padding
+        (1, 32, 35, 64, 7, 2, (0, 0, 1, 1), 1, False, False),
+        (1, 64, 8, 128, 3, 1, (3, 3, 2, 2), 1, False, False),
+        (1, 64, 8, 64, 1, 1, (1, 2, 2, 1), 1, False, False),
+        (1, 64, 17, 192, 1, 1, (1, 2), 1, False, False),
+        (1, 64, 8, 64, 3, 1, (3, 1), 1, False, False),
+        (1, 128, 8, 384, 3, 1, (0, 2), 1, False, False),
+        (1, 64, 8, 64, 1, 1, "VALID", 1, False, False),
+        (1, 392, 8, 64, 3, 1, "VALID", 1, False, False),
+        (1, 512, 19, 64, 1, 1, "SAME", 1, False, False),
+        (1, 64, 16, 32, 2, 1, "SAME", 1, False, False),
+        (1, 64, 8, 64, 3, 1, (1, 2, 2, 1), 1, False, True),
+        (1, 64, 8, 64, 5, 2, (1, 3), 1, True, False),
+        (1, 64, 56, 64, 3, 1, "VALID", 1, True, True),
+        (1, 64, 56, 64, 24, 1, "SAME", 1, True, True),
+    ],
+)
+def test_conv2d_NCHWc_int8(in_dtype, params):
+    with Int8Fallback():
         (
-            "llvm -device arm_cpu -mtriple aarch64-linux-gnu -mattr=+neon,+v8.2a,+dotprod",
-            topi.arm_cpu.conv2d_NCHWc_int8,
-            topi.arm_cpu.schedule_conv2d_NCHWc_int8,
-            8,
-            build_only_aarch64,
+            batch,
+            in_channel,
+            in_size,
+            num_filter,
+            kernel,
+            stride,
+            padding,
+            dilation,
+            add_bias,
+            add_relu,
+        ) = params
+        pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (kernel, kernel))
+        padding_sum = pad_top + pad_left + pad_bottom + pad_right
+        print(
+            "Workload: (%d, %d, %d, %d, %d, %d, %d, %d)"
+            % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation)
         )
-    )
-
-    if in_dtype == "int8":
-        targets += [
-            (
-                "llvm -device arm_cpu -mtriple aarch64-linux-gnu -mattr=+neon",
-                topi.arm_cpu.conv2d_NCHWc_int8,
-                topi.arm_cpu.schedule_conv2d_NCHWc_int8,
-                8,
-                build_only_aarch64,
-            ),
-            (
-                "rocm -mattr=+dotprod",
-                lambda a, w, s, p, d, l, ol, o: topi.cuda.conv2d_NCHWc_int8(a, w, s, p, d, l, o),
-                topi.cuda.schedule_conv2d_NCHWc_int8,
-                4,
-                False,
-            ),
-        ]
-
-    for target, compute, schedule, oc_block_factor, build_only in targets:
-        check_target(target, compute, schedule, oc_block_factor, build_only)
-
-
-def verify_conv2d_nchw_int8(
-    in_dtype,
-    batch,
-    in_channel,
-    in_size,
-    num_filter,
-    kernel,
-    stride,
-    padding,
-    dilation=1,
-    add_bias=False,
-    add_relu=False,
-):
-    pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (kernel, kernel))
-    padding_sum = pad_top + pad_left + pad_bottom + pad_right
-    print(
-        "Workload: (%d, %d, %d, %d, %d, %d, %d, %d)"
-        % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation)
-    )
-
-    in_height = in_width = in_size
-
-    A = te.placeholder((batch, in_channel, in_height, in_width), name="A", dtype=in_dtype)
-    W = te.placeholder((num_filter, in_channel, kernel, kernel), name="W", dtype=in_dtype)
-    bias = te.placeholder((num_filter, 1, 1), name="bias", dtype=in_dtype)
-
-    a_shape = get_const_tuple(A.shape)
-    w_shape = get_const_tuple(W.shape)
-    bias_shape = get_const_tuple(bias.shape)
-    dtype = A.dtype
-
-    @memoize("topi.tests.test_topi_conv2d_int8.verify_conv2d_nchw")
-    def get_ref_data():
-        a_np = np.random.randint(low=-128, high=127, size=a_shape).astype(dtype)
-        w_np = np.random.randint(low=-128, high=128, size=w_shape).astype(dtype)
-        b_np = np.random.uniform(size=bias_shape).astype(dtype)
-        dw_np = tvm.topi.testing.dilate_python(w_np, (1, 1, dilation, dilation))
-        c_np = tvm.topi.testing.conv2d_nchw_python(a_np, dw_np, stride, padding).astype(dtype)
-
-        if add_bias:
-            b_np = np.random.uniform(size=bias_shape).astype(dtype)
-            c_np += b_np
-        if add_relu:
-            c_np = np.maximum(c_np, 0)
-
-        return a_np, w_np, b_np, c_np
-
-    a_np, w_np, b_np, c_np = get_ref_data()
-
-    def verify_workload_padding():
-        _, _, out_height, out_width = get_const_tuple(c_np.shape)
-        wkl = _get_workload(A, W, (stride, stride), padding, dilation, dtype)
-
-        # for testing functionality,
-        # we choose arbitrary int32_lanes and num_int8_elements can divide the channel,
-        # regardless of the performance.
-        int32_lanes, num_int8_elements = num_filter, in_channel
 
-        # check if tile_ow candidates are the factors of the right output weight.
-        cfg = autotvm.get_config()
-        fallback_schedule_cpu_common_int8(cfg, wkl, int32_lanes, num_int8_elements)
-        ow_tile = np.prod(cfg["tile_ow"].size)
-
-        tvm.testing.assert_allclose(ow_tile, out_width)
+        in_height = in_width = in_size
+
+        A = te.placeholder((batch, in_channel, in_height, in_width), name="A", dtype=in_dtype)
+        W = te.placeholder((num_filter, in_channel, kernel, kernel), name="W", dtype=in_dtype)
+
+        a_shape = get_const_tuple(A.shape)
+        w_shape = get_const_tuple(W.shape)
+        dtype = A.dtype
+        out_dtype = "int32" if in_dtype == "int8" else "uint32"
+        lo = -128 if in_dtype == "int8" else 0
+        hi = 127 if in_dtype == "int8" else 255
+
+        def check_target(target, compute, schedule, oc_block_factor, build_only):
+            dev = tvm.device(target, 0)
+            if not tvm.testing.device_enabled(target):
+                print("Skip because %s is not enabled" % target)
+                return

Review Comment:
   Would be good to inform the test runner about this, even better would be to mark these as skipped earlier in the parameterize.
   
   ```suggestion
                   pytest.skip(reason="Skip because %s is not enabled" % target)
   ```



##########
tests/python/topi/python/test_topi_conv2d_int8.py:
##########
@@ -298,378 +175,462 @@ def get_ref_data():
 
         a_np, w_np, b_np, c_np = get_ref_data()
 
-        with tvm.target.Target(target):
-            C = compute(
-                A,
-                W,
-                (stride, stride),
-                padding,
-                (dilation, dilation),
-                "NCHW",
-                "NCHW",
-                out_dtype,
-            )
+        dev = tvm.device(target, 0)
+        with tvm.target.Target(target) as tvm_target:
+            C = compute(A, W, (stride, stride), padding, (dilation, dilation), dtype)
             if add_bias:
                 C = topi.add(C, bias)
             if add_relu:
                 C = topi.nn.relu(C)
             s = schedule([C])
 
-        a = tvm.nd.array(a_np.astype(dtype), dev)
-        w = tvm.nd.array(w_np.astype(dtype), dev)
-        b = tvm.nd.array(b_np.astype(out_dtype), dev)
-        c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), dev)
-
-        if add_bias:
-            compile_args = [A, W, bias, C]
-            run_args = [a, w, b, c]
-        else:
-            compile_args = [A, W, C]
-            run_args = [a, w, c]
-
-        func = tvm.build(
-            s,
-            compile_args,
-            target,
-            name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
-            % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation),
-        )
+            a = tvm.nd.array(a_np, dev)
+            w = tvm.nd.array(w_np, dev)
+            b = tvm.nd.array(b_np, dev)
+            c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), dev)
 
-        if build_only:
-            return
+            build_inputs = [A, W, bias, C] if add_bias else [A, W, C]
+            inference_inputs = (a, w, b, c) if add_bias else (a, w, c)
+
+            func = tvm.build(
+                s,
+                build_inputs,
+                target,
+                name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
+                % (
+                    batch,
+                    in_channel,
+                    in_size,
+                    num_filter,
+                    kernel,
+                    stride,
+                    padding_sum,
+                    dilation,
+                ),
+            )
 
-        print("Running on target: %s" % target)
+            build_only = tvm_target.features.is_aarch64 and (platform.machine() != "aarch64")
 
-        func(*run_args)
+            if not build_only:
+                print("Running on target: %s" % target)
+                func(*inference_inputs)
+                tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-5)
 
-        tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-5)
 
-    targets = [
-        (
-            "cuda",
-            lambda a, w, s, p, d, l, ol, o: topi.cuda.conv2d_NCHWc_int8(a, w, s, p, d, l, o),
-            topi.cuda.schedule_conv2d_NCHWc_int8,
-            4,
-            False,
-        ),
-        # Disable on CI since it does not support spirv int8 dot product
-        # (
-        #     "vulkan -from_device=0",
-        #     lambda a, w, s, p, d, l, ol, o: topi.cuda.conv2d_NCHWc_int8(a, w, s, p, d, l, o),
-        #     topi.cuda.schedule_conv2d_NCHWc_int8,
-        #     4,
-        #     False,
-        # ),
-    ]
-
-    build_only_aarch64 = platform.machine() != "aarch64"
-
-    targets.append(
+@pytest.mark.parametrize("in_dtype", ["int8", "uint8"])
+@pytest.mark.parametrize(
+    "params",
+    [
+        # ResNet18 workloads where channels in / out are multiple of oc_block_factor
+        (1, 64, 56, 64, 3, 1, 1, 1, False, False),
+        (1, 64, 56, 64, 1, 1, 0, 1, False, False),
+        (1, 64, 56, 128, 3, 2, 1, 1, False, False),
+        (1, 64, 56, 128, 1, 2, 0, 1, False, False),
+        (1, 128, 28, 128, 3, 1, 1, 1, False, False),
+        (1, 128, 28, 256, 3, 2, 1, 1, False, False),
+        (1, 128, 28, 256, 1, 2, 0, 1, False, False),
+        (1, 256, 14, 256, 3, 1, 1, 1, False, False),
+        (1, 256, 14, 512, 3, 2, 1, 1, False, False),
+        (1, 256, 14, 512, 1, 2, 0, 1, False, False),
+        (1, 512, 7, 512, 3, 1, 1, 1, False, False),
+        # bias, relu
+        (1, 64, 56, 64, 3, 1, 1, 1, False, True),
+        (1, 64, 56, 64, 3, 1, 1, 1, True, False),
+        (1, 64, 56, 64, 3, 1, 1, 1, True, True),
+        # dilation = 2
+        (1, 64, 56, 64, 3, 1, 1, 2, False, False),
+        # batch size
+        (4, 64, 56, 64, 3, 1, 1, 1, False, False),
+        (9, 64, 56, 64, 3, 1, 1, 1, False, False),
+        # weird workloads
+        (4, 4, 4, 8, 4, 4, 4, 1, False, False),
+        # inception v3 workloads where channels in / out are multiple of oc_block_factor
+        (1, 32, 149, 32, 3, 1, 0, 1, False, False),
+        (1, 32, 147, 64, 3, 1, 1, 1, False, False),
+        (1, 64, 73, 80, 1, 1, 0, 1, False, False),
+        (1, 80, 73, 192, 3, 1, 0, 1, False, False),
+        (1, 192, 35, 64, 1, 1, 0, 1, False, False),
+        (1, 192, 35, 48, 1, 1, 0, 1, False, False),
+        (1, 48, 35, 64, 5, 1, 2, 1, False, False),
+        (1, 64, 35, 96, 3, 1, 1, 1, False, False),
+        (1, 96, 35, 96, 3, 1, 1, 1, False, False),
+        (1, 192, 35, 32, 1, 1, 0, 1, False, False),
+        (1, 256, 35, 64, 1, 1, 0, 1, False, False),
+        (1, 256, 35, 48, 1, 1, 0, 1, False, False),
+        (1, 288, 35, 64, 1, 1, 0, 1, False, False),
+        (1, 288, 35, 48, 1, 1, 0, 1, False, False),
+        (1, 288, 35, 384, 3, 2, 0, 1, False, False),
+        (1, 96, 35, 96, 3, 2, 0, 1, False, False),
+        (1, 768, 17, 192, 1, 1, 0, 1, False, False),
+        (1, 768, 17, 128, 1, 1, 0, 1, False, False),
+        (1, 128, 17, 128, 1, 1, 0, 1, False, False),
+        (1, 128, 17, 192, 7, 1, 3, 1, False, False),
+        (1, 128, 17, 128, 7, 1, 3, 1, False, False),
+        (1, 128, 17, 192, 1, 1, 0, 1, False, False),
+        (1, 768, 17, 160, 1, 1, 0, 1, False, False),
+        (1, 160, 17, 160, 1, 1, 0, 1, False, False),
+        (1, 160, 17, 192, 7, 1, 3, 1, False, False),
+        (1, 160, 17, 160, 7, 1, 3, 1, False, False),
+        (1, 160, 17, 192, 1, 1, 0, 1, False, False),
+        (1, 192, 17, 192, 1, 1, 0, 1, False, False),
+        (1, 192, 17, 192, 7, 1, 3, 1, False, False),
+        (1, 192, 17, 320, 3, 2, 0, 1, False, False),
+        (1, 192, 17, 192, 3, 2, 0, 1, False, False),
+        (1, 1280, 8, 320, 1, 1, 0, 1, False, False),
+        (1, 1280, 8, 384, 1, 1, 0, 1, False, False),
+        (1, 384, 8, 384, 1, 1, 0, 1, False, False),
+        (1, 384, 8, 384, 3, 1, 1, 1, False, False),
+        (1, 1280, 8, 448, 1, 1, 0, 1, False, False),
+        (1, 448, 8, 384, 3, 1, 1, 1, False, False),
+        (1, 1280, 8, 192, 1, 1, 0, 1, False, False),
+        (1, 2048, 8, 320, 1, 1, 0, 1, False, False),
+        (1, 2048, 8, 384, 1, 1, 0, 1, False, False),
+        (1, 2048, 8, 448, 1, 1, 0, 1, False, False),
+        (1, 2048, 8, 192, 1, 1, 0, 1, False, False),
+        (1, 1024, 19, 88, 3, 1, 1, 1, False, False),
+        # batch > 1
+        (7, 32, 149, 32, 3, 1, 0, 1, False, False),
+        (8, 32, 149, 32, 3, 1, 0, 1, False, False),
+        (32, 32, 149, 32, 3, 1, 0, 1, False, False),
+        # Asymmetric padding
+        (1, 32, 35, 64, 7, 2, (0, 0, 1, 1), 1, False, False),
+        (1, 64, 8, 128, 3, 1, (3, 3, 2, 2), 1, False, False),
+        (1, 64, 8, 64, 1, 1, (1, 2, 2, 1), 1, False, False),
+        (1, 64, 17, 192, 1, 1, (1, 2), 1, False, False),
+        (1, 64, 8, 64, 3, 1, (3, 1), 1, False, False),
+        (1, 128, 8, 384, 3, 1, (0, 2), 1, False, False),
+        (1, 64, 8, 64, 1, 1, "VALID", 1, False, False),
+        (1, 392, 8, 64, 3, 1, "VALID", 1, False, False),
+        (1, 512, 19, 64, 1, 1, "SAME", 1, False, False),
+        (1, 64, 16, 32, 2, 1, "SAME", 1, False, False),
+        (1, 64, 8, 64, 3, 1, (1, 2, 2, 1), 1, False, True),
+        (1, 64, 8, 64, 5, 2, (1, 3), 1, True, False),
+        (1, 64, 56, 64, 3, 1, "VALID", 1, True, True),
+        (1, 64, 56, 64, 24, 1, "SAME", 1, True, True),
+    ],
+)
+def test_conv2d_NCHWc_int8(in_dtype, params):
+    with Int8Fallback():
         (
-            "llvm -device arm_cpu -mtriple aarch64-linux-gnu -mattr=+neon,+v8.2a,+dotprod",
-            topi.arm_cpu.conv2d_NCHWc_int8,
-            topi.arm_cpu.schedule_conv2d_NCHWc_int8,
-            8,
-            build_only_aarch64,
+            batch,
+            in_channel,
+            in_size,
+            num_filter,
+            kernel,
+            stride,
+            padding,
+            dilation,
+            add_bias,
+            add_relu,
+        ) = params
+        pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (kernel, kernel))
+        padding_sum = pad_top + pad_left + pad_bottom + pad_right
+        print(
+            "Workload: (%d, %d, %d, %d, %d, %d, %d, %d)"
+            % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation)
         )
-    )
-
-    if in_dtype == "int8":
-        targets += [
-            (
-                "llvm -device arm_cpu -mtriple aarch64-linux-gnu -mattr=+neon",
-                topi.arm_cpu.conv2d_NCHWc_int8,
-                topi.arm_cpu.schedule_conv2d_NCHWc_int8,
-                8,
-                build_only_aarch64,
-            ),
-            (
-                "rocm -mattr=+dotprod",
-                lambda a, w, s, p, d, l, ol, o: topi.cuda.conv2d_NCHWc_int8(a, w, s, p, d, l, o),
-                topi.cuda.schedule_conv2d_NCHWc_int8,
-                4,
-                False,
-            ),
-        ]
-
-    for target, compute, schedule, oc_block_factor, build_only in targets:
-        check_target(target, compute, schedule, oc_block_factor, build_only)
-
-
-def verify_conv2d_nchw_int8(
-    in_dtype,
-    batch,
-    in_channel,
-    in_size,
-    num_filter,
-    kernel,
-    stride,
-    padding,
-    dilation=1,
-    add_bias=False,
-    add_relu=False,
-):
-    pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (kernel, kernel))
-    padding_sum = pad_top + pad_left + pad_bottom + pad_right
-    print(
-        "Workload: (%d, %d, %d, %d, %d, %d, %d, %d)"
-        % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation)
-    )
-
-    in_height = in_width = in_size
-
-    A = te.placeholder((batch, in_channel, in_height, in_width), name="A", dtype=in_dtype)
-    W = te.placeholder((num_filter, in_channel, kernel, kernel), name="W", dtype=in_dtype)
-    bias = te.placeholder((num_filter, 1, 1), name="bias", dtype=in_dtype)
-
-    a_shape = get_const_tuple(A.shape)
-    w_shape = get_const_tuple(W.shape)
-    bias_shape = get_const_tuple(bias.shape)
-    dtype = A.dtype
-
-    @memoize("topi.tests.test_topi_conv2d_int8.verify_conv2d_nchw")
-    def get_ref_data():
-        a_np = np.random.randint(low=-128, high=127, size=a_shape).astype(dtype)
-        w_np = np.random.randint(low=-128, high=128, size=w_shape).astype(dtype)
-        b_np = np.random.uniform(size=bias_shape).astype(dtype)
-        dw_np = tvm.topi.testing.dilate_python(w_np, (1, 1, dilation, dilation))
-        c_np = tvm.topi.testing.conv2d_nchw_python(a_np, dw_np, stride, padding).astype(dtype)
-
-        if add_bias:
-            b_np = np.random.uniform(size=bias_shape).astype(dtype)
-            c_np += b_np
-        if add_relu:
-            c_np = np.maximum(c_np, 0)
-
-        return a_np, w_np, b_np, c_np
-
-    a_np, w_np, b_np, c_np = get_ref_data()
-
-    def verify_workload_padding():
-        _, _, out_height, out_width = get_const_tuple(c_np.shape)
-        wkl = _get_workload(A, W, (stride, stride), padding, dilation, dtype)
-
-        # for testing functionality,
-        # we choose arbitrary int32_lanes and num_int8_elements can divide the channel,
-        # regardless of the performance.
-        int32_lanes, num_int8_elements = num_filter, in_channel
 
-        # check if tile_ow candidates are the factors of the right output weight.
-        cfg = autotvm.get_config()
-        fallback_schedule_cpu_common_int8(cfg, wkl, int32_lanes, num_int8_elements)
-        ow_tile = np.prod(cfg["tile_ow"].size)
-
-        tvm.testing.assert_allclose(ow_tile, out_width)
+        in_height = in_width = in_size
+
+        A = te.placeholder((batch, in_channel, in_height, in_width), name="A", dtype=in_dtype)
+        W = te.placeholder((num_filter, in_channel, kernel, kernel), name="W", dtype=in_dtype)
+
+        a_shape = get_const_tuple(A.shape)
+        w_shape = get_const_tuple(W.shape)
+        dtype = A.dtype
+        out_dtype = "int32" if in_dtype == "int8" else "uint32"
+        lo = -128 if in_dtype == "int8" else 0
+        hi = 127 if in_dtype == "int8" else 255
+
+        def check_target(target, compute, schedule, oc_block_factor, build_only):
+            dev = tvm.device(target, 0)
+            if not tvm.testing.device_enabled(target):
+                print("Skip because %s is not enabled" % target)
+                return
+            if target == "cuda" and not tvm.contrib.nvcc.have_int8(dev.compute_version):
+                print("Skip because int8 intrinsics are not available")
+                return
+
+            bias = te.placeholder(
+                (num_filter // oc_block_factor, 1, 1, oc_block_factor), name="bias", dtype=out_dtype
+            )
+            bias_shape = get_const_tuple(bias.shape)
 
-    def check_target(target):
-        dev = tvm.device(target, 0)
-        if not tvm.testing.device_enabled(target):
-            print("Skip because %s is not enabled" % target)
-            return
-        if target == "cuda" and not tvm.contrib.nvcc.have_int8(dev.compute_version):
-            print("Skip because int8 intrinsics are not available")
-            return
+            @memoize("topi.tests.test_topi_conv2d_int8.verify_conv2d_nchw")
+            def get_ref_data():
+                a_np = np.random.randint(low=lo, high=hi, size=a_shape).astype(out_dtype)
+                w_np = np.random.randint(low=lo, high=hi, size=w_shape).astype(out_dtype)
+                b_np = np.random.uniform(size=bias_shape).astype(out_dtype)
+                dw_np = tvm.topi.testing.dilate_python(w_np, (1, 1, dilation, dilation))
+                c_np = tvm.topi.testing.conv2d_nchw_python(a_np, dw_np, stride, padding).astype(
+                    out_dtype
+                )
+
+                # convert to NCHWc
+                _, _, out_height, out_width = c_np.shape
+                c_np = c_np.reshape(
+                    (batch, num_filter // oc_block_factor, oc_block_factor, out_height, out_width)
+                ).transpose(0, 1, 3, 4, 2)
+
+                if add_bias:
+                    b_np = np.random.uniform(size=bias_shape).astype(out_dtype)
+                    c_np += b_np
+                if add_relu:
+                    c_np = np.maximum(c_np, 0)
+
+                return a_np, w_np, b_np, c_np
+
+            a_np, w_np, b_np, c_np = get_ref_data()
+
+            with tvm.target.Target(target):
+                C = compute(
+                    A,
+                    W,
+                    (stride, stride),
+                    padding,
+                    (dilation, dilation),
+                    "NCHW",
+                    "NCHW",
+                    out_dtype,
+                )
+                if add_bias:
+                    C = topi.add(C, bias)
+                if add_relu:
+                    C = topi.nn.relu(C)
+                s = schedule([C])
+
+            a = tvm.nd.array(a_np.astype(dtype), dev)
+            w = tvm.nd.array(w_np.astype(dtype), dev)
+            b = tvm.nd.array(b_np.astype(out_dtype), dev)
+            c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), dev)
 
-        print("Running on target: %s" % target)
-        with tvm.target.Target(target):
-            C = topi.cuda.conv2d_nchw_int8(
-                A, W, (stride, stride), padding, (dilation, dilation), dtype
-            )
             if add_bias:
-                C = topi.add(C, bias)
-            if add_relu:
-                C = topi.nn.relu(C)
-            s = topi.cuda.schedule_conv2d_nchw_int8([C])
-
-        a = tvm.nd.array(a_np, dev)
-        w = tvm.nd.array(w_np, dev)
-        b = tvm.nd.array(b_np, dev)
-        c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), dev)
-        if add_bias:
-            tvm.build(
-                s,
-                [A, W, bias, C],
-                target,
-                name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
-                % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation),
-            )
-            func = tvm.build(
-                s,
-                [A, W, bias, C],
-                target,
-                name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
-                % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation),
-            )
-            func(a, w, b, c)
-        else:
+                compile_args = [A, W, bias, C]
+                run_args = [a, w, b, c]
+            else:
+                compile_args = [A, W, C]
+                run_args = [a, w, c]
+
             func = tvm.build(
                 s,
-                [A, W, C],
+                compile_args,
                 target,
                 name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
                 % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation),
             )
-            func(a, w, c)
-        tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-5)
 
-    verify_workload_padding()
+            if build_only:
+                return
 
-    for target in ["cuda"]:
-        check_target(target)
+            print("Running on target: %s" % target)
 
+            func(*run_args)
 
-@pytest.mark.parametrize("in_dtype", ["int8", "uint8"])
-def test_conv2d_nchw(in_dtype):
-    with Int8Fallback():
-        # ResNet18 workloads where channels in / out are multiple of oc_block_factor
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 56, 64, 3, 1, 1)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 56, 64, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 56, 128, 3, 2, 1)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 56, 128, 1, 2, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 128, 28, 128, 3, 1, 1)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 128, 28, 256, 3, 2, 1)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 128, 28, 256, 1, 2, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 256, 14, 256, 3, 1, 1)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 256, 14, 512, 3, 2, 1)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 256, 14, 512, 1, 2, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 512, 7, 512, 3, 1, 1)
+            tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-5)
 
-        # bias, relu
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 56, 64, 3, 1, 1, add_relu=True)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 56, 64, 3, 1, 1, add_bias=True)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 56, 64, 3, 1, 1, add_bias=True, add_relu=True)
+        targets = [
+            (
+                "cuda",
+                lambda a, w, s, p, d, l, ol, o: topi.cuda.conv2d_NCHWc_int8(a, w, s, p, d, l, o),
+                topi.cuda.schedule_conv2d_NCHWc_int8,
+                4,
+                False,
+            ),
+            # Disable on CI since it does not support spirv int8 dot product
+            # (
+            #     "vulkan -from_device=0",
+            #     lambda a, w, s, p, d, l, ol, o: topi.cuda.conv2d_NCHWc_int8(a, w, s, p, d, l, o),
+            #     topi.cuda.schedule_conv2d_NCHWc_int8,
+            #     4,
+            #     False,
+            # ),
+        ]
 
-        # dilation = 2
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 56, 64, 3, 1, 1, dilation=2)
+        build_only_aarch64 = platform.machine() != "aarch64"
 
-        # batch size
-        verify_conv2d_NCHWc_int8(in_dtype, 4, 64, 56, 64, 3, 1, 1)
-        verify_conv2d_NCHWc_int8(in_dtype, 9, 64, 56, 64, 3, 1, 1)
+        targets.append(
+            (
+                "llvm -device arm_cpu -mtriple aarch64-linux-gnu -mattr=+neon,+v8.2a,+dotprod",
+                topi.arm_cpu.conv2d_NCHWc_int8,
+                topi.arm_cpu.schedule_conv2d_NCHWc_int8,
+                8,
+                build_only_aarch64,
+            )
+        )
 
-        # weird workloads
-        verify_conv2d_NCHWc_int8(in_dtype, 4, 4, 4, 8, 4, 4, 4)
+        if in_dtype == "int8":
+            targets += [
+                (
+                    "llvm -device arm_cpu -mtriple aarch64-linux-gnu -mattr=+neon",
+                    topi.arm_cpu.conv2d_NCHWc_int8,
+                    topi.arm_cpu.schedule_conv2d_NCHWc_int8,
+                    8,
+                    build_only_aarch64,
+                ),
+                (
+                    "rocm -mattr=+dotprod",
+                    lambda a, w, s, p, d, l, ol, o: topi.cuda.conv2d_NCHWc_int8(
+                        a, w, s, p, d, l, o
+                    ),
+                    topi.cuda.schedule_conv2d_NCHWc_int8,
+                    4,
+                    False,
+                ),
+            ]
+
+        for target, compute, schedule, oc_block_factor, build_only in targets:
+            check_target(target, compute, schedule, oc_block_factor, build_only)
+
+
+# Conv2d NCHW int8 schedule testing. Internally, it uses NCHWc schedule. So, just
+# performing basic testing - one test for all different scenarios - batch, dilation etc..
+@pytest.mark.parametrize("in_dtype", ["int8", "uint8"])
+@pytest.mark.parametrize(
+    "params",
+    [
+        (1, 64, 56, 64, 3, 1, 1, 1, False, False),
+        (1, 64, 56, 64, 3, 1, 1, 1, False, True),
+        (1, 64, 56, 64, 3, 1, 1, 2, False, False),
+        (9, 64, 56, 64, 3, 1, 1, 1, False, False),
+        (4, 4, 4, 4, 4, 4, 4, 1, False, False),
+        (1, 32, 149, 32, 3, 1, 0, 1, False, False),
+        (7, 32, 149, 32, 3, 1, 0, 1, False, False),
+        (1, 32, 35, 64, 7, 2, (0, 0, 1, 1), 1, False, False),
+        (1, 32, 35, 64, 7, 2, (0, 0, 2, 2), 1, False, False),
+    ],
+)
+def test_conv2d_nchw_int8(in_dtype, params):
+    with Int8Fallback():
+        (
+            batch,
+            in_channel,
+            in_size,
+            num_filter,
+            kernel,
+            stride,
+            padding,
+            dilation,
+            add_bias,
+            add_relu,
+        ) = params
+        pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (kernel, kernel))
+        padding_sum = pad_top + pad_left + pad_bottom + pad_right
+        print(
+            "Workload: (%d, %d, %d, %d, %d, %d, %d, %d)"
+            % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation)
+        )
 
-        # inception v3 workloads where channels in / out are multiple of oc_block_factor
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 32, 149, 32, 3, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 32, 147, 64, 3, 1, 1)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 73, 80, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 80, 73, 192, 3, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 192, 35, 64, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 192, 35, 48, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 48, 35, 64, 5, 1, 2)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 35, 96, 3, 1, 1)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 96, 35, 96, 3, 1, 1)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 192, 35, 32, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 256, 35, 64, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 256, 35, 48, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 288, 35, 64, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 288, 35, 48, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 288, 35, 384, 3, 2, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 96, 35, 96, 3, 2, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 768, 17, 192, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 768, 17, 128, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 128, 17, 128, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 128, 17, 192, 7, 1, 3)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 128, 17, 128, 7, 1, 3)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 128, 17, 192, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 768, 17, 160, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 160, 17, 160, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 160, 17, 192, 7, 1, 3)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 160, 17, 160, 7, 1, 3)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 160, 17, 192, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 192, 17, 192, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 192, 17, 192, 7, 1, 3)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 192, 17, 320, 3, 2, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 192, 17, 192, 3, 2, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 1280, 8, 320, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 1280, 8, 384, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 384, 8, 384, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 384, 8, 384, 3, 1, 1)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 1280, 8, 448, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 448, 8, 384, 3, 1, 1)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 1280, 8, 192, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 2048, 8, 320, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 2048, 8, 384, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 2048, 8, 448, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 2048, 8, 192, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 1024, 19, 88, 3, 1, 1)
+        in_height = in_width = in_size
 
-        # batch > 1
-        verify_conv2d_NCHWc_int8(in_dtype, 7, 32, 149, 32, 3, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 8, 32, 149, 32, 3, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 32, 32, 149, 32, 3, 1, 0)
+        A = te.placeholder((batch, in_channel, in_height, in_width), name="A", dtype=in_dtype)
+        W = te.placeholder((num_filter, in_channel, kernel, kernel), name="W", dtype=in_dtype)
+        bias = te.placeholder((num_filter, 1, 1), name="bias", dtype=in_dtype)
 
-        # Asymmetric padding
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 32, 35, 64, 7, 2, (0, 0, 1, 1))
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 8, 128, 3, 1, (3, 3, 2, 2))
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 8, 64, 1, 1, (1, 2, 2, 1))
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 17, 192, 1, 1, (1, 2))
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 8, 64, 3, 1, (3, 1))
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 128, 8, 384, 3, 1, (0, 2))
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 8, 64, 1, 1, "VALID")
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 392, 8, 64, 3, 1, "VALID")
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 512, 19, 64, 1, 1, "SAME")
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 16, 32, 2, 1, "SAME")
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 8, 64, 3, 1, (1, 2, 2, 1), add_relu=True)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 8, 64, 5, 2, (1, 3), add_bias=True)
-        verify_conv2d_NCHWc_int8(
-            in_dtype, 1, 64, 56, 64, 3, 1, "VALID", add_bias=True, add_relu=True
-        )
-        verify_conv2d_NCHWc_int8(
-            in_dtype, 1, 64, 56, 64, 24, 1, "SAME", add_bias=True, add_relu=True
-        )
+        a_shape = get_const_tuple(A.shape)
+        w_shape = get_const_tuple(W.shape)
+        bias_shape = get_const_tuple(bias.shape)
+        dtype = A.dtype
+
+        @memoize("topi.tests.test_topi_conv2d_int8.verify_conv2d_nchw")
+        def get_ref_data():
+            a_np = np.random.randint(low=-128, high=127, size=a_shape).astype(dtype)
+            w_np = np.random.randint(low=-128, high=128, size=w_shape).astype(dtype)
+            b_np = np.random.uniform(size=bias_shape).astype(dtype)
+            dw_np = tvm.topi.testing.dilate_python(w_np, (1, 1, dilation, dilation))
+            c_np = tvm.topi.testing.conv2d_nchw_python(a_np, dw_np, stride, padding).astype(dtype)
 
-        # Conv2d NCHW int8 schedule testing. Internally, it uses NCHWc schedule. So, just
-        # performing basic testing - one test for all different scenarios - batch, dilation etc..
-        verify_conv2d_nchw_int8(in_dtype, 1, 64, 56, 64, 3, 1, 1)
-        verify_conv2d_nchw_int8(in_dtype, 1, 64, 56, 64, 3, 1, 1, add_relu=True)
-        verify_conv2d_nchw_int8(in_dtype, 1, 64, 56, 64, 3, 1, 1, dilation=2)
-        verify_conv2d_nchw_int8(in_dtype, 9, 64, 56, 64, 3, 1, 1)
-        verify_conv2d_nchw_int8(in_dtype, 4, 4, 4, 4, 4, 4, 4)
-        verify_conv2d_nchw_int8(in_dtype, 1, 32, 149, 32, 3, 1, 0)
-        verify_conv2d_nchw_int8(in_dtype, 7, 32, 149, 32, 3, 1, 0)
-        verify_conv2d_nchw_int8(in_dtype, 1, 32, 35, 64, 7, 2, (0, 0, 1, 1))
-        verify_conv2d_nchw_int8(in_dtype, 1, 32, 35, 64, 7, 2, (0, 0, 2, 2))
+            if add_bias:
+                b_np = np.random.uniform(size=bias_shape).astype(dtype)
+                c_np += b_np
+            if add_relu:
+                c_np = np.maximum(c_np, 0)
 
+            return a_np, w_np, b_np, c_np
 
-def test_conv2d_nhwc():
-    with Int8Fallback():
-        # Subset of inception v3 expanded (dilation > 1, batch > 1, 'VALID' padding)
-        verify_conv2d_NHWC_gemm_int8(1, 3, 299, 32, 3, 2, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 32, 149, 32, 3, 1, "SAME", dilation=2)
-        verify_conv2d_NHWC_gemm_int8(4, 32, 147, 64, 3, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 64, 73, 80, 1, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 80, 73, 192, 3, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 192, 35, 48, 1, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 192, 35, 64, 1, 1, "VALID")
-        verify_conv2d_NHWC_gemm_int8(1, 192, 35, 32, 1, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 48, 35, 64, 5, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 96, 35, 96, 3, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 256, 35, 48, 1, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 256, 35, 64, 1, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 288, 35, 64, 1, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 288, 35, 48, 1, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 96, 35, 96, 3, 2, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 128, 17, 192, 7, 1, "SAME", dilation=2)
-        verify_conv2d_NHWC_gemm_int8(1, 160, 17, 160, 7, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 160, 17, 192, 1, 1, "VALID")
-        verify_conv2d_NHWC_gemm_int8(1, 192, 17, 192, 1, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 768, 5, 128, 1, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 192, 17, 320, 3, 2, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 192, 17, 192, 3, 2, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 1280, 8, 192, 1, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 1280, 8, 384, 1, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 1280, 8, 320, 1, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 1280, 8, 448, 1, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 384, 8, 384, 1, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 384, 8, 384, 3, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 448, 8, 384, 3, 1, "VALID")
-        verify_conv2d_NHWC_gemm_int8(1, 2048, 8, 320, 1, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 2048, 8, 448, 1, 1, "SAME", add_bias=True, add_relu=True)
-        verify_conv2d_NHWC_gemm_int8(1, 2048, 8, 192, 1, 1, "SAME", add_bias=True)
-
-        # Let's also verify that it compiles fine on AArch64 targets
-        compile_conv2d_NHWC_gemm_int8_arm(1, 3, 299, 32, 3, 2, "SAME")
+        a_np, w_np, b_np, c_np = get_ref_data()
+
+        def verify_workload_padding():
+            _, _, out_height, out_width = get_const_tuple(c_np.shape)
+            wkl = _get_workload(A, W, (stride, stride), padding, dilation, dtype)
+
+            # for testing functionality,
+            # we choose arbitrary int32_lanes and num_int8_elements can divide the channel,
+            # regardless of the performance.
+            int32_lanes, num_int8_elements = num_filter, in_channel
+
+            # check if tile_ow candidates are the factors of the right output weight.
+            cfg = autotvm.get_config()
+            fallback_schedule_cpu_common_int8(cfg, wkl, int32_lanes, num_int8_elements)
+            ow_tile = np.prod(cfg["tile_ow"].size)
+
+            tvm.testing.assert_allclose(ow_tile, out_width)
+
+        def check_target(target):
+            dev = tvm.device(target, 0)
+            if not tvm.testing.device_enabled(target):
+                print("Skip because %s is not enabled" % target)
+                return
+            if target == "cuda" and not tvm.contrib.nvcc.have_int8(dev.compute_version):
+                print("Skip because int8 intrinsics are not available")
+                return

Review Comment:
   `pytest.skip` as above, can we re-use the same function by hoisting it out of the test?



##########
tests/python/topi/python/test_topi_conv2d_int8.py:
##########
@@ -298,378 +175,462 @@ def get_ref_data():
 
         a_np, w_np, b_np, c_np = get_ref_data()
 
-        with tvm.target.Target(target):
-            C = compute(
-                A,
-                W,
-                (stride, stride),
-                padding,
-                (dilation, dilation),
-                "NCHW",
-                "NCHW",
-                out_dtype,
-            )
+        dev = tvm.device(target, 0)
+        with tvm.target.Target(target) as tvm_target:
+            C = compute(A, W, (stride, stride), padding, (dilation, dilation), dtype)
             if add_bias:
                 C = topi.add(C, bias)
             if add_relu:
                 C = topi.nn.relu(C)
             s = schedule([C])
 
-        a = tvm.nd.array(a_np.astype(dtype), dev)
-        w = tvm.nd.array(w_np.astype(dtype), dev)
-        b = tvm.nd.array(b_np.astype(out_dtype), dev)
-        c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), dev)
-
-        if add_bias:
-            compile_args = [A, W, bias, C]
-            run_args = [a, w, b, c]
-        else:
-            compile_args = [A, W, C]
-            run_args = [a, w, c]
-
-        func = tvm.build(
-            s,
-            compile_args,
-            target,
-            name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
-            % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation),
-        )
+            a = tvm.nd.array(a_np, dev)
+            w = tvm.nd.array(w_np, dev)
+            b = tvm.nd.array(b_np, dev)
+            c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), dev)
 
-        if build_only:
-            return
+            build_inputs = [A, W, bias, C] if add_bias else [A, W, C]
+            inference_inputs = (a, w, b, c) if add_bias else (a, w, c)
+
+            func = tvm.build(
+                s,
+                build_inputs,
+                target,
+                name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
+                % (
+                    batch,
+                    in_channel,
+                    in_size,
+                    num_filter,
+                    kernel,
+                    stride,
+                    padding_sum,
+                    dilation,
+                ),
+            )
 
-        print("Running on target: %s" % target)
+            build_only = tvm_target.features.is_aarch64 and (platform.machine() != "aarch64")
 
-        func(*run_args)
+            if not build_only:
+                print("Running on target: %s" % target)
+                func(*inference_inputs)
+                tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-5)
 
-        tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-5)
 
-    targets = [
-        (
-            "cuda",
-            lambda a, w, s, p, d, l, ol, o: topi.cuda.conv2d_NCHWc_int8(a, w, s, p, d, l, o),
-            topi.cuda.schedule_conv2d_NCHWc_int8,
-            4,
-            False,
-        ),
-        # Disable on CI since it does not support spirv int8 dot product
-        # (
-        #     "vulkan -from_device=0",
-        #     lambda a, w, s, p, d, l, ol, o: topi.cuda.conv2d_NCHWc_int8(a, w, s, p, d, l, o),
-        #     topi.cuda.schedule_conv2d_NCHWc_int8,
-        #     4,
-        #     False,
-        # ),
-    ]
-
-    build_only_aarch64 = platform.machine() != "aarch64"
-
-    targets.append(
+@pytest.mark.parametrize("in_dtype", ["int8", "uint8"])
+@pytest.mark.parametrize(
+    "params",
+    [
+        # ResNet18 workloads where channels in / out are multiple of oc_block_factor
+        (1, 64, 56, 64, 3, 1, 1, 1, False, False),
+        (1, 64, 56, 64, 1, 1, 0, 1, False, False),
+        (1, 64, 56, 128, 3, 2, 1, 1, False, False),
+        (1, 64, 56, 128, 1, 2, 0, 1, False, False),
+        (1, 128, 28, 128, 3, 1, 1, 1, False, False),
+        (1, 128, 28, 256, 3, 2, 1, 1, False, False),
+        (1, 128, 28, 256, 1, 2, 0, 1, False, False),
+        (1, 256, 14, 256, 3, 1, 1, 1, False, False),
+        (1, 256, 14, 512, 3, 2, 1, 1, False, False),
+        (1, 256, 14, 512, 1, 2, 0, 1, False, False),
+        (1, 512, 7, 512, 3, 1, 1, 1, False, False),
+        # bias, relu
+        (1, 64, 56, 64, 3, 1, 1, 1, False, True),
+        (1, 64, 56, 64, 3, 1, 1, 1, True, False),
+        (1, 64, 56, 64, 3, 1, 1, 1, True, True),
+        # dilation = 2
+        (1, 64, 56, 64, 3, 1, 1, 2, False, False),
+        # batch size
+        (4, 64, 56, 64, 3, 1, 1, 1, False, False),
+        (9, 64, 56, 64, 3, 1, 1, 1, False, False),
+        # weird workloads
+        (4, 4, 4, 8, 4, 4, 4, 1, False, False),
+        # inception v3 workloads where channels in / out are multiple of oc_block_factor
+        (1, 32, 149, 32, 3, 1, 0, 1, False, False),
+        (1, 32, 147, 64, 3, 1, 1, 1, False, False),
+        (1, 64, 73, 80, 1, 1, 0, 1, False, False),
+        (1, 80, 73, 192, 3, 1, 0, 1, False, False),
+        (1, 192, 35, 64, 1, 1, 0, 1, False, False),
+        (1, 192, 35, 48, 1, 1, 0, 1, False, False),
+        (1, 48, 35, 64, 5, 1, 2, 1, False, False),
+        (1, 64, 35, 96, 3, 1, 1, 1, False, False),
+        (1, 96, 35, 96, 3, 1, 1, 1, False, False),
+        (1, 192, 35, 32, 1, 1, 0, 1, False, False),
+        (1, 256, 35, 64, 1, 1, 0, 1, False, False),
+        (1, 256, 35, 48, 1, 1, 0, 1, False, False),
+        (1, 288, 35, 64, 1, 1, 0, 1, False, False),
+        (1, 288, 35, 48, 1, 1, 0, 1, False, False),
+        (1, 288, 35, 384, 3, 2, 0, 1, False, False),
+        (1, 96, 35, 96, 3, 2, 0, 1, False, False),
+        (1, 768, 17, 192, 1, 1, 0, 1, False, False),
+        (1, 768, 17, 128, 1, 1, 0, 1, False, False),
+        (1, 128, 17, 128, 1, 1, 0, 1, False, False),
+        (1, 128, 17, 192, 7, 1, 3, 1, False, False),
+        (1, 128, 17, 128, 7, 1, 3, 1, False, False),
+        (1, 128, 17, 192, 1, 1, 0, 1, False, False),
+        (1, 768, 17, 160, 1, 1, 0, 1, False, False),
+        (1, 160, 17, 160, 1, 1, 0, 1, False, False),
+        (1, 160, 17, 192, 7, 1, 3, 1, False, False),
+        (1, 160, 17, 160, 7, 1, 3, 1, False, False),
+        (1, 160, 17, 192, 1, 1, 0, 1, False, False),
+        (1, 192, 17, 192, 1, 1, 0, 1, False, False),
+        (1, 192, 17, 192, 7, 1, 3, 1, False, False),
+        (1, 192, 17, 320, 3, 2, 0, 1, False, False),
+        (1, 192, 17, 192, 3, 2, 0, 1, False, False),
+        (1, 1280, 8, 320, 1, 1, 0, 1, False, False),
+        (1, 1280, 8, 384, 1, 1, 0, 1, False, False),
+        (1, 384, 8, 384, 1, 1, 0, 1, False, False),
+        (1, 384, 8, 384, 3, 1, 1, 1, False, False),
+        (1, 1280, 8, 448, 1, 1, 0, 1, False, False),
+        (1, 448, 8, 384, 3, 1, 1, 1, False, False),
+        (1, 1280, 8, 192, 1, 1, 0, 1, False, False),
+        (1, 2048, 8, 320, 1, 1, 0, 1, False, False),
+        (1, 2048, 8, 384, 1, 1, 0, 1, False, False),
+        (1, 2048, 8, 448, 1, 1, 0, 1, False, False),
+        (1, 2048, 8, 192, 1, 1, 0, 1, False, False),
+        (1, 1024, 19, 88, 3, 1, 1, 1, False, False),
+        # batch > 1
+        (7, 32, 149, 32, 3, 1, 0, 1, False, False),
+        (8, 32, 149, 32, 3, 1, 0, 1, False, False),
+        (32, 32, 149, 32, 3, 1, 0, 1, False, False),
+        # Asymmetric padding
+        (1, 32, 35, 64, 7, 2, (0, 0, 1, 1), 1, False, False),
+        (1, 64, 8, 128, 3, 1, (3, 3, 2, 2), 1, False, False),
+        (1, 64, 8, 64, 1, 1, (1, 2, 2, 1), 1, False, False),
+        (1, 64, 17, 192, 1, 1, (1, 2), 1, False, False),
+        (1, 64, 8, 64, 3, 1, (3, 1), 1, False, False),
+        (1, 128, 8, 384, 3, 1, (0, 2), 1, False, False),
+        (1, 64, 8, 64, 1, 1, "VALID", 1, False, False),
+        (1, 392, 8, 64, 3, 1, "VALID", 1, False, False),
+        (1, 512, 19, 64, 1, 1, "SAME", 1, False, False),
+        (1, 64, 16, 32, 2, 1, "SAME", 1, False, False),
+        (1, 64, 8, 64, 3, 1, (1, 2, 2, 1), 1, False, True),
+        (1, 64, 8, 64, 5, 2, (1, 3), 1, True, False),
+        (1, 64, 56, 64, 3, 1, "VALID", 1, True, True),
+        (1, 64, 56, 64, 24, 1, "SAME", 1, True, True),
+    ],
+)
+def test_conv2d_NCHWc_int8(in_dtype, params):
+    with Int8Fallback():
         (
-            "llvm -device arm_cpu -mtriple aarch64-linux-gnu -mattr=+neon,+v8.2a,+dotprod",
-            topi.arm_cpu.conv2d_NCHWc_int8,
-            topi.arm_cpu.schedule_conv2d_NCHWc_int8,
-            8,
-            build_only_aarch64,
+            batch,
+            in_channel,
+            in_size,
+            num_filter,
+            kernel,
+            stride,
+            padding,
+            dilation,
+            add_bias,
+            add_relu,
+        ) = params
+        pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (kernel, kernel))
+        padding_sum = pad_top + pad_left + pad_bottom + pad_right
+        print(
+            "Workload: (%d, %d, %d, %d, %d, %d, %d, %d)"
+            % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation)
         )
-    )
-
-    if in_dtype == "int8":
-        targets += [
-            (
-                "llvm -device arm_cpu -mtriple aarch64-linux-gnu -mattr=+neon",
-                topi.arm_cpu.conv2d_NCHWc_int8,
-                topi.arm_cpu.schedule_conv2d_NCHWc_int8,
-                8,
-                build_only_aarch64,
-            ),
-            (
-                "rocm -mattr=+dotprod",
-                lambda a, w, s, p, d, l, ol, o: topi.cuda.conv2d_NCHWc_int8(a, w, s, p, d, l, o),
-                topi.cuda.schedule_conv2d_NCHWc_int8,
-                4,
-                False,
-            ),
-        ]
-
-    for target, compute, schedule, oc_block_factor, build_only in targets:
-        check_target(target, compute, schedule, oc_block_factor, build_only)
-
-
-def verify_conv2d_nchw_int8(
-    in_dtype,
-    batch,
-    in_channel,
-    in_size,
-    num_filter,
-    kernel,
-    stride,
-    padding,
-    dilation=1,
-    add_bias=False,
-    add_relu=False,
-):
-    pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (kernel, kernel))
-    padding_sum = pad_top + pad_left + pad_bottom + pad_right
-    print(
-        "Workload: (%d, %d, %d, %d, %d, %d, %d, %d)"
-        % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation)
-    )
-
-    in_height = in_width = in_size
-
-    A = te.placeholder((batch, in_channel, in_height, in_width), name="A", dtype=in_dtype)
-    W = te.placeholder((num_filter, in_channel, kernel, kernel), name="W", dtype=in_dtype)
-    bias = te.placeholder((num_filter, 1, 1), name="bias", dtype=in_dtype)
-
-    a_shape = get_const_tuple(A.shape)
-    w_shape = get_const_tuple(W.shape)
-    bias_shape = get_const_tuple(bias.shape)
-    dtype = A.dtype
-
-    @memoize("topi.tests.test_topi_conv2d_int8.verify_conv2d_nchw")
-    def get_ref_data():
-        a_np = np.random.randint(low=-128, high=127, size=a_shape).astype(dtype)
-        w_np = np.random.randint(low=-128, high=128, size=w_shape).astype(dtype)
-        b_np = np.random.uniform(size=bias_shape).astype(dtype)
-        dw_np = tvm.topi.testing.dilate_python(w_np, (1, 1, dilation, dilation))
-        c_np = tvm.topi.testing.conv2d_nchw_python(a_np, dw_np, stride, padding).astype(dtype)
-
-        if add_bias:
-            b_np = np.random.uniform(size=bias_shape).astype(dtype)
-            c_np += b_np
-        if add_relu:
-            c_np = np.maximum(c_np, 0)
-
-        return a_np, w_np, b_np, c_np
-
-    a_np, w_np, b_np, c_np = get_ref_data()
-
-    def verify_workload_padding():
-        _, _, out_height, out_width = get_const_tuple(c_np.shape)
-        wkl = _get_workload(A, W, (stride, stride), padding, dilation, dtype)
-
-        # for testing functionality,
-        # we choose arbitrary int32_lanes and num_int8_elements can divide the channel,
-        # regardless of the performance.
-        int32_lanes, num_int8_elements = num_filter, in_channel
 
-        # check if tile_ow candidates are the factors of the right output weight.
-        cfg = autotvm.get_config()
-        fallback_schedule_cpu_common_int8(cfg, wkl, int32_lanes, num_int8_elements)
-        ow_tile = np.prod(cfg["tile_ow"].size)
-
-        tvm.testing.assert_allclose(ow_tile, out_width)
+        in_height = in_width = in_size
+
+        A = te.placeholder((batch, in_channel, in_height, in_width), name="A", dtype=in_dtype)
+        W = te.placeholder((num_filter, in_channel, kernel, kernel), name="W", dtype=in_dtype)
+
+        a_shape = get_const_tuple(A.shape)
+        w_shape = get_const_tuple(W.shape)
+        dtype = A.dtype
+        out_dtype = "int32" if in_dtype == "int8" else "uint32"
+        lo = -128 if in_dtype == "int8" else 0
+        hi = 127 if in_dtype == "int8" else 255
+
+        def check_target(target, compute, schedule, oc_block_factor, build_only):
+            dev = tvm.device(target, 0)
+            if not tvm.testing.device_enabled(target):
+                print("Skip because %s is not enabled" % target)
+                return
+            if target == "cuda" and not tvm.contrib.nvcc.have_int8(dev.compute_version):
+                print("Skip because int8 intrinsics are not available")
+                return
+
+            bias = te.placeholder(
+                (num_filter // oc_block_factor, 1, 1, oc_block_factor), name="bias", dtype=out_dtype
+            )
+            bias_shape = get_const_tuple(bias.shape)
 
-    def check_target(target):
-        dev = tvm.device(target, 0)
-        if not tvm.testing.device_enabled(target):
-            print("Skip because %s is not enabled" % target)
-            return
-        if target == "cuda" and not tvm.contrib.nvcc.have_int8(dev.compute_version):
-            print("Skip because int8 intrinsics are not available")
-            return
+            @memoize("topi.tests.test_topi_conv2d_int8.verify_conv2d_nchw")
+            def get_ref_data():
+                a_np = np.random.randint(low=lo, high=hi, size=a_shape).astype(out_dtype)
+                w_np = np.random.randint(low=lo, high=hi, size=w_shape).astype(out_dtype)
+                b_np = np.random.uniform(size=bias_shape).astype(out_dtype)
+                dw_np = tvm.topi.testing.dilate_python(w_np, (1, 1, dilation, dilation))
+                c_np = tvm.topi.testing.conv2d_nchw_python(a_np, dw_np, stride, padding).astype(
+                    out_dtype
+                )
+
+                # convert to NCHWc
+                _, _, out_height, out_width = c_np.shape
+                c_np = c_np.reshape(
+                    (batch, num_filter // oc_block_factor, oc_block_factor, out_height, out_width)
+                ).transpose(0, 1, 3, 4, 2)
+
+                if add_bias:
+                    b_np = np.random.uniform(size=bias_shape).astype(out_dtype)
+                    c_np += b_np
+                if add_relu:
+                    c_np = np.maximum(c_np, 0)
+
+                return a_np, w_np, b_np, c_np
+
+            a_np, w_np, b_np, c_np = get_ref_data()
+
+            with tvm.target.Target(target):
+                C = compute(
+                    A,
+                    W,
+                    (stride, stride),
+                    padding,
+                    (dilation, dilation),
+                    "NCHW",
+                    "NCHW",
+                    out_dtype,
+                )
+                if add_bias:
+                    C = topi.add(C, bias)
+                if add_relu:
+                    C = topi.nn.relu(C)
+                s = schedule([C])
+
+            a = tvm.nd.array(a_np.astype(dtype), dev)
+            w = tvm.nd.array(w_np.astype(dtype), dev)
+            b = tvm.nd.array(b_np.astype(out_dtype), dev)
+            c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), dev)
 
-        print("Running on target: %s" % target)
-        with tvm.target.Target(target):
-            C = topi.cuda.conv2d_nchw_int8(
-                A, W, (stride, stride), padding, (dilation, dilation), dtype
-            )
             if add_bias:
-                C = topi.add(C, bias)
-            if add_relu:
-                C = topi.nn.relu(C)
-            s = topi.cuda.schedule_conv2d_nchw_int8([C])
-
-        a = tvm.nd.array(a_np, dev)
-        w = tvm.nd.array(w_np, dev)
-        b = tvm.nd.array(b_np, dev)
-        c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), dev)
-        if add_bias:
-            tvm.build(
-                s,
-                [A, W, bias, C],
-                target,
-                name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
-                % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation),
-            )
-            func = tvm.build(
-                s,
-                [A, W, bias, C],
-                target,
-                name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
-                % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation),
-            )
-            func(a, w, b, c)
-        else:
+                compile_args = [A, W, bias, C]
+                run_args = [a, w, b, c]
+            else:
+                compile_args = [A, W, C]
+                run_args = [a, w, c]
+
             func = tvm.build(
                 s,
-                [A, W, C],
+                compile_args,
                 target,
                 name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
                 % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation),
             )
-            func(a, w, c)
-        tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-5)
 
-    verify_workload_padding()
+            if build_only:
+                return
 
-    for target in ["cuda"]:
-        check_target(target)
+            print("Running on target: %s" % target)
 
+            func(*run_args)
 
-@pytest.mark.parametrize("in_dtype", ["int8", "uint8"])
-def test_conv2d_nchw(in_dtype):
-    with Int8Fallback():
-        # ResNet18 workloads where channels in / out are multiple of oc_block_factor
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 56, 64, 3, 1, 1)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 56, 64, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 56, 128, 3, 2, 1)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 56, 128, 1, 2, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 128, 28, 128, 3, 1, 1)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 128, 28, 256, 3, 2, 1)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 128, 28, 256, 1, 2, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 256, 14, 256, 3, 1, 1)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 256, 14, 512, 3, 2, 1)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 256, 14, 512, 1, 2, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 512, 7, 512, 3, 1, 1)
+            tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-5)
 
-        # bias, relu
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 56, 64, 3, 1, 1, add_relu=True)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 56, 64, 3, 1, 1, add_bias=True)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 56, 64, 3, 1, 1, add_bias=True, add_relu=True)
+        targets = [
+            (
+                "cuda",
+                lambda a, w, s, p, d, l, ol, o: topi.cuda.conv2d_NCHWc_int8(a, w, s, p, d, l, o),
+                topi.cuda.schedule_conv2d_NCHWc_int8,
+                4,
+                False,
+            ),
+            # Disable on CI since it does not support spirv int8 dot product
+            # (
+            #     "vulkan -from_device=0",
+            #     lambda a, w, s, p, d, l, ol, o: topi.cuda.conv2d_NCHWc_int8(a, w, s, p, d, l, o),
+            #     topi.cuda.schedule_conv2d_NCHWc_int8,
+            #     4,
+            #     False,
+            # ),
+        ]
 
-        # dilation = 2
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 56, 64, 3, 1, 1, dilation=2)
+        build_only_aarch64 = platform.machine() != "aarch64"
 
-        # batch size
-        verify_conv2d_NCHWc_int8(in_dtype, 4, 64, 56, 64, 3, 1, 1)
-        verify_conv2d_NCHWc_int8(in_dtype, 9, 64, 56, 64, 3, 1, 1)
+        targets.append(
+            (
+                "llvm -device arm_cpu -mtriple aarch64-linux-gnu -mattr=+neon,+v8.2a,+dotprod",
+                topi.arm_cpu.conv2d_NCHWc_int8,
+                topi.arm_cpu.schedule_conv2d_NCHWc_int8,
+                8,
+                build_only_aarch64,
+            )
+        )
 
-        # weird workloads
-        verify_conv2d_NCHWc_int8(in_dtype, 4, 4, 4, 8, 4, 4, 4)
+        if in_dtype == "int8":
+            targets += [
+                (
+                    "llvm -device arm_cpu -mtriple aarch64-linux-gnu -mattr=+neon",
+                    topi.arm_cpu.conv2d_NCHWc_int8,
+                    topi.arm_cpu.schedule_conv2d_NCHWc_int8,
+                    8,
+                    build_only_aarch64,
+                ),
+                (
+                    "rocm -mattr=+dotprod",
+                    lambda a, w, s, p, d, l, ol, o: topi.cuda.conv2d_NCHWc_int8(
+                        a, w, s, p, d, l, o
+                    ),
+                    topi.cuda.schedule_conv2d_NCHWc_int8,
+                    4,
+                    False,
+                ),
+            ]
+
+        for target, compute, schedule, oc_block_factor, build_only in targets:
+            check_target(target, compute, schedule, oc_block_factor, build_only)
+
+
+# Conv2d NCHW int8 schedule testing. Internally, it uses NCHWc schedule. So, just
+# performing basic testing - one test for all different scenarios - batch, dilation etc..
+@pytest.mark.parametrize("in_dtype", ["int8", "uint8"])
+@pytest.mark.parametrize(
+    "params",
+    [
+        (1, 64, 56, 64, 3, 1, 1, 1, False, False),
+        (1, 64, 56, 64, 3, 1, 1, 1, False, True),
+        (1, 64, 56, 64, 3, 1, 1, 2, False, False),
+        (9, 64, 56, 64, 3, 1, 1, 1, False, False),
+        (4, 4, 4, 4, 4, 4, 4, 1, False, False),
+        (1, 32, 149, 32, 3, 1, 0, 1, False, False),
+        (7, 32, 149, 32, 3, 1, 0, 1, False, False),
+        (1, 32, 35, 64, 7, 2, (0, 0, 1, 1), 1, False, False),
+        (1, 32, 35, 64, 7, 2, (0, 0, 2, 2), 1, False, False),
+    ],
+)
+def test_conv2d_nchw_int8(in_dtype, params):
+    with Int8Fallback():
+        (
+            batch,
+            in_channel,
+            in_size,
+            num_filter,
+            kernel,
+            stride,
+            padding,
+            dilation,
+            add_bias,
+            add_relu,
+        ) = params
+        pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (kernel, kernel))
+        padding_sum = pad_top + pad_left + pad_bottom + pad_right
+        print(
+            "Workload: (%d, %d, %d, %d, %d, %d, %d, %d)"
+            % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation)
+        )
 
-        # inception v3 workloads where channels in / out are multiple of oc_block_factor
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 32, 149, 32, 3, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 32, 147, 64, 3, 1, 1)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 73, 80, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 80, 73, 192, 3, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 192, 35, 64, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 192, 35, 48, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 48, 35, 64, 5, 1, 2)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 35, 96, 3, 1, 1)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 96, 35, 96, 3, 1, 1)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 192, 35, 32, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 256, 35, 64, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 256, 35, 48, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 288, 35, 64, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 288, 35, 48, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 288, 35, 384, 3, 2, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 96, 35, 96, 3, 2, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 768, 17, 192, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 768, 17, 128, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 128, 17, 128, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 128, 17, 192, 7, 1, 3)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 128, 17, 128, 7, 1, 3)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 128, 17, 192, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 768, 17, 160, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 160, 17, 160, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 160, 17, 192, 7, 1, 3)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 160, 17, 160, 7, 1, 3)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 160, 17, 192, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 192, 17, 192, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 192, 17, 192, 7, 1, 3)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 192, 17, 320, 3, 2, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 192, 17, 192, 3, 2, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 1280, 8, 320, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 1280, 8, 384, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 384, 8, 384, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 384, 8, 384, 3, 1, 1)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 1280, 8, 448, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 448, 8, 384, 3, 1, 1)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 1280, 8, 192, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 2048, 8, 320, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 2048, 8, 384, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 2048, 8, 448, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 2048, 8, 192, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 1024, 19, 88, 3, 1, 1)
+        in_height = in_width = in_size
 
-        # batch > 1
-        verify_conv2d_NCHWc_int8(in_dtype, 7, 32, 149, 32, 3, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 8, 32, 149, 32, 3, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 32, 32, 149, 32, 3, 1, 0)
+        A = te.placeholder((batch, in_channel, in_height, in_width), name="A", dtype=in_dtype)
+        W = te.placeholder((num_filter, in_channel, kernel, kernel), name="W", dtype=in_dtype)
+        bias = te.placeholder((num_filter, 1, 1), name="bias", dtype=in_dtype)
 
-        # Asymmetric padding
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 32, 35, 64, 7, 2, (0, 0, 1, 1))
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 8, 128, 3, 1, (3, 3, 2, 2))
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 8, 64, 1, 1, (1, 2, 2, 1))
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 17, 192, 1, 1, (1, 2))
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 8, 64, 3, 1, (3, 1))
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 128, 8, 384, 3, 1, (0, 2))
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 8, 64, 1, 1, "VALID")
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 392, 8, 64, 3, 1, "VALID")
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 512, 19, 64, 1, 1, "SAME")
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 16, 32, 2, 1, "SAME")
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 8, 64, 3, 1, (1, 2, 2, 1), add_relu=True)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 8, 64, 5, 2, (1, 3), add_bias=True)
-        verify_conv2d_NCHWc_int8(
-            in_dtype, 1, 64, 56, 64, 3, 1, "VALID", add_bias=True, add_relu=True
-        )
-        verify_conv2d_NCHWc_int8(
-            in_dtype, 1, 64, 56, 64, 24, 1, "SAME", add_bias=True, add_relu=True
-        )
+        a_shape = get_const_tuple(A.shape)
+        w_shape = get_const_tuple(W.shape)
+        bias_shape = get_const_tuple(bias.shape)
+        dtype = A.dtype
+
+        @memoize("topi.tests.test_topi_conv2d_int8.verify_conv2d_nchw")
+        def get_ref_data():
+            a_np = np.random.randint(low=-128, high=127, size=a_shape).astype(dtype)
+            w_np = np.random.randint(low=-128, high=128, size=w_shape).astype(dtype)
+            b_np = np.random.uniform(size=bias_shape).astype(dtype)
+            dw_np = tvm.topi.testing.dilate_python(w_np, (1, 1, dilation, dilation))
+            c_np = tvm.topi.testing.conv2d_nchw_python(a_np, dw_np, stride, padding).astype(dtype)
 
-        # Conv2d NCHW int8 schedule testing. Internally, it uses NCHWc schedule. So, just
-        # performing basic testing - one test for all different scenarios - batch, dilation etc..
-        verify_conv2d_nchw_int8(in_dtype, 1, 64, 56, 64, 3, 1, 1)
-        verify_conv2d_nchw_int8(in_dtype, 1, 64, 56, 64, 3, 1, 1, add_relu=True)
-        verify_conv2d_nchw_int8(in_dtype, 1, 64, 56, 64, 3, 1, 1, dilation=2)
-        verify_conv2d_nchw_int8(in_dtype, 9, 64, 56, 64, 3, 1, 1)
-        verify_conv2d_nchw_int8(in_dtype, 4, 4, 4, 4, 4, 4, 4)
-        verify_conv2d_nchw_int8(in_dtype, 1, 32, 149, 32, 3, 1, 0)
-        verify_conv2d_nchw_int8(in_dtype, 7, 32, 149, 32, 3, 1, 0)
-        verify_conv2d_nchw_int8(in_dtype, 1, 32, 35, 64, 7, 2, (0, 0, 1, 1))
-        verify_conv2d_nchw_int8(in_dtype, 1, 32, 35, 64, 7, 2, (0, 0, 2, 2))
+            if add_bias:
+                b_np = np.random.uniform(size=bias_shape).astype(dtype)
+                c_np += b_np
+            if add_relu:
+                c_np = np.maximum(c_np, 0)
 
+            return a_np, w_np, b_np, c_np
 
-def test_conv2d_nhwc():
-    with Int8Fallback():
-        # Subset of inception v3 expanded (dilation > 1, batch > 1, 'VALID' padding)
-        verify_conv2d_NHWC_gemm_int8(1, 3, 299, 32, 3, 2, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 32, 149, 32, 3, 1, "SAME", dilation=2)
-        verify_conv2d_NHWC_gemm_int8(4, 32, 147, 64, 3, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 64, 73, 80, 1, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 80, 73, 192, 3, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 192, 35, 48, 1, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 192, 35, 64, 1, 1, "VALID")
-        verify_conv2d_NHWC_gemm_int8(1, 192, 35, 32, 1, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 48, 35, 64, 5, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 96, 35, 96, 3, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 256, 35, 48, 1, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 256, 35, 64, 1, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 288, 35, 64, 1, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 288, 35, 48, 1, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 96, 35, 96, 3, 2, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 128, 17, 192, 7, 1, "SAME", dilation=2)
-        verify_conv2d_NHWC_gemm_int8(1, 160, 17, 160, 7, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 160, 17, 192, 1, 1, "VALID")
-        verify_conv2d_NHWC_gemm_int8(1, 192, 17, 192, 1, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 768, 5, 128, 1, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 192, 17, 320, 3, 2, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 192, 17, 192, 3, 2, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 1280, 8, 192, 1, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 1280, 8, 384, 1, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 1280, 8, 320, 1, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 1280, 8, 448, 1, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 384, 8, 384, 1, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 384, 8, 384, 3, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 448, 8, 384, 3, 1, "VALID")
-        verify_conv2d_NHWC_gemm_int8(1, 2048, 8, 320, 1, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 2048, 8, 448, 1, 1, "SAME", add_bias=True, add_relu=True)
-        verify_conv2d_NHWC_gemm_int8(1, 2048, 8, 192, 1, 1, "SAME", add_bias=True)
-
-        # Let's also verify that it compiles fine on AArch64 targets
-        compile_conv2d_NHWC_gemm_int8_arm(1, 3, 299, 32, 3, 2, "SAME")
+        a_np, w_np, b_np, c_np = get_ref_data()
+
+        def verify_workload_padding():
+            _, _, out_height, out_width = get_const_tuple(c_np.shape)
+            wkl = _get_workload(A, W, (stride, stride), padding, dilation, dtype)
+
+            # for testing functionality,
+            # we choose arbitrary int32_lanes and num_int8_elements can divide the channel,
+            # regardless of the performance.
+            int32_lanes, num_int8_elements = num_filter, in_channel
+
+            # check if tile_ow candidates are the factors of the right output weight.
+            cfg = autotvm.get_config()
+            fallback_schedule_cpu_common_int8(cfg, wkl, int32_lanes, num_int8_elements)
+            ow_tile = np.prod(cfg["tile_ow"].size)
+
+            tvm.testing.assert_allclose(ow_tile, out_width)
+
+        def check_target(target):
+            dev = tvm.device(target, 0)
+            if not tvm.testing.device_enabled(target):
+                print("Skip because %s is not enabled" % target)
+                return
+            if target == "cuda" and not tvm.contrib.nvcc.have_int8(dev.compute_version):
+                print("Skip because int8 intrinsics are not available")
+                return
+
+            print("Running on target: %s" % target)
+            with tvm.target.Target(target):
+                C = topi.cuda.conv2d_nchw_int8(
+                    A, W, (stride, stride), padding, (dilation, dilation), dtype
+                )
+                if add_bias:
+                    C = topi.add(C, bias)
+                if add_relu:
+                    C = topi.nn.relu(C)
+                s = topi.cuda.schedule_conv2d_nchw_int8([C])
+
+            a = tvm.nd.array(a_np, dev)
+            w = tvm.nd.array(w_np, dev)
+            b = tvm.nd.array(b_np, dev)
+            c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), dev)
+            if add_bias:
+                func = tvm.build(
+                    s,
+                    [A, W, bias, C],
+                    target,
+                    name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
+                    % (
+                        batch,
+                        in_channel,
+                        in_size,
+                        num_filter,
+                        kernel,
+                        stride,
+                        padding_sum,
+                        dilation,
+                    ),
+                )
+                func(a, w, b, c)
+            else:
+                func = tvm.build(
+                    s,
+                    [A, W, C],
+                    target,
+                    name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
+                    % (
+                        batch,
+                        in_channel,
+                        in_size,
+                        num_filter,
+                        kernel,
+                        stride,
+                        padding_sum,
+                        dilation,
+                    ),
+                )
+                func(a, w, c)
+            tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-5)
+
+        verify_workload_padding()
+
+        for target in ["cuda"]:
+            check_target(target)

Review Comment:
   ```suggestion
           check_target("cuda")
   ```



##########
tests/python/topi/python/test_topi_conv2d_int8.py:
##########
@@ -298,378 +175,462 @@ def get_ref_data():
 
         a_np, w_np, b_np, c_np = get_ref_data()
 
-        with tvm.target.Target(target):
-            C = compute(
-                A,
-                W,
-                (stride, stride),
-                padding,
-                (dilation, dilation),
-                "NCHW",
-                "NCHW",
-                out_dtype,
-            )
+        dev = tvm.device(target, 0)
+        with tvm.target.Target(target) as tvm_target:
+            C = compute(A, W, (stride, stride), padding, (dilation, dilation), dtype)
             if add_bias:
                 C = topi.add(C, bias)
             if add_relu:
                 C = topi.nn.relu(C)
             s = schedule([C])
 
-        a = tvm.nd.array(a_np.astype(dtype), dev)
-        w = tvm.nd.array(w_np.astype(dtype), dev)
-        b = tvm.nd.array(b_np.astype(out_dtype), dev)
-        c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), dev)
-
-        if add_bias:
-            compile_args = [A, W, bias, C]
-            run_args = [a, w, b, c]
-        else:
-            compile_args = [A, W, C]
-            run_args = [a, w, c]
-
-        func = tvm.build(
-            s,
-            compile_args,
-            target,
-            name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
-            % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation),
-        )
+            a = tvm.nd.array(a_np, dev)
+            w = tvm.nd.array(w_np, dev)
+            b = tvm.nd.array(b_np, dev)
+            c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), dev)
 
-        if build_only:
-            return
+            build_inputs = [A, W, bias, C] if add_bias else [A, W, C]
+            inference_inputs = (a, w, b, c) if add_bias else (a, w, c)
+
+            func = tvm.build(
+                s,
+                build_inputs,
+                target,
+                name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
+                % (
+                    batch,
+                    in_channel,
+                    in_size,
+                    num_filter,
+                    kernel,
+                    stride,
+                    padding_sum,
+                    dilation,
+                ),
+            )
 
-        print("Running on target: %s" % target)
+            build_only = tvm_target.features.is_aarch64 and (platform.machine() != "aarch64")
 
-        func(*run_args)
+            if not build_only:
+                print("Running on target: %s" % target)
+                func(*inference_inputs)
+                tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-5)
 
-        tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-5)
 
-    targets = [
-        (
-            "cuda",
-            lambda a, w, s, p, d, l, ol, o: topi.cuda.conv2d_NCHWc_int8(a, w, s, p, d, l, o),
-            topi.cuda.schedule_conv2d_NCHWc_int8,
-            4,
-            False,
-        ),
-        # Disable on CI since it does not support spirv int8 dot product
-        # (
-        #     "vulkan -from_device=0",
-        #     lambda a, w, s, p, d, l, ol, o: topi.cuda.conv2d_NCHWc_int8(a, w, s, p, d, l, o),
-        #     topi.cuda.schedule_conv2d_NCHWc_int8,
-        #     4,
-        #     False,
-        # ),
-    ]
-
-    build_only_aarch64 = platform.machine() != "aarch64"
-
-    targets.append(
+@pytest.mark.parametrize("in_dtype", ["int8", "uint8"])
+@pytest.mark.parametrize(
+    "params",
+    [
+        # ResNet18 workloads where channels in / out are multiple of oc_block_factor
+        (1, 64, 56, 64, 3, 1, 1, 1, False, False),
+        (1, 64, 56, 64, 1, 1, 0, 1, False, False),
+        (1, 64, 56, 128, 3, 2, 1, 1, False, False),
+        (1, 64, 56, 128, 1, 2, 0, 1, False, False),
+        (1, 128, 28, 128, 3, 1, 1, 1, False, False),
+        (1, 128, 28, 256, 3, 2, 1, 1, False, False),
+        (1, 128, 28, 256, 1, 2, 0, 1, False, False),
+        (1, 256, 14, 256, 3, 1, 1, 1, False, False),
+        (1, 256, 14, 512, 3, 2, 1, 1, False, False),
+        (1, 256, 14, 512, 1, 2, 0, 1, False, False),
+        (1, 512, 7, 512, 3, 1, 1, 1, False, False),
+        # bias, relu
+        (1, 64, 56, 64, 3, 1, 1, 1, False, True),
+        (1, 64, 56, 64, 3, 1, 1, 1, True, False),
+        (1, 64, 56, 64, 3, 1, 1, 1, True, True),
+        # dilation = 2
+        (1, 64, 56, 64, 3, 1, 1, 2, False, False),
+        # batch size
+        (4, 64, 56, 64, 3, 1, 1, 1, False, False),
+        (9, 64, 56, 64, 3, 1, 1, 1, False, False),
+        # weird workloads
+        (4, 4, 4, 8, 4, 4, 4, 1, False, False),
+        # inception v3 workloads where channels in / out are multiple of oc_block_factor
+        (1, 32, 149, 32, 3, 1, 0, 1, False, False),
+        (1, 32, 147, 64, 3, 1, 1, 1, False, False),
+        (1, 64, 73, 80, 1, 1, 0, 1, False, False),
+        (1, 80, 73, 192, 3, 1, 0, 1, False, False),
+        (1, 192, 35, 64, 1, 1, 0, 1, False, False),
+        (1, 192, 35, 48, 1, 1, 0, 1, False, False),
+        (1, 48, 35, 64, 5, 1, 2, 1, False, False),
+        (1, 64, 35, 96, 3, 1, 1, 1, False, False),
+        (1, 96, 35, 96, 3, 1, 1, 1, False, False),
+        (1, 192, 35, 32, 1, 1, 0, 1, False, False),
+        (1, 256, 35, 64, 1, 1, 0, 1, False, False),
+        (1, 256, 35, 48, 1, 1, 0, 1, False, False),
+        (1, 288, 35, 64, 1, 1, 0, 1, False, False),
+        (1, 288, 35, 48, 1, 1, 0, 1, False, False),
+        (1, 288, 35, 384, 3, 2, 0, 1, False, False),
+        (1, 96, 35, 96, 3, 2, 0, 1, False, False),
+        (1, 768, 17, 192, 1, 1, 0, 1, False, False),
+        (1, 768, 17, 128, 1, 1, 0, 1, False, False),
+        (1, 128, 17, 128, 1, 1, 0, 1, False, False),
+        (1, 128, 17, 192, 7, 1, 3, 1, False, False),
+        (1, 128, 17, 128, 7, 1, 3, 1, False, False),
+        (1, 128, 17, 192, 1, 1, 0, 1, False, False),
+        (1, 768, 17, 160, 1, 1, 0, 1, False, False),
+        (1, 160, 17, 160, 1, 1, 0, 1, False, False),
+        (1, 160, 17, 192, 7, 1, 3, 1, False, False),
+        (1, 160, 17, 160, 7, 1, 3, 1, False, False),
+        (1, 160, 17, 192, 1, 1, 0, 1, False, False),
+        (1, 192, 17, 192, 1, 1, 0, 1, False, False),
+        (1, 192, 17, 192, 7, 1, 3, 1, False, False),
+        (1, 192, 17, 320, 3, 2, 0, 1, False, False),
+        (1, 192, 17, 192, 3, 2, 0, 1, False, False),
+        (1, 1280, 8, 320, 1, 1, 0, 1, False, False),
+        (1, 1280, 8, 384, 1, 1, 0, 1, False, False),
+        (1, 384, 8, 384, 1, 1, 0, 1, False, False),
+        (1, 384, 8, 384, 3, 1, 1, 1, False, False),
+        (1, 1280, 8, 448, 1, 1, 0, 1, False, False),
+        (1, 448, 8, 384, 3, 1, 1, 1, False, False),
+        (1, 1280, 8, 192, 1, 1, 0, 1, False, False),
+        (1, 2048, 8, 320, 1, 1, 0, 1, False, False),
+        (1, 2048, 8, 384, 1, 1, 0, 1, False, False),
+        (1, 2048, 8, 448, 1, 1, 0, 1, False, False),
+        (1, 2048, 8, 192, 1, 1, 0, 1, False, False),
+        (1, 1024, 19, 88, 3, 1, 1, 1, False, False),
+        # batch > 1
+        (7, 32, 149, 32, 3, 1, 0, 1, False, False),
+        (8, 32, 149, 32, 3, 1, 0, 1, False, False),
+        (32, 32, 149, 32, 3, 1, 0, 1, False, False),
+        # Asymmetric padding
+        (1, 32, 35, 64, 7, 2, (0, 0, 1, 1), 1, False, False),
+        (1, 64, 8, 128, 3, 1, (3, 3, 2, 2), 1, False, False),
+        (1, 64, 8, 64, 1, 1, (1, 2, 2, 1), 1, False, False),
+        (1, 64, 17, 192, 1, 1, (1, 2), 1, False, False),
+        (1, 64, 8, 64, 3, 1, (3, 1), 1, False, False),
+        (1, 128, 8, 384, 3, 1, (0, 2), 1, False, False),
+        (1, 64, 8, 64, 1, 1, "VALID", 1, False, False),
+        (1, 392, 8, 64, 3, 1, "VALID", 1, False, False),
+        (1, 512, 19, 64, 1, 1, "SAME", 1, False, False),
+        (1, 64, 16, 32, 2, 1, "SAME", 1, False, False),
+        (1, 64, 8, 64, 3, 1, (1, 2, 2, 1), 1, False, True),
+        (1, 64, 8, 64, 5, 2, (1, 3), 1, True, False),
+        (1, 64, 56, 64, 3, 1, "VALID", 1, True, True),
+        (1, 64, 56, 64, 24, 1, "SAME", 1, True, True),
+    ],
+)
+def test_conv2d_NCHWc_int8(in_dtype, params):
+    with Int8Fallback():
         (
-            "llvm -device arm_cpu -mtriple aarch64-linux-gnu -mattr=+neon,+v8.2a,+dotprod",
-            topi.arm_cpu.conv2d_NCHWc_int8,
-            topi.arm_cpu.schedule_conv2d_NCHWc_int8,
-            8,
-            build_only_aarch64,
+            batch,
+            in_channel,
+            in_size,
+            num_filter,
+            kernel,
+            stride,
+            padding,
+            dilation,
+            add_bias,
+            add_relu,
+        ) = params
+        pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (kernel, kernel))
+        padding_sum = pad_top + pad_left + pad_bottom + pad_right
+        print(
+            "Workload: (%d, %d, %d, %d, %d, %d, %d, %d)"
+            % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation)
         )
-    )
-
-    if in_dtype == "int8":
-        targets += [
-            (
-                "llvm -device arm_cpu -mtriple aarch64-linux-gnu -mattr=+neon",
-                topi.arm_cpu.conv2d_NCHWc_int8,
-                topi.arm_cpu.schedule_conv2d_NCHWc_int8,
-                8,
-                build_only_aarch64,
-            ),
-            (
-                "rocm -mattr=+dotprod",
-                lambda a, w, s, p, d, l, ol, o: topi.cuda.conv2d_NCHWc_int8(a, w, s, p, d, l, o),
-                topi.cuda.schedule_conv2d_NCHWc_int8,
-                4,
-                False,
-            ),
-        ]
-
-    for target, compute, schedule, oc_block_factor, build_only in targets:
-        check_target(target, compute, schedule, oc_block_factor, build_only)
-
-
-def verify_conv2d_nchw_int8(
-    in_dtype,
-    batch,
-    in_channel,
-    in_size,
-    num_filter,
-    kernel,
-    stride,
-    padding,
-    dilation=1,
-    add_bias=False,
-    add_relu=False,
-):
-    pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (kernel, kernel))
-    padding_sum = pad_top + pad_left + pad_bottom + pad_right
-    print(
-        "Workload: (%d, %d, %d, %d, %d, %d, %d, %d)"
-        % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation)
-    )
-
-    in_height = in_width = in_size
-
-    A = te.placeholder((batch, in_channel, in_height, in_width), name="A", dtype=in_dtype)
-    W = te.placeholder((num_filter, in_channel, kernel, kernel), name="W", dtype=in_dtype)
-    bias = te.placeholder((num_filter, 1, 1), name="bias", dtype=in_dtype)
-
-    a_shape = get_const_tuple(A.shape)
-    w_shape = get_const_tuple(W.shape)
-    bias_shape = get_const_tuple(bias.shape)
-    dtype = A.dtype
-
-    @memoize("topi.tests.test_topi_conv2d_int8.verify_conv2d_nchw")
-    def get_ref_data():
-        a_np = np.random.randint(low=-128, high=127, size=a_shape).astype(dtype)
-        w_np = np.random.randint(low=-128, high=128, size=w_shape).astype(dtype)
-        b_np = np.random.uniform(size=bias_shape).astype(dtype)
-        dw_np = tvm.topi.testing.dilate_python(w_np, (1, 1, dilation, dilation))
-        c_np = tvm.topi.testing.conv2d_nchw_python(a_np, dw_np, stride, padding).astype(dtype)
-
-        if add_bias:
-            b_np = np.random.uniform(size=bias_shape).astype(dtype)
-            c_np += b_np
-        if add_relu:
-            c_np = np.maximum(c_np, 0)
-
-        return a_np, w_np, b_np, c_np
-
-    a_np, w_np, b_np, c_np = get_ref_data()
-
-    def verify_workload_padding():
-        _, _, out_height, out_width = get_const_tuple(c_np.shape)
-        wkl = _get_workload(A, W, (stride, stride), padding, dilation, dtype)
-
-        # for testing functionality,
-        # we choose arbitrary int32_lanes and num_int8_elements can divide the channel,
-        # regardless of the performance.
-        int32_lanes, num_int8_elements = num_filter, in_channel
 
-        # check if tile_ow candidates are the factors of the right output weight.
-        cfg = autotvm.get_config()
-        fallback_schedule_cpu_common_int8(cfg, wkl, int32_lanes, num_int8_elements)
-        ow_tile = np.prod(cfg["tile_ow"].size)
-
-        tvm.testing.assert_allclose(ow_tile, out_width)
+        in_height = in_width = in_size
+
+        A = te.placeholder((batch, in_channel, in_height, in_width), name="A", dtype=in_dtype)
+        W = te.placeholder((num_filter, in_channel, kernel, kernel), name="W", dtype=in_dtype)
+
+        a_shape = get_const_tuple(A.shape)
+        w_shape = get_const_tuple(W.shape)
+        dtype = A.dtype
+        out_dtype = "int32" if in_dtype == "int8" else "uint32"
+        lo = -128 if in_dtype == "int8" else 0
+        hi = 127 if in_dtype == "int8" else 255
+
+        def check_target(target, compute, schedule, oc_block_factor, build_only):
+            dev = tvm.device(target, 0)
+            if not tvm.testing.device_enabled(target):
+                print("Skip because %s is not enabled" % target)
+                return
+            if target == "cuda" and not tvm.contrib.nvcc.have_int8(dev.compute_version):
+                print("Skip because int8 intrinsics are not available")
+                return

Review Comment:
   ```suggestion
                   pytest.skip(reason="Skip because int8 intrinsics are not available")
   ```



##########
tests/python/topi/python/test_topi_conv2d_int8.py:
##########
@@ -298,378 +175,462 @@ def get_ref_data():
 
         a_np, w_np, b_np, c_np = get_ref_data()
 
-        with tvm.target.Target(target):
-            C = compute(
-                A,
-                W,
-                (stride, stride),
-                padding,
-                (dilation, dilation),
-                "NCHW",
-                "NCHW",
-                out_dtype,
-            )
+        dev = tvm.device(target, 0)
+        with tvm.target.Target(target) as tvm_target:
+            C = compute(A, W, (stride, stride), padding, (dilation, dilation), dtype)
             if add_bias:
                 C = topi.add(C, bias)
             if add_relu:
                 C = topi.nn.relu(C)
             s = schedule([C])
 
-        a = tvm.nd.array(a_np.astype(dtype), dev)
-        w = tvm.nd.array(w_np.astype(dtype), dev)
-        b = tvm.nd.array(b_np.astype(out_dtype), dev)
-        c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), dev)
-
-        if add_bias:
-            compile_args = [A, W, bias, C]
-            run_args = [a, w, b, c]
-        else:
-            compile_args = [A, W, C]
-            run_args = [a, w, c]
-
-        func = tvm.build(
-            s,
-            compile_args,
-            target,
-            name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
-            % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation),
-        )
+            a = tvm.nd.array(a_np, dev)
+            w = tvm.nd.array(w_np, dev)
+            b = tvm.nd.array(b_np, dev)
+            c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), dev)
 
-        if build_only:
-            return
+            build_inputs = [A, W, bias, C] if add_bias else [A, W, C]
+            inference_inputs = (a, w, b, c) if add_bias else (a, w, c)
+
+            func = tvm.build(
+                s,
+                build_inputs,
+                target,
+                name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
+                % (
+                    batch,
+                    in_channel,
+                    in_size,
+                    num_filter,
+                    kernel,
+                    stride,
+                    padding_sum,
+                    dilation,
+                ),
+            )
 
-        print("Running on target: %s" % target)
+            build_only = tvm_target.features.is_aarch64 and (platform.machine() != "aarch64")
 
-        func(*run_args)
+            if not build_only:
+                print("Running on target: %s" % target)
+                func(*inference_inputs)
+                tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-5)
 
-        tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-5)
 
-    targets = [
-        (
-            "cuda",
-            lambda a, w, s, p, d, l, ol, o: topi.cuda.conv2d_NCHWc_int8(a, w, s, p, d, l, o),
-            topi.cuda.schedule_conv2d_NCHWc_int8,
-            4,
-            False,
-        ),
-        # Disable on CI since it does not support spirv int8 dot product
-        # (
-        #     "vulkan -from_device=0",
-        #     lambda a, w, s, p, d, l, ol, o: topi.cuda.conv2d_NCHWc_int8(a, w, s, p, d, l, o),
-        #     topi.cuda.schedule_conv2d_NCHWc_int8,
-        #     4,
-        #     False,
-        # ),
-    ]
-
-    build_only_aarch64 = platform.machine() != "aarch64"
-
-    targets.append(
+@pytest.mark.parametrize("in_dtype", ["int8", "uint8"])
+@pytest.mark.parametrize(
+    "params",
+    [
+        # ResNet18 workloads where channels in / out are multiple of oc_block_factor
+        (1, 64, 56, 64, 3, 1, 1, 1, False, False),
+        (1, 64, 56, 64, 1, 1, 0, 1, False, False),
+        (1, 64, 56, 128, 3, 2, 1, 1, False, False),
+        (1, 64, 56, 128, 1, 2, 0, 1, False, False),
+        (1, 128, 28, 128, 3, 1, 1, 1, False, False),
+        (1, 128, 28, 256, 3, 2, 1, 1, False, False),
+        (1, 128, 28, 256, 1, 2, 0, 1, False, False),
+        (1, 256, 14, 256, 3, 1, 1, 1, False, False),
+        (1, 256, 14, 512, 3, 2, 1, 1, False, False),
+        (1, 256, 14, 512, 1, 2, 0, 1, False, False),
+        (1, 512, 7, 512, 3, 1, 1, 1, False, False),
+        # bias, relu
+        (1, 64, 56, 64, 3, 1, 1, 1, False, True),
+        (1, 64, 56, 64, 3, 1, 1, 1, True, False),
+        (1, 64, 56, 64, 3, 1, 1, 1, True, True),
+        # dilation = 2
+        (1, 64, 56, 64, 3, 1, 1, 2, False, False),
+        # batch size
+        (4, 64, 56, 64, 3, 1, 1, 1, False, False),
+        (9, 64, 56, 64, 3, 1, 1, 1, False, False),
+        # weird workloads
+        (4, 4, 4, 8, 4, 4, 4, 1, False, False),
+        # inception v3 workloads where channels in / out are multiple of oc_block_factor
+        (1, 32, 149, 32, 3, 1, 0, 1, False, False),
+        (1, 32, 147, 64, 3, 1, 1, 1, False, False),
+        (1, 64, 73, 80, 1, 1, 0, 1, False, False),
+        (1, 80, 73, 192, 3, 1, 0, 1, False, False),
+        (1, 192, 35, 64, 1, 1, 0, 1, False, False),
+        (1, 192, 35, 48, 1, 1, 0, 1, False, False),
+        (1, 48, 35, 64, 5, 1, 2, 1, False, False),
+        (1, 64, 35, 96, 3, 1, 1, 1, False, False),
+        (1, 96, 35, 96, 3, 1, 1, 1, False, False),
+        (1, 192, 35, 32, 1, 1, 0, 1, False, False),
+        (1, 256, 35, 64, 1, 1, 0, 1, False, False),
+        (1, 256, 35, 48, 1, 1, 0, 1, False, False),
+        (1, 288, 35, 64, 1, 1, 0, 1, False, False),
+        (1, 288, 35, 48, 1, 1, 0, 1, False, False),
+        (1, 288, 35, 384, 3, 2, 0, 1, False, False),
+        (1, 96, 35, 96, 3, 2, 0, 1, False, False),
+        (1, 768, 17, 192, 1, 1, 0, 1, False, False),
+        (1, 768, 17, 128, 1, 1, 0, 1, False, False),
+        (1, 128, 17, 128, 1, 1, 0, 1, False, False),
+        (1, 128, 17, 192, 7, 1, 3, 1, False, False),
+        (1, 128, 17, 128, 7, 1, 3, 1, False, False),
+        (1, 128, 17, 192, 1, 1, 0, 1, False, False),
+        (1, 768, 17, 160, 1, 1, 0, 1, False, False),
+        (1, 160, 17, 160, 1, 1, 0, 1, False, False),
+        (1, 160, 17, 192, 7, 1, 3, 1, False, False),
+        (1, 160, 17, 160, 7, 1, 3, 1, False, False),
+        (1, 160, 17, 192, 1, 1, 0, 1, False, False),
+        (1, 192, 17, 192, 1, 1, 0, 1, False, False),
+        (1, 192, 17, 192, 7, 1, 3, 1, False, False),
+        (1, 192, 17, 320, 3, 2, 0, 1, False, False),
+        (1, 192, 17, 192, 3, 2, 0, 1, False, False),
+        (1, 1280, 8, 320, 1, 1, 0, 1, False, False),
+        (1, 1280, 8, 384, 1, 1, 0, 1, False, False),
+        (1, 384, 8, 384, 1, 1, 0, 1, False, False),
+        (1, 384, 8, 384, 3, 1, 1, 1, False, False),
+        (1, 1280, 8, 448, 1, 1, 0, 1, False, False),
+        (1, 448, 8, 384, 3, 1, 1, 1, False, False),
+        (1, 1280, 8, 192, 1, 1, 0, 1, False, False),
+        (1, 2048, 8, 320, 1, 1, 0, 1, False, False),
+        (1, 2048, 8, 384, 1, 1, 0, 1, False, False),
+        (1, 2048, 8, 448, 1, 1, 0, 1, False, False),
+        (1, 2048, 8, 192, 1, 1, 0, 1, False, False),
+        (1, 1024, 19, 88, 3, 1, 1, 1, False, False),
+        # batch > 1
+        (7, 32, 149, 32, 3, 1, 0, 1, False, False),
+        (8, 32, 149, 32, 3, 1, 0, 1, False, False),
+        (32, 32, 149, 32, 3, 1, 0, 1, False, False),
+        # Asymmetric padding
+        (1, 32, 35, 64, 7, 2, (0, 0, 1, 1), 1, False, False),
+        (1, 64, 8, 128, 3, 1, (3, 3, 2, 2), 1, False, False),
+        (1, 64, 8, 64, 1, 1, (1, 2, 2, 1), 1, False, False),
+        (1, 64, 17, 192, 1, 1, (1, 2), 1, False, False),
+        (1, 64, 8, 64, 3, 1, (3, 1), 1, False, False),
+        (1, 128, 8, 384, 3, 1, (0, 2), 1, False, False),
+        (1, 64, 8, 64, 1, 1, "VALID", 1, False, False),
+        (1, 392, 8, 64, 3, 1, "VALID", 1, False, False),
+        (1, 512, 19, 64, 1, 1, "SAME", 1, False, False),
+        (1, 64, 16, 32, 2, 1, "SAME", 1, False, False),
+        (1, 64, 8, 64, 3, 1, (1, 2, 2, 1), 1, False, True),
+        (1, 64, 8, 64, 5, 2, (1, 3), 1, True, False),
+        (1, 64, 56, 64, 3, 1, "VALID", 1, True, True),
+        (1, 64, 56, 64, 24, 1, "SAME", 1, True, True),
+    ],
+)
+def test_conv2d_NCHWc_int8(in_dtype, params):
+    with Int8Fallback():
         (
-            "llvm -device arm_cpu -mtriple aarch64-linux-gnu -mattr=+neon,+v8.2a,+dotprod",
-            topi.arm_cpu.conv2d_NCHWc_int8,
-            topi.arm_cpu.schedule_conv2d_NCHWc_int8,
-            8,
-            build_only_aarch64,
+            batch,
+            in_channel,
+            in_size,
+            num_filter,
+            kernel,
+            stride,
+            padding,
+            dilation,
+            add_bias,
+            add_relu,
+        ) = params
+        pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (kernel, kernel))
+        padding_sum = pad_top + pad_left + pad_bottom + pad_right
+        print(
+            "Workload: (%d, %d, %d, %d, %d, %d, %d, %d)"
+            % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation)
         )
-    )
-
-    if in_dtype == "int8":
-        targets += [
-            (
-                "llvm -device arm_cpu -mtriple aarch64-linux-gnu -mattr=+neon",
-                topi.arm_cpu.conv2d_NCHWc_int8,
-                topi.arm_cpu.schedule_conv2d_NCHWc_int8,
-                8,
-                build_only_aarch64,
-            ),
-            (
-                "rocm -mattr=+dotprod",
-                lambda a, w, s, p, d, l, ol, o: topi.cuda.conv2d_NCHWc_int8(a, w, s, p, d, l, o),
-                topi.cuda.schedule_conv2d_NCHWc_int8,
-                4,
-                False,
-            ),
-        ]
-
-    for target, compute, schedule, oc_block_factor, build_only in targets:
-        check_target(target, compute, schedule, oc_block_factor, build_only)
-
-
-def verify_conv2d_nchw_int8(
-    in_dtype,
-    batch,
-    in_channel,
-    in_size,
-    num_filter,
-    kernel,
-    stride,
-    padding,
-    dilation=1,
-    add_bias=False,
-    add_relu=False,
-):
-    pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (kernel, kernel))
-    padding_sum = pad_top + pad_left + pad_bottom + pad_right
-    print(
-        "Workload: (%d, %d, %d, %d, %d, %d, %d, %d)"
-        % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation)
-    )
-
-    in_height = in_width = in_size
-
-    A = te.placeholder((batch, in_channel, in_height, in_width), name="A", dtype=in_dtype)
-    W = te.placeholder((num_filter, in_channel, kernel, kernel), name="W", dtype=in_dtype)
-    bias = te.placeholder((num_filter, 1, 1), name="bias", dtype=in_dtype)
-
-    a_shape = get_const_tuple(A.shape)
-    w_shape = get_const_tuple(W.shape)
-    bias_shape = get_const_tuple(bias.shape)
-    dtype = A.dtype
-
-    @memoize("topi.tests.test_topi_conv2d_int8.verify_conv2d_nchw")
-    def get_ref_data():
-        a_np = np.random.randint(low=-128, high=127, size=a_shape).astype(dtype)
-        w_np = np.random.randint(low=-128, high=128, size=w_shape).astype(dtype)
-        b_np = np.random.uniform(size=bias_shape).astype(dtype)
-        dw_np = tvm.topi.testing.dilate_python(w_np, (1, 1, dilation, dilation))
-        c_np = tvm.topi.testing.conv2d_nchw_python(a_np, dw_np, stride, padding).astype(dtype)
-
-        if add_bias:
-            b_np = np.random.uniform(size=bias_shape).astype(dtype)
-            c_np += b_np
-        if add_relu:
-            c_np = np.maximum(c_np, 0)
-
-        return a_np, w_np, b_np, c_np
-
-    a_np, w_np, b_np, c_np = get_ref_data()
-
-    def verify_workload_padding():
-        _, _, out_height, out_width = get_const_tuple(c_np.shape)
-        wkl = _get_workload(A, W, (stride, stride), padding, dilation, dtype)
-
-        # for testing functionality,
-        # we choose arbitrary int32_lanes and num_int8_elements can divide the channel,
-        # regardless of the performance.
-        int32_lanes, num_int8_elements = num_filter, in_channel
 
-        # check if tile_ow candidates are the factors of the right output weight.
-        cfg = autotvm.get_config()
-        fallback_schedule_cpu_common_int8(cfg, wkl, int32_lanes, num_int8_elements)
-        ow_tile = np.prod(cfg["tile_ow"].size)
-
-        tvm.testing.assert_allclose(ow_tile, out_width)
+        in_height = in_width = in_size
+
+        A = te.placeholder((batch, in_channel, in_height, in_width), name="A", dtype=in_dtype)
+        W = te.placeholder((num_filter, in_channel, kernel, kernel), name="W", dtype=in_dtype)
+
+        a_shape = get_const_tuple(A.shape)
+        w_shape = get_const_tuple(W.shape)
+        dtype = A.dtype
+        out_dtype = "int32" if in_dtype == "int8" else "uint32"
+        lo = -128 if in_dtype == "int8" else 0
+        hi = 127 if in_dtype == "int8" else 255

Review Comment:
   I think we can use https://github.com/apache/tvm/blob/main/python/tvm/testing/aot.py#L904 here, we can move it out of `aot.py` later. I think we can also be a bit clearer with our variable naming
   
   ```suggestion
           input_min, input_max = get_dtype_range(in_dtype)
   ```



##########
tests/python/topi/python/test_topi_conv2d_int8.py:
##########
@@ -298,378 +175,462 @@ def get_ref_data():
 
         a_np, w_np, b_np, c_np = get_ref_data()
 
-        with tvm.target.Target(target):
-            C = compute(
-                A,
-                W,
-                (stride, stride),
-                padding,
-                (dilation, dilation),
-                "NCHW",
-                "NCHW",
-                out_dtype,
-            )
+        dev = tvm.device(target, 0)
+        with tvm.target.Target(target) as tvm_target:
+            C = compute(A, W, (stride, stride), padding, (dilation, dilation), dtype)
             if add_bias:
                 C = topi.add(C, bias)
             if add_relu:
                 C = topi.nn.relu(C)
             s = schedule([C])
 
-        a = tvm.nd.array(a_np.astype(dtype), dev)
-        w = tvm.nd.array(w_np.astype(dtype), dev)
-        b = tvm.nd.array(b_np.astype(out_dtype), dev)
-        c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), dev)
-
-        if add_bias:
-            compile_args = [A, W, bias, C]
-            run_args = [a, w, b, c]
-        else:
-            compile_args = [A, W, C]
-            run_args = [a, w, c]
-
-        func = tvm.build(
-            s,
-            compile_args,
-            target,
-            name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
-            % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation),
-        )
+            a = tvm.nd.array(a_np, dev)
+            w = tvm.nd.array(w_np, dev)
+            b = tvm.nd.array(b_np, dev)
+            c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), dev)
 
-        if build_only:
-            return
+            build_inputs = [A, W, bias, C] if add_bias else [A, W, C]
+            inference_inputs = (a, w, b, c) if add_bias else (a, w, c)
+
+            func = tvm.build(
+                s,
+                build_inputs,
+                target,
+                name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
+                % (
+                    batch,
+                    in_channel,
+                    in_size,
+                    num_filter,
+                    kernel,
+                    stride,
+                    padding_sum,
+                    dilation,
+                ),
+            )
 
-        print("Running on target: %s" % target)
+            build_only = tvm_target.features.is_aarch64 and (platform.machine() != "aarch64")
 
-        func(*run_args)
+            if not build_only:
+                print("Running on target: %s" % target)
+                func(*inference_inputs)
+                tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-5)
 
-        tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-5)
 
-    targets = [
-        (
-            "cuda",
-            lambda a, w, s, p, d, l, ol, o: topi.cuda.conv2d_NCHWc_int8(a, w, s, p, d, l, o),
-            topi.cuda.schedule_conv2d_NCHWc_int8,
-            4,
-            False,
-        ),
-        # Disable on CI since it does not support spirv int8 dot product
-        # (
-        #     "vulkan -from_device=0",
-        #     lambda a, w, s, p, d, l, ol, o: topi.cuda.conv2d_NCHWc_int8(a, w, s, p, d, l, o),
-        #     topi.cuda.schedule_conv2d_NCHWc_int8,
-        #     4,
-        #     False,
-        # ),
-    ]
-
-    build_only_aarch64 = platform.machine() != "aarch64"
-
-    targets.append(
+@pytest.mark.parametrize("in_dtype", ["int8", "uint8"])
+@pytest.mark.parametrize(
+    "params",
+    [
+        # ResNet18 workloads where channels in / out are multiple of oc_block_factor
+        (1, 64, 56, 64, 3, 1, 1, 1, False, False),
+        (1, 64, 56, 64, 1, 1, 0, 1, False, False),
+        (1, 64, 56, 128, 3, 2, 1, 1, False, False),
+        (1, 64, 56, 128, 1, 2, 0, 1, False, False),
+        (1, 128, 28, 128, 3, 1, 1, 1, False, False),
+        (1, 128, 28, 256, 3, 2, 1, 1, False, False),
+        (1, 128, 28, 256, 1, 2, 0, 1, False, False),
+        (1, 256, 14, 256, 3, 1, 1, 1, False, False),
+        (1, 256, 14, 512, 3, 2, 1, 1, False, False),
+        (1, 256, 14, 512, 1, 2, 0, 1, False, False),
+        (1, 512, 7, 512, 3, 1, 1, 1, False, False),
+        # bias, relu
+        (1, 64, 56, 64, 3, 1, 1, 1, False, True),
+        (1, 64, 56, 64, 3, 1, 1, 1, True, False),
+        (1, 64, 56, 64, 3, 1, 1, 1, True, True),
+        # dilation = 2
+        (1, 64, 56, 64, 3, 1, 1, 2, False, False),
+        # batch size
+        (4, 64, 56, 64, 3, 1, 1, 1, False, False),
+        (9, 64, 56, 64, 3, 1, 1, 1, False, False),
+        # weird workloads
+        (4, 4, 4, 8, 4, 4, 4, 1, False, False),
+        # inception v3 workloads where channels in / out are multiple of oc_block_factor
+        (1, 32, 149, 32, 3, 1, 0, 1, False, False),
+        (1, 32, 147, 64, 3, 1, 1, 1, False, False),
+        (1, 64, 73, 80, 1, 1, 0, 1, False, False),
+        (1, 80, 73, 192, 3, 1, 0, 1, False, False),
+        (1, 192, 35, 64, 1, 1, 0, 1, False, False),
+        (1, 192, 35, 48, 1, 1, 0, 1, False, False),
+        (1, 48, 35, 64, 5, 1, 2, 1, False, False),
+        (1, 64, 35, 96, 3, 1, 1, 1, False, False),
+        (1, 96, 35, 96, 3, 1, 1, 1, False, False),
+        (1, 192, 35, 32, 1, 1, 0, 1, False, False),
+        (1, 256, 35, 64, 1, 1, 0, 1, False, False),
+        (1, 256, 35, 48, 1, 1, 0, 1, False, False),
+        (1, 288, 35, 64, 1, 1, 0, 1, False, False),
+        (1, 288, 35, 48, 1, 1, 0, 1, False, False),
+        (1, 288, 35, 384, 3, 2, 0, 1, False, False),
+        (1, 96, 35, 96, 3, 2, 0, 1, False, False),
+        (1, 768, 17, 192, 1, 1, 0, 1, False, False),
+        (1, 768, 17, 128, 1, 1, 0, 1, False, False),
+        (1, 128, 17, 128, 1, 1, 0, 1, False, False),
+        (1, 128, 17, 192, 7, 1, 3, 1, False, False),
+        (1, 128, 17, 128, 7, 1, 3, 1, False, False),
+        (1, 128, 17, 192, 1, 1, 0, 1, False, False),
+        (1, 768, 17, 160, 1, 1, 0, 1, False, False),
+        (1, 160, 17, 160, 1, 1, 0, 1, False, False),
+        (1, 160, 17, 192, 7, 1, 3, 1, False, False),
+        (1, 160, 17, 160, 7, 1, 3, 1, False, False),
+        (1, 160, 17, 192, 1, 1, 0, 1, False, False),
+        (1, 192, 17, 192, 1, 1, 0, 1, False, False),
+        (1, 192, 17, 192, 7, 1, 3, 1, False, False),
+        (1, 192, 17, 320, 3, 2, 0, 1, False, False),
+        (1, 192, 17, 192, 3, 2, 0, 1, False, False),
+        (1, 1280, 8, 320, 1, 1, 0, 1, False, False),
+        (1, 1280, 8, 384, 1, 1, 0, 1, False, False),
+        (1, 384, 8, 384, 1, 1, 0, 1, False, False),
+        (1, 384, 8, 384, 3, 1, 1, 1, False, False),
+        (1, 1280, 8, 448, 1, 1, 0, 1, False, False),
+        (1, 448, 8, 384, 3, 1, 1, 1, False, False),
+        (1, 1280, 8, 192, 1, 1, 0, 1, False, False),
+        (1, 2048, 8, 320, 1, 1, 0, 1, False, False),
+        (1, 2048, 8, 384, 1, 1, 0, 1, False, False),
+        (1, 2048, 8, 448, 1, 1, 0, 1, False, False),
+        (1, 2048, 8, 192, 1, 1, 0, 1, False, False),
+        (1, 1024, 19, 88, 3, 1, 1, 1, False, False),
+        # batch > 1
+        (7, 32, 149, 32, 3, 1, 0, 1, False, False),
+        (8, 32, 149, 32, 3, 1, 0, 1, False, False),
+        (32, 32, 149, 32, 3, 1, 0, 1, False, False),
+        # Asymmetric padding
+        (1, 32, 35, 64, 7, 2, (0, 0, 1, 1), 1, False, False),
+        (1, 64, 8, 128, 3, 1, (3, 3, 2, 2), 1, False, False),
+        (1, 64, 8, 64, 1, 1, (1, 2, 2, 1), 1, False, False),
+        (1, 64, 17, 192, 1, 1, (1, 2), 1, False, False),
+        (1, 64, 8, 64, 3, 1, (3, 1), 1, False, False),
+        (1, 128, 8, 384, 3, 1, (0, 2), 1, False, False),
+        (1, 64, 8, 64, 1, 1, "VALID", 1, False, False),
+        (1, 392, 8, 64, 3, 1, "VALID", 1, False, False),
+        (1, 512, 19, 64, 1, 1, "SAME", 1, False, False),
+        (1, 64, 16, 32, 2, 1, "SAME", 1, False, False),
+        (1, 64, 8, 64, 3, 1, (1, 2, 2, 1), 1, False, True),
+        (1, 64, 8, 64, 5, 2, (1, 3), 1, True, False),
+        (1, 64, 56, 64, 3, 1, "VALID", 1, True, True),
+        (1, 64, 56, 64, 24, 1, "SAME", 1, True, True),
+    ],
+)
+def test_conv2d_NCHWc_int8(in_dtype, params):
+    with Int8Fallback():
         (
-            "llvm -device arm_cpu -mtriple aarch64-linux-gnu -mattr=+neon,+v8.2a,+dotprod",
-            topi.arm_cpu.conv2d_NCHWc_int8,
-            topi.arm_cpu.schedule_conv2d_NCHWc_int8,
-            8,
-            build_only_aarch64,
+            batch,
+            in_channel,
+            in_size,
+            num_filter,
+            kernel,
+            stride,
+            padding,
+            dilation,
+            add_bias,
+            add_relu,
+        ) = params
+        pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (kernel, kernel))
+        padding_sum = pad_top + pad_left + pad_bottom + pad_right
+        print(
+            "Workload: (%d, %d, %d, %d, %d, %d, %d, %d)"
+            % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation)
         )
-    )
-
-    if in_dtype == "int8":
-        targets += [
-            (
-                "llvm -device arm_cpu -mtriple aarch64-linux-gnu -mattr=+neon",
-                topi.arm_cpu.conv2d_NCHWc_int8,
-                topi.arm_cpu.schedule_conv2d_NCHWc_int8,
-                8,
-                build_only_aarch64,
-            ),
-            (
-                "rocm -mattr=+dotprod",
-                lambda a, w, s, p, d, l, ol, o: topi.cuda.conv2d_NCHWc_int8(a, w, s, p, d, l, o),
-                topi.cuda.schedule_conv2d_NCHWc_int8,
-                4,
-                False,
-            ),
-        ]
-
-    for target, compute, schedule, oc_block_factor, build_only in targets:
-        check_target(target, compute, schedule, oc_block_factor, build_only)
-
-
-def verify_conv2d_nchw_int8(
-    in_dtype,
-    batch,
-    in_channel,
-    in_size,
-    num_filter,
-    kernel,
-    stride,
-    padding,
-    dilation=1,
-    add_bias=False,
-    add_relu=False,
-):
-    pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (kernel, kernel))
-    padding_sum = pad_top + pad_left + pad_bottom + pad_right
-    print(
-        "Workload: (%d, %d, %d, %d, %d, %d, %d, %d)"
-        % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation)
-    )
-
-    in_height = in_width = in_size
-
-    A = te.placeholder((batch, in_channel, in_height, in_width), name="A", dtype=in_dtype)
-    W = te.placeholder((num_filter, in_channel, kernel, kernel), name="W", dtype=in_dtype)
-    bias = te.placeholder((num_filter, 1, 1), name="bias", dtype=in_dtype)
-
-    a_shape = get_const_tuple(A.shape)
-    w_shape = get_const_tuple(W.shape)
-    bias_shape = get_const_tuple(bias.shape)
-    dtype = A.dtype
-
-    @memoize("topi.tests.test_topi_conv2d_int8.verify_conv2d_nchw")
-    def get_ref_data():
-        a_np = np.random.randint(low=-128, high=127, size=a_shape).astype(dtype)
-        w_np = np.random.randint(low=-128, high=128, size=w_shape).astype(dtype)
-        b_np = np.random.uniform(size=bias_shape).astype(dtype)
-        dw_np = tvm.topi.testing.dilate_python(w_np, (1, 1, dilation, dilation))
-        c_np = tvm.topi.testing.conv2d_nchw_python(a_np, dw_np, stride, padding).astype(dtype)
-
-        if add_bias:
-            b_np = np.random.uniform(size=bias_shape).astype(dtype)
-            c_np += b_np
-        if add_relu:
-            c_np = np.maximum(c_np, 0)
-
-        return a_np, w_np, b_np, c_np
-
-    a_np, w_np, b_np, c_np = get_ref_data()
-
-    def verify_workload_padding():
-        _, _, out_height, out_width = get_const_tuple(c_np.shape)
-        wkl = _get_workload(A, W, (stride, stride), padding, dilation, dtype)
-
-        # for testing functionality,
-        # we choose arbitrary int32_lanes and num_int8_elements can divide the channel,
-        # regardless of the performance.
-        int32_lanes, num_int8_elements = num_filter, in_channel
 
-        # check if tile_ow candidates are the factors of the right output weight.
-        cfg = autotvm.get_config()
-        fallback_schedule_cpu_common_int8(cfg, wkl, int32_lanes, num_int8_elements)
-        ow_tile = np.prod(cfg["tile_ow"].size)
-
-        tvm.testing.assert_allclose(ow_tile, out_width)
+        in_height = in_width = in_size
+
+        A = te.placeholder((batch, in_channel, in_height, in_width), name="A", dtype=in_dtype)
+        W = te.placeholder((num_filter, in_channel, kernel, kernel), name="W", dtype=in_dtype)
+
+        a_shape = get_const_tuple(A.shape)
+        w_shape = get_const_tuple(W.shape)
+        dtype = A.dtype
+        out_dtype = "int32" if in_dtype == "int8" else "uint32"
+        lo = -128 if in_dtype == "int8" else 0
+        hi = 127 if in_dtype == "int8" else 255
+
+        def check_target(target, compute, schedule, oc_block_factor, build_only):
+            dev = tvm.device(target, 0)
+            if not tvm.testing.device_enabled(target):
+                print("Skip because %s is not enabled" % target)
+                return
+            if target == "cuda" and not tvm.contrib.nvcc.have_int8(dev.compute_version):
+                print("Skip because int8 intrinsics are not available")
+                return
+
+            bias = te.placeholder(
+                (num_filter // oc_block_factor, 1, 1, oc_block_factor), name="bias", dtype=out_dtype
+            )
+            bias_shape = get_const_tuple(bias.shape)
 
-    def check_target(target):
-        dev = tvm.device(target, 0)
-        if not tvm.testing.device_enabled(target):
-            print("Skip because %s is not enabled" % target)
-            return
-        if target == "cuda" and not tvm.contrib.nvcc.have_int8(dev.compute_version):
-            print("Skip because int8 intrinsics are not available")
-            return
+            @memoize("topi.tests.test_topi_conv2d_int8.verify_conv2d_nchw")
+            def get_ref_data():
+                a_np = np.random.randint(low=lo, high=hi, size=a_shape).astype(out_dtype)
+                w_np = np.random.randint(low=lo, high=hi, size=w_shape).astype(out_dtype)
+                b_np = np.random.uniform(size=bias_shape).astype(out_dtype)
+                dw_np = tvm.topi.testing.dilate_python(w_np, (1, 1, dilation, dilation))
+                c_np = tvm.topi.testing.conv2d_nchw_python(a_np, dw_np, stride, padding).astype(
+                    out_dtype
+                )
+
+                # convert to NCHWc
+                _, _, out_height, out_width = c_np.shape
+                c_np = c_np.reshape(
+                    (batch, num_filter // oc_block_factor, oc_block_factor, out_height, out_width)
+                ).transpose(0, 1, 3, 4, 2)
+
+                if add_bias:
+                    b_np = np.random.uniform(size=bias_shape).astype(out_dtype)
+                    c_np += b_np
+                if add_relu:
+                    c_np = np.maximum(c_np, 0)
+
+                return a_np, w_np, b_np, c_np
+
+            a_np, w_np, b_np, c_np = get_ref_data()
+
+            with tvm.target.Target(target):
+                C = compute(
+                    A,
+                    W,
+                    (stride, stride),
+                    padding,
+                    (dilation, dilation),
+                    "NCHW",
+                    "NCHW",
+                    out_dtype,
+                )
+                if add_bias:
+                    C = topi.add(C, bias)
+                if add_relu:
+                    C = topi.nn.relu(C)
+                s = schedule([C])
+
+            a = tvm.nd.array(a_np.astype(dtype), dev)
+            w = tvm.nd.array(w_np.astype(dtype), dev)
+            b = tvm.nd.array(b_np.astype(out_dtype), dev)
+            c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), dev)
 
-        print("Running on target: %s" % target)
-        with tvm.target.Target(target):
-            C = topi.cuda.conv2d_nchw_int8(
-                A, W, (stride, stride), padding, (dilation, dilation), dtype
-            )
             if add_bias:
-                C = topi.add(C, bias)
-            if add_relu:
-                C = topi.nn.relu(C)
-            s = topi.cuda.schedule_conv2d_nchw_int8([C])
-
-        a = tvm.nd.array(a_np, dev)
-        w = tvm.nd.array(w_np, dev)
-        b = tvm.nd.array(b_np, dev)
-        c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), dev)
-        if add_bias:
-            tvm.build(
-                s,
-                [A, W, bias, C],
-                target,
-                name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
-                % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation),
-            )
-            func = tvm.build(
-                s,
-                [A, W, bias, C],
-                target,
-                name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
-                % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation),
-            )
-            func(a, w, b, c)
-        else:
+                compile_args = [A, W, bias, C]
+                run_args = [a, w, b, c]
+            else:
+                compile_args = [A, W, C]
+                run_args = [a, w, c]
+
             func = tvm.build(
                 s,
-                [A, W, C],
+                compile_args,
                 target,
                 name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
                 % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation),
             )
-            func(a, w, c)
-        tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-5)
 
-    verify_workload_padding()
+            if build_only:
+                return
 
-    for target in ["cuda"]:
-        check_target(target)
+            print("Running on target: %s" % target)
 
+            func(*run_args)
 
-@pytest.mark.parametrize("in_dtype", ["int8", "uint8"])
-def test_conv2d_nchw(in_dtype):
-    with Int8Fallback():
-        # ResNet18 workloads where channels in / out are multiple of oc_block_factor
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 56, 64, 3, 1, 1)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 56, 64, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 56, 128, 3, 2, 1)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 56, 128, 1, 2, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 128, 28, 128, 3, 1, 1)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 128, 28, 256, 3, 2, 1)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 128, 28, 256, 1, 2, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 256, 14, 256, 3, 1, 1)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 256, 14, 512, 3, 2, 1)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 256, 14, 512, 1, 2, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 512, 7, 512, 3, 1, 1)
+            tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-5)
 
-        # bias, relu
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 56, 64, 3, 1, 1, add_relu=True)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 56, 64, 3, 1, 1, add_bias=True)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 56, 64, 3, 1, 1, add_bias=True, add_relu=True)
+        targets = [
+            (
+                "cuda",
+                lambda a, w, s, p, d, l, ol, o: topi.cuda.conv2d_NCHWc_int8(a, w, s, p, d, l, o),
+                topi.cuda.schedule_conv2d_NCHWc_int8,
+                4,
+                False,
+            ),
+            # Disable on CI since it does not support spirv int8 dot product
+            # (
+            #     "vulkan -from_device=0",
+            #     lambda a, w, s, p, d, l, ol, o: topi.cuda.conv2d_NCHWc_int8(a, w, s, p, d, l, o),
+            #     topi.cuda.schedule_conv2d_NCHWc_int8,
+            #     4,
+            #     False,
+            # ),
+        ]
 
-        # dilation = 2
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 56, 64, 3, 1, 1, dilation=2)
+        build_only_aarch64 = platform.machine() != "aarch64"
 
-        # batch size
-        verify_conv2d_NCHWc_int8(in_dtype, 4, 64, 56, 64, 3, 1, 1)
-        verify_conv2d_NCHWc_int8(in_dtype, 9, 64, 56, 64, 3, 1, 1)
+        targets.append(
+            (
+                "llvm -device arm_cpu -mtriple aarch64-linux-gnu -mattr=+neon,+v8.2a,+dotprod",
+                topi.arm_cpu.conv2d_NCHWc_int8,
+                topi.arm_cpu.schedule_conv2d_NCHWc_int8,
+                8,
+                build_only_aarch64,
+            )
+        )
 
-        # weird workloads
-        verify_conv2d_NCHWc_int8(in_dtype, 4, 4, 4, 8, 4, 4, 4)
+        if in_dtype == "int8":
+            targets += [
+                (
+                    "llvm -device arm_cpu -mtriple aarch64-linux-gnu -mattr=+neon",
+                    topi.arm_cpu.conv2d_NCHWc_int8,
+                    topi.arm_cpu.schedule_conv2d_NCHWc_int8,
+                    8,
+                    build_only_aarch64,
+                ),
+                (
+                    "rocm -mattr=+dotprod",
+                    lambda a, w, s, p, d, l, ol, o: topi.cuda.conv2d_NCHWc_int8(
+                        a, w, s, p, d, l, o
+                    ),
+                    topi.cuda.schedule_conv2d_NCHWc_int8,
+                    4,
+                    False,
+                ),
+            ]
+
+        for target, compute, schedule, oc_block_factor, build_only in targets:
+            check_target(target, compute, schedule, oc_block_factor, build_only)
+
+
+# Conv2d NCHW int8 schedule testing. Internally, it uses NCHWc schedule. So, just
+# performing basic testing - one test for all different scenarios - batch, dilation etc..
+@pytest.mark.parametrize("in_dtype", ["int8", "uint8"])
+@pytest.mark.parametrize(
+    "params",
+    [
+        (1, 64, 56, 64, 3, 1, 1, 1, False, False),
+        (1, 64, 56, 64, 3, 1, 1, 1, False, True),
+        (1, 64, 56, 64, 3, 1, 1, 2, False, False),
+        (9, 64, 56, 64, 3, 1, 1, 1, False, False),
+        (4, 4, 4, 4, 4, 4, 4, 1, False, False),
+        (1, 32, 149, 32, 3, 1, 0, 1, False, False),
+        (7, 32, 149, 32, 3, 1, 0, 1, False, False),
+        (1, 32, 35, 64, 7, 2, (0, 0, 1, 1), 1, False, False),
+        (1, 32, 35, 64, 7, 2, (0, 0, 2, 2), 1, False, False),
+    ],
+)
+def test_conv2d_nchw_int8(in_dtype, params):
+    with Int8Fallback():
+        (
+            batch,
+            in_channel,
+            in_size,
+            num_filter,
+            kernel,
+            stride,
+            padding,
+            dilation,
+            add_bias,
+            add_relu,
+        ) = params
+        pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (kernel, kernel))
+        padding_sum = pad_top + pad_left + pad_bottom + pad_right
+        print(
+            "Workload: (%d, %d, %d, %d, %d, %d, %d, %d)"
+            % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation)
+        )
 
-        # inception v3 workloads where channels in / out are multiple of oc_block_factor
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 32, 149, 32, 3, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 32, 147, 64, 3, 1, 1)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 73, 80, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 80, 73, 192, 3, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 192, 35, 64, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 192, 35, 48, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 48, 35, 64, 5, 1, 2)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 35, 96, 3, 1, 1)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 96, 35, 96, 3, 1, 1)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 192, 35, 32, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 256, 35, 64, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 256, 35, 48, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 288, 35, 64, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 288, 35, 48, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 288, 35, 384, 3, 2, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 96, 35, 96, 3, 2, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 768, 17, 192, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 768, 17, 128, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 128, 17, 128, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 128, 17, 192, 7, 1, 3)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 128, 17, 128, 7, 1, 3)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 128, 17, 192, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 768, 17, 160, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 160, 17, 160, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 160, 17, 192, 7, 1, 3)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 160, 17, 160, 7, 1, 3)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 160, 17, 192, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 192, 17, 192, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 192, 17, 192, 7, 1, 3)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 192, 17, 320, 3, 2, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 192, 17, 192, 3, 2, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 1280, 8, 320, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 1280, 8, 384, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 384, 8, 384, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 384, 8, 384, 3, 1, 1)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 1280, 8, 448, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 448, 8, 384, 3, 1, 1)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 1280, 8, 192, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 2048, 8, 320, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 2048, 8, 384, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 2048, 8, 448, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 2048, 8, 192, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 1024, 19, 88, 3, 1, 1)
+        in_height = in_width = in_size
 
-        # batch > 1
-        verify_conv2d_NCHWc_int8(in_dtype, 7, 32, 149, 32, 3, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 8, 32, 149, 32, 3, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 32, 32, 149, 32, 3, 1, 0)
+        A = te.placeholder((batch, in_channel, in_height, in_width), name="A", dtype=in_dtype)
+        W = te.placeholder((num_filter, in_channel, kernel, kernel), name="W", dtype=in_dtype)
+        bias = te.placeholder((num_filter, 1, 1), name="bias", dtype=in_dtype)
 
-        # Asymmetric padding
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 32, 35, 64, 7, 2, (0, 0, 1, 1))
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 8, 128, 3, 1, (3, 3, 2, 2))
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 8, 64, 1, 1, (1, 2, 2, 1))
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 17, 192, 1, 1, (1, 2))
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 8, 64, 3, 1, (3, 1))
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 128, 8, 384, 3, 1, (0, 2))
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 8, 64, 1, 1, "VALID")
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 392, 8, 64, 3, 1, "VALID")
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 512, 19, 64, 1, 1, "SAME")
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 16, 32, 2, 1, "SAME")
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 8, 64, 3, 1, (1, 2, 2, 1), add_relu=True)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 8, 64, 5, 2, (1, 3), add_bias=True)
-        verify_conv2d_NCHWc_int8(
-            in_dtype, 1, 64, 56, 64, 3, 1, "VALID", add_bias=True, add_relu=True
-        )
-        verify_conv2d_NCHWc_int8(
-            in_dtype, 1, 64, 56, 64, 24, 1, "SAME", add_bias=True, add_relu=True
-        )
+        a_shape = get_const_tuple(A.shape)
+        w_shape = get_const_tuple(W.shape)
+        bias_shape = get_const_tuple(bias.shape)
+        dtype = A.dtype
+
+        @memoize("topi.tests.test_topi_conv2d_int8.verify_conv2d_nchw")
+        def get_ref_data():
+            a_np = np.random.randint(low=-128, high=127, size=a_shape).astype(dtype)
+            w_np = np.random.randint(low=-128, high=128, size=w_shape).astype(dtype)
+            b_np = np.random.uniform(size=bias_shape).astype(dtype)
+            dw_np = tvm.topi.testing.dilate_python(w_np, (1, 1, dilation, dilation))
+            c_np = tvm.topi.testing.conv2d_nchw_python(a_np, dw_np, stride, padding).astype(dtype)
 
-        # Conv2d NCHW int8 schedule testing. Internally, it uses NCHWc schedule. So, just
-        # performing basic testing - one test for all different scenarios - batch, dilation etc..
-        verify_conv2d_nchw_int8(in_dtype, 1, 64, 56, 64, 3, 1, 1)
-        verify_conv2d_nchw_int8(in_dtype, 1, 64, 56, 64, 3, 1, 1, add_relu=True)
-        verify_conv2d_nchw_int8(in_dtype, 1, 64, 56, 64, 3, 1, 1, dilation=2)
-        verify_conv2d_nchw_int8(in_dtype, 9, 64, 56, 64, 3, 1, 1)
-        verify_conv2d_nchw_int8(in_dtype, 4, 4, 4, 4, 4, 4, 4)
-        verify_conv2d_nchw_int8(in_dtype, 1, 32, 149, 32, 3, 1, 0)
-        verify_conv2d_nchw_int8(in_dtype, 7, 32, 149, 32, 3, 1, 0)
-        verify_conv2d_nchw_int8(in_dtype, 1, 32, 35, 64, 7, 2, (0, 0, 1, 1))
-        verify_conv2d_nchw_int8(in_dtype, 1, 32, 35, 64, 7, 2, (0, 0, 2, 2))
+            if add_bias:
+                b_np = np.random.uniform(size=bias_shape).astype(dtype)
+                c_np += b_np
+            if add_relu:
+                c_np = np.maximum(c_np, 0)
 
+            return a_np, w_np, b_np, c_np
 
-def test_conv2d_nhwc():
-    with Int8Fallback():
-        # Subset of inception v3 expanded (dilation > 1, batch > 1, 'VALID' padding)
-        verify_conv2d_NHWC_gemm_int8(1, 3, 299, 32, 3, 2, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 32, 149, 32, 3, 1, "SAME", dilation=2)
-        verify_conv2d_NHWC_gemm_int8(4, 32, 147, 64, 3, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 64, 73, 80, 1, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 80, 73, 192, 3, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 192, 35, 48, 1, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 192, 35, 64, 1, 1, "VALID")
-        verify_conv2d_NHWC_gemm_int8(1, 192, 35, 32, 1, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 48, 35, 64, 5, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 96, 35, 96, 3, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 256, 35, 48, 1, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 256, 35, 64, 1, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 288, 35, 64, 1, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 288, 35, 48, 1, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 96, 35, 96, 3, 2, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 128, 17, 192, 7, 1, "SAME", dilation=2)
-        verify_conv2d_NHWC_gemm_int8(1, 160, 17, 160, 7, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 160, 17, 192, 1, 1, "VALID")
-        verify_conv2d_NHWC_gemm_int8(1, 192, 17, 192, 1, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 768, 5, 128, 1, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 192, 17, 320, 3, 2, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 192, 17, 192, 3, 2, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 1280, 8, 192, 1, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 1280, 8, 384, 1, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 1280, 8, 320, 1, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 1280, 8, 448, 1, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 384, 8, 384, 1, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 384, 8, 384, 3, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 448, 8, 384, 3, 1, "VALID")
-        verify_conv2d_NHWC_gemm_int8(1, 2048, 8, 320, 1, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 2048, 8, 448, 1, 1, "SAME", add_bias=True, add_relu=True)
-        verify_conv2d_NHWC_gemm_int8(1, 2048, 8, 192, 1, 1, "SAME", add_bias=True)
-
-        # Let's also verify that it compiles fine on AArch64 targets
-        compile_conv2d_NHWC_gemm_int8_arm(1, 3, 299, 32, 3, 2, "SAME")
+        a_np, w_np, b_np, c_np = get_ref_data()
+
+        def verify_workload_padding():
+            _, _, out_height, out_width = get_const_tuple(c_np.shape)
+            wkl = _get_workload(A, W, (stride, stride), padding, dilation, dtype)
+
+            # for testing functionality,
+            # we choose arbitrary int32_lanes and num_int8_elements can divide the channel,
+            # regardless of the performance.
+            int32_lanes, num_int8_elements = num_filter, in_channel
+
+            # check if tile_ow candidates are the factors of the right output weight.
+            cfg = autotvm.get_config()
+            fallback_schedule_cpu_common_int8(cfg, wkl, int32_lanes, num_int8_elements)
+            ow_tile = np.prod(cfg["tile_ow"].size)
+
+            tvm.testing.assert_allclose(ow_tile, out_width)
+
+        def check_target(target):
+            dev = tvm.device(target, 0)
+            if not tvm.testing.device_enabled(target):
+                print("Skip because %s is not enabled" % target)
+                return
+            if target == "cuda" and not tvm.contrib.nvcc.have_int8(dev.compute_version):
+                print("Skip because int8 intrinsics are not available")
+                return
+
+            print("Running on target: %s" % target)
+            with tvm.target.Target(target):
+                C = topi.cuda.conv2d_nchw_int8(
+                    A, W, (stride, stride), padding, (dilation, dilation), dtype
+                )
+                if add_bias:
+                    C = topi.add(C, bias)
+                if add_relu:
+                    C = topi.nn.relu(C)
+                s = topi.cuda.schedule_conv2d_nchw_int8([C])
+
+            a = tvm.nd.array(a_np, dev)
+            w = tvm.nd.array(w_np, dev)
+            b = tvm.nd.array(b_np, dev)
+            c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), dev)
+            if add_bias:
+                func = tvm.build(
+                    s,
+                    [A, W, bias, C],
+                    target,
+                    name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
+                    % (
+                        batch,
+                        in_channel,
+                        in_size,
+                        num_filter,
+                        kernel,
+                        stride,
+                        padding_sum,
+                        dilation,
+                    ),
+                )
+                func(a, w, b, c)
+            else:
+                func = tvm.build(
+                    s,
+                    [A, W, C],
+                    target,
+                    name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
+                    % (
+                        batch,
+                        in_channel,
+                        in_size,
+                        num_filter,
+                        kernel,
+                        stride,
+                        padding_sum,
+                        dilation,
+                    ),
+                )
+                func(a, w, c)

Review Comment:
   I think only a few things change here, so we can reduce the amount of duplication.
   
   ```suggestion
               build_args = [A, W, bias, C] if add_bias else [A, W, C]
               func = tvm.build(
                   s,
                   build_args,
                   target,
                   name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
                   % (
                       batch,
                       in_channel,
                       in_size,
                       num_filter,
                       kernel,
                       stride,
                       padding_sum,
                       dilation,
                   ),
               )
               if add_bias:
                   func(a, w, b, c)
               else:
                   func(a, w, c)
   ```



##########
tests/python/topi/python/test_topi_conv2d_int8.py:
##########
@@ -298,378 +175,462 @@ def get_ref_data():
 
         a_np, w_np, b_np, c_np = get_ref_data()
 
-        with tvm.target.Target(target):
-            C = compute(
-                A,
-                W,
-                (stride, stride),
-                padding,
-                (dilation, dilation),
-                "NCHW",
-                "NCHW",
-                out_dtype,
-            )
+        dev = tvm.device(target, 0)
+        with tvm.target.Target(target) as tvm_target:
+            C = compute(A, W, (stride, stride), padding, (dilation, dilation), dtype)
             if add_bias:
                 C = topi.add(C, bias)
             if add_relu:
                 C = topi.nn.relu(C)
             s = schedule([C])
 
-        a = tvm.nd.array(a_np.astype(dtype), dev)
-        w = tvm.nd.array(w_np.astype(dtype), dev)
-        b = tvm.nd.array(b_np.astype(out_dtype), dev)
-        c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), dev)
-
-        if add_bias:
-            compile_args = [A, W, bias, C]
-            run_args = [a, w, b, c]
-        else:
-            compile_args = [A, W, C]
-            run_args = [a, w, c]
-
-        func = tvm.build(
-            s,
-            compile_args,
-            target,
-            name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
-            % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation),
-        )
+            a = tvm.nd.array(a_np, dev)
+            w = tvm.nd.array(w_np, dev)
+            b = tvm.nd.array(b_np, dev)
+            c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), dev)
 
-        if build_only:
-            return
+            build_inputs = [A, W, bias, C] if add_bias else [A, W, C]
+            inference_inputs = (a, w, b, c) if add_bias else (a, w, c)
+
+            func = tvm.build(
+                s,
+                build_inputs,
+                target,
+                name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
+                % (
+                    batch,
+                    in_channel,
+                    in_size,
+                    num_filter,
+                    kernel,
+                    stride,
+                    padding_sum,
+                    dilation,
+                ),
+            )
 
-        print("Running on target: %s" % target)
+            build_only = tvm_target.features.is_aarch64 and (platform.machine() != "aarch64")
 
-        func(*run_args)
+            if not build_only:
+                print("Running on target: %s" % target)
+                func(*inference_inputs)
+                tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-5)
 
-        tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-5)
 
-    targets = [
-        (
-            "cuda",
-            lambda a, w, s, p, d, l, ol, o: topi.cuda.conv2d_NCHWc_int8(a, w, s, p, d, l, o),
-            topi.cuda.schedule_conv2d_NCHWc_int8,
-            4,
-            False,
-        ),
-        # Disable on CI since it does not support spirv int8 dot product
-        # (
-        #     "vulkan -from_device=0",
-        #     lambda a, w, s, p, d, l, ol, o: topi.cuda.conv2d_NCHWc_int8(a, w, s, p, d, l, o),
-        #     topi.cuda.schedule_conv2d_NCHWc_int8,
-        #     4,
-        #     False,
-        # ),
-    ]
-
-    build_only_aarch64 = platform.machine() != "aarch64"
-
-    targets.append(
+@pytest.mark.parametrize("in_dtype", ["int8", "uint8"])
+@pytest.mark.parametrize(
+    "params",
+    [
+        # ResNet18 workloads where channels in / out are multiple of oc_block_factor
+        (1, 64, 56, 64, 3, 1, 1, 1, False, False),
+        (1, 64, 56, 64, 1, 1, 0, 1, False, False),
+        (1, 64, 56, 128, 3, 2, 1, 1, False, False),
+        (1, 64, 56, 128, 1, 2, 0, 1, False, False),
+        (1, 128, 28, 128, 3, 1, 1, 1, False, False),
+        (1, 128, 28, 256, 3, 2, 1, 1, False, False),
+        (1, 128, 28, 256, 1, 2, 0, 1, False, False),
+        (1, 256, 14, 256, 3, 1, 1, 1, False, False),
+        (1, 256, 14, 512, 3, 2, 1, 1, False, False),
+        (1, 256, 14, 512, 1, 2, 0, 1, False, False),
+        (1, 512, 7, 512, 3, 1, 1, 1, False, False),
+        # bias, relu
+        (1, 64, 56, 64, 3, 1, 1, 1, False, True),
+        (1, 64, 56, 64, 3, 1, 1, 1, True, False),
+        (1, 64, 56, 64, 3, 1, 1, 1, True, True),
+        # dilation = 2
+        (1, 64, 56, 64, 3, 1, 1, 2, False, False),
+        # batch size
+        (4, 64, 56, 64, 3, 1, 1, 1, False, False),
+        (9, 64, 56, 64, 3, 1, 1, 1, False, False),
+        # weird workloads
+        (4, 4, 4, 8, 4, 4, 4, 1, False, False),
+        # inception v3 workloads where channels in / out are multiple of oc_block_factor
+        (1, 32, 149, 32, 3, 1, 0, 1, False, False),
+        (1, 32, 147, 64, 3, 1, 1, 1, False, False),
+        (1, 64, 73, 80, 1, 1, 0, 1, False, False),
+        (1, 80, 73, 192, 3, 1, 0, 1, False, False),
+        (1, 192, 35, 64, 1, 1, 0, 1, False, False),
+        (1, 192, 35, 48, 1, 1, 0, 1, False, False),
+        (1, 48, 35, 64, 5, 1, 2, 1, False, False),
+        (1, 64, 35, 96, 3, 1, 1, 1, False, False),
+        (1, 96, 35, 96, 3, 1, 1, 1, False, False),
+        (1, 192, 35, 32, 1, 1, 0, 1, False, False),
+        (1, 256, 35, 64, 1, 1, 0, 1, False, False),
+        (1, 256, 35, 48, 1, 1, 0, 1, False, False),
+        (1, 288, 35, 64, 1, 1, 0, 1, False, False),
+        (1, 288, 35, 48, 1, 1, 0, 1, False, False),
+        (1, 288, 35, 384, 3, 2, 0, 1, False, False),
+        (1, 96, 35, 96, 3, 2, 0, 1, False, False),
+        (1, 768, 17, 192, 1, 1, 0, 1, False, False),
+        (1, 768, 17, 128, 1, 1, 0, 1, False, False),
+        (1, 128, 17, 128, 1, 1, 0, 1, False, False),
+        (1, 128, 17, 192, 7, 1, 3, 1, False, False),
+        (1, 128, 17, 128, 7, 1, 3, 1, False, False),
+        (1, 128, 17, 192, 1, 1, 0, 1, False, False),
+        (1, 768, 17, 160, 1, 1, 0, 1, False, False),
+        (1, 160, 17, 160, 1, 1, 0, 1, False, False),
+        (1, 160, 17, 192, 7, 1, 3, 1, False, False),
+        (1, 160, 17, 160, 7, 1, 3, 1, False, False),
+        (1, 160, 17, 192, 1, 1, 0, 1, False, False),
+        (1, 192, 17, 192, 1, 1, 0, 1, False, False),
+        (1, 192, 17, 192, 7, 1, 3, 1, False, False),
+        (1, 192, 17, 320, 3, 2, 0, 1, False, False),
+        (1, 192, 17, 192, 3, 2, 0, 1, False, False),
+        (1, 1280, 8, 320, 1, 1, 0, 1, False, False),
+        (1, 1280, 8, 384, 1, 1, 0, 1, False, False),
+        (1, 384, 8, 384, 1, 1, 0, 1, False, False),
+        (1, 384, 8, 384, 3, 1, 1, 1, False, False),
+        (1, 1280, 8, 448, 1, 1, 0, 1, False, False),
+        (1, 448, 8, 384, 3, 1, 1, 1, False, False),
+        (1, 1280, 8, 192, 1, 1, 0, 1, False, False),
+        (1, 2048, 8, 320, 1, 1, 0, 1, False, False),
+        (1, 2048, 8, 384, 1, 1, 0, 1, False, False),
+        (1, 2048, 8, 448, 1, 1, 0, 1, False, False),
+        (1, 2048, 8, 192, 1, 1, 0, 1, False, False),
+        (1, 1024, 19, 88, 3, 1, 1, 1, False, False),
+        # batch > 1
+        (7, 32, 149, 32, 3, 1, 0, 1, False, False),
+        (8, 32, 149, 32, 3, 1, 0, 1, False, False),
+        (32, 32, 149, 32, 3, 1, 0, 1, False, False),
+        # Asymmetric padding
+        (1, 32, 35, 64, 7, 2, (0, 0, 1, 1), 1, False, False),
+        (1, 64, 8, 128, 3, 1, (3, 3, 2, 2), 1, False, False),
+        (1, 64, 8, 64, 1, 1, (1, 2, 2, 1), 1, False, False),
+        (1, 64, 17, 192, 1, 1, (1, 2), 1, False, False),
+        (1, 64, 8, 64, 3, 1, (3, 1), 1, False, False),
+        (1, 128, 8, 384, 3, 1, (0, 2), 1, False, False),
+        (1, 64, 8, 64, 1, 1, "VALID", 1, False, False),
+        (1, 392, 8, 64, 3, 1, "VALID", 1, False, False),
+        (1, 512, 19, 64, 1, 1, "SAME", 1, False, False),
+        (1, 64, 16, 32, 2, 1, "SAME", 1, False, False),
+        (1, 64, 8, 64, 3, 1, (1, 2, 2, 1), 1, False, True),
+        (1, 64, 8, 64, 5, 2, (1, 3), 1, True, False),
+        (1, 64, 56, 64, 3, 1, "VALID", 1, True, True),
+        (1, 64, 56, 64, 24, 1, "SAME", 1, True, True),
+    ],
+)
+def test_conv2d_NCHWc_int8(in_dtype, params):
+    with Int8Fallback():
         (
-            "llvm -device arm_cpu -mtriple aarch64-linux-gnu -mattr=+neon,+v8.2a,+dotprod",
-            topi.arm_cpu.conv2d_NCHWc_int8,
-            topi.arm_cpu.schedule_conv2d_NCHWc_int8,
-            8,
-            build_only_aarch64,
+            batch,
+            in_channel,
+            in_size,
+            num_filter,
+            kernel,
+            stride,
+            padding,
+            dilation,
+            add_bias,
+            add_relu,
+        ) = params
+        pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (kernel, kernel))
+        padding_sum = pad_top + pad_left + pad_bottom + pad_right
+        print(
+            "Workload: (%d, %d, %d, %d, %d, %d, %d, %d)"
+            % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation)
         )
-    )
-
-    if in_dtype == "int8":
-        targets += [
-            (
-                "llvm -device arm_cpu -mtriple aarch64-linux-gnu -mattr=+neon",
-                topi.arm_cpu.conv2d_NCHWc_int8,
-                topi.arm_cpu.schedule_conv2d_NCHWc_int8,
-                8,
-                build_only_aarch64,
-            ),
-            (
-                "rocm -mattr=+dotprod",
-                lambda a, w, s, p, d, l, ol, o: topi.cuda.conv2d_NCHWc_int8(a, w, s, p, d, l, o),
-                topi.cuda.schedule_conv2d_NCHWc_int8,
-                4,
-                False,
-            ),
-        ]
-
-    for target, compute, schedule, oc_block_factor, build_only in targets:
-        check_target(target, compute, schedule, oc_block_factor, build_only)
-
-
-def verify_conv2d_nchw_int8(
-    in_dtype,
-    batch,
-    in_channel,
-    in_size,
-    num_filter,
-    kernel,
-    stride,
-    padding,
-    dilation=1,
-    add_bias=False,
-    add_relu=False,
-):
-    pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (kernel, kernel))
-    padding_sum = pad_top + pad_left + pad_bottom + pad_right
-    print(
-        "Workload: (%d, %d, %d, %d, %d, %d, %d, %d)"
-        % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation)
-    )
-
-    in_height = in_width = in_size
-
-    A = te.placeholder((batch, in_channel, in_height, in_width), name="A", dtype=in_dtype)
-    W = te.placeholder((num_filter, in_channel, kernel, kernel), name="W", dtype=in_dtype)
-    bias = te.placeholder((num_filter, 1, 1), name="bias", dtype=in_dtype)
-
-    a_shape = get_const_tuple(A.shape)
-    w_shape = get_const_tuple(W.shape)
-    bias_shape = get_const_tuple(bias.shape)
-    dtype = A.dtype
-
-    @memoize("topi.tests.test_topi_conv2d_int8.verify_conv2d_nchw")
-    def get_ref_data():
-        a_np = np.random.randint(low=-128, high=127, size=a_shape).astype(dtype)
-        w_np = np.random.randint(low=-128, high=128, size=w_shape).astype(dtype)
-        b_np = np.random.uniform(size=bias_shape).astype(dtype)
-        dw_np = tvm.topi.testing.dilate_python(w_np, (1, 1, dilation, dilation))
-        c_np = tvm.topi.testing.conv2d_nchw_python(a_np, dw_np, stride, padding).astype(dtype)
-
-        if add_bias:
-            b_np = np.random.uniform(size=bias_shape).astype(dtype)
-            c_np += b_np
-        if add_relu:
-            c_np = np.maximum(c_np, 0)
-
-        return a_np, w_np, b_np, c_np
-
-    a_np, w_np, b_np, c_np = get_ref_data()
-
-    def verify_workload_padding():
-        _, _, out_height, out_width = get_const_tuple(c_np.shape)
-        wkl = _get_workload(A, W, (stride, stride), padding, dilation, dtype)
-
-        # for testing functionality,
-        # we choose arbitrary int32_lanes and num_int8_elements can divide the channel,
-        # regardless of the performance.
-        int32_lanes, num_int8_elements = num_filter, in_channel
 
-        # check if tile_ow candidates are the factors of the right output weight.
-        cfg = autotvm.get_config()
-        fallback_schedule_cpu_common_int8(cfg, wkl, int32_lanes, num_int8_elements)
-        ow_tile = np.prod(cfg["tile_ow"].size)
-
-        tvm.testing.assert_allclose(ow_tile, out_width)
+        in_height = in_width = in_size
+
+        A = te.placeholder((batch, in_channel, in_height, in_width), name="A", dtype=in_dtype)
+        W = te.placeholder((num_filter, in_channel, kernel, kernel), name="W", dtype=in_dtype)
+
+        a_shape = get_const_tuple(A.shape)
+        w_shape = get_const_tuple(W.shape)
+        dtype = A.dtype
+        out_dtype = "int32" if in_dtype == "int8" else "uint32"
+        lo = -128 if in_dtype == "int8" else 0
+        hi = 127 if in_dtype == "int8" else 255
+
+        def check_target(target, compute, schedule, oc_block_factor, build_only):
+            dev = tvm.device(target, 0)
+            if not tvm.testing.device_enabled(target):
+                print("Skip because %s is not enabled" % target)
+                return
+            if target == "cuda" and not tvm.contrib.nvcc.have_int8(dev.compute_version):
+                print("Skip because int8 intrinsics are not available")
+                return
+
+            bias = te.placeholder(
+                (num_filter // oc_block_factor, 1, 1, oc_block_factor), name="bias", dtype=out_dtype
+            )
+            bias_shape = get_const_tuple(bias.shape)
 
-    def check_target(target):
-        dev = tvm.device(target, 0)
-        if not tvm.testing.device_enabled(target):
-            print("Skip because %s is not enabled" % target)
-            return
-        if target == "cuda" and not tvm.contrib.nvcc.have_int8(dev.compute_version):
-            print("Skip because int8 intrinsics are not available")
-            return
+            @memoize("topi.tests.test_topi_conv2d_int8.verify_conv2d_nchw")
+            def get_ref_data():
+                a_np = np.random.randint(low=lo, high=hi, size=a_shape).astype(out_dtype)
+                w_np = np.random.randint(low=lo, high=hi, size=w_shape).astype(out_dtype)
+                b_np = np.random.uniform(size=bias_shape).astype(out_dtype)
+                dw_np = tvm.topi.testing.dilate_python(w_np, (1, 1, dilation, dilation))
+                c_np = tvm.topi.testing.conv2d_nchw_python(a_np, dw_np, stride, padding).astype(
+                    out_dtype
+                )
+
+                # convert to NCHWc
+                _, _, out_height, out_width = c_np.shape
+                c_np = c_np.reshape(
+                    (batch, num_filter // oc_block_factor, oc_block_factor, out_height, out_width)
+                ).transpose(0, 1, 3, 4, 2)
+
+                if add_bias:
+                    b_np = np.random.uniform(size=bias_shape).astype(out_dtype)
+                    c_np += b_np
+                if add_relu:
+                    c_np = np.maximum(c_np, 0)
+
+                return a_np, w_np, b_np, c_np
+
+            a_np, w_np, b_np, c_np = get_ref_data()
+
+            with tvm.target.Target(target):
+                C = compute(
+                    A,
+                    W,
+                    (stride, stride),
+                    padding,
+                    (dilation, dilation),
+                    "NCHW",
+                    "NCHW",
+                    out_dtype,
+                )
+                if add_bias:
+                    C = topi.add(C, bias)
+                if add_relu:
+                    C = topi.nn.relu(C)
+                s = schedule([C])
+
+            a = tvm.nd.array(a_np.astype(dtype), dev)
+            w = tvm.nd.array(w_np.astype(dtype), dev)
+            b = tvm.nd.array(b_np.astype(out_dtype), dev)
+            c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), dev)
 
-        print("Running on target: %s" % target)
-        with tvm.target.Target(target):
-            C = topi.cuda.conv2d_nchw_int8(
-                A, W, (stride, stride), padding, (dilation, dilation), dtype
-            )
             if add_bias:
-                C = topi.add(C, bias)
-            if add_relu:
-                C = topi.nn.relu(C)
-            s = topi.cuda.schedule_conv2d_nchw_int8([C])
-
-        a = tvm.nd.array(a_np, dev)
-        w = tvm.nd.array(w_np, dev)
-        b = tvm.nd.array(b_np, dev)
-        c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), dev)
-        if add_bias:
-            tvm.build(
-                s,
-                [A, W, bias, C],
-                target,
-                name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
-                % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation),
-            )
-            func = tvm.build(
-                s,
-                [A, W, bias, C],
-                target,
-                name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
-                % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation),
-            )
-            func(a, w, b, c)
-        else:
+                compile_args = [A, W, bias, C]
+                run_args = [a, w, b, c]
+            else:
+                compile_args = [A, W, C]
+                run_args = [a, w, c]
+
             func = tvm.build(
                 s,
-                [A, W, C],
+                compile_args,
                 target,
                 name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
                 % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation),
             )
-            func(a, w, c)
-        tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-5)
 
-    verify_workload_padding()
+            if build_only:
+                return
 
-    for target in ["cuda"]:
-        check_target(target)
+            print("Running on target: %s" % target)
 
+            func(*run_args)
 
-@pytest.mark.parametrize("in_dtype", ["int8", "uint8"])
-def test_conv2d_nchw(in_dtype):
-    with Int8Fallback():
-        # ResNet18 workloads where channels in / out are multiple of oc_block_factor
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 56, 64, 3, 1, 1)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 56, 64, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 56, 128, 3, 2, 1)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 56, 128, 1, 2, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 128, 28, 128, 3, 1, 1)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 128, 28, 256, 3, 2, 1)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 128, 28, 256, 1, 2, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 256, 14, 256, 3, 1, 1)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 256, 14, 512, 3, 2, 1)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 256, 14, 512, 1, 2, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 512, 7, 512, 3, 1, 1)
+            tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-5)
 
-        # bias, relu
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 56, 64, 3, 1, 1, add_relu=True)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 56, 64, 3, 1, 1, add_bias=True)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 56, 64, 3, 1, 1, add_bias=True, add_relu=True)
+        targets = [
+            (
+                "cuda",
+                lambda a, w, s, p, d, l, ol, o: topi.cuda.conv2d_NCHWc_int8(a, w, s, p, d, l, o),
+                topi.cuda.schedule_conv2d_NCHWc_int8,
+                4,
+                False,
+            ),
+            # Disable on CI since it does not support spirv int8 dot product
+            # (
+            #     "vulkan -from_device=0",
+            #     lambda a, w, s, p, d, l, ol, o: topi.cuda.conv2d_NCHWc_int8(a, w, s, p, d, l, o),
+            #     topi.cuda.schedule_conv2d_NCHWc_int8,
+            #     4,
+            #     False,
+            # ),
+        ]
 
-        # dilation = 2
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 56, 64, 3, 1, 1, dilation=2)
+        build_only_aarch64 = platform.machine() != "aarch64"
 
-        # batch size
-        verify_conv2d_NCHWc_int8(in_dtype, 4, 64, 56, 64, 3, 1, 1)
-        verify_conv2d_NCHWc_int8(in_dtype, 9, 64, 56, 64, 3, 1, 1)
+        targets.append(
+            (
+                "llvm -device arm_cpu -mtriple aarch64-linux-gnu -mattr=+neon,+v8.2a,+dotprod",
+                topi.arm_cpu.conv2d_NCHWc_int8,
+                topi.arm_cpu.schedule_conv2d_NCHWc_int8,
+                8,
+                build_only_aarch64,
+            )
+        )
 
-        # weird workloads
-        verify_conv2d_NCHWc_int8(in_dtype, 4, 4, 4, 8, 4, 4, 4)
+        if in_dtype == "int8":
+            targets += [
+                (
+                    "llvm -device arm_cpu -mtriple aarch64-linux-gnu -mattr=+neon",
+                    topi.arm_cpu.conv2d_NCHWc_int8,
+                    topi.arm_cpu.schedule_conv2d_NCHWc_int8,
+                    8,
+                    build_only_aarch64,
+                ),
+                (
+                    "rocm -mattr=+dotprod",
+                    lambda a, w, s, p, d, l, ol, o: topi.cuda.conv2d_NCHWc_int8(
+                        a, w, s, p, d, l, o
+                    ),
+                    topi.cuda.schedule_conv2d_NCHWc_int8,
+                    4,
+                    False,
+                ),
+            ]
+
+        for target, compute, schedule, oc_block_factor, build_only in targets:
+            check_target(target, compute, schedule, oc_block_factor, build_only)
+
+
+# Conv2d NCHW int8 schedule testing. Internally, it uses NCHWc schedule. So, just
+# performing basic testing - one test for all different scenarios - batch, dilation etc..
+@pytest.mark.parametrize("in_dtype", ["int8", "uint8"])
+@pytest.mark.parametrize(
+    "params",
+    [
+        (1, 64, 56, 64, 3, 1, 1, 1, False, False),
+        (1, 64, 56, 64, 3, 1, 1, 1, False, True),
+        (1, 64, 56, 64, 3, 1, 1, 2, False, False),
+        (9, 64, 56, 64, 3, 1, 1, 1, False, False),
+        (4, 4, 4, 4, 4, 4, 4, 1, False, False),
+        (1, 32, 149, 32, 3, 1, 0, 1, False, False),
+        (7, 32, 149, 32, 3, 1, 0, 1, False, False),
+        (1, 32, 35, 64, 7, 2, (0, 0, 1, 1), 1, False, False),
+        (1, 32, 35, 64, 7, 2, (0, 0, 2, 2), 1, False, False),
+    ],
+)
+def test_conv2d_nchw_int8(in_dtype, params):
+    with Int8Fallback():
+        (
+            batch,
+            in_channel,
+            in_size,
+            num_filter,
+            kernel,
+            stride,
+            padding,
+            dilation,
+            add_bias,
+            add_relu,
+        ) = params
+        pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (kernel, kernel))
+        padding_sum = pad_top + pad_left + pad_bottom + pad_right
+        print(
+            "Workload: (%d, %d, %d, %d, %d, %d, %d, %d)"
+            % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation)
+        )
 
-        # inception v3 workloads where channels in / out are multiple of oc_block_factor
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 32, 149, 32, 3, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 32, 147, 64, 3, 1, 1)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 73, 80, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 80, 73, 192, 3, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 192, 35, 64, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 192, 35, 48, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 48, 35, 64, 5, 1, 2)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 35, 96, 3, 1, 1)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 96, 35, 96, 3, 1, 1)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 192, 35, 32, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 256, 35, 64, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 256, 35, 48, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 288, 35, 64, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 288, 35, 48, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 288, 35, 384, 3, 2, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 96, 35, 96, 3, 2, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 768, 17, 192, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 768, 17, 128, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 128, 17, 128, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 128, 17, 192, 7, 1, 3)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 128, 17, 128, 7, 1, 3)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 128, 17, 192, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 768, 17, 160, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 160, 17, 160, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 160, 17, 192, 7, 1, 3)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 160, 17, 160, 7, 1, 3)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 160, 17, 192, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 192, 17, 192, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 192, 17, 192, 7, 1, 3)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 192, 17, 320, 3, 2, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 192, 17, 192, 3, 2, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 1280, 8, 320, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 1280, 8, 384, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 384, 8, 384, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 384, 8, 384, 3, 1, 1)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 1280, 8, 448, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 448, 8, 384, 3, 1, 1)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 1280, 8, 192, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 2048, 8, 320, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 2048, 8, 384, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 2048, 8, 448, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 2048, 8, 192, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 1024, 19, 88, 3, 1, 1)
+        in_height = in_width = in_size
 
-        # batch > 1
-        verify_conv2d_NCHWc_int8(in_dtype, 7, 32, 149, 32, 3, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 8, 32, 149, 32, 3, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 32, 32, 149, 32, 3, 1, 0)
+        A = te.placeholder((batch, in_channel, in_height, in_width), name="A", dtype=in_dtype)
+        W = te.placeholder((num_filter, in_channel, kernel, kernel), name="W", dtype=in_dtype)
+        bias = te.placeholder((num_filter, 1, 1), name="bias", dtype=in_dtype)
 
-        # Asymmetric padding
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 32, 35, 64, 7, 2, (0, 0, 1, 1))
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 8, 128, 3, 1, (3, 3, 2, 2))
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 8, 64, 1, 1, (1, 2, 2, 1))
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 17, 192, 1, 1, (1, 2))
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 8, 64, 3, 1, (3, 1))
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 128, 8, 384, 3, 1, (0, 2))
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 8, 64, 1, 1, "VALID")
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 392, 8, 64, 3, 1, "VALID")
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 512, 19, 64, 1, 1, "SAME")
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 16, 32, 2, 1, "SAME")
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 8, 64, 3, 1, (1, 2, 2, 1), add_relu=True)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 8, 64, 5, 2, (1, 3), add_bias=True)
-        verify_conv2d_NCHWc_int8(
-            in_dtype, 1, 64, 56, 64, 3, 1, "VALID", add_bias=True, add_relu=True
-        )
-        verify_conv2d_NCHWc_int8(
-            in_dtype, 1, 64, 56, 64, 24, 1, "SAME", add_bias=True, add_relu=True
-        )
+        a_shape = get_const_tuple(A.shape)
+        w_shape = get_const_tuple(W.shape)
+        bias_shape = get_const_tuple(bias.shape)
+        dtype = A.dtype
+
+        @memoize("topi.tests.test_topi_conv2d_int8.verify_conv2d_nchw")
+        def get_ref_data():
+            a_np = np.random.randint(low=-128, high=127, size=a_shape).astype(dtype)
+            w_np = np.random.randint(low=-128, high=128, size=w_shape).astype(dtype)
+            b_np = np.random.uniform(size=bias_shape).astype(dtype)
+            dw_np = tvm.topi.testing.dilate_python(w_np, (1, 1, dilation, dilation))
+            c_np = tvm.topi.testing.conv2d_nchw_python(a_np, dw_np, stride, padding).astype(dtype)
 
-        # Conv2d NCHW int8 schedule testing. Internally, it uses NCHWc schedule. So, just
-        # performing basic testing - one test for all different scenarios - batch, dilation etc..
-        verify_conv2d_nchw_int8(in_dtype, 1, 64, 56, 64, 3, 1, 1)
-        verify_conv2d_nchw_int8(in_dtype, 1, 64, 56, 64, 3, 1, 1, add_relu=True)
-        verify_conv2d_nchw_int8(in_dtype, 1, 64, 56, 64, 3, 1, 1, dilation=2)
-        verify_conv2d_nchw_int8(in_dtype, 9, 64, 56, 64, 3, 1, 1)
-        verify_conv2d_nchw_int8(in_dtype, 4, 4, 4, 4, 4, 4, 4)
-        verify_conv2d_nchw_int8(in_dtype, 1, 32, 149, 32, 3, 1, 0)
-        verify_conv2d_nchw_int8(in_dtype, 7, 32, 149, 32, 3, 1, 0)
-        verify_conv2d_nchw_int8(in_dtype, 1, 32, 35, 64, 7, 2, (0, 0, 1, 1))
-        verify_conv2d_nchw_int8(in_dtype, 1, 32, 35, 64, 7, 2, (0, 0, 2, 2))
+            if add_bias:
+                b_np = np.random.uniform(size=bias_shape).astype(dtype)
+                c_np += b_np
+            if add_relu:
+                c_np = np.maximum(c_np, 0)
 
+            return a_np, w_np, b_np, c_np
 
-def test_conv2d_nhwc():
-    with Int8Fallback():
-        # Subset of inception v3 expanded (dilation > 1, batch > 1, 'VALID' padding)
-        verify_conv2d_NHWC_gemm_int8(1, 3, 299, 32, 3, 2, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 32, 149, 32, 3, 1, "SAME", dilation=2)
-        verify_conv2d_NHWC_gemm_int8(4, 32, 147, 64, 3, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 64, 73, 80, 1, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 80, 73, 192, 3, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 192, 35, 48, 1, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 192, 35, 64, 1, 1, "VALID")
-        verify_conv2d_NHWC_gemm_int8(1, 192, 35, 32, 1, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 48, 35, 64, 5, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 96, 35, 96, 3, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 256, 35, 48, 1, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 256, 35, 64, 1, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 288, 35, 64, 1, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 288, 35, 48, 1, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 96, 35, 96, 3, 2, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 128, 17, 192, 7, 1, "SAME", dilation=2)
-        verify_conv2d_NHWC_gemm_int8(1, 160, 17, 160, 7, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 160, 17, 192, 1, 1, "VALID")
-        verify_conv2d_NHWC_gemm_int8(1, 192, 17, 192, 1, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 768, 5, 128, 1, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 192, 17, 320, 3, 2, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 192, 17, 192, 3, 2, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 1280, 8, 192, 1, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 1280, 8, 384, 1, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 1280, 8, 320, 1, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 1280, 8, 448, 1, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 384, 8, 384, 1, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 384, 8, 384, 3, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 448, 8, 384, 3, 1, "VALID")
-        verify_conv2d_NHWC_gemm_int8(1, 2048, 8, 320, 1, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 2048, 8, 448, 1, 1, "SAME", add_bias=True, add_relu=True)
-        verify_conv2d_NHWC_gemm_int8(1, 2048, 8, 192, 1, 1, "SAME", add_bias=True)
-
-        # Let's also verify that it compiles fine on AArch64 targets
-        compile_conv2d_NHWC_gemm_int8_arm(1, 3, 299, 32, 3, 2, "SAME")
+        a_np, w_np, b_np, c_np = get_ref_data()
+
+        def verify_workload_padding():
+            _, _, out_height, out_width = get_const_tuple(c_np.shape)
+            wkl = _get_workload(A, W, (stride, stride), padding, dilation, dtype)
+
+            # for testing functionality,
+            # we choose arbitrary int32_lanes and num_int8_elements can divide the channel,
+            # regardless of the performance.
+            int32_lanes, num_int8_elements = num_filter, in_channel
+
+            # check if tile_ow candidates are the factors of the right output weight.
+            cfg = autotvm.get_config()
+            fallback_schedule_cpu_common_int8(cfg, wkl, int32_lanes, num_int8_elements)
+            ow_tile = np.prod(cfg["tile_ow"].size)
+
+            tvm.testing.assert_allclose(ow_tile, out_width)
+
+        def check_target(target):
+            dev = tvm.device(target, 0)
+            if not tvm.testing.device_enabled(target):
+                print("Skip because %s is not enabled" % target)
+                return
+            if target == "cuda" and not tvm.contrib.nvcc.have_int8(dev.compute_version):
+                print("Skip because int8 intrinsics are not available")
+                return
+
+            print("Running on target: %s" % target)
+            with tvm.target.Target(target):
+                C = topi.cuda.conv2d_nchw_int8(
+                    A, W, (stride, stride), padding, (dilation, dilation), dtype
+                )
+                if add_bias:
+                    C = topi.add(C, bias)
+                if add_relu:
+                    C = topi.nn.relu(C)
+                s = topi.cuda.schedule_conv2d_nchw_int8([C])
+
+            a = tvm.nd.array(a_np, dev)
+            w = tvm.nd.array(w_np, dev)
+            b = tvm.nd.array(b_np, dev)
+            c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), dev)
+            if add_bias:
+                func = tvm.build(
+                    s,
+                    [A, W, bias, C],
+                    target,
+                    name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
+                    % (
+                        batch,
+                        in_channel,
+                        in_size,
+                        num_filter,
+                        kernel,
+                        stride,
+                        padding_sum,
+                        dilation,
+                    ),
+                )
+                func(a, w, b, c)
+            else:
+                func = tvm.build(
+                    s,
+                    [A, W, C],
+                    target,
+                    name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
+                    % (
+                        batch,
+                        in_channel,
+                        in_size,
+                        num_filter,
+                        kernel,
+                        stride,
+                        padding_sum,
+                        dilation,
+                    ),
+                )
+                func(a, w, c)
+            tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-5)
+
+        verify_workload_padding()
+
+        for target in ["cuda"]:
+            check_target(target)

Review Comment:
   ```suggestion
           check_target("cuda")
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] ekalda commented on pull request #13669: [TOPI][bugfix] Fix a bug in arm_cpu int8 dotprod schedule and modernize tests

Posted by GitBox <gi...@apache.org>.
ekalda commented on PR #13669:
URL: https://github.com/apache/tvm/pull/13669#issuecomment-1366519663

   cc @leandron @Mousius 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] tvm-bot commented on pull request #13669: [TOPI][bugfix] Fix a bug in arm_cpu int8 dotprod schedule and modernize tests

Posted by GitBox <gi...@apache.org>.
tvm-bot commented on PR #13669:
URL: https://github.com/apache/tvm/pull/13669#issuecomment-1366519043

   <!---bot-comment-->
   
   Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from [Reviewers](https://github.com/apache/incubator-tvm/blob/master/CONTRIBUTORS.md#reviewers) by @-ing them in a comment.
   
   <!--bot-comment-ccs-start-->
    * No users to tag found in teams: `topi`, `bugfix` <sub>See [#10317](https://github.com/apache/tvm/issues/10317) for details</sub><!--bot-comment-ccs-end-->
   
   <sub>Generated by [tvm-bot](https://github.com/apache/tvm/blob/main/ci/README.md#github-actions)</sub>


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] ekalda commented on a diff in pull request #13669: [TOPI][bugfix] Fix a bug in arm_cpu int8 dotprod schedule and modernize tests

Posted by GitBox <gi...@apache.org>.
ekalda commented on code in PR #13669:
URL: https://github.com/apache/tvm/pull/13669#discussion_r1058905667


##########
tests/python/topi/python/test_topi_conv2d_int8.py:
##########
@@ -298,378 +175,462 @@ def get_ref_data():
 
         a_np, w_np, b_np, c_np = get_ref_data()
 
-        with tvm.target.Target(target):
-            C = compute(
-                A,
-                W,
-                (stride, stride),
-                padding,
-                (dilation, dilation),
-                "NCHW",
-                "NCHW",
-                out_dtype,
-            )
+        dev = tvm.device(target, 0)
+        with tvm.target.Target(target) as tvm_target:
+            C = compute(A, W, (stride, stride), padding, (dilation, dilation), dtype)
             if add_bias:
                 C = topi.add(C, bias)
             if add_relu:
                 C = topi.nn.relu(C)
             s = schedule([C])
 
-        a = tvm.nd.array(a_np.astype(dtype), dev)
-        w = tvm.nd.array(w_np.astype(dtype), dev)
-        b = tvm.nd.array(b_np.astype(out_dtype), dev)
-        c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), dev)
-
-        if add_bias:
-            compile_args = [A, W, bias, C]
-            run_args = [a, w, b, c]
-        else:
-            compile_args = [A, W, C]
-            run_args = [a, w, c]
-
-        func = tvm.build(
-            s,
-            compile_args,
-            target,
-            name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
-            % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation),
-        )
+            a = tvm.nd.array(a_np, dev)
+            w = tvm.nd.array(w_np, dev)
+            b = tvm.nd.array(b_np, dev)
+            c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), dev)
 
-        if build_only:
-            return
+            build_inputs = [A, W, bias, C] if add_bias else [A, W, C]
+            inference_inputs = (a, w, b, c) if add_bias else (a, w, c)
+
+            func = tvm.build(
+                s,
+                build_inputs,
+                target,
+                name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
+                % (
+                    batch,
+                    in_channel,
+                    in_size,
+                    num_filter,
+                    kernel,
+                    stride,
+                    padding_sum,
+                    dilation,
+                ),
+            )
 
-        print("Running on target: %s" % target)
+            build_only = tvm_target.features.is_aarch64 and (platform.machine() != "aarch64")
 
-        func(*run_args)
+            if not build_only:
+                print("Running on target: %s" % target)
+                func(*inference_inputs)
+                tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-5)
 
-        tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-5)
 
-    targets = [
-        (
-            "cuda",
-            lambda a, w, s, p, d, l, ol, o: topi.cuda.conv2d_NCHWc_int8(a, w, s, p, d, l, o),
-            topi.cuda.schedule_conv2d_NCHWc_int8,
-            4,
-            False,
-        ),
-        # Disable on CI since it does not support spirv int8 dot product
-        # (
-        #     "vulkan -from_device=0",
-        #     lambda a, w, s, p, d, l, ol, o: topi.cuda.conv2d_NCHWc_int8(a, w, s, p, d, l, o),
-        #     topi.cuda.schedule_conv2d_NCHWc_int8,
-        #     4,
-        #     False,
-        # ),
-    ]
-
-    build_only_aarch64 = platform.machine() != "aarch64"
-
-    targets.append(
+@pytest.mark.parametrize("in_dtype", ["int8", "uint8"])
+@pytest.mark.parametrize(
+    "params",
+    [
+        # ResNet18 workloads where channels in / out are multiple of oc_block_factor
+        (1, 64, 56, 64, 3, 1, 1, 1, False, False),
+        (1, 64, 56, 64, 1, 1, 0, 1, False, False),
+        (1, 64, 56, 128, 3, 2, 1, 1, False, False),
+        (1, 64, 56, 128, 1, 2, 0, 1, False, False),
+        (1, 128, 28, 128, 3, 1, 1, 1, False, False),
+        (1, 128, 28, 256, 3, 2, 1, 1, False, False),
+        (1, 128, 28, 256, 1, 2, 0, 1, False, False),
+        (1, 256, 14, 256, 3, 1, 1, 1, False, False),
+        (1, 256, 14, 512, 3, 2, 1, 1, False, False),
+        (1, 256, 14, 512, 1, 2, 0, 1, False, False),
+        (1, 512, 7, 512, 3, 1, 1, 1, False, False),
+        # bias, relu
+        (1, 64, 56, 64, 3, 1, 1, 1, False, True),
+        (1, 64, 56, 64, 3, 1, 1, 1, True, False),
+        (1, 64, 56, 64, 3, 1, 1, 1, True, True),
+        # dilation = 2
+        (1, 64, 56, 64, 3, 1, 1, 2, False, False),
+        # batch size
+        (4, 64, 56, 64, 3, 1, 1, 1, False, False),
+        (9, 64, 56, 64, 3, 1, 1, 1, False, False),
+        # weird workloads
+        (4, 4, 4, 8, 4, 4, 4, 1, False, False),
+        # inception v3 workloads where channels in / out are multiple of oc_block_factor
+        (1, 32, 149, 32, 3, 1, 0, 1, False, False),
+        (1, 32, 147, 64, 3, 1, 1, 1, False, False),
+        (1, 64, 73, 80, 1, 1, 0, 1, False, False),
+        (1, 80, 73, 192, 3, 1, 0, 1, False, False),
+        (1, 192, 35, 64, 1, 1, 0, 1, False, False),
+        (1, 192, 35, 48, 1, 1, 0, 1, False, False),
+        (1, 48, 35, 64, 5, 1, 2, 1, False, False),
+        (1, 64, 35, 96, 3, 1, 1, 1, False, False),
+        (1, 96, 35, 96, 3, 1, 1, 1, False, False),
+        (1, 192, 35, 32, 1, 1, 0, 1, False, False),
+        (1, 256, 35, 64, 1, 1, 0, 1, False, False),
+        (1, 256, 35, 48, 1, 1, 0, 1, False, False),
+        (1, 288, 35, 64, 1, 1, 0, 1, False, False),
+        (1, 288, 35, 48, 1, 1, 0, 1, False, False),
+        (1, 288, 35, 384, 3, 2, 0, 1, False, False),
+        (1, 96, 35, 96, 3, 2, 0, 1, False, False),
+        (1, 768, 17, 192, 1, 1, 0, 1, False, False),
+        (1, 768, 17, 128, 1, 1, 0, 1, False, False),
+        (1, 128, 17, 128, 1, 1, 0, 1, False, False),
+        (1, 128, 17, 192, 7, 1, 3, 1, False, False),
+        (1, 128, 17, 128, 7, 1, 3, 1, False, False),
+        (1, 128, 17, 192, 1, 1, 0, 1, False, False),
+        (1, 768, 17, 160, 1, 1, 0, 1, False, False),
+        (1, 160, 17, 160, 1, 1, 0, 1, False, False),
+        (1, 160, 17, 192, 7, 1, 3, 1, False, False),
+        (1, 160, 17, 160, 7, 1, 3, 1, False, False),
+        (1, 160, 17, 192, 1, 1, 0, 1, False, False),
+        (1, 192, 17, 192, 1, 1, 0, 1, False, False),
+        (1, 192, 17, 192, 7, 1, 3, 1, False, False),
+        (1, 192, 17, 320, 3, 2, 0, 1, False, False),
+        (1, 192, 17, 192, 3, 2, 0, 1, False, False),
+        (1, 1280, 8, 320, 1, 1, 0, 1, False, False),
+        (1, 1280, 8, 384, 1, 1, 0, 1, False, False),
+        (1, 384, 8, 384, 1, 1, 0, 1, False, False),
+        (1, 384, 8, 384, 3, 1, 1, 1, False, False),
+        (1, 1280, 8, 448, 1, 1, 0, 1, False, False),
+        (1, 448, 8, 384, 3, 1, 1, 1, False, False),
+        (1, 1280, 8, 192, 1, 1, 0, 1, False, False),
+        (1, 2048, 8, 320, 1, 1, 0, 1, False, False),
+        (1, 2048, 8, 384, 1, 1, 0, 1, False, False),
+        (1, 2048, 8, 448, 1, 1, 0, 1, False, False),
+        (1, 2048, 8, 192, 1, 1, 0, 1, False, False),
+        (1, 1024, 19, 88, 3, 1, 1, 1, False, False),
+        # batch > 1
+        (7, 32, 149, 32, 3, 1, 0, 1, False, False),
+        (8, 32, 149, 32, 3, 1, 0, 1, False, False),
+        (32, 32, 149, 32, 3, 1, 0, 1, False, False),
+        # Asymmetric padding
+        (1, 32, 35, 64, 7, 2, (0, 0, 1, 1), 1, False, False),
+        (1, 64, 8, 128, 3, 1, (3, 3, 2, 2), 1, False, False),
+        (1, 64, 8, 64, 1, 1, (1, 2, 2, 1), 1, False, False),
+        (1, 64, 17, 192, 1, 1, (1, 2), 1, False, False),
+        (1, 64, 8, 64, 3, 1, (3, 1), 1, False, False),
+        (1, 128, 8, 384, 3, 1, (0, 2), 1, False, False),
+        (1, 64, 8, 64, 1, 1, "VALID", 1, False, False),
+        (1, 392, 8, 64, 3, 1, "VALID", 1, False, False),
+        (1, 512, 19, 64, 1, 1, "SAME", 1, False, False),
+        (1, 64, 16, 32, 2, 1, "SAME", 1, False, False),
+        (1, 64, 8, 64, 3, 1, (1, 2, 2, 1), 1, False, True),
+        (1, 64, 8, 64, 5, 2, (1, 3), 1, True, False),
+        (1, 64, 56, 64, 3, 1, "VALID", 1, True, True),
+        (1, 64, 56, 64, 24, 1, "SAME", 1, True, True),
+    ],
+)
+def test_conv2d_NCHWc_int8(in_dtype, params):
+    with Int8Fallback():
         (
-            "llvm -device arm_cpu -mtriple aarch64-linux-gnu -mattr=+neon,+v8.2a,+dotprod",
-            topi.arm_cpu.conv2d_NCHWc_int8,
-            topi.arm_cpu.schedule_conv2d_NCHWc_int8,
-            8,
-            build_only_aarch64,
+            batch,
+            in_channel,
+            in_size,
+            num_filter,
+            kernel,
+            stride,
+            padding,
+            dilation,
+            add_bias,
+            add_relu,
+        ) = params
+        pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (kernel, kernel))
+        padding_sum = pad_top + pad_left + pad_bottom + pad_right
+        print(
+            "Workload: (%d, %d, %d, %d, %d, %d, %d, %d)"
+            % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation)
         )
-    )
-
-    if in_dtype == "int8":
-        targets += [
-            (
-                "llvm -device arm_cpu -mtriple aarch64-linux-gnu -mattr=+neon",
-                topi.arm_cpu.conv2d_NCHWc_int8,
-                topi.arm_cpu.schedule_conv2d_NCHWc_int8,
-                8,
-                build_only_aarch64,
-            ),
-            (
-                "rocm -mattr=+dotprod",
-                lambda a, w, s, p, d, l, ol, o: topi.cuda.conv2d_NCHWc_int8(a, w, s, p, d, l, o),
-                topi.cuda.schedule_conv2d_NCHWc_int8,
-                4,
-                False,
-            ),
-        ]
-
-    for target, compute, schedule, oc_block_factor, build_only in targets:
-        check_target(target, compute, schedule, oc_block_factor, build_only)
-
-
-def verify_conv2d_nchw_int8(
-    in_dtype,
-    batch,
-    in_channel,
-    in_size,
-    num_filter,
-    kernel,
-    stride,
-    padding,
-    dilation=1,
-    add_bias=False,
-    add_relu=False,
-):
-    pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (kernel, kernel))
-    padding_sum = pad_top + pad_left + pad_bottom + pad_right
-    print(
-        "Workload: (%d, %d, %d, %d, %d, %d, %d, %d)"
-        % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation)
-    )
-
-    in_height = in_width = in_size
-
-    A = te.placeholder((batch, in_channel, in_height, in_width), name="A", dtype=in_dtype)
-    W = te.placeholder((num_filter, in_channel, kernel, kernel), name="W", dtype=in_dtype)
-    bias = te.placeholder((num_filter, 1, 1), name="bias", dtype=in_dtype)
-
-    a_shape = get_const_tuple(A.shape)
-    w_shape = get_const_tuple(W.shape)
-    bias_shape = get_const_tuple(bias.shape)
-    dtype = A.dtype
-
-    @memoize("topi.tests.test_topi_conv2d_int8.verify_conv2d_nchw")
-    def get_ref_data():
-        a_np = np.random.randint(low=-128, high=127, size=a_shape).astype(dtype)
-        w_np = np.random.randint(low=-128, high=128, size=w_shape).astype(dtype)
-        b_np = np.random.uniform(size=bias_shape).astype(dtype)
-        dw_np = tvm.topi.testing.dilate_python(w_np, (1, 1, dilation, dilation))
-        c_np = tvm.topi.testing.conv2d_nchw_python(a_np, dw_np, stride, padding).astype(dtype)
-
-        if add_bias:
-            b_np = np.random.uniform(size=bias_shape).astype(dtype)
-            c_np += b_np
-        if add_relu:
-            c_np = np.maximum(c_np, 0)
-
-        return a_np, w_np, b_np, c_np
-
-    a_np, w_np, b_np, c_np = get_ref_data()
-
-    def verify_workload_padding():
-        _, _, out_height, out_width = get_const_tuple(c_np.shape)
-        wkl = _get_workload(A, W, (stride, stride), padding, dilation, dtype)
-
-        # for testing functionality,
-        # we choose arbitrary int32_lanes and num_int8_elements can divide the channel,
-        # regardless of the performance.
-        int32_lanes, num_int8_elements = num_filter, in_channel
 
-        # check if tile_ow candidates are the factors of the right output weight.
-        cfg = autotvm.get_config()
-        fallback_schedule_cpu_common_int8(cfg, wkl, int32_lanes, num_int8_elements)
-        ow_tile = np.prod(cfg["tile_ow"].size)
-
-        tvm.testing.assert_allclose(ow_tile, out_width)
+        in_height = in_width = in_size
+
+        A = te.placeholder((batch, in_channel, in_height, in_width), name="A", dtype=in_dtype)
+        W = te.placeholder((num_filter, in_channel, kernel, kernel), name="W", dtype=in_dtype)
+
+        a_shape = get_const_tuple(A.shape)
+        w_shape = get_const_tuple(W.shape)
+        dtype = A.dtype
+        out_dtype = "int32" if in_dtype == "int8" else "uint32"
+        lo = -128 if in_dtype == "int8" else 0
+        hi = 127 if in_dtype == "int8" else 255
+
+        def check_target(target, compute, schedule, oc_block_factor, build_only):
+            dev = tvm.device(target, 0)
+            if not tvm.testing.device_enabled(target):
+                print("Skip because %s is not enabled" % target)
+                return
+            if target == "cuda" and not tvm.contrib.nvcc.have_int8(dev.compute_version):
+                print("Skip because int8 intrinsics are not available")
+                return
+
+            bias = te.placeholder(
+                (num_filter // oc_block_factor, 1, 1, oc_block_factor), name="bias", dtype=out_dtype
+            )
+            bias_shape = get_const_tuple(bias.shape)
 
-    def check_target(target):
-        dev = tvm.device(target, 0)
-        if not tvm.testing.device_enabled(target):
-            print("Skip because %s is not enabled" % target)
-            return
-        if target == "cuda" and not tvm.contrib.nvcc.have_int8(dev.compute_version):
-            print("Skip because int8 intrinsics are not available")
-            return
+            @memoize("topi.tests.test_topi_conv2d_int8.verify_conv2d_nchw")
+            def get_ref_data():
+                a_np = np.random.randint(low=lo, high=hi, size=a_shape).astype(out_dtype)
+                w_np = np.random.randint(low=lo, high=hi, size=w_shape).astype(out_dtype)
+                b_np = np.random.uniform(size=bias_shape).astype(out_dtype)
+                dw_np = tvm.topi.testing.dilate_python(w_np, (1, 1, dilation, dilation))
+                c_np = tvm.topi.testing.conv2d_nchw_python(a_np, dw_np, stride, padding).astype(
+                    out_dtype
+                )
+
+                # convert to NCHWc
+                _, _, out_height, out_width = c_np.shape
+                c_np = c_np.reshape(
+                    (batch, num_filter // oc_block_factor, oc_block_factor, out_height, out_width)
+                ).transpose(0, 1, 3, 4, 2)
+
+                if add_bias:
+                    b_np = np.random.uniform(size=bias_shape).astype(out_dtype)
+                    c_np += b_np
+                if add_relu:
+                    c_np = np.maximum(c_np, 0)
+
+                return a_np, w_np, b_np, c_np
+
+            a_np, w_np, b_np, c_np = get_ref_data()
+
+            with tvm.target.Target(target):
+                C = compute(
+                    A,
+                    W,
+                    (stride, stride),
+                    padding,
+                    (dilation, dilation),
+                    "NCHW",
+                    "NCHW",
+                    out_dtype,
+                )
+                if add_bias:
+                    C = topi.add(C, bias)
+                if add_relu:
+                    C = topi.nn.relu(C)
+                s = schedule([C])
+
+            a = tvm.nd.array(a_np.astype(dtype), dev)
+            w = tvm.nd.array(w_np.astype(dtype), dev)
+            b = tvm.nd.array(b_np.astype(out_dtype), dev)
+            c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), dev)
 
-        print("Running on target: %s" % target)
-        with tvm.target.Target(target):
-            C = topi.cuda.conv2d_nchw_int8(
-                A, W, (stride, stride), padding, (dilation, dilation), dtype
-            )
             if add_bias:
-                C = topi.add(C, bias)
-            if add_relu:
-                C = topi.nn.relu(C)
-            s = topi.cuda.schedule_conv2d_nchw_int8([C])
-
-        a = tvm.nd.array(a_np, dev)
-        w = tvm.nd.array(w_np, dev)
-        b = tvm.nd.array(b_np, dev)
-        c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), dev)
-        if add_bias:
-            tvm.build(
-                s,
-                [A, W, bias, C],
-                target,
-                name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
-                % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation),
-            )
-            func = tvm.build(
-                s,
-                [A, W, bias, C],
-                target,
-                name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
-                % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation),
-            )
-            func(a, w, b, c)
-        else:
+                compile_args = [A, W, bias, C]
+                run_args = [a, w, b, c]
+            else:
+                compile_args = [A, W, C]
+                run_args = [a, w, c]
+
             func = tvm.build(
                 s,
-                [A, W, C],
+                compile_args,
                 target,
                 name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
                 % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation),
             )
-            func(a, w, c)
-        tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-5)
 
-    verify_workload_padding()
+            if build_only:
+                return
 
-    for target in ["cuda"]:
-        check_target(target)
+            print("Running on target: %s" % target)
 
+            func(*run_args)
 
-@pytest.mark.parametrize("in_dtype", ["int8", "uint8"])
-def test_conv2d_nchw(in_dtype):
-    with Int8Fallback():
-        # ResNet18 workloads where channels in / out are multiple of oc_block_factor
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 56, 64, 3, 1, 1)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 56, 64, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 56, 128, 3, 2, 1)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 56, 128, 1, 2, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 128, 28, 128, 3, 1, 1)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 128, 28, 256, 3, 2, 1)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 128, 28, 256, 1, 2, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 256, 14, 256, 3, 1, 1)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 256, 14, 512, 3, 2, 1)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 256, 14, 512, 1, 2, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 512, 7, 512, 3, 1, 1)
+            tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-5)
 
-        # bias, relu
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 56, 64, 3, 1, 1, add_relu=True)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 56, 64, 3, 1, 1, add_bias=True)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 56, 64, 3, 1, 1, add_bias=True, add_relu=True)
+        targets = [
+            (
+                "cuda",
+                lambda a, w, s, p, d, l, ol, o: topi.cuda.conv2d_NCHWc_int8(a, w, s, p, d, l, o),
+                topi.cuda.schedule_conv2d_NCHWc_int8,
+                4,
+                False,
+            ),
+            # Disable on CI since it does not support spirv int8 dot product
+            # (
+            #     "vulkan -from_device=0",
+            #     lambda a, w, s, p, d, l, ol, o: topi.cuda.conv2d_NCHWc_int8(a, w, s, p, d, l, o),
+            #     topi.cuda.schedule_conv2d_NCHWc_int8,
+            #     4,
+            #     False,
+            # ),
+        ]
 
-        # dilation = 2
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 56, 64, 3, 1, 1, dilation=2)
+        build_only_aarch64 = platform.machine() != "aarch64"
 
-        # batch size
-        verify_conv2d_NCHWc_int8(in_dtype, 4, 64, 56, 64, 3, 1, 1)
-        verify_conv2d_NCHWc_int8(in_dtype, 9, 64, 56, 64, 3, 1, 1)
+        targets.append(
+            (
+                "llvm -device arm_cpu -mtriple aarch64-linux-gnu -mattr=+neon,+v8.2a,+dotprod",
+                topi.arm_cpu.conv2d_NCHWc_int8,
+                topi.arm_cpu.schedule_conv2d_NCHWc_int8,
+                8,
+                build_only_aarch64,
+            )
+        )
 
-        # weird workloads
-        verify_conv2d_NCHWc_int8(in_dtype, 4, 4, 4, 8, 4, 4, 4)
+        if in_dtype == "int8":
+            targets += [
+                (
+                    "llvm -device arm_cpu -mtriple aarch64-linux-gnu -mattr=+neon",
+                    topi.arm_cpu.conv2d_NCHWc_int8,
+                    topi.arm_cpu.schedule_conv2d_NCHWc_int8,
+                    8,
+                    build_only_aarch64,
+                ),
+                (
+                    "rocm -mattr=+dotprod",
+                    lambda a, w, s, p, d, l, ol, o: topi.cuda.conv2d_NCHWc_int8(
+                        a, w, s, p, d, l, o
+                    ),
+                    topi.cuda.schedule_conv2d_NCHWc_int8,
+                    4,
+                    False,
+                ),
+            ]
+
+        for target, compute, schedule, oc_block_factor, build_only in targets:
+            check_target(target, compute, schedule, oc_block_factor, build_only)
+
+
+# Conv2d NCHW int8 schedule testing. Internally, it uses NCHWc schedule. So, just
+# performing basic testing - one test for all different scenarios - batch, dilation etc..
+@pytest.mark.parametrize("in_dtype", ["int8", "uint8"])
+@pytest.mark.parametrize(
+    "params",
+    [
+        (1, 64, 56, 64, 3, 1, 1, 1, False, False),
+        (1, 64, 56, 64, 3, 1, 1, 1, False, True),
+        (1, 64, 56, 64, 3, 1, 1, 2, False, False),
+        (9, 64, 56, 64, 3, 1, 1, 1, False, False),
+        (4, 4, 4, 4, 4, 4, 4, 1, False, False),
+        (1, 32, 149, 32, 3, 1, 0, 1, False, False),
+        (7, 32, 149, 32, 3, 1, 0, 1, False, False),
+        (1, 32, 35, 64, 7, 2, (0, 0, 1, 1), 1, False, False),
+        (1, 32, 35, 64, 7, 2, (0, 0, 2, 2), 1, False, False),
+    ],
+)
+def test_conv2d_nchw_int8(in_dtype, params):
+    with Int8Fallback():
+        (
+            batch,
+            in_channel,
+            in_size,
+            num_filter,
+            kernel,
+            stride,
+            padding,
+            dilation,
+            add_bias,
+            add_relu,
+        ) = params
+        pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (kernel, kernel))
+        padding_sum = pad_top + pad_left + pad_bottom + pad_right
+        print(
+            "Workload: (%d, %d, %d, %d, %d, %d, %d, %d)"
+            % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation)
+        )
 
-        # inception v3 workloads where channels in / out are multiple of oc_block_factor
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 32, 149, 32, 3, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 32, 147, 64, 3, 1, 1)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 73, 80, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 80, 73, 192, 3, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 192, 35, 64, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 192, 35, 48, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 48, 35, 64, 5, 1, 2)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 35, 96, 3, 1, 1)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 96, 35, 96, 3, 1, 1)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 192, 35, 32, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 256, 35, 64, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 256, 35, 48, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 288, 35, 64, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 288, 35, 48, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 288, 35, 384, 3, 2, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 96, 35, 96, 3, 2, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 768, 17, 192, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 768, 17, 128, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 128, 17, 128, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 128, 17, 192, 7, 1, 3)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 128, 17, 128, 7, 1, 3)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 128, 17, 192, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 768, 17, 160, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 160, 17, 160, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 160, 17, 192, 7, 1, 3)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 160, 17, 160, 7, 1, 3)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 160, 17, 192, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 192, 17, 192, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 192, 17, 192, 7, 1, 3)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 192, 17, 320, 3, 2, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 192, 17, 192, 3, 2, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 1280, 8, 320, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 1280, 8, 384, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 384, 8, 384, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 384, 8, 384, 3, 1, 1)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 1280, 8, 448, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 448, 8, 384, 3, 1, 1)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 1280, 8, 192, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 2048, 8, 320, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 2048, 8, 384, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 2048, 8, 448, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 2048, 8, 192, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 1024, 19, 88, 3, 1, 1)
+        in_height = in_width = in_size
 
-        # batch > 1
-        verify_conv2d_NCHWc_int8(in_dtype, 7, 32, 149, 32, 3, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 8, 32, 149, 32, 3, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 32, 32, 149, 32, 3, 1, 0)
+        A = te.placeholder((batch, in_channel, in_height, in_width), name="A", dtype=in_dtype)
+        W = te.placeholder((num_filter, in_channel, kernel, kernel), name="W", dtype=in_dtype)
+        bias = te.placeholder((num_filter, 1, 1), name="bias", dtype=in_dtype)
 
-        # Asymmetric padding
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 32, 35, 64, 7, 2, (0, 0, 1, 1))
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 8, 128, 3, 1, (3, 3, 2, 2))
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 8, 64, 1, 1, (1, 2, 2, 1))
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 17, 192, 1, 1, (1, 2))
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 8, 64, 3, 1, (3, 1))
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 128, 8, 384, 3, 1, (0, 2))
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 8, 64, 1, 1, "VALID")
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 392, 8, 64, 3, 1, "VALID")
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 512, 19, 64, 1, 1, "SAME")
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 16, 32, 2, 1, "SAME")
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 8, 64, 3, 1, (1, 2, 2, 1), add_relu=True)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 8, 64, 5, 2, (1, 3), add_bias=True)
-        verify_conv2d_NCHWc_int8(
-            in_dtype, 1, 64, 56, 64, 3, 1, "VALID", add_bias=True, add_relu=True
-        )
-        verify_conv2d_NCHWc_int8(
-            in_dtype, 1, 64, 56, 64, 24, 1, "SAME", add_bias=True, add_relu=True
-        )
+        a_shape = get_const_tuple(A.shape)
+        w_shape = get_const_tuple(W.shape)
+        bias_shape = get_const_tuple(bias.shape)
+        dtype = A.dtype
+
+        @memoize("topi.tests.test_topi_conv2d_int8.verify_conv2d_nchw")
+        def get_ref_data():
+            a_np = np.random.randint(low=-128, high=127, size=a_shape).astype(dtype)
+            w_np = np.random.randint(low=-128, high=128, size=w_shape).astype(dtype)
+            b_np = np.random.uniform(size=bias_shape).astype(dtype)
+            dw_np = tvm.topi.testing.dilate_python(w_np, (1, 1, dilation, dilation))
+            c_np = tvm.topi.testing.conv2d_nchw_python(a_np, dw_np, stride, padding).astype(dtype)
 
-        # Conv2d NCHW int8 schedule testing. Internally, it uses NCHWc schedule. So, just
-        # performing basic testing - one test for all different scenarios - batch, dilation etc..
-        verify_conv2d_nchw_int8(in_dtype, 1, 64, 56, 64, 3, 1, 1)
-        verify_conv2d_nchw_int8(in_dtype, 1, 64, 56, 64, 3, 1, 1, add_relu=True)
-        verify_conv2d_nchw_int8(in_dtype, 1, 64, 56, 64, 3, 1, 1, dilation=2)
-        verify_conv2d_nchw_int8(in_dtype, 9, 64, 56, 64, 3, 1, 1)
-        verify_conv2d_nchw_int8(in_dtype, 4, 4, 4, 4, 4, 4, 4)
-        verify_conv2d_nchw_int8(in_dtype, 1, 32, 149, 32, 3, 1, 0)
-        verify_conv2d_nchw_int8(in_dtype, 7, 32, 149, 32, 3, 1, 0)
-        verify_conv2d_nchw_int8(in_dtype, 1, 32, 35, 64, 7, 2, (0, 0, 1, 1))
-        verify_conv2d_nchw_int8(in_dtype, 1, 32, 35, 64, 7, 2, (0, 0, 2, 2))
+            if add_bias:
+                b_np = np.random.uniform(size=bias_shape).astype(dtype)
+                c_np += b_np
+            if add_relu:
+                c_np = np.maximum(c_np, 0)
 
+            return a_np, w_np, b_np, c_np
 
-def test_conv2d_nhwc():
-    with Int8Fallback():
-        # Subset of inception v3 expanded (dilation > 1, batch > 1, 'VALID' padding)
-        verify_conv2d_NHWC_gemm_int8(1, 3, 299, 32, 3, 2, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 32, 149, 32, 3, 1, "SAME", dilation=2)
-        verify_conv2d_NHWC_gemm_int8(4, 32, 147, 64, 3, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 64, 73, 80, 1, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 80, 73, 192, 3, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 192, 35, 48, 1, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 192, 35, 64, 1, 1, "VALID")
-        verify_conv2d_NHWC_gemm_int8(1, 192, 35, 32, 1, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 48, 35, 64, 5, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 96, 35, 96, 3, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 256, 35, 48, 1, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 256, 35, 64, 1, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 288, 35, 64, 1, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 288, 35, 48, 1, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 96, 35, 96, 3, 2, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 128, 17, 192, 7, 1, "SAME", dilation=2)
-        verify_conv2d_NHWC_gemm_int8(1, 160, 17, 160, 7, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 160, 17, 192, 1, 1, "VALID")
-        verify_conv2d_NHWC_gemm_int8(1, 192, 17, 192, 1, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 768, 5, 128, 1, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 192, 17, 320, 3, 2, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 192, 17, 192, 3, 2, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 1280, 8, 192, 1, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 1280, 8, 384, 1, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 1280, 8, 320, 1, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 1280, 8, 448, 1, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 384, 8, 384, 1, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 384, 8, 384, 3, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 448, 8, 384, 3, 1, "VALID")
-        verify_conv2d_NHWC_gemm_int8(1, 2048, 8, 320, 1, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 2048, 8, 448, 1, 1, "SAME", add_bias=True, add_relu=True)
-        verify_conv2d_NHWC_gemm_int8(1, 2048, 8, 192, 1, 1, "SAME", add_bias=True)
-
-        # Let's also verify that it compiles fine on AArch64 targets
-        compile_conv2d_NHWC_gemm_int8_arm(1, 3, 299, 32, 3, 2, "SAME")
+        a_np, w_np, b_np, c_np = get_ref_data()
+
+        def verify_workload_padding():
+            _, _, out_height, out_width = get_const_tuple(c_np.shape)
+            wkl = _get_workload(A, W, (stride, stride), padding, dilation, dtype)
+
+            # for testing functionality,
+            # we choose arbitrary int32_lanes and num_int8_elements can divide the channel,
+            # regardless of the performance.
+            int32_lanes, num_int8_elements = num_filter, in_channel
+
+            # check if tile_ow candidates are the factors of the right output weight.
+            cfg = autotvm.get_config()
+            fallback_schedule_cpu_common_int8(cfg, wkl, int32_lanes, num_int8_elements)
+            ow_tile = np.prod(cfg["tile_ow"].size)
+
+            tvm.testing.assert_allclose(ow_tile, out_width)
+
+        def check_target(target):
+            dev = tvm.device(target, 0)
+            if not tvm.testing.device_enabled(target):
+                print("Skip because %s is not enabled" % target)
+                return
+            if target == "cuda" and not tvm.contrib.nvcc.have_int8(dev.compute_version):
+                print("Skip because int8 intrinsics are not available")
+                return

Review Comment:
   I added the `pytest.skip`. I experimented with hoisting out the functions since all the tests in that file do something similar, but annoyingly the functions are all subtly different and depend on pretty much all the parameters passed to the test and defined in the test, also compute definitions, schedules, utility functions etc which would all need to be passed as arguments, so it didn't look like it was worth it. 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] ekalda commented on a diff in pull request #13669: [TOPI][bugfix] Fix a bug in arm_cpu int8 dotprod schedule and modernize tests

Posted by GitBox <gi...@apache.org>.
ekalda commented on code in PR #13669:
URL: https://github.com/apache/tvm/pull/13669#discussion_r1058904165


##########
tests/python/topi/python/test_topi_conv2d_int8.py:
##########
@@ -298,378 +175,462 @@ def get_ref_data():
 
         a_np, w_np, b_np, c_np = get_ref_data()
 
-        with tvm.target.Target(target):
-            C = compute(
-                A,
-                W,
-                (stride, stride),
-                padding,
-                (dilation, dilation),
-                "NCHW",
-                "NCHW",
-                out_dtype,
-            )
+        dev = tvm.device(target, 0)
+        with tvm.target.Target(target) as tvm_target:
+            C = compute(A, W, (stride, stride), padding, (dilation, dilation), dtype)
             if add_bias:
                 C = topi.add(C, bias)
             if add_relu:
                 C = topi.nn.relu(C)
             s = schedule([C])
 
-        a = tvm.nd.array(a_np.astype(dtype), dev)
-        w = tvm.nd.array(w_np.astype(dtype), dev)
-        b = tvm.nd.array(b_np.astype(out_dtype), dev)
-        c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), dev)
-
-        if add_bias:
-            compile_args = [A, W, bias, C]
-            run_args = [a, w, b, c]
-        else:
-            compile_args = [A, W, C]
-            run_args = [a, w, c]
-
-        func = tvm.build(
-            s,
-            compile_args,
-            target,
-            name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
-            % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation),
-        )
+            a = tvm.nd.array(a_np, dev)
+            w = tvm.nd.array(w_np, dev)
+            b = tvm.nd.array(b_np, dev)
+            c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), dev)
 
-        if build_only:
-            return
+            build_inputs = [A, W, bias, C] if add_bias else [A, W, C]
+            inference_inputs = (a, w, b, c) if add_bias else (a, w, c)
+
+            func = tvm.build(
+                s,
+                build_inputs,
+                target,
+                name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
+                % (
+                    batch,
+                    in_channel,
+                    in_size,
+                    num_filter,
+                    kernel,
+                    stride,
+                    padding_sum,
+                    dilation,
+                ),
+            )
 
-        print("Running on target: %s" % target)
+            build_only = tvm_target.features.is_aarch64 and (platform.machine() != "aarch64")
 
-        func(*run_args)
+            if not build_only:
+                print("Running on target: %s" % target)
+                func(*inference_inputs)
+                tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-5)
 
-        tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-5)
 
-    targets = [
-        (
-            "cuda",
-            lambda a, w, s, p, d, l, ol, o: topi.cuda.conv2d_NCHWc_int8(a, w, s, p, d, l, o),
-            topi.cuda.schedule_conv2d_NCHWc_int8,
-            4,
-            False,
-        ),
-        # Disable on CI since it does not support spirv int8 dot product
-        # (
-        #     "vulkan -from_device=0",
-        #     lambda a, w, s, p, d, l, ol, o: topi.cuda.conv2d_NCHWc_int8(a, w, s, p, d, l, o),
-        #     topi.cuda.schedule_conv2d_NCHWc_int8,
-        #     4,
-        #     False,
-        # ),
-    ]
-
-    build_only_aarch64 = platform.machine() != "aarch64"
-
-    targets.append(
+@pytest.mark.parametrize("in_dtype", ["int8", "uint8"])
+@pytest.mark.parametrize(
+    "params",
+    [
+        # ResNet18 workloads where channels in / out are multiple of oc_block_factor
+        (1, 64, 56, 64, 3, 1, 1, 1, False, False),
+        (1, 64, 56, 64, 1, 1, 0, 1, False, False),
+        (1, 64, 56, 128, 3, 2, 1, 1, False, False),
+        (1, 64, 56, 128, 1, 2, 0, 1, False, False),
+        (1, 128, 28, 128, 3, 1, 1, 1, False, False),
+        (1, 128, 28, 256, 3, 2, 1, 1, False, False),
+        (1, 128, 28, 256, 1, 2, 0, 1, False, False),
+        (1, 256, 14, 256, 3, 1, 1, 1, False, False),
+        (1, 256, 14, 512, 3, 2, 1, 1, False, False),
+        (1, 256, 14, 512, 1, 2, 0, 1, False, False),
+        (1, 512, 7, 512, 3, 1, 1, 1, False, False),
+        # bias, relu
+        (1, 64, 56, 64, 3, 1, 1, 1, False, True),
+        (1, 64, 56, 64, 3, 1, 1, 1, True, False),
+        (1, 64, 56, 64, 3, 1, 1, 1, True, True),
+        # dilation = 2
+        (1, 64, 56, 64, 3, 1, 1, 2, False, False),
+        # batch size
+        (4, 64, 56, 64, 3, 1, 1, 1, False, False),
+        (9, 64, 56, 64, 3, 1, 1, 1, False, False),
+        # weird workloads
+        (4, 4, 4, 8, 4, 4, 4, 1, False, False),
+        # inception v3 workloads where channels in / out are multiple of oc_block_factor
+        (1, 32, 149, 32, 3, 1, 0, 1, False, False),
+        (1, 32, 147, 64, 3, 1, 1, 1, False, False),
+        (1, 64, 73, 80, 1, 1, 0, 1, False, False),
+        (1, 80, 73, 192, 3, 1, 0, 1, False, False),
+        (1, 192, 35, 64, 1, 1, 0, 1, False, False),
+        (1, 192, 35, 48, 1, 1, 0, 1, False, False),
+        (1, 48, 35, 64, 5, 1, 2, 1, False, False),
+        (1, 64, 35, 96, 3, 1, 1, 1, False, False),
+        (1, 96, 35, 96, 3, 1, 1, 1, False, False),
+        (1, 192, 35, 32, 1, 1, 0, 1, False, False),
+        (1, 256, 35, 64, 1, 1, 0, 1, False, False),
+        (1, 256, 35, 48, 1, 1, 0, 1, False, False),
+        (1, 288, 35, 64, 1, 1, 0, 1, False, False),
+        (1, 288, 35, 48, 1, 1, 0, 1, False, False),
+        (1, 288, 35, 384, 3, 2, 0, 1, False, False),
+        (1, 96, 35, 96, 3, 2, 0, 1, False, False),
+        (1, 768, 17, 192, 1, 1, 0, 1, False, False),
+        (1, 768, 17, 128, 1, 1, 0, 1, False, False),
+        (1, 128, 17, 128, 1, 1, 0, 1, False, False),
+        (1, 128, 17, 192, 7, 1, 3, 1, False, False),
+        (1, 128, 17, 128, 7, 1, 3, 1, False, False),
+        (1, 128, 17, 192, 1, 1, 0, 1, False, False),
+        (1, 768, 17, 160, 1, 1, 0, 1, False, False),
+        (1, 160, 17, 160, 1, 1, 0, 1, False, False),
+        (1, 160, 17, 192, 7, 1, 3, 1, False, False),
+        (1, 160, 17, 160, 7, 1, 3, 1, False, False),
+        (1, 160, 17, 192, 1, 1, 0, 1, False, False),
+        (1, 192, 17, 192, 1, 1, 0, 1, False, False),
+        (1, 192, 17, 192, 7, 1, 3, 1, False, False),
+        (1, 192, 17, 320, 3, 2, 0, 1, False, False),
+        (1, 192, 17, 192, 3, 2, 0, 1, False, False),
+        (1, 1280, 8, 320, 1, 1, 0, 1, False, False),
+        (1, 1280, 8, 384, 1, 1, 0, 1, False, False),
+        (1, 384, 8, 384, 1, 1, 0, 1, False, False),
+        (1, 384, 8, 384, 3, 1, 1, 1, False, False),
+        (1, 1280, 8, 448, 1, 1, 0, 1, False, False),
+        (1, 448, 8, 384, 3, 1, 1, 1, False, False),
+        (1, 1280, 8, 192, 1, 1, 0, 1, False, False),
+        (1, 2048, 8, 320, 1, 1, 0, 1, False, False),
+        (1, 2048, 8, 384, 1, 1, 0, 1, False, False),
+        (1, 2048, 8, 448, 1, 1, 0, 1, False, False),
+        (1, 2048, 8, 192, 1, 1, 0, 1, False, False),
+        (1, 1024, 19, 88, 3, 1, 1, 1, False, False),
+        # batch > 1
+        (7, 32, 149, 32, 3, 1, 0, 1, False, False),
+        (8, 32, 149, 32, 3, 1, 0, 1, False, False),
+        (32, 32, 149, 32, 3, 1, 0, 1, False, False),
+        # Asymmetric padding
+        (1, 32, 35, 64, 7, 2, (0, 0, 1, 1), 1, False, False),
+        (1, 64, 8, 128, 3, 1, (3, 3, 2, 2), 1, False, False),
+        (1, 64, 8, 64, 1, 1, (1, 2, 2, 1), 1, False, False),
+        (1, 64, 17, 192, 1, 1, (1, 2), 1, False, False),
+        (1, 64, 8, 64, 3, 1, (3, 1), 1, False, False),
+        (1, 128, 8, 384, 3, 1, (0, 2), 1, False, False),
+        (1, 64, 8, 64, 1, 1, "VALID", 1, False, False),
+        (1, 392, 8, 64, 3, 1, "VALID", 1, False, False),
+        (1, 512, 19, 64, 1, 1, "SAME", 1, False, False),
+        (1, 64, 16, 32, 2, 1, "SAME", 1, False, False),
+        (1, 64, 8, 64, 3, 1, (1, 2, 2, 1), 1, False, True),
+        (1, 64, 8, 64, 5, 2, (1, 3), 1, True, False),
+        (1, 64, 56, 64, 3, 1, "VALID", 1, True, True),
+        (1, 64, 56, 64, 24, 1, "SAME", 1, True, True),
+    ],
+)
+def test_conv2d_NCHWc_int8(in_dtype, params):
+    with Int8Fallback():
         (
-            "llvm -device arm_cpu -mtriple aarch64-linux-gnu -mattr=+neon,+v8.2a,+dotprod",
-            topi.arm_cpu.conv2d_NCHWc_int8,
-            topi.arm_cpu.schedule_conv2d_NCHWc_int8,
-            8,
-            build_only_aarch64,
+            batch,
+            in_channel,
+            in_size,
+            num_filter,
+            kernel,
+            stride,
+            padding,
+            dilation,
+            add_bias,
+            add_relu,
+        ) = params
+        pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (kernel, kernel))
+        padding_sum = pad_top + pad_left + pad_bottom + pad_right
+        print(
+            "Workload: (%d, %d, %d, %d, %d, %d, %d, %d)"
+            % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation)
         )
-    )
-
-    if in_dtype == "int8":
-        targets += [
-            (
-                "llvm -device arm_cpu -mtriple aarch64-linux-gnu -mattr=+neon",
-                topi.arm_cpu.conv2d_NCHWc_int8,
-                topi.arm_cpu.schedule_conv2d_NCHWc_int8,
-                8,
-                build_only_aarch64,
-            ),
-            (
-                "rocm -mattr=+dotprod",
-                lambda a, w, s, p, d, l, ol, o: topi.cuda.conv2d_NCHWc_int8(a, w, s, p, d, l, o),
-                topi.cuda.schedule_conv2d_NCHWc_int8,
-                4,
-                False,
-            ),
-        ]
-
-    for target, compute, schedule, oc_block_factor, build_only in targets:
-        check_target(target, compute, schedule, oc_block_factor, build_only)
-
-
-def verify_conv2d_nchw_int8(
-    in_dtype,
-    batch,
-    in_channel,
-    in_size,
-    num_filter,
-    kernel,
-    stride,
-    padding,
-    dilation=1,
-    add_bias=False,
-    add_relu=False,
-):
-    pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (kernel, kernel))
-    padding_sum = pad_top + pad_left + pad_bottom + pad_right
-    print(
-        "Workload: (%d, %d, %d, %d, %d, %d, %d, %d)"
-        % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation)
-    )
-
-    in_height = in_width = in_size
-
-    A = te.placeholder((batch, in_channel, in_height, in_width), name="A", dtype=in_dtype)
-    W = te.placeholder((num_filter, in_channel, kernel, kernel), name="W", dtype=in_dtype)
-    bias = te.placeholder((num_filter, 1, 1), name="bias", dtype=in_dtype)
-
-    a_shape = get_const_tuple(A.shape)
-    w_shape = get_const_tuple(W.shape)
-    bias_shape = get_const_tuple(bias.shape)
-    dtype = A.dtype
-
-    @memoize("topi.tests.test_topi_conv2d_int8.verify_conv2d_nchw")
-    def get_ref_data():
-        a_np = np.random.randint(low=-128, high=127, size=a_shape).astype(dtype)
-        w_np = np.random.randint(low=-128, high=128, size=w_shape).astype(dtype)
-        b_np = np.random.uniform(size=bias_shape).astype(dtype)
-        dw_np = tvm.topi.testing.dilate_python(w_np, (1, 1, dilation, dilation))
-        c_np = tvm.topi.testing.conv2d_nchw_python(a_np, dw_np, stride, padding).astype(dtype)
-
-        if add_bias:
-            b_np = np.random.uniform(size=bias_shape).astype(dtype)
-            c_np += b_np
-        if add_relu:
-            c_np = np.maximum(c_np, 0)
-
-        return a_np, w_np, b_np, c_np
-
-    a_np, w_np, b_np, c_np = get_ref_data()
-
-    def verify_workload_padding():
-        _, _, out_height, out_width = get_const_tuple(c_np.shape)
-        wkl = _get_workload(A, W, (stride, stride), padding, dilation, dtype)
-
-        # for testing functionality,
-        # we choose arbitrary int32_lanes and num_int8_elements can divide the channel,
-        # regardless of the performance.
-        int32_lanes, num_int8_elements = num_filter, in_channel
 
-        # check if tile_ow candidates are the factors of the right output weight.
-        cfg = autotvm.get_config()
-        fallback_schedule_cpu_common_int8(cfg, wkl, int32_lanes, num_int8_elements)
-        ow_tile = np.prod(cfg["tile_ow"].size)
-
-        tvm.testing.assert_allclose(ow_tile, out_width)
+        in_height = in_width = in_size
+
+        A = te.placeholder((batch, in_channel, in_height, in_width), name="A", dtype=in_dtype)
+        W = te.placeholder((num_filter, in_channel, kernel, kernel), name="W", dtype=in_dtype)
+
+        a_shape = get_const_tuple(A.shape)
+        w_shape = get_const_tuple(W.shape)
+        dtype = A.dtype
+        out_dtype = "int32" if in_dtype == "int8" else "uint32"
+        lo = -128 if in_dtype == "int8" else 0
+        hi = 127 if in_dtype == "int8" else 255

Review Comment:
   Good point! Done 



##########
tests/python/topi/python/test_topi_conv2d_int8.py:
##########
@@ -298,378 +175,462 @@ def get_ref_data():
 
         a_np, w_np, b_np, c_np = get_ref_data()
 
-        with tvm.target.Target(target):
-            C = compute(
-                A,
-                W,
-                (stride, stride),
-                padding,
-                (dilation, dilation),
-                "NCHW",
-                "NCHW",
-                out_dtype,
-            )
+        dev = tvm.device(target, 0)
+        with tvm.target.Target(target) as tvm_target:
+            C = compute(A, W, (stride, stride), padding, (dilation, dilation), dtype)
             if add_bias:
                 C = topi.add(C, bias)
             if add_relu:
                 C = topi.nn.relu(C)
             s = schedule([C])
 
-        a = tvm.nd.array(a_np.astype(dtype), dev)
-        w = tvm.nd.array(w_np.astype(dtype), dev)
-        b = tvm.nd.array(b_np.astype(out_dtype), dev)
-        c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), dev)
-
-        if add_bias:
-            compile_args = [A, W, bias, C]
-            run_args = [a, w, b, c]
-        else:
-            compile_args = [A, W, C]
-            run_args = [a, w, c]
-
-        func = tvm.build(
-            s,
-            compile_args,
-            target,
-            name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
-            % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation),
-        )
+            a = tvm.nd.array(a_np, dev)
+            w = tvm.nd.array(w_np, dev)
+            b = tvm.nd.array(b_np, dev)
+            c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), dev)
 
-        if build_only:
-            return
+            build_inputs = [A, W, bias, C] if add_bias else [A, W, C]
+            inference_inputs = (a, w, b, c) if add_bias else (a, w, c)
+
+            func = tvm.build(
+                s,
+                build_inputs,
+                target,
+                name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
+                % (
+                    batch,
+                    in_channel,
+                    in_size,
+                    num_filter,
+                    kernel,
+                    stride,
+                    padding_sum,
+                    dilation,
+                ),
+            )
 
-        print("Running on target: %s" % target)
+            build_only = tvm_target.features.is_aarch64 and (platform.machine() != "aarch64")
 
-        func(*run_args)
+            if not build_only:
+                print("Running on target: %s" % target)
+                func(*inference_inputs)
+                tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-5)
 
-        tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-5)
 
-    targets = [
-        (
-            "cuda",
-            lambda a, w, s, p, d, l, ol, o: topi.cuda.conv2d_NCHWc_int8(a, w, s, p, d, l, o),
-            topi.cuda.schedule_conv2d_NCHWc_int8,
-            4,
-            False,
-        ),
-        # Disable on CI since it does not support spirv int8 dot product
-        # (
-        #     "vulkan -from_device=0",
-        #     lambda a, w, s, p, d, l, ol, o: topi.cuda.conv2d_NCHWc_int8(a, w, s, p, d, l, o),
-        #     topi.cuda.schedule_conv2d_NCHWc_int8,
-        #     4,
-        #     False,
-        # ),
-    ]
-
-    build_only_aarch64 = platform.machine() != "aarch64"
-
-    targets.append(
+@pytest.mark.parametrize("in_dtype", ["int8", "uint8"])
+@pytest.mark.parametrize(
+    "params",
+    [
+        # ResNet18 workloads where channels in / out are multiple of oc_block_factor
+        (1, 64, 56, 64, 3, 1, 1, 1, False, False),
+        (1, 64, 56, 64, 1, 1, 0, 1, False, False),
+        (1, 64, 56, 128, 3, 2, 1, 1, False, False),
+        (1, 64, 56, 128, 1, 2, 0, 1, False, False),
+        (1, 128, 28, 128, 3, 1, 1, 1, False, False),
+        (1, 128, 28, 256, 3, 2, 1, 1, False, False),
+        (1, 128, 28, 256, 1, 2, 0, 1, False, False),
+        (1, 256, 14, 256, 3, 1, 1, 1, False, False),
+        (1, 256, 14, 512, 3, 2, 1, 1, False, False),
+        (1, 256, 14, 512, 1, 2, 0, 1, False, False),
+        (1, 512, 7, 512, 3, 1, 1, 1, False, False),
+        # bias, relu
+        (1, 64, 56, 64, 3, 1, 1, 1, False, True),
+        (1, 64, 56, 64, 3, 1, 1, 1, True, False),
+        (1, 64, 56, 64, 3, 1, 1, 1, True, True),
+        # dilation = 2
+        (1, 64, 56, 64, 3, 1, 1, 2, False, False),
+        # batch size
+        (4, 64, 56, 64, 3, 1, 1, 1, False, False),
+        (9, 64, 56, 64, 3, 1, 1, 1, False, False),
+        # weird workloads
+        (4, 4, 4, 8, 4, 4, 4, 1, False, False),
+        # inception v3 workloads where channels in / out are multiple of oc_block_factor
+        (1, 32, 149, 32, 3, 1, 0, 1, False, False),
+        (1, 32, 147, 64, 3, 1, 1, 1, False, False),
+        (1, 64, 73, 80, 1, 1, 0, 1, False, False),
+        (1, 80, 73, 192, 3, 1, 0, 1, False, False),
+        (1, 192, 35, 64, 1, 1, 0, 1, False, False),
+        (1, 192, 35, 48, 1, 1, 0, 1, False, False),
+        (1, 48, 35, 64, 5, 1, 2, 1, False, False),
+        (1, 64, 35, 96, 3, 1, 1, 1, False, False),
+        (1, 96, 35, 96, 3, 1, 1, 1, False, False),
+        (1, 192, 35, 32, 1, 1, 0, 1, False, False),
+        (1, 256, 35, 64, 1, 1, 0, 1, False, False),
+        (1, 256, 35, 48, 1, 1, 0, 1, False, False),
+        (1, 288, 35, 64, 1, 1, 0, 1, False, False),
+        (1, 288, 35, 48, 1, 1, 0, 1, False, False),
+        (1, 288, 35, 384, 3, 2, 0, 1, False, False),
+        (1, 96, 35, 96, 3, 2, 0, 1, False, False),
+        (1, 768, 17, 192, 1, 1, 0, 1, False, False),
+        (1, 768, 17, 128, 1, 1, 0, 1, False, False),
+        (1, 128, 17, 128, 1, 1, 0, 1, False, False),
+        (1, 128, 17, 192, 7, 1, 3, 1, False, False),
+        (1, 128, 17, 128, 7, 1, 3, 1, False, False),
+        (1, 128, 17, 192, 1, 1, 0, 1, False, False),
+        (1, 768, 17, 160, 1, 1, 0, 1, False, False),
+        (1, 160, 17, 160, 1, 1, 0, 1, False, False),
+        (1, 160, 17, 192, 7, 1, 3, 1, False, False),
+        (1, 160, 17, 160, 7, 1, 3, 1, False, False),
+        (1, 160, 17, 192, 1, 1, 0, 1, False, False),
+        (1, 192, 17, 192, 1, 1, 0, 1, False, False),
+        (1, 192, 17, 192, 7, 1, 3, 1, False, False),
+        (1, 192, 17, 320, 3, 2, 0, 1, False, False),
+        (1, 192, 17, 192, 3, 2, 0, 1, False, False),
+        (1, 1280, 8, 320, 1, 1, 0, 1, False, False),
+        (1, 1280, 8, 384, 1, 1, 0, 1, False, False),
+        (1, 384, 8, 384, 1, 1, 0, 1, False, False),
+        (1, 384, 8, 384, 3, 1, 1, 1, False, False),
+        (1, 1280, 8, 448, 1, 1, 0, 1, False, False),
+        (1, 448, 8, 384, 3, 1, 1, 1, False, False),
+        (1, 1280, 8, 192, 1, 1, 0, 1, False, False),
+        (1, 2048, 8, 320, 1, 1, 0, 1, False, False),
+        (1, 2048, 8, 384, 1, 1, 0, 1, False, False),
+        (1, 2048, 8, 448, 1, 1, 0, 1, False, False),
+        (1, 2048, 8, 192, 1, 1, 0, 1, False, False),
+        (1, 1024, 19, 88, 3, 1, 1, 1, False, False),
+        # batch > 1
+        (7, 32, 149, 32, 3, 1, 0, 1, False, False),
+        (8, 32, 149, 32, 3, 1, 0, 1, False, False),
+        (32, 32, 149, 32, 3, 1, 0, 1, False, False),
+        # Asymmetric padding
+        (1, 32, 35, 64, 7, 2, (0, 0, 1, 1), 1, False, False),
+        (1, 64, 8, 128, 3, 1, (3, 3, 2, 2), 1, False, False),
+        (1, 64, 8, 64, 1, 1, (1, 2, 2, 1), 1, False, False),
+        (1, 64, 17, 192, 1, 1, (1, 2), 1, False, False),
+        (1, 64, 8, 64, 3, 1, (3, 1), 1, False, False),
+        (1, 128, 8, 384, 3, 1, (0, 2), 1, False, False),
+        (1, 64, 8, 64, 1, 1, "VALID", 1, False, False),
+        (1, 392, 8, 64, 3, 1, "VALID", 1, False, False),
+        (1, 512, 19, 64, 1, 1, "SAME", 1, False, False),
+        (1, 64, 16, 32, 2, 1, "SAME", 1, False, False),
+        (1, 64, 8, 64, 3, 1, (1, 2, 2, 1), 1, False, True),
+        (1, 64, 8, 64, 5, 2, (1, 3), 1, True, False),
+        (1, 64, 56, 64, 3, 1, "VALID", 1, True, True),
+        (1, 64, 56, 64, 24, 1, "SAME", 1, True, True),
+    ],
+)
+def test_conv2d_NCHWc_int8(in_dtype, params):
+    with Int8Fallback():
         (
-            "llvm -device arm_cpu -mtriple aarch64-linux-gnu -mattr=+neon,+v8.2a,+dotprod",
-            topi.arm_cpu.conv2d_NCHWc_int8,
-            topi.arm_cpu.schedule_conv2d_NCHWc_int8,
-            8,
-            build_only_aarch64,
+            batch,
+            in_channel,
+            in_size,
+            num_filter,
+            kernel,
+            stride,
+            padding,
+            dilation,
+            add_bias,
+            add_relu,
+        ) = params
+        pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (kernel, kernel))
+        padding_sum = pad_top + pad_left + pad_bottom + pad_right
+        print(
+            "Workload: (%d, %d, %d, %d, %d, %d, %d, %d)"
+            % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation)
         )
-    )
-
-    if in_dtype == "int8":
-        targets += [
-            (
-                "llvm -device arm_cpu -mtriple aarch64-linux-gnu -mattr=+neon",
-                topi.arm_cpu.conv2d_NCHWc_int8,
-                topi.arm_cpu.schedule_conv2d_NCHWc_int8,
-                8,
-                build_only_aarch64,
-            ),
-            (
-                "rocm -mattr=+dotprod",
-                lambda a, w, s, p, d, l, ol, o: topi.cuda.conv2d_NCHWc_int8(a, w, s, p, d, l, o),
-                topi.cuda.schedule_conv2d_NCHWc_int8,
-                4,
-                False,
-            ),
-        ]
-
-    for target, compute, schedule, oc_block_factor, build_only in targets:
-        check_target(target, compute, schedule, oc_block_factor, build_only)
-
-
-def verify_conv2d_nchw_int8(
-    in_dtype,
-    batch,
-    in_channel,
-    in_size,
-    num_filter,
-    kernel,
-    stride,
-    padding,
-    dilation=1,
-    add_bias=False,
-    add_relu=False,
-):
-    pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (kernel, kernel))
-    padding_sum = pad_top + pad_left + pad_bottom + pad_right
-    print(
-        "Workload: (%d, %d, %d, %d, %d, %d, %d, %d)"
-        % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation)
-    )
-
-    in_height = in_width = in_size
-
-    A = te.placeholder((batch, in_channel, in_height, in_width), name="A", dtype=in_dtype)
-    W = te.placeholder((num_filter, in_channel, kernel, kernel), name="W", dtype=in_dtype)
-    bias = te.placeholder((num_filter, 1, 1), name="bias", dtype=in_dtype)
-
-    a_shape = get_const_tuple(A.shape)
-    w_shape = get_const_tuple(W.shape)
-    bias_shape = get_const_tuple(bias.shape)
-    dtype = A.dtype
-
-    @memoize("topi.tests.test_topi_conv2d_int8.verify_conv2d_nchw")
-    def get_ref_data():
-        a_np = np.random.randint(low=-128, high=127, size=a_shape).astype(dtype)
-        w_np = np.random.randint(low=-128, high=128, size=w_shape).astype(dtype)
-        b_np = np.random.uniform(size=bias_shape).astype(dtype)
-        dw_np = tvm.topi.testing.dilate_python(w_np, (1, 1, dilation, dilation))
-        c_np = tvm.topi.testing.conv2d_nchw_python(a_np, dw_np, stride, padding).astype(dtype)
-
-        if add_bias:
-            b_np = np.random.uniform(size=bias_shape).astype(dtype)
-            c_np += b_np
-        if add_relu:
-            c_np = np.maximum(c_np, 0)
-
-        return a_np, w_np, b_np, c_np
-
-    a_np, w_np, b_np, c_np = get_ref_data()
-
-    def verify_workload_padding():
-        _, _, out_height, out_width = get_const_tuple(c_np.shape)
-        wkl = _get_workload(A, W, (stride, stride), padding, dilation, dtype)
-
-        # for testing functionality,
-        # we choose arbitrary int32_lanes and num_int8_elements can divide the channel,
-        # regardless of the performance.
-        int32_lanes, num_int8_elements = num_filter, in_channel
 
-        # check if tile_ow candidates are the factors of the right output weight.
-        cfg = autotvm.get_config()
-        fallback_schedule_cpu_common_int8(cfg, wkl, int32_lanes, num_int8_elements)
-        ow_tile = np.prod(cfg["tile_ow"].size)
-
-        tvm.testing.assert_allclose(ow_tile, out_width)
+        in_height = in_width = in_size
+
+        A = te.placeholder((batch, in_channel, in_height, in_width), name="A", dtype=in_dtype)
+        W = te.placeholder((num_filter, in_channel, kernel, kernel), name="W", dtype=in_dtype)
+
+        a_shape = get_const_tuple(A.shape)
+        w_shape = get_const_tuple(W.shape)
+        dtype = A.dtype
+        out_dtype = "int32" if in_dtype == "int8" else "uint32"
+        lo = -128 if in_dtype == "int8" else 0
+        hi = 127 if in_dtype == "int8" else 255
+
+        def check_target(target, compute, schedule, oc_block_factor, build_only):
+            dev = tvm.device(target, 0)
+            if not tvm.testing.device_enabled(target):
+                print("Skip because %s is not enabled" % target)
+                return
+            if target == "cuda" and not tvm.contrib.nvcc.have_int8(dev.compute_version):
+                print("Skip because int8 intrinsics are not available")
+                return

Review Comment:
   Done



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] Mousius commented on a diff in pull request #13669: [TOPI][bugfix] Fix a bug in arm_cpu int8 dotprod schedule and modernize tests

Posted by GitBox <gi...@apache.org>.
Mousius commented on code in PR #13669:
URL: https://github.com/apache/tvm/pull/13669#discussion_r1058367284


##########
tests/python/topi/python/test_topi_conv2d_int8.py:
##########
@@ -298,378 +175,462 @@ def get_ref_data():
 
         a_np, w_np, b_np, c_np = get_ref_data()
 
-        with tvm.target.Target(target):
-            C = compute(
-                A,
-                W,
-                (stride, stride),
-                padding,
-                (dilation, dilation),
-                "NCHW",
-                "NCHW",
-                out_dtype,
-            )
+        dev = tvm.device(target, 0)
+        with tvm.target.Target(target) as tvm_target:
+            C = compute(A, W, (stride, stride), padding, (dilation, dilation), dtype)
             if add_bias:
                 C = topi.add(C, bias)
             if add_relu:
                 C = topi.nn.relu(C)
             s = schedule([C])
 
-        a = tvm.nd.array(a_np.astype(dtype), dev)
-        w = tvm.nd.array(w_np.astype(dtype), dev)
-        b = tvm.nd.array(b_np.astype(out_dtype), dev)
-        c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), dev)
-
-        if add_bias:
-            compile_args = [A, W, bias, C]
-            run_args = [a, w, b, c]
-        else:
-            compile_args = [A, W, C]
-            run_args = [a, w, c]
-
-        func = tvm.build(
-            s,
-            compile_args,
-            target,
-            name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
-            % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation),
-        )
+            a = tvm.nd.array(a_np, dev)
+            w = tvm.nd.array(w_np, dev)
+            b = tvm.nd.array(b_np, dev)
+            c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), dev)
 
-        if build_only:
-            return
+            build_inputs = [A, W, bias, C] if add_bias else [A, W, C]
+            inference_inputs = (a, w, b, c) if add_bias else (a, w, c)
+
+            func = tvm.build(
+                s,
+                build_inputs,
+                target,
+                name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
+                % (
+                    batch,
+                    in_channel,
+                    in_size,
+                    num_filter,
+                    kernel,
+                    stride,
+                    padding_sum,
+                    dilation,
+                ),
+            )
 
-        print("Running on target: %s" % target)
+            build_only = tvm_target.features.is_aarch64 and (platform.machine() != "aarch64")
 
-        func(*run_args)
+            if not build_only:
+                print("Running on target: %s" % target)
+                func(*inference_inputs)
+                tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-5)
 
-        tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-5)
 
-    targets = [
-        (
-            "cuda",
-            lambda a, w, s, p, d, l, ol, o: topi.cuda.conv2d_NCHWc_int8(a, w, s, p, d, l, o),
-            topi.cuda.schedule_conv2d_NCHWc_int8,
-            4,
-            False,
-        ),
-        # Disable on CI since it does not support spirv int8 dot product
-        # (
-        #     "vulkan -from_device=0",
-        #     lambda a, w, s, p, d, l, ol, o: topi.cuda.conv2d_NCHWc_int8(a, w, s, p, d, l, o),
-        #     topi.cuda.schedule_conv2d_NCHWc_int8,
-        #     4,
-        #     False,
-        # ),
-    ]
-
-    build_only_aarch64 = platform.machine() != "aarch64"
-
-    targets.append(
+@pytest.mark.parametrize("in_dtype", ["int8", "uint8"])
+@pytest.mark.parametrize(
+    "params",
+    [
+        # ResNet18 workloads where channels in / out are multiple of oc_block_factor
+        (1, 64, 56, 64, 3, 1, 1, 1, False, False),
+        (1, 64, 56, 64, 1, 1, 0, 1, False, False),
+        (1, 64, 56, 128, 3, 2, 1, 1, False, False),
+        (1, 64, 56, 128, 1, 2, 0, 1, False, False),
+        (1, 128, 28, 128, 3, 1, 1, 1, False, False),
+        (1, 128, 28, 256, 3, 2, 1, 1, False, False),
+        (1, 128, 28, 256, 1, 2, 0, 1, False, False),
+        (1, 256, 14, 256, 3, 1, 1, 1, False, False),
+        (1, 256, 14, 512, 3, 2, 1, 1, False, False),
+        (1, 256, 14, 512, 1, 2, 0, 1, False, False),
+        (1, 512, 7, 512, 3, 1, 1, 1, False, False),
+        # bias, relu
+        (1, 64, 56, 64, 3, 1, 1, 1, False, True),
+        (1, 64, 56, 64, 3, 1, 1, 1, True, False),
+        (1, 64, 56, 64, 3, 1, 1, 1, True, True),
+        # dilation = 2
+        (1, 64, 56, 64, 3, 1, 1, 2, False, False),
+        # batch size
+        (4, 64, 56, 64, 3, 1, 1, 1, False, False),
+        (9, 64, 56, 64, 3, 1, 1, 1, False, False),
+        # weird workloads
+        (4, 4, 4, 8, 4, 4, 4, 1, False, False),
+        # inception v3 workloads where channels in / out are multiple of oc_block_factor
+        (1, 32, 149, 32, 3, 1, 0, 1, False, False),
+        (1, 32, 147, 64, 3, 1, 1, 1, False, False),
+        (1, 64, 73, 80, 1, 1, 0, 1, False, False),
+        (1, 80, 73, 192, 3, 1, 0, 1, False, False),
+        (1, 192, 35, 64, 1, 1, 0, 1, False, False),
+        (1, 192, 35, 48, 1, 1, 0, 1, False, False),
+        (1, 48, 35, 64, 5, 1, 2, 1, False, False),
+        (1, 64, 35, 96, 3, 1, 1, 1, False, False),
+        (1, 96, 35, 96, 3, 1, 1, 1, False, False),
+        (1, 192, 35, 32, 1, 1, 0, 1, False, False),
+        (1, 256, 35, 64, 1, 1, 0, 1, False, False),
+        (1, 256, 35, 48, 1, 1, 0, 1, False, False),
+        (1, 288, 35, 64, 1, 1, 0, 1, False, False),
+        (1, 288, 35, 48, 1, 1, 0, 1, False, False),
+        (1, 288, 35, 384, 3, 2, 0, 1, False, False),
+        (1, 96, 35, 96, 3, 2, 0, 1, False, False),
+        (1, 768, 17, 192, 1, 1, 0, 1, False, False),
+        (1, 768, 17, 128, 1, 1, 0, 1, False, False),
+        (1, 128, 17, 128, 1, 1, 0, 1, False, False),
+        (1, 128, 17, 192, 7, 1, 3, 1, False, False),
+        (1, 128, 17, 128, 7, 1, 3, 1, False, False),
+        (1, 128, 17, 192, 1, 1, 0, 1, False, False),
+        (1, 768, 17, 160, 1, 1, 0, 1, False, False),
+        (1, 160, 17, 160, 1, 1, 0, 1, False, False),
+        (1, 160, 17, 192, 7, 1, 3, 1, False, False),
+        (1, 160, 17, 160, 7, 1, 3, 1, False, False),
+        (1, 160, 17, 192, 1, 1, 0, 1, False, False),
+        (1, 192, 17, 192, 1, 1, 0, 1, False, False),
+        (1, 192, 17, 192, 7, 1, 3, 1, False, False),
+        (1, 192, 17, 320, 3, 2, 0, 1, False, False),
+        (1, 192, 17, 192, 3, 2, 0, 1, False, False),
+        (1, 1280, 8, 320, 1, 1, 0, 1, False, False),
+        (1, 1280, 8, 384, 1, 1, 0, 1, False, False),
+        (1, 384, 8, 384, 1, 1, 0, 1, False, False),
+        (1, 384, 8, 384, 3, 1, 1, 1, False, False),
+        (1, 1280, 8, 448, 1, 1, 0, 1, False, False),
+        (1, 448, 8, 384, 3, 1, 1, 1, False, False),
+        (1, 1280, 8, 192, 1, 1, 0, 1, False, False),
+        (1, 2048, 8, 320, 1, 1, 0, 1, False, False),
+        (1, 2048, 8, 384, 1, 1, 0, 1, False, False),
+        (1, 2048, 8, 448, 1, 1, 0, 1, False, False),
+        (1, 2048, 8, 192, 1, 1, 0, 1, False, False),
+        (1, 1024, 19, 88, 3, 1, 1, 1, False, False),
+        # batch > 1
+        (7, 32, 149, 32, 3, 1, 0, 1, False, False),
+        (8, 32, 149, 32, 3, 1, 0, 1, False, False),
+        (32, 32, 149, 32, 3, 1, 0, 1, False, False),
+        # Asymmetric padding
+        (1, 32, 35, 64, 7, 2, (0, 0, 1, 1), 1, False, False),
+        (1, 64, 8, 128, 3, 1, (3, 3, 2, 2), 1, False, False),
+        (1, 64, 8, 64, 1, 1, (1, 2, 2, 1), 1, False, False),
+        (1, 64, 17, 192, 1, 1, (1, 2), 1, False, False),
+        (1, 64, 8, 64, 3, 1, (3, 1), 1, False, False),
+        (1, 128, 8, 384, 3, 1, (0, 2), 1, False, False),
+        (1, 64, 8, 64, 1, 1, "VALID", 1, False, False),
+        (1, 392, 8, 64, 3, 1, "VALID", 1, False, False),
+        (1, 512, 19, 64, 1, 1, "SAME", 1, False, False),
+        (1, 64, 16, 32, 2, 1, "SAME", 1, False, False),
+        (1, 64, 8, 64, 3, 1, (1, 2, 2, 1), 1, False, True),
+        (1, 64, 8, 64, 5, 2, (1, 3), 1, True, False),
+        (1, 64, 56, 64, 3, 1, "VALID", 1, True, True),
+        (1, 64, 56, 64, 24, 1, "SAME", 1, True, True),
+    ],
+)
+def test_conv2d_NCHWc_int8(in_dtype, params):
+    with Int8Fallback():
         (
-            "llvm -device arm_cpu -mtriple aarch64-linux-gnu -mattr=+neon,+v8.2a,+dotprod",
-            topi.arm_cpu.conv2d_NCHWc_int8,
-            topi.arm_cpu.schedule_conv2d_NCHWc_int8,
-            8,
-            build_only_aarch64,
+            batch,
+            in_channel,
+            in_size,
+            num_filter,
+            kernel,
+            stride,
+            padding,
+            dilation,
+            add_bias,
+            add_relu,
+        ) = params
+        pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (kernel, kernel))
+        padding_sum = pad_top + pad_left + pad_bottom + pad_right
+        print(
+            "Workload: (%d, %d, %d, %d, %d, %d, %d, %d)"
+            % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation)
         )
-    )
-
-    if in_dtype == "int8":
-        targets += [
-            (
-                "llvm -device arm_cpu -mtriple aarch64-linux-gnu -mattr=+neon",
-                topi.arm_cpu.conv2d_NCHWc_int8,
-                topi.arm_cpu.schedule_conv2d_NCHWc_int8,
-                8,
-                build_only_aarch64,
-            ),
-            (
-                "rocm -mattr=+dotprod",
-                lambda a, w, s, p, d, l, ol, o: topi.cuda.conv2d_NCHWc_int8(a, w, s, p, d, l, o),
-                topi.cuda.schedule_conv2d_NCHWc_int8,
-                4,
-                False,
-            ),
-        ]
-
-    for target, compute, schedule, oc_block_factor, build_only in targets:
-        check_target(target, compute, schedule, oc_block_factor, build_only)
-
-
-def verify_conv2d_nchw_int8(
-    in_dtype,
-    batch,
-    in_channel,
-    in_size,
-    num_filter,
-    kernel,
-    stride,
-    padding,
-    dilation=1,
-    add_bias=False,
-    add_relu=False,
-):
-    pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (kernel, kernel))
-    padding_sum = pad_top + pad_left + pad_bottom + pad_right
-    print(
-        "Workload: (%d, %d, %d, %d, %d, %d, %d, %d)"
-        % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation)
-    )
-
-    in_height = in_width = in_size
-
-    A = te.placeholder((batch, in_channel, in_height, in_width), name="A", dtype=in_dtype)
-    W = te.placeholder((num_filter, in_channel, kernel, kernel), name="W", dtype=in_dtype)
-    bias = te.placeholder((num_filter, 1, 1), name="bias", dtype=in_dtype)
-
-    a_shape = get_const_tuple(A.shape)
-    w_shape = get_const_tuple(W.shape)
-    bias_shape = get_const_tuple(bias.shape)
-    dtype = A.dtype
-
-    @memoize("topi.tests.test_topi_conv2d_int8.verify_conv2d_nchw")
-    def get_ref_data():
-        a_np = np.random.randint(low=-128, high=127, size=a_shape).astype(dtype)
-        w_np = np.random.randint(low=-128, high=128, size=w_shape).astype(dtype)
-        b_np = np.random.uniform(size=bias_shape).astype(dtype)
-        dw_np = tvm.topi.testing.dilate_python(w_np, (1, 1, dilation, dilation))
-        c_np = tvm.topi.testing.conv2d_nchw_python(a_np, dw_np, stride, padding).astype(dtype)
-
-        if add_bias:
-            b_np = np.random.uniform(size=bias_shape).astype(dtype)
-            c_np += b_np
-        if add_relu:
-            c_np = np.maximum(c_np, 0)
-
-        return a_np, w_np, b_np, c_np
-
-    a_np, w_np, b_np, c_np = get_ref_data()
-
-    def verify_workload_padding():
-        _, _, out_height, out_width = get_const_tuple(c_np.shape)
-        wkl = _get_workload(A, W, (stride, stride), padding, dilation, dtype)
-
-        # for testing functionality,
-        # we choose arbitrary int32_lanes and num_int8_elements can divide the channel,
-        # regardless of the performance.
-        int32_lanes, num_int8_elements = num_filter, in_channel
 
-        # check if tile_ow candidates are the factors of the right output weight.
-        cfg = autotvm.get_config()
-        fallback_schedule_cpu_common_int8(cfg, wkl, int32_lanes, num_int8_elements)
-        ow_tile = np.prod(cfg["tile_ow"].size)
-
-        tvm.testing.assert_allclose(ow_tile, out_width)
+        in_height = in_width = in_size
+
+        A = te.placeholder((batch, in_channel, in_height, in_width), name="A", dtype=in_dtype)
+        W = te.placeholder((num_filter, in_channel, kernel, kernel), name="W", dtype=in_dtype)
+
+        a_shape = get_const_tuple(A.shape)
+        w_shape = get_const_tuple(W.shape)
+        dtype = A.dtype
+        out_dtype = "int32" if in_dtype == "int8" else "uint32"
+        lo = -128 if in_dtype == "int8" else 0
+        hi = 127 if in_dtype == "int8" else 255
+
+        def check_target(target, compute, schedule, oc_block_factor, build_only):
+            dev = tvm.device(target, 0)
+            if not tvm.testing.device_enabled(target):
+                print("Skip because %s is not enabled" % target)
+                return
+            if target == "cuda" and not tvm.contrib.nvcc.have_int8(dev.compute_version):
+                print("Skip because int8 intrinsics are not available")
+                return
+
+            bias = te.placeholder(
+                (num_filter // oc_block_factor, 1, 1, oc_block_factor), name="bias", dtype=out_dtype
+            )
+            bias_shape = get_const_tuple(bias.shape)
 
-    def check_target(target):
-        dev = tvm.device(target, 0)
-        if not tvm.testing.device_enabled(target):
-            print("Skip because %s is not enabled" % target)
-            return
-        if target == "cuda" and not tvm.contrib.nvcc.have_int8(dev.compute_version):
-            print("Skip because int8 intrinsics are not available")
-            return
+            @memoize("topi.tests.test_topi_conv2d_int8.verify_conv2d_nchw")
+            def get_ref_data():
+                a_np = np.random.randint(low=lo, high=hi, size=a_shape).astype(out_dtype)
+                w_np = np.random.randint(low=lo, high=hi, size=w_shape).astype(out_dtype)
+                b_np = np.random.uniform(size=bias_shape).astype(out_dtype)
+                dw_np = tvm.topi.testing.dilate_python(w_np, (1, 1, dilation, dilation))
+                c_np = tvm.topi.testing.conv2d_nchw_python(a_np, dw_np, stride, padding).astype(
+                    out_dtype
+                )
+
+                # convert to NCHWc
+                _, _, out_height, out_width = c_np.shape
+                c_np = c_np.reshape(
+                    (batch, num_filter // oc_block_factor, oc_block_factor, out_height, out_width)
+                ).transpose(0, 1, 3, 4, 2)
+
+                if add_bias:
+                    b_np = np.random.uniform(size=bias_shape).astype(out_dtype)
+                    c_np += b_np
+                if add_relu:
+                    c_np = np.maximum(c_np, 0)
+
+                return a_np, w_np, b_np, c_np
+
+            a_np, w_np, b_np, c_np = get_ref_data()
+
+            with tvm.target.Target(target):
+                C = compute(
+                    A,
+                    W,
+                    (stride, stride),
+                    padding,
+                    (dilation, dilation),
+                    "NCHW",
+                    "NCHW",
+                    out_dtype,
+                )
+                if add_bias:
+                    C = topi.add(C, bias)
+                if add_relu:
+                    C = topi.nn.relu(C)
+                s = schedule([C])
+
+            a = tvm.nd.array(a_np.astype(dtype), dev)
+            w = tvm.nd.array(w_np.astype(dtype), dev)
+            b = tvm.nd.array(b_np.astype(out_dtype), dev)
+            c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), dev)
 
-        print("Running on target: %s" % target)
-        with tvm.target.Target(target):
-            C = topi.cuda.conv2d_nchw_int8(
-                A, W, (stride, stride), padding, (dilation, dilation), dtype
-            )
             if add_bias:
-                C = topi.add(C, bias)
-            if add_relu:
-                C = topi.nn.relu(C)
-            s = topi.cuda.schedule_conv2d_nchw_int8([C])
-
-        a = tvm.nd.array(a_np, dev)
-        w = tvm.nd.array(w_np, dev)
-        b = tvm.nd.array(b_np, dev)
-        c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), dev)
-        if add_bias:
-            tvm.build(
-                s,
-                [A, W, bias, C],
-                target,
-                name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
-                % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation),
-            )
-            func = tvm.build(
-                s,
-                [A, W, bias, C],
-                target,
-                name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
-                % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation),
-            )
-            func(a, w, b, c)
-        else:
+                compile_args = [A, W, bias, C]
+                run_args = [a, w, b, c]
+            else:
+                compile_args = [A, W, C]
+                run_args = [a, w, c]
+
             func = tvm.build(
                 s,
-                [A, W, C],
+                compile_args,
                 target,
                 name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
                 % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation),
             )
-            func(a, w, c)
-        tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-5)
 
-    verify_workload_padding()
+            if build_only:
+                return
 
-    for target in ["cuda"]:
-        check_target(target)
+            print("Running on target: %s" % target)
 
+            func(*run_args)
 
-@pytest.mark.parametrize("in_dtype", ["int8", "uint8"])
-def test_conv2d_nchw(in_dtype):
-    with Int8Fallback():
-        # ResNet18 workloads where channels in / out are multiple of oc_block_factor
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 56, 64, 3, 1, 1)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 56, 64, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 56, 128, 3, 2, 1)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 56, 128, 1, 2, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 128, 28, 128, 3, 1, 1)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 128, 28, 256, 3, 2, 1)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 128, 28, 256, 1, 2, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 256, 14, 256, 3, 1, 1)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 256, 14, 512, 3, 2, 1)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 256, 14, 512, 1, 2, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 512, 7, 512, 3, 1, 1)
+            tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-5)
 
-        # bias, relu
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 56, 64, 3, 1, 1, add_relu=True)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 56, 64, 3, 1, 1, add_bias=True)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 56, 64, 3, 1, 1, add_bias=True, add_relu=True)
+        targets = [
+            (
+                "cuda",
+                lambda a, w, s, p, d, l, ol, o: topi.cuda.conv2d_NCHWc_int8(a, w, s, p, d, l, o),
+                topi.cuda.schedule_conv2d_NCHWc_int8,
+                4,
+                False,
+            ),
+            # Disable on CI since it does not support spirv int8 dot product
+            # (
+            #     "vulkan -from_device=0",
+            #     lambda a, w, s, p, d, l, ol, o: topi.cuda.conv2d_NCHWc_int8(a, w, s, p, d, l, o),
+            #     topi.cuda.schedule_conv2d_NCHWc_int8,
+            #     4,
+            #     False,
+            # ),
+        ]
 
-        # dilation = 2
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 56, 64, 3, 1, 1, dilation=2)
+        build_only_aarch64 = platform.machine() != "aarch64"
 
-        # batch size
-        verify_conv2d_NCHWc_int8(in_dtype, 4, 64, 56, 64, 3, 1, 1)
-        verify_conv2d_NCHWc_int8(in_dtype, 9, 64, 56, 64, 3, 1, 1)
+        targets.append(
+            (
+                "llvm -device arm_cpu -mtriple aarch64-linux-gnu -mattr=+neon,+v8.2a,+dotprod",
+                topi.arm_cpu.conv2d_NCHWc_int8,
+                topi.arm_cpu.schedule_conv2d_NCHWc_int8,
+                8,
+                build_only_aarch64,
+            )
+        )
 
-        # weird workloads
-        verify_conv2d_NCHWc_int8(in_dtype, 4, 4, 4, 8, 4, 4, 4)
+        if in_dtype == "int8":
+            targets += [
+                (
+                    "llvm -device arm_cpu -mtriple aarch64-linux-gnu -mattr=+neon",
+                    topi.arm_cpu.conv2d_NCHWc_int8,
+                    topi.arm_cpu.schedule_conv2d_NCHWc_int8,
+                    8,
+                    build_only_aarch64,
+                ),
+                (
+                    "rocm -mattr=+dotprod",
+                    lambda a, w, s, p, d, l, ol, o: topi.cuda.conv2d_NCHWc_int8(
+                        a, w, s, p, d, l, o
+                    ),
+                    topi.cuda.schedule_conv2d_NCHWc_int8,
+                    4,
+                    False,
+                ),
+            ]
+
+        for target, compute, schedule, oc_block_factor, build_only in targets:
+            check_target(target, compute, schedule, oc_block_factor, build_only)
+
+
+# Conv2d NCHW int8 schedule testing. Internally, it uses NCHWc schedule. So, just
+# performing basic testing - one test for all different scenarios - batch, dilation etc..
+@pytest.mark.parametrize("in_dtype", ["int8", "uint8"])
+@pytest.mark.parametrize(
+    "params",
+    [
+        (1, 64, 56, 64, 3, 1, 1, 1, False, False),
+        (1, 64, 56, 64, 3, 1, 1, 1, False, True),
+        (1, 64, 56, 64, 3, 1, 1, 2, False, False),
+        (9, 64, 56, 64, 3, 1, 1, 1, False, False),
+        (4, 4, 4, 4, 4, 4, 4, 1, False, False),
+        (1, 32, 149, 32, 3, 1, 0, 1, False, False),
+        (7, 32, 149, 32, 3, 1, 0, 1, False, False),
+        (1, 32, 35, 64, 7, 2, (0, 0, 1, 1), 1, False, False),
+        (1, 32, 35, 64, 7, 2, (0, 0, 2, 2), 1, False, False),
+    ],
+)
+def test_conv2d_nchw_int8(in_dtype, params):
+    with Int8Fallback():
+        (
+            batch,
+            in_channel,
+            in_size,
+            num_filter,
+            kernel,
+            stride,
+            padding,
+            dilation,
+            add_bias,
+            add_relu,
+        ) = params
+        pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (kernel, kernel))
+        padding_sum = pad_top + pad_left + pad_bottom + pad_right
+        print(
+            "Workload: (%d, %d, %d, %d, %d, %d, %d, %d)"
+            % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation)
+        )
 
-        # inception v3 workloads where channels in / out are multiple of oc_block_factor
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 32, 149, 32, 3, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 32, 147, 64, 3, 1, 1)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 73, 80, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 80, 73, 192, 3, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 192, 35, 64, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 192, 35, 48, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 48, 35, 64, 5, 1, 2)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 35, 96, 3, 1, 1)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 96, 35, 96, 3, 1, 1)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 192, 35, 32, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 256, 35, 64, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 256, 35, 48, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 288, 35, 64, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 288, 35, 48, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 288, 35, 384, 3, 2, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 96, 35, 96, 3, 2, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 768, 17, 192, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 768, 17, 128, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 128, 17, 128, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 128, 17, 192, 7, 1, 3)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 128, 17, 128, 7, 1, 3)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 128, 17, 192, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 768, 17, 160, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 160, 17, 160, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 160, 17, 192, 7, 1, 3)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 160, 17, 160, 7, 1, 3)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 160, 17, 192, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 192, 17, 192, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 192, 17, 192, 7, 1, 3)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 192, 17, 320, 3, 2, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 192, 17, 192, 3, 2, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 1280, 8, 320, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 1280, 8, 384, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 384, 8, 384, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 384, 8, 384, 3, 1, 1)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 1280, 8, 448, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 448, 8, 384, 3, 1, 1)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 1280, 8, 192, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 2048, 8, 320, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 2048, 8, 384, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 2048, 8, 448, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 2048, 8, 192, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 1024, 19, 88, 3, 1, 1)
+        in_height = in_width = in_size
 
-        # batch > 1
-        verify_conv2d_NCHWc_int8(in_dtype, 7, 32, 149, 32, 3, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 8, 32, 149, 32, 3, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 32, 32, 149, 32, 3, 1, 0)
+        A = te.placeholder((batch, in_channel, in_height, in_width), name="A", dtype=in_dtype)
+        W = te.placeholder((num_filter, in_channel, kernel, kernel), name="W", dtype=in_dtype)
+        bias = te.placeholder((num_filter, 1, 1), name="bias", dtype=in_dtype)
 
-        # Asymmetric padding
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 32, 35, 64, 7, 2, (0, 0, 1, 1))
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 8, 128, 3, 1, (3, 3, 2, 2))
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 8, 64, 1, 1, (1, 2, 2, 1))
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 17, 192, 1, 1, (1, 2))
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 8, 64, 3, 1, (3, 1))
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 128, 8, 384, 3, 1, (0, 2))
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 8, 64, 1, 1, "VALID")
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 392, 8, 64, 3, 1, "VALID")
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 512, 19, 64, 1, 1, "SAME")
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 16, 32, 2, 1, "SAME")
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 8, 64, 3, 1, (1, 2, 2, 1), add_relu=True)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 8, 64, 5, 2, (1, 3), add_bias=True)
-        verify_conv2d_NCHWc_int8(
-            in_dtype, 1, 64, 56, 64, 3, 1, "VALID", add_bias=True, add_relu=True
-        )
-        verify_conv2d_NCHWc_int8(
-            in_dtype, 1, 64, 56, 64, 24, 1, "SAME", add_bias=True, add_relu=True
-        )
+        a_shape = get_const_tuple(A.shape)
+        w_shape = get_const_tuple(W.shape)
+        bias_shape = get_const_tuple(bias.shape)
+        dtype = A.dtype
+
+        @memoize("topi.tests.test_topi_conv2d_int8.verify_conv2d_nchw")
+        def get_ref_data():
+            a_np = np.random.randint(low=-128, high=127, size=a_shape).astype(dtype)
+            w_np = np.random.randint(low=-128, high=128, size=w_shape).astype(dtype)
+            b_np = np.random.uniform(size=bias_shape).astype(dtype)
+            dw_np = tvm.topi.testing.dilate_python(w_np, (1, 1, dilation, dilation))
+            c_np = tvm.topi.testing.conv2d_nchw_python(a_np, dw_np, stride, padding).astype(dtype)
 
-        # Conv2d NCHW int8 schedule testing. Internally, it uses NCHWc schedule. So, just
-        # performing basic testing - one test for all different scenarios - batch, dilation etc..
-        verify_conv2d_nchw_int8(in_dtype, 1, 64, 56, 64, 3, 1, 1)
-        verify_conv2d_nchw_int8(in_dtype, 1, 64, 56, 64, 3, 1, 1, add_relu=True)
-        verify_conv2d_nchw_int8(in_dtype, 1, 64, 56, 64, 3, 1, 1, dilation=2)
-        verify_conv2d_nchw_int8(in_dtype, 9, 64, 56, 64, 3, 1, 1)
-        verify_conv2d_nchw_int8(in_dtype, 4, 4, 4, 4, 4, 4, 4)
-        verify_conv2d_nchw_int8(in_dtype, 1, 32, 149, 32, 3, 1, 0)
-        verify_conv2d_nchw_int8(in_dtype, 7, 32, 149, 32, 3, 1, 0)
-        verify_conv2d_nchw_int8(in_dtype, 1, 32, 35, 64, 7, 2, (0, 0, 1, 1))
-        verify_conv2d_nchw_int8(in_dtype, 1, 32, 35, 64, 7, 2, (0, 0, 2, 2))
+            if add_bias:
+                b_np = np.random.uniform(size=bias_shape).astype(dtype)
+                c_np += b_np
+            if add_relu:
+                c_np = np.maximum(c_np, 0)
 
+            return a_np, w_np, b_np, c_np
 
-def test_conv2d_nhwc():
-    with Int8Fallback():
-        # Subset of inception v3 expanded (dilation > 1, batch > 1, 'VALID' padding)
-        verify_conv2d_NHWC_gemm_int8(1, 3, 299, 32, 3, 2, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 32, 149, 32, 3, 1, "SAME", dilation=2)
-        verify_conv2d_NHWC_gemm_int8(4, 32, 147, 64, 3, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 64, 73, 80, 1, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 80, 73, 192, 3, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 192, 35, 48, 1, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 192, 35, 64, 1, 1, "VALID")
-        verify_conv2d_NHWC_gemm_int8(1, 192, 35, 32, 1, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 48, 35, 64, 5, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 96, 35, 96, 3, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 256, 35, 48, 1, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 256, 35, 64, 1, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 288, 35, 64, 1, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 288, 35, 48, 1, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 96, 35, 96, 3, 2, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 128, 17, 192, 7, 1, "SAME", dilation=2)
-        verify_conv2d_NHWC_gemm_int8(1, 160, 17, 160, 7, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 160, 17, 192, 1, 1, "VALID")
-        verify_conv2d_NHWC_gemm_int8(1, 192, 17, 192, 1, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 768, 5, 128, 1, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 192, 17, 320, 3, 2, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 192, 17, 192, 3, 2, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 1280, 8, 192, 1, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 1280, 8, 384, 1, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 1280, 8, 320, 1, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 1280, 8, 448, 1, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 384, 8, 384, 1, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 384, 8, 384, 3, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 448, 8, 384, 3, 1, "VALID")
-        verify_conv2d_NHWC_gemm_int8(1, 2048, 8, 320, 1, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 2048, 8, 448, 1, 1, "SAME", add_bias=True, add_relu=True)
-        verify_conv2d_NHWC_gemm_int8(1, 2048, 8, 192, 1, 1, "SAME", add_bias=True)
-
-        # Let's also verify that it compiles fine on AArch64 targets
-        compile_conv2d_NHWC_gemm_int8_arm(1, 3, 299, 32, 3, 2, "SAME")
+        a_np, w_np, b_np, c_np = get_ref_data()
+
+        def verify_workload_padding():
+            _, _, out_height, out_width = get_const_tuple(c_np.shape)
+            wkl = _get_workload(A, W, (stride, stride), padding, dilation, dtype)
+
+            # for testing functionality,
+            # we choose arbitrary int32_lanes and num_int8_elements can divide the channel,
+            # regardless of the performance.
+            int32_lanes, num_int8_elements = num_filter, in_channel
+
+            # check if tile_ow candidates are the factors of the right output weight.
+            cfg = autotvm.get_config()
+            fallback_schedule_cpu_common_int8(cfg, wkl, int32_lanes, num_int8_elements)
+            ow_tile = np.prod(cfg["tile_ow"].size)
+
+            tvm.testing.assert_allclose(ow_tile, out_width)
+
+        def check_target(target):
+            dev = tvm.device(target, 0)
+            if not tvm.testing.device_enabled(target):
+                print("Skip because %s is not enabled" % target)
+                return
+            if target == "cuda" and not tvm.contrib.nvcc.have_int8(dev.compute_version):
+                print("Skip because int8 intrinsics are not available")
+                return
+
+            print("Running on target: %s" % target)
+            with tvm.target.Target(target):
+                C = topi.cuda.conv2d_nchw_int8(
+                    A, W, (stride, stride), padding, (dilation, dilation), dtype
+                )
+                if add_bias:
+                    C = topi.add(C, bias)
+                if add_relu:
+                    C = topi.nn.relu(C)
+                s = topi.cuda.schedule_conv2d_nchw_int8([C])
+
+            a = tvm.nd.array(a_np, dev)
+            w = tvm.nd.array(w_np, dev)
+            b = tvm.nd.array(b_np, dev)
+            c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), dev)
+            if add_bias:
+                func = tvm.build(
+                    s,
+                    [A, W, bias, C],
+                    target,
+                    name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
+                    % (
+                        batch,
+                        in_channel,
+                        in_size,
+                        num_filter,
+                        kernel,
+                        stride,
+                        padding_sum,
+                        dilation,
+                    ),
+                )
+                func(a, w, b, c)
+            else:
+                func = tvm.build(
+                    s,
+                    [A, W, C],
+                    target,
+                    name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
+                    % (
+                        batch,
+                        in_channel,
+                        in_size,
+                        num_filter,
+                        kernel,
+                        stride,
+                        padding_sum,
+                        dilation,
+                    ),
+                )
+                func(a, w, c)
+            tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-5)
+
+        verify_workload_padding()
+
+        for target in ["cuda"]:
+            check_target(target)

Review Comment:
   ```suggestion
           check_target("cuda")
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] ekalda commented on a diff in pull request #13669: [TOPI][bugfix] Fix a bug in arm_cpu int8 dotprod schedule and modernize tests

Posted by GitBox <gi...@apache.org>.
ekalda commented on code in PR #13669:
URL: https://github.com/apache/tvm/pull/13669#discussion_r1058905849


##########
tests/python/topi/python/test_topi_conv2d_int8.py:
##########
@@ -298,378 +175,462 @@ def get_ref_data():
 
         a_np, w_np, b_np, c_np = get_ref_data()
 
-        with tvm.target.Target(target):
-            C = compute(
-                A,
-                W,
-                (stride, stride),
-                padding,
-                (dilation, dilation),
-                "NCHW",
-                "NCHW",
-                out_dtype,
-            )
+        dev = tvm.device(target, 0)
+        with tvm.target.Target(target) as tvm_target:
+            C = compute(A, W, (stride, stride), padding, (dilation, dilation), dtype)
             if add_bias:
                 C = topi.add(C, bias)
             if add_relu:
                 C = topi.nn.relu(C)
             s = schedule([C])
 
-        a = tvm.nd.array(a_np.astype(dtype), dev)
-        w = tvm.nd.array(w_np.astype(dtype), dev)
-        b = tvm.nd.array(b_np.astype(out_dtype), dev)
-        c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), dev)
-
-        if add_bias:
-            compile_args = [A, W, bias, C]
-            run_args = [a, w, b, c]
-        else:
-            compile_args = [A, W, C]
-            run_args = [a, w, c]
-
-        func = tvm.build(
-            s,
-            compile_args,
-            target,
-            name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
-            % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation),
-        )
+            a = tvm.nd.array(a_np, dev)
+            w = tvm.nd.array(w_np, dev)
+            b = tvm.nd.array(b_np, dev)
+            c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), dev)
 
-        if build_only:
-            return
+            build_inputs = [A, W, bias, C] if add_bias else [A, W, C]
+            inference_inputs = (a, w, b, c) if add_bias else (a, w, c)
+
+            func = tvm.build(
+                s,
+                build_inputs,
+                target,
+                name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
+                % (
+                    batch,
+                    in_channel,
+                    in_size,
+                    num_filter,
+                    kernel,
+                    stride,
+                    padding_sum,
+                    dilation,
+                ),
+            )
 
-        print("Running on target: %s" % target)
+            build_only = tvm_target.features.is_aarch64 and (platform.machine() != "aarch64")
 
-        func(*run_args)
+            if not build_only:
+                print("Running on target: %s" % target)
+                func(*inference_inputs)
+                tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-5)
 
-        tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-5)
 
-    targets = [
-        (
-            "cuda",
-            lambda a, w, s, p, d, l, ol, o: topi.cuda.conv2d_NCHWc_int8(a, w, s, p, d, l, o),
-            topi.cuda.schedule_conv2d_NCHWc_int8,
-            4,
-            False,
-        ),
-        # Disable on CI since it does not support spirv int8 dot product
-        # (
-        #     "vulkan -from_device=0",
-        #     lambda a, w, s, p, d, l, ol, o: topi.cuda.conv2d_NCHWc_int8(a, w, s, p, d, l, o),
-        #     topi.cuda.schedule_conv2d_NCHWc_int8,
-        #     4,
-        #     False,
-        # ),
-    ]
-
-    build_only_aarch64 = platform.machine() != "aarch64"
-
-    targets.append(
+@pytest.mark.parametrize("in_dtype", ["int8", "uint8"])
+@pytest.mark.parametrize(
+    "params",
+    [
+        # ResNet18 workloads where channels in / out are multiple of oc_block_factor
+        (1, 64, 56, 64, 3, 1, 1, 1, False, False),
+        (1, 64, 56, 64, 1, 1, 0, 1, False, False),
+        (1, 64, 56, 128, 3, 2, 1, 1, False, False),
+        (1, 64, 56, 128, 1, 2, 0, 1, False, False),
+        (1, 128, 28, 128, 3, 1, 1, 1, False, False),
+        (1, 128, 28, 256, 3, 2, 1, 1, False, False),
+        (1, 128, 28, 256, 1, 2, 0, 1, False, False),
+        (1, 256, 14, 256, 3, 1, 1, 1, False, False),
+        (1, 256, 14, 512, 3, 2, 1, 1, False, False),
+        (1, 256, 14, 512, 1, 2, 0, 1, False, False),
+        (1, 512, 7, 512, 3, 1, 1, 1, False, False),
+        # bias, relu
+        (1, 64, 56, 64, 3, 1, 1, 1, False, True),
+        (1, 64, 56, 64, 3, 1, 1, 1, True, False),
+        (1, 64, 56, 64, 3, 1, 1, 1, True, True),
+        # dilation = 2
+        (1, 64, 56, 64, 3, 1, 1, 2, False, False),
+        # batch size
+        (4, 64, 56, 64, 3, 1, 1, 1, False, False),
+        (9, 64, 56, 64, 3, 1, 1, 1, False, False),
+        # weird workloads
+        (4, 4, 4, 8, 4, 4, 4, 1, False, False),
+        # inception v3 workloads where channels in / out are multiple of oc_block_factor
+        (1, 32, 149, 32, 3, 1, 0, 1, False, False),
+        (1, 32, 147, 64, 3, 1, 1, 1, False, False),
+        (1, 64, 73, 80, 1, 1, 0, 1, False, False),
+        (1, 80, 73, 192, 3, 1, 0, 1, False, False),
+        (1, 192, 35, 64, 1, 1, 0, 1, False, False),
+        (1, 192, 35, 48, 1, 1, 0, 1, False, False),
+        (1, 48, 35, 64, 5, 1, 2, 1, False, False),
+        (1, 64, 35, 96, 3, 1, 1, 1, False, False),
+        (1, 96, 35, 96, 3, 1, 1, 1, False, False),
+        (1, 192, 35, 32, 1, 1, 0, 1, False, False),
+        (1, 256, 35, 64, 1, 1, 0, 1, False, False),
+        (1, 256, 35, 48, 1, 1, 0, 1, False, False),
+        (1, 288, 35, 64, 1, 1, 0, 1, False, False),
+        (1, 288, 35, 48, 1, 1, 0, 1, False, False),
+        (1, 288, 35, 384, 3, 2, 0, 1, False, False),
+        (1, 96, 35, 96, 3, 2, 0, 1, False, False),
+        (1, 768, 17, 192, 1, 1, 0, 1, False, False),
+        (1, 768, 17, 128, 1, 1, 0, 1, False, False),
+        (1, 128, 17, 128, 1, 1, 0, 1, False, False),
+        (1, 128, 17, 192, 7, 1, 3, 1, False, False),
+        (1, 128, 17, 128, 7, 1, 3, 1, False, False),
+        (1, 128, 17, 192, 1, 1, 0, 1, False, False),
+        (1, 768, 17, 160, 1, 1, 0, 1, False, False),
+        (1, 160, 17, 160, 1, 1, 0, 1, False, False),
+        (1, 160, 17, 192, 7, 1, 3, 1, False, False),
+        (1, 160, 17, 160, 7, 1, 3, 1, False, False),
+        (1, 160, 17, 192, 1, 1, 0, 1, False, False),
+        (1, 192, 17, 192, 1, 1, 0, 1, False, False),
+        (1, 192, 17, 192, 7, 1, 3, 1, False, False),
+        (1, 192, 17, 320, 3, 2, 0, 1, False, False),
+        (1, 192, 17, 192, 3, 2, 0, 1, False, False),
+        (1, 1280, 8, 320, 1, 1, 0, 1, False, False),
+        (1, 1280, 8, 384, 1, 1, 0, 1, False, False),
+        (1, 384, 8, 384, 1, 1, 0, 1, False, False),
+        (1, 384, 8, 384, 3, 1, 1, 1, False, False),
+        (1, 1280, 8, 448, 1, 1, 0, 1, False, False),
+        (1, 448, 8, 384, 3, 1, 1, 1, False, False),
+        (1, 1280, 8, 192, 1, 1, 0, 1, False, False),
+        (1, 2048, 8, 320, 1, 1, 0, 1, False, False),
+        (1, 2048, 8, 384, 1, 1, 0, 1, False, False),
+        (1, 2048, 8, 448, 1, 1, 0, 1, False, False),
+        (1, 2048, 8, 192, 1, 1, 0, 1, False, False),
+        (1, 1024, 19, 88, 3, 1, 1, 1, False, False),
+        # batch > 1
+        (7, 32, 149, 32, 3, 1, 0, 1, False, False),
+        (8, 32, 149, 32, 3, 1, 0, 1, False, False),
+        (32, 32, 149, 32, 3, 1, 0, 1, False, False),
+        # Asymmetric padding
+        (1, 32, 35, 64, 7, 2, (0, 0, 1, 1), 1, False, False),
+        (1, 64, 8, 128, 3, 1, (3, 3, 2, 2), 1, False, False),
+        (1, 64, 8, 64, 1, 1, (1, 2, 2, 1), 1, False, False),
+        (1, 64, 17, 192, 1, 1, (1, 2), 1, False, False),
+        (1, 64, 8, 64, 3, 1, (3, 1), 1, False, False),
+        (1, 128, 8, 384, 3, 1, (0, 2), 1, False, False),
+        (1, 64, 8, 64, 1, 1, "VALID", 1, False, False),
+        (1, 392, 8, 64, 3, 1, "VALID", 1, False, False),
+        (1, 512, 19, 64, 1, 1, "SAME", 1, False, False),
+        (1, 64, 16, 32, 2, 1, "SAME", 1, False, False),
+        (1, 64, 8, 64, 3, 1, (1, 2, 2, 1), 1, False, True),
+        (1, 64, 8, 64, 5, 2, (1, 3), 1, True, False),
+        (1, 64, 56, 64, 3, 1, "VALID", 1, True, True),
+        (1, 64, 56, 64, 24, 1, "SAME", 1, True, True),
+    ],
+)
+def test_conv2d_NCHWc_int8(in_dtype, params):
+    with Int8Fallback():
         (
-            "llvm -device arm_cpu -mtriple aarch64-linux-gnu -mattr=+neon,+v8.2a,+dotprod",
-            topi.arm_cpu.conv2d_NCHWc_int8,
-            topi.arm_cpu.schedule_conv2d_NCHWc_int8,
-            8,
-            build_only_aarch64,
+            batch,
+            in_channel,
+            in_size,
+            num_filter,
+            kernel,
+            stride,
+            padding,
+            dilation,
+            add_bias,
+            add_relu,
+        ) = params
+        pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (kernel, kernel))
+        padding_sum = pad_top + pad_left + pad_bottom + pad_right
+        print(
+            "Workload: (%d, %d, %d, %d, %d, %d, %d, %d)"
+            % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation)
         )
-    )
-
-    if in_dtype == "int8":
-        targets += [
-            (
-                "llvm -device arm_cpu -mtriple aarch64-linux-gnu -mattr=+neon",
-                topi.arm_cpu.conv2d_NCHWc_int8,
-                topi.arm_cpu.schedule_conv2d_NCHWc_int8,
-                8,
-                build_only_aarch64,
-            ),
-            (
-                "rocm -mattr=+dotprod",
-                lambda a, w, s, p, d, l, ol, o: topi.cuda.conv2d_NCHWc_int8(a, w, s, p, d, l, o),
-                topi.cuda.schedule_conv2d_NCHWc_int8,
-                4,
-                False,
-            ),
-        ]
-
-    for target, compute, schedule, oc_block_factor, build_only in targets:
-        check_target(target, compute, schedule, oc_block_factor, build_only)
-
-
-def verify_conv2d_nchw_int8(
-    in_dtype,
-    batch,
-    in_channel,
-    in_size,
-    num_filter,
-    kernel,
-    stride,
-    padding,
-    dilation=1,
-    add_bias=False,
-    add_relu=False,
-):
-    pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (kernel, kernel))
-    padding_sum = pad_top + pad_left + pad_bottom + pad_right
-    print(
-        "Workload: (%d, %d, %d, %d, %d, %d, %d, %d)"
-        % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation)
-    )
-
-    in_height = in_width = in_size
-
-    A = te.placeholder((batch, in_channel, in_height, in_width), name="A", dtype=in_dtype)
-    W = te.placeholder((num_filter, in_channel, kernel, kernel), name="W", dtype=in_dtype)
-    bias = te.placeholder((num_filter, 1, 1), name="bias", dtype=in_dtype)
-
-    a_shape = get_const_tuple(A.shape)
-    w_shape = get_const_tuple(W.shape)
-    bias_shape = get_const_tuple(bias.shape)
-    dtype = A.dtype
-
-    @memoize("topi.tests.test_topi_conv2d_int8.verify_conv2d_nchw")
-    def get_ref_data():
-        a_np = np.random.randint(low=-128, high=127, size=a_shape).astype(dtype)
-        w_np = np.random.randint(low=-128, high=128, size=w_shape).astype(dtype)
-        b_np = np.random.uniform(size=bias_shape).astype(dtype)
-        dw_np = tvm.topi.testing.dilate_python(w_np, (1, 1, dilation, dilation))
-        c_np = tvm.topi.testing.conv2d_nchw_python(a_np, dw_np, stride, padding).astype(dtype)
-
-        if add_bias:
-            b_np = np.random.uniform(size=bias_shape).astype(dtype)
-            c_np += b_np
-        if add_relu:
-            c_np = np.maximum(c_np, 0)
-
-        return a_np, w_np, b_np, c_np
-
-    a_np, w_np, b_np, c_np = get_ref_data()
-
-    def verify_workload_padding():
-        _, _, out_height, out_width = get_const_tuple(c_np.shape)
-        wkl = _get_workload(A, W, (stride, stride), padding, dilation, dtype)
-
-        # for testing functionality,
-        # we choose arbitrary int32_lanes and num_int8_elements can divide the channel,
-        # regardless of the performance.
-        int32_lanes, num_int8_elements = num_filter, in_channel
 
-        # check if tile_ow candidates are the factors of the right output weight.
-        cfg = autotvm.get_config()
-        fallback_schedule_cpu_common_int8(cfg, wkl, int32_lanes, num_int8_elements)
-        ow_tile = np.prod(cfg["tile_ow"].size)
-
-        tvm.testing.assert_allclose(ow_tile, out_width)
+        in_height = in_width = in_size
+
+        A = te.placeholder((batch, in_channel, in_height, in_width), name="A", dtype=in_dtype)
+        W = te.placeholder((num_filter, in_channel, kernel, kernel), name="W", dtype=in_dtype)
+
+        a_shape = get_const_tuple(A.shape)
+        w_shape = get_const_tuple(W.shape)
+        dtype = A.dtype
+        out_dtype = "int32" if in_dtype == "int8" else "uint32"
+        lo = -128 if in_dtype == "int8" else 0
+        hi = 127 if in_dtype == "int8" else 255
+
+        def check_target(target, compute, schedule, oc_block_factor, build_only):
+            dev = tvm.device(target, 0)
+            if not tvm.testing.device_enabled(target):
+                print("Skip because %s is not enabled" % target)
+                return
+            if target == "cuda" and not tvm.contrib.nvcc.have_int8(dev.compute_version):
+                print("Skip because int8 intrinsics are not available")
+                return
+
+            bias = te.placeholder(
+                (num_filter // oc_block_factor, 1, 1, oc_block_factor), name="bias", dtype=out_dtype
+            )
+            bias_shape = get_const_tuple(bias.shape)
 
-    def check_target(target):
-        dev = tvm.device(target, 0)
-        if not tvm.testing.device_enabled(target):
-            print("Skip because %s is not enabled" % target)
-            return
-        if target == "cuda" and not tvm.contrib.nvcc.have_int8(dev.compute_version):
-            print("Skip because int8 intrinsics are not available")
-            return
+            @memoize("topi.tests.test_topi_conv2d_int8.verify_conv2d_nchw")
+            def get_ref_data():
+                a_np = np.random.randint(low=lo, high=hi, size=a_shape).astype(out_dtype)
+                w_np = np.random.randint(low=lo, high=hi, size=w_shape).astype(out_dtype)
+                b_np = np.random.uniform(size=bias_shape).astype(out_dtype)
+                dw_np = tvm.topi.testing.dilate_python(w_np, (1, 1, dilation, dilation))
+                c_np = tvm.topi.testing.conv2d_nchw_python(a_np, dw_np, stride, padding).astype(
+                    out_dtype
+                )
+
+                # convert to NCHWc
+                _, _, out_height, out_width = c_np.shape
+                c_np = c_np.reshape(
+                    (batch, num_filter // oc_block_factor, oc_block_factor, out_height, out_width)
+                ).transpose(0, 1, 3, 4, 2)
+
+                if add_bias:
+                    b_np = np.random.uniform(size=bias_shape).astype(out_dtype)
+                    c_np += b_np
+                if add_relu:
+                    c_np = np.maximum(c_np, 0)
+
+                return a_np, w_np, b_np, c_np
+
+            a_np, w_np, b_np, c_np = get_ref_data()
+
+            with tvm.target.Target(target):
+                C = compute(
+                    A,
+                    W,
+                    (stride, stride),
+                    padding,
+                    (dilation, dilation),
+                    "NCHW",
+                    "NCHW",
+                    out_dtype,
+                )
+                if add_bias:
+                    C = topi.add(C, bias)
+                if add_relu:
+                    C = topi.nn.relu(C)
+                s = schedule([C])
+
+            a = tvm.nd.array(a_np.astype(dtype), dev)
+            w = tvm.nd.array(w_np.astype(dtype), dev)
+            b = tvm.nd.array(b_np.astype(out_dtype), dev)
+            c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), dev)
 
-        print("Running on target: %s" % target)
-        with tvm.target.Target(target):
-            C = topi.cuda.conv2d_nchw_int8(
-                A, W, (stride, stride), padding, (dilation, dilation), dtype
-            )
             if add_bias:
-                C = topi.add(C, bias)
-            if add_relu:
-                C = topi.nn.relu(C)
-            s = topi.cuda.schedule_conv2d_nchw_int8([C])
-
-        a = tvm.nd.array(a_np, dev)
-        w = tvm.nd.array(w_np, dev)
-        b = tvm.nd.array(b_np, dev)
-        c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), dev)
-        if add_bias:
-            tvm.build(
-                s,
-                [A, W, bias, C],
-                target,
-                name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
-                % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation),
-            )
-            func = tvm.build(
-                s,
-                [A, W, bias, C],
-                target,
-                name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
-                % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation),
-            )
-            func(a, w, b, c)
-        else:
+                compile_args = [A, W, bias, C]
+                run_args = [a, w, b, c]
+            else:
+                compile_args = [A, W, C]
+                run_args = [a, w, c]
+
             func = tvm.build(
                 s,
-                [A, W, C],
+                compile_args,
                 target,
                 name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
                 % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation),
             )
-            func(a, w, c)
-        tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-5)
 
-    verify_workload_padding()
+            if build_only:
+                return
 
-    for target in ["cuda"]:
-        check_target(target)
+            print("Running on target: %s" % target)
 
+            func(*run_args)
 
-@pytest.mark.parametrize("in_dtype", ["int8", "uint8"])
-def test_conv2d_nchw(in_dtype):
-    with Int8Fallback():
-        # ResNet18 workloads where channels in / out are multiple of oc_block_factor
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 56, 64, 3, 1, 1)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 56, 64, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 56, 128, 3, 2, 1)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 56, 128, 1, 2, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 128, 28, 128, 3, 1, 1)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 128, 28, 256, 3, 2, 1)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 128, 28, 256, 1, 2, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 256, 14, 256, 3, 1, 1)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 256, 14, 512, 3, 2, 1)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 256, 14, 512, 1, 2, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 512, 7, 512, 3, 1, 1)
+            tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-5)
 
-        # bias, relu
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 56, 64, 3, 1, 1, add_relu=True)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 56, 64, 3, 1, 1, add_bias=True)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 56, 64, 3, 1, 1, add_bias=True, add_relu=True)
+        targets = [
+            (
+                "cuda",
+                lambda a, w, s, p, d, l, ol, o: topi.cuda.conv2d_NCHWc_int8(a, w, s, p, d, l, o),
+                topi.cuda.schedule_conv2d_NCHWc_int8,
+                4,
+                False,
+            ),
+            # Disable on CI since it does not support spirv int8 dot product
+            # (
+            #     "vulkan -from_device=0",
+            #     lambda a, w, s, p, d, l, ol, o: topi.cuda.conv2d_NCHWc_int8(a, w, s, p, d, l, o),
+            #     topi.cuda.schedule_conv2d_NCHWc_int8,
+            #     4,
+            #     False,
+            # ),
+        ]
 
-        # dilation = 2
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 56, 64, 3, 1, 1, dilation=2)
+        build_only_aarch64 = platform.machine() != "aarch64"
 
-        # batch size
-        verify_conv2d_NCHWc_int8(in_dtype, 4, 64, 56, 64, 3, 1, 1)
-        verify_conv2d_NCHWc_int8(in_dtype, 9, 64, 56, 64, 3, 1, 1)
+        targets.append(
+            (
+                "llvm -device arm_cpu -mtriple aarch64-linux-gnu -mattr=+neon,+v8.2a,+dotprod",
+                topi.arm_cpu.conv2d_NCHWc_int8,
+                topi.arm_cpu.schedule_conv2d_NCHWc_int8,
+                8,
+                build_only_aarch64,
+            )
+        )
 
-        # weird workloads
-        verify_conv2d_NCHWc_int8(in_dtype, 4, 4, 4, 8, 4, 4, 4)
+        if in_dtype == "int8":
+            targets += [
+                (
+                    "llvm -device arm_cpu -mtriple aarch64-linux-gnu -mattr=+neon",
+                    topi.arm_cpu.conv2d_NCHWc_int8,
+                    topi.arm_cpu.schedule_conv2d_NCHWc_int8,
+                    8,
+                    build_only_aarch64,
+                ),
+                (
+                    "rocm -mattr=+dotprod",
+                    lambda a, w, s, p, d, l, ol, o: topi.cuda.conv2d_NCHWc_int8(
+                        a, w, s, p, d, l, o
+                    ),
+                    topi.cuda.schedule_conv2d_NCHWc_int8,
+                    4,
+                    False,
+                ),
+            ]
+
+        for target, compute, schedule, oc_block_factor, build_only in targets:
+            check_target(target, compute, schedule, oc_block_factor, build_only)
+
+
+# Conv2d NCHW int8 schedule testing. Internally, it uses NCHWc schedule. So, just
+# performing basic testing - one test for all different scenarios - batch, dilation etc..
+@pytest.mark.parametrize("in_dtype", ["int8", "uint8"])
+@pytest.mark.parametrize(
+    "params",
+    [
+        (1, 64, 56, 64, 3, 1, 1, 1, False, False),
+        (1, 64, 56, 64, 3, 1, 1, 1, False, True),
+        (1, 64, 56, 64, 3, 1, 1, 2, False, False),
+        (9, 64, 56, 64, 3, 1, 1, 1, False, False),
+        (4, 4, 4, 4, 4, 4, 4, 1, False, False),
+        (1, 32, 149, 32, 3, 1, 0, 1, False, False),
+        (7, 32, 149, 32, 3, 1, 0, 1, False, False),
+        (1, 32, 35, 64, 7, 2, (0, 0, 1, 1), 1, False, False),
+        (1, 32, 35, 64, 7, 2, (0, 0, 2, 2), 1, False, False),
+    ],
+)
+def test_conv2d_nchw_int8(in_dtype, params):
+    with Int8Fallback():
+        (
+            batch,
+            in_channel,
+            in_size,
+            num_filter,
+            kernel,
+            stride,
+            padding,
+            dilation,
+            add_bias,
+            add_relu,
+        ) = params
+        pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (kernel, kernel))
+        padding_sum = pad_top + pad_left + pad_bottom + pad_right
+        print(
+            "Workload: (%d, %d, %d, %d, %d, %d, %d, %d)"
+            % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation)
+        )
 
-        # inception v3 workloads where channels in / out are multiple of oc_block_factor
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 32, 149, 32, 3, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 32, 147, 64, 3, 1, 1)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 73, 80, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 80, 73, 192, 3, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 192, 35, 64, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 192, 35, 48, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 48, 35, 64, 5, 1, 2)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 35, 96, 3, 1, 1)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 96, 35, 96, 3, 1, 1)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 192, 35, 32, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 256, 35, 64, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 256, 35, 48, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 288, 35, 64, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 288, 35, 48, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 288, 35, 384, 3, 2, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 96, 35, 96, 3, 2, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 768, 17, 192, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 768, 17, 128, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 128, 17, 128, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 128, 17, 192, 7, 1, 3)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 128, 17, 128, 7, 1, 3)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 128, 17, 192, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 768, 17, 160, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 160, 17, 160, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 160, 17, 192, 7, 1, 3)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 160, 17, 160, 7, 1, 3)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 160, 17, 192, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 192, 17, 192, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 192, 17, 192, 7, 1, 3)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 192, 17, 320, 3, 2, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 192, 17, 192, 3, 2, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 1280, 8, 320, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 1280, 8, 384, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 384, 8, 384, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 384, 8, 384, 3, 1, 1)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 1280, 8, 448, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 448, 8, 384, 3, 1, 1)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 1280, 8, 192, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 2048, 8, 320, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 2048, 8, 384, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 2048, 8, 448, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 2048, 8, 192, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 1024, 19, 88, 3, 1, 1)
+        in_height = in_width = in_size
 
-        # batch > 1
-        verify_conv2d_NCHWc_int8(in_dtype, 7, 32, 149, 32, 3, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 8, 32, 149, 32, 3, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 32, 32, 149, 32, 3, 1, 0)
+        A = te.placeholder((batch, in_channel, in_height, in_width), name="A", dtype=in_dtype)
+        W = te.placeholder((num_filter, in_channel, kernel, kernel), name="W", dtype=in_dtype)
+        bias = te.placeholder((num_filter, 1, 1), name="bias", dtype=in_dtype)
 
-        # Asymmetric padding
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 32, 35, 64, 7, 2, (0, 0, 1, 1))
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 8, 128, 3, 1, (3, 3, 2, 2))
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 8, 64, 1, 1, (1, 2, 2, 1))
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 17, 192, 1, 1, (1, 2))
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 8, 64, 3, 1, (3, 1))
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 128, 8, 384, 3, 1, (0, 2))
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 8, 64, 1, 1, "VALID")
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 392, 8, 64, 3, 1, "VALID")
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 512, 19, 64, 1, 1, "SAME")
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 16, 32, 2, 1, "SAME")
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 8, 64, 3, 1, (1, 2, 2, 1), add_relu=True)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 8, 64, 5, 2, (1, 3), add_bias=True)
-        verify_conv2d_NCHWc_int8(
-            in_dtype, 1, 64, 56, 64, 3, 1, "VALID", add_bias=True, add_relu=True
-        )
-        verify_conv2d_NCHWc_int8(
-            in_dtype, 1, 64, 56, 64, 24, 1, "SAME", add_bias=True, add_relu=True
-        )
+        a_shape = get_const_tuple(A.shape)
+        w_shape = get_const_tuple(W.shape)
+        bias_shape = get_const_tuple(bias.shape)
+        dtype = A.dtype
+
+        @memoize("topi.tests.test_topi_conv2d_int8.verify_conv2d_nchw")
+        def get_ref_data():
+            a_np = np.random.randint(low=-128, high=127, size=a_shape).astype(dtype)
+            w_np = np.random.randint(low=-128, high=128, size=w_shape).astype(dtype)
+            b_np = np.random.uniform(size=bias_shape).astype(dtype)
+            dw_np = tvm.topi.testing.dilate_python(w_np, (1, 1, dilation, dilation))
+            c_np = tvm.topi.testing.conv2d_nchw_python(a_np, dw_np, stride, padding).astype(dtype)
 
-        # Conv2d NCHW int8 schedule testing. Internally, it uses NCHWc schedule. So, just
-        # performing basic testing - one test for all different scenarios - batch, dilation etc..
-        verify_conv2d_nchw_int8(in_dtype, 1, 64, 56, 64, 3, 1, 1)
-        verify_conv2d_nchw_int8(in_dtype, 1, 64, 56, 64, 3, 1, 1, add_relu=True)
-        verify_conv2d_nchw_int8(in_dtype, 1, 64, 56, 64, 3, 1, 1, dilation=2)
-        verify_conv2d_nchw_int8(in_dtype, 9, 64, 56, 64, 3, 1, 1)
-        verify_conv2d_nchw_int8(in_dtype, 4, 4, 4, 4, 4, 4, 4)
-        verify_conv2d_nchw_int8(in_dtype, 1, 32, 149, 32, 3, 1, 0)
-        verify_conv2d_nchw_int8(in_dtype, 7, 32, 149, 32, 3, 1, 0)
-        verify_conv2d_nchw_int8(in_dtype, 1, 32, 35, 64, 7, 2, (0, 0, 1, 1))
-        verify_conv2d_nchw_int8(in_dtype, 1, 32, 35, 64, 7, 2, (0, 0, 2, 2))
+            if add_bias:
+                b_np = np.random.uniform(size=bias_shape).astype(dtype)
+                c_np += b_np
+            if add_relu:
+                c_np = np.maximum(c_np, 0)
 
+            return a_np, w_np, b_np, c_np
 
-def test_conv2d_nhwc():
-    with Int8Fallback():
-        # Subset of inception v3 expanded (dilation > 1, batch > 1, 'VALID' padding)
-        verify_conv2d_NHWC_gemm_int8(1, 3, 299, 32, 3, 2, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 32, 149, 32, 3, 1, "SAME", dilation=2)
-        verify_conv2d_NHWC_gemm_int8(4, 32, 147, 64, 3, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 64, 73, 80, 1, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 80, 73, 192, 3, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 192, 35, 48, 1, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 192, 35, 64, 1, 1, "VALID")
-        verify_conv2d_NHWC_gemm_int8(1, 192, 35, 32, 1, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 48, 35, 64, 5, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 96, 35, 96, 3, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 256, 35, 48, 1, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 256, 35, 64, 1, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 288, 35, 64, 1, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 288, 35, 48, 1, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 96, 35, 96, 3, 2, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 128, 17, 192, 7, 1, "SAME", dilation=2)
-        verify_conv2d_NHWC_gemm_int8(1, 160, 17, 160, 7, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 160, 17, 192, 1, 1, "VALID")
-        verify_conv2d_NHWC_gemm_int8(1, 192, 17, 192, 1, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 768, 5, 128, 1, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 192, 17, 320, 3, 2, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 192, 17, 192, 3, 2, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 1280, 8, 192, 1, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 1280, 8, 384, 1, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 1280, 8, 320, 1, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 1280, 8, 448, 1, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 384, 8, 384, 1, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 384, 8, 384, 3, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 448, 8, 384, 3, 1, "VALID")
-        verify_conv2d_NHWC_gemm_int8(1, 2048, 8, 320, 1, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 2048, 8, 448, 1, 1, "SAME", add_bias=True, add_relu=True)
-        verify_conv2d_NHWC_gemm_int8(1, 2048, 8, 192, 1, 1, "SAME", add_bias=True)
-
-        # Let's also verify that it compiles fine on AArch64 targets
-        compile_conv2d_NHWC_gemm_int8_arm(1, 3, 299, 32, 3, 2, "SAME")
+        a_np, w_np, b_np, c_np = get_ref_data()
+
+        def verify_workload_padding():
+            _, _, out_height, out_width = get_const_tuple(c_np.shape)
+            wkl = _get_workload(A, W, (stride, stride), padding, dilation, dtype)
+
+            # for testing functionality,
+            # we choose arbitrary int32_lanes and num_int8_elements can divide the channel,
+            # regardless of the performance.
+            int32_lanes, num_int8_elements = num_filter, in_channel
+
+            # check if tile_ow candidates are the factors of the right output weight.
+            cfg = autotvm.get_config()
+            fallback_schedule_cpu_common_int8(cfg, wkl, int32_lanes, num_int8_elements)
+            ow_tile = np.prod(cfg["tile_ow"].size)
+
+            tvm.testing.assert_allclose(ow_tile, out_width)
+
+        def check_target(target):
+            dev = tvm.device(target, 0)
+            if not tvm.testing.device_enabled(target):
+                print("Skip because %s is not enabled" % target)
+                return
+            if target == "cuda" and not tvm.contrib.nvcc.have_int8(dev.compute_version):
+                print("Skip because int8 intrinsics are not available")
+                return
+
+            print("Running on target: %s" % target)
+            with tvm.target.Target(target):
+                C = topi.cuda.conv2d_nchw_int8(
+                    A, W, (stride, stride), padding, (dilation, dilation), dtype
+                )
+                if add_bias:
+                    C = topi.add(C, bias)
+                if add_relu:
+                    C = topi.nn.relu(C)
+                s = topi.cuda.schedule_conv2d_nchw_int8([C])
+
+            a = tvm.nd.array(a_np, dev)
+            w = tvm.nd.array(w_np, dev)
+            b = tvm.nd.array(b_np, dev)
+            c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), dev)
+            if add_bias:
+                func = tvm.build(
+                    s,
+                    [A, W, bias, C],
+                    target,
+                    name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
+                    % (
+                        batch,
+                        in_channel,
+                        in_size,
+                        num_filter,
+                        kernel,
+                        stride,
+                        padding_sum,
+                        dilation,
+                    ),
+                )
+                func(a, w, b, c)
+            else:
+                func = tvm.build(
+                    s,
+                    [A, W, C],
+                    target,
+                    name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
+                    % (
+                        batch,
+                        in_channel,
+                        in_size,
+                        num_filter,
+                        kernel,
+                        stride,
+                        padding_sum,
+                        dilation,
+                    ),
+                )
+                func(a, w, c)
+            tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-5)
+
+        verify_workload_padding()
+
+        for target in ["cuda"]:
+            check_target(target)

Review Comment:
   Done



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] ekalda commented on a diff in pull request #13669: [TOPI][bugfix] Fix a bug in arm_cpu int8 dotprod schedule and modernize tests

Posted by GitBox <gi...@apache.org>.
ekalda commented on code in PR #13669:
URL: https://github.com/apache/tvm/pull/13669#discussion_r1058905748


##########
tests/python/topi/python/test_topi_conv2d_int8.py:
##########
@@ -298,378 +175,462 @@ def get_ref_data():
 
         a_np, w_np, b_np, c_np = get_ref_data()
 
-        with tvm.target.Target(target):
-            C = compute(
-                A,
-                W,
-                (stride, stride),
-                padding,
-                (dilation, dilation),
-                "NCHW",
-                "NCHW",
-                out_dtype,
-            )
+        dev = tvm.device(target, 0)
+        with tvm.target.Target(target) as tvm_target:
+            C = compute(A, W, (stride, stride), padding, (dilation, dilation), dtype)
             if add_bias:
                 C = topi.add(C, bias)
             if add_relu:
                 C = topi.nn.relu(C)
             s = schedule([C])
 
-        a = tvm.nd.array(a_np.astype(dtype), dev)
-        w = tvm.nd.array(w_np.astype(dtype), dev)
-        b = tvm.nd.array(b_np.astype(out_dtype), dev)
-        c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), dev)
-
-        if add_bias:
-            compile_args = [A, W, bias, C]
-            run_args = [a, w, b, c]
-        else:
-            compile_args = [A, W, C]
-            run_args = [a, w, c]
-
-        func = tvm.build(
-            s,
-            compile_args,
-            target,
-            name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
-            % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation),
-        )
+            a = tvm.nd.array(a_np, dev)
+            w = tvm.nd.array(w_np, dev)
+            b = tvm.nd.array(b_np, dev)
+            c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), dev)
 
-        if build_only:
-            return
+            build_inputs = [A, W, bias, C] if add_bias else [A, W, C]
+            inference_inputs = (a, w, b, c) if add_bias else (a, w, c)
+
+            func = tvm.build(
+                s,
+                build_inputs,
+                target,
+                name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
+                % (
+                    batch,
+                    in_channel,
+                    in_size,
+                    num_filter,
+                    kernel,
+                    stride,
+                    padding_sum,
+                    dilation,
+                ),
+            )
 
-        print("Running on target: %s" % target)
+            build_only = tvm_target.features.is_aarch64 and (platform.machine() != "aarch64")
 
-        func(*run_args)
+            if not build_only:
+                print("Running on target: %s" % target)
+                func(*inference_inputs)
+                tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-5)
 
-        tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-5)
 
-    targets = [
-        (
-            "cuda",
-            lambda a, w, s, p, d, l, ol, o: topi.cuda.conv2d_NCHWc_int8(a, w, s, p, d, l, o),
-            topi.cuda.schedule_conv2d_NCHWc_int8,
-            4,
-            False,
-        ),
-        # Disable on CI since it does not support spirv int8 dot product
-        # (
-        #     "vulkan -from_device=0",
-        #     lambda a, w, s, p, d, l, ol, o: topi.cuda.conv2d_NCHWc_int8(a, w, s, p, d, l, o),
-        #     topi.cuda.schedule_conv2d_NCHWc_int8,
-        #     4,
-        #     False,
-        # ),
-    ]
-
-    build_only_aarch64 = platform.machine() != "aarch64"
-
-    targets.append(
+@pytest.mark.parametrize("in_dtype", ["int8", "uint8"])
+@pytest.mark.parametrize(
+    "params",
+    [
+        # ResNet18 workloads where channels in / out are multiple of oc_block_factor
+        (1, 64, 56, 64, 3, 1, 1, 1, False, False),
+        (1, 64, 56, 64, 1, 1, 0, 1, False, False),
+        (1, 64, 56, 128, 3, 2, 1, 1, False, False),
+        (1, 64, 56, 128, 1, 2, 0, 1, False, False),
+        (1, 128, 28, 128, 3, 1, 1, 1, False, False),
+        (1, 128, 28, 256, 3, 2, 1, 1, False, False),
+        (1, 128, 28, 256, 1, 2, 0, 1, False, False),
+        (1, 256, 14, 256, 3, 1, 1, 1, False, False),
+        (1, 256, 14, 512, 3, 2, 1, 1, False, False),
+        (1, 256, 14, 512, 1, 2, 0, 1, False, False),
+        (1, 512, 7, 512, 3, 1, 1, 1, False, False),
+        # bias, relu
+        (1, 64, 56, 64, 3, 1, 1, 1, False, True),
+        (1, 64, 56, 64, 3, 1, 1, 1, True, False),
+        (1, 64, 56, 64, 3, 1, 1, 1, True, True),
+        # dilation = 2
+        (1, 64, 56, 64, 3, 1, 1, 2, False, False),
+        # batch size
+        (4, 64, 56, 64, 3, 1, 1, 1, False, False),
+        (9, 64, 56, 64, 3, 1, 1, 1, False, False),
+        # weird workloads
+        (4, 4, 4, 8, 4, 4, 4, 1, False, False),
+        # inception v3 workloads where channels in / out are multiple of oc_block_factor
+        (1, 32, 149, 32, 3, 1, 0, 1, False, False),
+        (1, 32, 147, 64, 3, 1, 1, 1, False, False),
+        (1, 64, 73, 80, 1, 1, 0, 1, False, False),
+        (1, 80, 73, 192, 3, 1, 0, 1, False, False),
+        (1, 192, 35, 64, 1, 1, 0, 1, False, False),
+        (1, 192, 35, 48, 1, 1, 0, 1, False, False),
+        (1, 48, 35, 64, 5, 1, 2, 1, False, False),
+        (1, 64, 35, 96, 3, 1, 1, 1, False, False),
+        (1, 96, 35, 96, 3, 1, 1, 1, False, False),
+        (1, 192, 35, 32, 1, 1, 0, 1, False, False),
+        (1, 256, 35, 64, 1, 1, 0, 1, False, False),
+        (1, 256, 35, 48, 1, 1, 0, 1, False, False),
+        (1, 288, 35, 64, 1, 1, 0, 1, False, False),
+        (1, 288, 35, 48, 1, 1, 0, 1, False, False),
+        (1, 288, 35, 384, 3, 2, 0, 1, False, False),
+        (1, 96, 35, 96, 3, 2, 0, 1, False, False),
+        (1, 768, 17, 192, 1, 1, 0, 1, False, False),
+        (1, 768, 17, 128, 1, 1, 0, 1, False, False),
+        (1, 128, 17, 128, 1, 1, 0, 1, False, False),
+        (1, 128, 17, 192, 7, 1, 3, 1, False, False),
+        (1, 128, 17, 128, 7, 1, 3, 1, False, False),
+        (1, 128, 17, 192, 1, 1, 0, 1, False, False),
+        (1, 768, 17, 160, 1, 1, 0, 1, False, False),
+        (1, 160, 17, 160, 1, 1, 0, 1, False, False),
+        (1, 160, 17, 192, 7, 1, 3, 1, False, False),
+        (1, 160, 17, 160, 7, 1, 3, 1, False, False),
+        (1, 160, 17, 192, 1, 1, 0, 1, False, False),
+        (1, 192, 17, 192, 1, 1, 0, 1, False, False),
+        (1, 192, 17, 192, 7, 1, 3, 1, False, False),
+        (1, 192, 17, 320, 3, 2, 0, 1, False, False),
+        (1, 192, 17, 192, 3, 2, 0, 1, False, False),
+        (1, 1280, 8, 320, 1, 1, 0, 1, False, False),
+        (1, 1280, 8, 384, 1, 1, 0, 1, False, False),
+        (1, 384, 8, 384, 1, 1, 0, 1, False, False),
+        (1, 384, 8, 384, 3, 1, 1, 1, False, False),
+        (1, 1280, 8, 448, 1, 1, 0, 1, False, False),
+        (1, 448, 8, 384, 3, 1, 1, 1, False, False),
+        (1, 1280, 8, 192, 1, 1, 0, 1, False, False),
+        (1, 2048, 8, 320, 1, 1, 0, 1, False, False),
+        (1, 2048, 8, 384, 1, 1, 0, 1, False, False),
+        (1, 2048, 8, 448, 1, 1, 0, 1, False, False),
+        (1, 2048, 8, 192, 1, 1, 0, 1, False, False),
+        (1, 1024, 19, 88, 3, 1, 1, 1, False, False),
+        # batch > 1
+        (7, 32, 149, 32, 3, 1, 0, 1, False, False),
+        (8, 32, 149, 32, 3, 1, 0, 1, False, False),
+        (32, 32, 149, 32, 3, 1, 0, 1, False, False),
+        # Asymmetric padding
+        (1, 32, 35, 64, 7, 2, (0, 0, 1, 1), 1, False, False),
+        (1, 64, 8, 128, 3, 1, (3, 3, 2, 2), 1, False, False),
+        (1, 64, 8, 64, 1, 1, (1, 2, 2, 1), 1, False, False),
+        (1, 64, 17, 192, 1, 1, (1, 2), 1, False, False),
+        (1, 64, 8, 64, 3, 1, (3, 1), 1, False, False),
+        (1, 128, 8, 384, 3, 1, (0, 2), 1, False, False),
+        (1, 64, 8, 64, 1, 1, "VALID", 1, False, False),
+        (1, 392, 8, 64, 3, 1, "VALID", 1, False, False),
+        (1, 512, 19, 64, 1, 1, "SAME", 1, False, False),
+        (1, 64, 16, 32, 2, 1, "SAME", 1, False, False),
+        (1, 64, 8, 64, 3, 1, (1, 2, 2, 1), 1, False, True),
+        (1, 64, 8, 64, 5, 2, (1, 3), 1, True, False),
+        (1, 64, 56, 64, 3, 1, "VALID", 1, True, True),
+        (1, 64, 56, 64, 24, 1, "SAME", 1, True, True),
+    ],
+)
+def test_conv2d_NCHWc_int8(in_dtype, params):
+    with Int8Fallback():
         (
-            "llvm -device arm_cpu -mtriple aarch64-linux-gnu -mattr=+neon,+v8.2a,+dotprod",
-            topi.arm_cpu.conv2d_NCHWc_int8,
-            topi.arm_cpu.schedule_conv2d_NCHWc_int8,
-            8,
-            build_only_aarch64,
+            batch,
+            in_channel,
+            in_size,
+            num_filter,
+            kernel,
+            stride,
+            padding,
+            dilation,
+            add_bias,
+            add_relu,
+        ) = params
+        pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (kernel, kernel))
+        padding_sum = pad_top + pad_left + pad_bottom + pad_right
+        print(
+            "Workload: (%d, %d, %d, %d, %d, %d, %d, %d)"
+            % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation)
         )
-    )
-
-    if in_dtype == "int8":
-        targets += [
-            (
-                "llvm -device arm_cpu -mtriple aarch64-linux-gnu -mattr=+neon",
-                topi.arm_cpu.conv2d_NCHWc_int8,
-                topi.arm_cpu.schedule_conv2d_NCHWc_int8,
-                8,
-                build_only_aarch64,
-            ),
-            (
-                "rocm -mattr=+dotprod",
-                lambda a, w, s, p, d, l, ol, o: topi.cuda.conv2d_NCHWc_int8(a, w, s, p, d, l, o),
-                topi.cuda.schedule_conv2d_NCHWc_int8,
-                4,
-                False,
-            ),
-        ]
-
-    for target, compute, schedule, oc_block_factor, build_only in targets:
-        check_target(target, compute, schedule, oc_block_factor, build_only)
-
-
-def verify_conv2d_nchw_int8(
-    in_dtype,
-    batch,
-    in_channel,
-    in_size,
-    num_filter,
-    kernel,
-    stride,
-    padding,
-    dilation=1,
-    add_bias=False,
-    add_relu=False,
-):
-    pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (kernel, kernel))
-    padding_sum = pad_top + pad_left + pad_bottom + pad_right
-    print(
-        "Workload: (%d, %d, %d, %d, %d, %d, %d, %d)"
-        % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation)
-    )
-
-    in_height = in_width = in_size
-
-    A = te.placeholder((batch, in_channel, in_height, in_width), name="A", dtype=in_dtype)
-    W = te.placeholder((num_filter, in_channel, kernel, kernel), name="W", dtype=in_dtype)
-    bias = te.placeholder((num_filter, 1, 1), name="bias", dtype=in_dtype)
-
-    a_shape = get_const_tuple(A.shape)
-    w_shape = get_const_tuple(W.shape)
-    bias_shape = get_const_tuple(bias.shape)
-    dtype = A.dtype
-
-    @memoize("topi.tests.test_topi_conv2d_int8.verify_conv2d_nchw")
-    def get_ref_data():
-        a_np = np.random.randint(low=-128, high=127, size=a_shape).astype(dtype)
-        w_np = np.random.randint(low=-128, high=128, size=w_shape).astype(dtype)
-        b_np = np.random.uniform(size=bias_shape).astype(dtype)
-        dw_np = tvm.topi.testing.dilate_python(w_np, (1, 1, dilation, dilation))
-        c_np = tvm.topi.testing.conv2d_nchw_python(a_np, dw_np, stride, padding).astype(dtype)
-
-        if add_bias:
-            b_np = np.random.uniform(size=bias_shape).astype(dtype)
-            c_np += b_np
-        if add_relu:
-            c_np = np.maximum(c_np, 0)
-
-        return a_np, w_np, b_np, c_np
-
-    a_np, w_np, b_np, c_np = get_ref_data()
-
-    def verify_workload_padding():
-        _, _, out_height, out_width = get_const_tuple(c_np.shape)
-        wkl = _get_workload(A, W, (stride, stride), padding, dilation, dtype)
-
-        # for testing functionality,
-        # we choose arbitrary int32_lanes and num_int8_elements can divide the channel,
-        # regardless of the performance.
-        int32_lanes, num_int8_elements = num_filter, in_channel
 
-        # check if tile_ow candidates are the factors of the right output weight.
-        cfg = autotvm.get_config()
-        fallback_schedule_cpu_common_int8(cfg, wkl, int32_lanes, num_int8_elements)
-        ow_tile = np.prod(cfg["tile_ow"].size)
-
-        tvm.testing.assert_allclose(ow_tile, out_width)
+        in_height = in_width = in_size
+
+        A = te.placeholder((batch, in_channel, in_height, in_width), name="A", dtype=in_dtype)
+        W = te.placeholder((num_filter, in_channel, kernel, kernel), name="W", dtype=in_dtype)
+
+        a_shape = get_const_tuple(A.shape)
+        w_shape = get_const_tuple(W.shape)
+        dtype = A.dtype
+        out_dtype = "int32" if in_dtype == "int8" else "uint32"
+        lo = -128 if in_dtype == "int8" else 0
+        hi = 127 if in_dtype == "int8" else 255
+
+        def check_target(target, compute, schedule, oc_block_factor, build_only):
+            dev = tvm.device(target, 0)
+            if not tvm.testing.device_enabled(target):
+                print("Skip because %s is not enabled" % target)
+                return
+            if target == "cuda" and not tvm.contrib.nvcc.have_int8(dev.compute_version):
+                print("Skip because int8 intrinsics are not available")
+                return
+
+            bias = te.placeholder(
+                (num_filter // oc_block_factor, 1, 1, oc_block_factor), name="bias", dtype=out_dtype
+            )
+            bias_shape = get_const_tuple(bias.shape)
 
-    def check_target(target):
-        dev = tvm.device(target, 0)
-        if not tvm.testing.device_enabled(target):
-            print("Skip because %s is not enabled" % target)
-            return
-        if target == "cuda" and not tvm.contrib.nvcc.have_int8(dev.compute_version):
-            print("Skip because int8 intrinsics are not available")
-            return
+            @memoize("topi.tests.test_topi_conv2d_int8.verify_conv2d_nchw")
+            def get_ref_data():
+                a_np = np.random.randint(low=lo, high=hi, size=a_shape).astype(out_dtype)
+                w_np = np.random.randint(low=lo, high=hi, size=w_shape).astype(out_dtype)
+                b_np = np.random.uniform(size=bias_shape).astype(out_dtype)
+                dw_np = tvm.topi.testing.dilate_python(w_np, (1, 1, dilation, dilation))
+                c_np = tvm.topi.testing.conv2d_nchw_python(a_np, dw_np, stride, padding).astype(
+                    out_dtype
+                )
+
+                # convert to NCHWc
+                _, _, out_height, out_width = c_np.shape
+                c_np = c_np.reshape(
+                    (batch, num_filter // oc_block_factor, oc_block_factor, out_height, out_width)
+                ).transpose(0, 1, 3, 4, 2)
+
+                if add_bias:
+                    b_np = np.random.uniform(size=bias_shape).astype(out_dtype)
+                    c_np += b_np
+                if add_relu:
+                    c_np = np.maximum(c_np, 0)
+
+                return a_np, w_np, b_np, c_np
+
+            a_np, w_np, b_np, c_np = get_ref_data()
+
+            with tvm.target.Target(target):
+                C = compute(
+                    A,
+                    W,
+                    (stride, stride),
+                    padding,
+                    (dilation, dilation),
+                    "NCHW",
+                    "NCHW",
+                    out_dtype,
+                )
+                if add_bias:
+                    C = topi.add(C, bias)
+                if add_relu:
+                    C = topi.nn.relu(C)
+                s = schedule([C])
+
+            a = tvm.nd.array(a_np.astype(dtype), dev)
+            w = tvm.nd.array(w_np.astype(dtype), dev)
+            b = tvm.nd.array(b_np.astype(out_dtype), dev)
+            c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), dev)
 
-        print("Running on target: %s" % target)
-        with tvm.target.Target(target):
-            C = topi.cuda.conv2d_nchw_int8(
-                A, W, (stride, stride), padding, (dilation, dilation), dtype
-            )
             if add_bias:
-                C = topi.add(C, bias)
-            if add_relu:
-                C = topi.nn.relu(C)
-            s = topi.cuda.schedule_conv2d_nchw_int8([C])
-
-        a = tvm.nd.array(a_np, dev)
-        w = tvm.nd.array(w_np, dev)
-        b = tvm.nd.array(b_np, dev)
-        c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), dev)
-        if add_bias:
-            tvm.build(
-                s,
-                [A, W, bias, C],
-                target,
-                name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
-                % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation),
-            )
-            func = tvm.build(
-                s,
-                [A, W, bias, C],
-                target,
-                name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
-                % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation),
-            )
-            func(a, w, b, c)
-        else:
+                compile_args = [A, W, bias, C]
+                run_args = [a, w, b, c]
+            else:
+                compile_args = [A, W, C]
+                run_args = [a, w, c]
+
             func = tvm.build(
                 s,
-                [A, W, C],
+                compile_args,
                 target,
                 name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
                 % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation),
             )
-            func(a, w, c)
-        tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-5)
 
-    verify_workload_padding()
+            if build_only:
+                return
 
-    for target in ["cuda"]:
-        check_target(target)
+            print("Running on target: %s" % target)
 
+            func(*run_args)
 
-@pytest.mark.parametrize("in_dtype", ["int8", "uint8"])
-def test_conv2d_nchw(in_dtype):
-    with Int8Fallback():
-        # ResNet18 workloads where channels in / out are multiple of oc_block_factor
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 56, 64, 3, 1, 1)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 56, 64, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 56, 128, 3, 2, 1)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 56, 128, 1, 2, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 128, 28, 128, 3, 1, 1)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 128, 28, 256, 3, 2, 1)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 128, 28, 256, 1, 2, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 256, 14, 256, 3, 1, 1)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 256, 14, 512, 3, 2, 1)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 256, 14, 512, 1, 2, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 512, 7, 512, 3, 1, 1)
+            tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-5)
 
-        # bias, relu
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 56, 64, 3, 1, 1, add_relu=True)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 56, 64, 3, 1, 1, add_bias=True)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 56, 64, 3, 1, 1, add_bias=True, add_relu=True)
+        targets = [
+            (
+                "cuda",
+                lambda a, w, s, p, d, l, ol, o: topi.cuda.conv2d_NCHWc_int8(a, w, s, p, d, l, o),
+                topi.cuda.schedule_conv2d_NCHWc_int8,
+                4,
+                False,
+            ),
+            # Disable on CI since it does not support spirv int8 dot product
+            # (
+            #     "vulkan -from_device=0",
+            #     lambda a, w, s, p, d, l, ol, o: topi.cuda.conv2d_NCHWc_int8(a, w, s, p, d, l, o),
+            #     topi.cuda.schedule_conv2d_NCHWc_int8,
+            #     4,
+            #     False,
+            # ),
+        ]
 
-        # dilation = 2
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 56, 64, 3, 1, 1, dilation=2)
+        build_only_aarch64 = platform.machine() != "aarch64"
 
-        # batch size
-        verify_conv2d_NCHWc_int8(in_dtype, 4, 64, 56, 64, 3, 1, 1)
-        verify_conv2d_NCHWc_int8(in_dtype, 9, 64, 56, 64, 3, 1, 1)
+        targets.append(
+            (
+                "llvm -device arm_cpu -mtriple aarch64-linux-gnu -mattr=+neon,+v8.2a,+dotprod",
+                topi.arm_cpu.conv2d_NCHWc_int8,
+                topi.arm_cpu.schedule_conv2d_NCHWc_int8,
+                8,
+                build_only_aarch64,
+            )
+        )
 
-        # weird workloads
-        verify_conv2d_NCHWc_int8(in_dtype, 4, 4, 4, 8, 4, 4, 4)
+        if in_dtype == "int8":
+            targets += [
+                (
+                    "llvm -device arm_cpu -mtriple aarch64-linux-gnu -mattr=+neon",
+                    topi.arm_cpu.conv2d_NCHWc_int8,
+                    topi.arm_cpu.schedule_conv2d_NCHWc_int8,
+                    8,
+                    build_only_aarch64,
+                ),
+                (
+                    "rocm -mattr=+dotprod",
+                    lambda a, w, s, p, d, l, ol, o: topi.cuda.conv2d_NCHWc_int8(
+                        a, w, s, p, d, l, o
+                    ),
+                    topi.cuda.schedule_conv2d_NCHWc_int8,
+                    4,
+                    False,
+                ),
+            ]
+
+        for target, compute, schedule, oc_block_factor, build_only in targets:
+            check_target(target, compute, schedule, oc_block_factor, build_only)
+
+
+# Conv2d NCHW int8 schedule testing. Internally, it uses NCHWc schedule. So, just
+# performing basic testing - one test for all different scenarios - batch, dilation etc..
+@pytest.mark.parametrize("in_dtype", ["int8", "uint8"])
+@pytest.mark.parametrize(
+    "params",
+    [
+        (1, 64, 56, 64, 3, 1, 1, 1, False, False),
+        (1, 64, 56, 64, 3, 1, 1, 1, False, True),
+        (1, 64, 56, 64, 3, 1, 1, 2, False, False),
+        (9, 64, 56, 64, 3, 1, 1, 1, False, False),
+        (4, 4, 4, 4, 4, 4, 4, 1, False, False),
+        (1, 32, 149, 32, 3, 1, 0, 1, False, False),
+        (7, 32, 149, 32, 3, 1, 0, 1, False, False),
+        (1, 32, 35, 64, 7, 2, (0, 0, 1, 1), 1, False, False),
+        (1, 32, 35, 64, 7, 2, (0, 0, 2, 2), 1, False, False),
+    ],
+)
+def test_conv2d_nchw_int8(in_dtype, params):
+    with Int8Fallback():
+        (
+            batch,
+            in_channel,
+            in_size,
+            num_filter,
+            kernel,
+            stride,
+            padding,
+            dilation,
+            add_bias,
+            add_relu,
+        ) = params
+        pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (kernel, kernel))
+        padding_sum = pad_top + pad_left + pad_bottom + pad_right
+        print(
+            "Workload: (%d, %d, %d, %d, %d, %d, %d, %d)"
+            % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation)
+        )
 
-        # inception v3 workloads where channels in / out are multiple of oc_block_factor
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 32, 149, 32, 3, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 32, 147, 64, 3, 1, 1)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 73, 80, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 80, 73, 192, 3, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 192, 35, 64, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 192, 35, 48, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 48, 35, 64, 5, 1, 2)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 35, 96, 3, 1, 1)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 96, 35, 96, 3, 1, 1)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 192, 35, 32, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 256, 35, 64, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 256, 35, 48, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 288, 35, 64, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 288, 35, 48, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 288, 35, 384, 3, 2, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 96, 35, 96, 3, 2, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 768, 17, 192, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 768, 17, 128, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 128, 17, 128, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 128, 17, 192, 7, 1, 3)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 128, 17, 128, 7, 1, 3)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 128, 17, 192, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 768, 17, 160, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 160, 17, 160, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 160, 17, 192, 7, 1, 3)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 160, 17, 160, 7, 1, 3)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 160, 17, 192, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 192, 17, 192, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 192, 17, 192, 7, 1, 3)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 192, 17, 320, 3, 2, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 192, 17, 192, 3, 2, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 1280, 8, 320, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 1280, 8, 384, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 384, 8, 384, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 384, 8, 384, 3, 1, 1)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 1280, 8, 448, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 448, 8, 384, 3, 1, 1)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 1280, 8, 192, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 2048, 8, 320, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 2048, 8, 384, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 2048, 8, 448, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 2048, 8, 192, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 1024, 19, 88, 3, 1, 1)
+        in_height = in_width = in_size
 
-        # batch > 1
-        verify_conv2d_NCHWc_int8(in_dtype, 7, 32, 149, 32, 3, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 8, 32, 149, 32, 3, 1, 0)
-        verify_conv2d_NCHWc_int8(in_dtype, 32, 32, 149, 32, 3, 1, 0)
+        A = te.placeholder((batch, in_channel, in_height, in_width), name="A", dtype=in_dtype)
+        W = te.placeholder((num_filter, in_channel, kernel, kernel), name="W", dtype=in_dtype)
+        bias = te.placeholder((num_filter, 1, 1), name="bias", dtype=in_dtype)
 
-        # Asymmetric padding
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 32, 35, 64, 7, 2, (0, 0, 1, 1))
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 8, 128, 3, 1, (3, 3, 2, 2))
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 8, 64, 1, 1, (1, 2, 2, 1))
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 17, 192, 1, 1, (1, 2))
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 8, 64, 3, 1, (3, 1))
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 128, 8, 384, 3, 1, (0, 2))
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 8, 64, 1, 1, "VALID")
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 392, 8, 64, 3, 1, "VALID")
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 512, 19, 64, 1, 1, "SAME")
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 16, 32, 2, 1, "SAME")
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 8, 64, 3, 1, (1, 2, 2, 1), add_relu=True)
-        verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 8, 64, 5, 2, (1, 3), add_bias=True)
-        verify_conv2d_NCHWc_int8(
-            in_dtype, 1, 64, 56, 64, 3, 1, "VALID", add_bias=True, add_relu=True
-        )
-        verify_conv2d_NCHWc_int8(
-            in_dtype, 1, 64, 56, 64, 24, 1, "SAME", add_bias=True, add_relu=True
-        )
+        a_shape = get_const_tuple(A.shape)
+        w_shape = get_const_tuple(W.shape)
+        bias_shape = get_const_tuple(bias.shape)
+        dtype = A.dtype
+
+        @memoize("topi.tests.test_topi_conv2d_int8.verify_conv2d_nchw")
+        def get_ref_data():
+            a_np = np.random.randint(low=-128, high=127, size=a_shape).astype(dtype)
+            w_np = np.random.randint(low=-128, high=128, size=w_shape).astype(dtype)
+            b_np = np.random.uniform(size=bias_shape).astype(dtype)
+            dw_np = tvm.topi.testing.dilate_python(w_np, (1, 1, dilation, dilation))
+            c_np = tvm.topi.testing.conv2d_nchw_python(a_np, dw_np, stride, padding).astype(dtype)
 
-        # Conv2d NCHW int8 schedule testing. Internally, it uses NCHWc schedule. So, just
-        # performing basic testing - one test for all different scenarios - batch, dilation etc..
-        verify_conv2d_nchw_int8(in_dtype, 1, 64, 56, 64, 3, 1, 1)
-        verify_conv2d_nchw_int8(in_dtype, 1, 64, 56, 64, 3, 1, 1, add_relu=True)
-        verify_conv2d_nchw_int8(in_dtype, 1, 64, 56, 64, 3, 1, 1, dilation=2)
-        verify_conv2d_nchw_int8(in_dtype, 9, 64, 56, 64, 3, 1, 1)
-        verify_conv2d_nchw_int8(in_dtype, 4, 4, 4, 4, 4, 4, 4)
-        verify_conv2d_nchw_int8(in_dtype, 1, 32, 149, 32, 3, 1, 0)
-        verify_conv2d_nchw_int8(in_dtype, 7, 32, 149, 32, 3, 1, 0)
-        verify_conv2d_nchw_int8(in_dtype, 1, 32, 35, 64, 7, 2, (0, 0, 1, 1))
-        verify_conv2d_nchw_int8(in_dtype, 1, 32, 35, 64, 7, 2, (0, 0, 2, 2))
+            if add_bias:
+                b_np = np.random.uniform(size=bias_shape).astype(dtype)
+                c_np += b_np
+            if add_relu:
+                c_np = np.maximum(c_np, 0)
 
+            return a_np, w_np, b_np, c_np
 
-def test_conv2d_nhwc():
-    with Int8Fallback():
-        # Subset of inception v3 expanded (dilation > 1, batch > 1, 'VALID' padding)
-        verify_conv2d_NHWC_gemm_int8(1, 3, 299, 32, 3, 2, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 32, 149, 32, 3, 1, "SAME", dilation=2)
-        verify_conv2d_NHWC_gemm_int8(4, 32, 147, 64, 3, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 64, 73, 80, 1, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 80, 73, 192, 3, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 192, 35, 48, 1, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 192, 35, 64, 1, 1, "VALID")
-        verify_conv2d_NHWC_gemm_int8(1, 192, 35, 32, 1, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 48, 35, 64, 5, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 96, 35, 96, 3, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 256, 35, 48, 1, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 256, 35, 64, 1, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 288, 35, 64, 1, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 288, 35, 48, 1, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 96, 35, 96, 3, 2, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 128, 17, 192, 7, 1, "SAME", dilation=2)
-        verify_conv2d_NHWC_gemm_int8(1, 160, 17, 160, 7, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 160, 17, 192, 1, 1, "VALID")
-        verify_conv2d_NHWC_gemm_int8(1, 192, 17, 192, 1, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 768, 5, 128, 1, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 192, 17, 320, 3, 2, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 192, 17, 192, 3, 2, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 1280, 8, 192, 1, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 1280, 8, 384, 1, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 1280, 8, 320, 1, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 1280, 8, 448, 1, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 384, 8, 384, 1, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 384, 8, 384, 3, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 448, 8, 384, 3, 1, "VALID")
-        verify_conv2d_NHWC_gemm_int8(1, 2048, 8, 320, 1, 1, "SAME")
-        verify_conv2d_NHWC_gemm_int8(1, 2048, 8, 448, 1, 1, "SAME", add_bias=True, add_relu=True)
-        verify_conv2d_NHWC_gemm_int8(1, 2048, 8, 192, 1, 1, "SAME", add_bias=True)
-
-        # Let's also verify that it compiles fine on AArch64 targets
-        compile_conv2d_NHWC_gemm_int8_arm(1, 3, 299, 32, 3, 2, "SAME")
+        a_np, w_np, b_np, c_np = get_ref_data()
+
+        def verify_workload_padding():
+            _, _, out_height, out_width = get_const_tuple(c_np.shape)
+            wkl = _get_workload(A, W, (stride, stride), padding, dilation, dtype)
+
+            # for testing functionality,
+            # we choose arbitrary int32_lanes and num_int8_elements can divide the channel,
+            # regardless of the performance.
+            int32_lanes, num_int8_elements = num_filter, in_channel
+
+            # check if tile_ow candidates are the factors of the right output weight.
+            cfg = autotvm.get_config()
+            fallback_schedule_cpu_common_int8(cfg, wkl, int32_lanes, num_int8_elements)
+            ow_tile = np.prod(cfg["tile_ow"].size)
+
+            tvm.testing.assert_allclose(ow_tile, out_width)
+
+        def check_target(target):
+            dev = tvm.device(target, 0)
+            if not tvm.testing.device_enabled(target):
+                print("Skip because %s is not enabled" % target)
+                return
+            if target == "cuda" and not tvm.contrib.nvcc.have_int8(dev.compute_version):
+                print("Skip because int8 intrinsics are not available")
+                return
+
+            print("Running on target: %s" % target)
+            with tvm.target.Target(target):
+                C = topi.cuda.conv2d_nchw_int8(
+                    A, W, (stride, stride), padding, (dilation, dilation), dtype
+                )
+                if add_bias:
+                    C = topi.add(C, bias)
+                if add_relu:
+                    C = topi.nn.relu(C)
+                s = topi.cuda.schedule_conv2d_nchw_int8([C])
+
+            a = tvm.nd.array(a_np, dev)
+            w = tvm.nd.array(w_np, dev)
+            b = tvm.nd.array(b_np, dev)
+            c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), dev)
+            if add_bias:
+                func = tvm.build(
+                    s,
+                    [A, W, bias, C],
+                    target,
+                    name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
+                    % (
+                        batch,
+                        in_channel,
+                        in_size,
+                        num_filter,
+                        kernel,
+                        stride,
+                        padding_sum,
+                        dilation,
+                    ),
+                )
+                func(a, w, b, c)
+            else:
+                func = tvm.build(
+                    s,
+                    [A, W, C],
+                    target,
+                    name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
+                    % (
+                        batch,
+                        in_channel,
+                        in_size,
+                        num_filter,
+                        kernel,
+                        stride,
+                        padding_sum,
+                        dilation,
+                    ),
+                )
+                func(a, w, c)

Review Comment:
   Done



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] Mousius merged pull request #13669: [TOPI][bugfix] Fix a bug in arm_cpu int8 dotprod schedule and modernize tests

Posted by GitBox <gi...@apache.org>.
Mousius merged PR #13669:
URL: https://github.com/apache/tvm/pull/13669


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org