You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2022/07/16 04:17:38 UTC

[tvm] branch main updated: [MetaSchedule] Allow MultiLevelTilingTensorCore rule to specify multiple tensor intrin groups (#12113)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 895f79f42a [MetaSchedule] Allow MultiLevelTilingTensorCore rule to specify multiple tensor intrin groups (#12113)
895f79f42a is described below

commit 895f79f42a1c798f5b70b0689ae26ae159031302
Author: Wuwei Lin <wu...@apache.org>
AuthorDate: Fri Jul 15 21:17:31 2022 -0700

    [MetaSchedule] Allow MultiLevelTilingTensorCore rule to specify multiple tensor intrin groups (#12113)
---
 include/tvm/meta_schedule/schedule_rule.h          |  21 +--
 .../schedule_rule/multi_level_tiling.py            |  16 +-
 python/tvm/meta_schedule/testing/schedule_rule.py  |  26 +++-
 .../multi_level_tiling_tensor_core.cc              | 145 ++++++++++++-------
 ...ta_schedule_schedule_rule_multi_level_tiling.py | 161 +++++++++++++++++----
 5 files changed, 260 insertions(+), 109 deletions(-)

diff --git a/include/tvm/meta_schedule/schedule_rule.h b/include/tvm/meta_schedule/schedule_rule.h
index 5e4698db17..b5f4a17b69 100644
--- a/include/tvm/meta_schedule/schedule_rule.h
+++ b/include/tvm/meta_schedule/schedule_rule.h
@@ -174,13 +174,13 @@ class ScheduleRule : public runtime::ObjectRef {
       Optional<Map<String, ObjectRef>> reuse_read, Optional<Map<String, ObjectRef>> reuse_write);
 
   /*!
-   * \brief Extension of MultiLevelTiling for auto-tensorizing with a single group of tensor core
-   * intrinsics
-   * \param intrin_group A group of tensor core intrinsics. The map should contains key "init",
-   * "load_a", "load_b", "compute", "store", which represent the tensor intrin for initialization,
-   * loading operand A, loading operand B, tensor core computation, storing the result. The value of
-   * the map should be names of tensor intrinsics, must be registerd via TensorIntrin.register(...)
-   * beforehand
+   * \brief Extension of MultiLevelTiling for auto-tensorizing with multiple groups of candidate
+   * tensor core intrinsics
+   * \param intrin_groups A list of groups of tensor core intrinsics. The map should contains key
+   * "init", "load_a", "load_b", "compute", "store", which represent the tensor intrin for
+   * initialization, loading operand A, loading operand B, tensor core computation, storing the
+   * result. The value of the map should be names of tensor intrinsics, must be registerd via
+   * TensorIntrin.register(...) beforehand
    * \param structure The tiling structure. Recommended:
    * - 'SSSRRSRS' on GPU
    * \param tile_binds For each level of tiles, which thread axis it is bound to. Recommended:
@@ -193,9 +193,10 @@ class ScheduleRule : public runtime::ObjectRef {
    * \return The schedule rule created
    */
   TVM_DLL static ScheduleRule MultiLevelTilingTensorCore(
-      Map<String, String> intrin_group, String structure, Optional<Array<String>> tile_binds,
-      Optional<Integer> max_innermost_factor, Optional<Array<Integer>> vector_load_lens,
-      Optional<Map<String, ObjectRef>> reuse_read, Optional<Map<String, ObjectRef>> reuse_write);
+      Array<Map<String, String>> intrin_groups, String structure,
+      Optional<Array<String>> tile_binds, Optional<Integer> max_innermost_factor,
+      Optional<Array<Integer>> vector_load_lens, Optional<Map<String, ObjectRef>> reuse_read,
+      Optional<Map<String, ObjectRef>> reuse_write);
 
   /*!
    * \brief Create a rule: add-rfactor to some blocks if needed
diff --git a/python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py b/python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py
index 71fbaee4f6..a728a91eb7 100644
--- a/python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py
+++ b/python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py
@@ -135,15 +135,15 @@ class MultiLevelTilingWithIntrin(ScheduleRule):
 
 @register_object("meta_schedule.MultiLevelTilingTensorCore")
 class MultiLevelTilingTensorCore(ScheduleRule):
-    """Extension of MultiLevelTiling for auto-tensorizing with a single group of tensor core
-    intrinsics.
+    """Extension of MultiLevelTiling for auto-tensorizing with multiple groups of candidate tensor
+    core intrinsics.
 
     Parameters
     ----------
-    intrin_group : Mapping[str, str]
-        A group of tensor core intrinsics. The map should contains key "init", "load_a", "load_b",
-        "compute", "store", which represent the tensor intrin for initialization, loading operand A,
-        loading operand B, tensor core computation, storing the result.
+    intrin_groups : List[Mapping[str, str]]
+        A list of groups of tensor core intrinsics. The map should contains key "init", "load_a",
+        "load_b", "compute", "store", which represent the tensor intrin for initialization,
+        loading operand A, loading operand B, tensor core computation, storing the result.
         The value of the map should be names of tensor intrinsics, must be registerd via
         TensorIntrin.register(...) beforehand
     structure : str
@@ -165,7 +165,7 @@ class MultiLevelTilingTensorCore(ScheduleRule):
 
     def __init__(
         self,
-        intrin_group: Mapping[str, str],
+        intrin_groups: List[Mapping[str, str]],
         structure: str,
         tile_binds: Optional[List[str]] = None,
         max_innermost_factor: Optional[int] = None,
@@ -175,7 +175,7 @@ class MultiLevelTilingTensorCore(ScheduleRule):
     ) -> None:
         self.__init_handle_by_constructor__(
             _ffi_api.ScheduleRuleMultiLevelTilingTensorCore,  # type: ignore # pylint: disable=no-member
-            intrin_group,
+            intrin_groups,
             structure,
             tile_binds,
             max_innermost_factor,
diff --git a/python/tvm/meta_schedule/testing/schedule_rule.py b/python/tvm/meta_schedule/testing/schedule_rule.py
index 717be59512..ea748ddc05 100644
--- a/python/tvm/meta_schedule/testing/schedule_rule.py
+++ b/python/tvm/meta_schedule/testing/schedule_rule.py
@@ -15,6 +15,7 @@
 # specific language governing permissions and limitations
 # under the License.
 """Default schedule rules"""
+from typing import List, Union
 from tvm.meta_schedule.schedule_rule import (
     AddRFactor,
     AutoBind,
@@ -114,18 +115,29 @@ def multi_level_tiling(target: Target) -> ScheduleRule:
 
 def multi_level_tiling_tensor_core(
     target: Target,
-    write_reuse_scope="shared",
-    in_dtype="float16",
-    out_dtype="float32",
-    trans_b=False,
+    write_reuse_scope: str = "shared",
+    in_dtype: Union[str, List[str]] = "float16",
+    out_dtype: Union[str, List[str]] = "float32",
+    trans_b: Union[bool, List[bool]] = False,
 ) -> ScheduleRule:
     """Default schedule rules for with multi-level tiling reuse for tensor core"""
     assert write_reuse_scope in ["shared", "global"]
+    if not isinstance(in_dtype, list):
+        in_dtype = [in_dtype]
+    if not isinstance(out_dtype, list):
+        out_dtype = [out_dtype]
+    if not isinstance(trans_b, list):
+        trans_b = [trans_b]
+
     if target.kind.name == "cuda":
+        intrin_groups = [
+            tensor_intrin.get_wmma_intrin_group(write_reuse_scope, _in_dtype, _out_dtype, _trans_b)
+            for _in_dtype in in_dtype
+            for _out_dtype in out_dtype
+            for _trans_b in trans_b
+        ]
         return MultiLevelTilingTensorCore(
-            intrin_group=tensor_intrin.get_wmma_intrin_group(
-                write_reuse_scope, in_dtype, out_dtype, trans_b
-            ),
+            intrin_groups=intrin_groups,
             structure="SSSRRSRS",
             tile_binds=["blockIdx.y", "blockIdx.x", "threadIdx.y"],
             max_innermost_factor=4,  # 64 // tensor intrin size
diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc b/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc
index 91df62fc36..6d34f7b64e 100644
--- a/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc
+++ b/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc
@@ -38,10 +38,42 @@ struct TensorCoreIntrinGroup {
   String load_b_intrin;
   String compute_intrin;
   String store_intrin;
+
+  /*! \brief Create TensorCoreIntrinGroup from config in a map. The map should contains the
+   * following keys:
+   *  - init
+   *  - load_a
+   *  - load_b
+   *  - compute
+   *  - store
+   * The values of the keys should be the names of the corresponding intrinsics and should be
+   * registered via TensorIntrin.Register beforehand.
+   */
+  static TensorCoreIntrinGroup FromConfig(const Map<String, String>& config);
 };
 
