You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2022/09/16 06:28:51 UTC

[tvm] branch main updated: [MetaSchedule][Test] MLT uses SEqual tests (#12805)

This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 77d0a288df [MetaSchedule][Test] MLT uses SEqual tests (#12805)
77d0a288df is described below

commit 77d0a288df4a1975784def14b316bde576fe3980
Author: Junru Shao <ju...@gmail.com>
AuthorDate: Thu Sep 15 23:28:45 2022 -0700

    [MetaSchedule][Test] MLT uses SEqual tests (#12805)
    
    This PR finishes migration from `check_trace` (string-based equality
    check on TIR trace) to `check_sketch` (SEqual-based equality check on
    TIR). Here, we split multi-level-tiling into 3 files:
    - Plain multi-level tiling without any intrinsics
    - Multi-level tiling with intrinsics like VNNI, DP4a
    - Multi-level tiling with TensorCore which comes with different handling
    
    Besides, we cleaned up the testing folder and removed several methods
    that are no longer useful for unittests.
---
 python/tvm/meta_schedule/testing/schedule_rule.py  |  138 +--
 .../multi_level_tiling_tensor_core.cc              |    4 +-
 src/meta_schedule/utils.h                          |   35 +-
 .../test_meta_schedule_schedule_rule_auto_bind.py  |   22 +-
 ...test_meta_schedule_schedule_rule_auto_inline.py |   19 +-
 ...chedule_schedule_rule_cross_thread_reduction.py |   17 +-
 .../test_meta_schedule_schedule_rule_mlt.py        |  529 +++++++++
 .../test_meta_schedule_schedule_rule_mlt_intrin.py |  418 +++++++
 .../test_meta_schedule_schedule_rule_mlt_tc.py     |  957 ++++++++++++++++
 ...ta_schedule_schedule_rule_multi_level_tiling.py | 1205 --------------------
 10 files changed, 1961 insertions(+), 1383 deletions(-)

diff --git a/python/tvm/meta_schedule/testing/schedule_rule.py b/python/tvm/meta_schedule/testing/schedule_rule.py
index 12ca4200d7..f14e90b6f0 100644
--- a/python/tvm/meta_schedule/testing/schedule_rule.py
+++ b/python/tvm/meta_schedule/testing/schedule_rule.py
@@ -15,122 +15,22 @@
 # specific language governing permissions and limitations
 # under the License.
 """Default schedule rules"""
-from typing import List, Union
-
-from tvm.meta_schedule.schedule_rule import (
-    AutoInline,
-    MultiLevelTiling,
-    MultiLevelTilingTensorCore,
-    ReuseType,
-    ScheduleRule,
-)
-from tvm.target import Target
-
-
-def auto_inline(target: Target) -> ScheduleRule:
-    """Default schedule rules for auto inline"""
-    if target.kind.name == "llvm":
-        return AutoInline(
-            into_producer=False,
-            into_consumer=True,
-            inline_const_tensor=True,
-            disallow_if_then_else=True,
-            require_injective=True,
-            require_ordered=True,
-            disallow_op=["tir.exp"],
-        )
-    if target.kind.name == "cuda":
-        return AutoInline(
-            into_producer=True,
-            into_consumer=True,
-            inline_const_tensor=True,
-            disallow_if_then_else=False,
-            require_injective=False,
-            require_ordered=False,
-            disallow_op=None,
-        )
-    raise NotImplementedError(f"{target.kind.name} is not supported")
-
-
-def multi_level_tiling(target: Target) -> ScheduleRule:
-    """Default schedule rules for with multi-level tiling and reuse"""
-    if target.kind.name == "llvm":
-        return MultiLevelTiling(
-            structure="SSRSRS",
-            tile_binds=None,
-            max_innermost_factor=64,
-            vector_load_lens=None,
-            reuse_read=None,
-            reuse_write=ReuseType(
-                req="may",
-                levels=[1, 2],
-                scope="global",
-            ),
-        )
-    if target.kind.name == "cuda":
-        return MultiLevelTiling(
-            structure="SSSRRSRS",
-            tile_binds=["blockIdx.x", "vthread.x", "threadIdx.x"],
-            max_innermost_factor=64,
-            vector_load_lens=[1, 2, 3, 4, 8, 16],
-            reuse_read=ReuseType(
-                req="must",
-                levels=[4],
-                scope="shared",
-            ),
-            reuse_write=ReuseType(
-                req="must",
-                levels=[3],
-                scope="local",
-            ),
-        )
-    raise NotImplementedError(f"{target.kind.name} is not supported")
-
-
-def multi_level_tiling_tensor_core(
-    target: Target,
-    write_reuse_scope: str = "shared",
-    in_dtype: Union[str, List[str]] = "float16",
-    out_dtype: Union[str, List[str]] = "float32",
-    trans_b: Union[bool, List[bool]] = False,
-    use_software_pipeline: bool = False,
-) -> ScheduleRule:
-    """Default schedule rules for with multi-level tiling reuse for tensor core"""
-    assert write_reuse_scope in ["shared", "global"]
-    if not isinstance(in_dtype, list):
-        in_dtype = [in_dtype]
-    if not isinstance(out_dtype, list):
-        out_dtype = [out_dtype]
-    if not isinstance(trans_b, list):
-        trans_b = [trans_b]
-
-    if target.kind.name == "cuda":
-        from tvm.tir.tensor_intrin import (  # pylint: disable=import-outside-toplevel
-            cuda,
-        )
-
-        intrin_groups = [
-            cuda.get_wmma_intrin_group(write_reuse_scope, _in_dtype, _out_dtype, _trans_b)
-            for _in_dtype in in_dtype
-            for _out_dtype in out_dtype
-            for _trans_b in trans_b
-        ]
-        return MultiLevelTilingTensorCore(
-            intrin_groups=intrin_groups,
-            structure="SSSRRSRS",
-            tile_binds=["blockIdx.y", "blockIdx.x", "threadIdx.y"],
-            max_innermost_factor=4,  # 64 // tensor intrin size
-            vector_load_lens=[1, 2, 3, 4, 8, 16],
-            reuse_read=ReuseType(
-                req="must",
-                levels=[4],
-                scope="shared",
-            ),
-            reuse_write=ReuseType(
-                req="must" if write_reuse_scope == "shared" else "no",
-                levels=[2],
-                scope=write_reuse_scope,
-            ),
-            use_software_pipeline=use_software_pipeline,
-        )
-    raise NotImplementedError(f"{target.kind.name} is not supported")
+from typing import List, Tuple, Union
+
+from tvm.meta_schedule import default_config
+from tvm.meta_schedule.schedule_rule import ScheduleRule
+
+
+def get_rules(kind: str, types: Union[type, Tuple[type, ...]]) -> List[ScheduleRule]:
+    """Get default schedule rules"""
+    # pylint: disable=protected-access
+    if kind == "llvm":
+        rules = default_config._DefaultLLVM.schedule_rules()
+    elif kind == "cuda":
+        rules = default_config._DefaultCUDA.schedule_rules()
+    elif kind == "tensor_core":
+        rules = default_config._DefaultCUDATensorCore.schedule_rules()
+    else:
+        raise NotImplementedError(f"{kind} is not supported")
+    # pylint: enable=protected-access
+    return [rule for rule in rules if isinstance(rule, types)]
diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc b/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc
index 13b00fa7de..8fcb8fe503 100644
--- a/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc
+++ b/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc
@@ -328,7 +328,7 @@ std::vector<State> MultiLevelTilingTensorCoreNode::AddSoftwarePipeline(
   // Add local stage and double buffering
   for (int i = 0; i < 2; ++i) {
     const tir::BlockRV cache_read = state->read_reuse.at(i);
-    sch->Annotate(cache_read, tir::attr::manifest_shared_memory_local_stage, Bool(true));
+    sch->Annotate(cache_read, tir::attr::manifest_shared_memory_local_stage, Integer(1));
     sch->Annotate(cache_read, tir::attr::double_buffer_scope, Integer(0));
   }
 
@@ -536,7 +536,7 @@ inline std::vector<State> MultiLevelTilingTensorCoreNode::TransformForTensorizat
                        state->intrin_group.compute_intrin);
   state->sch->Annotate(state->block_rv, tir::attr::meta_schedule_auto_tensorize_init,
                        state->intrin_group.init_intrin);
-  state->sch->Annotate(state->block_rv, tir::attr::warp_execution, Bool(true));
+  state->sch->Annotate(state->block_rv, tir::attr::warp_execution, Integer(1));
   return {std::move(state)};
 }
 
diff --git a/src/meta_schedule/utils.h b/src/meta_schedule/utils.h
index ad56fa7f6a..cf9a329170 100644
--- a/src/meta_schedule/utils.h
+++ b/src/meta_schedule/utils.h
@@ -77,33 +77,34 @@ class PyLogMessage {
     // FATAL not included
   };
 
-  PyLogMessage(const std::string& file, int lineno, PackedFunc logging_func, Level logging_level) {
-    this->logging_func = logging_func;
-    this->logging_level = logging_level;
-  }
+  explicit PyLogMessage(const char* file, int lineno, PackedFunc logging_func, Level logging_level)
+      : file_(file), lineno_(lineno), logging_func_(logging_func), logging_level_(logging_level) {}
+
   TVM_NO_INLINE ~PyLogMessage() {
-    if (this->logging_func.defined()) {
-      logging_func(static_cast<int>(logging_level), stream_.str());
+    if (this->logging_func_.defined()) {
+      logging_func_(static_cast<int>(logging_level_), stream_.str());
     } else {
-      if (logging_level == Level::INFO) {
-        LOG(INFO) << stream_.str();
-      } else if (logging_level == Level::WARNING) {
-        LOG(WARNING) << stream_.str();
-      } else if (logging_level == Level::ERROR) {
-        LOG(ERROR) << stream_.str();
-      } else if (logging_level == Level::DEBUG) {
-        DLOG(INFO) << stream_.str();
+      if (logging_level_ == Level::INFO) {
+        runtime::detail::LogMessage(file_, lineno_).stream() << stream_.str();
+      } else if (logging_level_ == Level::WARNING) {
+        runtime::detail::LogMessage(file_, lineno_).stream() << "Warning: " << stream_.str();
+      } else if (logging_level_ == Level::ERROR) {
+        runtime::detail::LogMessage(file_, lineno_).stream() << "Error: " << stream_.str();
+      } else if (logging_level_ == Level::DEBUG) {
+        runtime::detail::LogMessage(file_, lineno_).stream() << "Debug: " << stream_.str();
       } else {
-        LOG(FATAL) << stream_.str();
+        runtime::detail::LogFatal(file_, lineno_).stream() << stream_.str();
       }
     }
   }
   std::ostringstream& stream() { return stream_; }
 
  private:
+  const char* file_;
+  int lineno_;
   std::ostringstream stream_;
-  PackedFunc logging_func;
-  Level logging_level;
+  PackedFunc logging_func_;
+  Level logging_level_;
 };
 
 /*! \brief The type of the random state */
diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_auto_bind.py b/tests/python/unittest/test_meta_schedule_schedule_rule_auto_bind.py
index 21ad04da47..a50292df7a 100644
--- a/tests/python/unittest/test_meta_schedule_schedule_rule_auto_bind.py
+++ b/tests/python/unittest/test_meta_schedule_schedule_rule_auto_bind.py
@@ -16,6 +16,7 @@
 # under the License.
 # pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring
 from tvm import meta_schedule as ms
+from tvm.meta_schedule.testing.schedule_rule import get_rules
 from tvm.meta_schedule.testing.space_generation import check_sketches
 from tvm.script import tir as T
 from tvm.target import Target
@@ -83,12 +84,7 @@ def test_cuda_element_wise():
         mod=mod,
         target=Target("nvidia/geforce-rtx-3080", host="llvm"),
         space_generator=ms.space_generator.PostOrderApply(),
-        sch_rules=[
-            ms.schedule_rule.AutoBind(
-                max_threadblocks=256,
-                thread_extents=[32, 64, 128, 256, 512, 1024],
-            )
-        ],
+        sch_rules=get_rules("cuda", ms.schedule_rule.AutoBind),
         task_name="test",
     ).generate_design_space()
     check_sketches(
@@ -122,12 +118,7 @@ def test_cuda_reduction_loop_only():
         mod=mod,
         target=Target("nvidia/geforce-rtx-3080", host="llvm"),
         space_generator=ms.space_generator.PostOrderApply(),
-        sch_rules=[
-            ms.schedule_rule.AutoBind(
-                max_threadblocks=256,
-                thread_extents=[32, 64, 128, 256, 512, 1024],
-            )
-        ],
+        sch_rules=get_rules("cuda", ms.schedule_rule.AutoBind),
         task_name="test",
     ).generate_design_space()
     check_sketches(
@@ -158,12 +149,7 @@ def test_cuda_zero_dim_add():
         mod=mod,
         target=Target("nvidia/geforce-rtx-3080", host="llvm"),
         space_generator=ms.space_generator.PostOrderApply(),
-        sch_rules=[
-            ms.schedule_rule.AutoBind(
-                max_threadblocks=256,
-                thread_extents=[32, 64, 128, 256, 512, 1024],
-            )
-        ],
+        sch_rules=get_rules("cuda", ms.schedule_rule.AutoBind),
         task_name="test",
     ).generate_design_space()
     check_sketches(
diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_auto_inline.py b/tests/python/unittest/test_meta_schedule_schedule_rule_auto_inline.py
index fcf6a8571b..c0801c9d7b 100644
--- a/tests/python/unittest/test_meta_schedule_schedule_rule_auto_inline.py
+++ b/tests/python/unittest/test_meta_schedule_schedule_rule_auto_inline.py
@@ -16,9 +16,8 @@
 # under the License.
 # pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring
 import tvm
-from tvm.meta_schedule.space_generator.post_order_apply import PostOrderApply
-from tvm.meta_schedule.testing.schedule_rule import auto_inline
-from tvm.meta_schedule.tune_context import TuneContext
+from tvm import meta_schedule as ms
+from tvm.meta_schedule.testing.schedule_rule import get_rules
 from tvm.script import tir as T
 from tvm.target import Target
 
@@ -340,10 +339,10 @@ class ConstConsumer:
 
 
 def _create_context(mod, target, rule):
-    ctx = TuneContext(
+    ctx = ms.TuneContext(
         mod=mod,
         target=target,
-        space_generator=PostOrderApply(),
+        space_generator=ms.space_generator.PostOrderApply(),
         sch_rules=[rule],
         task_name="test",
     )
@@ -356,7 +355,7 @@ def test_inline_consumer_chain():
     ctx = _create_context(
         mod=mod,
         target=target,
-        rule=auto_inline(target=target),
+        rule=get_rules("llvm", ms.schedule_rule.AutoInline)[0],
     )
     (space,) = ctx.space_generator.generate_design_space(mod=mod)
     tvm.ir.assert_structural_equal(lhs=space.mod, rhs=Conv2DBiasBnReLUInlined)
@@ -368,7 +367,7 @@ def test_inline_into_cache():
     ctx = _create_context(
         mod=mod,
         target=target,
-        rule=auto_inline(target=target),
+        rule=get_rules("cuda", ms.schedule_rule.AutoInline)[0],
     )
     (space,) = ctx.space_generator.generate_design_space(mod=mod)
     tvm.ir.assert_structural_equal(lhs=space.mod, rhs=MultiLevelTiledConv2DAfterInline)
@@ -380,7 +379,7 @@ def test_inline_into_multiple_consumers():
     ctx = _create_context(
         mod=mod,
         target=target,
-        rule=auto_inline(target=target),
+        rule=get_rules("cuda", ms.schedule_rule.AutoInline)[0],
     )
     (space,) = ctx.space_generator.generate_design_space(mod=mod)
     tvm.ir.assert_structural_equal(lhs=space.mod, rhs=SoftmaxAfterInline)
@@ -392,7 +391,7 @@ def test_inline_pure_spatial():
     ctx = _create_context(
         mod=mod,
         target=target,
-        rule=auto_inline(target=target),
+        rule=get_rules("llvm", ms.schedule_rule.AutoInline)[0],
     )
     (space,) = ctx.space_generator.generate_design_space(mod=mod)
     tvm.ir.assert_structural_equal(lhs=space.mod, rhs=AfterPureSpatial)
@@ -404,7 +403,7 @@ def test_inline_constant_tensor():
     ctx = _create_context(
         mod=mod,
         target=target,
-        rule=auto_inline(target=target),
+        rule=get_rules("cuda", ms.schedule_rule.AutoInline)[0],
     )
     (space,) = ctx.space_generator.generate_design_space(mod=mod)
     tvm.ir.assert_structural_equal(lhs=space.mod, rhs=ConstConsumer)
diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_cross_thread_reduction.py b/tests/python/unittest/test_meta_schedule_schedule_rule_cross_thread_reduction.py
index ab8df6678b..4278638a1a 100644
--- a/tests/python/unittest/test_meta_schedule_schedule_rule_cross_thread_reduction.py
+++ b/tests/python/unittest/test_meta_schedule_schedule_rule_cross_thread_reduction.py
@@ -19,6 +19,7 @@
 import tvm
 from tvm import meta_schedule as ms
 from tvm.meta_schedule.testing import te_workload
+from tvm.meta_schedule.testing.schedule_rule import get_rules
 from tvm.meta_schedule.testing.space_generation import check_sketches
 from tvm.script import tir as T
 from tvm.target import Target
@@ -283,9 +284,7 @@ def test_gpu_softmax_mn():
         mod=mod,
         target=Target("nvidia/geforce-rtx-3090", host="llvm"),
         space_generator=ms.space_generator.PostOrderApply(),
-        sch_rules=[
-            ms.schedule_rule.CrossThreadReduction(thread_extents=[4, 8, 16, 32, 64, 128, 256, 512])
-        ],
+        sch_rules=get_rules("cuda", ms.schedule_rule.CrossThreadReduction),
         task_name="test",
     ).generate_design_space()
     check_sketches(
@@ -481,9 +480,7 @@ def test_gpu_softmax_mn_after_inline():
         mod=mod,
         target=Target("nvidia/geforce-rtx-3090", host="llvm"),
         space_generator=ms.space_generator.PostOrderApply(),
-        sch_rules=[
-            ms.schedule_rule.CrossThreadReduction(thread_extents=[4, 8, 16, 32, 64, 128, 256, 512])
-        ],
+        sch_rules=get_rules("cuda", ms.schedule_rule.CrossThreadReduction),
         task_name="test",
     ).generate_design_space()
     check_sketches(
@@ -559,9 +556,7 @@ def test_gpu_batch_norm_bmn():
         mod=mod,
         target=Target("nvidia/geforce-rtx-3090", host="llvm"),
         space_generator=ms.space_generator.PostOrderApply(),
-        sch_rules=[
-            ms.schedule_rule.CrossThreadReduction(thread_extents=[4, 8, 16, 32, 64, 128, 256, 512])
-        ],
+        sch_rules=get_rules("cuda", ms.schedule_rule.CrossThreadReduction),
         task_name="test",
     ).generate_design_space()
     check_sketches(
@@ -657,9 +652,7 @@ def test_gpu_argmax():
         mod=mod,
         target=Target("nvidia/geforce-rtx-3090", host="llvm"),
         space_generator=ms.space_generator.PostOrderApply(),
-        sch_rules=[
-            ms.schedule_rule.CrossThreadReduction(thread_extents=[4, 8, 16, 32, 64, 128, 256, 512])
-        ],
+        sch_rules=get_rules("cuda", ms.schedule_rule.CrossThreadReduction),
         task_name="test",
     ).generate_design_space()
     check_sketches(
diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_mlt.py b/tests/python/unittest/test_meta_schedule_schedule_rule_mlt.py
new file mode 100644
index 0000000000..939ccbe54f
--- /dev/null
+++ b/tests/python/unittest/test_meta_schedule_schedule_rule_mlt.py
@@ -0,0 +1,529 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring
+from tvm import meta_schedule as ms
+from tvm import te
+from tvm.meta_schedule.testing import te_workload
+from tvm.meta_schedule.testing.schedule_rule import get_rules
+from tvm.meta_schedule.testing.space_generation import check_sketches
+from tvm.script import tir as T
+from tvm.target import Target
+
+
+def test_cpu_matmul():
+    @T.prim_func
+    def cpu_matmul_0(
+        A: T.Buffer[(512, 512), "float32"],
+        B: T.Buffer[(512, 512), "float32"],
+        C: T.Buffer[(512, 512), "float32"],
+    ) -> None:
+        # function attr dict
+        T.func_attr({"global_symbol": "main", "tir.noalias": True})
+        # body
+        # with T.block("root")
+        C_global = T.alloc_buffer([512, 512], dtype="float32")
+        for i0_0, i1_0, i0_1, i1_1 in T.grid(1, 8, 8, 1):
+            for i2_0, i0_2, i1_2, i2_1, i0_3, i1_3 in T.grid(16, 2, 8, 32, 32, 8):
+                with T.block("C"):
+                    i = T.axis.spatial(512, i0_0 * 512 + i0_1 * 64 + i0_2 * 32 + i0_3)
+                    j = T.axis.spatial(512, i1_0 * 64 + i1_1 * 64 + i1_2 * 8 + i1_3)
+                    k = T.axis.reduce(512, i2_0 * 32 + i2_1)
+                    T.reads(A[i, k], B[k, j])
+                    T.writes(C_global[i, j])
+                    T.block_attr({"meta_schedule.tiling_structure": "SSRSRS"})
+                    with T.init():
+                        C_global[i, j] = T.float32(0)
+                    C_global[i, j] = C_global[i, j] + A[i, k] * B[k, j]
+            for ax0, ax1 in T.grid(64, 64):
+                with T.block("C_global"):
+                    v0 = T.axis.spatial(512, i0_1 * 64 + ax0)
+                    v1 = T.axis.spatial(512, i1_0 * 64 + ax1)
+                    T.reads(C_global[v0, v1])
+                    T.writes(C[v0, v1])
+                    C[v0, v1] = C_global[v0, v1]
+
+    @T.prim_func
+    def cpu_matmul_1(
+        A: T.Buffer[(512, 512), "float32"],
+        B: T.Buffer[(512, 512), "float32"],
+        C: T.Buffer[(512, 512), "float32"],
+    ) -> None:
+        # function attr dict
+        T.func_attr({"global_symbol": "main", "tir.noalias": True})
+        # body
+        # with T.block("root")
+        C_global = T.alloc_buffer([512, 512], dtype="float32")
+        for i0_0, i1_0 in T.grid(1, 8):
+            for i0_1, i1_1, i2_0, i0_2, i1_2, i2_1, i0_3, i1_3 in T.grid(8, 1, 16, 2, 8, 32, 32, 8):
+                with T.block("C"):
+                    i = T.axis.spatial(512, i0_0 * 512 + i0_1 * 64 + i0_2 * 32 + i0_3)
+                    j = T.axis.spatial(512, i1_0 * 64 + i1_1 * 64 + i1_2 * 8 + i1_3)
+                    k = T.axis.reduce(512, i2_0 * 32 + i2_1)
+                    T.reads(A[i, k], B[k, j])
+                    T.writes(C_global[i, j])
+                    T.block_attr({"meta_schedule.tiling_structure": "SSRSRS"})
+                    with T.init():
+                        C_global[i, j] = T.float32(0)
+                    C_global[i, j] = C_global[i, j] + A[i, k] * B[k, j]
+            for ax0, ax1 in T.grid(512, 64):
+                with T.block("C_global"):
+                    v0 = T.axis.spatial(512, ax0)
+                    v1 = T.axis.spatial(512, i1_0 * 64 + ax1)
+                    T.reads(C_global[v0, v1])
+                    T.writes(C[v0, v1])
+                    C[v0, v1] = C_global[v0, v1]
+
+    @T.prim_func
+    def cpu_matmul_2(
+        A: T.Buffer[(512, 512), "float32"],
+        B: T.Buffer[(512, 512), "float32"],
+        C: T.Buffer[(512, 512), "float32"],
+    ) -> None:
+        # function attr dict
+        T.func_attr({"global_symbol": "main", "tir.noalias": True})
+        # body
+        # with T.block("root")
+        for i0_0, i1_0, i0_1, i1_1, i2_0, i0_2, i1_2, i2_1, i0_3, i1_3 in T.grid(
+            1, 8, 8, 1, 16, 2, 8, 32, 32, 8
+        ):
+            with T.block("C"):
+                i = T.axis.spatial(512, i0_0 * 512 + i0_1 * 64 + i0_2 * 32 + i0_3)
+                j = T.axis.spatial(512, i1_0 * 64 + i1_1 * 64 + i1_2 * 8 + i1_3)
+                k = T.axis.reduce(512, i2_0 * 32 + i2_1)
+                T.reads(A[i, k], B[k, j])
+                T.writes(C[i, j])
+                T.block_attr({"meta_schedule.tiling_structure": "SSRSRS"})
+                with T.init():
+                    C[i, j] = T.float32(0)
+                C[i, j] = C[i, j] + A[i, k] * B[k, j]
+
+    decision_0 = [
+        ("SamplePerfectTile", [1, 8, 2, 32]),
+        ("SamplePerfectTile", [8, 1, 8, 8]),
+        ("SamplePerfectTile", [16, 32]),
+    ]
+    decision_1 = [
+        ("SamplePerfectTile", [1, 8, 2, 32]),
+        ("SamplePerfectTile", [8, 1, 8, 8]),
+        ("SamplePerfectTile", [16, 32]),
+    ]
+    decision_2 = [
+        ("SamplePerfectTile", [1, 8, 2, 32]),
+        ("SamplePerfectTile", [8, 1, 8, 8]),
+        ("SamplePerfectTile", [16, 32]),
+    ]
+
+    mod = te.create_prim_func(te_workload.matmul(512, 512, 512))
+    actual = ms.TuneContext(
+        mod=mod,
+        target=Target("llvm"),
+        space_generator=ms.space_generator.PostOrderApply(),
+        sch_rules=get_rules("llvm", ms.schedule_rule.MultiLevelTiling),
+        task_name="test",
+    ).generate_design_space()
+    check_sketches(
+        mod,
+        sketches=actual,
+        expected_mods=[cpu_matmul_0, cpu_matmul_1, cpu_matmul_2],
+        expected_decisions=[decision_0, decision_1, decision_2],
+    )
+
+
+def test_cpu_matmul_relu():
+    @T.prim_func
+    def cpu_matmul_relu_0(
+        A: T.Buffer[(512, 512), "float32"],
+        B: T.Buffer[(512, 512), "float32"],
+        compute: T.Buffer[(512, 512), "float32"],
+    ) -> None:
+        # function attr dict
+        T.func_attr({"global_symbol": "main", "tir.noalias": True})
+        # body
+        # with T.block("root")
+        C = T.alloc_buffer([512, 512], dtype="float32")
+        for i0_0, i1_0, i0_1, i1_1, i2_0, i0_2, i1_2, i2_1, i0_3, i1_3 in T.grid(
+            256, 4, 1, 4, 64, 1, 32, 8, 2, 1
+        ):
+            with T.block("C"):
+                i = T.axis.spatial(512, i0_0 * 2 + i0_1 * 2 + i0_2 * 2 + i0_3)
+                j = T.axis.spatial(512, i1_0 * 128 + i1_1 * 32 + i1_2 + i1_3)
+                k = T.axis.reduce(512, i2_0 * 8 + i2_1)
+                T.reads(A[i, k], B[k, j])
+                T.writes(C[i, j])
+                T.block_attr({"meta_schedule.tiling_structure": "SSRSRS"})
+                with T.init():
+                    C[i, j] = T.float32(0)
+                C[i, j] = C[i, j] + A[i, k] * B[k, j]
+        for i0, i1 in T.grid(512, 512):
+            with T.block("compute"):
+                i0_4, i1_4 = T.axis.remap("SS", [i0, i1])
+                T.reads(C[i0_4, i1_4])
+                T.writes(compute[i0_4, i1_4])
+                compute[i0_4, i1_4] = T.max(C[i0_4, i1_4], T.float32(0))
+
+    @T.prim_func
+    def cpu_matmul_relu_1(
+        A: T.Buffer[(512, 512), "float32"],
+        B: T.Buffer[(512, 512), "float32"],
+        compute: T.Buffer[(512, 512), "float32"],
+    ) -> None:
+        # function attr dict
+        T.func_attr({"global_symbol": "main", "tir.noalias": True})
+        # body
+        # with T.block("root")
+        C = T.alloc_buffer([512, 512], dtype="float32")
+        for i0_0, i1_0, i0_1, i1_1 in T.grid(256, 4, 1, 4):
+            for i2_0, i0_2, i1_2, i2_1, i0_3, i1_3 in T.grid(64, 1, 32, 8, 2, 1):
+                with T.block("C"):
+                    i = T.axis.spatial(512, i0_0 * 2 + i0_1 * 2 + i0_2 * 2 + i0_3)
+                    j = T.axis.spatial(512, i1_0 * 128 + i1_1 * 32 + i1_2 + i1_3)
+                    k = T.axis.reduce(512, i2_0 * 8 + i2_1)
+                    T.reads(A[i, k], B[k, j])
+                    T.writes(C[i, j])
+                    T.block_attr({"meta_schedule.tiling_structure": "SSRSRS"})
+                    with T.init():
+                        C[i, j] = T.float32(0)
+                    C[i, j] = C[i, j] + A[i, k] * B[k, j]
+            for ax0, ax1 in T.grid(2, 32):
+                with T.block("compute"):
+                    i0 = T.axis.spatial(512, i0_0 * 2 + ax0)
+                    i1 = T.axis.spatial(512, i1_0 * 128 + i1_1 * 32 + ax1)
+                    T.reads(C[i0, i1])
+                    T.writes(compute[i0, i1])
+                    compute[i0, i1] = T.max(C[i0, i1], T.float32(0))
+
+    @T.prim_func
+    def cpu_matmul_relu_2(
+        A: T.Buffer[(512, 512), "float32"],
+        B: T.Buffer[(512, 512), "float32"],
+        compute: T.Buffer[(512, 512), "float32"],
+    ) -> None:
+        # function attr dict
+        T.func_attr({"global_symbol": "main", "tir.noalias": True})
+        # body
+        # with T.block("root")
+        C = T.alloc_buffer([512, 512], dtype="float32")
+        for i0_0, i1_0 in T.grid(256, 4):
+            for i0_1, i1_1, i2_0, i0_2, i1_2, i2_1, i0_3, i1_3 in T.grid(1, 4, 64, 1, 32, 8, 2, 1):
+                with T.block("C"):
+                    i = T.axis.spatial(512, i0_0 * 2 + i0_1 * 2 + i0_2 * 2 + i0_3)
+                    j = T.axis.spatial(512, i1_0 * 128 + i1_1 * 32 + i1_2 + i1_3)
+                    k = T.axis.reduce(512, i2_0 * 8 + i2_1)
+                    T.reads(A[i, k], B[k, j])
+                    T.writes(C[i, j])
+                    T.block_attr({"meta_schedule.tiling_structure": "SSRSRS"})
+                    with T.init():
+                        C[i, j] = T.float32(0)
+                    C[i, j] = C[i, j] + A[i, k] * B[k, j]
+            for ax0, ax1 in T.grid(2, 128):
+                with T.block("compute"):
+                    i0 = T.axis.spatial(512, i0_0 * 2 + ax0)
+                    i1 = T.axis.spatial(512, i1_0 * 128 + ax1)
+                    T.reads(C[i0, i1])
+                    T.writes(compute[i0, i1])
+                    compute[i0, i1] = T.max(C[i0, i1], T.float32(0))
+
+    decision_0 = [
+        ("SamplePerfectTile", [256, 1, 1, 2]),
+        ("SamplePerfectTile", [4, 4, 32, 1]),
+        ("SamplePerfectTile", [64, 8]),
+    ]
+    decision_1 = [
+        ("SamplePerfectTile", [256, 1, 1, 2]),
+        ("SamplePerfectTile", [4, 4, 32, 1]),
+        ("SamplePerfectTile", [64, 8]),
+    ]
+    decision_2 = [
+        ("SamplePerfectTile", [256, 1, 1, 2]),
+        ("SamplePerfectTile", [4, 4, 32, 1]),
+        ("SamplePerfectTile", [64, 8]),
+    ]
+    mod = te.create_prim_func(te_workload.matmul_relu(512, 512, 512))
+    actual = ms.TuneContext(
+        mod=mod,
+        target=Target("llvm"),
+        space_generator=ms.space_generator.PostOrderApply(),
+        sch_rules=get_rules("llvm", ms.schedule_rule.MultiLevelTiling),
+        task_name="test",
+    ).generate_design_space()
+    check_sketches(
+        mod,
+        sketches=actual,
+        expected_mods=[cpu_matmul_relu_0, cpu_matmul_relu_1, cpu_matmul_relu_2],
+        expected_decisions=[decision_0, decision_1, decision_2],
+    )
+
+
+def test_cuda_matmul():
+    @T.prim_func
+    def cuda_matmul_0(
+        A: T.Buffer[(512, 512), "float32"],
+        B: T.Buffer[(512, 512), "float32"],
+        C: T.Buffer[(512, 512), "float32"],
+    ) -> None:
+        # function attr dict
+        T.func_attr({"global_symbol": "main", "tir.noalias": True})
+        # body
+        # with T.block("root")
+        C_local = T.alloc_buffer([512, 512], dtype="float32", scope="local")
+        A_shared = T.alloc_buffer([512, 512], dtype="float32", scope="shared")
+        B_shared = T.alloc_buffer([512, 512], dtype="float32", scope="shared")
+        for i0_0_i1_0_fused in T.thread_binding(128, thread="blockIdx.x"):
+            for i0_1_i1_1_fused in T.thread_binding(8, thread="vthread.x"):
+                for i0_2_i1_2_fused in T.thread_binding(4, thread="threadIdx.x"):
+                    for i2_0 in T.serial(128):
+                        for ax0_ax1_fused in T.serial(256):
+                            with T.block("A_shared"):
+                                v0 = T.axis.spatial(
+                                    512, i0_0_i1_0_fused // 16 * 64 + ax0_ax1_fused // 4
+                                )
+                                v1 = T.axis.spatial(512, i2_0 * 4 + ax0_ax1_fused % 4)
+                                T.reads(A[v0, v1])
+                                T.writes(A_shared[v0, v1])
+                                T.block_attr({"meta_schedule.cooperative_fetch": 2})
+                                A_shared[v0, v1] = A[v0, v1]
+                        for ax0_ax1_fused in T.serial(128):
+                            with T.block("B_shared"):
+                                v0 = T.axis.spatial(512, i2_0 * 4 + ax0_ax1_fused // 32)
+                                v1 = T.axis.spatial(
+                                    512, i0_0_i1_0_fused % 16 * 32 + ax0_ax1_fused % 32
+                                )
+                                T.reads(B[v0, v1])
+                                T.writes(B_shared[v0, v1])
+                                T.block_attr({"meta_schedule.cooperative_fetch": 1})
+                                B_shared[v0, v1] = B[v0, v1]
+                        for i2_1, i0_3, i1_3, i2_2, i0_4, i1_4 in T.grid(2, 1, 1, 2, 16, 4):
+                            with T.block("C"):
+                                i = T.axis.spatial(
+                                    512,
+                                    i0_0_i1_0_fused // 16 * 64
+                                    + i0_1_i1_1_fused // 2 * 16
+                                    + i0_3 * 16
+                                    + i0_4,
+                                )
+                                j = T.axis.spatial(
+                                    512,
+                                    i0_0_i1_0_fused % 16 * 32
+                                    + i0_1_i1_1_fused % 2 * 16
+                                    + i0_2_i1_2_fused * 4
+                                    + i1_3 * 4
+                                    + i1_4,
+                                )
+                                k = T.axis.reduce(512, i2_0 * 4 + i2_1 * 2 + i2_2)
+                                T.reads(A_shared[i, k], B_shared[k, j])
+                                T.writes(C_local[i, j])
+                                T.block_attr(
+                                    {
+                                        "meta_schedule.thread_extent_high_inclusive": 1024,
+                                        "meta_schedule.thread_extent_low_inclusive": 32,
+                                        "meta_schedule.tiling_structure": "SSSRRSRS",
+                                    }
+                                )
+                                with T.init():
+                                    C_local[i, j] = T.float32(0)
+                                C_local[i, j] = C_local[i, j] + A_shared[i, k] * B_shared[k, j]
+                    for ax0, ax1 in T.grid(16, 4):
+                        with T.block("C_local"):
+                            v0 = T.axis.spatial(
+                                512, i0_0_i1_0_fused // 16 * 64 + i0_1_i1_1_fused // 2 * 16 + ax0
+                            )
+                            v1 = T.axis.spatial(
+                                512,
+                                i0_0_i1_0_fused % 16 * 32
+                                + i0_1_i1_1_fused % 2 * 16
+                                + i0_2_i1_2_fused * 4
+                                + ax1,
+                            )
+                            T.reads(C_local[v0, v1])
+                            T.writes(C[v0, v1])
+                            C[v0, v1] = C_local[v0, v1]
+
+    decision_0 = [
+        ("SamplePerfectTile", [8, 4, 1, 1, 16]),
+        ("SamplePerfectTile", [16, 2, 4, 1, 4]),
+        ("SamplePerfectTile", [128, 2, 2]),
+        ("SampleCategorical", 1),
+        ("SampleCategorical", 0),
+    ]
+    mod = te.create_prim_func(te_workload.matmul(512, 512, 512))
+    actual = ms.TuneContext(
+        mod=mod,
+        target=Target("nvidia/geforce-rtx-3080"),
+        space_generator=ms.space_generator.PostOrderApply(),
+        sch_rules=get_rules("cuda", ms.schedule_rule.MultiLevelTiling),
+        task_name="test",
+    ).generate_design_space()
+    check_sketches(
+        mod,
+        sketches=actual,
+        expected_mods=[cuda_matmul_0],
+        expected_decisions=[decision_0],
+    )
+
+
+def test_cuda_matmul_relu():
+    @T.prim_func
+    def cuda_matmul_relu_0(
+        A: T.Buffer[(512, 512), "float32"],
+        B: T.Buffer[(512, 512), "float32"],
+        compute: T.Buffer[(512, 512), "float32"],
+    ) -> None:
+        # function attr dict
+        T.func_attr({"global_symbol": "main", "tir.noalias": True})
+        # body
+        # with T.block("root")
+        C = T.alloc_buffer([512, 512], dtype="float32")
+        C_local = T.alloc_buffer([512, 512], dtype="float32", scope="local")
+        A_shared = T.alloc_buffer([512, 512], dtype="float32", scope="shared")
+        B_shared = T.alloc_buffer([512, 512], dtype="float32", scope="shared")
+        for i0_0_i1_0_fused in T.thread_binding(64, thread="blockIdx.x"):
+            for i0_1_i1_1_fused in T.thread_binding(64, thread="vthread.x"):
+                for i0_2_i1_2_fused in T.thread_binding(8, thread="threadIdx.x"):
+                    for i2_0 in T.serial(8):
+                        for ax0_ax1_fused in T.serial(4096):
+                            with T.block("A_shared"):
+                                v0 = T.axis.spatial(
+                                    512, i0_0_i1_0_fused // 8 * 64 + ax0_ax1_fused // 64
+                                )
+                                v1 = T.axis.spatial(512, i2_0 * 64 + ax0_ax1_fused % 64)
+                                T.reads(A[v0, v1])
+                                T.writes(A_shared[v0, v1])
+                                T.block_attr({"meta_schedule.cooperative_fetch": 2})
+                                A_shared[v0, v1] = A[v0, v1]
+                        for ax0_ax1_fused in T.serial(4096):
+                            with T.block("B_shared"):
+                                v0 = T.axis.spatial(512, i2_0 * 64 + ax0_ax1_fused // 64)
+                                v1 = T.axis.spatial(
+                                    512, i0_0_i1_0_fused % 8 * 64 + ax0_ax1_fused % 64
+                                )
+                                T.reads(B[v0, v1])
+                                T.writes(B_shared[v0, v1])
+                                T.block_attr({"meta_schedule.cooperative_fetch": 4})
+                                B_shared[v0, v1] = B[v0, v1]
+                        for i2_1, i0_3, i1_3, i2_2, i0_4, i1_4 in T.grid(8, 2, 1, 8, 2, 2):
+                            with T.block("C"):
+                                i = T.axis.spatial(
+                                    512,
+                                    i0_0_i1_0_fused // 8 * 64
+                                    + i0_1_i1_1_fused // 8 * 8
+                                    + i0_2_i1_2_fused // 4 * 4
+                                    + i0_3 * 2
+                                    + i0_4,
+                                )
+                                j = T.axis.spatial(
+                                    512,
+                                    i0_0_i1_0_fused % 8 * 64
+                                    + i0_1_i1_1_fused % 8 * 8
+                                    + i0_2_i1_2_fused % 4 * 2
+                                    + i1_3 * 2
+                                    + i1_4,
+                                )
+                                k = T.axis.reduce(512, i2_0 * 64 + i2_1 * 8 + i2_2)
+                                T.reads(A_shared[i, k], B_shared[k, j])
+                                T.writes(C_local[i, j])
+                                T.block_attr(
+                                    {
+                                        "meta_schedule.thread_extent_high_inclusive": 1024,
+                                        "meta_schedule.thread_extent_low_inclusive": 32,
+                                        "meta_schedule.tiling_structure": "SSSRRSRS",
+                                    }
+                                )
+                                with T.init():
+                                    C_local[i, j] = T.float32(0)
+                                C_local[i, j] = C_local[i, j] + A_shared[i, k] * B_shared[k, j]
+                    for ax0, ax1 in T.grid(4, 2):
+                        with T.block("C_local"):
+                            v0 = T.axis.spatial(
+                                512,
+                                i0_0_i1_0_fused // 8 * 64
+                                + i0_1_i1_1_fused // 8 * 8
+                                + i0_2_i1_2_fused // 4 * 4
+                                + ax0,
+                            )
+                            v1 = T.axis.spatial(
+                                512,
+                                i0_0_i1_0_fused % 8 * 64
+                                + i0_1_i1_1_fused % 8 * 8
+                                + i0_2_i1_2_fused % 4 * 2
+                                + ax1,
+                            )
+                            T.reads(C_local[v0, v1])
+                            T.writes(C[v0, v1])
+                            C[v0, v1] = C_local[v0, v1]
+        for i0, i1 in T.grid(512, 512):
+            with T.block("compute"):
+                i0_1, i1_1 = T.axis.remap("SS", [i0, i1])
+                T.reads(C[i0_1, i1_1])
+                T.writes(compute[i0_1, i1_1])
+                compute[i0_1, i1_1] = T.max(C[i0_1, i1_1], T.float32(0))
+
+    decision_0 = [
+        ("SamplePerfectTile", [8, 8, 2, 2, 2]),
+        ("SamplePerfectTile", [8, 8, 4, 1, 2]),
+        ("SamplePerfectTile", [8, 8, 8]),
+        ("SampleCategorical", 1),
+        ("SampleCategorical", 3),
+    ]
+    mod = te.create_prim_func(te_workload.matmul_relu(512, 512, 512))
+    actual = ms.TuneContext(
+        mod=mod,
+        target=Target("nvidia/geforce-rtx-3080"),
+        space_generator=ms.space_generator.PostOrderApply(),
+        sch_rules=get_rules("cuda", ms.schedule_rule.MultiLevelTiling),
+        task_name="test",
+    ).generate_design_space()
+    check_sketches(
+        mod,
+        sketches=actual,
+        expected_mods=[cuda_matmul_relu_0],
+        expected_decisions=[decision_0],
+    )
+
+
+def test_cuda_sum_with_trivial_block_iter():
+    @T.prim_func
+    def sum_with_trivial_block_iter(
+        A: T.Buffer[(1, 64, 768), "float32"],
+        B: T.Buffer[(1, 64, 1), "float32"],
+    ) -> None:
+        for i0, i1, i2, i3 in T.grid(1, 64, 1, 768):
+            with T.block("sum"):
+                ax0, ax1, ax2, k2 = T.axis.remap("SSSR", [i0, i1, i2, i3])
+                T.reads(A[ax0, ax1, k2])
+                T.writes(B[ax0, ax1, ax2])
+                with T.init():
+                    B[ax0, ax1, ax2] = T.float32(0)
+                B[ax0, ax1, ax2] = B[ax0, ax1, ax2] + A[ax0, ax1, k2]
+
+    # Expect nothing to happen - the rule is not supposed to be applied in this case
+    mod = sum_with_trivial_block_iter
+    (sch,) = ms.TuneContext(
+        mod=mod,
+        target=Target("nvidia/geforce-rtx-3080"),
+        space_generator=ms.space_generator.PostOrderApply(),
+        sch_rules=get_rules("cuda", ms.schedule_rule.MultiLevelTiling),
+        task_name="test",
+    ).generate_design_space()
+    assert not sch.trace.simplified(remove_postproc=True).insts
+
+
+if __name__ == "__main__":
+    test_cpu_matmul()
+    test_cpu_matmul_relu()
+    test_cuda_matmul()
+    test_cuda_matmul_relu()
+    test_cuda_sum_with_trivial_block_iter()
diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_intrin.py b/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_intrin.py
new file mode 100644
index 0000000000..38ddb137e1
--- /dev/null
+++ b/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_intrin.py
@@ -0,0 +1,418 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring
+from tvm import meta_schedule as ms
+from tvm import te
+from tvm.ir import assert_structural_equal
+from tvm.meta_schedule.testing.space_generation import check_sketches
+from tvm.script import tir as T
+from tvm.target import Target
+from tvm.tir.tensor_intrin.arm_cpu import DP4A_INTRIN
+from tvm.tir.tensor_intrin.x86 import VNNI_DOT_16x4_INTRIN as VNNI_INTRIN
+
+
+def test_vnni_conv2d_nchwc():
+    @T.prim_func
+    def conv2d_nchwc(
+        placeholder: T.Buffer[(1, 4, 56, 56, 16), "uint8"],
+        placeholder_1: T.Buffer[(16, 4, 1, 1, 4, 16, 4), "int8"],
+        conv2d_NCHWc_int8: T.Buffer[(1, 16, 56, 56, 16), "int32"],
+    ) -> None:
+        T.func_attr({"global_symbol": "main", "tir.noalias": True})
+        for i0, i1, i2, i3, i4, i5, i6, i7, i8, i9 in T.grid(1, 16, 56, 56, 16, 1, 1, 4, 4, 4):
+            with T.block("conv2d_NCHWc_int8"):
+                (
+                    n,
+                    oc_chunk,
+                    oh,
+                    ow,
+                    oc_block,
+                    kh,
+                    kw,
+                    ic_outer,
+                    ic_f_inner,
+                    ic_s_inner,
+                ) = T.axis.remap("SSSSSRRRRR", [i0, i1, i2, i3, i4, i5, i6, i7, i8, i9])
+                T.reads(
+                    placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 + ic_s_inner],
+                    placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, oc_block, ic_s_inner],
+                )
+                T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block])
+                with T.init():
+                    conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block] = 0
+                conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block] = conv2d_NCHWc_int8[
+                    n, oc_chunk, oh, ow, oc_block
+                ] + T.cast(
+                    placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 + ic_s_inner], "int32"
+                ) * T.cast(
+                    placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, oc_block, ic_s_inner],
+                    "int32",
+                )
+
+    # fmt: off
+    @T.prim_func
+    def vnni_conv2d_nchwc_0(placeholder: T.Buffer[(1, 4, 56, 56, 16), "uint8"], placeholder_1: T.Buffer[(16, 4, 1, 1, 4, 16, 4), "int8"], conv2d_NCHWc_int8: T.Buffer[(1, 16, 56, 56, 16), "int32"]) -> None:
+        T.func_attr({"global_symbol": "main", "tir.noalias": True})
+        conv2d_NCHWc_int8_global = T.alloc_buffer([1, 16, 56, 56, 16], dtype="int32")
+        for i0_0, i1_0, i2_0, i3_0, i4_0_0, i0_1, i1_1, i2_1, i3_1, i4_0_1 in T.grid(1, 8, 28, 56, 1, 1, 2, 1, 1, 1):
+            for i5_0, i6_0, i7_0, i8_0, i9_0_0, i0_2, i1_2, i2_2, i3_2, i4_0_2, i5_1, i6_1, i7_1, i8_1, i9_0_1, i0_3, i1_3, i2_3, i3_3, i4_0_3 in T.grid(1, 1, 1, 4, 1, 1, 1, 2, 1, 1, 1, 1, 4, 1, 1, 1, 1, 1, 1, 1):
+                with T.block("conv2d_NCHWc_int8_o"):
+                    n = T.axis.spatial(1, 0)
+                    oc_chunk = T.axis.spatial(16, i1_0 * 2 + i1_1 + i1_2 + i1_3)
+                    oh = T.axis.spatial(56, i2_0 * 2 + i2_1 * 2 + i2_2 + i2_3)
+                    ow = T.axis.spatial(56, i3_3 + i3_0 + i3_1 + i3_2)
+                    oc_block_o = T.axis.spatial(1, 0)
+                    kh = T.axis.reduce(1, 0)
+                    kw = T.axis.reduce(1, 0)
+                    ic_outer = T.axis.reduce(4, i7_0 * 4 + i7_1)
+                    ic_f_inner = T.axis.reduce(4, i8_0 + i8_1)
+                    ic_s_inner_o = T.axis.reduce(1, 0)
+                    T.reads(placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 : ic_f_inner * 4 + 4], placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, 0 : 16, 0 : 4])
+                    T.writes(conv2d_NCHWc_int8_global[n, oc_chunk, oh, ow, 0 : 16])
+                    T.block_attr({"meta_schedule.auto_tensorize":"dot_16x4_vnni"})
+                    with T.init():
+                        for i4_1 in T.serial(16):
+                            with T.block("conv2d_NCHWc_int8_init"):
+                                oc_block_i_init = T.axis.spatial(16, i4_1)
+                                T.reads()
+                                T.writes(conv2d_NCHWc_int8_global[n, oc_chunk, oh, ow, oc_block_i_init])
+                                conv2d_NCHWc_int8_global[n, oc_chunk, oh, ow, oc_block_i_init] = 0
+                    for i4_1, i9_1 in T.grid(16, 4):
+                        with T.block("conv2d_NCHWc_int8"):
+                            oc_block_i, ic_s_inner_i = T.axis.remap("SR", [i4_1, i9_1])
+                            T.reads(conv2d_NCHWc_int8_global[n, oc_chunk, oh, ow, oc_block_i], placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 + ic_s_inner_i], placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, oc_block_i, ic_s_inner_i])
+                            T.writes(conv2d_NCHWc_int8_global[n, oc_chunk, oh, ow, oc_block_i])
+                            T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"})
+                            conv2d_NCHWc_int8_global[n, oc_chunk, oh, ow, oc_block_i] = conv2d_NCHWc_int8_global[n, oc_chunk, oh, ow, oc_block_i] + T.cast(placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 + ic_s_inner_i], "int32") * T.cast(placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, oc_block_i, ic_s_inner_i], "int32")
+            for ax0, ax1, ax2, ax3, ax4 in T.grid(1, 1, 2, 1, 16):
+                with T.block("conv2d_NCHWc_int8_global"):
+                    v0 = T.axis.spatial(1, ax0)
+                    v1 = T.axis.spatial(16, i1_0 * 2 + i1_1 + ax1)
+                    v2 = T.axis.spatial(56, i2_0 * 2 + ax2)
+                    v3 = T.axis.spatial(56, i3_0 + ax3)
+                    v4 = T.axis.spatial(16, ax4)
+                    T.reads(conv2d_NCHWc_int8_global[v0, v1, v2, v3, v4])
+                    T.writes(conv2d_NCHWc_int8[v0, v1, v2, v3, v4])
+                    conv2d_NCHWc_int8[v0, v1, v2, v3, v4] = conv2d_NCHWc_int8_global[v0, v1, v2, v3, v4]
+
+    @T.prim_func
+    def vnni_conv2d_nchwc_1(placeholder: T.Buffer[(1, 4, 56, 56, 16), "uint8"], placeholder_1: T.Buffer[(16, 4, 1, 1, 4, 16, 4), "int8"], conv2d_NCHWc_int8: T.Buffer[(1, 16, 56, 56, 16), "int32"]) -> None:
+        T.func_attr({"global_symbol": "main", "tir.noalias": True})
+        conv2d_NCHWc_int8_global = T.alloc_buffer([1, 16, 56, 56, 16], dtype="int32")
+        for i0_0, i1_0, i2_0, i3_0, i4_0_0 in T.grid(1, 8, 28, 56, 1):
+            for i0_1, i1_1, i2_1, i3_1, i4_0_1, i5_0, i6_0, i7_0, i8_0, i9_0_0, i0_2, i1_2, i2_2, i3_2, i4_0_2, i5_1, i6_1, i7_1, i8_1, i9_0_1, i0_3, i1_3, i2_3, i3_3, i4_0_3 in T.grid(1, 2, 1, 1, 1, 1, 1, 1, 4, 1, 1, 1, 2, 1, 1, 1, 1, 4, 1, 1, 1, 1, 1, 1, 1):
+                with T.block("conv2d_NCHWc_int8_o"):
+                    n = T.axis.spatial(1, 0)
+                    oc_chunk = T.axis.spatial(16, i1_0 * 2 + i1_1 + i1_2 + i1_3)
+                    oh = T.axis.spatial(56, i2_0 * 2 + i2_1 * 2 + i2_2 + i2_3)
+                    ow = T.axis.spatial(56, i3_3 + i3_0 + i3_1 + i3_2)
+                    oc_block_o = T.axis.spatial(1, 0)
+                    kh = T.axis.reduce(1, 0)
+                    kw = T.axis.reduce(1, 0)
+                    ic_outer = T.axis.reduce(4, i7_0 * 4 + i7_1)
+                    ic_f_inner = T.axis.reduce(4, i8_0 + i8_1)
+                    ic_s_inner_o = T.axis.reduce(1, 0)
+                    T.reads(placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 : ic_f_inner * 4 + 4], placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, 0 : 16, 0 : 4])
+                    T.writes(conv2d_NCHWc_int8_global[n, oc_chunk, oh, ow, 0 : 16])
+                    T.block_attr({"meta_schedule.auto_tensorize":"dot_16x4_vnni"})
+                    with T.init():
+                        for i4_1 in T.serial(16):
+                            with T.block("conv2d_NCHWc_int8_init"):
+                                oc_block_i_init = T.axis.spatial(16, i4_1)
+                                T.reads()
+                                T.writes(conv2d_NCHWc_int8_global[n, oc_chunk, oh, ow, oc_block_i_init])
+                                conv2d_NCHWc_int8_global[n, oc_chunk, oh, ow, oc_block_i_init] = 0
+                    for i4_1, i9_1 in T.grid(16, 4):
+                        with T.block("conv2d_NCHWc_int8"):
+                            oc_block_i, ic_s_inner_i = T.axis.remap("SR", [i4_1, i9_1])
+                            T.reads(conv2d_NCHWc_int8_global[n, oc_chunk, oh, ow, oc_block_i], placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 + ic_s_inner_i], placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, oc_block_i, ic_s_inner_i])
+                            T.writes(conv2d_NCHWc_int8_global[n, oc_chunk, oh, ow, oc_block_i])
+                            T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"})
+                            conv2d_NCHWc_int8_global[n, oc_chunk, oh, ow, oc_block_i] = conv2d_NCHWc_int8_global[n, oc_chunk, oh, ow, oc_block_i] + T.cast(placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 + ic_s_inner_i], "int32") * T.cast(placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, oc_block_i, ic_s_inner_i], "int32")
+            for ax0, ax1, ax2, ax3, ax4 in T.grid(1, 2, 2, 1, 16):
+                with T.block("conv2d_NCHWc_int8_global"):
+                    v0 = T.axis.spatial(1, ax0)
+                    v1 = T.axis.spatial(16, i1_0 * 2 + ax1)
+                    v2 = T.axis.spatial(56, i2_0 * 2 + ax2)
+                    v3 = T.axis.spatial(56, i3_0 + ax3)
+                    v4 = T.axis.spatial(16, ax4)
+                    T.reads(conv2d_NCHWc_int8_global[v0, v1, v2, v3, v4])
+                    T.writes(conv2d_NCHWc_int8[v0, v1, v2, v3, v4])
+                    conv2d_NCHWc_int8[v0, v1, v2, v3, v4] = conv2d_NCHWc_int8_global[v0, v1, v2, v3, v4]
+
+    @T.prim_func
+    def vnni_conv2d_nchwc_2(placeholder: T.Buffer[(1, 4, 56, 56, 16), "uint8"], placeholder_1: T.Buffer[(16, 4, 1, 1, 4, 16, 4), "int8"], conv2d_NCHWc_int8: T.Buffer[(1, 16, 56, 56, 16), "int32"]) -> None:
+        T.func_attr({"global_symbol": "main", "tir.noalias": True})
+        for i0_0, i1_0, i2_0, i3_0, i4_0_0, i0_1, i1_1, i2_1, i3_1, i4_0_1, i5_0, i6_0, i7_0, i8_0, i9_0_0, i0_2, i1_2, i2_2, i3_2, i4_0_2, i5_1, i6_1, i7_1, i8_1, i9_0_1, i0_3, i1_3, i2_3, i3_3, i4_0_3 in T.grid(1, 8, 28, 56, 1, 1, 2, 1, 1, 1, 1, 1, 1, 4, 1, 1, 1, 2, 1, 1, 1, 1, 4, 1, 1, 1, 1, 1, 1, 1):
+            with T.block("conv2d_NCHWc_int8_o"):
+                n = T.axis.spatial(1, 0)
+                oc_chunk = T.axis.spatial(16, i1_0 * 2 + i1_1 + i1_2 + i1_3)
+                oh = T.axis.spatial(56, i2_0 * 2 + i2_1 * 2 + i2_2 + i2_3)
+                ow = T.axis.spatial(56, i3_3 + i3_0 + i3_1 + i3_2)
+                oc_block_o = T.axis.spatial(1, 0)
+                kh = T.axis.reduce(1, 0)
+                kw = T.axis.reduce(1, 0)
+                ic_outer = T.axis.reduce(4, i7_0 * 4 + i7_1)
+                ic_f_inner = T.axis.reduce(4, i8_0 + i8_1)
+                ic_s_inner_o = T.axis.reduce(1, 0)
+                T.reads(placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 : ic_f_inner * 4 + 4], placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, 0 : 16, 0 : 4])
+                T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, 0 : 16])
+                T.block_attr({"meta_schedule.auto_tensorize":"dot_16x4_vnni"})
+                with T.init():
+                    for i4_1 in T.serial(16):
+                        with T.block("conv2d_NCHWc_int8_init"):
+                            oc_block_i_init = T.axis.spatial(16, i4_1)
+                            T.reads()
+                            T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block_i_init])
+                            conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block_i_init] = 0
+                for i4_1, i9_1 in T.grid(16, 4):
+                    with T.block("conv2d_NCHWc_int8"):
+                        oc_block_i, ic_s_inner_i = T.axis.remap("SR", [i4_1, i9_1])
+                        T.reads(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block_i], placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 + ic_s_inner_i], placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, oc_block_i, ic_s_inner_i])
+                        T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block_i])
+                        T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"})
+                        conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block_i] = conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block_i] + T.cast(placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 + ic_s_inner_i], "int32") * T.cast(placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, oc_block_i, ic_s_inner_i], "int32")
+    # fmt: on
+    decision_0 = [
+        ("SamplePerfectTile", [1, 1, 1, 1]),
+        ("SamplePerfectTile", [8, 2, 1, 1]),
+        ("SamplePerfectTile", [28, 1, 2, 1]),
+        ("SamplePerfectTile", [56, 1, 1, 1]),
+        ("SamplePerfectTile", [1, 1, 1, 1]),
+        ("SamplePerfectTile", [1, 1]),
+        ("SamplePerfectTile", [1, 1]),
+        ("SamplePerfectTile", [1, 4]),
+        ("SamplePerfectTile", [4, 1]),
+        ("SamplePerfectTile", [1, 1]),
+    ]
+    decision_1 = [
+        ("SamplePerfectTile", [1, 1, 1, 1]),
+        ("SamplePerfectTile", [8, 2, 1, 1]),
+        ("SamplePerfectTile", [28, 1, 2, 1]),
+        ("SamplePerfectTile", [56, 1, 1, 1]),
+        ("SamplePerfectTile", [1, 1, 1, 1]),
+        ("SamplePerfectTile", [1, 1]),
+        ("SamplePerfectTile", [1, 1]),
+        ("SamplePerfectTile", [1, 4]),
+        ("SamplePerfectTile", [4, 1]),
+        ("SamplePerfectTile", [1, 1]),
+    ]
+    decision_2 = [
+        ("SamplePerfectTile", [1, 1, 1, 1]),
+        ("SamplePerfectTile", [8, 2, 1, 1]),
+        ("SamplePerfectTile", [28, 1, 2, 1]),
+        ("SamplePerfectTile", [56, 1, 1, 1]),
+        ("SamplePerfectTile", [1, 1, 1, 1]),
+        ("SamplePerfectTile", [1, 1]),
+        ("SamplePerfectTile", [1, 1]),
+        ("SamplePerfectTile", [1, 4]),
+        ("SamplePerfectTile", [4, 1]),
+        ("SamplePerfectTile", [1, 1]),
+    ]
+
+    mod = conv2d_nchwc
+    target = Target("llvm -mcpu=cascadelake -num-cores=4")
+    actual = ms.TuneContext(
+        mod=mod,
+        target=Target(target),
+        space_generator=ms.space_generator.PostOrderApply(),
+        sch_rules=[
+            ms.schedule_rule.MultiLevelTilingWithIntrin(
+                VNNI_INTRIN,
+                structure="SSRSRS",
+                tile_binds=None,
+                max_innermost_factor=64,
+                vector_load_lens=None,
+                reuse_read=None,
+                reuse_write=ms.schedule_rule.ReuseType(req="may", levels=[1, 2], scope="global"),
+            ),
+        ],
+    ).generate_design_space()
+    check_sketches(
+        mod,
+        sketches=actual,
+        expected_mods=[vnni_conv2d_nchwc_0, vnni_conv2d_nchwc_1, vnni_conv2d_nchwc_2],
+        expected_decisions=[decision_0, decision_1, decision_2],
+    )
+
+
+def _check_dp4a_dense(m, n, k, in_dtype, out_dtype, expected_mods, expected_decisions):
+    def _dense(m, n, k, in_dtype, out_dtype):
+        X = te.placeholder((m, k), name="X", dtype=in_dtype)
+        W = te.placeholder((n, k), name="W", dtype=in_dtype)
+        ak = te.reduce_axis((0, k), name="k")
+        matmul = te.compute(
+            (m, n),
+            lambda i, j: te.sum(
+                X[i, ak].astype(out_dtype) * W[j, ak].astype(out_dtype),
+                axis=ak,
+            ),
+            name="compute",
+        )
+        return te.create_prim_func([X, W, matmul])
+
+    mod = _dense(m, n, k, in_dtype, out_dtype)
+    actual = ms.TuneContext(
+        mod=mod,
+        target=Target("cuda"),
+        space_generator=ms.space_generator.PostOrderApply(),
+        sch_rules=[
+            ms.schedule_rule.MultiLevelTilingWithIntrin(
+                DP4A_INTRIN,
+                structure="SSSRRSRS",
+                tile_binds=["blockIdx.x", "vthread.x", "threadIdx.x"],
+                max_innermost_factor=64,
+                vector_load_lens=[1, 2, 3, 4],
+                reuse_read=ms.schedule_rule.ReuseType(req="must", levels=[4], scope="shared"),
+                reuse_write=ms.schedule_rule.ReuseType(req="must", levels=[3], scope="local"),
+            )
+        ],
+    ).generate_design_space()
+    if expected_mods is None:
+        assert expected_decisions is None
+        assert len(actual) == 1
+        assert_structural_equal(mod, actual[0].mod["main"])
+    else:
+        check_sketches(mod, actual, expected_mods, expected_decisions)
+
+
+def test_dp4a_dense():
+    @T.prim_func
+    def dp4a_dense_0(
+        X: T.Buffer[(128, 128), "int8"],
+        W: T.Buffer[(128, 128), "int8"],
+        compute: T.Buffer[(128, 128), "int32"],
+    ) -> None:
+        # function attr dict
+        T.func_attr({"global_symbol": "main", "tir.noalias": True})
+        # body
+        # with T.block("root")
+        compute_local = T.alloc_buffer([128, 128], dtype="int32", scope="local")
+        X_shared = T.alloc_buffer([128, 128], dtype="int8", scope="shared")
+        W_shared = T.alloc_buffer([128, 128], dtype="int8", scope="shared")
+        for i0_0_i1_0_fused in T.thread_binding(1, thread="blockIdx.x"):
+            for i0_1_i1_1_fused in T.thread_binding(512, thread="vthread.x"):
+                for i0_2_i1_2_fused in T.thread_binding(2, thread="threadIdx.x"):
+                    for i2_0_0 in T.serial(1):
+                        for ax0_ax1_fused in T.serial(16384):
+                            with T.block("X_shared"):
+                                v0 = T.axis.spatial(128, ax0_ax1_fused // 128)
+                                v1 = T.axis.spatial(128, ax0_ax1_fused % 128)
+                                T.reads(X[v0, v1])
+                                T.writes(X_shared[v0, v1])
+                                T.block_attr({"meta_schedule.cooperative_fetch": 1})
+                                X_shared[v0, v1] = X[v0, v1]
+                        for ax0_ax1_fused in T.serial(16384):
+                            with T.block("W_shared"):
+                                v0 = T.axis.spatial(128, ax0_ax1_fused // 128)
+                                v1 = T.axis.spatial(128, ax0_ax1_fused % 128)
+                                T.reads(W[v0, v1])
+                                T.writes(W_shared[v0, v1])
+                                T.block_attr({"meta_schedule.cooperative_fetch": 1})
+                                W_shared[v0, v1] = W[v0, v1]
+                        for i2_0_1, i0_3, i1_3, i2_0_2, i0_4, i1_4 in T.grid(1, 2, 4, 32, 2, 1):
+                            with T.block("compute_o"):
+                                i = T.axis.spatial(
+                                    128,
+                                    i0_1_i1_1_fused // 32 * 8
+                                    + i0_2_i1_2_fused * 4
+                                    + i0_3 * 2
+                                    + i0_4,
+                                )
+                                j = T.axis.spatial(128, i1_4 + i0_1_i1_1_fused % 32 * 4 + i1_3)
+                                k_o = T.axis.reduce(32, i2_0_0 * 32 + i2_0_1 * 32 + i2_0_2)
+                                T.reads(
+                                    X_shared[i, k_o * 4 : k_o * 4 + 4],
+                                    W_shared[j, k_o * 4 : k_o * 4 + 4],
+                                )
+                                T.writes(compute_local[i, j])
+                                T.block_attr({"meta_schedule.auto_tensorize": "dp4a"})
+                                with T.init():
+                                    with T.block("compute_init"):
+                                        T.reads()
+                                        T.writes(compute_local[i, j])
+                                        compute_local[i, j] = 0
+                                for i2_1 in T.serial(4):
+                                    with T.block("compute"):
+                                        k_i = T.axis.reduce(4, i2_1)
+                                        T.reads(
+                                            compute_local[i, j],
+                                            X_shared[i, k_o * 4 + k_i],
+                                            W_shared[j, k_o * 4 + k_i],
+                                        )
+                                        T.writes(compute_local[i, j])
+                                        T.block_attr({"meta_schedule.tiling_structure": "SSSRRSRS"})
+                                        compute_local[i, j] = compute_local[i, j] + T.cast(
+                                            X_shared[i, k_o * 4 + k_i], "int32"
+                                        ) * T.cast(W_shared[j, k_o * 4 + k_i], "int32")
+                    for ax0, ax1 in T.grid(4, 4):
+                        with T.block("compute_local"):
+                            v0 = T.axis.spatial(
+                                128, i0_1_i1_1_fused // 32 * 8 + i0_2_i1_2_fused * 4 + ax0
+                            )
+                            v1 = T.axis.spatial(128, i0_1_i1_1_fused % 32 * 4 + ax1)
+                            T.reads(compute_local[v0, v1])
+                            T.writes(compute[v0, v1])
+                            compute[v0, v1] = compute_local[v0, v1]
+
+    decision_0 = [
+        ("SamplePerfectTile", [1, 16, 2, 2, 2]),
+        ("SamplePerfectTile", [1, 32, 1, 4, 1]),
+        ("SamplePerfectTile", [1, 1, 32]),
+        ("SampleCategorical", 0),
+        ("SampleCategorical", 0),
+    ]
+    _check_dp4a_dense(
+        m=128,
+        n=128,
+        k=128,
+        in_dtype="int8",
+        out_dtype="int32",
+        expected_mods=[dp4a_dense_0],
+        expected_decisions=[decision_0],
+    )
+
+
+def test_dp4a_dense_no_tensorize_1():
+    _check_dp4a_dense(
+        m=128,
+        n=128,
+        k=128,
+        in_dtype="float32",
+        out_dtype="float32",
+        expected_mods=None,
+        expected_decisions=None,
+    )
+
+
+def test_dp4a_dense_no_tensorize_2():
+    _check_dp4a_dense(
+        m=127,
+        n=127,
+        k=127,
+        in_dtype="int8",
+        out_dtype="int32",
+        expected_mods=None,
+        expected_decisions=None,
+    )
+
+
+if __name__ == "__main__":
+    test_vnni_conv2d_nchwc()
+    test_dp4a_dense()
+    test_dp4a_dense_no_tensorize_1()
+    test_dp4a_dense_no_tensorize_2()
diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_tc.py b/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_tc.py
new file mode 100644
index 0000000000..fbb74090b1
--- /dev/null
+++ b/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_tc.py
@@ -0,0 +1,957 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring
+import tvm
+from tvm import meta_schedule as ms
+from tvm import te
+from tvm.meta_schedule.testing import te_workload
+from tvm.meta_schedule.testing.schedule_rule import get_rules
+from tvm.meta_schedule.testing.space_generation import check_sketches
+from tvm.script import tir as T
+from tvm.tir.tensor_intrin.cuda import get_wmma_intrin_group
+
+
+def multi_level_tiling_tensor_core(
+    *,
+    write_reuse_scope="shared",
+    in_dtype="float16",
+    out_dtype="float32",
+    trans_b=False,
+    use_software_pipeline=False,
+) -> ms.schedule_rule.ScheduleRule:
+    assert write_reuse_scope in ["shared", "global"]
+    if not isinstance(in_dtype, list):
+        in_dtype = [in_dtype]
+    if not isinstance(out_dtype, list):
+        out_dtype = [out_dtype]
+    if not isinstance(trans_b, list):
+        trans_b = [trans_b]
+    return ms.schedule_rule.MultiLevelTilingTensorCore(
+        intrin_groups=[
+            get_wmma_intrin_group(write_reuse_scope, _in_dtype, _out_dtype, _trans_b)
+            for _in_dtype in in_dtype
+            for _out_dtype in out_dtype
+            for _trans_b in trans_b
+        ],
+        structure="SSSRRSRS",
+        tile_binds=["blockIdx.y", "blockIdx.x", "threadIdx.y"],
+        max_innermost_factor=4,  # 64 // tensor intrin size
+        vector_load_lens=[1, 2, 3, 4, 8, 16],
+        reuse_read=ms.schedule_rule.ReuseType(
+            req="must",
+            levels=[4],
+            scope="shared",
+        ),
+        reuse_write=ms.schedule_rule.ReuseType(
+            req="must" if write_reuse_scope == "shared" else "no",
+            levels=[2],
+            scope=write_reuse_scope,
+        ),
+        use_software_pipeline=use_software_pipeline,
+    )
+
+
+def test_matmul_relu():
+    # fmt: off
+    @T.prim_func
+    def matmul_relu_0(A: T.Buffer[(128, 128), "float16"], B: T.Buffer[(128, 128), "float16"], compute: T.Buffer[(128, 128), "float32"]) -> None:
+        T.func_attr({"global_symbol": "main", "tir.noalias": True})
+        C_reindex_shared = T.alloc_buffer([128, 128], dtype="float32", scope="shared")
+        C_reindex_shared_wmma_accumulator = T.alloc_buffer([128, 128], dtype="float32", scope="wmma.accumulator")
+        A_reindex_shared = T.alloc_buffer([128, 128], dtype="float16", scope="shared")
+        B_reindex_shared = T.alloc_buffer([128, 128], dtype="float16", scope="shared")
+        A_reindex_shared_wmma_matrix_a = T.alloc_buffer([128, 128], dtype="float16", scope="wmma.matrix_a")
+        B_reindex_shared_wmma_matrix_b = T.alloc_buffer([128, 128], dtype="float16", scope="wmma.matrix_b")
+        for ax0_0_0_ax1_0_0_fused in T.thread_binding(8, thread="blockIdx.y"):
+            for ax0_0_1_ax1_0_1_fused in T.thread_binding(2, thread="blockIdx.x"):
+                for ax0_0_2_ax1_0_2_fused in T.thread_binding(2, thread="threadIdx.y"):
+                    for ax2_0_0 in T.serial(1):
+                        for ax0_ax1_fused in T.serial(4096):
+                            with T.block("A_reindex_shared"):
+                                v0 = T.axis.spatial(128, ax0_0_0_ax1_0_0_fused // 2 * 32 + ax0_ax1_fused // 128)
+                                v1 = T.axis.spatial(128, ax0_ax1_fused % 128)
+                                T.reads(A[v0, v1])
+                                T.writes(A_reindex_shared[v0, v1])
+                                T.block_attr({"buffer_dim_align":[[0, 0, 32, 8]], "meta_schedule.cooperative_fetch":8})
+                                A_reindex_shared[v0, v1] = A[v0, v1]
+                        for ax0_ax1_fused in T.serial(4096):
+                            with T.block("B_reindex_shared"):
+                                v0 = T.axis.spatial(128, ax0_ax1_fused // 32)
+                                v1 = T.axis.spatial(128, ax0_0_0_ax1_0_0_fused % 2 * 64 + ax0_0_1_ax1_0_1_fused * 32 + ax0_ax1_fused % 32)
+                                T.reads(B[v0, v1])
+                                T.writes(B_reindex_shared[v0, v1])
+                                T.block_attr({"buffer_dim_align":[[0, 0, 32, 8]], "meta_schedule.cooperative_fetch":1})
+                                B_reindex_shared[v0, v1] = B[v0, v1]
+                        for ax2_0_1 in T.serial(4):
+                            for ax0_0, ax1_0 in T.grid(2, 2):
+                                with T.block("A_reindex_shared_wmma.matrix_a_o"):
+                                    v0_o = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused // 2 * 2 + ax0_0)
+                                    v1_o = T.axis.spatial(8, ax2_0_1 * 2 + ax1_0)
+                                    T.reads(A_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16])
+                                    T.writes(A_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16])
+                                    T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_a"})
+                                    for ax0_1, ax1_1 in T.grid(16, 16):
+                                        with T.block("A_reindex_shared_wmma.matrix_a"):
+                                            v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1])
+                                            T.reads(A_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i])
+                                            T.writes(A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i])
+                                            A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = A_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]
+                            for ax0_0, ax1_0 in T.grid(2, 1):
+                                with T.block("B_reindex_shared_wmma.matrix_b_o"):
+                                    v0_o = T.axis.spatial(8, ax2_0_1 * 2 + ax0_0)
+                                    v1_o = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused % 2 * 4 + ax0_0_1_ax1_0_1_fused * 2 + ax0_0_2_ax1_0_2_fused)
+                                    T.reads(B_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16])
+                                    T.writes(B_reindex_shared_wmma_matrix_b[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16])
+                                    T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_b"})
+                                    for ax0_1, ax1_1 in T.grid(16, 16):
+                                        with T.block("B_reindex_shared_wmma.matrix_b"):
+                                            v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1])
+                                            T.reads(B_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i])
+                                            T.writes(B_reindex_shared_wmma_matrix_b[v0_o * 16 + v0_i, v1_o * 16 + v1_i])
+                                            B_reindex_shared_wmma_matrix_b[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = B_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]
+                            for ax0_0_3, ax1_0_3, ax2_0_2, ax0_0_4, ax1_0_4 in T.grid(1, 1, 2, 2, 1):
+                                with T.block("C_o"):
+                                    v0_o = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused // 2 * 2 + ax0_0_3 * 2 + ax0_0_4)
+                                    v1_o = T.axis.spatial(8, ax1_0_4 + ax0_0_0_ax1_0_0_fused % 2 * 4 + ax0_0_1_ax1_0_1_fused * 2 + ax0_0_2_ax1_0_2_fused + ax1_0_3)
+                                    v2_o = T.axis.reduce(8, ax2_0_0 * 8 + ax2_0_1 * 2 + ax2_0_2)
+                                    T.reads(A_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v2_o * 16 : v2_o * 16 + 16], B_reindex_shared_wmma_matrix_b[v2_o * 16 : v2_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16])
+                                    T.writes(C_reindex_shared_wmma_accumulator[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16])
+                                    T.block_attr({"meta_schedule.auto_tensorize":"wmma_sync_16x16x16_f16f16f32", "meta_schedule.auto_tensorize_init":"wmma_fill_16x16x16_f32", "warp_execution":1})
+                                    with T.init():
+                                        for ax0_1, ax1_1 in T.grid(16, 16):
+                                            with T.block("C_init"):
+                                                v0_i_init, v1_i_init = T.axis.remap("SS", [ax0_1, ax1_1])
+                                                T.reads()
+                                                T.writes(C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i_init, v1_o * 16 + v1_i_init])
+                                                C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i_init, v1_o * 16 + v1_i_init] = T.float32(0)
+                                    for ax0_1, ax1_1, ax2_1 in T.grid(16, 16, 16):
+                                        with T.block("C"):
+                                            v0_i, v1_i, v2_i = T.axis.remap("SSR", [ax0_1, ax1_1, ax2_1])
+                                            T.reads(C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i], A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i], B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i])
+                                            T.writes(C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i])
+                                            T.block_attr({"meta_schedule.tiling_structure":"SSSRRSRS"})
+                                            C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] + T.cast(A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i], "float32") * T.cast(B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i], "float32")
+                    for ax0_0, ax1_0 in T.grid(2, 1):
+                        with T.block("C_reindex_shared_wmma.accumulator_o"):
+                            v0_o = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused // 2 * 2 + ax0_0)
+                            v1_o = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused % 2 * 4 + ax0_0_1_ax1_0_1_fused * 2 + ax0_0_2_ax1_0_2_fused)
+                            T.reads(C_reindex_shared_wmma_accumulator[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16])
+                            T.writes(C_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16])
+                            T.block_attr({"meta_schedule.auto_tensorize":"wmma_store_16x16x16_f32_shared"})
+                            for ax0_1, ax1_1 in T.grid(16, 16):
+                                with T.block("C_reindex_shared_wmma.accumulator"):
+                                    v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1])
+                                    T.reads(C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i])
+                                    T.writes(C_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i])
+                                    C_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i]
+                for ax0, ax1 in T.grid(32, 32):
+                    with T.block("C_reindex_shared"):
+                        v0 = T.axis.spatial(128, ax0_0_0_ax1_0_0_fused // 2 * 32 + ax0)
+                        v1 = T.axis.spatial(128, ax0_0_0_ax1_0_0_fused % 2 * 64 + ax0_0_1_ax1_0_1_fused * 32 + ax1)
+                        T.reads(C_reindex_shared[v0, v1])
+                        T.writes(compute[v0, v1])
+                        T.block_attr({"meta_schedule.cooperative_fetch":4})
+                        compute[v0, v1] = T.max(C_reindex_shared[v0, v1], T.float32(0))
+    # fmt: on
+    decision_0 = [
+        ("SamplePerfectTile", [4, 1, 1, 1, 2]),
+        ("SamplePerfectTile", [2, 2, 2, 1, 1]),
+        ("SamplePerfectTile", [1, 4, 2]),
+        ("SampleCategorical", 3),
+        ("SampleCategorical", 3),
+        ("SampleCategorical", 0),
+    ]
+
+    mod = te.create_prim_func(
+        te_workload.matmul_relu(
+            n=128,
+            m=128,
+            k=128,
+            in_dtype="float16",
+            out_dtype="float32",
+        )
+    )
+    actual = ms.TuneContext(
+        mod=mod,
+        target=tvm.target.Target("cuda"),
+        space_generator=ms.space_generator.PostOrderApply(),
+        sch_rules=[multi_level_tiling_tensor_core()]
+        + get_rules("cuda", ms.schedule_rule.AutoInline),
+    ).generate_design_space()
+    check_sketches(
+        mod,
+        sketches=actual,
+        expected_mods=[matmul_relu_0],
+        expected_decisions=[decision_0],
+    )
+
+
+def test_matmul_relu_with_fallback():
+    # fmt: off
+    @T.prim_func
+    def matmul_relu_fallback_0(A: T.Buffer[(128, 128), "float16"], B: T.Buffer[(128, 128), "float16"], compute: T.Buffer[(128, 128), "float32"]) -> None:
+        # function attr dict
+        T.func_attr({"global_symbol": "main", "tir.noalias": True})
+        # body
+        # with T.block("root")
+        C_reindex_shared = T.alloc_buffer([128, 128], dtype="float32", scope="shared")
+        C_reindex_shared_wmma_accumulator = T.alloc_buffer([128, 128], dtype="float32", scope="wmma.accumulator")
+        A_reindex_shared = T.alloc_buffer([128, 128], dtype="float16", scope="shared")
+        B_reindex_shared = T.alloc_buffer([128, 128], dtype="float16", scope="shared")
+        A_reindex_shared_wmma_matrix_a = T.alloc_buffer([128, 128], dtype="float16", scope="wmma.matrix_a")
+        B_reindex_shared_wmma_matrix_b = T.alloc_buffer([128, 128], dtype="float16", scope="wmma.matrix_b")
+        for ax0_0_0_ax1_0_0_fused in T.thread_binding(2, thread="blockIdx.y"):
+            for ax0_0_1_ax1_0_1_fused in T.thread_binding(2, thread="blockIdx.x"):
+                for ax0_0_2_ax1_0_2_fused in T.thread_binding(2, thread="threadIdx.y"):
+                    for ax2_0_0 in T.serial(2):
+                        for ax0_ax1_fused in T.serial(2048):
+                            with T.block("A_reindex_shared"):
+                                v0 = T.axis.spatial(128, ax0_0_0_ax1_0_0_fused * 64 + ax0_0_1_ax1_0_1_fused * 32 + ax0_ax1_fused // 64)
+                                v1 = T.axis.spatial(128, ax2_0_0 * 64 + ax0_ax1_fused % 64)
+                                T.reads(A[v0, v1])
+                                T.writes(A_reindex_shared[v0, v1])
+                                T.block_attr({"buffer_dim_align":[[0, 0, 32, 8]], "meta_schedule.cooperative_fetch":4})
+                                A_reindex_shared[v0, v1] = A[v0, v1]
+                        for ax0_ax1_fused in T.serial(8192):
+                            with T.block("B_reindex_shared"):
+                                v0 = T.axis.spatial(128, ax2_0_0 * 64 + ax0_ax1_fused // 128)
+                                v1 = T.axis.spatial(128, ax0_ax1_fused % 128)
+                                T.reads(B[v0, v1])
+                                T.writes(B_reindex_shared[v0, v1])
+                                T.block_attr({"buffer_dim_align":[[0, 0, 32, 8]], "meta_schedule.cooperative_fetch":2})
+                                B_reindex_shared[v0, v1] = B[v0, v1]
+                        for ax2_0_1 in T.serial(1):
+                            for ax0_0, ax1_0 in T.grid(2, 4):
+                                with T.block("A_reindex_shared_wmma.matrix_a_o"):
+                                    v0_o = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused * 4 + ax0_0_1_ax1_0_1_fused * 2 + ax0_0)
+                                    v1_o = T.axis.spatial(8, ax2_0_0 * 4 + ax1_0)
+                                    T.reads(A_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16])
+                                    T.writes(A_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16])
+                                    T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_a"})
+                                    for ax0_1, ax1_1 in T.grid(16, 16):
+                                        with T.block("A_reindex_shared_wmma.matrix_a"):
+                                            v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1])
+                                            T.reads(A_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i])
+                                            T.writes(A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i])
+                                            A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = A_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]
+                            for ax0_0, ax1_0 in T.grid(4, 4):
+                                with T.block("B_reindex_shared_wmma.matrix_b_o"):
+                                    v0_o = T.axis.spatial(8, ax2_0_0 * 4 + ax0_0)
+                                    v1_o = T.axis.spatial(8, ax0_0_2_ax1_0_2_fused * 4 + ax1_0)
+                                    T.reads(B_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16])
+                                    T.writes(B_reindex_shared_wmma_matrix_b[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16])
+                                    T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_b"})
+                                    for ax0_1, ax1_1 in T.grid(16, 16):
+                                        with T.block("B_reindex_shared_wmma.matrix_b"):
+                                            v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1])
+                                            T.reads(B_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i])
+                                            T.writes(B_reindex_shared_wmma_matrix_b[v0_o * 16 + v0_i, v1_o * 16 + v1_i])
+                                            B_reindex_shared_wmma_matrix_b[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = B_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]
+                            for ax0_0_3, ax1_0_3, ax2_0_2, ax0_0_4, ax1_0_4 in T.grid(1, 1, 4, 2, 4):
+                                with T.block("C_o"):
+                                    v0_o = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused * 4 + ax0_0_1_ax1_0_1_fused * 2 + ax0_0_3 * 2 + ax0_0_4)
+                                    v1_o = T.axis.spatial(8, ax0_0_2_ax1_0_2_fused * 4 + ax1_0_3 * 4 + ax1_0_4)
+                                    v2_o = T.axis.reduce(8, ax2_0_0 * 4 + ax2_0_1 * 4 + ax2_0_2)
+                                    T.reads(A_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v2_o * 16 : v2_o * 16 + 16], B_reindex_shared_wmma_matrix_b[v2_o * 16 : v2_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16])
+                                    T.writes(C_reindex_shared_wmma_accumulator[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16])
+                                    T.block_attr({"meta_schedule.auto_tensorize":"wmma_sync_16x16x16_f16f16f32", "meta_schedule.auto_tensorize_init":"wmma_fill_16x16x16_f32", "warp_execution":1})
+                                    with T.init():
+                                        for ax0_1, ax1_1 in T.grid(16, 16):
+                                            with T.block("C_init"):
+                                                v0_i_init, v1_i_init = T.axis.remap("SS", [ax0_1, ax1_1])
+                                                T.reads()
+                                                T.writes(C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i_init, v1_o * 16 + v1_i_init])
+                                                C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i_init, v1_o * 16 + v1_i_init] = T.float32(0)
+                                    for ax0_1, ax1_1, ax2_1 in T.grid(16, 16, 16):
+                                        with T.block("C"):
+                                            v0_i, v1_i, v2_i = T.axis.remap("SSR", [ax0_1, ax1_1, ax2_1])
+                                            T.reads(C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i], A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i], B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i])
+                                            T.writes(C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i])
+                                            T.block_attr({"meta_schedule.tiling_structure":"SSSRRSRS"})
+                                            C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] + T.cast(A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i], "float32") * T.cast(B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i], "float32")
+                    for ax0_0, ax1_0 in T.grid(2, 4):
+                        with T.block("C_reindex_shared_wmma.accumulator_o"):
+                            v0_o = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused * 4 + ax0_0_1_ax1_0_1_fused * 2 + ax0_0)
+                            v1_o = T.axis.spatial(8, ax0_0_2_ax1_0_2_fused * 4 + ax1_0)
+                            T.reads(C_reindex_shared_wmma_accumulator[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16])
+                            T.writes(C_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16])
+                            T.block_attr({"meta_schedule.auto_tensorize":"wmma_store_16x16x16_f32_shared"})
+                            for ax0_1, ax1_1 in T.grid(16, 16):
+                                with T.block("C_reindex_shared_wmma.accumulator"):
+                                    v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1])
+                                    T.reads(C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i])
+                                    T.writes(C_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i])
+                                    C_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i]
+                for ax0, ax1 in T.grid(32, 128):
+                    with T.block("C_reindex_shared"):
+                        v0 = T.axis.spatial(128, ax0_0_0_ax1_0_0_fused * 64 + ax0_0_1_ax1_0_1_fused * 32 + ax0)
+                        v1 = T.axis.spatial(128, ax1)
+                        T.reads(C_reindex_shared[v0, v1])
+                        T.writes(compute[v0, v1])
+                        T.block_attr({"meta_schedule.cooperative_fetch":4})
+                        compute[v0, v1] = T.max(C_reindex_shared[v0, v1], T.float32(0))
+    # fmt: on
+    decision_0 = [
+        ("SamplePerfectTile", [2, 2, 1, 1, 2]),
+        ("SamplePerfectTile", [1, 1, 2, 1, 4]),
+        ("SamplePerfectTile", [2, 1, 4]),
+        ("SampleCategorical", 3),
+        ("SampleCategorical", 2),
+        ("SampleCategorical", 1),
+    ]
+
+    mod = te.create_prim_func(
+        te_workload.matmul_relu(
+            n=128,
+            m=128,
+            k=128,
+            in_dtype="float16",
+            out_dtype="float32",
+        )
+    )
+    actual = ms.TuneContext(
+        mod=mod,
+        target=tvm.target.Target("cuda"),
+        space_generator=ms.space_generator.PostOrderApply(),
+        sch_rules=[
+            multi_level_tiling_tensor_core(),
+        ]
+        + get_rules(
+            "cuda",
+            (
+                ms.schedule_rule.MultiLevelTiling,
+                ms.schedule_rule.AutoInline,
+            ),
+        ),
+    ).generate_design_space()
+    check_sketches(
+        mod,
+        sketches=actual,
+        expected_mods=[matmul_relu_fallback_0],
+        expected_decisions=[decision_0],
+    )
+
+
+def test_conv2d():
+    # fmt: off
+    @T.prim_func
+    def conv2d_0(inputs: T.Buffer[(1, 16, 16, 32), "float16"], weight: T.Buffer[(3, 3, 32, 32), "float16"], conv2d_nhwc: T.Buffer[(1, 16, 16, 32), "float32"]) -> None:
+        T.func_attr({"global_symbol": "main", "tir.noalias": True})
+        PadInput = T.alloc_buffer([1, 18, 18, 32], dtype="float16")
+        conv2d_nhwc_reindex_shared = T.alloc_buffer([256, 32], dtype="float32", scope="shared")
+        conv2d_nhwc_reindex_shared_wmma_accumulator = T.alloc_buffer([256, 32], dtype="float32", scope="wmma.accumulator")
+        PadInput_reindex_shared = T.alloc_buffer([256, 288], dtype="float16", scope="shared")
+        weight_reindex_shared = T.alloc_buffer([288, 32], dtype="float16", scope="shared")
+        PadInput_reindex_shared_wmma_matrix_a = T.alloc_buffer([256, 288], dtype="float16", scope="wmma.matrix_a")
+        weight_reindex_shared_wmma_matrix_b = T.alloc_buffer([288, 32], dtype="float16", scope="wmma.matrix_b")
+        for i0, i1, i2, i3 in T.grid(1, 18, 18, 32):
+            with T.block("PadInput"):
+                i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3])
+                T.reads(inputs[i0_1, i1_1 - 1, i2_1 - 1, i3_1])
+                T.writes(PadInput[i0_1, i1_1, i2_1, i3_1])
+                PadInput[i0_1, i1_1, i2_1, i3_1] = T.if_then_else(1 <= i1_1 and i1_1 < 17 and 1 <= i2_1 and i2_1 < 17, inputs[i0_1, i1_1 - 1, i2_1 - 1, i3_1], T.float16(0), dtype="float16")
+        for ax0_0_ax1_0_0_ax2_0_0_fused in T.thread_binding(2, thread="blockIdx.y"):
+            for ax0_1_ax1_0_1_ax2_0_1_fused in T.thread_binding(16, thread="blockIdx.x"):
+                for ax0_2_ax1_0_2_ax2_0_2_fused in T.thread_binding(1, thread="threadIdx.y"):
+                    for ax3_0_0 in T.serial(1):
+                        for ax0_ax1_fused in T.serial(4608):
+                            with T.block("PadInput_reindex_shared"):
+                                v0 = T.axis.spatial(256, ax0_1_ax1_0_1_ax2_0_1_fused * 16 + ax0_ax1_fused // 288)
+                                v1 = T.axis.spatial(288, ax0_ax1_fused % 288)
+                                T.reads(PadInput[0, v1 // 96 + v0 // 16, v1 % 96 // 32 + v0 % 16, v1 % 32])
+                                T.writes(PadInput_reindex_shared[v0, v1])
+                                T.block_attr({"buffer_dim_align":[[0, 0, 32, 8]], "meta_schedule.cooperative_fetch":2})
+                                PadInput_reindex_shared[v0, v1] = PadInput[0, v1 // 96 + v0 // 16, v1 % 96 // 32 + v0 % 16, v1 % 32]
+                        for ax0_ax1_fused in T.serial(4608):
+                            with T.block("weight_reindex_shared"):
+                                v0 = T.axis.spatial(288, ax0_ax1_fused // 16)
+                                v1 = T.axis.spatial(32, ax0_0_ax1_0_0_ax2_0_0_fused * 16 + ax0_ax1_fused % 16)
+                                T.reads(weight[v0 // 96, v0 % 96 // 32, v0 % 32, v1])
+                                T.writes(weight_reindex_shared[v0, v1])
+                                T.block_attr({"buffer_dim_align":[[0, 0, 32, 8]], "meta_schedule.cooperative_fetch":8})
+                                weight_reindex_shared[v0, v1] = weight[v0 // 96, v0 % 96 // 32, v0 % 32, v1]
+                        for ax3_0_1 in T.serial(18):
+                            for ax0_0, ax1_0 in T.grid(1, 1):
+                                with T.block("PadInput_reindex_shared_wmma.matrix_a_o"):
+                                    v0_o, v1_o = T.axis.remap("SS", [ax0_1_ax1_0_1_ax2_0_1_fused, ax3_0_1])
+                                    T.reads(PadInput_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16])
+                                    T.writes(PadInput_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16])
+                                    T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_a"})
+                                    for ax0_1, ax1_1 in T.grid(16, 16):
+                                        with T.block("PadInput_reindex_shared_wmma.matrix_a"):
+                                            v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1])
+                                            T.reads(PadInput_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i])
+                                            T.writes(PadInput_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i])
+                                            PadInput_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = PadInput_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]
+                            for ax0_0, ax1_0 in T.grid(1, 1):
+                                with T.block("weight_reindex_shared_wmma.matrix_b_o"):
+                                    v0_o, v1_o = T.axis.remap("SS", [ax3_0_1, ax0_0_ax1_0_0_ax2_0_0_fused])
+                                    T.reads(weight_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16])
+                                    T.writes(weight_reindex_shared_wmma_matrix_b[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16])
+                                    T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_b"})
+                                    for ax0_1, ax1_1 in T.grid(16, 16):
+                                        with T.block("weight_reindex_shared_wmma.matrix_b"):
+                                            v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1])
+                                            T.reads(weight_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i])
+                                            T.writes(weight_reindex_shared_wmma_matrix_b[v0_o * 16 + v0_i, v1_o * 16 + v1_i])
+                                            weight_reindex_shared_wmma_matrix_b[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = weight_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]
+                            for ax0_3, ax1_0_3, ax2_0_3, ax3_0_2, ax0_4, ax1_0_4, ax2_0_4 in T.grid(1, 1, 1, 1, 1, 1, 1):
+                                with T.block("conv2d_nhwc_o"):
+                                    v0 = T.axis.spatial(1, 0)
+                                    v1_o = T.axis.spatial(16, ax1_0_4 + ax0_1_ax1_0_1_ax2_0_1_fused + ax1_0_3)
+                                    v2_o = T.axis.spatial(2, ax0_0_ax1_0_0_ax2_0_0_fused + ax2_0_3 + ax2_0_4)
+                                    v3_o = T.axis.reduce(18, ax3_0_0 * 18 + ax3_0_1 + ax3_0_2)
+                                    T.reads(PadInput_reindex_shared_wmma_matrix_a[v1_o * 16 : v1_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16], weight_reindex_shared_wmma_matrix_b[v3_o * 16 : v3_o * 16 + 16, v2_o * 16 : v2_o * 16 + 16])
+                                    T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v1_o * 16 : v1_o * 16 + 16, v2_o * 16 : v2_o * 16 + 16])
+                                    T.block_attr({"meta_schedule.auto_tensorize":"wmma_sync_16x16x16_f16f16f32", "meta_schedule.auto_tensorize_init":"wmma_fill_16x16x16_f32", "warp_execution":1})
+                                    with T.init():
+                                        for ax1_1, ax2_1 in T.grid(16, 16):
+                                            with T.block("conv2d_nhwc_init"):
+                                                v1_i_init, v2_i_init = T.axis.remap("SS", [ax1_1, ax2_1])
+                                                T.reads()
+                                                T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v1_o * 16 + v1_i_init, v2_o * 16 + v2_i_init])
+                                                conv2d_nhwc_reindex_shared_wmma_accumulator[v1_o * 16 + v1_i_init, v2_o * 16 + v2_i_init] = T.float32(0)
+                                    for ax1_1, ax2_1, ax3_1 in T.grid(16, 16, 16):
+                                        with T.block("conv2d_nhwc"):
+                                            v1_i, v2_i, v3_i = T.axis.remap("SSR", [ax1_1, ax2_1, ax3_1])
+                                            T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v1_o * 16 + v1_i, v2_o * 16 + v2_i], PadInput_reindex_shared_wmma_matrix_a[v1_o * 16 + v1_i, v3_o * 16 + v3_i], weight_reindex_shared_wmma_matrix_b[v3_o * 16 + v3_i, v2_o * 16 + v2_i])
+                                            T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v1_o * 16 + v1_i, v2_o * 16 + v2_i])
+                                            T.block_attr({"meta_schedule.tiling_structure":"SSSRRSRS"})
+                                            conv2d_nhwc_reindex_shared_wmma_accumulator[v1_o * 16 + v1_i, v2_o * 16 + v2_i] = conv2d_nhwc_reindex_shared_wmma_accumulator[v1_o * 16 + v1_i, v2_o * 16 + v2_i] + T.cast(PadInput_reindex_shared_wmma_matrix_a[v1_o * 16 + v1_i, v3_o * 16 + v3_i], "float32") * T.cast(weight_reindex_shared_wmma_matrix_b[v3_o * 16 + v3_i, v2_o * 16 + v2_i], "float32")
+                    for ax0_0, ax1_0 in T.grid(1, 1):
+                        with T.block("conv2d_nhwc_reindex_shared_wmma.accumulator_o"):
+                            v0_o, v1_o = T.axis.remap("SS", [ax0_1_ax1_0_1_ax2_0_1_fused, ax0_0_ax1_0_0_ax2_0_0_fused])
+                            T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16])
+                            T.writes(conv2d_nhwc_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16])
+                            T.block_attr({"meta_schedule.auto_tensorize":"wmma_store_16x16x16_f32_shared"})
+                            for ax0_1, ax1_1 in T.grid(16, 16):
+                                with T.block("conv2d_nhwc_reindex_shared_wmma.accumulator"):
+                                    v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1])
+                                    T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i])
+                                    T.writes(conv2d_nhwc_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i])
+                                    conv2d_nhwc_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i]
+                for ax0, ax1 in T.grid(16, 16):
+                    with T.block("conv2d_nhwc_reindex_shared"):
+                        v0 = T.axis.spatial(256, ax0_1_ax1_0_1_ax2_0_1_fused * 16 + ax0)
+                        v1 = T.axis.spatial(32, ax0_0_ax1_0_0_ax2_0_0_fused * 16 + ax1)
+                        T.reads(conv2d_nhwc_reindex_shared[v0, v1])
+                        T.writes(conv2d_nhwc[0, v0 // 16, v0 % 16, v1])
+                        T.block_attr({"meta_schedule.cooperative_fetch":3})
+                        conv2d_nhwc[0, v0 // 16, v0 % 16, v1] = conv2d_nhwc_reindex_shared[v0, v1]
+    # fmt: on
+    decision_0 = [
+        ("SamplePerfectTile", [1, 1, 1, 1, 1]),
+        ("SamplePerfectTile", [1, 16, 1, 1, 1]),
+        ("SamplePerfectTile", [2, 1, 1, 1, 1]),
+        ("SamplePerfectTile", [1, 18, 1]),
+        ("SampleCategorical", 2),
+        ("SampleCategorical", 1),
+        ("SampleCategorical", 3),
+    ]
+    mod = te.create_prim_func(
+        te_workload.conv2d_nhwc(
+            N=1,
+            H=16,
+            W=16,
+            CI=32,
+            CO=32,
+            kernel_size=3,
+            stride=1,
+            padding=1,
+            in_dtype="float16",
+            out_dtype="float32",
+        )
+    )
+    actual = ms.TuneContext(
+        mod=mod,
+        target=tvm.target.Target("cuda"),
+        space_generator=ms.space_generator.PostOrderApply(),
+        sch_rules=[multi_level_tiling_tensor_core()],
+    ).generate_design_space()
+    check_sketches(
+        mod,
+        sketches=actual,
+        expected_mods=[conv2d_0],
+        expected_decisions=[decision_0],
+    )
+
+
+def test_conv2d_more_intrin():
+    # test adding inapplicable tensor intrinsics doesn't change the search space
+    # fmt: off
+    @T.prim_func
+    def conv2d_more_intrin_0(inputs: T.Buffer[(1, 16, 16, 32), "float16"], weight: T.Buffer[(3, 3, 32, 32), "float16"], conv2d_nhwc: T.Buffer[(1, 16, 16, 32), "float32"]) -> None:
+        # function attr dict
+        T.func_attr({"global_symbol": "main", "tir.noalias": True})
+        # body
+        # with T.block("root")
+        PadInput = T.alloc_buffer([1, 18, 18, 32], dtype="float16")
+        conv2d_nhwc_reindex_shared = T.alloc_buffer([256, 32], dtype="float32", scope="shared")
+        conv2d_nhwc_reindex_shared_wmma_accumulator = T.alloc_buffer([256, 32], dtype="float32", scope="wmma.accumulator")
+        PadInput_reindex_shared = T.alloc_buffer([256, 288], dtype="float16", scope="shared")
+        weight_reindex_shared = T.alloc_buffer([288, 32], dtype="float16", scope="shared")
+        PadInput_reindex_shared_wmma_matrix_a = T.alloc_buffer([256, 288], dtype="float16", scope="wmma.matrix_a")
+        weight_reindex_shared_wmma_matrix_b = T.alloc_buffer([288, 32], dtype="float16", scope="wmma.matrix_b")
+        for i0, i1, i2, i3 in T.grid(1, 18, 18, 32):
+            with T.block("PadInput"):
+                i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3])
+                T.reads(inputs[i0_1, i1_1 - 1, i2_1 - 1, i3_1])
+                T.writes(PadInput[i0_1, i1_1, i2_1, i3_1])
+                PadInput[i0_1, i1_1, i2_1, i3_1] = T.if_then_else(1 <= i1_1 and i1_1 < 17 and 1 <= i2_1 and i2_1 < 17, inputs[i0_1, i1_1 - 1, i2_1 - 1, i3_1], T.float16(0), dtype="float16")
+        for ax0_0_ax1_0_0_ax2_0_0_fused in T.thread_binding(4, thread="blockIdx.y"):
+            for ax0_1_ax1_0_1_ax2_0_1_fused in T.thread_binding(4, thread="blockIdx.x"):
+                for ax0_2_ax1_0_2_ax2_0_2_fused in T.thread_binding(1, thread="threadIdx.y"):
+                    for ax3_0_0 in T.serial(3):
+                        for ax0_ax1_fused in T.serial(1536):
+                            with T.block("PadInput_reindex_shared"):
+                                v0 = T.axis.spatial(256, ax0_0_ax1_0_0_ax2_0_0_fused * 64 + ax0_1_ax1_0_1_ax2_0_1_fused * 16 + ax0_ax1_fused // 96)
+                                v1 = T.axis.spatial(288, ax3_0_0 * 96 + ax0_ax1_fused % 96)
+                                T.reads(PadInput[0, v1 // 96 + v0 // 16, v1 % 96 // 32 + v0 % 16, v1 % 32])
+                                T.writes(PadInput_reindex_shared[v0, v1])
+                                T.block_attr({"buffer_dim_align":[[0, 0, 32, 8]], "meta_schedule.cooperative_fetch":8})
+                                PadInput_reindex_shared[v0, v1] = PadInput[0, v1 // 96 + v0 // 16, v1 % 96 // 32 + v0 % 16, v1 % 32]
+                        for ax0_ax1_fused in T.serial(3072):
+                            with T.block("weight_reindex_shared"):
+                                v0 = T.axis.spatial(288, ax3_0_0 * 96 + ax0_ax1_fused // 32)
+                                v1 = T.axis.spatial(32, ax0_ax1_fused % 32)
+                                T.reads(weight[v0 // 96, v0 % 96 // 32, v0 % 32, v1])
+                                T.writes(weight_reindex_shared[v0, v1])
+                                T.block_attr({"buffer_dim_align":[[0, 0, 32, 8]], "meta_schedule.cooperative_fetch":8})
+                                weight_reindex_shared[v0, v1] = weight[v0 // 96, v0 % 96 // 32, v0 % 32, v1]
+                        for ax3_0_1 in T.serial(2):
+                            for ax0_0, ax1_0 in T.grid(1, 3):
+                                with T.block("PadInput_reindex_shared_wmma.matrix_a_o"):
+                                    v0_o = T.axis.spatial(16, ax0_0_ax1_0_0_ax2_0_0_fused * 4 + ax0_1_ax1_0_1_ax2_0_1_fused)
+                                    v1_o = T.axis.spatial(18, ax3_0_0 * 6 + ax3_0_1 * 3 + ax1_0)
+                                    T.reads(PadInput_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16])
+                                    T.writes(PadInput_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16])
+                                    T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_a"})
+                                    for ax0_1, ax1_1 in T.grid(16, 16):
+                                        with T.block("PadInput_reindex_shared_wmma.matrix_a"):
+                                            v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1])
+                                            T.reads(PadInput_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i])
+                                            T.writes(PadInput_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i])
+                                            PadInput_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = PadInput_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]
+                            for ax0_0, ax1_0 in T.grid(3, 2):
+                                with T.block("weight_reindex_shared_wmma.matrix_b_o"):
+                                    v0_o = T.axis.spatial(18, ax3_0_0 * 6 + ax3_0_1 * 3 + ax0_0)
+                                    v1_o = T.axis.spatial(2, ax1_0)
+                                    T.reads(weight_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16])
+                                    T.writes(weight_reindex_shared_wmma_matrix_b[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16])
+                                    T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_b"})
+                                    for ax0_1, ax1_1 in T.grid(16, 16):
+                                        with T.block("weight_reindex_shared_wmma.matrix_b"):
+                                            v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1])
+                                            T.reads(weight_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i])
+                                            T.writes(weight_reindex_shared_wmma_matrix_b[v0_o * 16 + v0_i, v1_o * 16 + v1_i])
+                                            weight_reindex_shared_wmma_matrix_b[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = weight_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]
+                            for ax0_3, ax1_0_3, ax2_0_3, ax3_0_2, ax0_4, ax1_0_4, ax2_0_4 in T.grid(1, 1, 2, 3, 1, 1, 1):
+                                with T.block("conv2d_nhwc_o"):
+                                    v0 = T.axis.spatial(1, 0)
+                                    v1_o = T.axis.spatial(16, ax1_0_4 + ax0_0_ax1_0_0_ax2_0_0_fused * 4 + ax0_1_ax1_0_1_ax2_0_1_fused + ax1_0_3)
+                                    v2_o = T.axis.spatial(2, ax2_0_4 + ax2_0_3)
+                                    v3_o = T.axis.reduce(18, ax3_0_0 * 6 + ax3_0_1 * 3 + ax3_0_2)
+                                    T.reads(PadInput_reindex_shared_wmma_matrix_a[v1_o * 16 : v1_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16], weight_reindex_shared_wmma_matrix_b[v3_o * 16 : v3_o * 16 + 16, v2_o * 16 : v2_o * 16 + 16])
+                                    T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v1_o * 16 : v1_o * 16 + 16, v2_o * 16 : v2_o * 16 + 16])
+                                    T.block_attr({"meta_schedule.auto_tensorize":"wmma_sync_16x16x16_f16f16f32", "meta_schedule.auto_tensorize_init":"wmma_fill_16x16x16_f32", "warp_execution":1})
+                                    with T.init():
+                                        for ax1_1, ax2_1 in T.grid(16, 16):
+                                            with T.block("conv2d_nhwc_init"):
+                                                v1_i_init, v2_i_init = T.axis.remap("SS", [ax1_1, ax2_1])
+                                                T.reads()
+                                                T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v1_o * 16 + v1_i_init, v2_o * 16 + v2_i_init])
+                                                conv2d_nhwc_reindex_shared_wmma_accumulator[v1_o * 16 + v1_i_init, v2_o * 16 + v2_i_init] = T.float32(0)
+                                    for ax1_1, ax2_1, ax3_1 in T.grid(16, 16, 16):
+                                        with T.block("conv2d_nhwc"):
+                                            v1_i, v2_i, v3_i = T.axis.remap("SSR", [ax1_1, ax2_1, ax3_1])
+                                            T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v1_o * 16 + v1_i, v2_o * 16 + v2_i], PadInput_reindex_shared_wmma_matrix_a[v1_o * 16 + v1_i, v3_o * 16 + v3_i], weight_reindex_shared_wmma_matrix_b[v3_o * 16 + v3_i, v2_o * 16 + v2_i])
+                                            T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v1_o * 16 + v1_i, v2_o * 16 + v2_i])
+                                            T.block_attr({"meta_schedule.tiling_structure":"SSSRRSRS"})
+                                            conv2d_nhwc_reindex_shared_wmma_accumulator[v1_o * 16 + v1_i, v2_o * 16 + v2_i] = conv2d_nhwc_reindex_shared_wmma_accumulator[v1_o * 16 + v1_i, v2_o * 16 + v2_i] + T.cast(PadInput_reindex_shared_wmma_matrix_a[v1_o * 16 + v1_i, v3_o * 16 + v3_i], "float32") * T.cast(weight_reindex_shared_wmma_matrix_b[v3_o * 16 + v3_i, v2_o * 16 + v2_i], "float32")
+                    for ax0_0, ax1_0 in T.grid(1, 2):
+                        with T.block("conv2d_nhwc_reindex_shared_wmma.accumulator_o"):
+                            v0_o = T.axis.spatial(16, ax0_0_ax1_0_0_ax2_0_0_fused * 4 + ax0_1_ax1_0_1_ax2_0_1_fused)
+                            v1_o = T.axis.spatial(2, ax1_0)
+                            T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16])
+                            T.writes(conv2d_nhwc_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16])
+                            T.block_attr({"meta_schedule.auto_tensorize":"wmma_store_16x16x16_f32_shared"})
+                            for ax0_1, ax1_1 in T.grid(16, 16):
+                                with T.block("conv2d_nhwc_reindex_shared_wmma.accumulator"):
+                                    v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1])
+                                    T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i])
+                                    T.writes(conv2d_nhwc_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i])
+                                    conv2d_nhwc_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i]
+                for ax0, ax1 in T.grid(16, 32):
+                    with T.block("conv2d_nhwc_reindex_shared"):
+                        v0 = T.axis.spatial(256, ax0_0_ax1_0_0_ax2_0_0_fused * 64 + ax0_1_ax1_0_1_ax2_0_1_fused * 16 + ax0)
+                        v1 = T.axis.spatial(32, ax1)
+                        T.reads(conv2d_nhwc_reindex_shared[v0, v1])
+                        T.writes(conv2d_nhwc[0, v0 // 16, v0 % 16, v1])
+                        T.block_attr({"meta_schedule.cooperative_fetch":3})
+                        conv2d_nhwc[0, v0 // 16, v0 % 16, v1] = conv2d_nhwc_reindex_shared[v0, v1]
+    # fmt: on
+    decision_0 = [
+        ("SamplePerfectTile", [1, 1, 1, 1, 1]),
+        ("SamplePerfectTile", [4, 4, 1, 1, 1]),
+        ("SamplePerfectTile", [1, 1, 1, 2, 1]),
+        ("SamplePerfectTile", [3, 2, 3]),
+        ("SampleCategorical", 2),
+        ("SampleCategorical", 3),
+        ("SampleCategorical", 3),
+    ]
+
+    mod = te.create_prim_func(
+        te_workload.conv2d_nhwc(
+            N=1,
+            H=16,
+            W=16,
+            CI=32,
+            CO=32,
+            kernel_size=3,
+            stride=1,
+            padding=1,
+            in_dtype="float16",
+            out_dtype="float32",
+        )
+    )
+    actual = ms.TuneContext(
+        mod=mod,
+        target=tvm.target.Target("cuda"),
+        space_generator=ms.space_generator.PostOrderApply(),
+        sch_rules=[
+            multi_level_tiling_tensor_core(
+                in_dtype="float16",
+                out_dtype=["float16", "float32"],
+            ),
+        ],
+    ).generate_design_space()
+    check_sketches(
+        mod,
+        sketches=actual,
+        expected_mods=[conv2d_more_intrin_0],
+        expected_decisions=[decision_0],
+    )
+
+
+def test_matmul_relu_pipeline():
+    # fmt: off
+    @T.prim_func
+    def matmul_relu_pipeline_0(A: T.Buffer[(128, 128), "float16"], B: T.Buffer[(128, 128), "float16"], compute: T.Buffer[(128, 128), "float32"]) -> None:
+        # function attr dict
+        T.func_attr({"global_symbol": "main", "tir.noalias": True})
+        # body
+        # with T.block("root")
+        C = T.alloc_buffer([128, 128], dtype="float32")
+        C_reindex_shared = T.alloc_buffer([128, 128], dtype="float32", scope="shared")
+        C_reindex_shared_wmma_accumulator = T.alloc_buffer([128, 128], dtype="float32", scope="wmma.accumulator")
+        A_reindex_shared = T.alloc_buffer([128, 128], dtype="float16", scope="shared")
+        B_reindex_shared = T.alloc_buffer([128, 128], dtype="float16", scope="shared")
+        A_reindex_shared_wmma_matrix_a = T.alloc_buffer([128, 128], dtype="float16", scope="wmma.matrix_a")
+        B_reindex_shared_wmma_matrix_b = T.alloc_buffer([128, 128], dtype="float16", scope="wmma.matrix_b")
+        for ax0_0_0_ax1_0_0_fused in T.thread_binding(1, thread="blockIdx.y"):
+            for ax0_0_1_ax1_0_1_fused in T.thread_binding(16, thread="blockIdx.x"):
+                for ax0_0_2_ax1_0_2_fused in T.thread_binding(1, thread="threadIdx.y"):
+                    for ax2_0_0 in T.serial(4, annotations={"software_pipeline_order":[0, 3, 1, 4, 5, 2, 6], "software_pipeline_stage":[0, 0, 0, 0, 0, 1, 1]}):
+                        for ax0_ax1_fused in T.serial(1024):
+                            with T.block("A_reindex_shared"):
+                                v0 = T.axis.spatial(128, ax0_0_1_ax1_0_1_fused // 4 * 32 + ax0_ax1_fused // 32)
+                                v1 = T.axis.spatial(128, ax2_0_0 * 32 + ax0_ax1_fused % 32)
+                                T.reads(A[v0, v1])
+                                T.writes(A_reindex_shared[v0, v1])
+                                T.block_attr({"buffer_dim_align":[[0, 0, 32, 8]], "double_buffer_scope":0, "meta_schedule.cooperative_fetch":4, "tir.manifest_shared_memory_local_stage":1})
+                                A_reindex_shared[v0, v1] = A[v0, v1]
+                        for ax0_ax1_fused in T.serial(1024):
+                            with T.block("B_reindex_shared"):
+                                v0 = T.axis.spatial(128, ax2_0_0 * 32 + ax0_ax1_fused // 32)
+                                v1 = T.axis.spatial(128, ax0_0_1_ax1_0_1_fused % 4 * 32 + ax0_ax1_fused % 32)
+                                T.reads(B[v0, v1])
+                                T.writes(B_reindex_shared[v0, v1])
+                                T.block_attr({"buffer_dim_align":[[0, 0, 32, 8]], "double_buffer_scope":0, "meta_schedule.cooperative_fetch":2, "tir.manifest_shared_memory_local_stage":1})
+                                B_reindex_shared[v0, v1] = B[v0, v1]
+                        for ax2_0_1 in T.serial(2, annotations={"software_pipeline_order":[0, 1, 2], "software_pipeline_stage":[0, 0, 1]}):
+                            for ax0_0, ax1_0 in T.grid(2, 1):
+                                with T.block("A_reindex_shared_wmma.matrix_a_o"):
+                                    v0_o = T.axis.spatial(8, ax0_0_1_ax1_0_1_fused // 4 * 2 + ax0_0)
+                                    v1_o = T.axis.spatial(8, ax2_0_0 * 2 + ax2_0_1)
+                                    T.reads(A_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16])
+                                    T.writes(A_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16])
+                                    T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_a"})
+                                    for ax0_1, ax1_1 in T.grid(16, 16):
+                                        with T.block("A_reindex_shared_wmma.matrix_a"):
+                                            v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1])
+                                            T.reads(A_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i])
+                                            T.writes(A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i])
+                                            A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = A_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]
+                            for ax0_0, ax1_0 in T.grid(1, 2):
+                                with T.block("B_reindex_shared_wmma.matrix_b_o"):
+                                    v0_o = T.axis.spatial(8, ax2_0_0 * 2 + ax2_0_1)
+                                    v1_o = T.axis.spatial(8, ax0_0_1_ax1_0_1_fused % 4 * 2 + ax1_0)
+                                    T.reads(B_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16])
+                                    T.writes(B_reindex_shared_wmma_matrix_b[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16])
+                                    T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_b"})
+                                    for ax0_1, ax1_1 in T.grid(16, 16):
+                                        with T.block("B_reindex_shared_wmma.matrix_b"):
+                                            v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1])
+                                            T.reads(B_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i])
+                                            T.writes(B_reindex_shared_wmma_matrix_b[v0_o * 16 + v0_i, v1_o * 16 + v1_i])
+                                            B_reindex_shared_wmma_matrix_b[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = B_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]
+                            for ax0_0_3, ax1_0_3, ax2_0_2, ax0_0_4, ax1_0_4 in T.grid(1, 1, 1, 2, 2):
+                                with T.block("C_o"):
+                                    v0_o = T.axis.spatial(8, ax0_0_1_ax1_0_1_fused // 4 * 2 + ax0_0_3 * 2 + ax0_0_4)
+                                    v1_o = T.axis.spatial(8, ax0_0_1_ax1_0_1_fused % 4 * 2 + ax1_0_3 * 2 + ax1_0_4)
+                                    v2_o = T.axis.reduce(8, ax2_0_0 * 2 + ax2_0_1 + ax2_0_2)
+                                    T.reads(A_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v2_o * 16 : v2_o * 16 + 16], B_reindex_shared_wmma_matrix_b[v2_o * 16 : v2_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16])
+                                    T.writes(C_reindex_shared_wmma_accumulator[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16])
+                                    T.block_attr({"meta_schedule.auto_tensorize":"wmma_sync_16x16x16_f16f16f32", "meta_schedule.auto_tensorize_init":"wmma_fill_16x16x16_f32", "warp_execution":1})
+                                    with T.init():
+                                        for ax0_1, ax1_1 in T.grid(16, 16):
+                                            with T.block("C_init"):
+                                                v0_i_init, v1_i_init = T.axis.remap("SS", [ax0_1, ax1_1])
+                                                T.reads()
+                                                T.writes(C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i_init, v1_o * 16 + v1_i_init])
+                                                C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i_init, v1_o * 16 + v1_i_init] = T.float32(0)
+                                    for ax0_1, ax1_1, ax2_1 in T.grid(16, 16, 16):
+                                        with T.block("C"):
+                                            v0_i, v1_i, v2_i = T.axis.remap("SSR", [ax0_1, ax1_1, ax2_1])
+                                            T.reads(C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i], A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i], B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i])
+                                            T.writes(C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i])
+                                            T.block_attr({"meta_schedule.tiling_structure":"SSSRRSRS"})
+                                            C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] + T.cast(A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i], "float32") * T.cast(B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i], "float32")
+                    for ax0_0, ax1_0 in T.grid(2, 2):
+                        with T.block("C_reindex_shared_wmma.accumulator_o"):
+                            v0_o = T.axis.spatial(8, ax0_0_1_ax1_0_1_fused // 4 * 2 + ax0_0)
+                            v1_o = T.axis.spatial(8, ax0_0_1_ax1_0_1_fused % 4 * 2 + ax1_0)
+                            T.reads(C_reindex_shared_wmma_accumulator[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16])
+                            T.writes(C_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16])
+                            T.block_attr({"meta_schedule.auto_tensorize":"wmma_store_16x16x16_f32_shared"})
+                            for ax0_1, ax1_1 in T.grid(16, 16):
+                                with T.block("C_reindex_shared_wmma.accumulator"):
+                                    v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1])
+                                    T.reads(C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i])
+                                    T.writes(C_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i])
+                                    C_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i]
+                for ax0, ax1 in T.grid(32, 32):
+                    with T.block("C_reindex_shared"):
+                        v0 = T.axis.spatial(128, ax0_0_1_ax1_0_1_fused // 4 * 32 + ax0)
+                        v1 = T.axis.spatial(128, ax0_0_1_ax1_0_1_fused % 4 * 32 + ax1)
+                        T.reads(C_reindex_shared[v0, v1])
+                        T.writes(C[v0, v1])
+                        T.block_attr({"meta_schedule.cooperative_fetch":3})
+                        C[v0, v1] = C_reindex_shared[v0, v1]
+        for i0, i1 in T.grid(128, 128):
+            with T.block("compute"):
+                i0_1, i1_1 = T.axis.remap("SS", [i0, i1])
+                T.reads(C[i0_1, i1_1])
+                T.writes(compute[i0_1, i1_1])
+                compute[i0_1, i1_1] = T.max(C[i0_1, i1_1], T.float32(0))
+    # fmt: on
+    decision_0 = [
+        ("SamplePerfectTile", [1, 4, 1, 1, 2]),
+        ("SamplePerfectTile", [1, 4, 1, 1, 2]),
+        ("SamplePerfectTile", [4, 2, 1]),
+        ("SampleCategorical", 2),
+        ("SampleCategorical", 2),
+        ("SampleCategorical", 1),
+    ]
+    mod = te.create_prim_func(
+        te_workload.matmul_relu(
+            n=128,
+            m=128,
+            k=128,
+            in_dtype="float16",
+            out_dtype="float32",
+        )
+    )
+    actual = ms.TuneContext(
+        mod=mod,
+        target=tvm.target.Target("cuda"),
+        space_generator=ms.space_generator.PostOrderApply(),
+        sch_rules=[
+            multi_level_tiling_tensor_core(
+                use_software_pipeline=True,
+            ),
+        ],
+    ).generate_design_space()
+    check_sketches(
+        mod,
+        sketches=actual,
+        expected_mods=[matmul_relu_pipeline_0],
+        expected_decisions=[decision_0],
+    )
+
+
+def test_matmul_relu_global():
+    # fmt: off
+    @T.prim_func
+    def matmul_relu_global_0(A: T.Buffer[(128, 128), "float16"], B: T.Buffer[(128, 128), "float16"], compute: T.Buffer[(128, 128), "float32"]) -> None:
+        # function attr dict
+        T.func_attr({"global_symbol": "main", "tir.noalias": True})
+        # body
+        # with T.block("root")
+        C = T.alloc_buffer([128, 128], dtype="float32")
+        C_reindex_wmma_accumulator = T.alloc_buffer([128, 128], dtype="float32", scope="wmma.accumulator")
+        A_reindex_shared = T.alloc_buffer([128, 128], dtype="float16", scope="shared")
+        B_reindex_shared = T.alloc_buffer([128, 128], dtype="float16", scope="shared")
+        A_reindex_shared_wmma_matrix_a = T.alloc_buffer([128, 128], dtype="float16", scope="wmma.matrix_a")
+        B_reindex_shared_wmma_matrix_b = T.alloc_buffer([128, 128], dtype="float16", scope="wmma.matrix_b")
+        for ax0_0_0_ax1_0_0_fused in T.thread_binding(1, thread="blockIdx.y"):
+            for ax0_0_1_ax1_0_1_fused in T.thread_binding(1, thread="blockIdx.x"):
+                for ax0_0_2_ax1_0_2_fused in T.thread_binding(16, thread="threadIdx.y"):
+                    for ax2_0_0 in T.serial(2):
+                        for ax0_ax1_fused in T.serial(8192):
+                            with T.block("A_reindex_shared"):
+                                v0 = T.axis.spatial(128, ax0_ax1_fused // 64)
+                                v1 = T.axis.spatial(128, ax2_0_0 * 64 + ax0_ax1_fused % 64)
+                                T.reads(A[v0, v1])
+                                T.writes(A_reindex_shared[v0, v1])
+                                T.block_attr({"buffer_dim_align":[[0, 0, 32, 8]], "meta_schedule.cooperative_fetch":1})
+                                A_reindex_shared[v0, v1] = A[v0, v1]
+                        for ax0_ax1_fused in T.serial(8192):
+                            with T.block("B_reindex_shared"):
+                                v0 = T.axis.spatial(128, ax2_0_0 * 64 + ax0_ax1_fused // 128)
+                                v1 = T.axis.spatial(128, ax0_ax1_fused % 128)
+                                T.reads(B[v0, v1])
+                                T.writes(B_reindex_shared[v0, v1])
+                                T.block_attr({"buffer_dim_align":[[0, 0, 32, 8]], "meta_schedule.cooperative_fetch":1})
+                                B_reindex_shared[v0, v1] = B[v0, v1]
+                        for ax2_0_1 in T.serial(2):
+                            for ax0_0, ax1_0 in T.grid(1, 2):
+                                with T.block("A_reindex_shared_wmma.matrix_a_o"):
+                                    v0_o = T.axis.spatial(8, ax0_0_2_ax1_0_2_fused // 2)
+                                    v1_o = T.axis.spatial(8, ax2_0_0 * 4 + ax2_0_1 * 2 + ax1_0)
+                                    T.reads(A_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16])
+                                    T.writes(A_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16])
+                                    T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_a"})
+                                    for ax0_1, ax1_1 in T.grid(16, 16):
+                                        with T.block("A_reindex_shared_wmma.matrix_a"):
+                                            v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1])
+                                            T.reads(A_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i])
+                                            T.writes(A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i])
+                                            A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = A_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]
+                            for ax0_0, ax1_0 in T.grid(2, 4):
+                                with T.block("B_reindex_shared_wmma.matrix_b_o"):
+                                    v0_o = T.axis.spatial(8, ax2_0_0 * 4 + ax2_0_1 * 2 + ax0_0)
+                                    v1_o = T.axis.spatial(8, ax0_0_2_ax1_0_2_fused % 2 * 4 + ax1_0)
+                                    T.reads(B_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16])
+                                    T.writes(B_reindex_shared_wmma_matrix_b[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16])
+                                    T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_b"})
+                                    for ax0_1, ax1_1 in T.grid(16, 16):
+                                        with T.block("B_reindex_shared_wmma.matrix_b"):
+                                            v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1])
+                                            T.reads(B_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i])
+                                            T.writes(B_reindex_shared_wmma_matrix_b[v0_o * 16 + v0_i, v1_o * 16 + v1_i])
+                                            B_reindex_shared_wmma_matrix_b[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = B_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]
+                            for ax0_0_3, ax1_0_3, ax2_0_2, ax0_0_4, ax1_0_4 in T.grid(1, 4, 2, 1, 1):
+                                with T.block("C_o"):
+                                    v0_o = T.axis.spatial(8, ax0_0_2_ax1_0_2_fused // 2 + ax0_0_3 + ax0_0_4)
+                                    v1_o = T.axis.spatial(8, ax1_0_4 + ax0_0_2_ax1_0_2_fused % 2 * 4 + ax1_0_3)
+                                    v2_o = T.axis.reduce(8, ax2_0_0 * 4 + ax2_0_1 * 2 + ax2_0_2)
+                                    T.reads(A_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v2_o * 16 : v2_o * 16 + 16], B_reindex_shared_wmma_matrix_b[v2_o * 16 : v2_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16])
+                                    T.writes(C_reindex_wmma_accumulator[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16])
+                                    T.block_attr({"meta_schedule.auto_tensorize":"wmma_sync_16x16x16_f16f16f32", "meta_schedule.auto_tensorize_init":"wmma_fill_16x16x16_f32", "warp_execution":1})
+                                    with T.init():
+                                        for ax0_1, ax1_1 in T.grid(16, 16):
+                                            with T.block("C_init"):
+                                                v0_i_init, v1_i_init = T.axis.remap("SS", [ax0_1, ax1_1])
+                                                T.reads()
+                                                T.writes(C_reindex_wmma_accumulator[v0_o * 16 + v0_i_init, v1_o * 16 + v1_i_init])
+                                                C_reindex_wmma_accumulator[v0_o * 16 + v0_i_init, v1_o * 16 + v1_i_init] = T.float32(0)
+                                    for ax0_1, ax1_1, ax2_1 in T.grid(16, 16, 16):
+                                        with T.block("C"):
+                                            v0_i, v1_i, v2_i = T.axis.remap("SSR", [ax0_1, ax1_1, ax2_1])
+                                            T.reads(C_reindex_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i], A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i], B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i])
+                                            T.writes(C_reindex_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i])
+                                            T.block_attr({"meta_schedule.tiling_structure":"SSSRRSRS"})
+                                            C_reindex_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = C_reindex_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] + T.cast(A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i], "float32") * T.cast(B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i], "float32")
+                    for ax0_0, ax1_0 in T.grid(1, 4):
+                        with T.block("C_reindex_wmma.accumulator_o"):
+                            v0_o = T.axis.spatial(8, ax0_0_2_ax1_0_2_fused // 2)
+                            v1_o = T.axis.spatial(8, ax0_0_2_ax1_0_2_fused % 2 * 4 + ax1_0)
+                            T.reads(C_reindex_wmma_accumulator[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16])
+                            T.writes(C[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16])
+                            T.block_attr({"meta_schedule.auto_tensorize":"wmma_store_16x16x16_f32_global"})
+                            for ax0_1, ax1_1 in T.grid(16, 16):
+                                with T.block("C_reindex_wmma.accumulator"):
+                                    v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1])
+                                    T.reads(C_reindex_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i])
+                                    T.writes(C[v0_o * 16 + v0_i, v1_o * 16 + v1_i])
+                                    C[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = C_reindex_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i]
+        for i0, i1 in T.grid(128, 128):
+            with T.block("compute"):
+                i0_1, i1_1 = T.axis.remap("SS", [i0, i1])
+                T.reads(C[i0_1, i1_1])
+                T.writes(compute[i0_1, i1_1])
+                compute[i0_1, i1_1] = T.max(C[i0_1, i1_1], T.float32(0))
+    # fmt: on
+    decision_0 = [
+        ("SamplePerfectTile", [1, 1, 8, 1, 1]),
+        ("SamplePerfectTile", [1, 1, 2, 4, 1]),
+        ("SamplePerfectTile", [2, 2, 2]),
+        ("SampleCategorical", 0),
+        ("SampleCategorical", 0),
+    ]
+    mod = te.create_prim_func(
+        te_workload.matmul_relu(
+            n=128,
+            m=128,
+            k=128,
+            in_dtype="float16",
+            out_dtype="float32",
+        )
+    )
+    actual = ms.TuneContext(
+        mod=mod,
+        target=tvm.target.Target("cuda"),
+        space_generator=ms.space_generator.PostOrderApply(),
+        sch_rules=[multi_level_tiling_tensor_core(write_reuse_scope="global")]
+        + get_rules("cuda", ms.schedule_rule.AutoInline),
+    ).generate_design_space()
+    check_sketches(
+        mod,
+        sketches=actual,
+        expected_mods=[matmul_relu_global_0],
+        expected_decisions=[decision_0],
+    )
+
+
+def test_matmul_relu_non_tensorizable():
+    # expected to do nothing on non-tensorizable workloads
+    mod = te.create_prim_func(
+        te_workload.matmul_relu(  # dtype doesn't match tensor intrin
+            n=128,
+            m=128,
+            k=128,
+        )
+    )
+    (sch,) = ms.TuneContext(
+        mod=mod,
+        target=tvm.target.Target("cuda"),
+        space_generator=ms.space_generator.PostOrderApply(),
+        sch_rules=[multi_level_tiling_tensor_core(write_reuse_scope="global")]
+        + get_rules("cuda", ms.schedule_rule.AutoInline),
+    ).generate_design_space()
+    tvm.ir.assert_structural_equal(mod, sch.mod["main"])
+
+
+if __name__ == "__main__":
+    test_matmul_relu()
+    test_matmul_relu_with_fallback()
+    test_conv2d()
+    test_conv2d_more_intrin()
+    test_matmul_relu_pipeline()
+    test_matmul_relu_global()
+    test_matmul_relu_non_tensorizable()
diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py b/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py
deleted file mode 100644
index fe1220c509..0000000000
--- a/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py
+++ /dev/null
@@ -1,1205 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring
-import tvm
-import tvm.testing
-from tvm import te
-from tvm.meta_schedule import schedule_rule
-from tvm.meta_schedule.space_generator.post_order_apply import PostOrderApply
-from tvm.meta_schedule.testing import te_workload
-from tvm.meta_schedule.testing.schedule_rule import (
-    auto_inline,
-    multi_level_tiling,
-    multi_level_tiling_tensor_core,
-)
-from tvm.meta_schedule.testing.space_generation import check_trace
-from tvm.meta_schedule.tune_context import TuneContext
-from tvm.script import tir as T
-from tvm.target import Target
-from tvm.te import create_prim_func
-from tvm.tir.tensor_intrin.arm_cpu import DP4A_INTRIN
-from tvm.tir.tensor_intrin.x86 import VNNI_DOT_16x4_INTRIN as VNNI_INTRIN
-
-
-def _create_context(mod, target, rule) -> TuneContext:
-    if not isinstance(rule, (list, tuple)):
-        rule = [rule]
-    ctx = TuneContext(
-        mod=mod,
-        target=target,
-        space_generator=PostOrderApply(),
-        sch_rules=rule,
-        task_name="test",
-    )
-    return ctx
-
-
-def test_cpu_matmul():
-    expected = [
-        [
-            'b0 = sch.get_block(name="C", func_name="main")',
-            'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS")',
-            "l1, l2, l3 = sch.get_loops(block=b0)",
-            "v4, v5, v6, v7 = sch.sample_perfect_tile(loop=l1, n=4, max_innermost_factor=64)",
-            "l8, l9, l10, l11 = sch.split(loop=l1, factors=[v4, v5, v6, v7], preserve_unit_iters=True)",
-            "v12, v13, v14, v15 = sch.sample_perfect_tile(loop=l2, n=4, max_innermost_factor=64)",
-            "l16, l17, l18, l19 = sch.split(loop=l2, factors=[v12, v13, v14, v15], preserve_unit_iters=True)",
-            "v20, v21 = sch.sample_perfect_tile(loop=l3, n=2, max_innermost_factor=64)",
-            "l22, l23 = sch.split(loop=l3, factors=[v20, v21], preserve_unit_iters=True)",
-            "sch.reorder(l8, l16, l9, l17, l22, l10, l18, l23, l11, l19)",
-            'b24 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="global")',
-            "sch.reverse_compute_at(block=b24, loop=l17, preserve_unit_loops=True, index=-1)",
-        ],
-        [
-            'b0 = sch.get_block(name="C", func_name="main")',
-            'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS")',
-            "l1, l2, l3 = sch.get_loops(block=b0)",
-            "v4, v5, v6, v7 = sch.sample_perfect_tile(loop=l1, n=4, max_innermost_factor=64)",
-            "l8, l9, l10, l11 = sch.split(loop=l1, factors=[v4, v5, v6, v7], preserve_unit_iters=True)",
-            "v12, v13, v14, v15 = sch.sample_perfect_tile(loop=l2, n=4, max_innermost_factor=64)",
-            "l16, l17, l18, l19 = sch.split(loop=l2, factors=[v12, v13, v14, v15], preserve_unit_iters=True)",
-            "v20, v21 = sch.sample_perfect_tile(loop=l3, n=2, max_innermost_factor=64)",
-            "l22, l23 = sch.split(loop=l3, factors=[v20, v21], preserve_unit_iters=True)",
-            "sch.reorder(l8, l16, l9, l17, l22, l10, l18, l23, l11, l19)",
-            'b24 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="global")',
-            "sch.reverse_compute_at(block=b24, loop=l16, preserve_unit_loops=True, index=-1)",
-        ],
-        [
-            'b0 = sch.get_block(name="C", func_name="main")',
-            'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS")',
-            "l1, l2, l3 = sch.get_loops(block=b0)",
-            "v4, v5, v6, v7 = sch.sample_perfect_tile(loop=l1, n=4, max_innermost_factor=64)",
-            "l8, l9, l10, l11 = sch.split(loop=l1, factors=[v4, v5, v6, v7], preserve_unit_iters=True)",
-            "v12, v13, v14, v15 = sch.sample_perfect_tile(loop=l2, n=4, max_innermost_factor=64)",
-            "l16, l17, l18, l19 = sch.split(loop=l2, factors=[v12, v13, v14, v15], preserve_unit_iters=True)",
-            "v20, v21 = sch.sample_perfect_tile(loop=l3, n=2, max_innermost_factor=64)",
-            "l22, l23 = sch.split(loop=l3, factors=[v20, v21], preserve_unit_iters=True)",
-            "sch.reorder(l8, l16, l9, l17, l22, l10, l18, l23, l11, l19)",
-        ],
-    ]
-    target = Target("llvm")
-    ctx = _create_context(
-        create_prim_func(
-            te_workload.matmul(
-                n=512,
-                m=512,
-                k=512,
-            )
-        ),
-        target=target,
-        rule=multi_level_tiling(target=target),
-    )
-    spaces = ctx.space_generator.generate_design_space(mod=ctx.mod)
-    assert len(spaces) == 3
-    check_trace(spaces, expected)
-
-
-def test_cpu_matmul_relu():
-    # pylint: disable=line-too-long
-    expected = [
-        [
-            'b0 = sch.get_block(name="C", func_name="main")',
-            'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS")',
-            "l1, l2, l3 = sch.get_loops(block=b0)",
-            "v4, v5, v6, v7 = sch.sample_perfect_tile(loop=l1, n=4, max_innermost_factor=64)",
-            "l8, l9, l10, l11 = sch.split(loop=l1, factors=[v4, v5, v6, v7], preserve_unit_iters=True)",
-            "v12, v13, v14, v15 = sch.sample_perfect_tile(loop=l2, n=4, max_innermost_factor=64)",
-            "l16, l17, l18, l19 = sch.split(loop=l2, factors=[v12, v13, v14, v15], preserve_unit_iters=True)",
-            "v20, v21 = sch.sample_perfect_tile(loop=l3, n=2, max_innermost_factor=64)",
-            "l22, l23 = sch.split(loop=l3, factors=[v20, v21], preserve_unit_iters=True)",
-            "sch.reorder(l8, l16, l9, l17, l22, l10, l18, l23, l11, l19)",
-            "b24, = sch.get_consumers(block=b0)",
-            "sch.reverse_compute_at(block=b24, loop=l17, preserve_unit_loops=True, index=-1)",
-        ],
-        [
-            'b0 = sch.get_block(name="C", func_name="main")',
-            'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS")',
-            "l1, l2, l3 = sch.get_loops(block=b0)",
-            "v4, v5, v6, v7 = sch.sample_perfect_tile(loop=l1, n=4, max_innermost_factor=64)",
-            "l8, l9, l10, l11 = sch.split(loop=l1, factors=[v4, v5, v6, v7], preserve_unit_iters=True)",
-            "v12, v13, v14, v15 = sch.sample_perfect_tile(loop=l2, n=4, max_innermost_factor=64)",
-            "l16, l17, l18, l19 = sch.split(loop=l2, factors=[v12, v13, v14, v15], preserve_unit_iters=True)",
-            "v20, v21 = sch.sample_perfect_tile(loop=l3, n=2, max_innermost_factor=64)",
-            "l22, l23 = sch.split(loop=l3, factors=[v20, v21], preserve_unit_iters=True)",
-            "sch.reorder(l8, l16, l9, l17, l22, l10, l18, l23, l11, l19)",
-            "b24, = sch.get_consumers(block=b0)",
-            "sch.reverse_compute_at(block=b24, loop=l16, preserve_unit_loops=True, index=-1)",
-        ],
-        [
-            'b0 = sch.get_block(name="C", func_name="main")',
-            'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS")',
-            "l1, l2, l3 = sch.get_loops(block=b0)",
-            "v4, v5, v6, v7 = sch.sample_perfect_tile(loop=l1, n=4, max_innermost_factor=64)",
-            "l8, l9, l10, l11 = sch.split(loop=l1, factors=[v4, v5, v6, v7], preserve_unit_iters=True)",
-            "v12, v13, v14, v15 = sch.sample_perfect_tile(loop=l2, n=4, max_innermost_factor=64)",
-            "l16, l17, l18, l19 = sch.split(loop=l2, factors=[v12, v13, v14, v15], preserve_unit_iters=True)",
-            "v20, v21 = sch.sample_perfect_tile(loop=l3, n=2, max_innermost_factor=64)",
-            "l22, l23 = sch.split(loop=l3, factors=[v20, v21], preserve_unit_iters=True)",
-            "sch.reorder(l8, l16, l9, l17, l22, l10, l18, l23, l11, l19)",
-        ],
-    ]
-    # pylint: enable=line-too-long
-    target = Target("llvm")
-    ctx = _create_context(
-        create_prim_func(
-            te_workload.matmul_relu(
-                n=512,
-                m=512,
-                k=512,
-            )
-        ),
-        target=target,
-        rule=multi_level_tiling(target=target),
-    )
-    spaces = ctx.space_generator.generate_design_space(mod=ctx.mod)
-    assert len(spaces) == 3
-    check_trace(spaces, expected)
-
-
-def test_cuda_matmul():
-    # pylint: disable=line-too-long
-    expected = [
-        [
-            'b0 = sch.get_block(name="C", func_name="main")',
-            'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS")',
-            "l1, l2, l3 = sch.get_loops(block=b0)",
-            "v4, v5, v6, v7, v8 = sch.sample_perfect_tile(loop=l1, n=5, max_innermost_factor=64)",
-            "l9, l10, l11, l12, l13 = sch.split(loop=l1, factors=[v4, v5, v6, v7, v8], preserve_unit_iters=True)",
-            "v14, v15, v16, v17, v18 = sch.sample_perfect_tile(loop=l2, n=5, max_innermost_factor=64)",
-            "l19, l20, l21, l22, l23 = sch.split(loop=l2, factors=[v14, v15, v16, v17, v18], preserve_unit_iters=True)",
-            "v24, v25, v26 = sch.sample_perfect_tile(loop=l3, n=3, max_innermost_factor=64)",
-            "l27, l28, l29 = sch.split(loop=l3, factors=[v24, v25, v26], preserve_unit_iters=True)",
-            "sch.reorder(l9, l19, l10, l20, l11, l21, l27, l28, l12, l22, l29, l13, l23)",
-            "l30 = sch.fuse(l9, l19, preserve_unit_iters=True)",
-            'sch.bind(loop=l30, thread_axis="blockIdx.x")',
-            "l31 = sch.fuse(l10, l20, preserve_unit_iters=True)",
-            'sch.bind(loop=l31, thread_axis="vthread.x")',
-            "l32 = sch.fuse(l11, l21, preserve_unit_iters=True)",
-            'sch.bind(loop=l32, thread_axis="threadIdx.x")',
-            'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_low_inclusive", ann_val=32)',
-            'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_high_inclusive", ann_val=1024)',
-            'b33 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local")',
-            "sch.reverse_compute_at(block=b33, loop=l32, preserve_unit_loops=True, index=-1)",
-            'b34 = sch.cache_read(block=b0, read_buffer_index=0, storage_scope="shared")',
-            "sch.compute_at(block=b34, loop=l27, preserve_unit_loops=True, index=-1)",
-            "l35, l36, l37, l38, l39, l40 = sch.get_loops(block=b34)",
-            "l41 = sch.fuse(l39, l40, preserve_unit_iters=True)",
-            "v42 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])",
-            'sch.annotate(block_or_loop=b34, ann_key="meta_schedule.cooperative_fetch", ann_val=v42)',
-            'b43 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared")',
-            "sch.compute_at(block=b43, loop=l27, preserve_unit_loops=True, index=-1)",
-            "l44, l45, l46, l47, l48, l49 = sch.get_loops(block=b43)",
-            "l50 = sch.fuse(l48, l49, preserve_unit_iters=True)",
-            "v51 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])",
-            'sch.annotate(block_or_loop=b43, ann_key="meta_schedule.cooperative_fetch", ann_val=v51)',
-        ]
-    ]
-    # pylint: enable=line-too-long
-    target = Target("cuda --max_threads_per_block=1024 --thread_warp_size=32", host="llvm")
-    ctx = _create_context(
-        create_prim_func(
-            te_workload.matmul(
-                n=512,
-                m=512,
-                k=512,
-            )
-        ),
-        target=target,
-        rule=multi_level_tiling(target=target),
-    )
-    spaces = ctx.space_generator.generate_design_space(mod=ctx.mod)
-    assert len(spaces) == 1
-    check_trace(spaces, expected)
-
-
-def test_cuda_matmul_relu():
-    # pylint: disable=line-too-long
-    expected = [
-        [
-            'b0 = sch.get_block(name="C", func_name="main")',
-            'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS")',
-            "l1, l2, l3 = sch.get_loops(block=b0)",
-            "v4, v5, v6, v7, v8 = sch.sample_perfect_tile(loop=l1, n=5, max_innermost_factor=64)",
-            "l9, l10, l11, l12, l13 = sch.split(loop=l1, factors=[v4, v5, v6, v7, v8], preserve_unit_iters=True)",
-            "v14, v15, v16, v17, v18 = sch.sample_perfect_tile(loop=l2, n=5, max_innermost_factor=64)",
-            "l19, l20, l21, l22, l23 = sch.split(loop=l2, factors=[v14, v15, v16, v17, v18], preserve_unit_iters=True)",
-            "v24, v25, v26 = sch.sample_perfect_tile(loop=l3, n=3, max_innermost_factor=64)",
-            "l27, l28, l29 = sch.split(loop=l3, factors=[v24, v25, v26], preserve_unit_iters=True)",
-            "sch.reorder(l9, l19, l10, l20, l11, l21, l27, l28, l12, l22, l29, l13, l23)",
-            "l30 = sch.fuse(l9, l19, preserve_unit_iters=True)",
-            'sch.bind(loop=l30, thread_axis="blockIdx.x")',
-            "l31 = sch.fuse(l10, l20, preserve_unit_iters=True)",
-            'sch.bind(loop=l31, thread_axis="vthread.x")',
-            "l32 = sch.fuse(l11, l21, preserve_unit_iters=True)",
-            'sch.bind(loop=l32, thread_axis="threadIdx.x")',
-            'b33 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local")',
-            "sch.reverse_compute_at(block=b33, loop=l32, preserve_unit_loops=True, index=-1)",
-            'b34 = sch.cache_read(block=b0, read_buffer_index=0, storage_scope="shared")',
-            "sch.compute_at(block=b34, loop=l27, preserve_unit_loops=True, index=-1)",
-            "l35, l36, l37, l38, l39, l40 = sch.get_loops(block=b34)",
-            "l41 = sch.fuse(l39, l40, preserve_unit_iters=True)",
-            "v42 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])",
-            'sch.annotate(block_or_loop=b34, ann_key="meta_schedule.cooperative_fetch", ann_val=v42)',
-            'b43 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared")',
-            "sch.compute_at(block=b43, loop=l27, preserve_unit_loops=True, index=-1)",
-            "l44, l45, l46, l47, l48, l49 = sch.get_loops(block=b43)",
-            "l50 = sch.fuse(l48, l49, preserve_unit_iters=True)",
-            "v51 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])",
-            'sch.annotate(block_or_loop=b43, ann_key="meta_schedule.cooperative_fetch", ann_val=v51)',
-        ]
-    ]
-    # pylint: enable=line-too-long
-    target = Target("cuda", host="llvm")
-    ctx = _create_context(
-        create_prim_func(
-            te_workload.matmul_relu(
-                n=512,
-                m=512,
-                k=512,
-            )
-        ),
-        target=target,
-        rule=multi_level_tiling(target=target),
-    )
-    spaces = ctx.space_generator.generate_design_space(mod=ctx.mod)
-    assert len(spaces) == 1
-    check_trace(spaces, expected)
-
-
-def test_cuda_sum_with_trivial_block_iter():
-    @T.prim_func
-    def sum_with_trivial_block_iter(
-        A: T.Buffer[(1, 64, 768), "float32"], B: T.Buffer[(1, 64, 1), "float32"]
-    ) -> None:
-        for i0, i1, i2, i3 in T.grid(1, 64, 1, 768):
-            with T.block("sum"):
-                ax0, ax1, ax2, k2 = T.axis.remap("SSSR", [i0, i1, i2, i3])
-                T.reads(A[ax0, ax1, k2])
-                T.writes(B[ax0, ax1, ax2])
-                with T.init():
-                    B[ax0, ax1, ax2] = T.float32(0)
-                B[ax0, ax1, ax2] = B[ax0, ax1, ax2] + A[ax0, ax1, k2]
-
-    # Expect nothing to happen - the rule is not supposed to be applied in this case
-    expected = [[]]
-    target = Target("cuda", host="llvm")
-    ctx = _create_context(
-        sum_with_trivial_block_iter,
-        target=target,
-        rule=multi_level_tiling(target=target),
-    )
-    spaces = ctx.space_generator.generate_design_space(mod=ctx.mod)
-    assert len(spaces) == 1
-    check_trace(spaces, expected)
-
-
-@tvm.script.ir_module
-class Conv2dNCHWcVNNIModule:
-    @T.prim_func
-    def main(
-        placeholder: T.Buffer[(1, 4, 56, 56, 16), "uint8"],
-        placeholder_1: T.Buffer[(16, 4, 1, 1, 4, 16, 4), "int8"],
-        conv2d_NCHWc_int8: T.Buffer[(1, 16, 56, 56, 16), "int32"],
-    ) -> None:
-        T.func_attr({"global_symbol": "main", "tir.noalias": True})
-        for i0, i1, i2, i3, i4, i5, i6, i7, i8, i9 in T.grid(1, 16, 56, 56, 16, 1, 1, 4, 4, 4):
-            with T.block("conv2d_NCHWc_int8"):
-                (
-                    n,
-                    oc_chunk,
-                    oh,
-                    ow,
-                    oc_block,
-                    kh,
-                    kw,
-                    ic_outer,
-                    ic_f_inner,
-                    ic_s_inner,
-                ) = T.axis.remap("SSSSSRRRRR", [i0, i1, i2, i3, i4, i5, i6, i7, i8, i9])
-                T.reads(
-                    placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 + ic_s_inner],
-                    placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, oc_block, ic_s_inner],
-                )
-                T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block])
-                with T.init():
-                    conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block] = 0
-                conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block] = conv2d_NCHWc_int8[
-                    n, oc_chunk, oh, ow, oc_block
-                ] + T.cast(
-                    placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 + ic_s_inner], "int32"
-                ) * T.cast(
-                    placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, oc_block, ic_s_inner],
-                    "int32",
-                )
-
-
-def test_multi_level_tiling_conv2d_nchwc_vnni():
-    target = "llvm -mcpu=cascadelake -num-cores 4"
-    ctx = _create_context(
-        Conv2dNCHWcVNNIModule,
-        target=tvm.target.Target(target),
-        rule=schedule_rule.MultiLevelTilingWithIntrin(
-            VNNI_INTRIN,
-            structure="SSRSRS",
-            tile_binds=None,
-            max_innermost_factor=64,
-            vector_load_lens=None,
-            reuse_read=None,
-            reuse_write=schedule_rule.ReuseType(
-                req="may",
-                levels=[1, 2],
-                scope="global",
-            ),
-        ),
-    )
-
-    spaces = ctx.space_generator.generate_design_space(mod=ctx.mod)
-
-    expected = [
-        """b0 = sch.get_block(name="conv2d_NCHWc_int8", func_name="main")
-sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS")
-l1, l2, l3, l4, l5, l6, l7, l8, l9, l10 = sch.get_loops(block=b0)
-l11, l12 = sch.split(loop=l10, factors=[None, 4], preserve_unit_iters=True)
-l13, l14 = sch.split(loop=l5, factors=[None, 16], preserve_unit_iters=True)
-l15, l16, l17, l18, l19, l20, l21, l22, l23, l24, l25, l26 = sch.get_loops(block=b0)
-sch.reorder(l21, l22, l23, l24, l25, l14, l12)
-b27 = sch.blockize(loop=l14)
-sch.annotate(block_or_loop=b27, ann_key="meta_schedule.auto_tensorize", ann_val="dot_16x4_vnni")
-l28, l29, l30, l31, l32, l33, l34, l35, l36, l37 = sch.get_loops(block=b27)
-v38, v39, v40, v41 = sch.sample_perfect_tile(loop=l28, n=4, max_innermost_factor=64)
-l42, l43, l44, l45 = sch.split(loop=l28, factors=[v38, v39, v40, v41], preserve_unit_iters=True)
-v46, v47, v48, v49 = sch.sample_perfect_tile(loop=l29, n=4, max_innermost_factor=64)
-l50, l51, l52, l53 = sch.split(loop=l29, factors=[v46, v47, v48, v49], preserve_unit_iters=True)
-v54, v55, v56, v57 = sch.sample_perfect_tile(loop=l30, n=4, max_innermost_factor=64)
-l58, l59, l60, l61 = sch.split(loop=l30, factors=[v54, v55, v56, v57], preserve_unit_iters=True)
-v62, v63, v64, v65 = sch.sample_perfect_tile(loop=l31, n=4, max_innermost_factor=64)
-l66, l67, l68, l69 = sch.split(loop=l31, factors=[v62, v63, v64, v65], preserve_unit_iters=True)
-v70, v71, v72, v73 = sch.sample_perfect_tile(loop=l32, n=4, max_innermost_factor=64)
-l74, l75, l76, l77 = sch.split(loop=l32, factors=[v70, v71, v72, v73], preserve_unit_iters=True)
-v78, v79 = sch.sample_perfect_tile(loop=l33, n=2, max_innermost_factor=64)
-l80, l81 = sch.split(loop=l33, factors=[v78, v79], preserve_unit_iters=True)
-v82, v83 = sch.sample_perfect_tile(loop=l34, n=2, max_innermost_factor=64)
-l84, l85 = sch.split(loop=l34, factors=[v82, v83], preserve_unit_iters=True)
-v86, v87 = sch.sample_perfect_tile(loop=l35, n=2, max_innermost_factor=64)
-l88, l89 = sch.split(loop=l35, factors=[v86, v87], preserve_unit_iters=True)
-v90, v91 = sch.sample_perfect_tile(loop=l36, n=2, max_innermost_factor=64)
-l92, l93 = sch.split(loop=l36, factors=[v90, v91], preserve_unit_iters=True)
-v94, v95 = sch.sample_perfect_tile(loop=l37, n=2, max_innermost_factor=64)
-l96, l97 = sch.split(loop=l37, factors=[v94, v95], preserve_unit_iters=True)
-sch.reorder(l42, l50, l58, l66, l74, l43, l51, l59, l67, l75, l80, l84, l88, l92, l96, l44, l52, l60, l68, l76, l81, l85, l89, l93, l97, l45, l53, l61, l69, l77)
-b98 = sch.cache_write(block=b27, write_buffer_index=0, storage_scope="global")
-sch.reverse_compute_at(block=b98, loop=l75, preserve_unit_loops=True, index=-1)""".split(
-            "\n"
-        ),
-        """b0 = sch.get_block(name="conv2d_NCHWc_int8", func_name="main")
-sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS")
-l1, l2, l3, l4, l5, l6, l7, l8, l9, l10 = sch.get_loops(block=b0)
-l11, l12 = sch.split(loop=l10, factors=[None, 4], preserve_unit_iters=True)
-l13, l14 = sch.split(loop=l5, factors=[None, 16], preserve_unit_iters=True)
-l15, l16, l17, l18, l19, l20, l21, l22, l23, l24, l25, l26 = sch.get_loops(block=b0)
-sch.reorder(l21, l22, l23, l24, l25, l14, l12)
-b27 = sch.blockize(loop=l14)
-sch.annotate(block_or_loop=b27, ann_key="meta_schedule.auto_tensorize", ann_val="dot_16x4_vnni")
-l28, l29, l30, l31, l32, l33, l34, l35, l36, l37 = sch.get_loops(block=b27)
-v38, v39, v40, v41 = sch.sample_perfect_tile(loop=l28, n=4, max_innermost_factor=64)
-l42, l43, l44, l45 = sch.split(loop=l28, factors=[v38, v39, v40, v41], preserve_unit_iters=True)
-v46, v47, v48, v49 = sch.sample_perfect_tile(loop=l29, n=4, max_innermost_factor=64)
-l50, l51, l52, l53 = sch.split(loop=l29, factors=[v46, v47, v48, v49], preserve_unit_iters=True)
-v54, v55, v56, v57 = sch.sample_perfect_tile(loop=l30, n=4, max_innermost_factor=64)
-l58, l59, l60, l61 = sch.split(loop=l30, factors=[v54, v55, v56, v57], preserve_unit_iters=True)
-v62, v63, v64, v65 = sch.sample_perfect_tile(loop=l31, n=4, max_innermost_factor=64)
-l66, l67, l68, l69 = sch.split(loop=l31, factors=[v62, v63, v64, v65], preserve_unit_iters=True)
-v70, v71, v72, v73 = sch.sample_perfect_tile(loop=l32, n=4, max_innermost_factor=64)
-l74, l75, l76, l77 = sch.split(loop=l32, factors=[v70, v71, v72, v73], preserve_unit_iters=True)
-v78, v79 = sch.sample_perfect_tile(loop=l33, n=2, max_innermost_factor=64)
-l80, l81 = sch.split(loop=l33, factors=[v78, v79], preserve_unit_iters=True)
-v82, v83 = sch.sample_perfect_tile(loop=l34, n=2, max_innermost_factor=64)
-l84, l85 = sch.split(loop=l34, factors=[v82, v83], preserve_unit_iters=True)
-v86, v87 = sch.sample_perfect_tile(loop=l35, n=2, max_innermost_factor=64)
-l88, l89 = sch.split(loop=l35, factors=[v86, v87], preserve_unit_iters=True)
-v90, v91 = sch.sample_perfect_tile(loop=l36, n=2, max_innermost_factor=64)
-l92, l93 = sch.split(loop=l36, factors=[v90, v91], preserve_unit_iters=True)
-v94, v95 = sch.sample_perfect_tile(loop=l37, n=2, max_innermost_factor=64)
-l96, l97 = sch.split(loop=l37, factors=[v94, v95], preserve_unit_iters=True)
-sch.reorder(l42, l50, l58, l66, l74, l43, l51, l59, l67, l75, l80, l84, l88, l92, l96, l44, l52, l60, l68, l76, l81, l85, l89, l93, l97, l45, l53, l61, l69, l77)
-b98 = sch.cache_write(block=b27, write_buffer_index=0, storage_scope="global")
-sch.reverse_compute_at(block=b98, loop=l74, preserve_unit_loops=True, index=-1)""".split(
-            "\n"
-        ),
-        """b0 = sch.get_block(name="conv2d_NCHWc_int8", func_name="main")
-sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS")
-l1, l2, l3, l4, l5, l6, l7, l8, l9, l10 = sch.get_loops(block=b0)
-l11, l12 = sch.split(loop=l10, factors=[None, 4], preserve_unit_iters=True)
-l13, l14 = sch.split(loop=l5, factors=[None, 16], preserve_unit_iters=True)
-l15, l16, l17, l18, l19, l20, l21, l22, l23, l24, l25, l26 = sch.get_loops(block=b0)
-sch.reorder(l21, l22, l23, l24, l25, l14, l12)
-b27 = sch.blockize(loop=l14)
-sch.annotate(block_or_loop=b27, ann_key="meta_schedule.auto_tensorize", ann_val="dot_16x4_vnni")
-l28, l29, l30, l31, l32, l33, l34, l35, l36, l37 = sch.get_loops(block=b27)
-v38, v39, v40, v41 = sch.sample_perfect_tile(loop=l28, n=4, max_innermost_factor=64)
-l42, l43, l44, l45 = sch.split(loop=l28, factors=[v38, v39, v40, v41], preserve_unit_iters=True)
-v46, v47, v48, v49 = sch.sample_perfect_tile(loop=l29, n=4, max_innermost_factor=64)
-l50, l51, l52, l53 = sch.split(loop=l29, factors=[v46, v47, v48, v49], preserve_unit_iters=True)
-v54, v55, v56, v57 = sch.sample_perfect_tile(loop=l30, n=4, max_innermost_factor=64)
-l58, l59, l60, l61 = sch.split(loop=l30, factors=[v54, v55, v56, v57], preserve_unit_iters=True)
-v62, v63, v64, v65 = sch.sample_perfect_tile(loop=l31, n=4, max_innermost_factor=64)
-l66, l67, l68, l69 = sch.split(loop=l31, factors=[v62, v63, v64, v65], preserve_unit_iters=True)
-v70, v71, v72, v73 = sch.sample_perfect_tile(loop=l32, n=4, max_innermost_factor=64)
-l74, l75, l76, l77 = sch.split(loop=l32, factors=[v70, v71, v72, v73], preserve_unit_iters=True)
-v78, v79 = sch.sample_perfect_tile(loop=l33, n=2, max_innermost_factor=64)
-l80, l81 = sch.split(loop=l33, factors=[v78, v79], preserve_unit_iters=True)
-v82, v83 = sch.sample_perfect_tile(loop=l34, n=2, max_innermost_factor=64)
-l84, l85 = sch.split(loop=l34, factors=[v82, v83], preserve_unit_iters=True)
-v86, v87 = sch.sample_perfect_tile(loop=l35, n=2, max_innermost_factor=64)
-l88, l89 = sch.split(loop=l35, factors=[v86, v87], preserve_unit_iters=True)
-v90, v91 = sch.sample_perfect_tile(loop=l36, n=2, max_innermost_factor=64)
-l92, l93 = sch.split(loop=l36, factors=[v90, v91], preserve_unit_iters=True)
-v94, v95 = sch.sample_perfect_tile(loop=l37, n=2, max_innermost_factor=64)
-l96, l97 = sch.split(loop=l37, factors=[v94, v95], preserve_unit_iters=True)
-sch.reorder(l42, l50, l58, l66, l74, l43, l51, l59, l67, l75, l80, l84, l88, l92, l96, l44, l52, l60, l68, l76, l81, l85, l89, l93, l97, l45, l53, l61, l69, l77)""".split(
-            "\n"
-        ),
-    ]
-
-    check_trace(spaces, expected)
-
-
-def _test_multi_level_tiling_dense_dp4a(m, n, k, in_dtype, out_dtype, expected):
-    X = te.placeholder((m, k), name="X", dtype=in_dtype)
-    W = te.placeholder((n, k), name="W", dtype=in_dtype)
-    ak = te.reduce_axis((0, k), name="k")
-
-    matmul = te.compute(
-        (m, n),
-        lambda i, j: te.sum(
-            X[i, ak].astype(out_dtype) * W[j, ak].astype(out_dtype),
-            axis=ak,
-        ),
-        name="compute",
-    )
-
-    func = te.create_prim_func([X, W, matmul])
-
-    ctx = _create_context(
-        func,
-        target=tvm.target.Target("cuda"),
-        rule=schedule_rule.MultiLevelTilingWithIntrin(
-            DP4A_INTRIN,
-            structure="SSSRRSRS",
-            tile_binds=["blockIdx.x", "vthread.x", "threadIdx.x"],
-            max_innermost_factor=64,
-            vector_load_lens=[1, 2, 3, 4],
-            reuse_read=schedule_rule.ReuseType(
-                req="must",
-                levels=[4],
-                scope="shared",
-            ),
-            reuse_write=schedule_rule.ReuseType(
-                req="must",
-                levels=[3],
-                scope="local",
-            ),
-        ),
-    )
-
-    spaces = ctx.space_generator.generate_design_space(mod=ctx.mod)
-    check_trace(spaces, expected)
-
-
-def test_multi_level_tiling_dense_dp4a():
-    m, n, k = 128, 128, 128
-
-    expected = [
-        """b0 = sch.get_block(name="compute", func_name="main")
-sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS")
-l1, l2, l3 = sch.get_loops(block=b0)
-l4, l5 = sch.split(loop=l3, factors=[None, 4], preserve_unit_iters=True)
-sch.reorder(l5)
-b6 = sch.blockize(loop=l5)
-sch.annotate(block_or_loop=b6, ann_key="meta_schedule.auto_tensorize", ann_val="dp4a")
-l7, l8, l9 = sch.get_loops(block=b6)
-v10, v11, v12, v13, v14 = sch.sample_perfect_tile(loop=l7, n=5, max_innermost_factor=64)
-l15, l16, l17, l18, l19 = sch.split(loop=l7, factors=[v10, v11, v12, v13, v14], preserve_unit_iters=True)
-v20, v21, v22, v23, v24 = sch.sample_perfect_tile(loop=l8, n=5, max_innermost_factor=64)
-l25, l26, l27, l28, l29 = sch.split(loop=l8, factors=[v20, v21, v22, v23, v24], preserve_unit_iters=True)
-v30, v31, v32 = sch.sample_perfect_tile(loop=l9, n=3, max_innermost_factor=64)
-l33, l34, l35 = sch.split(loop=l9, factors=[v30, v31, v32], preserve_unit_iters=True)
-sch.reorder(l15, l25, l16, l26, l17, l27, l33, l34, l18, l28, l35, l19, l29)
-l36 = sch.fuse(l15, l25, preserve_unit_iters=True)
-sch.bind(loop=l36, thread_axis="blockIdx.x")
-l37 = sch.fuse(l16, l26, preserve_unit_iters=True)
-sch.bind(loop=l37, thread_axis="vthread.x")
-l38 = sch.fuse(l17, l27, preserve_unit_iters=True)
-sch.bind(loop=l38, thread_axis="threadIdx.x")
-b39 = sch.cache_write(block=b6, write_buffer_index=0, storage_scope="local")
-sch.reverse_compute_at(block=b39, loop=l38, preserve_unit_loops=True, index=-1)
-b40 = sch.cache_read(block=b6, read_buffer_index=0, storage_scope="shared")
-sch.compute_at(block=b40, loop=l33, preserve_unit_loops=True, index=-1)
-l41, l42, l43, l44, l45, l46 = sch.get_loops(block=b40)
-l47 = sch.fuse(l45, l46, preserve_unit_iters=True)
-v48 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])
-sch.annotate(block_or_loop=b40, ann_key="meta_schedule.cooperative_fetch", ann_val=v48)
-b49 = sch.cache_read(block=b6, read_buffer_index=1, storage_scope="shared")
-sch.compute_at(block=b49, loop=l33, preserve_unit_loops=True, index=-1)
-l50, l51, l52, l53, l54, l55 = sch.get_loops(block=b49)
-l56 = sch.fuse(l54, l55, preserve_unit_iters=True)
-v57 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])
-sch.annotate(block_or_loop=b49, ann_key="meta_schedule.cooperative_fetch", ann_val=v57)""".split(
-            "\n"
-        )
-    ]
-
-    _test_multi_level_tiling_dense_dp4a(m, n, k, "int8", "int32", expected)
-
-
-def test_multi_level_tiling_dense_dp4a_non_tensorizable():
-    _test_multi_level_tiling_dense_dp4a(128, 128, 128, "float32", "float32", [""])
-    _test_multi_level_tiling_dense_dp4a(127, 127, 127, "int8", "int32", [""])
-
-
-def test_cuda_tensor_core_matmul_relu():
-    m = n = k = 128
-    target = Target("cuda", host="llvm")
-    ctx = _create_context(
-        create_prim_func(
-            te_workload.matmul_relu(
-                n=n,
-                m=m,
-                k=k,
-                in_dtype="float16",
-                out_dtype="float32",
-            )
-        ),
-        target=target,
-        rule=[
-            multi_level_tiling_tensor_core(target=target, write_reuse_scope="shared"),
-            auto_inline(target),
-        ],
-    )
-    spaces = ctx.space_generator.generate_design_space(mod=ctx.mod)
-    assert len(spaces) == 1
-
-    expected = [
-        """b0 = sch.get_block(name="C", func_name="main")
-b1 = sch.get_block(name="compute", func_name="main")
-sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS")
-b2 = sch.reindex(block=b0, buffer=("write", 0))
-b3 = sch.reindex(block=b0, buffer=("read", 0))
-b4 = sch.reindex(block=b0, buffer=("read", 1))
-sch.transform_layout(block=b0, buffer=("read", 0), index_map=lambda i, k: (i, k, ))
-sch.transform_layout(block=b0, buffer=("read", 1), index_map=lambda j, k: (k, j, ))
-sch.transform_layout(block=b0, buffer=("write", 0), index_map=lambda i, j: (i, j, ))
-sch.transform_block_layout(block=b2, index_map=lambda i, j, k: (i, j, k, ))
-sch.transform_block_layout(block=b3, index_map=lambda i, j, k: (i, j, k, ))
-sch.transform_block_layout(block=b4, index_map=lambda i, j, k: (i, j, k, ))
-sch.transform_block_layout(block=b0, index_map=lambda i, j, k: (i, j, k, ))
-l5, l6, l7 = sch.get_loops(block=b0)
-l8, l9 = sch.split(loop=l7, factors=[None, 16], preserve_unit_iters=True)
-l10, l11 = sch.split(loop=l6, factors=[None, 16], preserve_unit_iters=True)
-l12, l13 = sch.split(loop=l5, factors=[None, 16], preserve_unit_iters=True)
-l14, l15, l16, l17, l18, l19 = sch.get_loops(block=b0)
-sch.reorder(l16, l18, l13, l11, l9)
-b20 = sch.blockize(loop=l13)
-sch.annotate(block_or_loop=b20, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_sync_16x16x16_f16f16f32")
-sch.annotate(block_or_loop=b20, ann_key="meta_schedule.auto_tensorize_init", ann_val="wmma_fill_16x16x16_f32")
-sch.annotate(block_or_loop=b20, ann_key="warp_execution", ann_val=1)
-l21, l22, l23 = sch.get_loops(block=b20)
-v24, v25, v26, v27, v28 = sch.sample_perfect_tile(loop=l21, n=5, max_innermost_factor=4)
-l29, l30, l31, l32, l33 = sch.split(loop=l21, factors=[v24, v25, v26, v27, v28], preserve_unit_iters=True)
-v34, v35, v36, v37, v38 = sch.sample_perfect_tile(loop=l22, n=5, max_innermost_factor=4)
-l39, l40, l41, l42, l43 = sch.split(loop=l22, factors=[v34, v35, v36, v37, v38], preserve_unit_iters=True)
-v44, v45, v46 = sch.sample_perfect_tile(loop=l23, n=3, max_innermost_factor=4)
-l47, l48, l49 = sch.split(loop=l23, factors=[v44, v45, v46], preserve_unit_iters=True)
-sch.reorder(l29, l39, l30, l40, l31, l41, l47, l48, l32, l42, l49, l33, l43)
-l50 = sch.fuse(l29, l39, preserve_unit_iters=True)
-sch.bind(loop=l50, thread_axis="blockIdx.y")
-l51 = sch.fuse(l30, l40, preserve_unit_iters=True)
-sch.bind(loop=l51, thread_axis="blockIdx.x")
-l52 = sch.fuse(l31, l41, preserve_unit_iters=True)
-sch.bind(loop=l52, thread_axis="threadIdx.y")
-b53 = sch.cache_write(block=b20, write_buffer_index=0, storage_scope="shared")
-sch.reverse_compute_at(block=b53, loop=l51, preserve_unit_loops=True, index=-1)
-b54 = sch.cache_write(block=b20, write_buffer_index=0, storage_scope="wmma.accumulator")
-sch.reverse_compute_at(block=b54, loop=l52, preserve_unit_loops=True, index=-1)
-v55 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])
-sch.annotate(block_or_loop=b53, ann_key="meta_schedule.cooperative_fetch", ann_val=v55)
-sch.reverse_compute_inline(block=b2)
-l56, l57, l58, l59, l60 = sch.get_loops(block=b54)
-l61, l62 = sch.split(loop=l60, factors=[None, 16], preserve_unit_iters=True)
-l63, l64 = sch.split(loop=l59, factors=[None, 16], preserve_unit_iters=True)
-l65, l66, l67, l68, l69, l70, l71 = sch.get_loops(block=b54)
-sch.reorder(l70, l64, l62)
-b72 = sch.blockize(loop=l64)
-sch.annotate(block_or_loop=b72, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_store_16x16x16_f32_shared")
-b73 = sch.cache_read(block=b20, read_buffer_index=0, storage_scope="shared")
-sch.compute_at(block=b73, loop=l47, preserve_unit_loops=True, index=-1)
-l74, l75, l76, l77, l78, l79 = sch.get_loops(block=b73)
-l80 = sch.fuse(l78, l79, preserve_unit_iters=True)
-v81 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25])
-sch.annotate(block_or_loop=b73, ann_key="meta_schedule.cooperative_fetch", ann_val=v81)
-b82 = sch.cache_read(block=b20, read_buffer_index=1, storage_scope="shared")
-sch.compute_at(block=b82, loop=l47, preserve_unit_loops=True, index=-1)
-l83, l84, l85, l86, l87, l88 = sch.get_loops(block=b82)
-l89 = sch.fuse(l87, l88, preserve_unit_iters=True)
-v90 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25])
-sch.annotate(block_or_loop=b82, ann_key="meta_schedule.cooperative_fetch", ann_val=v90)
-b91 = sch.cache_read(block=b20, read_buffer_index=0, storage_scope="wmma.matrix_a")
-sch.compute_at(block=b91, loop=l48, preserve_unit_loops=True, index=-1)
-l92, l93, l94, l95, l96, l97, l98 = sch.get_loops(block=b91)
-l99, l100 = sch.split(loop=l98, factors=[None, 16], preserve_unit_iters=True)
-l101, l102 = sch.split(loop=l97, factors=[None, 16], preserve_unit_iters=True)
-l103, l104, l105, l106, l107, l108, l109, l110, l111 = sch.get_loops(block=b91)
-sch.reorder(l110, l102, l100)
-b112 = sch.blockize(loop=l102)
-sch.annotate(block_or_loop=b112, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_load_16x16x16_f16_a")
-b113 = sch.cache_read(block=b20, read_buffer_index=1, storage_scope="wmma.matrix_b")
-sch.compute_at(block=b113, loop=l48, preserve_unit_loops=True, index=-1)
-l114, l115, l116, l117, l118, l119, l120 = sch.get_loops(block=b113)
-l121, l122 = sch.split(loop=l120, factors=[None, 16], preserve_unit_iters=True)
-l123, l124 = sch.split(loop=l119, factors=[None, 16], preserve_unit_iters=True)
-l125, l126, l127, l128, l129, l130, l131, l132, l133 = sch.get_loops(block=b113)
-sch.reorder(l132, l124, l122)
-b134 = sch.blockize(loop=l124)
-sch.annotate(block_or_loop=b134, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_load_16x16x16_f16_b")
-sch.compute_inline(block=b3)
-sch.compute_inline(block=b4)
-sch.storage_align(block=b73, buffer_index=0, axis=-2, factor=32, offset=8)
-sch.storage_align(block=b82, buffer_index=0, axis=-2, factor=32, offset=8)
-sch.reverse_compute_inline(block=b1)""".split(
-            "\n"
-        )
-    ]
-    check_trace(spaces, expected)
-
-    # test multi_level_tiling_tensor_core and multi_level_tiling can be used together in order
-    # to use multi_level_tiling as a fallback when the workload can't be tensorized
-    ctx = _create_context(
-        create_prim_func(
-            te_workload.matmul_relu(
-                n=n,
-                m=m,
-                k=k,
-                in_dtype="float16",
-                out_dtype="float32",
-            )
-        ),
-        target=target,
-        rule=[
-            multi_level_tiling_tensor_core(target=target, write_reuse_scope="shared"),
-            multi_level_tiling(target=target),
-            auto_inline(target),
-        ],
-    )
-    spaces = ctx.space_generator.generate_design_space(mod=ctx.mod)
-    assert len(spaces) == 1
-    check_trace(spaces, expected)
-
-
-def test_cuda_tensor_core_software_pipeline_matmul_relu():
-    m = n = k = 128
-    target = Target("cuda", host="llvm")
-    ctx = _create_context(
-        create_prim_func(
-            te_workload.matmul_relu(
-                n=n,
-                m=m,
-                k=k,
-                in_dtype="float16",
-                out_dtype="float32",
-            )
-        ),
-        target=target,
-        rule=[
-            multi_level_tiling_tensor_core(
-                target=target, write_reuse_scope="shared", use_software_pipeline=True
-            ),
-            auto_inline(target),
-        ],
-    )
-    spaces = ctx.space_generator.generate_design_space(mod=ctx.mod)
-    assert len(spaces) == 1
-
-    expected = [
-        """b0 = sch.get_block(name="C", func_name="main")
-b1 = sch.get_block(name="compute", func_name="main")
-sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS")
-b2 = sch.reindex(block=b0, buffer=("write", 0))
-b3 = sch.reindex(block=b0, buffer=("read", 0))
-b4 = sch.reindex(block=b0, buffer=("read", 1))
-sch.transform_layout(block=b0, buffer=("read", 0), index_map=lambda i, k: (i, k, ))
-sch.transform_layout(block=b0, buffer=("read", 1), index_map=lambda j, k: (k, j, ))
-sch.transform_layout(block=b0, buffer=("write", 0), index_map=lambda i, j: (i, j, ))
-sch.transform_block_layout(block=b2, index_map=lambda i, j, k: (i, j, k, ))
-sch.transform_block_layout(block=b3, index_map=lambda i, j, k: (i, j, k, ))
-sch.transform_block_layout(block=b4, index_map=lambda i, j, k: (i, j, k, ))
-sch.transform_block_layout(block=b0, index_map=lambda i, j, k: (i, j, k, ))
-l5, l6, l7 = sch.get_loops(block=b0)
-l8, l9 = sch.split(loop=l7, factors=[None, 16], preserve_unit_iters=True)
-l10, l11 = sch.split(loop=l6, factors=[None, 16], preserve_unit_iters=True)
-l12, l13 = sch.split(loop=l5, factors=[None, 16], preserve_unit_iters=True)
-l14, l15, l16, l17, l18, l19 = sch.get_loops(block=b0)
-sch.reorder(l16, l18, l13, l11, l9)
-b20 = sch.blockize(loop=l13)
-sch.annotate(block_or_loop=b20, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_sync_16x16x16_f16f16f32")
-sch.annotate(block_or_loop=b20, ann_key="meta_schedule.auto_tensorize_init", ann_val="wmma_fill_16x16x16_f32")
-sch.annotate(block_or_loop=b20, ann_key="warp_execution", ann_val=1)
-l21, l22, l23 = sch.get_loops(block=b20)
-v24, v25, v26, v27, v28 = sch.sample_perfect_tile(loop=l21, n=5, max_innermost_factor=4)
-l29, l30, l31, l32, l33 = sch.split(loop=l21, factors=[v24, v25, v26, v27, v28], preserve_unit_iters=True)
-v34, v35, v36, v37, v38 = sch.sample_perfect_tile(loop=l22, n=5, max_innermost_factor=4)
-l39, l40, l41, l42, l43 = sch.split(loop=l22, factors=[v34, v35, v36, v37, v38], preserve_unit_iters=True)
-v44, v45, v46 = sch.sample_perfect_tile(loop=l23, n=3, max_innermost_factor=4)
-l47, l48, l49 = sch.split(loop=l23, factors=[v44, v45, v46], preserve_unit_iters=True)
-sch.reorder(l29, l39, l30, l40, l31, l41, l47, l48, l32, l42, l49, l33, l43)
-l50 = sch.fuse(l29, l39, preserve_unit_iters=True)
-sch.bind(loop=l50, thread_axis="blockIdx.y")
-l51 = sch.fuse(l30, l40, preserve_unit_iters=True)
-sch.bind(loop=l51, thread_axis="blockIdx.x")
-l52 = sch.fuse(l31, l41, preserve_unit_iters=True)
-sch.bind(loop=l52, thread_axis="threadIdx.y")
-b53 = sch.cache_write(block=b20, write_buffer_index=0, storage_scope="shared")
-sch.reverse_compute_at(block=b53, loop=l51, preserve_unit_loops=True, index=-1)
-b54 = sch.cache_write(block=b20, write_buffer_index=0, storage_scope="wmma.accumulator")
-sch.reverse_compute_at(block=b54, loop=l52, preserve_unit_loops=True, index=-1)
-v55 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])
-sch.annotate(block_or_loop=b53, ann_key="meta_schedule.cooperative_fetch", ann_val=v55)
-sch.reverse_compute_inline(block=b2)
-l56, l57, l58, l59, l60 = sch.get_loops(block=b54)
-l61, l62 = sch.split(loop=l60, factors=[None, 16], preserve_unit_iters=True)
-l63, l64 = sch.split(loop=l59, factors=[None, 16], preserve_unit_iters=True)
-l65, l66, l67, l68, l69, l70, l71 = sch.get_loops(block=b54)
-sch.reorder(l70, l64, l62)
-b72 = sch.blockize(loop=l64)
-sch.annotate(block_or_loop=b72, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_store_16x16x16_f32_shared")
-b73 = sch.cache_read(block=b20, read_buffer_index=0, storage_scope="shared")
-sch.compute_at(block=b73, loop=l47, preserve_unit_loops=True, index=-1)
-l74, l75, l76, l77, l78, l79 = sch.get_loops(block=b73)
-l80 = sch.fuse(l78, l79, preserve_unit_iters=True)
-v81 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25])
-sch.annotate(block_or_loop=b73, ann_key="meta_schedule.cooperative_fetch", ann_val=v81)
-b82 = sch.cache_read(block=b20, read_buffer_index=1, storage_scope="shared")
-sch.compute_at(block=b82, loop=l47, preserve_unit_loops=True, index=-1)
-l83, l84, l85, l86, l87, l88 = sch.get_loops(block=b82)
-l89 = sch.fuse(l87, l88, preserve_unit_iters=True)
-v90 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25])
-sch.annotate(block_or_loop=b82, ann_key="meta_schedule.cooperative_fetch", ann_val=v90)
-b91 = sch.cache_read(block=b20, read_buffer_index=0, storage_scope="wmma.matrix_a")
-sch.compute_at(block=b91, loop=l48, preserve_unit_loops=True, index=-1)
-l92, l93, l94, l95, l96, l97, l98 = sch.get_loops(block=b91)
-l99, l100 = sch.split(loop=l98, factors=[None, 16], preserve_unit_iters=True)
-l101, l102 = sch.split(loop=l97, factors=[None, 16], preserve_unit_iters=True)
-l103, l104, l105, l106, l107, l108, l109, l110, l111 = sch.get_loops(block=b91)
-sch.reorder(l110, l102, l100)
-b112 = sch.blockize(loop=l102)
-sch.annotate(block_or_loop=b112, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_load_16x16x16_f16_a")
-b113 = sch.cache_read(block=b20, read_buffer_index=1, storage_scope="wmma.matrix_b")
-sch.compute_at(block=b113, loop=l48, preserve_unit_loops=True, index=-1)
-l114, l115, l116, l117, l118, l119, l120 = sch.get_loops(block=b113)
-l121, l122 = sch.split(loop=l120, factors=[None, 16], preserve_unit_iters=True)
-l123, l124 = sch.split(loop=l119, factors=[None, 16], preserve_unit_iters=True)
-l125, l126, l127, l128, l129, l130, l131, l132, l133 = sch.get_loops(block=b113)
-sch.reorder(l132, l124, l122)
-b134 = sch.blockize(loop=l124)
-sch.annotate(block_or_loop=b134, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_load_16x16x16_f16_b")
-sch.compute_inline(block=b3)
-sch.compute_inline(block=b4)
-sch.storage_align(block=b73, buffer_index=0, axis=-2, factor=32, offset=8)
-sch.storage_align(block=b82, buffer_index=0, axis=-2, factor=32, offset=8)
-sch.annotate(block_or_loop=b73, ann_key="tir.manifest_shared_memory_local_stage", ann_val=1)
-sch.annotate(block_or_loop=b73, ann_key="double_buffer_scope", ann_val=0)
-sch.annotate(block_or_loop=b82, ann_key="tir.manifest_shared_memory_local_stage", ann_val=1)
-sch.annotate(block_or_loop=b82, ann_key="double_buffer_scope", ann_val=0)
-sch.annotate(block_or_loop=l48, ann_key="software_pipeline_stage", ann_val=[0, 0, 1])
-sch.annotate(block_or_loop=l48, ann_key="software_pipeline_order", ann_val=[0, 1, 2])
-sch.annotate(block_or_loop=l47, ann_key="software_pipeline_stage", ann_val=[0, 0, 0, 0, 0, 1, 1])
-sch.annotate(block_or_loop=l47, ann_key="software_pipeline_order", ann_val=[0, 3, 1, 4, 5, 2, 6])
-sch.reverse_compute_inline(block=b1)""".split(
-            "\n"
-        )
-    ]
-    check_trace(spaces, expected)
-
-
-def test_cuda_tensor_core_matmul_relu_global():
-    m = n = k = 128
-    target = Target("cuda", host="llvm")
-    workload = create_prim_func(
-        te_workload.matmul_relu(
-            n=n,
-            m=m,
-            k=k,
-            in_dtype="float16",
-            out_dtype="float32",
-        ),
-    )
-    ctx = _create_context(
-        workload,
-        target=target,
-        rule=[
-            multi_level_tiling_tensor_core(target=target, write_reuse_scope="global"),
-            auto_inline(target),
-        ],
-    )
-    spaces = ctx.space_generator.generate_design_space(mod=ctx.mod)
-    assert len(spaces) == 1
-
-    expected = [
-        """b0 = sch.get_block(name="C", func_name="main")
-sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS")
-b1 = sch.reindex(block=b0, buffer=("write", 0))
-b2 = sch.reindex(block=b0, buffer=("read", 0))
-b3 = sch.reindex(block=b0, buffer=("read", 1))
-sch.transform_layout(block=b0, buffer=("read", 0), index_map=lambda i, k: (i, k, ))
-sch.transform_layout(block=b0, buffer=("read", 1), index_map=lambda j, k: (k, j, ))
-sch.transform_layout(block=b0, buffer=("write", 0), index_map=lambda i, j: (i, j, ))
-sch.transform_block_layout(block=b1, index_map=lambda i, j, k: (i, j, k, ))
-sch.transform_block_layout(block=b2, index_map=lambda i, j, k: (i, j, k, ))
-sch.transform_block_layout(block=b3, index_map=lambda i, j, k: (i, j, k, ))
-sch.transform_block_layout(block=b0, index_map=lambda i, j, k: (i, j, k, ))
-l4, l5, l6 = sch.get_loops(block=b0)
-l7, l8 = sch.split(loop=l6, factors=[None, 16], preserve_unit_iters=True)
-l9, l10 = sch.split(loop=l5, factors=[None, 16], preserve_unit_iters=True)
-l11, l12 = sch.split(loop=l4, factors=[None, 16], preserve_unit_iters=True)
-l13, l14, l15, l16, l17, l18 = sch.get_loops(block=b0)
-sch.reorder(l15, l17, l12, l10, l8)
-b19 = sch.blockize(loop=l12)
-sch.annotate(block_or_loop=b19, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_sync_16x16x16_f16f16f32")
-sch.annotate(block_or_loop=b19, ann_key="meta_schedule.auto_tensorize_init", ann_val="wmma_fill_16x16x16_f32")
-sch.annotate(block_or_loop=b19, ann_key="warp_execution", ann_val=1)
-l20, l21, l22 = sch.get_loops(block=b19)
-v23, v24, v25, v26, v27 = sch.sample_perfect_tile(loop=l20, n=5, max_innermost_factor=4)
-l28, l29, l30, l31, l32 = sch.split(loop=l20, factors=[v23, v24, v25, v26, v27], preserve_unit_iters=True)
-v33, v34, v35, v36, v37 = sch.sample_perfect_tile(loop=l21, n=5, max_innermost_factor=4)
-l38, l39, l40, l41, l42 = sch.split(loop=l21, factors=[v33, v34, v35, v36, v37], preserve_unit_iters=True)
-v43, v44, v45 = sch.sample_perfect_tile(loop=l22, n=3, max_innermost_factor=4)
-l46, l47, l48 = sch.split(loop=l22, factors=[v43, v44, v45], preserve_unit_iters=True)
-sch.reorder(l28, l38, l29, l39, l30, l40, l46, l47, l31, l41, l48, l32, l42)
-l49 = sch.fuse(l28, l38, preserve_unit_iters=True)
-sch.bind(loop=l49, thread_axis="blockIdx.y")
-l50 = sch.fuse(l29, l39, preserve_unit_iters=True)
-sch.bind(loop=l50, thread_axis="blockIdx.x")
-l51 = sch.fuse(l30, l40, preserve_unit_iters=True)
-sch.bind(loop=l51, thread_axis="threadIdx.y")
-b52 = sch.cache_write(block=b19, write_buffer_index=0, storage_scope="wmma.accumulator")
-sch.reverse_compute_at(block=b52, loop=l51, preserve_unit_loops=True, index=-1)
-sch.reverse_compute_inline(block=b1)
-l53, l54, l55, l56, l57 = sch.get_loops(block=b52)
-l58, l59 = sch.split(loop=l57, factors=[None, 16], preserve_unit_iters=True)
-l60, l61 = sch.split(loop=l56, factors=[None, 16], preserve_unit_iters=True)
-l62, l63, l64, l65, l66, l67, l68 = sch.get_loops(block=b52)
-sch.reorder(l67, l61, l59)
-b69 = sch.blockize(loop=l61)
-sch.annotate(block_or_loop=b69, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_store_16x16x16_f32_global")
-b70 = sch.cache_read(block=b19, read_buffer_index=0, storage_scope="shared")
-sch.compute_at(block=b70, loop=l46, preserve_unit_loops=True, index=-1)
-l71, l72, l73, l74, l75, l76 = sch.get_loops(block=b70)
-l77 = sch.fuse(l75, l76, preserve_unit_iters=True)
-v78 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25])
-sch.annotate(block_or_loop=b70, ann_key="meta_schedule.cooperative_fetch", ann_val=v78)
-b79 = sch.cache_read(block=b19, read_buffer_index=1, storage_scope="shared")
-sch.compute_at(block=b79, loop=l46, preserve_unit_loops=True, index=-1)
-l80, l81, l82, l83, l84, l85 = sch.get_loops(block=b79)
-l86 = sch.fuse(l84, l85, preserve_unit_iters=True)
-v87 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25])
-sch.annotate(block_or_loop=b79, ann_key="meta_schedule.cooperative_fetch", ann_val=v87)
-b88 = sch.cache_read(block=b19, read_buffer_index=0, storage_scope="wmma.matrix_a")
-sch.compute_at(block=b88, loop=l47, preserve_unit_loops=True, index=-1)
-l89, l90, l91, l92, l93, l94, l95 = sch.get_loops(block=b88)
-l96, l97 = sch.split(loop=l95, factors=[None, 16], preserve_unit_iters=True)
-l98, l99 = sch.split(loop=l94, factors=[None, 16], preserve_unit_iters=True)
-l100, l101, l102, l103, l104, l105, l106, l107, l108 = sch.get_loops(block=b88)
-sch.reorder(l107, l99, l97)
-b109 = sch.blockize(loop=l99)
-sch.annotate(block_or_loop=b109, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_load_16x16x16_f16_a")
-b110 = sch.cache_read(block=b19, read_buffer_index=1, storage_scope="wmma.matrix_b")
-sch.compute_at(block=b110, loop=l47, preserve_unit_loops=True, index=-1)
-l111, l112, l113, l114, l115, l116, l117 = sch.get_loops(block=b110)
-l118, l119 = sch.split(loop=l117, factors=[None, 16], preserve_unit_iters=True)
-l120, l121 = sch.split(loop=l116, factors=[None, 16], preserve_unit_iters=True)
-l122, l123, l124, l125, l126, l127, l128, l129, l130 = sch.get_loops(block=b110)
-sch.reorder(l129, l121, l119)
-b131 = sch.blockize(loop=l121)
-sch.annotate(block_or_loop=b131, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_load_16x16x16_f16_b")
-sch.compute_inline(block=b2)
-sch.compute_inline(block=b3)
-sch.storage_align(block=b70, buffer_index=0, axis=-2, factor=32, offset=8)
-sch.storage_align(block=b79, buffer_index=0, axis=-2, factor=32, offset=8)""".split(
-            "\n"
-        )
-    ]
-    check_trace(spaces, expected)
-
-    ctx = _create_context(
-        workload,
-        target=target,
-        rule=[
-            multi_level_tiling_tensor_core(
-                target=target, write_reuse_scope="global", trans_b=[False, True]
-            ),
-            auto_inline(target),
-        ],
-    )
-    spaces = ctx.space_generator.generate_design_space(mod=ctx.mod)
-    assert len(spaces) == 2
-
-    expected = [
-        expected[0],
-        """b0 = sch.get_block(name="C", func_name="main")
-sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS")
-b1 = sch.reindex(block=b0, buffer=("write", 0))
-b2 = sch.reindex(block=b0, buffer=("read", 0))
-b3 = sch.reindex(block=b0, buffer=("read", 1))
-sch.transform_layout(block=b0, buffer=("read", 0), index_map=lambda i, k: (i, k, ))
-sch.transform_layout(block=b0, buffer=("read", 1), index_map=lambda j, k: (j, k, ))
-sch.transform_layout(block=b0, buffer=("write", 0), index_map=lambda i, j: (i, j, ))
-sch.transform_block_layout(block=b1, index_map=lambda i, j, k: (i, j, k, ))
-sch.transform_block_layout(block=b2, index_map=lambda i, j, k: (i, j, k, ))
-sch.transform_block_layout(block=b3, index_map=lambda i, j, k: (i, j, k, ))
-sch.transform_block_layout(block=b0, index_map=lambda i, j, k: (i, j, k, ))
-l4, l5, l6 = sch.get_loops(block=b0)
-l7, l8 = sch.split(loop=l6, factors=[None, 16], preserve_unit_iters=True)
-l9, l10 = sch.split(loop=l5, factors=[None, 16], preserve_unit_iters=True)
-l11, l12 = sch.split(loop=l4, factors=[None, 16], preserve_unit_iters=True)
-l13, l14, l15, l16, l17, l18 = sch.get_loops(block=b0)
-sch.reorder(l15, l17, l12, l10, l8)
-b19 = sch.blockize(loop=l12)
-sch.annotate(block_or_loop=b19, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_sync_16x16x16_f16f16f32_trans")
-sch.annotate(block_or_loop=b19, ann_key="meta_schedule.auto_tensorize_init", ann_val="wmma_fill_16x16x16_f32")
-sch.annotate(block_or_loop=b19, ann_key="warp_execution", ann_val=1)
-l20, l21, l22 = sch.get_loops(block=b19)
-v23, v24, v25, v26, v27 = sch.sample_perfect_tile(loop=l20, n=5, max_innermost_factor=4)
-l28, l29, l30, l31, l32 = sch.split(loop=l20, factors=[v23, v24, v25, v26, v27], preserve_unit_iters=True)
-v33, v34, v35, v36, v37 = sch.sample_perfect_tile(loop=l21, n=5, max_innermost_factor=4)
-l38, l39, l40, l41, l42 = sch.split(loop=l21, factors=[v33, v34, v35, v36, v37], preserve_unit_iters=True)
-v43, v44, v45 = sch.sample_perfect_tile(loop=l22, n=3, max_innermost_factor=4)
-l46, l47, l48 = sch.split(loop=l22, factors=[v43, v44, v45], preserve_unit_iters=True)
-sch.reorder(l28, l38, l29, l39, l30, l40, l46, l47, l31, l41, l48, l32, l42)
-l49 = sch.fuse(l28, l38, preserve_unit_iters=True)
-sch.bind(loop=l49, thread_axis="blockIdx.y")
-l50 = sch.fuse(l29, l39, preserve_unit_iters=True)
-sch.bind(loop=l50, thread_axis="blockIdx.x")
-l51 = sch.fuse(l30, l40, preserve_unit_iters=True)
-sch.bind(loop=l51, thread_axis="threadIdx.y")
-b52 = sch.cache_write(block=b19, write_buffer_index=0, storage_scope="wmma.accumulator")
-sch.reverse_compute_at(block=b52, loop=l51, preserve_unit_loops=True, index=-1)
-sch.reverse_compute_inline(block=b1)
-l53, l54, l55, l56, l57 = sch.get_loops(block=b52)
-l58, l59 = sch.split(loop=l57, factors=[None, 16], preserve_unit_iters=True)
-l60, l61 = sch.split(loop=l56, factors=[None, 16], preserve_unit_iters=True)
-l62, l63, l64, l65, l66, l67, l68 = sch.get_loops(block=b52)
-sch.reorder(l67, l61, l59)
-b69 = sch.blockize(loop=l61)
-sch.annotate(block_or_loop=b69, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_store_16x16x16_f32_global")
-b70 = sch.cache_read(block=b19, read_buffer_index=0, storage_scope="shared")
-sch.compute_at(block=b70, loop=l46, preserve_unit_loops=True, index=-1)
-l71, l72, l73, l74, l75, l76 = sch.get_loops(block=b70)
-l77 = sch.fuse(l75, l76, preserve_unit_iters=True)
-v78 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25])
-sch.annotate(block_or_loop=b70, ann_key="meta_schedule.cooperative_fetch", ann_val=v78)
-b79 = sch.cache_read(block=b19, read_buffer_index=1, storage_scope="shared")
-sch.compute_at(block=b79, loop=l46, preserve_unit_loops=True, index=-1)
-l80, l81, l82, l83, l84, l85 = sch.get_loops(block=b79)
-l86 = sch.fuse(l84, l85, preserve_unit_iters=True)
-v87 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25])
-sch.annotate(block_or_loop=b79, ann_key="meta_schedule.cooperative_fetch", ann_val=v87)
-b88 = sch.cache_read(block=b19, read_buffer_index=0, storage_scope="wmma.matrix_a")
-sch.compute_at(block=b88, loop=l47, preserve_unit_loops=True, index=-1)
-l89, l90, l91, l92, l93, l94, l95 = sch.get_loops(block=b88)
-l96, l97 = sch.split(loop=l95, factors=[None, 16], preserve_unit_iters=True)
-l98, l99 = sch.split(loop=l94, factors=[None, 16], preserve_unit_iters=True)
-l100, l101, l102, l103, l104, l105, l106, l107, l108 = sch.get_loops(block=b88)
-sch.reorder(l107, l99, l97)
-b109 = sch.blockize(loop=l99)
-sch.annotate(block_or_loop=b109, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_load_16x16x16_f16_a")
-b110 = sch.cache_read(block=b19, read_buffer_index=1, storage_scope="wmma.matrix_b")
-sch.compute_at(block=b110, loop=l47, preserve_unit_loops=True, index=-1)
-l111, l112, l113, l114, l115, l116, l117 = sch.get_loops(block=b110)
-l118, l119 = sch.split(loop=l117, factors=[None, 16], preserve_unit_iters=True)
-l120, l121 = sch.split(loop=l116, factors=[None, 16], preserve_unit_iters=True)
-l122, l123, l124, l125, l126, l127, l128, l129, l130 = sch.get_loops(block=b110)
-sch.reorder(l129, l121, l119)
-b131 = sch.blockize(loop=l121)
-sch.annotate(block_or_loop=b131, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_load_16x16x16_f16_b_trans")
-sch.compute_inline(block=b2)
-sch.compute_inline(block=b3)
-sch.storage_align(block=b70, buffer_index=0, axis=-2, factor=32, offset=8)
-sch.storage_align(block=b79, buffer_index=0, axis=-2, factor=32, offset=8)""".split(
-            "\n"
-        ),
-    ]
-    check_trace(spaces, expected)
-
-
-def test_multi_level_tiling_non_tensorizable():
-    # expected to do nothing on non-tensorizable workloads
-    m = n = k = 128
-    target = Target("cuda", host="llvm")
-    ctx = _create_context(
-        create_prim_func(
-            # dtype doesn't match tensor intrin
-            te_workload.matmul_relu(
-                n=n,
-                m=m,
-                k=k,
-            )
-        ),
-        target=target,
-        rule=multi_level_tiling_tensor_core(target=target, write_reuse_scope="global"),
-    )
-    spaces = ctx.space_generator.generate_design_space(mod=ctx.mod)
-    assert len(spaces) == 1
-
-    expected = [
-        "",  # expected to do nothing when the workload can't be tensorized
-    ]
-    check_trace(spaces, expected)
-
-
-def test_cuda_tensor_core_conv2d():
-    target = Target("cuda", host="llvm")
-    workload = create_prim_func(
-        te_workload.conv2d_nhwc(
-            N=1,
-            H=16,
-            W=16,
-            CI=32,
-            CO=32,
-            kernel_size=3,
-            stride=1,
-            padding=1,
-            in_dtype="float16",
-            out_dtype="float32",
-        )
-    )
-    ctx = _create_context(
-        workload,
-        target=target,
-        rule=multi_level_tiling_tensor_core(target=target, write_reuse_scope="shared"),
-    )
-    spaces = ctx.space_generator.generate_design_space(mod=ctx.mod)
-    assert len(spaces) == 1
-
-    expected = [
-        """b0 = sch.get_block(name="conv2d_nhwc", func_name="main")
-sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS")
-b1 = sch.reindex(block=b0, buffer=("write", 0))
-b2 = sch.reindex(block=b0, buffer=("read", 0))
-b3 = sch.reindex(block=b0, buffer=("read", 1))
-sch.transform_layout(block=b0, buffer=("read", 0), index_map=lambda h, w, rh, rw, rc: (((h*16) + w), (((rh*96) + (rw*32)) + rc), ))
-sch.transform_layout(block=b0, buffer=("read", 1), index_map=lambda co, rh, rw, rc: ((((rh*96) + (rw*32)) + rc), co, ))
-sch.transform_layout(block=b0, buffer=("write", 0), index_map=lambda h, w, co: (((h*16) + w), co, ))
-sch.transform_block_layout(block=b1, index_map=lambda n, h, w, co, rh, rw, rc: (n, ((h*16) + w), co, (((rh*96) + (rw*32)) + rc), ))
-sch.transform_block_layout(block=b2, index_map=lambda n, h, w, co, rh, rw, rc: (n, ((h*16) + w), co, (((rh*96) + (rw*32)) + rc), ))
-sch.transform_block_layout(block=b3, index_map=lambda n, h, w, co, rh, rw, rc: (n, ((h*16) + w), co, (((rh*96) + (rw*32)) + rc), ))
-sch.transform_block_layout(block=b0, index_map=lambda n, h, w, co, rh, rw, rc: (n, ((h*16) + w), co, (((rh*96) + (rw*32)) + rc), ))
-l4, l5, l6, l7 = sch.get_loops(block=b0)
-l8, l9 = sch.split(loop=l7, factors=[None, 16], preserve_unit_iters=True)
-l10, l11 = sch.split(loop=l6, factors=[None, 16], preserve_unit_iters=True)
-l12, l13 = sch.split(loop=l5, factors=[None, 16], preserve_unit_iters=True)
-l14, l15, l16, l17, l18, l19, l20 = sch.get_loops(block=b0)
-sch.reorder(l17, l19, l13, l11, l9)
-b21 = sch.blockize(loop=l13)
-sch.annotate(block_or_loop=b21, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_sync_16x16x16_f16f16f32")
-sch.annotate(block_or_loop=b21, ann_key="meta_schedule.auto_tensorize_init", ann_val="wmma_fill_16x16x16_f32")
-sch.annotate(block_or_loop=b21, ann_key="warp_execution", ann_val=1)
-l22, l23, l24, l25 = sch.get_loops(block=b21)
-v26, v27, v28, v29, v30 = sch.sample_perfect_tile(loop=l22, n=5, max_innermost_factor=4)
-l31, l32, l33, l34, l35 = sch.split(loop=l22, factors=[v26, v27, v28, v29, v30], preserve_unit_iters=True)
-v36, v37, v38, v39, v40 = sch.sample_perfect_tile(loop=l23, n=5, max_innermost_factor=4)
-l41, l42, l43, l44, l45 = sch.split(loop=l23, factors=[v36, v37, v38, v39, v40], preserve_unit_iters=True)
-v46, v47, v48, v49, v50 = sch.sample_perfect_tile(loop=l24, n=5, max_innermost_factor=4)
-l51, l52, l53, l54, l55 = sch.split(loop=l24, factors=[v46, v47, v48, v49, v50], preserve_unit_iters=True)
-v56, v57, v58 = sch.sample_perfect_tile(loop=l25, n=3, max_innermost_factor=4)
-l59, l60, l61 = sch.split(loop=l25, factors=[v56, v57, v58], preserve_unit_iters=True)
-sch.reorder(l31, l41, l51, l32, l42, l52, l33, l43, l53, l59, l60, l34, l44, l54, l61, l35, l45, l55)
-l62 = sch.fuse(l31, l41, l51, preserve_unit_iters=True)
-sch.bind(loop=l62, thread_axis="blockIdx.y")
-l63 = sch.fuse(l32, l42, l52, preserve_unit_iters=True)
-sch.bind(loop=l63, thread_axis="blockIdx.x")
-l64 = sch.fuse(l33, l43, l53, preserve_unit_iters=True)
-sch.bind(loop=l64, thread_axis="threadIdx.y")
-b65 = sch.cache_write(block=b21, write_buffer_index=0, storage_scope="shared")
-sch.reverse_compute_at(block=b65, loop=l63, preserve_unit_loops=True, index=-1)
-b66 = sch.cache_write(block=b21, write_buffer_index=0, storage_scope="wmma.accumulator")
-sch.reverse_compute_at(block=b66, loop=l64, preserve_unit_loops=True, index=-1)
-v67 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])
-sch.annotate(block_or_loop=b65, ann_key="meta_schedule.cooperative_fetch", ann_val=v67)
-sch.reverse_compute_inline(block=b1)
-l68, l69, l70, l71, l72 = sch.get_loops(block=b66)
-l73, l74 = sch.split(loop=l72, factors=[None, 16], preserve_unit_iters=True)
-l75, l76 = sch.split(loop=l71, factors=[None, 16], preserve_unit_iters=True)
-l77, l78, l79, l80, l81, l82, l83 = sch.get_loops(block=b66)
-sch.reorder(l82, l76, l74)
-b84 = sch.blockize(loop=l76)
-sch.annotate(block_or_loop=b84, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_store_16x16x16_f32_shared")
-b85 = sch.cache_read(block=b21, read_buffer_index=0, storage_scope="shared")
-sch.compute_at(block=b85, loop=l59, preserve_unit_loops=True, index=-1)
-l86, l87, l88, l89, l90, l91 = sch.get_loops(block=b85)
-l92 = sch.fuse(l90, l91, preserve_unit_iters=True)
-v93 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25])
-sch.annotate(block_or_loop=b85, ann_key="meta_schedule.cooperative_fetch", ann_val=v93)
-b94 = sch.cache_read(block=b21, read_buffer_index=1, storage_scope="shared")
-sch.compute_at(block=b94, loop=l59, preserve_unit_loops=True, index=-1)
-l95, l96, l97, l98, l99, l100 = sch.get_loops(block=b94)
-l101 = sch.fuse(l99, l100, preserve_unit_iters=True)
-v102 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25])
-sch.annotate(block_or_loop=b94, ann_key="meta_schedule.cooperative_fetch", ann_val=v102)
-b103 = sch.cache_read(block=b21, read_buffer_index=0, storage_scope="wmma.matrix_a")
-sch.compute_at(block=b103, loop=l60, preserve_unit_loops=True, index=-1)
-l104, l105, l106, l107, l108, l109, l110 = sch.get_loops(block=b103)
-l111, l112 = sch.split(loop=l110, factors=[None, 16], preserve_unit_iters=True)
-l113, l114 = sch.split(loop=l109, factors=[None, 16], preserve_unit_iters=True)
-l115, l116, l117, l118, l119, l120, l121, l122, l123 = sch.get_loops(block=b103)
-sch.reorder(l122, l114, l112)
-b124 = sch.blockize(loop=l114)
-sch.annotate(block_or_loop=b124, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_load_16x16x16_f16_a")
-b125 = sch.cache_read(block=b21, read_buffer_index=1, storage_scope="wmma.matrix_b")
-sch.compute_at(block=b125, loop=l60, preserve_unit_loops=True, index=-1)
-l126, l127, l128, l129, l130, l131, l132 = sch.get_loops(block=b125)
-l133, l134 = sch.split(loop=l132, factors=[None, 16], preserve_unit_iters=True)
-l135, l136 = sch.split(loop=l131, factors=[None, 16], preserve_unit_iters=True)
-l137, l138, l139, l140, l141, l142, l143, l144, l145 = sch.get_loops(block=b125)
-sch.reorder(l144, l136, l134)
-b146 = sch.blockize(loop=l136)
-sch.annotate(block_or_loop=b146, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_load_16x16x16_f16_b")
-sch.compute_inline(block=b2)
-sch.compute_inline(block=b3)
-sch.storage_align(block=b85, buffer_index=0, axis=-2, factor=32, offset=8)
-sch.storage_align(block=b94, buffer_index=0, axis=-2, factor=32, offset=8)""".split(
-            "\n"
-        )
-    ]
-    check_trace(spaces, expected)
-
-    # test adding unappliable tensor intrinsics doesn't change the search space
-    ctx = _create_context(
-        workload,
-        target,
-        multi_level_tiling_tensor_core(
-            target=target,
-            write_reuse_scope="shared",
-            in_dtype="float16",
-            out_dtype=["float16", "float32"],
-        ),
-    )
-    check_trace(spaces, expected)
-    spaces = ctx.space_generator.generate_design_space(mod=ctx.mod)
-    assert len(spaces) == 1
-
-
-if __name__ == "__main__":
-    tvm.testing.main()