You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by wu...@apache.org on 2022/07/14 20:32:16 UTC

[tvm] branch main updated: [MetaSchedule] Add MultiLevelTilingTensorCore rule for auto-tensorization on CUDA (#12059)

This is an automated email from the ASF dual-hosted git repository.

wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new e084791852 [MetaSchedule] Add MultiLevelTilingTensorCore rule for auto-tensorization on CUDA (#12059)
e084791852 is described below

commit e0847918526ee6da7415fa40c557a6f3d86db75f
Author: Wuwei Lin <wu...@apache.org>
AuthorDate: Thu Jul 14 13:32:11 2022 -0700

    [MetaSchedule] Add MultiLevelTilingTensorCore rule for auto-tensorization on CUDA (#12059)
    
    * [MetaSchedule] Add MultiLevelTilingTensorCore rule for auto-tensorization on CUDA
    
    Co-authored-by: Bohan Hou <32...@users.noreply.github.com>
    Co-authored-by: Hongyi Jin <32...@qq.com>
    
    * address comments
    
    * update intrin registrations
    
    * fix tests
    
    * address comments
    
    * add warning when storage align doesn't work
    
    * remove print
    
    Co-authored-by: Bohan Hou <32...@users.noreply.github.com>
    Co-authored-by: Hongyi Jin <32...@qq.com>
---
 include/tvm/meta_schedule/schedule_rule.h          |  24 ++
 include/tvm/tir/stmt.h                             |   4 +
 python/tvm/meta_schedule/schedule_rule/__init__.py |   7 +-
 .../schedule_rule/multi_level_tiling.py            |  54 ++-
 python/tvm/meta_schedule/testing/schedule_rule.py  |  34 ++
 python/tvm/tir/tensor_intrin/cuda.py               |  73 +++-
 .../postproc/rewrite_reduction_block.cc            |  17 +
 src/meta_schedule/postproc/rewrite_tensorize.cc    |   6 +-
 src/meta_schedule/schedule_rule/auto_inline.cc     |   3 +-
 .../schedule_rule/multi_level_tiling.cc            |  28 +-
 .../schedule_rule/multi_level_tiling.h             |  10 +-
 .../multi_level_tiling_tensor_core.cc              | 393 +++++++++++++++++++
 src/tir/schedule/analysis.h                        |  18 +-
 src/tir/schedule/analysis/analysis.cc              |  33 +-
 src/tir/schedule/primitive/block_annotate.cc       |   5 +-
 src/tir/schedule/primitive/cache_read_write.cc     |   7 +-
 .../schedule/primitive/layout_transformation.cc    |   7 +-
 src/tir/schedule/transform.cc                      |   6 +-
 ...est_meta_schedule_postproc_rewrite_tensorize.py |   2 +-
 ...ta_schedule_schedule_rule_multi_level_tiling.py | 426 ++++++++++++++++++++-
 20 files changed, 1092 insertions(+), 65 deletions(-)

diff --git a/include/tvm/meta_schedule/schedule_rule.h b/include/tvm/meta_schedule/schedule_rule.h
index 7e0e5bda57..5e4698db17 100644
--- a/include/tvm/meta_schedule/schedule_rule.h
+++ b/include/tvm/meta_schedule/schedule_rule.h
@@ -173,6 +173,30 @@ class ScheduleRule : public runtime::ObjectRef {
       Optional<Integer> max_innermost_factor, Optional<Array<Integer>> vector_load_lens,
       Optional<Map<String, ObjectRef>> reuse_read, Optional<Map<String, ObjectRef>> reuse_write);
 
+  /*!
+   * \brief Extension of MultiLevelTiling for auto-tensorizing with a single group of tensor core
+   * intrinsics
+   * \param intrin_group A group of tensor core intrinsics. The map should contains key "init",
+   * "load_a", "load_b", "compute", "store", which represent the tensor intrin for initialization,
+   * loading operand A, loading operand B, tensor core computation, storing the result. The value of
+   * the map should be names of tensor intrinsics, must be registerd via TensorIntrin.register(...)
+   * beforehand
+   * \param structure The tiling structure. Recommended:
+   * - 'SSSRRSRS' on GPU
+   * \param tile_binds For each level of tiles, which thread axis it is bound to. Recommended:
+   * - [blockIdx.y, blockIdx.x, threadIdx.y] on GPU
+   * \param max_innermost_factor The maximum size of the innermost factor. NullOpt means no limit
+   * \param vector_load_lens The length of vector lane in vectorized cooperative fetching.
+   * NullOpt means disable vectorization
+   * \param reuse_read Data reuse configuration for reading. NullOpt means no reuse.
+   * \param reuse_write Data reuse configuration for writing. NullOpt means no reuse.
+   * \return The schedule rule created
+   */
+  TVM_DLL static ScheduleRule MultiLevelTilingTensorCore(
+      Map<String, String> intrin_group, String structure, Optional<Array<String>> tile_binds,
+      Optional<Integer> max_innermost_factor, Optional<Array<Integer>> vector_load_lens,
+      Optional<Map<String, ObjectRef>> reuse_read, Optional<Map<String, ObjectRef>> reuse_write);
+
   /*!
    * \brief Create a rule: add-rfactor to some blocks if needed
    * \param max_jobs_per_core The maximum number of jobs to be launched per CPU core. It sets the
diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h
index ddc97549fc..2060fb7920 100644
--- a/include/tvm/tir/stmt.h
+++ b/include/tvm/tir/stmt.h
@@ -1524,6 +1524,10 @@ constexpr const char* meta_schedule_auto_tensorize = "meta_schedule.auto_tensori
 
 /*! \brief Mark that a block is a preprocessor block for layout rewrite. */
 constexpr const char* meta_schedule_layout_rewrite_preproc = "meta_schedule.layout_rewrite_preproc";
+/*!
+ * \brief Mark that the init statement of a block should be further rewritten using tensorization.
+ */
+constexpr const char* meta_schedule_auto_tensorize_init = "meta_schedule.auto_tensorize_init";
 
 /*!
  * \brief Mark that a block is executed by a warp. This implies the extend of threadIdx.x is
diff --git a/python/tvm/meta_schedule/schedule_rule/__init__.py b/python/tvm/meta_schedule/schedule_rule/__init__.py
index 18fc1de78c..dd0119b0a7 100644
--- a/python/tvm/meta_schedule/schedule_rule/__init__.py
+++ b/python/tvm/meta_schedule/schedule_rule/__init__.py
@@ -23,7 +23,12 @@ from .add_rfactor import AddRFactor
 from .auto_bind import AutoBind
 from .auto_inline import AutoInline
 from .cross_thread_reduction import CrossThreadReduction
-from .multi_level_tiling import MultiLevelTiling, MultiLevelTilingWithIntrin, ReuseType
+from .multi_level_tiling import (
+    MultiLevelTiling,
+    MultiLevelTilingWithIntrin,
+    ReuseType,
+    MultiLevelTilingTensorCore,
+)
 from .parallel_vectorize_unroll import ParallelizeVectorizeUnroll
 from .random_compute_location import RandomComputeLocation
 from .schedule_rule import PyScheduleRule, ScheduleRule
diff --git a/python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py b/python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py
index 0bad6cbb4c..71fbaee4f6 100644
--- a/python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py
+++ b/python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py
@@ -15,7 +15,7 @@
 # specific language governing permissions and limitations
 # under the License.
 """Multi-level tiling with reuse."""
-from typing import Any, Dict, List, NamedTuple, Optional
+from typing import Any, Dict, List, Mapping, NamedTuple, Optional
 
 from tvm._ffi import register_object
 
@@ -131,3 +131,55 @@ class MultiLevelTilingWithIntrin(ScheduleRule):
             reuse_read.as_dict() if reuse_read is not None else None,
             reuse_write.as_dict() if reuse_write is not None else None,
         )
+
+
+@register_object("meta_schedule.MultiLevelTilingTensorCore")
+class MultiLevelTilingTensorCore(ScheduleRule):
+    """Extension of MultiLevelTiling for auto-tensorizing with a single group of tensor core
+    intrinsics.
+
+    Parameters
+    ----------
+    intrin_group : Mapping[str, str]
+        A group of tensor core intrinsics. The map should contains key "init", "load_a", "load_b",
+        "compute", "store", which represent the tensor intrin for initialization, loading operand A,
+        loading operand B, tensor core computation, storing the result.
+        The value of the map should be names of tensor intrinsics, must be registerd via
+        TensorIntrin.register(...) beforehand
+    structure : str
+        The tiling structure. Recommended:
+        - 'SSSRRSRS' on GPU
+    tile_bind : Optional[List[str]]
+        For each level of tiles, which thread axis it is bound to. Recommended:
+        - [blockIdx.y, vthread.x, threadIdx.y] on GPU
+    max_innermost_factor : Optional[int]
+        The maximum size of the innermost factor. None means no limit
+    vector_load_lens : Optional[List[int]]
+        The length of vector lane in vectorized cooperative fetching.
+        None means disable vectorization
+    reuse_read : Optional[ReuseType]
+        Data reuse configuration for reading. None means no reuse.
+    reuse_write : Optional[ReuseType]
+        Data reuse configuration for writing. None means no reuse.
+    """
+
+    def __init__(
+        self,
+        intrin_group: Mapping[str, str],
+        structure: str,
+        tile_binds: Optional[List[str]] = None,
+        max_innermost_factor: Optional[int] = None,
+        vector_load_lens: Optional[List[int]] = None,
+        reuse_read: Optional[ReuseType] = None,
+        reuse_write: Optional[ReuseType] = None,
+    ) -> None:
+        self.__init_handle_by_constructor__(
+            _ffi_api.ScheduleRuleMultiLevelTilingTensorCore,  # type: ignore # pylint: disable=no-member
+            intrin_group,
+            structure,
+            tile_binds,
+            max_innermost_factor,
+            vector_load_lens,
+            reuse_read.as_dict() if reuse_read is not None else None,
+            reuse_write.as_dict() if reuse_write is not None else None,
+        )
diff --git a/python/tvm/meta_schedule/testing/schedule_rule.py b/python/tvm/meta_schedule/testing/schedule_rule.py
index e159bfaaaa..717be59512 100644
--- a/python/tvm/meta_schedule/testing/schedule_rule.py
+++ b/python/tvm/meta_schedule/testing/schedule_rule.py
@@ -26,6 +26,8 @@ from tvm.meta_schedule.schedule_rule import (
     ReuseType,
     ScheduleRule,
 )
+from tvm.meta_schedule.schedule_rule.multi_level_tiling import MultiLevelTilingTensorCore
+from tvm.tir import tensor_intrin
 from tvm.target import Target
 
 
@@ -110,6 +112,38 @@ def multi_level_tiling(target: Target) -> ScheduleRule:
     raise NotImplementedError(f"{target.kind.name} is not supported")
 
 
+def multi_level_tiling_tensor_core(
+    target: Target,
+    write_reuse_scope="shared",
+    in_dtype="float16",
+    out_dtype="float32",
+    trans_b=False,
+) -> ScheduleRule:
+    """Default schedule rules for with multi-level tiling reuse for tensor core"""
+    assert write_reuse_scope in ["shared", "global"]
+    if target.kind.name == "cuda":
+        return MultiLevelTilingTensorCore(
+            intrin_group=tensor_intrin.get_wmma_intrin_group(
+                write_reuse_scope, in_dtype, out_dtype, trans_b
+            ),
+            structure="SSSRRSRS",
+            tile_binds=["blockIdx.y", "blockIdx.x", "threadIdx.y"],
+            max_innermost_factor=4,  # 64 // tensor intrin size
+            vector_load_lens=[1, 2, 3, 4],
+            reuse_read=ReuseType(
+                req="must",
+                levels=[4],
+                scope="shared",
+            ),
+            reuse_write=ReuseType(
+                req="must" if write_reuse_scope == "shared" else "no",
+                levels=[2],
+                scope=write_reuse_scope,
+            ),
+        )
+    raise NotImplementedError(f"{target.kind.name} is not supported")
+
+
 def random_compute_location(target: Target) -> ScheduleRule:
     """Default schedule rules for with random-compute-location"""
     if target.kind.name == "llvm":
diff --git a/python/tvm/tir/tensor_intrin/cuda.py b/python/tvm/tir/tensor_intrin/cuda.py
index 909b13e35c..e7d5defcf3 100644
--- a/python/tvm/tir/tensor_intrin/cuda.py
+++ b/python/tvm/tir/tensor_intrin/cuda.py
@@ -16,7 +16,7 @@
 # under the License.
 # pylint: disable=invalid-name,missing-function-docstring
 """Intrinsics for tensorization on NVIDIA GPU."""
-from typing import Tuple
+from typing import Tuple, Dict
 from tvm.script import tir as T
 from tvm.tir.function import PrimFunc
 from .. import IntImm, Cast
@@ -769,15 +769,15 @@ TensorIntrin.register(
     *get_wmma_load_intrin(16, 16, 16, "float16", "shared", True, False),
 )
 
-WMMA_LOAD_16x16x16_F16_A_INTRIN = "wmma_load_16x16x16_f16_a_trans"
+WMMA_LOAD_16x16x16_F16_A_TRANS_INTRIN = "wmma_load_16x16x16_f16_a_trans"
 TensorIntrin.register(
-    WMMA_LOAD_16x16x16_F16_A_INTRIN,
+    WMMA_LOAD_16x16x16_F16_A_TRANS_INTRIN,
     *get_wmma_load_intrin(16, 16, 16, "float16", "shared", False, True),
 )
 
-WMMA_LOAD_16x16x16_F16_B_INTRIN = "wmma_load_16x16x16_f16_b_trans"
+WMMA_LOAD_16x16x16_F16_B_TRANS_INTRIN = "wmma_load_16x16x16_f16_b_trans"
 TensorIntrin.register(
-    WMMA_LOAD_16x16x16_F16_B_INTRIN,
+    WMMA_LOAD_16x16x16_F16_B_TRANS_INTRIN,
     *get_wmma_load_intrin(16, 16, 16, "float16", "shared", True, True),
 )
 
@@ -806,3 +806,66 @@ WMMA_STORE_16x16x16_F16_GLOBAL_INTRIN = "wmma_store_16x16x16_f16_global"
 TensorIntrin.register(
     WMMA_STORE_16x16x16_F16_GLOBAL_INTRIN, *get_wmma_store_intrin(16, 16, 16, "float16", "global")
 )
+
+
+def get_wmma_intrin_group(
+    store_scope: str, in_dtype: str, out_dtype: str, trans_b: bool
+) -> Dict[str, str]:
+    """Get a group of intrinsics for wmma tensor core with the given configurations
+
+    Parameters
+    ----------
+    store_scope : str
+        Must be one of ["global", "shared"]. The memory scope of the result buffer.
+
+    in_dtype : str
+        The input data type.
+
+    out_dtype : str
+        The output data dtype.
+
+    trans_b : bool
+        Whether the input matrix B is transposed.
+
+    Returns
+    -------
+    ret : Dict[str, str]
+        A group of tensor intrinsics.
+    """
+    assert store_scope in ["global", "shared"]
+    assert in_dtype in ["float16"]
+    assert out_dtype in ["float16", "float32"]
+
+    load_a_intrins = {"float16": WMMA_LOAD_16x16x16_F16_A_INTRIN}
+    load_b_intrins = {
+        "float16": WMMA_LOAD_16x16x16_F16_B_TRANS_INTRIN
+        if trans_b
+        else WMMA_LOAD_16x16x16_F16_B_INTRIN
+    }
+    compute_intrins = {
+        "float16": WMMA_SYNC_16x16x16_f16f16f16_TRANS_INTRIN
+        if trans_b
+        else WMMA_SYNC_16x16x16_f16f16f16_INTRIN,
+        "float32": WMMA_SYNC_16x16x16_f16f16f32_TRANS_INTRIN
+        if trans_b
+        else WMMA_SYNC_16x16x16_f16f16f32_INTRIN,
+    }
+    init_intrins = {
+        "float16": WMMA_FILL_16x16x16_F16_INTRIN,
+        "float32": WMMA_FILL_16x16x16_F32_INTRIN,
+    }
+    store_intrins = {
+        "float16": WMMA_STORE_16x16x16_F16_SHARED_INTRIN
+        if store_scope == "shared"
+        else WMMA_STORE_16x16x16_F16_GLOBAL_INTRIN,
+        "float32": WMMA_STORE_16x16x16_F32_SHARED_INTRIN
+        if store_scope == "shared"
+        else WMMA_STORE_16x16x16_F32_GLOBAL_INTRIN,
+    }
+    return {
+        "init": init_intrins[out_dtype],
+        "load_a": load_a_intrins[in_dtype],
+        "load_b": load_b_intrins[in_dtype],
+        "compute": compute_intrins[out_dtype],
+        "store": store_intrins[out_dtype],
+    }
diff --git a/src/meta_schedule/postproc/rewrite_reduction_block.cc b/src/meta_schedule/postproc/rewrite_reduction_block.cc
index cea1f5b93c..ea204e3061 100644
--- a/src/meta_schedule/postproc/rewrite_reduction_block.cc
+++ b/src/meta_schedule/postproc/rewrite_reduction_block.cc
@@ -135,6 +135,23 @@ bool RewriteReductionBlockNode::Apply(const tir::Schedule& sch) {
       tir::BlockRV block_rv = GetRVFromSRef(sch, block_sref, global_var_name);
       Array<tir::LoopRV> loop_rvs = sch->GetLoops(block_rv);
       tir::BlockRV init_block_rv = sch->DecomposeReduction(block_rv, loop_rvs[decompose_point]);
+
+      // Rewrite auto tensorization related annotations
+      if (tir::GetAnn<String>(block_sref, tir::attr::meta_schedule_auto_tensorize).defined()) {
+        // Remove tensorization annotation as it shouldn't be propagated to the init block.
+        sch->Unannotate(init_block_rv, tir::attr::meta_schedule_auto_tensorize);
+        Optional<String> tensorize_init =
+            tir::GetAnn<String>(block_sref, tir::attr::meta_schedule_auto_tensorize_init);
+        // The annotation of tensorization of the init statement should be moved to the init block
+        // after 'DecomposeReduction'.
+        // Annotate to hint `RewriteTensorize` postprocessor even if tensorize_init is NullOpt.
+        sch->Annotate(init_block_rv, tir::attr::meta_schedule_auto_tensorize,
+                      tensorize_init.value_or(""));
+        if (tensorize_init.defined()) {
+          sch->Unannotate(block_rv, tir::attr::meta_schedule_auto_tensorize_init);
+          sch->Unannotate(init_block_rv, tir::attr::meta_schedule_auto_tensorize_init);
+        }
+      }
       ++rewritten;
     }
     if (rewritten == 0) {
diff --git a/src/meta_schedule/postproc/rewrite_tensorize.cc b/src/meta_schedule/postproc/rewrite_tensorize.cc
index 3df9075972..3b6c438d02 100644
--- a/src/meta_schedule/postproc/rewrite_tensorize.cc
+++ b/src/meta_schedule/postproc/rewrite_tensorize.cc
@@ -35,10 +35,10 @@ void CollectTensorizationJobs(
   tir::PostOrderVisit(func->body, [=, &jobs](const ObjectRef& obj) {
     if (const auto* block = obj.as<tir::BlockNode>()) {
       tir::StmtSRef block_sref = sch->GetSRef(block);
+      std::string block_name = block_sref->StmtAs<tir::BlockNode>()->name_hint;
       if (Optional<String> intrin_name =
               tir::GetAnn<String>(block_sref, tir::attr::meta_schedule_auto_tensorize)) {
-        std::string block_name = block_sref->StmtAs<tir::BlockNode>()->name_hint;
-        if (block_name.find("init") == std::string::npos) {
+        if (intrin_name.value() != "") {
           jobs->emplace_back(block_name, func_name, [sch, intrin_name](tir::BlockRV block) {
             try {
               sch->Tensorize(block, intrin_name.value());
@@ -46,7 +46,7 @@ void CollectTensorizationJobs(
               LOG(WARNING) << "Tensorize failed with error " << e.what();
             }
           });
-        } else if (vectorize_init_loop) {
+        } else if (block_name.find("init") && vectorize_init_loop) {
           jobs->emplace_back(block_name, func_name, [sch](tir::BlockRV block) {
             Array<BlockRV> child_blocks = sch->GetChildBlocks(block);
             ICHECK(child_blocks.size() == 1);
diff --git a/src/meta_schedule/schedule_rule/auto_inline.cc b/src/meta_schedule/schedule_rule/auto_inline.cc
index 309f0a60ac..df4d3ac859 100644
--- a/src/meta_schedule/schedule_rule/auto_inline.cc
+++ b/src/meta_schedule/schedule_rule/auto_inline.cc
@@ -143,7 +143,8 @@ inline InlineType AutoInlineNode::CheckInline(const tir::Schedule& sch,
     Array<tir::StmtSRef> producer_srefs = GetProducers(state, block_sref);
     if (producer_srefs.size() == 1 &&
         tir::IsCompleteBlock(sch->state(), producer_srefs[0], scope_block) &&
-        CanReverseComputeInline(state, block_sref)) {
+        CanReverseComputeInline(state, block_sref) &&
+        !GetAnn<String>(producer_srefs[0], tir::attr::meta_schedule_auto_tensorize).defined()) {
       return InlineType::kInlineIntoProducer;
     }
   }
diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling.cc b/src/meta_schedule/schedule_rule/multi_level_tiling.cc
index 2f2eb219e8..5f048dec00 100644
--- a/src/meta_schedule/schedule_rule/multi_level_tiling.cc
+++ b/src/meta_schedule/schedule_rule/multi_level_tiling.cc
@@ -61,6 +61,8 @@ using tir::IterVarType;
 using tir::LoopRV;
 using tir::Schedule;
 
+TVM_REGISTER_OBJECT_TYPE(StateNode);
+
 State::State(tir::Schedule sch, tir::BlockRV block_rv, Array<Array<tir::LoopRV>> tiles) {
   ObjectPtr<StateNode> node = make_object<StateNode>();
   node->sch = std::move(sch);
@@ -133,6 +135,7 @@ std::vector<State> MultiLevelTilingNode::AddWriteReuse(State state) const {
         new_state->sch->ReverseComputeAt(consumer_rvs[0], loop_rv, true);
         results.push_back(std::move(new_state));
       }
+      state->write_reuse.emplace(0, consumer_rvs[0]);
       results.push_back(state);
       return results;
     } else {
@@ -146,6 +149,7 @@ std::vector<State> MultiLevelTilingNode::AddWriteReuse(State state) const {
   BlockRV write_cache =
       state->sch->CacheWrite(/*block_rv=*/state->block_rv, /*read_buffer_index=*/0,
                              /*storage_scope=*/config.scope);
+  state->write_reuse.emplace(0, write_cache);
   for (int level : levels) {
     State new_state = state->Copy();
     const LoopRV& loop_rv = new_state->tiles[level - 1].back();
@@ -247,22 +251,26 @@ std::vector<State> MultiLevelTilingNode::AddReadReuse(State state) const {
       Array<LoopRV> buffer_loops = sch->GetLoops(cache_read_block);
       LoopRV fused = sch->Fuse(Array<LoopRV>{buffer_loops.end() - buffer_ndim,  //
                                              buffer_loops.end()});
-      // Annotate cooperative fetching
-      if (!vector_load_lens.empty()) {
-        int n = vector_load_lens.size();
-        double prob = 1.0 / n;
-        tir::ExprRV vector_load_len =
-            sch->SampleCategorical(support::AsArray<int, Integer>(vector_load_lens),
-                                   Array<FloatImm>(n, FloatImm(DataType::Float(64), prob)));
-        sch->Annotate(cache_read_block, tir::attr::meta_schedule_cooperative_fetch,
-                      vector_load_len);
-      }
+      AnnotateCooperativeFetching(&sch, cache_read_block);
+      new_state->read_reuse.emplace(i, cache_read_block);
     }
     results.push_back(std::move(new_state));
   }
   return results;
 }
 
+void MultiLevelTilingNode::AnnotateCooperativeFetching(Schedule* sch,
+                                                       const tir::BlockRV& block) const {
+  if (!vector_load_lens.empty()) {
+    int n = vector_load_lens.size();
+    double prob = 1.0 / n;
+    tir::ExprRV vector_load_len =
+        (*sch)->SampleCategorical(support::AsArray<int, Integer>(vector_load_lens),
+                                  Array<FloatImm>(n, FloatImm(DataType::Float(64), prob)));
+    (*sch)->Annotate(block, tir::attr::meta_schedule_cooperative_fetch, vector_load_len);
+  }
+}
+
 // Constructor
 
 ScheduleRule ScheduleRule::MultiLevelTiling(String structure, Optional<Array<String>> tile_binds,
diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling.h b/src/meta_schedule/schedule_rule/multi_level_tiling.h
index 05179318d0..3398255674 100644
--- a/src/meta_schedule/schedule_rule/multi_level_tiling.h
+++ b/src/meta_schedule/schedule_rule/multi_level_tiling.h
@@ -22,6 +22,7 @@
 #include <tvm/meta_schedule/schedule_rule.h>
 #include <tvm/tir/schedule/schedule.h>
 
+#include <unordered_map>
 #include <utility>
 #include <vector>
 
@@ -93,6 +94,10 @@ class StateNode : public Object {
   tir::BlockRV block_rv;
   /*! \brief The loop tiles */
   Array<Array<tir::LoopRV>> tiles;
+  /*! \brief The mapping from buffer index to read cache block. */
+  std::unordered_map<int, tir::BlockRV> read_reuse;
+  /*! \brief The mapping from buffer index to write cache block. */
+  std::unordered_map<int, tir::BlockRV> write_reuse;
 
   /*!
    * \brief Create a copy of the state. The underlying schedule is copied. Schedule rules that
@@ -148,11 +153,14 @@ class MultiLevelTilingNode : public ScheduleRuleNode {
   void InitializeWithTuneContext(const TuneContext& context) final;
 
   // Entry of the mega rule; Inherited from ScheduleRuleNode
-  Array<tir::Schedule> Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) final;
+  Array<tir::Schedule> Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) override;
 
  protected:
   virtual std::vector<State> ApplySubRules(std::vector<State> states);
 
+  // Annotate a block to use cooperative fetching
+  void AnnotateCooperativeFetching(tir::Schedule* sch, const tir::BlockRV& block) const;
+
  public:
   /*!
    * \brief The tiling structure. Recommended:
diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc b/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc
new file mode 100644
index 0000000000..91df62fc36
--- /dev/null
+++ b/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc
@@ -0,0 +1,393 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <tvm/meta_schedule/schedule_rule.h>
+
+#include <algorithm>
+#include <utility>
+#include <vector>
+
+#include "../utils.h"
+#include "./multi_level_tiling.h"
+
+namespace tvm {
+namespace meta_schedule {
+
+using tir::BlockRV;
+using tir::LoopRV;
+using tir::Schedule;
+
+struct TensorCoreIntrinGroup {
+  String init_intrin;
+  String load_a_intrin;
+  String load_b_intrin;
+  String compute_intrin;
+  String store_intrin;
+};
+
+class TensorCoreStateNode : public StateNode {
+ public:
+  /*! \brief The Tensor Core reindex block A for Tensor Core computation */
+  tir::BlockRV tensor_core_reindex_A;
+  /*! \brief The Tensor Core reindex block B for Tensor Core computation */
+  tir::BlockRV tensor_core_reindex_B;
+  /*! \brief The Tensor Core reindex store block for Tensor Core computation */
+  tir::BlockRV tensor_core_reindex_store;
+
+  State Copy() const final;
+
+  static constexpr const char* _type_key = "meta_schedule.TensorCoreState";
+  TVM_DECLARE_FINAL_OBJECT_INFO(TensorCoreStateNode, StateNode);
+};
+
+class TensorCoreState : public State {
+ public:
+  explicit TensorCoreState(tir::Schedule sch, tir::BlockRV block_rv,
+                           Array<Array<tir::LoopRV>> tiles = {});
+
+  TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(TensorCoreState, State, TensorCoreStateNode);
+};
+
+TVM_REGISTER_OBJECT_TYPE(TensorCoreStateNode);
+
+TensorCoreState::TensorCoreState(Schedule sch, BlockRV block_rv, Array<Array<LoopRV>> tiles) {
+  ObjectPtr<TensorCoreStateNode> node = make_object<TensorCoreStateNode>();
+  node->sch = std::move(sch);
+  node->block_rv = std::move(block_rv);
+  node->tiles = std::move(tiles);
+  data_ = std::move(node);
+}
+
+State TensorCoreStateNode::Copy() const {
+  ObjectPtr<TensorCoreStateNode> node = make_object<TensorCoreStateNode>(*this);
+  node->sch = sch->Copy();
+  return State(node);
+}
+
+/*!
+ * \brief Extension of MultiLevelTiling for auto-tensorizing with a single group of tensor core
+ * intrinsics.
+ */
+class MultiLevelTilingTensorCoreNode : public MultiLevelTilingNode {
+ private:
+  // SubRule: Add tensorization-related transformations
+  inline std::vector<State> TransformForTensorization(TensorCoreState state) const;
+  // Subrule: Add tensorized load
+  inline std::vector<State> AddReadReuseTensorCore(TensorCoreState state) const;
+  // Subrule: Add tensorized store
+  inline std::vector<State> AddWriteReuseTensorCore(TensorCoreState state) const;
+
+  // Override ApplySubRules to apply tensorization-specific sub-rules
+  std::vector<State> ApplySubRules(std::vector<State> states) final;
+
+  // Override Apply to apply tensorization-specific analysis before applying sub-rules
+  Array<Schedule> Apply(const Schedule& sch, const BlockRV& block_rv) final;
+
+  /*!
+   * \brief Transform and tensorize with the given tensor intrin
+   * \param state The state of the meta schedule rule
+   * \param intrin_name The name of the tensor intrin
+   * \return The loop to be tensorized. NullOpt if the workload can't be tensorized.
+   */
+  Optional<LoopRV> TransformWithTensorIntrin(TensorCoreStateNode* state,
+                                             const String& intrin_name) const;
+
+  /*!
+   * \brief Tile, blockize and annotate for tensorization with the given intrin
+   * \param block_rv The block to be tensorized
+   * \param intrin_name The name of the tensor intrin
+   */
+  void TileAndAnnotateTensorize(Schedule* sch, const BlockRV& block_rv,
+                                const String& intrin_name) const;
+
+ public:
+  /*! \brief The tensor core intrin group to apply */
+  TensorCoreIntrinGroup intrin_group;
+  static constexpr const char* _type_key = "meta_schedule.MultiLevelTilingTensorCore";
+  TVM_DECLARE_FINAL_OBJECT_INFO(MultiLevelTilingTensorCoreNode, MultiLevelTilingNode);
+
+ private:
+  /*!
+   * \brief The mapping info for auto tensorization
+   */
+  tir::AutoTensorizeMappingInfo mapping_info_{nullptr};
+};
+
+// Entry of the mega rule; Inherited from ScheduleRuleNode
+Array<Schedule> MultiLevelTilingTensorCoreNode::Apply(const Schedule& sch,
+                                                      const BlockRV& block_rv) {
+  if (!NeedsMultiLevelTiling(sch->state(), sch->GetSRef(block_rv))) {
+    return {sch};
+  }
+
+  Optional<tir::AutoTensorizeMappingInfo> mapping_info =
+      tir::GetAutoTensorizeMappingInfo(sch->state(), sch->GetSRef(block_rv),
+                                       tir::TensorIntrin::Get(intrin_group.compute_intrin)->desc);
+  if (!mapping_info.defined()) {
+    return {sch};
+  }
+  mapping_info_ = mapping_info.value();
+
+  // Create a copy of the schedule so that we can roll back transformations if tensorization
+  // fail.
+  Schedule original_sch = sch->Copy();
+  sch->Annotate(block_rv, tir::attr::meta_schedule_tiling_structure, structure);
+
+  Array<Schedule> results;
+  for (auto&& state : ApplySubRules({TensorCoreState(sch, block_rv)})) {
+    results.push_back(std::move(state->sch));
+  }
+  if (results.empty()) {
+    return {original_sch};
+  }
+  return results;
+}
+
+std::vector<State> MultiLevelTilingTensorCoreNode::ApplySubRules(std::vector<State> states) {
+  states = SubRule(std::move(states), [&](State state) {
+    return TransformForTensorization(Downcast<TensorCoreState>(state));
+  });
+  states = SubRule(std::move(states), [&](State state) { return TileLoopNest(state); });
+  states = SubRule(std::move(states), [&](State state) { return AddWriteReuse(state); });
+  states = SubRule(std::move(states), [&](State state) {
+    return AddWriteReuseTensorCore(Downcast<TensorCoreState>(state));
+  });
+  states = SubRule(std::move(states), [&](State state) { return AddReadReuse(state); });
+  states = SubRule(std::move(states), [&](State state) {
+    return AddReadReuseTensorCore(Downcast<TensorCoreState>(state));
+  });
+  return states;
+}
+
+void MultiLevelTilingTensorCoreNode::TileAndAnnotateTensorize(Schedule* sch,
+                                                              const BlockRV& block_rv,
+                                                              const String& intrin_name) const {
+  Optional<LoopRV> loop = TileWithTensorIntrin(*sch, block_rv, intrin_name).value();
+  ICHECK(loop.defined());
+  BlockRV blockized_outer = (*sch)->Blockize(loop.value());
+  (*sch)->Annotate(blockized_outer, tir::attr::meta_schedule_auto_tensorize, intrin_name);
+}
+
+std::vector<State> MultiLevelTilingTensorCoreNode::AddWriteReuseTensorCore(
+    TensorCoreState state) const {
+  // Add the cache write stage for Tensor Core
+  int level = r_indices_.front() - 1;
+  const LoopRV& loop = state->tiles[level].back();
+  Schedule& sch = state->sch;
+  auto cache_write = sch->CacheWrite(state->block_rv, 0, "wmma.accumulator");
+  sch->ReverseComputeAt(cache_write, loop, true);
+
+  if (state->write_reuse.count(0)) {
+    AnnotateCooperativeFetching(&sch, state->write_reuse[0]);
+  }
+  sch->ReverseComputeInline(state->tensor_core_reindex_store);
+  TileAndAnnotateTensorize(&sch, cache_write, intrin_group.store_intrin);
+  return {state};
+}
+
+std::vector<State> MultiLevelTilingTensorCoreNode::AddReadReuseTensorCore(
+    TensorCoreState state) const {
+  const Array<LoopRV>& r_tiles = state->tiles[r_indices_[1]];
+  Schedule& sch = state->sch;
+  ICHECK(!r_tiles.empty()) << "ValueError: Cannot find the suitable reduction loop in the block";
+
+  auto f_tensorize_load = [&](int read_index, String scope, String intrin_name) {
+    auto cache_read = sch->CacheRead(state->block_rv, read_index, scope);
+    state->sch->ComputeAt(cache_read, r_tiles.back(), true);
+    TileAndAnnotateTensorize(&sch, cache_read, intrin_name);
+  };
+
+  f_tensorize_load(0, "wmma.matrix_a", intrin_group.load_a_intrin);
+  f_tensorize_load(1, "wmma.matrix_b", intrin_group.load_b_intrin);
+  sch->ComputeInline(state->tensor_core_reindex_A);
+  sch->ComputeInline(state->tensor_core_reindex_B);
+
+  for (int i = 0; i < 2; ++i) {
+    const tir::BlockRV cache_read = state->read_reuse.at(i);
+    const tir::BlockNode* cache_read_block = sch->GetSRef(cache_read)->StmtAs<tir::BlockNode>();
+    tir::Buffer cache_read_buffer = tir::GetNthAccessBuffer(
+        sch->state(), GetRef<tir::Block>(cache_read_block), 0, tir::BufferIndexType::kWrite);
+    const DataType& dtype = cache_read_buffer->dtype;
+    if (dtype.is_float16()) {
+      sch->StorageAlign(cache_read, 0, -2, 32, 8);
+    } else if (dtype.is_int() && dtype.bits() == 8) {
+      sch->StorageAlign(cache_read, 0, -2, 32, 16);
+    } else {
+      LOG(WARNING) << "StorageAlign is not applied for data type " << dtype
+                   << ", shared memory accesses might be inefficient.";
+    }
+  }
+  return {state};
+}
+
+Optional<LoopRV> MultiLevelTilingTensorCoreNode::TransformWithTensorIntrin(
+    TensorCoreStateNode* state, const String& intrin_name) const {
+  BlockRV block_rv = state->block_rv;
+  tir::StmtSRef block_sref = state->sch->GetSRef(state->block_rv);
+
+  // Add reindex stages
+  const tir::BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
+  // Hold the reference of the block before reindex
+  const tir::Block block_before_reindex = GetRef<tir::Block>(block);
+  if (block->reads.size() != 2 || block->writes.size() != 1) {
+    // only matmul-like computation is allowed
+    return NullOpt;
+  }
+  state->tensor_core_reindex_store =
+      state->sch->ReIndex(state->block_rv, 0, tir::BufferIndexType::kWrite);
+  state->tensor_core_reindex_A =
+      state->sch->ReIndex(state->block_rv, 0, tir::BufferIndexType::kRead);
+  state->tensor_core_reindex_B =
+      state->sch->ReIndex(state->block_rv, 1, tir::BufferIndexType::kRead);
+
+  // Transform the layout of reindex buffers accordingly.
+  // The index map defines the mapping for the computation block. We need to extract the sub index
+  // map to transform the load and store block.
+  ICHECK_EQ(mapping_info_->mappings.size(), 1U);  // assume only one mapping is present
+  const tir::IndexMap& index_map = mapping_info_->mappings[0];
+
+  // Find the correspondence between block iters and the iters in the index map.
+  std::unordered_map<tir::Var, tir::Var, ObjectPtrHash, ObjectPtrEqual> lhs_to_index_map_src;
+  std::unordered_map<tir::Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> rhs_to_index_map_tgt;
+  std::unordered_set<tir::Var, ObjectPtrHash, ObjectPtrEqual> unmapped_index_map_src;
+  ICHECK_EQ(mapping_info_->lhs_iters.size(), index_map->initial_indices.size());
+  for (int i = 0; i < static_cast<int>(mapping_info_->lhs_iters.size()); ++i) {
+    lhs_to_index_map_src[mapping_info_->lhs_iters[i]->var] = index_map->initial_indices[i];
+  }
+  // The number of result iters in the index map is equal or more than the number of rhs (the
+  // tensor intrin) iters. When there are extra iters, these iters represent unmapped iters from the
+  // lhs. They will be skipped during pattern matching for tensorization.
+  // An example of such case is batch matmul, the batch dimension is kept after layout
+  // transformations and it will be kept as a outer loop after tensorization.
+  int offset = static_cast<int>(index_map->final_indices.size()) -
+               static_cast<int>(mapping_info_->rhs_iters.size());
+  ICHECK_GE(offset, 0);
+  for (int i = 0; i < offset; ++i) {
+    const tir::VarNode* var_ptr = index_map->final_indices[i].as<tir::VarNode>();
+    ICHECK(var_ptr != nullptr);
+    unmapped_index_map_src.insert(GetRef<tir::Var>(var_ptr));
+  }
+  for (int i = offset; i < static_cast<int>(index_map->final_indices.size()); ++i) {
+    rhs_to_index_map_tgt[mapping_info_->rhs_iters[i - offset]->var] = index_map->final_indices[i];
+  }
+
+  auto f_get_sub_index_map = [&](const tir::Buffer& lhs_buffer, const tir::Region& lhs_region) {
+    std::vector<tir::Var> sub_index_map_src;
+    std::vector<PrimExpr> sub_index_map_tgt;
+    const tir::Buffer& rhs_buffer = mapping_info_->lhs_buffer_map[lhs_buffer];
+    for (const Range& range : lhs_region) {
+      ICHECK(tir::is_one(range->extent));
+      const tir::VarNode* var_ptr = range->min.as<tir::VarNode>();
+      ICHECK(var_ptr != nullptr);
+      const tir::Var& lhs_representer = lhs_to_index_map_src[GetRef<tir::Var>(var_ptr)];
+      sub_index_map_src.push_back(lhs_representer);
+      if (unmapped_index_map_src.count(lhs_representer)) {
+        sub_index_map_tgt.push_back(lhs_representer);
+      }
+    }
+    for (size_t i = 0; i < mapping_info_->rhs_buffer_indices[rhs_buffer].size(); ++i) {
+      const tir::VarNode* var = mapping_info_->rhs_buffer_indices[rhs_buffer][i].as<tir::VarNode>();
+      ICHECK(var != nullptr);
+      sub_index_map_tgt.push_back(rhs_to_index_map_tgt[GetRef<tir::Var>(var)]);
+    }
+    return tir::IndexMap(sub_index_map_src, sub_index_map_tgt);
+  };
+
+  std::unordered_set<tir::Buffer, ObjectPtrHash, ObjectPtrEqual> visited_buffers;
+
+  auto f_transform_buffer_layout = [&](tir::BufferIndexType index_type, int buffer_index) {
+    const tir::Buffer& lhs_buffer = tir::GetNthAccessBuffer(
+        state->sch->state(), block_before_reindex, buffer_index, index_type);
+    if (visited_buffers.count(lhs_buffer)) {
+      return;
+    }
+    visited_buffers.insert(lhs_buffer);
+    // Refresh block pointer (block sref is not invalidated)
+    block = TVM_SREF_TO_BLOCK(block, block_sref);
+    const tir::BufferRegion& reindexed_buffer_region = tir::GetNthAccessBufferRegion(
+        state->sch->state(), GetRef<tir::Block>(block), buffer_index, index_type);
+    auto sub_index_map = f_get_sub_index_map(lhs_buffer, reindexed_buffer_region->region);
+    state->sch->TransformLayout(state->block_rv, buffer_index, index_type, sub_index_map);
+  };
+
+  for (int i = 0, n = block_before_reindex->reads.size(); i < n; ++i) {
+    f_transform_buffer_layout(tir::BufferIndexType::kRead, i);
+  }
+  for (int i = 0, n = block_before_reindex->writes.size(); i < n; ++i) {
+    f_transform_buffer_layout(tir::BufferIndexType::kWrite, i);
+  }
+
+  // Transform the layout of current block and reindex blocks
+  state->sch->TransformBlockLayout(state->tensor_core_reindex_store, index_map);
+  state->sch->TransformBlockLayout(state->tensor_core_reindex_A, index_map);
+  state->sch->TransformBlockLayout(state->tensor_core_reindex_B, index_map);
+  state->sch->TransformBlockLayout(state->block_rv, index_map);
+
+  return tir::TileWithTensorIntrin(state->sch, state->block_rv, intrin_name);
+}
+
+inline std::vector<State> MultiLevelTilingTensorCoreNode::TransformForTensorization(
+    TensorCoreState state) const {
+  // Do reindex and layout transformations.
+  Optional<LoopRV> transformed_loop_rv =
+      TransformWithTensorIntrin(state.operator->(), intrin_group.compute_intrin);
+  if (!transformed_loop_rv.defined()) {
+    // The workload can't be tensorized.
+    return {};
+  }
+
+  // Do blockize
+  state->block_rv = state->sch->Blockize(transformed_loop_rv.value());
+
+  // Add annotations for post processors.
+  state->sch->Annotate(state->block_rv, tir::attr::meta_schedule_auto_tensorize,
+                       intrin_group.compute_intrin);
+  state->sch->Annotate(state->block_rv, tir::attr::meta_schedule_auto_tensorize_init,
+                       intrin_group.init_intrin);
+  state->sch->Annotate(state->block_rv, tir::attr::warp_execution, Bool(true));
+  return {std::move(state)};
+}
+
+ScheduleRule ScheduleRule::MultiLevelTilingTensorCore(
+    Map<String, String> intrin_group, String structure, Optional<Array<String>> tile_binds,
+    Optional<Integer> max_innermost_factor, Optional<Array<Integer>> vector_load_lens,
+    Optional<Map<String, ObjectRef>> reuse_read, Optional<Map<String, ObjectRef>> reuse_write) {
+  auto node = MultiLevelTilingInitCommon<MultiLevelTilingTensorCoreNode>(
+      structure, tile_binds, max_innermost_factor, vector_load_lens, reuse_read, reuse_write);
+
+  auto f_initialize_intrin = [&intrin_group](String key_name, String* intrin_name) {
+    CHECK(intrin_group.count(key_name)) << "ValueError: " << key_name << " is not set.";
+    *intrin_name = intrin_group.at(key_name);
+    // Check the existence of the intrin
+    tir::TensorIntrin::Get(*intrin_name);
+  };
+  f_initialize_intrin("init", &node->intrin_group.init_intrin);
+  f_initialize_intrin("load_a", &node->intrin_group.load_a_intrin);
+  f_initialize_intrin("load_b", &node->intrin_group.load_b_intrin);
+  f_initialize_intrin("compute", &node->intrin_group.compute_intrin);
+  f_initialize_intrin("store", &node->intrin_group.store_intrin);
+
+  return ScheduleRule(node);
+}
+
+TVM_REGISTER_NODE_TYPE(MultiLevelTilingTensorCoreNode);
+TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleMultiLevelTilingTensorCore")
+    .set_body_typed(ScheduleRule::MultiLevelTilingTensorCore);
+
+}  // namespace meta_schedule
+}  // namespace tvm
diff --git a/src/tir/schedule/analysis.h b/src/tir/schedule/analysis.h
index 317b3625f0..37b3e00cc2 100644
--- a/src/tir/schedule/analysis.h
+++ b/src/tir/schedule/analysis.h
@@ -22,6 +22,7 @@
 #include <tvm/arith/analyzer.h>
 #include <tvm/ir/op.h>
 #include <tvm/tir/index_map.h>
