You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by wu...@apache.org on 2022/04/20 03:01:32 UTC

[tvm] branch main updated: [TIR] Utility function to decide loop mapping for auto tensorization (#11050)

This is an automated email from the ASF dual-hosted git repository.

wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 3823b39b8a [TIR] Utility function to decide loop mapping for auto tensorization (#11050)
3823b39b8a is described below

commit 3823b39b8a197e9e01ebb93dffaa1e710118c148
Author: Masahiro Masuda <ma...@gmail.com>
AuthorDate: Wed Apr 20 12:01:24 2022 +0900

    [TIR] Utility function to decide loop mapping for auto tensorization (#11050)
    
    * [TIR] Add TensorizeInfo and GetTensorizeLoopMapping
    
    * expose PreOrderVisit to python
    
    * add test case
    
    * add conv2d nchwc test
    
    * add mma test
    
    * add arm nhwc conv2d test
    
    * Revert "add arm nhwc conv2d test"
    
    This reverts commit eb147f33bb02d62a0eacc9cdfe777ac047ee1bc9.
    
    * refine
    
    * add doc
    
    * update
    
    * fixd condition
    
    * black
    
    * pylint
    
    * Update python/tvm/tir/schedule/analysis.py
    
    Co-authored-by: Ruihang Lai <la...@qq.com>
    
    * run black
    
    * bring back logic in original code to support loop permutation
    
    * add comment
    
    * simplify
    
    * minor fix to test
    
    Co-authored-by: Ruihang Lai <la...@qq.com>
    
    Co-authored-by: Siyuan Feng <Hz...@sjtu.edu.cn>
    Co-authored-by: Bohan Hou <32...@users.noreply.github.com>
    Co-authored-by: Hongyi Jin <32...@qq.com>
    Co-authored-by: Ruihang Lai <la...@qq.com>
    Co-authored-by: Wuwei Lin <wu...@apache.org>
---
 python/tvm/tir/schedule/analysis.py                |  33 +++-
 python/tvm/tir/stmt_functor.py                     |  12 ++
 src/tir/ir/stmt_functor.cc                         |   4 +
 src/tir/schedule/analysis.h                        |  33 ++++
 src/tir/schedule/analysis/analysis.cc              | 167 ++++++++++++++++++-
 .../python/unittest/test_tir_schedule_analysis.py  | 183 +++++++++++++++++++--
 6 files changed, 418 insertions(+), 14 deletions(-)

diff --git a/python/tvm/tir/schedule/analysis.py b/python/tvm/tir/schedule/analysis.py
index f2fb7c4f3d..71ff024217 100644
--- a/python/tvm/tir/schedule/analysis.py
+++ b/python/tvm/tir/schedule/analysis.py
@@ -17,12 +17,16 @@
 """Analysis used in TensorIR scheduling"""
 from typing import List, Optional
 
+import tvm._ffi
+from tvm.runtime import Object
+
 from ..buffer import Buffer
 from ..stmt import For
 from ..expr import PrimExpr
-from ..function import IndexMap
+from ..function import IndexMap, PrimFunc
 
 from . import _ffi_api
+from .schedule import Schedule, BlockRV
 
 
 def suggest_index_map(
@@ -56,3 +60,30 @@ def suggest_index_map(
         loops,
         predicate,
     )
+
+
+@tvm._ffi.register_object("tir.schedule.TensorizeInfo")
+class TensorizeInfo(Object):
+    """Necessary information used for tensorization."""
+
+
+def get_tensorize_loop_mapping(
+    sch: Schedule, block: BlockRV, desc_func: PrimFunc
+) -> Optional[TensorizeInfo]:
+    """Establish a mapping between loops in a target block and an intrinsic description
+
+    Parameters
+    ----------
+    sch : Schedule
+        The schedule to be tensorized
+    block : BlockRV
+        The target block to match against
+    desc_func : PrimFunc
+        The prim func describing the computation to be tensorized
+
+    Returns
+    -------
+    tensorize_info : Optional[TensorizeInfo]
+        TensorizeInfo structure if a valid mapping is found, None otherwise
+    """
+    return _ffi_api.GetTensorizeLoopMapping(sch, block, desc_func)  # type: ignore
diff --git a/python/tvm/tir/stmt_functor.py b/python/tvm/tir/stmt_functor.py
index 56dc1c20c2..5bcf4ae802 100644
--- a/python/tvm/tir/stmt_functor.py
+++ b/python/tvm/tir/stmt_functor.py
@@ -58,6 +58,18 @@ def post_order_visit(stmt, fvisit):
     return _ffi_api.PostOrderVisit(stmt, fvisit)  # type: ignore
 
 
+def pre_order_visit(stmt, fvisit):
+    """Recursive pre-order visit on stmt AST, applying fvisit on each node.
+       If fvisit returns False, it won't visit the children of the node.
+
+    Parameters
+    ----------
+    fvisit: function of the signature Object -> bool
+        The visitor function.
+    """
+    return _ffi_api.PreOrderVisit(stmt, fvisit)  # type: ignore
+
+
 def substitute(node, vmap):
     """Substitute the var specified by vmap.
 
diff --git a/src/tir/ir/stmt_functor.cc b/src/tir/ir/stmt_functor.cc
index c4d7ad0f6c..06933c2c0d 100644
--- a/src/tir/ir/stmt_functor.cc
+++ b/src/tir/ir/stmt_functor.cc
@@ -792,6 +792,10 @@ TVM_REGISTER_GLOBAL("tir.PostOrderVisit").set_body_typed([](ObjectRef node, Pack
   tir::PostOrderVisit(node, [f](const ObjectRef& n) { f(n); });
 });
 
+TVM_REGISTER_GLOBAL("tir.PreOrderVisit").set_body_typed([](ObjectRef node, PackedFunc f) {
+  tir::PreOrderVisit(node, [f](const ObjectRef& n) { return f(n); });
+});
+
 TVM_REGISTER_GLOBAL("tir.Substitute")
     .set_body_typed([](ObjectRef node, Map<Var, PrimExpr> vmap) -> ObjectRef {
       if (node->IsInstance<StmtNode>()) {
diff --git a/src/tir/schedule/analysis.h b/src/tir/schedule/analysis.h
index b76d41326f..c9c3d72ae0 100644
--- a/src/tir/schedule/analysis.h
+++ b/src/tir/schedule/analysis.h
@@ -656,6 +656,39 @@ Array<arith::IntSet> AnalyzeRegionLowerBound(const BufferRegion& region, const P
                                              const StmtSRef& dom_high_exclusive,
                                              arith::Analyzer* analyzer);
 
+/*! \brief Necessary information used for tensorization */
+class TensorizeInfoNode : public Object {
+ public:
+  /*! \brief Maps loops in a target block to the ones in an intrinsic description */
+  Map<tir::StmtSRef, tir::For> loop_map;
+  /*! \brief Maps loops in an intrinsic description to its index, outer to inner */
+  Map<tir::For, Integer> desc_loop_indexer;
+
+  void VisitAttrs(AttrVisitor* v) {
+    v->Visit("loop_map", &loop_map);
+    v->Visit("desc_loop_indexer", &desc_loop_indexer);
+  }
+
+  static constexpr const char* _type_key = "tir.schedule.TensorizeInfo";
+  TVM_DECLARE_FINAL_OBJECT_INFO(TensorizeInfoNode, Object);
+};
+
+class TensorizeInfo : public ObjectRef {
+ public:
+  TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TensorizeInfo, ObjectRef, TensorizeInfoNode);
+};
+
+/*!
+ * \brief Establish a mapping between loops in a target block and an intrinsic description
+ * \param self The schedule state to be tensorized
+ * \param block_sref The target block to match against
+ * \param desc_func The prim func describing the computation to be tensorized
+ * \return TensorizeInfo structure if a valid mapping is found, NullOpt otherwise
+ */
+Optional<TensorizeInfo> GetTensorizeLoopMapping(const tir::ScheduleState& self,
+                                                const tir::StmtSRef& block_sref,
+                                                const tir::PrimFunc& desc_func);
+
 }  // namespace tir
 }  // namespace tvm
 
diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc
index 4a7ac401dd..4777ee2657 100644
--- a/src/tir/schedule/analysis/analysis.cc
+++ b/src/tir/schedule/analysis/analysis.cc
@@ -16,6 +16,9 @@
  * specific language governing permissions and limitations
  * under the License.
  */
+#include <tvm/runtime/container/optional.h>
+#include <tvm/tir/expr.h>
+
 #include "../utils.h"
 
 namespace tvm {
@@ -492,8 +495,7 @@ void CheckNotOutputBlock(const ScheduleState& self, const StmtSRef& block_sref,
   }
 }
 
-std::vector<IterVarType> GetBlockVarTypes(const StmtSRef& block_sref) {
-  const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
+std::vector<IterVarType> GetBlockVarTypes(const BlockNode* block) {
   std::vector<IterVarType> results;
   results.reserve(block->iter_vars.size());
   for (const IterVar& iter_var : block->iter_vars) {
@@ -502,6 +504,11 @@ std::vector<IterVarType> GetBlockVarTypes(const StmtSRef& block_sref) {
   return results;
 }
 
+std::vector<IterVarType> GetBlockVarTypes(const StmtSRef& block_sref) {
+  const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
+  return GetBlockVarTypes(block);
+}
+
 bool IsWriteCache(const StmtSRef& block_sref) {
   const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
   if (block->writes.size() != 1) {
@@ -2028,5 +2035,161 @@ bool NeedsRFactorOrCrossThreadReduction(const tir::ScheduleState& self,   //
   }
 }
 
+TVM_REGISTER_NODE_TYPE(TensorizeInfoNode);
+
+Optional<TensorizeInfo> GetTensorizeLoopMapping(const tir::ScheduleState& self,
+                                                const tir::StmtSRef& block_sref,
+                                                const tir::PrimFunc& desc_func) {
+  arith::Analyzer analyzer;
+  const tir::BlockRealize& block = tir::GetBlockRealize(self, block_sref);
+  // Step 1. Analyze desc_func, extract its block, loops and loop vars
+  const tir::BlockRealizeNode* desc_block = nullptr;
+  std::vector<const tir::ForNode*> desc_loops;
+  std::unordered_set<const tir::VarNode*> desc_loop_vars;
+  const auto* desc_scope_realize = desc_func->body.as<tir::BlockRealizeNode>();
+  ICHECK(desc_scope_realize);
+  {
+    auto f_visit = [&desc_block, &desc_loops, &desc_loop_vars,
+                    &analyzer](const ObjectRef& obj) -> bool {
+      // Extract the block
+      if (const auto* block = obj.as<tir::BlockRealizeNode>()) {
+        desc_block = block;
+        return false;
+      }
+      // Extract loops
+      if (const auto* loop = obj.as<tir::ForNode>()) {
+        desc_loops.push_back(loop);
+        desc_loop_vars.insert(loop->loop_var.get());
+        if (!analyzer.CanProve(loop->min == 0)) {
+          return false;
+        }
+      }
+      return true;
+    };
+    tir::PostOrderVisit(desc_scope_realize->block->body, f_visit);
+    std::reverse(desc_loops.begin(), desc_loops.end());
+    ICHECK(desc_block);
+  }
+  // Step 2. Collect loops from block_sref
+  const tir::StmtSRef& scope_sref = GetScopeRoot(self, block_sref, false);
+  const tir::BlockNode* scope_block = TVM_SREF_TO_BLOCK(scope_block, scope_sref);
+  std::vector<const tir::ForNode*> block_loops;
+  std::unordered_set<const tir::VarNode*> block_loop_vars;
+  {
+    for (const tir::StmtSRefNode* loop_sref = block_sref->parent;; loop_sref = loop_sref->parent) {
+      const auto* loop = loop_sref->StmtAs<tir::ForNode>();
+      if (loop == nullptr || loop->body->IsInstance<tir::SeqStmtNode>()) {
+        break;
+      }
+      block_loops.push_back(loop);
+      block_loop_vars.insert(loop->loop_var.get());
+      if (!analyzer.CanProve(loop->min == 0)) {
+        return NullOpt;
+      }
+    }
+    std::reverse(block_loops.begin(), block_loops.end());
+  }
+  // Step 3. Map from block loops to desc block loops
+  ObjectPtr<TensorizeInfoNode> ret = make_object<TensorizeInfoNode>();
+  const int n_block_vars = block->iter_values.size();
+  const int n_desc_vars = desc_block->iter_values.size();
+  const int offset = n_block_vars - n_desc_vars;
+
+  if (offset < 0) {
+    return NullOpt;
+  }
+
+  const std::vector<IterVarType> iter_types_block = GetBlockVarTypes(block_sref);
+  const std::vector<IterVarType> iter_types_desc = GetBlockVarTypes(desc_block->block.get());
+
+  ICHECK(desc_loops.size() == static_cast<size_t>(n_desc_vars));
+  ICHECK(block_loops.size() == iter_types_block.size());
+
+  // We assume that the orders of iter_vars in the target and the desc block are consistent.
+  // Based on that assumption, the following logic supports arbitrary permutations of a loop order,
+  // such as
+
+  // for k:
+  //   for i:
+  //     for j:
+  //       C[i, j] += A[i, k] * B[k, j]
+
+  // or
+
+  // for i:
+  //   for j:
+  //     for k:
+  //       C[i, j] += A[i, k] * B[k, j]
+
+  int next_block_ind = block_loops.size() - 1;
+  for (int i_desc = n_desc_vars - 1; i_desc >= 0; --i_desc) {
+    // Step 3.1. Find the corresponding loop of the i_desc-th block var of desc
+    const PrimExpr& desc_bind = desc_block->iter_values[i_desc];
+    const tir::ForNode* desc_loop = nullptr;
+    IterVarType iter_type_desc = iter_types_desc[i_desc];
+    for (int i = 0, n = desc_loops.size(); i < n; ++i) {
+      // Check if desc_bind = loops[i]->loop_var + stuff-irrelevant-of-loop-vars
+      PrimExpr residual = analyzer.Simplify(desc_bind - desc_loops[i]->loop_var);
+      if (!UsesVar(residual,
+                   [&desc_loop_vars](const VarNode* var) { return desc_loop_vars.count(var); })) {
+        desc_loop = desc_loops[i];
+        iter_type_desc = iter_types_desc[i];
+        break;
+      }
+    }
+    if (desc_loop == nullptr || desc_loop->extent.as<IntImmNode>() == nullptr) {
+      return NullOpt;
+    }
+
+    const IntImmNode* int_desc_extent = desc_loop->extent.as<IntImmNode>();
+
+    // Step 3.2. Find the corresponding iter_value of the target block with a matching iterator type
+    PrimExpr block_bind;
+    for (int i = next_block_ind; i >= 0; --i) {
+      if (iter_types_block[i] == iter_type_desc) {
+        next_block_ind = i - 1;
+        block_bind = block->iter_values[i];
+        break;
+      }
+    }
+
+    if (!block_bind.defined()) return NullOpt;
+
+    // Step 3.3. Find the corresponding loop of the target block
+    for (int i = 0, n = block_loops.size(); i < n; ++i) {
+      // Check if block_bind = block_loops[i]->loop_var + stuff-irrelevant-of-loop-vars
+      const tir::ForNode* block_loop = block_loops[i];
+      const tir::StmtSRef& block_loop_sref = self->stmt2ref[block_loop];
+      // Skip i-th loop if it has already been mapped
+      if (ret->loop_map.find(block_loop_sref) != ret->loop_map.end()) continue;
+
+      PrimExpr residual = analyzer.Simplify(block_bind - block_loops[i]->loop_var);
+      if (UsesVar(residual,
+                  [&block_loop_vars](const VarNode* var) { return block_loop_vars.count(var); }))
+        continue;
+
+      const IntImmNode* int_block_extent = block_loops[i]->extent.as<IntImmNode>();
+
+      // Check divisibility
+      if (!int_block_extent || int_block_extent->value % int_desc_extent->value != 0) {
+        return NullOpt;
+      }
+
+      ret->loop_map.Set(block_loop_sref, GetRef<tir::For>(desc_loop));
+      break;
+    }
+  }
+
+  for (int i = 0, n = desc_loops.size(); i < n; ++i) {
+    ret->desc_loop_indexer.Set(GetRef<tir::For>(desc_loops[i]), Integer(i));
+  }
+  return TensorizeInfo(ret);
+}
+
+TVM_REGISTER_GLOBAL("tir.schedule.GetTensorizeLoopMapping")
+    .set_body_typed([](Schedule sch, BlockRV block, PrimFunc desc_func) {
+      return GetTensorizeLoopMapping(sch->state(), sch->GetSRef(block), desc_func);
+    });
+
 }  // namespace tir
 }  // namespace tvm
diff --git a/tests/python/unittest/test_tir_schedule_analysis.py b/tests/python/unittest/test_tir_schedule_analysis.py
index 760b412ac8..10371d3cca 100644
--- a/tests/python/unittest/test_tir_schedule_analysis.py
+++ b/tests/python/unittest/test_tir_schedule_analysis.py
@@ -17,18 +17,17 @@
 # pylint: disable=missing-docstring
 from typing import List
 
-from tvm.tir import (
-    Evaluate,
-    For,
-    ForKind,
-    IndexMap,
-    Var,
-    decl_buffer,
-    floordiv,
-    floormod,
-)
+import tvm
+from tvm.tir.tensor_intrin.x86 import dot_product_16x4_u8i8i32_desc
+
+
+from tvm.tir import Evaluate, For, ForKind, IndexMap, Var, decl_buffer, floordiv, floormod, Schedule
 from tvm.tir.analysis import expr_deep_equal
-from tvm.tir.schedule.analysis import suggest_index_map
+from tvm.tir.schedule.analysis import suggest_index_map, get_tensorize_loop_mapping, TensorizeInfo
+from tvm.script import tir as T
+from tvm.tir.stmt_functor import pre_order_visit
+from tvm.meta_schedule.testing import te_workload
+from tvm.te import create_prim_func
 
 
 def _make_vars(*args: str) -> List[Var]:
@@ -102,6 +101,168 @@ def test_suggest_index_map_bijective():
     _assert_equal_index_map(index_map, expected_index_map)
 
 
+@tvm.script.ir_module
+class DenseVNNIModule:
+    @T.prim_func
+    def main(
+        placeholder: T.Buffer[(1024, 1024), "uint8"],
+        placeholder_1: T.Buffer[(64, 256, 16, 4), "int8"],
+        compute: T.Buffer[(1024, 1024), "int32"],
+    ) -> None:
+        T.func_attr({"global_symbol": "main", "tir.noalias": True})
+        with T.block("root"):
+            T.reads()
+            T.writes()
+            for i0, i1, i2 in T.grid(1024, 1024, 1024):
+                with T.block("compute"):
+                    i, j, k = T.axis.remap("SSR", [i0, i1, i2])
+                    T.reads(placeholder[i, k], placeholder_1[j // 16, k // 4, j % 16, k % 4])
+                    T.writes(compute[i, j])
+                    with T.init():
+                        compute[i, j] = 0
+                    compute[i, j] = compute[i, j] + T.cast(placeholder[i, k], "int32") * T.cast(
+                        placeholder_1[j // 16, k // 4, j % 16, k % 4], "int32"
+                    )
+
+
+@tvm.script.ir_module
+class Conv2dNCHWcVNNIModule:
+    @T.prim_func
+    def main(
+        placeholder: T.Buffer[(1, 4, 56, 56, 16), "uint8"],
+        placeholder_1: T.Buffer[(16, 4, 1, 1, 4, 16, 4), "int8"],
+        conv2d_NCHWc_int8: T.Buffer[(1, 16, 56, 56, 16), "int32"],
+    ) -> None:
+        T.func_attr({"global_symbol": "main", "tir.noalias": True})
+        for i0, i1, i2, i3, i4, i5, i6, i7, i8, i9 in T.grid(1, 16, 56, 56, 16, 1, 1, 4, 4, 4):
+            with T.block("conv2d_NCHWc_int8"):
+                (
+                    n,
+                    oc_chunk,
+                    oh,
+                    ow,
+                    oc_block,
+                    kh,
+                    kw,
+                    ic_outer,
+                    ic_f_inner,
+                    ic_s_inner,
+                ) = T.axis.remap("SSSSSRRRRR", [i0, i1, i2, i3, i4, i5, i6, i7, i8, i9])
+                T.reads(
+                    placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 + ic_s_inner],
+                    placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, oc_block, ic_s_inner],
+                )
+                T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block])
+                with T.init():
+                    conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block] = 0
+                conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block] = conv2d_NCHWc_int8[
+                    n, oc_chunk, oh, ow, oc_block
+                ] + T.cast(
+                    placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 + ic_s_inner], "int32"
+                ) * T.cast(
+                    placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, oc_block, ic_s_inner],
+                    "int32",
+                )
+
+
+def collect_loops(prim_func):
+    loops = []
+
+    def callback(node):
+        if isinstance(node, tvm.tir.For):
+            loops.append(node)
+        return True
+
+    pre_order_visit(prim_func.body, callback)
+
+    return loops
+
+
+def test_get_tensorize_loop_mapping_dense_vnni():
+    s = Schedule(DenseVNNIModule)
+    block = s.get_block("compute")
+
+    info = get_tensorize_loop_mapping(s, block, dot_product_16x4_u8i8i32_desc)
+
+    assert isinstance(info, TensorizeInfo)
+
+    desc_loop_to_sref = dict((v, k) for k, v in info.loop_map.items())
+
+    desc_loops = collect_loops(dot_product_16x4_u8i8i32_desc)
+    _, loop_j, loop_k = s.get_loops(block)
+
+    assert desc_loops[0] in desc_loop_to_sref and desc_loops[1] in desc_loop_to_sref
+    assert s.get(desc_loop_to_sref[desc_loops[0]]) == s.get(loop_j)
+    assert s.get(desc_loop_to_sref[desc_loops[1]]) == s.get(loop_k)
+
+
+def test_get_tensorize_loop_mapping_conv2d_nchwc_vnni():
+    s = Schedule(Conv2dNCHWcVNNIModule)
+    block = s.get_block("conv2d_NCHWc_int8")
+
+    info = get_tensorize_loop_mapping(s, block, dot_product_16x4_u8i8i32_desc)
+
+    desc_loop_to_sref = dict((v, k) for k, v in info.loop_map.items())
+
+    desc_loops = collect_loops(dot_product_16x4_u8i8i32_desc)
+
+    # i4 corresonds to the inner output channel axis of the NCHWc output tensor
+    # for i0, i1, i2, i3, i4, i5, i6, i7, i8, i9 in T.grid(1, 16, 56, 56, 16, 1, 1, 4, 4, 4):
+    _, _, _, _, i4, _, _, _, _, i9 = s.get_loops(block)
+
+    assert desc_loops[0] in desc_loop_to_sref and desc_loops[1] in desc_loop_to_sref
+    assert s.get(desc_loop_to_sref[desc_loops[0]]) == s.get(i4)
+    assert s.get(desc_loop_to_sref[desc_loops[1]]) == s.get(i9)
+
+
+def test_get_tensorize_loop_mapping_matmul_mma():
+    @T.prim_func
+    def matmul_16x16x16xf16f16f16_desc(
+        A: T.Buffer((16, 16), "float16", align=128, offset_factor=1),
+        B: T.Buffer((16, 16), "float16", align=128, offset_factor=1),
+        C: T.Buffer((16, 16), "float16", align=128, offset_factor=1),
+    ) -> None:
+        with T.block("root"):
+            T.reads(C[0:16, 0:16], A[0:16, 0:16], B[0:16, 0:16])
+            T.writes(C[0:16, 0:16])
+            for i, j, k in T.grid(16, 16, 16):
+                with T.block("update"):
+                    vii, vjj, vkk = T.axis.remap("SSR", [i, j, k])
+                    C[vii, vjj] = C[vii, vjj] + A[vii, vkk] * B[vjj, vkk]
+
+    matmul = create_prim_func(
+        te_workload.matmul_relu(
+            n=512,
+            m=512,
+            k=512,
+        )
+    )
+
+    s = Schedule(matmul)
+    block = s.get_block("C")
+    i0, i1, i2 = s.get_loops(block)
+    desc_loops = collect_loops(matmul_16x16x16xf16f16f16_desc)
+
+    for do_reorder in [False, True]:
+        # Mapping should be invariant to the loop permutation
+        if do_reorder:
+            s.reorder(i2, i0, i1)
+
+        info = get_tensorize_loop_mapping(s, block, matmul_16x16x16xf16f16f16_desc)
+        assert info is not None
+        desc_loop_to_sref = dict((v, k) for k, v in info.loop_map.items())
+
+        for i in range(3):
+            assert desc_loops[i] in desc_loop_to_sref
+
+        assert s.get(desc_loop_to_sref[desc_loops[0]]) == s.get(i0)
+        assert s.get(desc_loop_to_sref[desc_loops[1]]) == s.get(i1)
+        assert s.get(desc_loop_to_sref[desc_loops[2]]) == s.get(i2)
+
+
 if __name__ == "__main__":
     test_suggest_index_map_simple()
     test_suggest_index_map_bijective()
+    test_get_tensorize_loop_mapping_dense_vnni()
+    test_get_tensorize_loop_mapping_conv2d_nchwc_vnni()
+    test_get_tensorize_loop_mapping_matmul_mma()