You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by wu...@apache.org on 2022/04/20 03:01:32 UTC
[tvm] branch main updated: [TIR] Utility function to decide loop mapping for auto tensorization (#11050)
This is an automated email from the ASF dual-hosted git repository.
wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 3823b39b8a [TIR] Utility function to decide loop mapping for auto tensorization (#11050)
3823b39b8a is described below
commit 3823b39b8a197e9e01ebb93dffaa1e710118c148
Author: Masahiro Masuda <ma...@gmail.com>
AuthorDate: Wed Apr 20 12:01:24 2022 +0900
[TIR] Utility function to decide loop mapping for auto tensorization (#11050)
* [TIR] Add TensorizeInfo and GetTensorizeLoopMapping
* expose PreOrderVisit to python
* add test case
* add conv2d nchwc test
* add mma test
* add arm nhwc conv2d test
* Revert "add arm nhwc conv2d test"
This reverts commit eb147f33bb02d62a0eacc9cdfe777ac047ee1bc9.
* refine
* add doc
* update
* fixd condition
* black
* pylint
* Update python/tvm/tir/schedule/analysis.py
Co-authored-by: Ruihang Lai <la...@qq.com>
* run black
* bring back logic in original code to support loop permutation
* add comment
* simplify
* minor fix to test
Co-authored-by: Ruihang Lai <la...@qq.com>
Co-authored-by: Siyuan Feng <Hz...@sjtu.edu.cn>
Co-authored-by: Bohan Hou <32...@users.noreply.github.com>
Co-authored-by: Hongyi Jin <32...@qq.com>
Co-authored-by: Ruihang Lai <la...@qq.com>
Co-authored-by: Wuwei Lin <wu...@apache.org>
---
python/tvm/tir/schedule/analysis.py | 33 +++-
python/tvm/tir/stmt_functor.py | 12 ++
src/tir/ir/stmt_functor.cc | 4 +
src/tir/schedule/analysis.h | 33 ++++
src/tir/schedule/analysis/analysis.cc | 167 ++++++++++++++++++-
.../python/unittest/test_tir_schedule_analysis.py | 183 +++++++++++++++++++--
6 files changed, 418 insertions(+), 14 deletions(-)
diff --git a/python/tvm/tir/schedule/analysis.py b/python/tvm/tir/schedule/analysis.py
index f2fb7c4f3d..71ff024217 100644
--- a/python/tvm/tir/schedule/analysis.py
+++ b/python/tvm/tir/schedule/analysis.py
@@ -17,12 +17,16 @@
"""Analysis used in TensorIR scheduling"""
from typing import List, Optional
+import tvm._ffi
+from tvm.runtime import Object
+
from ..buffer import Buffer
from ..stmt import For
from ..expr import PrimExpr
-from ..function import IndexMap
+from ..function import IndexMap, PrimFunc
from . import _ffi_api
+from .schedule import Schedule, BlockRV
def suggest_index_map(
@@ -56,3 +60,30 @@ def suggest_index_map(
loops,
predicate,
)
+
+
+@tvm._ffi.register_object("tir.schedule.TensorizeInfo")
+class TensorizeInfo(Object):
+ """Necessary information used for tensorization."""
+
+
+def get_tensorize_loop_mapping(
+ sch: Schedule, block: BlockRV, desc_func: PrimFunc
+) -> Optional[TensorizeInfo]:
+ """Establish a mapping between loops in a target block and an intrinsic description
+
+ Parameters
+ ----------
+ sch : Schedule
+ The schedule to be tensorized
+ block : BlockRV
+ The target block to match against
+ desc_func : PrimFunc
+ The prim func describing the computation to be tensorized
+
+ Returns
+ -------
+ tensorize_info : Optional[TensorizeInfo]
+ TensorizeInfo structure if a valid mapping is found, None otherwise
+ """
+ return _ffi_api.GetTensorizeLoopMapping(sch, block, desc_func) # type: ignore
diff --git a/python/tvm/tir/stmt_functor.py b/python/tvm/tir/stmt_functor.py
index 56dc1c20c2..5bcf4ae802 100644
--- a/python/tvm/tir/stmt_functor.py
+++ b/python/tvm/tir/stmt_functor.py
@@ -58,6 +58,18 @@ def post_order_visit(stmt, fvisit):
return _ffi_api.PostOrderVisit(stmt, fvisit) # type: ignore
+def pre_order_visit(stmt, fvisit):
+ """Recursive pre-order visit on stmt AST, applying fvisit on each node.
+ If fvisit returns False, it won't visit the children of the node.
+
+ Parameters
+ ----------
+ fvisit: function of the signature Object -> bool
+ The visitor function.
+ """
+ return _ffi_api.PreOrderVisit(stmt, fvisit) # type: ignore
+
+
def substitute(node, vmap):
"""Substitute the var specified by vmap.
diff --git a/src/tir/ir/stmt_functor.cc b/src/tir/ir/stmt_functor.cc
index c4d7ad0f6c..06933c2c0d 100644
--- a/src/tir/ir/stmt_functor.cc
+++ b/src/tir/ir/stmt_functor.cc
@@ -792,6 +792,10 @@ TVM_REGISTER_GLOBAL("tir.PostOrderVisit").set_body_typed([](ObjectRef node, Pack
tir::PostOrderVisit(node, [f](const ObjectRef& n) { f(n); });
});
+TVM_REGISTER_GLOBAL("tir.PreOrderVisit").set_body_typed([](ObjectRef node, PackedFunc f) {
+ tir::PreOrderVisit(node, [f](const ObjectRef& n) { return f(n); });
+});
+
TVM_REGISTER_GLOBAL("tir.Substitute")
.set_body_typed([](ObjectRef node, Map<Var, PrimExpr> vmap) -> ObjectRef {
if (node->IsInstance<StmtNode>()) {
diff --git a/src/tir/schedule/analysis.h b/src/tir/schedule/analysis.h
index b76d41326f..c9c3d72ae0 100644
--- a/src/tir/schedule/analysis.h
+++ b/src/tir/schedule/analysis.h
@@ -656,6 +656,39 @@ Array<arith::IntSet> AnalyzeRegionLowerBound(const BufferRegion& region, const P
const StmtSRef& dom_high_exclusive,
arith::Analyzer* analyzer);
+/*! \brief Necessary information used for tensorization */
+class TensorizeInfoNode : public Object {
+ public:
+ /*! \brief Maps loops in a target block to the ones in an intrinsic description */
+ Map<tir::StmtSRef, tir::For> loop_map;
+ /*! \brief Maps loops in an intrinsic description to its index, outer to inner */
+ Map<tir::For, Integer> desc_loop_indexer;
+
+ void VisitAttrs(AttrVisitor* v) {
+ v->Visit("loop_map", &loop_map);
+ v->Visit("desc_loop_indexer", &desc_loop_indexer);
+ }
+
+ static constexpr const char* _type_key = "tir.schedule.TensorizeInfo";
+ TVM_DECLARE_FINAL_OBJECT_INFO(TensorizeInfoNode, Object);
+};
+
+class TensorizeInfo : public ObjectRef {
+ public:
+ TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TensorizeInfo, ObjectRef, TensorizeInfoNode);
+};
+
+/*!
+ * \brief Establish a mapping between loops in a target block and an intrinsic description
+ * \param self The schedule state to be tensorized
+ * \param block_sref The target block to match against
+ * \param desc_func The prim func describing the computation to be tensorized
+ * \return TensorizeInfo structure if a valid mapping is found, NullOpt otherwise
+ */
+Optional<TensorizeInfo> GetTensorizeLoopMapping(const tir::ScheduleState& self,
+ const tir::StmtSRef& block_sref,
+ const tir::PrimFunc& desc_func);
+
} // namespace tir
} // namespace tvm
diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc
index 4a7ac401dd..4777ee2657 100644
--- a/src/tir/schedule/analysis/analysis.cc
+++ b/src/tir/schedule/analysis/analysis.cc
@@ -16,6 +16,9 @@
* specific language governing permissions and limitations
* under the License.
*/
+#include <tvm/runtime/container/optional.h>
+#include <tvm/tir/expr.h>
+
#include "../utils.h"
namespace tvm {
@@ -492,8 +495,7 @@ void CheckNotOutputBlock(const ScheduleState& self, const StmtSRef& block_sref,
}
}
-std::vector<IterVarType> GetBlockVarTypes(const StmtSRef& block_sref) {
- const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
+std::vector<IterVarType> GetBlockVarTypes(const BlockNode* block) {
std::vector<IterVarType> results;
results.reserve(block->iter_vars.size());
for (const IterVar& iter_var : block->iter_vars) {
@@ -502,6 +504,11 @@ std::vector<IterVarType> GetBlockVarTypes(const StmtSRef& block_sref) {
return results;
}
+std::vector<IterVarType> GetBlockVarTypes(const StmtSRef& block_sref) {
+ const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
+ return GetBlockVarTypes(block);
+}
+
bool IsWriteCache(const StmtSRef& block_sref) {
const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
if (block->writes.size() != 1) {
@@ -2028,5 +2035,161 @@ bool NeedsRFactorOrCrossThreadReduction(const tir::ScheduleState& self, //
}
}
+TVM_REGISTER_NODE_TYPE(TensorizeInfoNode);
+
+Optional<TensorizeInfo> GetTensorizeLoopMapping(const tir::ScheduleState& self,
+ const tir::StmtSRef& block_sref,
+ const tir::PrimFunc& desc_func) {
+ arith::Analyzer analyzer;
+ const tir::BlockRealize& block = tir::GetBlockRealize(self, block_sref);
+ // Step 1. Analyze desc_func, extract its block, loops and loop vars
+ const tir::BlockRealizeNode* desc_block = nullptr;
+ std::vector<const tir::ForNode*> desc_loops;
+ std::unordered_set<const tir::VarNode*> desc_loop_vars;
+ const auto* desc_scope_realize = desc_func->body.as<tir::BlockRealizeNode>();
+ ICHECK(desc_scope_realize);
+ {
+ auto f_visit = [&desc_block, &desc_loops, &desc_loop_vars,
+ &analyzer](const ObjectRef& obj) -> bool {
+ // Extract the block
+ if (const auto* block = obj.as<tir::BlockRealizeNode>()) {
+ desc_block = block;
+ return false;
+ }
+ // Extract loops
+ if (const auto* loop = obj.as<tir::ForNode>()) {
+ desc_loops.push_back(loop);
+ desc_loop_vars.insert(loop->loop_var.get());
+ if (!analyzer.CanProve(loop->min == 0)) {
+ return false;
+ }
+ }
+ return true;
+ };
+ tir::PostOrderVisit(desc_scope_realize->block->body, f_visit);
+ std::reverse(desc_loops.begin(), desc_loops.end());
+ ICHECK(desc_block);
+ }
+ // Step 2. Collect loops from block_sref
+ const tir::StmtSRef& scope_sref = GetScopeRoot(self, block_sref, false);
+ const tir::BlockNode* scope_block = TVM_SREF_TO_BLOCK(scope_block, scope_sref);
+ std::vector<const tir::ForNode*> block_loops;
+ std::unordered_set<const tir::VarNode*> block_loop_vars;
+ {
+ for (const tir::StmtSRefNode* loop_sref = block_sref->parent;; loop_sref = loop_sref->parent) {
+ const auto* loop = loop_sref->StmtAs<tir::ForNode>();
+ if (loop == nullptr || loop->body->IsInstance<tir::SeqStmtNode>()) {
+ break;
+ }
+ block_loops.push_back(loop);
+ block_loop_vars.insert(loop->loop_var.get());
+ if (!analyzer.CanProve(loop->min == 0)) {
+ return NullOpt;
+ }
+ }
+ std::reverse(block_loops.begin(), block_loops.end());
+ }
+ // Step 3. Map from block loops to desc block loops
+ ObjectPtr<TensorizeInfoNode> ret = make_object<TensorizeInfoNode>();
+ const int n_block_vars = block->iter_values.size();
+ const int n_desc_vars = desc_block->iter_values.size();
+ const int offset = n_block_vars - n_desc_vars;
+
+ if (offset < 0) {
+ return NullOpt;
+ }
+
+ const std::vector<IterVarType> iter_types_block = GetBlockVarTypes(block_sref);
+ const std::vector<IterVarType> iter_types_desc = GetBlockVarTypes(desc_block->block.get());
+
+ ICHECK(desc_loops.size() == static_cast<size_t>(n_desc_vars));
+ ICHECK(block_loops.size() == iter_types_block.size());
+
+ // We assume that the orders of iter_vars in the target and the desc block are consistent.
+ // Based on that assumption, the following logic supports arbitrary permutations of a loop order,
+ // such as
+
+ // for k:
+ // for i:
+ // for j:
+ // C[i, j] += A[i, k] * B[k, j]
+
+ // or
+
+ // for i:
+ // for j:
+ // for k:
+ // C[i, j] += A[i, k] * B[k, j]
+
+ int next_block_ind = block_loops.size() - 1;
+ for (int i_desc = n_desc_vars - 1; i_desc >= 0; --i_desc) {
+ // Step 3.1. Find the corresponding loop of the i_desc-th block var of desc
+ const PrimExpr& desc_bind = desc_block->iter_values[i_desc];
+ const tir::ForNode* desc_loop = nullptr;
+ IterVarType iter_type_desc = iter_types_desc[i_desc];
+ for (int i = 0, n = desc_loops.size(); i < n; ++i) {
+ // Check if desc_bind = loops[i]->loop_var + stuff-irrelevant-of-loop-vars
+ PrimExpr residual = analyzer.Simplify(desc_bind - desc_loops[i]->loop_var);
+ if (!UsesVar(residual,
+ [&desc_loop_vars](const VarNode* var) { return desc_loop_vars.count(var); })) {
+ desc_loop = desc_loops[i];
+ iter_type_desc = iter_types_desc[i];
+ break;
+ }
+ }
+ if (desc_loop == nullptr || desc_loop->extent.as<IntImmNode>() == nullptr) {
+ return NullOpt;
+ }
+
+ const IntImmNode* int_desc_extent = desc_loop->extent.as<IntImmNode>();
+
+ // Step 3.2. Find the corresponding iter_value of the target block with a matching iterator type
+ PrimExpr block_bind;
+ for (int i = next_block_ind; i >= 0; --i) {
+ if (iter_types_block[i] == iter_type_desc) {
+ next_block_ind = i - 1;
+ block_bind = block->iter_values[i];
+ break;
+ }
+ }
+
+ if (!block_bind.defined()) return NullOpt;
+
+ // Step 3.3. Find the corresponding loop of the target block
+ for (int i = 0, n = block_loops.size(); i < n; ++i) {
+ // Check if block_bind = block_loops[i]->loop_var + stuff-irrelevant-of-loop-vars
+ const tir::ForNode* block_loop = block_loops[i];
+ const tir::StmtSRef& block_loop_sref = self->stmt2ref[block_loop];
+ // Skip i-th loop if it has already been mapped
+ if (ret->loop_map.find(block_loop_sref) != ret->loop_map.end()) continue;
+
+ PrimExpr residual = analyzer.Simplify(block_bind - block_loops[i]->loop_var);
+ if (UsesVar(residual,
+ [&block_loop_vars](const VarNode* var) { return block_loop_vars.count(var); }))
+ continue;
+
+ const IntImmNode* int_block_extent = block_loops[i]->extent.as<IntImmNode>();
+
+ // Check divisibility
+ if (!int_block_extent || int_block_extent->value % int_desc_extent->value != 0) {
+ return NullOpt;
+ }
+
+ ret->loop_map.Set(block_loop_sref, GetRef<tir::For>(desc_loop));
+ break;
+ }
+ }
+
+ for (int i = 0, n = desc_loops.size(); i < n; ++i) {
+ ret->desc_loop_indexer.Set(GetRef<tir::For>(desc_loops[i]), Integer(i));
+ }
+ return TensorizeInfo(ret);
+}
+
+TVM_REGISTER_GLOBAL("tir.schedule.GetTensorizeLoopMapping")
+ .set_body_typed([](Schedule sch, BlockRV block, PrimFunc desc_func) {
+ return GetTensorizeLoopMapping(sch->state(), sch->GetSRef(block), desc_func);
+ });
+
} // namespace tir
} // namespace tvm
diff --git a/tests/python/unittest/test_tir_schedule_analysis.py b/tests/python/unittest/test_tir_schedule_analysis.py
index 760b412ac8..10371d3cca 100644
--- a/tests/python/unittest/test_tir_schedule_analysis.py
+++ b/tests/python/unittest/test_tir_schedule_analysis.py
@@ -17,18 +17,17 @@
# pylint: disable=missing-docstring
from typing import List
-from tvm.tir import (
- Evaluate,
- For,
- ForKind,
- IndexMap,
- Var,
- decl_buffer,
- floordiv,
- floormod,
-)
+import tvm
+from tvm.tir.tensor_intrin.x86 import dot_product_16x4_u8i8i32_desc
+
+
+from tvm.tir import Evaluate, For, ForKind, IndexMap, Var, decl_buffer, floordiv, floormod, Schedule
from tvm.tir.analysis import expr_deep_equal
-from tvm.tir.schedule.analysis import suggest_index_map
+from tvm.tir.schedule.analysis import suggest_index_map, get_tensorize_loop_mapping, TensorizeInfo
+from tvm.script import tir as T
+from tvm.tir.stmt_functor import pre_order_visit
+from tvm.meta_schedule.testing import te_workload
+from tvm.te import create_prim_func
def _make_vars(*args: str) -> List[Var]:
@@ -102,6 +101,168 @@ def test_suggest_index_map_bijective():
_assert_equal_index_map(index_map, expected_index_map)
+@tvm.script.ir_module
+class DenseVNNIModule:
+ @T.prim_func
+ def main(
+ placeholder: T.Buffer[(1024, 1024), "uint8"],
+ placeholder_1: T.Buffer[(64, 256, 16, 4), "int8"],
+ compute: T.Buffer[(1024, 1024), "int32"],
+ ) -> None:
+ T.func_attr({"global_symbol": "main", "tir.noalias": True})
+ with T.block("root"):
+ T.reads()
+ T.writes()
+ for i0, i1, i2 in T.grid(1024, 1024, 1024):
+ with T.block("compute"):
+ i, j, k = T.axis.remap("SSR", [i0, i1, i2])
+ T.reads(placeholder[i, k], placeholder_1[j // 16, k // 4, j % 16, k % 4])
+ T.writes(compute[i, j])
+ with T.init():
+ compute[i, j] = 0
+ compute[i, j] = compute[i, j] + T.cast(placeholder[i, k], "int32") * T.cast(
+ placeholder_1[j // 16, k // 4, j % 16, k % 4], "int32"
+ )
+
+
+@tvm.script.ir_module
+class Conv2dNCHWcVNNIModule:
+ @T.prim_func
+ def main(
+ placeholder: T.Buffer[(1, 4, 56, 56, 16), "uint8"],
+ placeholder_1: T.Buffer[(16, 4, 1, 1, 4, 16, 4), "int8"],
+ conv2d_NCHWc_int8: T.Buffer[(1, 16, 56, 56, 16), "int32"],
+ ) -> None:
+ T.func_attr({"global_symbol": "main", "tir.noalias": True})
+ for i0, i1, i2, i3, i4, i5, i6, i7, i8, i9 in T.grid(1, 16, 56, 56, 16, 1, 1, 4, 4, 4):
+ with T.block("conv2d_NCHWc_int8"):
+ (
+ n,
+ oc_chunk,
+ oh,
+ ow,
+ oc_block,
+ kh,
+ kw,
+ ic_outer,
+ ic_f_inner,
+ ic_s_inner,
+ ) = T.axis.remap("SSSSSRRRRR", [i0, i1, i2, i3, i4, i5, i6, i7, i8, i9])
+ T.reads(
+ placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 + ic_s_inner],
+ placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, oc_block, ic_s_inner],
+ )
+ T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block])
+ with T.init():
+ conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block] = 0
+ conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block] = conv2d_NCHWc_int8[
+ n, oc_chunk, oh, ow, oc_block
+ ] + T.cast(
+ placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 + ic_s_inner], "int32"
+ ) * T.cast(
+ placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, oc_block, ic_s_inner],
+ "int32",
+ )
+
+
+def collect_loops(prim_func):
+ loops = []
+
+ def callback(node):
+ if isinstance(node, tvm.tir.For):
+ loops.append(node)
+ return True
+
+ pre_order_visit(prim_func.body, callback)
+
+ return loops
+
+
+def test_get_tensorize_loop_mapping_dense_vnni():
+ s = Schedule(DenseVNNIModule)
+ block = s.get_block("compute")
+
+ info = get_tensorize_loop_mapping(s, block, dot_product_16x4_u8i8i32_desc)
+
+ assert isinstance(info, TensorizeInfo)
+
+ desc_loop_to_sref = dict((v, k) for k, v in info.loop_map.items())
+
+ desc_loops = collect_loops(dot_product_16x4_u8i8i32_desc)
+ _, loop_j, loop_k = s.get_loops(block)
+
+ assert desc_loops[0] in desc_loop_to_sref and desc_loops[1] in desc_loop_to_sref
+ assert s.get(desc_loop_to_sref[desc_loops[0]]) == s.get(loop_j)
+ assert s.get(desc_loop_to_sref[desc_loops[1]]) == s.get(loop_k)
+
+
+def test_get_tensorize_loop_mapping_conv2d_nchwc_vnni():
+ s = Schedule(Conv2dNCHWcVNNIModule)
+ block = s.get_block("conv2d_NCHWc_int8")
+
+ info = get_tensorize_loop_mapping(s, block, dot_product_16x4_u8i8i32_desc)
+
+ desc_loop_to_sref = dict((v, k) for k, v in info.loop_map.items())
+
+ desc_loops = collect_loops(dot_product_16x4_u8i8i32_desc)
+
+ # i4 corresonds to the inner output channel axis of the NCHWc output tensor
+ # for i0, i1, i2, i3, i4, i5, i6, i7, i8, i9 in T.grid(1, 16, 56, 56, 16, 1, 1, 4, 4, 4):
+ _, _, _, _, i4, _, _, _, _, i9 = s.get_loops(block)
+
+ assert desc_loops[0] in desc_loop_to_sref and desc_loops[1] in desc_loop_to_sref
+ assert s.get(desc_loop_to_sref[desc_loops[0]]) == s.get(i4)
+ assert s.get(desc_loop_to_sref[desc_loops[1]]) == s.get(i9)
+
+
+def test_get_tensorize_loop_mapping_matmul_mma():
+ @T.prim_func
+ def matmul_16x16x16xf16f16f16_desc(
+ A: T.Buffer((16, 16), "float16", align=128, offset_factor=1),
+ B: T.Buffer((16, 16), "float16", align=128, offset_factor=1),
+ C: T.Buffer((16, 16), "float16", align=128, offset_factor=1),
+ ) -> None:
+ with T.block("root"):
+ T.reads(C[0:16, 0:16], A[0:16, 0:16], B[0:16, 0:16])
+ T.writes(C[0:16, 0:16])
+ for i, j, k in T.grid(16, 16, 16):
+ with T.block("update"):
+ vii, vjj, vkk = T.axis.remap("SSR", [i, j, k])
+ C[vii, vjj] = C[vii, vjj] + A[vii, vkk] * B[vjj, vkk]
+
+ matmul = create_prim_func(
+ te_workload.matmul_relu(
+ n=512,
+ m=512,
+ k=512,
+ )
+ )
+
+ s = Schedule(matmul)
+ block = s.get_block("C")
+ i0, i1, i2 = s.get_loops(block)
+ desc_loops = collect_loops(matmul_16x16x16xf16f16f16_desc)
+
+ for do_reorder in [False, True]:
+ # Mapping should be invariant to the loop permutation
+ if do_reorder:
+ s.reorder(i2, i0, i1)
+
+ info = get_tensorize_loop_mapping(s, block, matmul_16x16x16xf16f16f16_desc)
+ assert info is not None
+ desc_loop_to_sref = dict((v, k) for k, v in info.loop_map.items())
+
+ for i in range(3):
+ assert desc_loops[i] in desc_loop_to_sref
+
+ assert s.get(desc_loop_to_sref[desc_loops[0]]) == s.get(i0)
+ assert s.get(desc_loop_to_sref[desc_loops[1]]) == s.get(i1)
+ assert s.get(desc_loop_to_sref[desc_loops[2]]) == s.get(i2)
+
+
if __name__ == "__main__":
test_suggest_index_map_simple()
test_suggest_index_map_bijective()
+ test_get_tensorize_loop_mapping_dense_vnni()
+ test_get_tensorize_loop_mapping_conv2d_nchwc_vnni()
+ test_get_tensorize_loop_mapping_matmul_mma()