You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by lu...@apache.org on 2022/11/30 14:17:57 UTC

[tvm] branch main updated: [TE][TIR] Improved naming when converting TE to schedulable TIR (#13431)

This is an automated email from the ASF dual-hosted git repository.

lunderberg pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new c2dd53d531 [TE][TIR] Improved naming when converting TE to schedulable TIR (#13431)
c2dd53d531 is described below

commit c2dd53d5315b3073a14ced200ab55426ac69904e
Author: Eric Lunderberg <Lu...@users.noreply.github.com>
AuthorDate: Wed Nov 30 08:17:44 2022 -0600

    [TE][TIR] Improved naming when converting TE to schedulable TIR (#13431)
    
    Prior to this commit, loop iterators were named `i0`, `i1`, and so on,
    while the `BlockNode::iter_vars` used the name from the TE `IterVar`.
    As a result, after `BlockNode::iter_vars` is lowered out, the
    resulting `PrimFunc` no longer contained the user-generated iterator
    names.  This commit updates the TIR conversion so that the loop
    iterators take the name of the TE `IterVar`, and the
    `BlockNode::iter_vars` are named `v_$IterVarName`.
---
 src/te/operation/create_primfunc.cc                        | 14 +++++++-------
 .../unittest/test_meta_schedule_schedule_rule_mlt.py       | 10 +++++-----
 .../test_tir_transform_inject_software_pipeline.py         |  4 ++--
 3 files changed, 14 insertions(+), 14 deletions(-)

diff --git a/src/te/operation/create_primfunc.cc b/src/te/operation/create_primfunc.cc
index 223f8dcd5d..21456af1bd 100644
--- a/src/te/operation/create_primfunc.cc
+++ b/src/te/operation/create_primfunc.cc
@@ -153,7 +153,7 @@ BlockRealize GenerateBlockFromTensors(const te::ComputeOp& compute_op,
   auto f_push_block_vars = [&iter_vars, &var_map, &analyzer](const Array<IterVar>& iters) {
     for (IterVar iter_var : iters) {
       // Create new var
-      Var new_var(iter_var->var->name_hint, iter_var->var->dtype);
+      Var new_var("v_" + iter_var->var->name_hint, iter_var->var->dtype);
       var_map[iter_var->var.get()] = new_var;
 
       PrimExpr dom_min = analyzer->Simplify(iter_var->dom->min);
@@ -307,12 +307,12 @@ Stmt GenerateStmtFromCompute(const te::ComputeOp& compute_op, CreateFuncInfo* in
   // Step 1. Creating loop vars for block bindings.
   Array<IterVar> axes = compute_op->axis;
   axes.insert(axes.end(), compute_op->reduce_axis.begin(), compute_op->reduce_axis.end());
-  Array<PrimExpr> bindings;
-  for (size_t i = 0; i < axes.size(); ++i) {
-    const IterVar& axis = axes[i];
-    int bits = std::max(axis->dom->min.dtype().bits(), axis->dom->extent.dtype().bits());
-    bindings.push_back(Var("i" + std::to_string(i), runtime::DataType::Int(bits)));
-  }
+
+  Array<PrimExpr> bindings = axes.Map([&](IterVar iter_var) -> PrimExpr {
+    int bits = std::max(iter_var->dom->min.dtype().bits(), iter_var->dom->extent.dtype().bits());
+    return Var(iter_var->var->name_hint, runtime::DataType::Int(bits));
+  });
+
   // Step 2. Generate block bodies.
   Array<Stmt> seq_stmt;
   if (compute_op->body[0]->IsInstance<ReduceNode>()) {
diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_mlt.py b/tests/python/unittest/test_meta_schedule_schedule_rule_mlt.py
index 24e3430220..2c5a44d7a2 100644
--- a/tests/python/unittest/test_meta_schedule_schedule_rule_mlt.py
+++ b/tests/python/unittest/test_meta_schedule_schedule_rule_mlt.py
@@ -635,12 +635,12 @@ def test_cache_read_specify_consumer():
     )
 
     residual_block = """
-        for i0, i1 in T.grid(512, 512):
+        for ax0, ax1 in T.grid(512, 512):
             with T.block("T_add"):
-                ax0, ax1 = T.axis.remap("SS", [i0, i1])
-                T.reads(C[ax0, ax1], A[ax0, ax1])
-                T.writes(T_add[ax0, ax1])
-                T_add[ax0, ax1] = C[ax0, ax1] + A[ax0, ax1]
+                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                T.reads(C[v_ax0, v_ax1], A[v_ax0, v_ax1])
+                T.writes(T_add[v_ax0, v_ax1])
+                T_add[v_ax0, v_ax1] = C[v_ax0, v_ax1] + A[v_ax0, v_ax1]
     """
 
     assert residual_block in space[0].mod.script()
diff --git a/tests/python/unittest/test_tir_transform_inject_software_pipeline.py b/tests/python/unittest/test_tir_transform_inject_software_pipeline.py
index c70525b057..006b67d626 100644
--- a/tests/python/unittest/test_tir_transform_inject_software_pipeline.py
+++ b/tests/python/unittest/test_tir_transform_inject_software_pipeline.py
@@ -1507,7 +1507,7 @@ def test_async_pipelined_mma_gemm_simple():
     assert body.block.body.body[1].block.body.body.value == 3
 
     assert epilogue.block.body.body.block.body.body.attr_key == "async_wait_inflight_count"
-    assert str(epilogue.block.body.body.block.body.body.value) == "(2 - i2_0_0: int32)"
+    assert str(epilogue.block.body.body.block.body.body.value) == "(2 - k_0_0: int32)"
 
     build_and_run(sch)
 
@@ -1554,7 +1554,7 @@ def test_async_nested_pipeline_mma_gemm_ideal_annotation():
     assert body.block.body.body[1].block.body.body.attr_key == "async_wait_inflight_count"
     assert body.block.body.body[1].block.body.body.value == 2
 
-    assert str(epilogue.block.body.body[0].block.body.body.value) == "(1 - i2_0_0: int32)"
+    assert str(epilogue.block.body.body[0].block.body.body.value) == "(1 - k_0_0: int32)"
 
     build_and_run(sch)