+#include <tvm/tir/schedule/schedule.h>
 #include <tvm/tir/schedule/state.h>
 
 #include <tuple>
@@ -422,11 +423,24 @@ struct ProducerConsumerSplit {
  * \param self The schedule state.
  * \param block The queried block.
  * \param n The index of the queried buffer.
- * \param is_write A boolean flag to indicate querying write buffer or read buffer.
+ * \param index_type The type of the buffer index, kRead or kWrite.
  * \return The buffer of the n-th read/write region of the block.
  * \throw ScheduleError If the buffer index is out of bound.
  */
-Buffer GetNthAccessBuffer(const ScheduleState& self, const Block& block, int n, bool is_write);
+Buffer GetNthAccessBuffer(const ScheduleState& self, const Block& block, int n,
+                          BufferIndexType index_type);
+
+/*!
+ * \brief Get the n-th read or write buffer of the given block.
+ * \param self The schedule state.
+ * \param block The queried block.
+ * \param n The index of the queried buffer.
+ * \param index_type The type of the buffer index, kRead or kWrite.
+ * \return The n-th read/write region of the block.
+ * \throw ScheduleError If the buffer index is out of bound.
+ */
+BufferRegion GetNthAccessBufferRegion(const ScheduleState& self, const Block& block, int n,
+                                      BufferIndexType index_type);
 
 /*!
  * \brief Find the defining site of the buffer in the given block and its ancestors
diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc
index ac73ac3ce2..569259d061 100644
--- a/src/tir/schedule/analysis/analysis.cc
+++ b/src/tir/schedule/analysis/analysis.cc
@@ -1142,17 +1142,19 @@ ProducerConsumerSplit ProducerConsumerSplit::Find(
 
 /******** Block-buffer relation ********/
 
-Buffer GetNthAccessBuffer(const ScheduleState& self, const Block& block, int n, bool is_write) {
+BufferRegion GetNthAccessBufferRegion(const ScheduleState& self, const Block& block, int n,
+                                      BufferIndexType index_type) {
   class BufferIndexOutOfRangeError : public ScheduleError {
    public:
-    explicit BufferIndexOutOfRangeError(IRModule mod, Block block, int buffer_index, bool is_write)
+    explicit BufferIndexOutOfRangeError(IRModule mod, Block block, int buffer_index,
+                                        BufferIndexType index_type)
         : mod_(std::move(mod)),
           block_(std::move(block)),
           buffer_index_(buffer_index),
-          is_write_(is_write) {}
+          index_type_(index_type) {}
 
     String FastErrorString() const final {
-      if (is_write_) {
+      if (index_type_ == BufferIndexType::kWrite) {
         return "ScheduleError: The input `buffer_index` is out of range. It is required to be in "
                "range "
                "[0, num_write_regions) where `num_write_regions` is the number of buffer regions "
@@ -1167,9 +1169,9 @@ Buffer GetNthAccessBuffer(const ScheduleState& self, const Block& block, int n,
 
     String DetailRenderTemplate() const final {
       std::ostringstream os;
-      size_t num = is_write_ ? block_->writes.size() : block_->reads.size();
-      std::string access_type = is_write_ ? "write" : "read";
-      os << "The block {0} has " << num << " " << access_type
+      size_t num =
+          index_type_ == BufferIndexType::kWrite ? block_->writes.size() : block_->reads.size();
+      os << "The block {0} has " << num << " " << BufferIndexType2Str(index_type_)
          << " regions, so `buffer_index` is required to be in [0, " << num
          << "). However, the input `buffer_index` is " << buffer_index_
          << ", which is out of the expected range.";
@@ -1183,15 +1185,21 @@ Buffer GetNthAccessBuffer(const ScheduleState& self, const Block& block, int n,
     IRModule mod_;
     Block block_;
     int buffer_index_;
-    bool is_write_;
+    BufferIndexType index_type_;
   };
 
-  const Array<BufferRegion>& access_region = is_write ? block->writes : block->reads;
+  const Array<BufferRegion>& access_region =
+      index_type == BufferIndexType::kWrite ? block->writes : block->reads;
 
   if (n < 0 || static_cast<int>(access_region.size()) <= n) {
-    throw BufferIndexOutOfRangeError(self->mod, block, n, is_write);
+    throw BufferIndexOutOfRangeError(self->mod, block, n, index_type);
   }
-  return access_region[n]->buffer;
+  return access_region[n];
+}
+
+Buffer GetNthAccessBuffer(const ScheduleState& self, const Block& block, int n,
+                          BufferIndexType index_type) {
+  return GetNthAccessBufferRegion(self, block, n, index_type)->buffer;
 }
 
 std::pair<Optional<StmtSRef>, bool> GetBufferDefiningSite(const StmtSRef& block_sref,
@@ -1942,6 +1950,9 @@ bool IsTrivialBinding(const ScheduleState& self, const StmtSRef& block_sref) {
 }
 
 bool NeedsMultiLevelTiling(const ScheduleState& self, const StmtSRef& block_sref) {
+  if (HasBeenMultiLevelTiled(block_sref)) {
+    return false;
+  }
   const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
   if (block->writes.size() != 1 || block->reads.empty() || IsSpatial(block_sref) ||
       !IsTrivialBinding(self, block_sref)) {
diff --git a/src/tir/schedule/primitive/block_annotate.cc b/src/tir/schedule/primitive/block_annotate.cc
index ede239878a..2d876d9bf7 100644
--- a/src/tir/schedule/primitive/block_annotate.cc
+++ b/src/tir/schedule/primitive/block_annotate.cc
@@ -240,7 +240,7 @@ void StorageAlign(ScheduleState self, const StmtSRef& block_sref, int buffer_ind
                   int factor, int offset) {
   const BlockNode* block_ptr = TVM_SREF_TO_BLOCK(block_ptr, block_sref);
   Buffer buffer =
-      GetNthAccessBuffer(self, GetRef<Block>(block_ptr), buffer_index, /*is_write=*/true);
+      GetNthAccessBuffer(self, GetRef<Block>(block_ptr), buffer_index, BufferIndexType::kWrite);
   StorageAlignInvalidFactorError::Check(self->mod, factor);
   axis = StorageAlignAxisOutOfRangeError::CheckAndUpdate(self->mod, buffer, axis);
   NonAllocatedBufferError::CheckAndGetBufferAllocationSite(self->mod, block_sref, buffer);
@@ -275,7 +275,8 @@ void StorageAlign(ScheduleState self, const StmtSRef& block_sref, int buffer_ind
 void SetScope(ScheduleState self, const StmtSRef& block_sref, int buffer_index,
               const String& storage_scope) {
   const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
-  Buffer buffer = GetNthAccessBuffer(self, GetRef<Block>(block), buffer_index, true);
+  Buffer buffer =
+      GetNthAccessBuffer(self, GetRef<Block>(block), buffer_index, BufferIndexType::kWrite);
 
   // Step 1. If `storage_scope` equals the original storage scope of the buffer, just return.
   if (buffer.scope() == storage_scope) {
diff --git a/src/tir/schedule/primitive/cache_read_write.cc b/src/tir/schedule/primitive/cache_read_write.cc
index 6a7b59cfec..d28a659fd7 100644
--- a/src/tir/schedule/primitive/cache_read_write.cc
+++ b/src/tir/schedule/primitive/cache_read_write.cc
@@ -981,7 +981,7 @@ StmtSRef CacheRead(ScheduleState self, const StmtSRef& block_sref, int read_buff
   // Step 1. Check index, getting the target buffer and the parent scope
   const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
   Buffer read_buffer =
-      GetNthAccessBuffer(self, GetRef<Block>(block), read_buffer_index, /*is_write=*/false);
+      GetNthAccessBuffer(self, GetRef<Block>(block), read_buffer_index, BufferIndexType::kRead);
   StmtSRef scope_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/true);
   const BlockNode* scope_block = TVM_SREF_TO_BLOCK(scope_block, scope_sref);
 
@@ -1052,7 +1052,7 @@ StmtSRef CacheWrite(ScheduleState self, const StmtSRef& block_sref, int write_bu
   // Step 1. Checking index, getting the target buffer and the parent scope
   const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
   Buffer write_buffer =
-      GetNthAccessBuffer(self, GetRef<Block>(block), write_buffer_index, /*is_write=*/true);
+      GetNthAccessBuffer(self, GetRef<Block>(block), write_buffer_index, BufferIndexType::kWrite);
   StmtSRef scope_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/true);
 
   // Step 2. Creating CacheStageInfo
@@ -1094,8 +1094,7 @@ StmtSRef ReIndex(ScheduleState self, const StmtSRef& block_sref, int buffer_inde
                  BufferIndexType buffer_index_type) {
   const BlockNode* block_ptr = TVM_SREF_TO_BLOCK(block_ptr, block_sref);
   Block block = GetRef<Block>(block_ptr);
-  Buffer buffer =
-      GetNthAccessBuffer(self, block, buffer_index, buffer_index_type == BufferIndexType::kWrite);
+  Buffer buffer = GetNthAccessBuffer(self, block, buffer_index, buffer_index_type);
   StmtSRef scope_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/true);
   arith::Analyzer analyzer;
 
diff --git a/src/tir/schedule/primitive/layout_transformation.cc b/src/tir/schedule/primitive/layout_transformation.cc
index 639593ab3e..148b3ee033 100644
--- a/src/tir/schedule/primitive/layout_transformation.cc
+++ b/src/tir/schedule/primitive/layout_transformation.cc
@@ -136,8 +136,7 @@ void TransformLayout(ScheduleState self, const StmtSRef& block_sref, int buffer_
                      BufferIndexType buffer_index_type, const IndexMap& index_map) {
   const BlockNode* block_ptr = TVM_SREF_TO_BLOCK(block_ptr, block_sref);
   Buffer old_buffer =
-      GetNthAccessBuffer(self, GetRef<Block>(block_ptr), buffer_index,
-                         buffer_index_type == BufferIndexType::kRead ? false : true);
+      GetNthAccessBuffer(self, GetRef<Block>(block_ptr), buffer_index, buffer_index_type);
   Optional<StmtSRef> defining_site_sref;
   bool is_alloc;
   std::tie(defining_site_sref, is_alloc) = GetBufferDefiningSite(block_sref, old_buffer);
@@ -491,8 +490,8 @@ class BufferAxisSeparatorMutator : private ReplaceBufferMutator {
 void SetAxisSeparator(ScheduleState self, const StmtSRef& block_sref, int buffer_index,
                       BufferIndexType buffer_index_type, const Array<IntImm>& axis_separators) {
   const BlockNode* block_ptr = TVM_SREF_TO_BLOCK(block_ptr, block_sref);
-  Buffer old_buffer = GetNthAccessBuffer(self, GetRef<Block>(block_ptr), buffer_index,
-                                         buffer_index_type == BufferIndexType::kWrite);
+  Buffer old_buffer =
+      GetNthAccessBuffer(self, GetRef<Block>(block_ptr), buffer_index, buffer_index_type);
   Optional<StmtSRef> defining_site_sref;
   bool is_alloc;
   std::tie(defining_site_sref, is_alloc) = GetBufferDefiningSite(block_sref, old_buffer);
diff --git a/src/tir/schedule/transform.cc b/src/tir/schedule/transform.cc
index 436d529abd..a739373ab3 100644
--- a/src/tir/schedule/transform.cc
+++ b/src/tir/schedule/transform.cc
@@ -278,9 +278,9 @@ Optional<LoopRV> TileWithTensorIntrin(const tir::Schedule& sch, const tir::Block
     int64_t total = int_block_extent->value;
     int64_t inner = int_desc_extent->value;
     ICHECK_EQ(total % inner, 0);
-    int64_t outer = int_block_extent->value / int_desc_extent->value;
-    // Do the split
-    Array<LoopRV> split = sch->Split(loop2rv.at(block_loop_sref), {Integer(outer), Integer(inner)});
+    // Do the split. Leave the outer extent as NullOpt (unspecified) so that the split factors
+    // can be used for different extents (needed during tuning).
+    Array<LoopRV> split = sch->Split(loop2rv.at(block_loop_sref), {NullOpt, Integer(inner)});
     ICHECK_EQ(split.size(), 2);
     inner_loops.insert(sch->GetSRef(split[1]).operator->());
     // The inner split will be reordered to the loop domain that is tensorized
diff --git a/tests/python/unittest/test_meta_schedule_postproc_rewrite_tensorize.py b/tests/python/unittest/test_meta_schedule_postproc_rewrite_tensorize.py
index 6fae11c7fd..a1184c1edf 100644
--- a/tests/python/unittest/test_meta_schedule_postproc_rewrite_tensorize.py
+++ b/tests/python/unittest/test_meta_schedule_postproc_rewrite_tensorize.py
@@ -361,7 +361,7 @@ class DenseDP4ATensorized:
                             )
                             T.reads()
                             T.writes(compute_local[i, j])
-                            T.block_attr({"meta_schedule.auto_tensorize": "dp4a"})
+                            T.block_attr({"meta_schedule.auto_tensorize": ""})
                             with T.block("compute_init"):
                                 T.reads()
                                 T.writes(compute_local[i, j])
diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py b/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py
index 30511d6690..1ceef0afc3 100644
--- a/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py
+++ b/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py
@@ -16,11 +16,16 @@
 # under the License.
 # pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring
 import tvm
+import tvm.testing
 from tvm import te
 from tvm.meta_schedule import schedule_rule
 from tvm.meta_schedule.space_generator.post_order_apply import PostOrderApply
 from tvm.meta_schedule.testing import te_workload
-from tvm.meta_schedule.testing.schedule_rule import multi_level_tiling
+from tvm.meta_schedule.testing.schedule_rule import (
+    auto_inline,
+    multi_level_tiling,
+    multi_level_tiling_tensor_core,
+)
 from tvm.meta_schedule.testing.space_generation import check_trace
 from tvm.meta_schedule.tune_context import TuneContext
 from tvm.script import tir as T
@@ -31,11 +36,13 @@ from tvm.tir.tensor_intrin import VNNI_DOT_16x4_INTRIN as VNNI_INTRIN
 
 
 def _create_context(mod, target, rule) -> TuneContext:
+    if not isinstance(rule, (list, tuple)):
+        rule = [rule]
     ctx = TuneContext(
         mod=mod,
         target=target,
         space_generator=PostOrderApply(),
-        sch_rules=[rule],
+        sch_rules=rule,
         task_name="test",
     )
     return ctx
@@ -366,8 +373,8 @@ def test_multi_level_tiling_conv2d_nchwc_vnni():
         """b0 = sch.get_block(name="conv2d_NCHWc_int8", func_name="main")
 sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS")
 l1, l2, l3, l4, l5, l6, l7, l8, l9, l10 = sch.get_loops(block=b0)
-l11, l12 = sch.split(loop=l10, factors=[1, 4], preserve_unit_iters=True)
-l13, l14 = sch.split(loop=l5, factors=[1, 16], preserve_unit_iters=True)
+l11, l12 = sch.split(loop=l10, factors=[None, 4], preserve_unit_iters=True)
+l13, l14 = sch.split(loop=l5, factors=[None, 16], preserve_unit_iters=True)
 l15, l16, l17, l18, l19, l20, l21, l22, l23, l24, l25, l26 = sch.get_loops(block=b0)
 sch.reorder(l21, l22, l23, l24, l25, l14, l12)
 b27 = sch.blockize(loop=l14)
@@ -401,8 +408,8 @@ sch.reverse_compute_at(block=b98, loop=l75, preserve_unit_loops=True)""".split(
         """b0 = sch.get_block(name="conv2d_NCHWc_int8", func_name="main")
 sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS")
 l1, l2, l3, l4, l5, l6, l7, l8, l9, l10 = sch.get_loops(block=b0)
-l11, l12 = sch.split(loop=l10, factors=[1, 4], preserve_unit_iters=True)
-l13, l14 = sch.split(loop=l5, factors=[1, 16], preserve_unit_iters=True)
+l11, l12 = sch.split(loop=l10, factors=[None, 4], preserve_unit_iters=True)
+l13, l14 = sch.split(loop=l5, factors=[None, 16], preserve_unit_iters=True)
 l15, l16, l17, l18, l19, l20, l21, l22, l23, l24, l25, l26 = sch.get_loops(block=b0)
 sch.reorder(l21, l22, l23, l24, l25, l14, l12)
 b27 = sch.blockize(loop=l14)
@@ -436,8 +443,8 @@ sch.reverse_compute_at(block=b98, loop=l74, preserve_unit_loops=True)""".split(
         """b0 = sch.get_block(name="conv2d_NCHWc_int8", func_name="main")
 sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS")
 l1, l2, l3, l4, l5, l6, l7, l8, l9, l10 = sch.get_loops(block=b0)
-l11, l12 = sch.split(loop=l10, factors=[1, 4], preserve_unit_iters=True)
-l13, l14 = sch.split(loop=l5, factors=[1, 16], preserve_unit_iters=True)
+l11, l12 = sch.split(loop=l10, factors=[None, 4], preserve_unit_iters=True)
+l13, l14 = sch.split(loop=l5, factors=[None, 16], preserve_unit_iters=True)
 l15, l16, l17, l18, l19, l20, l21, l22, l23, l24, l25, l26 = sch.get_loops(block=b0)
 sch.reorder(l21, l22, l23, l24, l25, l14, l12)
 b27 = sch.blockize(loop=l14)
@@ -517,7 +524,7 @@ def test_multi_level_tiling_dense_dpa4():
         """b0 = sch.get_block(name="compute", func_name="main")
 sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS")
 l1, l2, l3 = sch.get_loops(block=b0)
-l4, l5 = sch.split(loop=l3, factors=[32, 4], preserve_unit_iters=True)
+l4, l5 = sch.split(loop=l3, factors=[None, 4], preserve_unit_iters=True)
 sch.reorder(l5)
 b6 = sch.blockize(loop=l5)
 sch.annotate(block_or_loop=b6, ann_key="meta_schedule.auto_tensorize", ann_val="dp4a")
@@ -556,11 +563,398 @@ sch.annotate(block_or_loop=b49, ann_key="meta_schedule.cooperative_fetch", ann_v
     check_trace(spaces, expected)
 
 
+def test_cuda_tensor_core_conv2d():
+    target = Target("cuda", host="llvm")
+    ctx = _create_context(
+        create_prim_func(
+            te_workload.conv2d_nhwc_f16(
+                N=1, H=16, W=16, CI=16, CO=16, kernel_size=3, stride=1, padding=1
+            )
+        ),
+        target,
+        multi_level_tiling_tensor_core(target=target, write_reuse_scope="shared"),
+    )
+    spaces = ctx.space_generator.generate_design_space(mod=ctx.mod)
+    assert len(spaces) == 1
+
+    expected = []
+    print("".join(spaces[0].trace.as_python()))
+    check_trace(spaces, expected)
+
+
+def test_cuda_tensor_core_matmul_relu():
+    m = n = k = 128
+    target = Target("cuda", host="llvm")
+    ctx = _create_context(
+        create_prim_func(
+            te_workload.matmul_relu_fp16(
+                n=n,
+                m=m,
+                k=k,
+            )
+        ),
+        target=target,
+        rule=[
+            multi_level_tiling_tensor_core(target=target, write_reuse_scope="shared"),
+            auto_inline(target),
+        ],
+    )
+    spaces = ctx.space_generator.generate_design_space(mod=ctx.mod)
+    assert len(spaces) == 1
+
+    expected = [
+        """b0 = sch.get_block(name="C", func_name="main")
+b1 = sch.get_block(name="compute", func_name="main")
+sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS")
+b2 = sch.reindex(block=b0, buffer=("write", 0))
+b3 = sch.reindex(block=b0, buffer=("read", 0))
+b4 = sch.reindex(block=b0, buffer=("read", 1))
+sch.transform_layout(block=b0, buffer=("read", 0), index_map=lambda i, k: (i, k, ))
+sch.transform_layout(block=b0, buffer=("read", 1), index_map=lambda j, k: (k, j, ))
+sch.transform_layout(block=b0, buffer=("write", 0), index_map=lambda i, j: (i, j, ))
+sch.transform_block_layout(block=b2, index_map=lambda i, j, k: (i, j, k, ))
+sch.transform_block_layout(block=b3, index_map=lambda i, j, k: (i, j, k, ))
+sch.transform_block_layout(block=b4, index_map=lambda i, j, k: (i, j, k, ))
+sch.transform_block_layout(block=b0, index_map=lambda i, j, k: (i, j, k, ))
+l5, l6, l7 = sch.get_loops(block=b0)
+l8, l9 = sch.split(loop=l7, factors=[None, 16], preserve_unit_iters=True)
+l10, l11 = sch.split(loop=l6, factors=[None, 16], preserve_unit_iters=True)
+l12, l13 = sch.split(loop=l5, factors=[None, 16], preserve_unit_iters=True)
+l14, l15, l16, l17, l18, l19 = sch.get_loops(block=b0)
+sch.reorder(l16, l18, l13, l11, l9)
+b20 = sch.blockize(loop=l13)
+sch.annotate(block_or_loop=b20, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_sync_16x16x16_f16f16f32")
+sch.annotate(block_or_loop=b20, ann_key="meta_schedule.auto_tensorize_init", ann_val="wmma_fill_16x16x16_f32")
+sch.annotate(block_or_loop=b20, ann_key="warp_execution", ann_val=1)
+l21, l22, l23 = sch.get_loops(block=b20)
+v24, v25, v26, v27, v28 = sch.sample_perfect_tile(loop=l21, n=5, max_innermost_factor=4)
+l29, l30, l31, l32, l33 = sch.split(loop=l21, factors=[v24, v25, v26, v27, v28], preserve_unit_iters=True)
+v34, v35, v36, v37, v38 = sch.sample_perfect_tile(loop=l22, n=5, max_innermost_factor=4)
+l39, l40, l41, l42, l43 = sch.split(loop=l22, factors=[v34, v35, v36, v37, v38], preserve_unit_iters=True)
+v44, v45, v46 = sch.sample_perfect_tile(loop=l23, n=3, max_innermost_factor=4)
+l47, l48, l49 = sch.split(loop=l23, factors=[v44, v45, v46], preserve_unit_iters=True)
+sch.reorder(l29, l39, l30, l40, l31, l41, l47, l48, l32, l42, l49, l33, l43)
+l50 = sch.fuse(l29, l39, preserve_unit_iters=True)
+sch.bind(loop=l50, thread_axis="blockIdx.y")
+l51 = sch.fuse(l30, l40, preserve_unit_iters=True)
+sch.bind(loop=l51, thread_axis="blockIdx.x")
+l52 = sch.fuse(l31, l41, preserve_unit_iters=True)
+sch.bind(loop=l52, thread_axis="threadIdx.y")
+b53 = sch.cache_write(block=b20, write_buffer_index=0, storage_scope="shared")
+sch.reverse_compute_at(block=b53, loop=l51, preserve_unit_loops=True)
+b54 = sch.cache_write(block=b20, write_buffer_index=0, storage_scope="wmma.accumulator")
+sch.reverse_compute_at(block=b54, loop=l52, preserve_unit_loops=True)
+v55 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])
+sch.annotate(block_or_loop=b53, ann_key="meta_schedule.cooperative_fetch", ann_val=v55)
+sch.reverse_compute_inline(block=b2)
+l56, l57, l58, l59, l60 = sch.get_loops(block=b54)
+l61, l62 = sch.split(loop=l60, factors=[None, 16], preserve_unit_iters=True)
+l63, l64 = sch.split(loop=l59, factors=[None, 16], preserve_unit_iters=True)
+l65, l66, l67, l68, l69, l70, l71 = sch.get_loops(block=b54)
+sch.reorder(l70, l64, l62)
+b72 = sch.blockize(loop=l64)
+sch.annotate(block_or_loop=b72, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_store_16x16x16_f32_shared")
+b73 = sch.cache_read(block=b20, read_buffer_index=0, storage_scope="shared")
+sch.compute_at(block=b73, loop=l47, preserve_unit_loops=True)
+l74, l75, l76, l77, l78, l79 = sch.get_loops(block=b73)
+l80 = sch.fuse(l78, l79, preserve_unit_iters=True)
+v81 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])
+sch.annotate(block_or_loop=b73, ann_key="meta_schedule.cooperative_fetch", ann_val=v81)
+b82 = sch.cache_read(block=b20, read_buffer_index=1, storage_scope="shared")
+sch.compute_at(block=b82, loop=l47, preserve_unit_loops=True)
+l83, l84, l85, l86, l87, l88 = sch.get_loops(block=b82)
+l89 = sch.fuse(l87, l88, preserve_unit_iters=True)
+v90 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])
+sch.annotate(block_or_loop=b82, ann_key="meta_schedule.cooperative_fetch", ann_val=v90)
+b91 = sch.cache_read(block=b20, read_buffer_index=0, storage_scope="wmma.matrix_a")
+sch.compute_at(block=b91, loop=l48, preserve_unit_loops=True)
+l92, l93, l94, l95, l96, l97, l98 = sch.get_loops(block=b91)
+l99, l100 = sch.split(loop=l98, factors=[None, 16], preserve_unit_iters=True)
+l101, l102 = sch.split(loop=l97, factors=[None, 16], preserve_unit_iters=True)
+l103, l104, l105, l106, l107, l108, l109, l110, l111 = sch.get_loops(block=b91)
+sch.reorder(l110, l102, l100)
+b112 = sch.blockize(loop=l102)
+sch.annotate(block_or_loop=b112, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_load_16x16x16_f16_a")
+b113 = sch.cache_read(block=b20, read_buffer_index=1, storage_scope="wmma.matrix_b")
+sch.compute_at(block=b113, loop=l48, preserve_unit_loops=True)
+l114, l115, l116, l117, l118, l119, l120 = sch.get_loops(block=b113)
+l121, l122 = sch.split(loop=l120, factors=[None, 16], preserve_unit_iters=True)
+l123, l124 = sch.split(loop=l119, factors=[None, 16], preserve_unit_iters=True)
+l125, l126, l127, l128, l129, l130, l131, l132, l133 = sch.get_loops(block=b113)
+sch.reorder(l132, l124, l122)
+b134 = sch.blockize(loop=l124)
+sch.annotate(block_or_loop=b134, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_load_16x16x16_f16_b")
+sch.compute_inline(block=b3)
+sch.compute_inline(block=b4)
+sch.storage_align(block=b73, buffer_index=0, axis=-2, factor=32, offset=8)
+sch.storage_align(block=b82, buffer_index=0, axis=-2, factor=32, offset=8)
+sch.reverse_compute_inline(block=b1)""".split(
+            "\n"
+        )
+    ]
+    check_trace(spaces, expected)
+
+    # test multi_level_tiling_tensor_core and multi_level_tiling can be used together in order
+    # to use multi_level_tiling as a fallback when the workload can't be tensorized
+    ctx = _create_context(
+        create_prim_func(
+            te_workload.matmul_relu_fp16(
+                n=n,
+                m=m,
+                k=k,
+            )
+        ),
+        target=target,
+        rule=[
+            multi_level_tiling_tensor_core(target=target, write_reuse_scope="shared"),
+            multi_level_tiling(target=target),
+            auto_inline(target),
+        ],
+    )
+    spaces = ctx.space_generator.generate_design_space(mod=ctx.mod)
+    assert len(spaces) == 1
+    check_trace(spaces, expected)
+
+
+def test_cuda_tensor_core_matmul_relu_global():
+    m = n = k = 128
+    target = Target("cuda", host="llvm")
+    ctx = _create_context(
+        create_prim_func(
+            te_workload.matmul_relu_fp16(
+                n=n,
+                m=m,
+                k=k,
+            ),
+        ),
+        target=target,
+        rule=[
+            multi_level_tiling_tensor_core(target=target, write_reuse_scope="global"),
+            auto_inline(target),
+        ],
+    )
+    spaces = ctx.space_generator.generate_design_space(mod=ctx.mod)
+    assert len(spaces) == 1
+
+    expected = [
+        """b0 = sch.get_block(name="C", func_name="main")
+sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS")
+b1 = sch.reindex(block=b0, buffer=("write", 0))
+b2 = sch.reindex(block=b0, buffer=("read", 0))
+b3 = sch.reindex(block=b0, buffer=("read", 1))
+sch.transform_layout(block=b0, buffer=("read", 0), index_map=lambda i, k: (i, k, ))
+sch.transform_layout(block=b0, buffer=("read", 1), index_map=lambda j, k: (k, j, ))
+sch.transform_layout(block=b0, buffer=("write", 0), index_map=lambda i, j: (i, j, ))
+sch.transform_block_layout(block=b1, index_map=lambda i, j, k: (i, j, k, ))
+sch.transform_block_layout(block=b2, index_map=lambda i, j, k: (i, j, k, ))
+sch.transform_block_layout(block=b3, index_map=lambda i, j, k: (i, j, k, ))
+sch.transform_block_layout(block=b0, index_map=lambda i, j, k: (i, j, k, ))
+l4, l5, l6 = sch.get_loops(block=b0)
+l7, l8 = sch.split(loop=l6, factors=[None, 16], preserve_unit_iters=True)
+l9, l10 = sch.split(loop=l5, factors=[None, 16], preserve_unit_iters=True)
+l11, l12 = sch.split(loop=l4, factors=[None, 16], preserve_unit_iters=True)
+l13, l14, l15, l16, l17, l18 = sch.get_loops(block=b0)
+sch.reorder(l15, l17, l12, l10, l8)
+b19 = sch.blockize(loop=l12)
+sch.annotate(block_or_loop=b19, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_sync_16x16x16_f16f16f32")
+sch.annotate(block_or_loop=b19, ann_key="meta_schedule.auto_tensorize_init", ann_val="wmma_fill_16x16x16_f32")
+sch.annotate(block_or_loop=b19, ann_key="warp_execution", ann_val=1)
+l20, l21, l22 = sch.get_loops(block=b19)
+v23, v24, v25, v26, v27 = sch.sample_perfect_tile(loop=l20, n=5, max_innermost_factor=4)
+l28, l29, l30, l31, l32 = sch.split(loop=l20, factors=[v23, v24, v25, v26, v27], preserve_unit_iters=True)
+v33, v34, v35, v36, v37 = sch.sample_perfect_tile(loop=l21, n=5, max_innermost_factor=4)
+l38, l39, l40, l41, l42 = sch.split(loop=l21, factors=[v33, v34, v35, v36, v37], preserve_unit_iters=True)
+v43, v44, v45 = sch.sample_perfect_tile(loop=l22, n=3, max_innermost_factor=4)
+l46, l47, l48 = sch.split(loop=l22, factors=[v43, v44, v45], preserve_unit_iters=True)
+sch.reorder(l28, l38, l29, l39, l30, l40, l46, l47, l31, l41, l48, l32, l42)
+l49 = sch.fuse(l28, l38, preserve_unit_iters=True)
+sch.bind(loop=l49, thread_axis="blockIdx.y")
+l50 = sch.fuse(l29, l39, preserve_unit_iters=True)
+sch.bind(loop=l50, thread_axis="blockIdx.x")
+l51 = sch.fuse(l30, l40, preserve_unit_iters=True)
+sch.bind(loop=l51, thread_axis="threadIdx.y")
+b52 = sch.cache_write(block=b19, write_buffer_index=0, storage_scope="wmma.accumulator")
+sch.reverse_compute_at(block=b52, loop=l51, preserve_unit_loops=True)
+sch.reverse_compute_inline(block=b1)
+l53, l54, l55, l56, l57 = sch.get_loops(block=b52)
+l58, l59 = sch.split(loop=l57, factors=[None, 16], preserve_unit_iters=True)
+l60, l61 = sch.split(loop=l56, factors=[None, 16], preserve_unit_iters=True)
+l62, l63, l64, l65, l66, l67, l68 = sch.get_loops(block=b52)
+sch.reorder(l67, l61, l59)
+b69 = sch.blockize(loop=l61)
+sch.annotate(block_or_loop=b69, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_store_16x16x16_f32_global")
+b70 = sch.cache_read(block=b19, read_buffer_index=0, storage_scope="shared")
+sch.compute_at(block=b70, loop=l46, preserve_unit_loops=True)
+l71, l72, l73, l74, l75, l76 = sch.get_loops(block=b70)
+l77 = sch.fuse(l75, l76, preserve_unit_iters=True)
+v78 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])
+sch.annotate(block_or_loop=b70, ann_key="meta_schedule.cooperative_fetch", ann_val=v78)
+b79 = sch.cache_read(block=b19, read_buffer_index=1, storage_scope="shared")
+sch.compute_at(block=b79, loop=l46, preserve_unit_loops=True)
+l80, l81, l82, l83, l84, l85 = sch.get_loops(block=b79)
+l86 = sch.fuse(l84, l85, preserve_unit_iters=True)
+v87 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])
+sch.annotate(block_or_loop=b79, ann_key="meta_schedule.cooperative_fetch", ann_val=v87)
+b88 = sch.cache_read(block=b19, read_buffer_index=0, storage_scope="wmma.matrix_a")
+sch.compute_at(block=b88, loop=l47, preserve_unit_loops=True)
+l89, l90, l91, l92, l93, l94, l95 = sch.get_loops(block=b88)
+l96, l97 = sch.split(loop=l95, factors=[None, 16], preserve_unit_iters=True)
+l98, l99 = sch.split(loop=l94, factors=[None, 16], preserve_unit_iters=True)
+l100, l101, l102, l103, l104, l105, l106, l107, l108 = sch.get_loops(block=b88)
+sch.reorder(l107, l99, l97)
+b109 = sch.blockize(loop=l99)
+sch.annotate(block_or_loop=b109, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_load_16x16x16_f16_a")
+b110 = sch.cache_read(block=b19, read_buffer_index=1, storage_scope="wmma.matrix_b")
+sch.compute_at(block=b110, loop=l47, preserve_unit_loops=True)
+l111, l112, l113, l114, l115, l116, l117 = sch.get_loops(block=b110)
+l118, l119 = sch.split(loop=l117, factors=[None, 16], preserve_unit_iters=True)
+l120, l121 = sch.split(loop=l116, factors=[None, 16], preserve_unit_iters=True)
+l122, l123, l124, l125, l126, l127, l128, l129, l130 = sch.get_loops(block=b110)
+sch.reorder(l129, l121, l119)
+b131 = sch.blockize(loop=l121)
+sch.annotate(block_or_loop=b131, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_load_16x16x16_f16_b")
+sch.compute_inline(block=b2)
+sch.compute_inline(block=b3)
+sch.storage_align(block=b70, buffer_index=0, axis=-2, factor=32, offset=8)
+sch.storage_align(block=b79, buffer_index=0, axis=-2, factor=32, offset=8)""".split(
+            "\n"
+        )
+    ]
+    check_trace(spaces, expected)
+
+
+def test_multi_level_tiling_non_tensorizable():
+    # expected to do nothing on non-tensorizable workloads
+    m = n = k = 128
+    target = Target("cuda", host="llvm")
+    ctx = _create_context(
+        create_prim_func(
+            # dtype doesn't match tensor intrin
+            te_workload.matmul_relu(
+                n=n,
+                m=m,
+                k=k,
+            )
+        ),
+        target=target,
+        rule=multi_level_tiling_tensor_core(target=target, write_reuse_scope="global"),
+    )
+    spaces = ctx.space_generator.generate_design_space(mod=ctx.mod)
+    assert len(spaces) == 1
+
+    expected = [
+        "",  # expected to do nothing when the workload can't be tensorized
+    ]
+    check_trace(spaces, expected)
+
+
+def test_cuda_tensor_core_conv2d():
+    target = Target("cuda", host="llvm")
+    ctx = _create_context(
+        create_prim_func(
+            # dtype doesn't match tensor intrin
+            te_workload.conv2d_nhwc_f16(
+                N=1, H=16, W=16, CI=32, CO=32, kernel_size=3, stride=1, padding=1
+            )
+        ),
+        target=target,
+        rule=multi_level_tiling_tensor_core(target=target, write_reuse_scope="shared"),
+    )
+    spaces = ctx.space_generator.generate_design_space(mod=ctx.mod)
+    assert len(spaces) == 1
+
+    expected = [
+        """b0 = sch.get_block(name="conv2d_nhwc", func_name="main")
+sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS")
+b1 = sch.reindex(block=b0, buffer=("write", 0))
+b2 = sch.reindex(block=b0, buffer=("read", 0))
+b3 = sch.reindex(block=b0, buffer=("read", 1))
+sch.transform_layout(block=b0, buffer=("read", 0), index_map=lambda h, w, rh, rw, rc: (((h*16) + w), (((rh*96) + (rw*32)) + rc), ))
+sch.transform_layout(block=b0, buffer=("read", 1), index_map=lambda co, rh, rw, rc: ((((rh*96) + (rw*32)) + rc), co, ))
+sch.transform_layout(block=b0, buffer=("write", 0), index_map=lambda h, w, co: (((h*16) + w), co, ))
+sch.transform_block_layout(block=b1, index_map=lambda n, h, w, co, rh, rw, rc: (n, ((h*16) + w), co, (((rh*96) + (rw*32)) + rc), ))
+sch.transform_block_layout(block=b2, index_map=lambda n, h, w, co, rh, rw, rc: (n, ((h*16) + w), co, (((rh*96) + (rw*32)) + rc), ))
+sch.transform_block_layout(block=b3, index_map=lambda n, h, w, co, rh, rw, rc: (n, ((h*16) + w), co, (((rh*96) + (rw*32)) + rc), ))
+sch.transform_block_layout(block=b0, index_map=lambda n, h, w, co, rh, rw, rc: (n, ((h*16) + w), co, (((rh*96) + (rw*32)) + rc), ))
+l4, l5, l6, l7 = sch.get_loops(block=b0)
+l8, l9 = sch.split(loop=l7, factors=[None, 16], preserve_unit_iters=True)
+l10, l11 = sch.split(loop=l6, factors=[None, 16], preserve_unit_iters=True)
+l12, l13 = sch.split(loop=l5, factors=[None, 16], preserve_unit_iters=True)
+l14, l15, l16, l17, l18, l19, l20 = sch.get_loops(block=b0)
+sch.reorder(l17, l19, l13, l11, l9)
+b21 = sch.blockize(loop=l13)
+sch.annotate(block_or_loop=b21, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_sync_16x16x16_f16f16f32")
+sch.annotate(block_or_loop=b21, ann_key="meta_schedule.auto_tensorize_init", ann_val="wmma_fill_16x16x16_f32")
+sch.annotate(block_or_loop=b21, ann_key="warp_execution", ann_val=1)
+l22, l23, l24, l25 = sch.get_loops(block=b21)
+v26, v27, v28, v29, v30 = sch.sample_perfect_tile(loop=l22, n=5, max_innermost_factor=4)
+l31, l32, l33, l34, l35 = sch.split(loop=l22, factors=[v26, v27, v28, v29, v30], preserve_unit_iters=True)
+v36, v37, v38, v39, v40 = sch.sample_perfect_tile(loop=l23, n=5, max_innermost_factor=4)
+l41, l42, l43, l44, l45 = sch.split(loop=l23, factors=[v36, v37, v38, v39, v40], preserve_unit_iters=True)
+v46, v47, v48, v49, v50 = sch.sample_perfect_tile(loop=l24, n=5, max_innermost_factor=4)
+l51, l52, l53, l54, l55 = sch.split(loop=l24, factors=[v46, v47, v48, v49, v50], preserve_unit_iters=True)
+v56, v57, v58 = sch.sample_perfect_tile(loop=l25, n=3, max_innermost_factor=4)
+l59, l60, l61 = sch.split(loop=l25, factors=[v56, v57, v58], preserve_unit_iters=True)
+sch.reorder(l31, l41, l51, l32, l42, l52, l33, l43, l53, l59, l60, l34, l44, l54, l61, l35, l45, l55)
+l62 = sch.fuse(l31, l41, l51, preserve_unit_iters=True)
+sch.bind(loop=l62, thread_axis="blockIdx.y")
+l63 = sch.fuse(l32, l42, l52, preserve_unit_iters=True)
+sch.bind(loop=l63, thread_axis="blockIdx.x")
+l64 = sch.fuse(l33, l43, l53, preserve_unit_iters=True)
+sch.bind(loop=l64, thread_axis="threadIdx.y")
+b65 = sch.cache_write(block=b21, write_buffer_index=0, storage_scope="shared")
+sch.reverse_compute_at(block=b65, loop=l63, preserve_unit_loops=True)
+b66 = sch.cache_write(block=b21, write_buffer_index=0, storage_scope="wmma.accumulator")
+sch.reverse_compute_at(block=b66, loop=l64, preserve_unit_loops=True)
+v67 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])
+sch.annotate(block_or_loop=b65, ann_key="meta_schedule.cooperative_fetch", ann_val=v67)
+sch.reverse_compute_inline(block=b1)
+l68, l69, l70, l71, l72 = sch.get_loops(block=b66)
+l73, l74 = sch.split(loop=l72, factors=[None, 16], preserve_unit_iters=True)
+l75, l76 = sch.split(loop=l71, factors=[None, 16], preserve_unit_iters=True)
+l77, l78, l79, l80, l81, l82, l83 = sch.get_loops(block=b66)
+sch.reorder(l82, l76, l74)
+b84 = sch.blockize(loop=l76)
+sch.annotate(block_or_loop=b84, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_store_16x16x16_f32_shared")
+b85 = sch.cache_read(block=b21, read_buffer_index=0, storage_scope="shared")
+sch.compute_at(block=b85, loop=l59, preserve_unit_loops=True)
+l86, l87, l88, l89, l90, l91 = sch.get_loops(block=b85)
+l92 = sch.fuse(l90, l91, preserve_unit_iters=True)
+v93 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])
+sch.annotate(block_or_loop=b85, ann_key="meta_schedule.cooperative_fetch", ann_val=v93)
+b94 = sch.cache_read(block=b21, read_buffer_index=1, storage_scope="shared")
+sch.compute_at(block=b94, loop=l59, preserve_unit_loops=True)
+l95, l96, l97, l98, l99, l100 = sch.get_loops(block=b94)
+l101 = sch.fuse(l99, l100, preserve_unit_iters=True)
+v102 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])
+sch.annotate(block_or_loop=b94, ann_key="meta_schedule.cooperative_fetch", ann_val=v102)
+b103 = sch.cache_read(block=b21, read_buffer_index=0, storage_scope="wmma.matrix_a")
+sch.compute_at(block=b103, loop=l60, preserve_unit_loops=True)
+l104, l105, l106, l107, l108, l109, l110 = sch.get_loops(block=b103)
+l111, l112 = sch.split(loop=l110, factors=[None, 16], preserve_unit_iters=True)
+l113, l114 = sch.split(loop=l109, factors=[None, 16], preserve_unit_iters=True)
+l115, l116, l117, l118, l119, l120, l121, l122, l123 = sch.get_loops(block=b103)
+sch.reorder(l122, l114, l112)
+b124 = sch.blockize(loop=l114)
+sch.annotate(block_or_loop=b124, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_load_16x16x16_f16_a")
+b125 = sch.cache_read(block=b21, read_buffer_index=1, storage_scope="wmma.matrix_b")
+sch.compute_at(block=b125, loop=l60, preserve_unit_loops=True)
+l126, l127, l128, l129, l130, l131, l132 = sch.get_loops(block=b125)
+l133, l134 = sch.split(loop=l132, factors=[None, 16], preserve_unit_iters=True)
+l135, l136 = sch.split(loop=l131, factors=[None, 16], preserve_unit_iters=True)
+l137, l138, l139, l140, l141, l142, l143, l144, l145 = sch.get_loops(block=b125)
+sch.reorder(l144, l136, l134)
+b146 = sch.blockize(loop=l136)
+sch.annotate(block_or_loop=b146, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_load_16x16x16_f16_b")
+sch.compute_inline(block=b2)
+sch.compute_inline(block=b3)
+sch.storage_align(block=b85, buffer_index=0, axis=-2, factor=32, offset=8)
+sch.storage_align(block=b94, buffer_index=0, axis=-2, factor=32, offset=8)""".split(
+            "\n"
+        )
+    ]
+    check_trace(spaces, expected)
+
+
 if __name__ == "__main__":
-    test_cpu_matmul()
-    test_cpu_matmul_relu()
-    test_cuda_matmul()
-    test_cuda_matmul_relu()
-    test_cuda_sum_with_trivial_block_iter()
-    test_multi_level_tiling_conv2d_nchwc_vnni()
-    test_multi_level_tiling_dense_dpa4()
+    tvm.testing.main()