You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by lu...@apache.org on 2022/11/30 14:17:57 UTC
[tvm] branch main updated: [TE][TIR] Improved naming when converting TE to schedulable TIR (#13431)
This is an automated email from the ASF dual-hosted git repository.
lunderberg pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new c2dd53d531 [TE][TIR] Improved naming when converting TE to schedulable TIR (#13431)
c2dd53d531 is described below
commit c2dd53d5315b3073a14ced200ab55426ac69904e
Author: Eric Lunderberg <Lu...@users.noreply.github.com>
AuthorDate: Wed Nov 30 08:17:44 2022 -0600
[TE][TIR] Improved naming when converting TE to schedulable TIR (#13431)
Prior to this commit, loop iterators were named `i0`, `i1`, and so on,
while the `BlockNode::iter_vars` used the name from the TE `IterVar`.
As a result, after `BlockNode::iter_vars` is lowered out, the
resulting `PrimFunc` no longer contained the user-generated iterator
names. This commit updates the TIR conversion so that the loop
iterators take the name of the TE `IterVar`, and the
`BlockNode::iter_vars` are named `v_$IterVarName`.
---
src/te/operation/create_primfunc.cc | 14 +++++++-------
.../unittest/test_meta_schedule_schedule_rule_mlt.py | 10 +++++-----
.../test_tir_transform_inject_software_pipeline.py | 4 ++--
3 files changed, 14 insertions(+), 14 deletions(-)
diff --git a/src/te/operation/create_primfunc.cc b/src/te/operation/create_primfunc.cc
index 223f8dcd5d..21456af1bd 100644
--- a/src/te/operation/create_primfunc.cc
+++ b/src/te/operation/create_primfunc.cc
@@ -153,7 +153,7 @@ BlockRealize GenerateBlockFromTensors(const te::ComputeOp& compute_op,
auto f_push_block_vars = [&iter_vars, &var_map, &analyzer](const Array<IterVar>& iters) {
for (IterVar iter_var : iters) {
// Create new var
- Var new_var(iter_var->var->name_hint, iter_var->var->dtype);
+ Var new_var("v_" + iter_var->var->name_hint, iter_var->var->dtype);
var_map[iter_var->var.get()] = new_var;
PrimExpr dom_min = analyzer->Simplify(iter_var->dom->min);
@@ -307,12 +307,12 @@ Stmt GenerateStmtFromCompute(const te::ComputeOp& compute_op, CreateFuncInfo* in
// Step 1. Creating loop vars for block bindings.
Array<IterVar> axes = compute_op->axis;
axes.insert(axes.end(), compute_op->reduce_axis.begin(), compute_op->reduce_axis.end());
- Array<PrimExpr> bindings;
- for (size_t i = 0; i < axes.size(); ++i) {
- const IterVar& axis = axes[i];
- int bits = std::max(axis->dom->min.dtype().bits(), axis->dom->extent.dtype().bits());
- bindings.push_back(Var("i" + std::to_string(i), runtime::DataType::Int(bits)));
- }
+
+ Array<PrimExpr> bindings = axes.Map([&](IterVar iter_var) -> PrimExpr {
+ int bits = std::max(iter_var->dom->min.dtype().bits(), iter_var->dom->extent.dtype().bits());
+ return Var(iter_var->var->name_hint, runtime::DataType::Int(bits));
+ });
+
// Step 2. Generate block bodies.
Array<Stmt> seq_stmt;
if (compute_op->body[0]->IsInstance<ReduceNode>()) {
diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_mlt.py b/tests/python/unittest/test_meta_schedule_schedule_rule_mlt.py
index 24e3430220..2c5a44d7a2 100644
--- a/tests/python/unittest/test_meta_schedule_schedule_rule_mlt.py
+++ b/tests/python/unittest/test_meta_schedule_schedule_rule_mlt.py
@@ -635,12 +635,12 @@ def test_cache_read_specify_consumer():
)
residual_block = """
- for i0, i1 in T.grid(512, 512):
+ for ax0, ax1 in T.grid(512, 512):
with T.block("T_add"):
- ax0, ax1 = T.axis.remap("SS", [i0, i1])
- T.reads(C[ax0, ax1], A[ax0, ax1])
- T.writes(T_add[ax0, ax1])
- T_add[ax0, ax1] = C[ax0, ax1] + A[ax0, ax1]
+ v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+ T.reads(C[v_ax0, v_ax1], A[v_ax0, v_ax1])
+ T.writes(T_add[v_ax0, v_ax1])
+ T_add[v_ax0, v_ax1] = C[v_ax0, v_ax1] + A[v_ax0, v_ax1]
"""
assert residual_block in space[0].mod.script()
diff --git a/tests/python/unittest/test_tir_transform_inject_software_pipeline.py b/tests/python/unittest/test_tir_transform_inject_software_pipeline.py
index c70525b057..006b67d626 100644
--- a/tests/python/unittest/test_tir_transform_inject_software_pipeline.py
+++ b/tests/python/unittest/test_tir_transform_inject_software_pipeline.py
@@ -1507,7 +1507,7 @@ def test_async_pipelined_mma_gemm_simple():
assert body.block.body.body[1].block.body.body.value == 3
assert epilogue.block.body.body.block.body.body.attr_key == "async_wait_inflight_count"
- assert str(epilogue.block.body.body.block.body.body.value) == "(2 - i2_0_0: int32)"
+ assert str(epilogue.block.body.body.block.body.body.value) == "(2 - k_0_0: int32)"
build_and_run(sch)
@@ -1554,7 +1554,7 @@ def test_async_nested_pipeline_mma_gemm_ideal_annotation():
assert body.block.body.body[1].block.body.body.attr_key == "async_wait_inflight_count"
assert body.block.body.body[1].block.body.body.value == 2
- assert str(epilogue.block.body.body[0].block.body.body.value) == "(1 - i2_0_0: int32)"
+ assert str(epilogue.block.body.body[0].block.body.body.value) == "(1 - k_0_0: int32)"
build_and_run(sch)