+TensorCoreIntrinGroup TensorCoreIntrinGroup::FromConfig(const Map<String, String>& config) {
+  auto f_initialize_intrin = [&config](String key_name, String* intrin_name) {
+    CHECK(config.count(key_name)) << "ValueError: " << key_name << " is not set.";
+    *intrin_name = config.at(key_name);
+    // Check the existence of the intrin
+    tir::TensorIntrin::Get(*intrin_name);
+  };
+  TensorCoreIntrinGroup intrin_group;
+  f_initialize_intrin("init", &intrin_group.init_intrin);
+  f_initialize_intrin("load_a", &intrin_group.load_a_intrin);
+  f_initialize_intrin("load_b", &intrin_group.load_b_intrin);
+  f_initialize_intrin("compute", &intrin_group.compute_intrin);
+  f_initialize_intrin("store", &intrin_group.store_intrin);
+  return intrin_group;
+}
+
 class TensorCoreStateNode : public StateNode {
  public:
+  /*! \brief The tensor core intrinsic group. */
+  TensorCoreIntrinGroup intrin_group;
+  /*! \brief The auto tensorization maping info. */
+  tir::AutoTensorizeMappingInfo mapping_info{nullptr};
   /*! \brief The Tensor Core reindex block A for Tensor Core computation */
   tir::BlockRV tensor_core_reindex_A;
   /*! \brief The Tensor Core reindex block B for Tensor Core computation */
@@ -57,16 +89,21 @@ class TensorCoreStateNode : public StateNode {
 
 class TensorCoreState : public State {
  public:
-  explicit TensorCoreState(tir::Schedule sch, tir::BlockRV block_rv,
-                           Array<Array<tir::LoopRV>> tiles = {});
+  explicit TensorCoreState(TensorCoreIntrinGroup intrin_group,
+                           tir::AutoTensorizeMappingInfo mapping_info, Schedule sch,
+                           BlockRV block_rv, Array<Array<tir::LoopRV>> tiles = {});
 
   TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(TensorCoreState, State, TensorCoreStateNode);
 };
 
 TVM_REGISTER_OBJECT_TYPE(TensorCoreStateNode);
 
-TensorCoreState::TensorCoreState(Schedule sch, BlockRV block_rv, Array<Array<LoopRV>> tiles) {
+TensorCoreState::TensorCoreState(TensorCoreIntrinGroup intrin_group,
+                                 tir::AutoTensorizeMappingInfo mapping_info, Schedule sch,
+                                 BlockRV block_rv, Array<Array<LoopRV>> tiles) {
   ObjectPtr<TensorCoreStateNode> node = make_object<TensorCoreStateNode>();
+  node->intrin_group = intrin_group;
+  node->mapping_info = mapping_info;
   node->sch = std::move(sch);
   node->block_rv = std::move(block_rv);
   node->tiles = std::move(tiles);
@@ -116,16 +153,12 @@ class MultiLevelTilingTensorCoreNode : public MultiLevelTilingNode {
                                 const String& intrin_name) const;
 
  public:
-  /*! \brief The tensor core intrin group to apply */
-  TensorCoreIntrinGroup intrin_group;
+  /*! \brief The candidate tensor core intrin groups to apply */
+  std::vector<TensorCoreIntrinGroup> intrin_groups;
   static constexpr const char* _type_key = "meta_schedule.MultiLevelTilingTensorCore";
   TVM_DECLARE_FINAL_OBJECT_INFO(MultiLevelTilingTensorCoreNode, MultiLevelTilingNode);
 
  private:
-  /*!
-   * \brief The mapping info for auto tensorization
-   */
-  tir::AutoTensorizeMappingInfo mapping_info_{nullptr};
 };
 
 // Entry of the mega rule; Inherited from ScheduleRuleNode
@@ -135,21 +168,36 @@ Array<Schedule> MultiLevelTilingTensorCoreNode::Apply(const Schedule& sch,
     return {sch};
   }
 
-  Optional<tir::AutoTensorizeMappingInfo> mapping_info =
-      tir::GetAutoTensorizeMappingInfo(sch->state(), sch->GetSRef(block_rv),
-                                       tir::TensorIntrin::Get(intrin_group.compute_intrin)->desc);
-  if (!mapping_info.defined()) {
+  std::unordered_map<int, tir::AutoTensorizeMappingInfo> intrin_group_to_mapping_info;
+  for (int i = 0, n = intrin_groups.size(); i < n; ++i) {
+    TensorCoreIntrinGroup intrin_group = intrin_groups[i];
+    Optional<tir::AutoTensorizeMappingInfo> mapping_info = tir::GetAutoTensorizeMappingInfo(
+        sch->state(), sch->GetSRef(block_rv),
+        tir::TensorIntrin::Get(intrin_groups[i].compute_intrin)->desc);
+    if (mapping_info.defined()) {
+      intrin_group_to_mapping_info.emplace(i, mapping_info.value());
+    }
+  }
+
+  if (intrin_group_to_mapping_info.empty()) {
+    // No tensor intrinsics can be applied.
     return {sch};
   }
-  mapping_info_ = mapping_info.value();
 
-  // Create a copy of the schedule so that we can roll back transformations if tensorization
+  // Save the original schedule so that we can roll back transformations if tensorization
   // fail.
-  Schedule original_sch = sch->Copy();
-  sch->Annotate(block_rv, tir::attr::meta_schedule_tiling_structure, structure);
-
+  Schedule original_sch = sch;
+
+  std::vector<State> initial_states;
+  for (const auto& kv : intrin_group_to_mapping_info) {
+    const TensorCoreIntrinGroup& intrin_group = intrin_groups[kv.first];
+    const tir::AutoTensorizeMappingInfo& mapping_info = kv.second;
+    Schedule new_sch = sch->Copy();
+    new_sch->Annotate(block_rv, tir::attr::meta_schedule_tiling_structure, structure);
+    initial_states.push_back(TensorCoreState(intrin_group, mapping_info, new_sch, block_rv));
+  }
   Array<Schedule> results;
-  for (auto&& state : ApplySubRules({TensorCoreState(sch, block_rv)})) {
+  for (auto&& state : ApplySubRules(initial_states)) {
     results.push_back(std::move(state->sch));
   }
   if (results.empty()) {
@@ -196,7 +244,7 @@ std::vector<State> MultiLevelTilingTensorCoreNode::AddWriteReuseTensorCore(
     AnnotateCooperativeFetching(&sch, state->write_reuse[0]);
   }
   sch->ReverseComputeInline(state->tensor_core_reindex_store);
-  TileAndAnnotateTensorize(&sch, cache_write, intrin_group.store_intrin);
+  TileAndAnnotateTensorize(&sch, cache_write, state->intrin_group.store_intrin);
   return {state};
 }
 
@@ -212,8 +260,8 @@ std::vector<State> MultiLevelTilingTensorCoreNode::AddReadReuseTensorCore(
     TileAndAnnotateTensorize(&sch, cache_read, intrin_name);
   };
 
-  f_tensorize_load(0, "wmma.matrix_a", intrin_group.load_a_intrin);
-  f_tensorize_load(1, "wmma.matrix_b", intrin_group.load_b_intrin);
+  f_tensorize_load(0, "wmma.matrix_a", state->intrin_group.load_a_intrin);
+  f_tensorize_load(1, "wmma.matrix_b", state->intrin_group.load_b_intrin);
   sch->ComputeInline(state->tensor_core_reindex_A);
   sch->ComputeInline(state->tensor_core_reindex_B);
 
@@ -238,6 +286,7 @@ std::vector<State> MultiLevelTilingTensorCoreNode::AddReadReuseTensorCore(
 Optional<LoopRV> MultiLevelTilingTensorCoreNode::TransformWithTensorIntrin(
     TensorCoreStateNode* state, const String& intrin_name) const {
   BlockRV block_rv = state->block_rv;
+  const tir::AutoTensorizeMappingInfo& mapping_info = state->mapping_info;
   tir::StmtSRef block_sref = state->sch->GetSRef(state->block_rv);
 
   // Add reindex stages
@@ -258,24 +307,24 @@ Optional<LoopRV> MultiLevelTilingTensorCoreNode::TransformWithTensorIntrin(
   // Transform the layout of reindex buffers accordingly.
   // The index map defines the mapping for the computation block. We need to extract the sub index
   // map to transform the load and store block.
-  ICHECK_EQ(mapping_info_->mappings.size(), 1U);  // assume only one mapping is present
-  const tir::IndexMap& index_map = mapping_info_->mappings[0];
+  ICHECK_EQ(mapping_info->mappings.size(), 1U);  // assume only one mapping is present
+  const tir::IndexMap& index_map = mapping_info->mappings[0];
 
   // Find the correspondence between block iters and the iters in the index map.
   std::unordered_map<tir::Var, tir::Var, ObjectPtrHash, ObjectPtrEqual> lhs_to_index_map_src;
   std::unordered_map<tir::Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> rhs_to_index_map_tgt;
   std::unordered_set<tir::Var, ObjectPtrHash, ObjectPtrEqual> unmapped_index_map_src;
-  ICHECK_EQ(mapping_info_->lhs_iters.size(), index_map->initial_indices.size());
-  for (int i = 0; i < static_cast<int>(mapping_info_->lhs_iters.size()); ++i) {
-    lhs_to_index_map_src[mapping_info_->lhs_iters[i]->var] = index_map->initial_indices[i];
+  ICHECK_EQ(mapping_info->lhs_iters.size(), index_map->initial_indices.size());
+  for (int i = 0; i < static_cast<int>(mapping_info->lhs_iters.size()); ++i) {
+    lhs_to_index_map_src[mapping_info->lhs_iters[i]->var] = index_map->initial_indices[i];
   }
   // The number of result iters in the index map is equal or more than the number of rhs (the
-  // tensor intrin) iters. When there are extra iters, these iters represent unmapped iters from the
-  // lhs. They will be skipped during pattern matching for tensorization.
-  // An example of such case is batch matmul, the batch dimension is kept after layout
-  // transformations and it will be kept as a outer loop after tensorization.
+  // tensor intrin) iters. When there are extra iters, these iters represent unmapped iters from
+  // the lhs. They will be skipped during pattern matching for tensorization. An example of such
+  // case is batch matmul, the batch dimension is kept after layout transformations and it will be
+  // kept as a outer loop after tensorization.
   int offset = static_cast<int>(index_map->final_indices.size()) -
-               static_cast<int>(mapping_info_->rhs_iters.size());
+               static_cast<int>(mapping_info->rhs_iters.size());
   ICHECK_GE(offset, 0);
   for (int i = 0; i < offset; ++i) {
     const tir::VarNode* var_ptr = index_map->final_indices[i].as<tir::VarNode>();
@@ -283,13 +332,13 @@ Optional<LoopRV> MultiLevelTilingTensorCoreNode::TransformWithTensorIntrin(
     unmapped_index_map_src.insert(GetRef<tir::Var>(var_ptr));
   }
   for (int i = offset; i < static_cast<int>(index_map->final_indices.size()); ++i) {
-    rhs_to_index_map_tgt[mapping_info_->rhs_iters[i - offset]->var] = index_map->final_indices[i];
+    rhs_to_index_map_tgt[mapping_info->rhs_iters[i - offset]->var] = index_map->final_indices[i];
   }
 
   auto f_get_sub_index_map = [&](const tir::Buffer& lhs_buffer, const tir::Region& lhs_region) {
     std::vector<tir::Var> sub_index_map_src;
     std::vector<PrimExpr> sub_index_map_tgt;
-    const tir::Buffer& rhs_buffer = mapping_info_->lhs_buffer_map[lhs_buffer];
+    const tir::Buffer& rhs_buffer = mapping_info->lhs_buffer_map[lhs_buffer];
     for (const Range& range : lhs_region) {
       ICHECK(tir::is_one(range->extent));
       const tir::VarNode* var_ptr = range->min.as<tir::VarNode>();
@@ -300,8 +349,8 @@ Optional<LoopRV> MultiLevelTilingTensorCoreNode::TransformWithTensorIntrin(
         sub_index_map_tgt.push_back(lhs_representer);
       }
     }
-    for (size_t i = 0; i < mapping_info_->rhs_buffer_indices[rhs_buffer].size(); ++i) {
-      const tir::VarNode* var = mapping_info_->rhs_buffer_indices[rhs_buffer][i].as<tir::VarNode>();
+    for (size_t i = 0; i < mapping_info->rhs_buffer_indices[rhs_buffer].size(); ++i) {
+      const tir::VarNode* var = mapping_info->rhs_buffer_indices[rhs_buffer][i].as<tir::VarNode>();
       ICHECK(var != nullptr);
       sub_index_map_tgt.push_back(rhs_to_index_map_tgt[GetRef<tir::Var>(var)]);
     }
@@ -345,7 +394,7 @@ inline std::vector<State> MultiLevelTilingTensorCoreNode::TransformForTensorizat
     TensorCoreState state) const {
   // Do reindex and layout transformations.
   Optional<LoopRV> transformed_loop_rv =
-      TransformWithTensorIntrin(state.operator->(), intrin_group.compute_intrin);
+      TransformWithTensorIntrin(state.operator->(), state->intrin_group.compute_intrin);
   if (!transformed_loop_rv.defined()) {
     // The workload can't be tensorized.
     return {};
@@ -356,32 +405,24 @@ inline std::vector<State> MultiLevelTilingTensorCoreNode::TransformForTensorizat
 
   // Add annotations for post processors.
   state->sch->Annotate(state->block_rv, tir::attr::meta_schedule_auto_tensorize,
-                       intrin_group.compute_intrin);
+                       state->intrin_group.compute_intrin);
   state->sch->Annotate(state->block_rv, tir::attr::meta_schedule_auto_tensorize_init,
-                       intrin_group.init_intrin);
+                       state->intrin_group.init_intrin);
   state->sch->Annotate(state->block_rv, tir::attr::warp_execution, Bool(true));
   return {std::move(state)};
 }
 
 ScheduleRule ScheduleRule::MultiLevelTilingTensorCore(
-    Map<String, String> intrin_group, String structure, Optional<Array<String>> tile_binds,
+    Array<Map<String, String>> intrin_groups, String structure, Optional<Array<String>> tile_binds,
     Optional<Integer> max_innermost_factor, Optional<Array<Integer>> vector_load_lens,
     Optional<Map<String, ObjectRef>> reuse_read, Optional<Map<String, ObjectRef>> reuse_write) {
   auto node = MultiLevelTilingInitCommon<MultiLevelTilingTensorCoreNode>(
       structure, tile_binds, max_innermost_factor, vector_load_lens, reuse_read, reuse_write);
 
-  auto f_initialize_intrin = [&intrin_group](String key_name, String* intrin_name) {
-    CHECK(intrin_group.count(key_name)) << "ValueError: " << key_name << " is not set.";
-    *intrin_name = intrin_group.at(key_name);
-    // Check the existence of the intrin
-    tir::TensorIntrin::Get(*intrin_name);
-  };
-  f_initialize_intrin("init", &node->intrin_group.init_intrin);
-  f_initialize_intrin("load_a", &node->intrin_group.load_a_intrin);
-  f_initialize_intrin("load_b", &node->intrin_group.load_b_intrin);
-  f_initialize_intrin("compute", &node->intrin_group.compute_intrin);
-  f_initialize_intrin("store", &node->intrin_group.store_intrin);
-
+  node->intrin_groups.reserve(intrin_groups.size());
+  for (const auto& intrin_group_config : intrin_groups) {
+    node->intrin_groups.emplace_back(TensorCoreIntrinGroup::FromConfig(intrin_group_config));
+  }
   return ScheduleRule(node);
 }
 
diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py b/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py
index 1ceef0afc3..c43645832b 100644
--- a/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py
+++ b/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py
@@ -563,25 +563,6 @@ sch.annotate(block_or_loop=b49, ann_key="meta_schedule.cooperative_fetch", ann_v
     check_trace(spaces, expected)
 
 
-def test_cuda_tensor_core_conv2d():
-    target = Target("cuda", host="llvm")
-    ctx = _create_context(
-        create_prim_func(
-            te_workload.conv2d_nhwc_f16(
-                N=1, H=16, W=16, CI=16, CO=16, kernel_size=3, stride=1, padding=1
-            )
-        ),
-        target,
-        multi_level_tiling_tensor_core(target=target, write_reuse_scope="shared"),
-    )
-    spaces = ctx.space_generator.generate_design_space(mod=ctx.mod)
-    assert len(spaces) == 1
-
-    expected = []
-    print("".join(spaces[0].trace.as_python()))
-    check_trace(spaces, expected)
-
-
 def test_cuda_tensor_core_matmul_relu():
     m = n = k = 128
     target = Target("cuda", host="llvm")
@@ -719,14 +700,15 @@ sch.reverse_compute_inline(block=b1)""".split(
 def test_cuda_tensor_core_matmul_relu_global():
     m = n = k = 128
     target = Target("cuda", host="llvm")
-    ctx = _create_context(
-        create_prim_func(
-            te_workload.matmul_relu_fp16(
-                n=n,
-                m=m,
-                k=k,
-            ),
+    workload = create_prim_func(
+        te_workload.matmul_relu_fp16(
+            n=n,
+            m=m,
+            k=k,
         ),
+    )
+    ctx = _create_context(
+        workload,
         target=target,
         rule=[
             multi_level_tiling_tensor_core(target=target, write_reuse_scope="global"),
@@ -822,6 +804,106 @@ sch.storage_align(block=b79, buffer_index=0, axis=-2, factor=32, offset=8)""".sp
     ]
     check_trace(spaces, expected)
 
+    ctx = _create_context(
+        workload,
+        target=target,
+        rule=[
+            multi_level_tiling_tensor_core(
+                target=target, write_reuse_scope="global", trans_b=[False, True]
+            ),
+            auto_inline(target),
+        ],
+    )
+    spaces = ctx.space_generator.generate_design_space(mod=ctx.mod)
+    assert len(spaces) == 2
+
+    expected = [
+        expected[0],
+        """b0 = sch.get_block(name="C", func_name="main")
+sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS")
+b1 = sch.reindex(block=b0, buffer=("write", 0))
+b2 = sch.reindex(block=b0, buffer=("read", 0))
+b3 = sch.reindex(block=b0, buffer=("read", 1))
+sch.transform_layout(block=b0, buffer=("read", 0), index_map=lambda i, k: (i, k, ))
+sch.transform_layout(block=b0, buffer=("read", 1), index_map=lambda j, k: (j, k, ))
+sch.transform_layout(block=b0, buffer=("write", 0), index_map=lambda i, j: (i, j, ))
+sch.transform_block_layout(block=b1, index_map=lambda i, j, k: (i, j, k, ))
+sch.transform_block_layout(block=b2, index_map=lambda i, j, k: (i, j, k, ))
+sch.transform_block_layout(block=b3, index_map=lambda i, j, k: (i, j, k, ))
+sch.transform_block_layout(block=b0, index_map=lambda i, j, k: (i, j, k, ))
+l4, l5, l6 = sch.get_loops(block=b0)
+l7, l8 = sch.split(loop=l6, factors=[None, 16], preserve_unit_iters=True)
+l9, l10 = sch.split(loop=l5, factors=[None, 16], preserve_unit_iters=True)
+l11, l12 = sch.split(loop=l4, factors=[None, 16], preserve_unit_iters=True)
+l13, l14, l15, l16, l17, l18 = sch.get_loops(block=b0)
+sch.reorder(l15, l17, l12, l10, l8)
+b19 = sch.blockize(loop=l12)
+sch.annotate(block_or_loop=b19, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_sync_16x16x16_f16f16f32_trans")
+sch.annotate(block_or_loop=b19, ann_key="meta_schedule.auto_tensorize_init", ann_val="wmma_fill_16x16x16_f32")
+sch.annotate(block_or_loop=b19, ann_key="warp_execution", ann_val=1)
+l20, l21, l22 = sch.get_loops(block=b19)
+v23, v24, v25, v26, v27 = sch.sample_perfect_tile(loop=l20, n=5, max_innermost_factor=4)
+l28, l29, l30, l31, l32 = sch.split(loop=l20, factors=[v23, v24, v25, v26, v27], preserve_unit_iters=True)
+v33, v34, v35, v36, v37 = sch.sample_perfect_tile(loop=l21, n=5, max_innermost_factor=4)
+l38, l39, l40, l41, l42 = sch.split(loop=l21, factors=[v33, v34, v35, v36, v37], preserve_unit_iters=True)
+v43, v44, v45 = sch.sample_perfect_tile(loop=l22, n=3, max_innermost_factor=4)
+l46, l47, l48 = sch.split(loop=l22, factors=[v43, v44, v45], preserve_unit_iters=True)
+sch.reorder(l28, l38, l29, l39, l30, l40, l46, l47, l31, l41, l48, l32, l42)
+l49 = sch.fuse(l28, l38, preserve_unit_iters=True)
+sch.bind(loop=l49, thread_axis="blockIdx.y")
+l50 = sch.fuse(l29, l39, preserve_unit_iters=True)
+sch.bind(loop=l50, thread_axis="blockIdx.x")
+l51 = sch.fuse(l30, l40, preserve_unit_iters=True)
+sch.bind(loop=l51, thread_axis="threadIdx.y")
+b52 = sch.cache_write(block=b19, write_buffer_index=0, storage_scope="wmma.accumulator")
+sch.reverse_compute_at(block=b52, loop=l51, preserve_unit_loops=True)
+sch.reverse_compute_inline(block=b1)
+l53, l54, l55, l56, l57 = sch.get_loops(block=b52)
+l58, l59 = sch.split(loop=l57, factors=[None, 16], preserve_unit_iters=True)
+l60, l61 = sch.split(loop=l56, factors=[None, 16], preserve_unit_iters=True)
+l62, l63, l64, l65, l66, l67, l68 = sch.get_loops(block=b52)
+sch.reorder(l67, l61, l59)
+b69 = sch.blockize(loop=l61)
+sch.annotate(block_or_loop=b69, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_store_16x16x16_f32_global")
+b70 = sch.cache_read(block=b19, read_buffer_index=0, storage_scope="shared")
+sch.compute_at(block=b70, loop=l46, preserve_unit_loops=True)
+l71, l72, l73, l74, l75, l76 = sch.get_loops(block=b70)
+l77 = sch.fuse(l75, l76, preserve_unit_iters=True)
+v78 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])
+sch.annotate(block_or_loop=b70, ann_key="meta_schedule.cooperative_fetch", ann_val=v78)
+b79 = sch.cache_read(block=b19, read_buffer_index=1, storage_scope="shared")
+sch.compute_at(block=b79, loop=l46, preserve_unit_loops=True)
+l80, l81, l82, l83, l84, l85 = sch.get_loops(block=b79)
+l86 = sch.fuse(l84, l85, preserve_unit_iters=True)
+v87 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])
+sch.annotate(block_or_loop=b79, ann_key="meta_schedule.cooperative_fetch", ann_val=v87)
+b88 = sch.cache_read(block=b19, read_buffer_index=0, storage_scope="wmma.matrix_a")
+sch.compute_at(block=b88, loop=l47, preserve_unit_loops=True)
+l89, l90, l91, l92, l93, l94, l95 = sch.get_loops(block=b88)
+l96, l97 = sch.split(loop=l95, factors=[None, 16], preserve_unit_iters=True)
+l98, l99 = sch.split(loop=l94, factors=[None, 16], preserve_unit_iters=True)
+l100, l101, l102, l103, l104, l105, l106, l107, l108 = sch.get_loops(block=b88)
+sch.reorder(l107, l99, l97)
+b109 = sch.blockize(loop=l99)
+sch.annotate(block_or_loop=b109, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_load_16x16x16_f16_a")
+b110 = sch.cache_read(block=b19, read_buffer_index=1, storage_scope="wmma.matrix_b")
+sch.compute_at(block=b110, loop=l47, preserve_unit_loops=True)
+l111, l112, l113, l114, l115, l116, l117 = sch.get_loops(block=b110)
+l118, l119 = sch.split(loop=l117, factors=[None, 16], preserve_unit_iters=True)
+l120, l121 = sch.split(loop=l116, factors=[None, 16], preserve_unit_iters=True)
+l122, l123, l124, l125, l126, l127, l128, l129, l130 = sch.get_loops(block=b110)
+sch.reorder(l129, l121, l119)
+b131 = sch.blockize(loop=l121)
+sch.annotate(block_or_loop=b131, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_load_16x16x16_f16_b_trans")
+sch.compute_inline(block=b2)
+sch.compute_inline(block=b3)
+sch.storage_align(block=b70, buffer_index=0, axis=-2, factor=32, offset=8)
+sch.storage_align(block=b79, buffer_index=0, axis=-2, factor=32, offset=8)""".split(
+            "\n"
+        ),
+    ]
+    check_trace(spaces, expected)
+
 
 def test_multi_level_tiling_non_tensorizable():
     # expected to do nothing on non-tensorizable workloads
@@ -850,13 +932,13 @@ def test_multi_level_tiling_non_tensorizable():
 
 def test_cuda_tensor_core_conv2d():
     target = Target("cuda", host="llvm")
+    workload = create_prim_func(
+        te_workload.conv2d_nhwc_f16(
+            N=1, H=16, W=16, CI=32, CO=32, kernel_size=3, stride=1, padding=1
+        )
+    )
     ctx = _create_context(
-        create_prim_func(
-            # dtype doesn't match tensor intrin
-            te_workload.conv2d_nhwc_f16(
-                N=1, H=16, W=16, CI=32, CO=32, kernel_size=3, stride=1, padding=1
-            )
-        ),
+        workload,
         target=target,
         rule=multi_level_tiling_tensor_core(target=target, write_reuse_scope="shared"),
     )
@@ -955,6 +1037,21 @@ sch.storage_align(block=b94, buffer_index=0, axis=-2, factor=32, offset=8)""".sp
     ]
     check_trace(spaces, expected)
 
+    # test adding unappliable tensor intrinsics doesn't change the search space
+    ctx = _create_context(
+        workload,
+        target,
+        multi_level_tiling_tensor_core(
+            target=target,
+            write_reuse_scope="shared",
+            in_dtype="float16",
+            out_dtype=["float16", "float32"],
+        ),
+    )
+    check_trace(spaces, expected)
+    spaces = ctx.space_generator.generate_design_space(mod=ctx.mod)
+    assert len(spaces) == 1
+
 
 if __name__ == "__main__":
     tvm.testing.main()