You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2023/04/17 21:12:04 UTC

[tvm] branch unity updated: [Unity] BlockBuilder assigning unique tensor names in call_te (#14632)

This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new b663c58c6b [Unity] BlockBuilder assigning unique tensor names in call_te (#14632)
b663c58c6b is described below

commit b663c58c6bf4d1136bc9dd7fe0652481661bae31
Author: Ruihang Lai <ru...@cs.cmu.edu>
AuthorDate: Mon Apr 17 17:11:57 2023 -0400

    [Unity] BlockBuilder assigning unique tensor names in call_te (#14632)
    
    This PR changes the naming behavior of call_te. With this PR, all
    tensors created by call_te will be named in alphabetical order (`A`,
    `B`, `C`...). When the number of tensors exceeds 26, the trailing
    tensors will be named as `input26`, `input27` and so on so forth.
    
     ### Background
    
    Prior to this PR, call_te uses the default name "rxplaceholder" for
    every tensor. On one hand, this name is too long and doesn't look clean
    in the printed TVMScript. On the other hand, using this name for every
    tensor is problematic from the perspective of TIR and will lead to
    scheduling error. For example, previously, the following code snippet
    will print the TIR function where two cache-read blocks having the same
    name, which is not a legal TIR function.
    ```python
    bb = relax.BlockBuilder()
    x = relax.Var("x", R.Tensor((2, 3), "float32"))
    y = relax.Var("y", R.Tensor((3, 4), "float32"))
    with bb.function("main", [x, y]):
        gv = bb.emit_te(topi.nn.matmul, x, y)
        bb.emit_func_output(gv)
    
    sch = tir.Schedule(bb.get())
    sch.work_on("matmul")
    sch.cache_read("T_matmul_NN", 0, "global")
    sch.cache_read("T_matmul_NN", 1, "global")
    print(sch.mod["matmul"].script())
    
     ## Output:
    @T.prim_func
    def matmul(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), rxplaceholder_1: T.Buffer((T.int64(3), T.int64(4)), "float32"), T_matmul_NN: T.Buffer((T.int64(2), T.int64(4)), "float32")):
        T.func_attr({"layout_free_buffers": [1], "tir.noalias": T.bool(True)})
        # with T.block("root"):
        rxplaceholder_global = T.alloc_buffer((T.int64(2), T.int64(3)))
        rxplaceholder_global_1 = T.alloc_buffer((T.int64(3), T.int64(4)))
        for ax0, ax1 in T.grid(T.int64(3), T.int64(4)):
            with T.block("rxplaceholder_global"):
                v0, v1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(rxplaceholder_1[v0, v1])
                T.writes(rxplaceholder_global_1[v0, v1])
                rxplaceholder_global_1[v0, v1] = rxplaceholder_1[v0, v1]
        for ax0, ax1 in T.grid(T.int64(2), T.int64(3)):
            with T.block("rxplaceholder_global"):
                v0, v1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(rxplaceholder[v0, v1])
                T.writes(rxplaceholder_global[v0, v1])
                rxplaceholder_global[v0, v1] = rxplaceholder[v0, v1]
        for i, j, k in T.grid(T.int64(2), T.int64(4), T.int64(3)):
            with T.block("T_matmul_NN"):
                v_i, v_j, v_k = T.axis.remap("SSR", [i, j, k])
                T.reads(rxplaceholder_global[v_i, v_k], rxplaceholder_global_1[v_k, v_j])
                T.writes(T_matmul_NN[v_i, v_j])
                with T.init():
                    T_matmul_NN[v_i, v_j] = T.float32(0)
                T_matmul_NN[v_i, v_j] = T_matmul_NN[v_i, v_j] + rxplaceholder_global[v_i, v_k] * rxplaceholder_global_1[v_k, v_j]
    ```
---
 python/tvm/relax/utils.py                    |  7 ++++++-
 tests/python/relax/test_blockbuilder_core.py | 18 ++++++++++++++++++
 2 files changed, 24 insertions(+), 1 deletion(-)

diff --git a/python/tvm/relax/utils.py b/python/tvm/relax/utils.py
index 02d9b1e0d4..3998fb8863 100644
--- a/python/tvm/relax/utils.py
+++ b/python/tvm/relax/utils.py
@@ -348,7 +348,10 @@ def gen_call_tir_inputs(
 
             tir.stmt_functor.post_order_visit(expr, _visit_expr)
 
+        n_tensor = 0
+
         def _convert_te_arg_helper(arg):
+            nonlocal n_tensor
             if isinstance(arg, Expr):  # type: ignore
                 if isinstance(arg.struct_info, TensorStructInfo):
                     assert isinstance(
@@ -357,7 +360,9 @@ def gen_call_tir_inputs(
                     for shape_value in arg.struct_info.shape.values:
                         _copy_undefined_var(shape_value)
 
-                    arg = te_tensor(arg, tir_var_map)
+                    name = chr(ord("A") + n_tensor) if n_tensor < 26 else f"input{n_tensor}"
+                    arg = te_tensor(arg, tir_var_map, name)
+                    n_tensor += 1
                     te_args_list.append(arg)
                     return arg
                 if isinstance(arg.struct_info, ShapeStructInfo):
diff --git a/tests/python/relax/test_blockbuilder_core.py b/tests/python/relax/test_blockbuilder_core.py
index 9932227854..f0b14933d1 100644
--- a/tests/python/relax/test_blockbuilder_core.py
+++ b/tests/python/relax/test_blockbuilder_core.py
@@ -333,6 +333,24 @@ def test_call_te():
     assert len(rx_func.body.blocks[0].bindings) == 1
 
 
+def test_call_te_unique_tensor_name():
+    bb = rx.BlockBuilder()
+    x = rx.Var("x", R.Tensor((2, 3), "float32"))
+    y = rx.Var("y", R.Tensor((3, 4), "float32"))
+    with bb.function("main", [x, y]):
+        gv = bb.emit_te(topi.nn.matmul, x, y)
+        bb.emit_func_output(gv)
+
+    f_matmul = bb.get()["matmul"]
+    param_A = f_matmul.params[0]
+    param_B = f_matmul.params[1]
+    buffer_A = f_matmul.buffer_map[param_A]
+    buffer_B = f_matmul.buffer_map[param_B]
+    assert param_A.name != param_B.name
+    assert buffer_A.name != buffer_B.name
+    assert buffer_A.data.name != buffer_B.data.name
+
+
 def test_call_te_with_unsupported_shape_arg():
     bb = rx.BlockBuilder()
     x = rx.Var("x", rx.TensorStructInfo((200,), "float32"))