You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2022/11/10 04:59:59 UTC

[tvm] branch feature/2022-11-09/printer-explicit-ir-node created (now 9cd4ad0077)

This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a change to branch feature/2022-11-09/printer-explicit-ir-node
in repository https://gitbox.apache.org/repos/asf/tvm.git


      at 9cd4ad0077 [TIR] Make syntax of AST nodes different than ops

This branch includes the following new commits:

     new 9cd4ad0077 [TIR] Make syntax of AST nodes different than ops

The 1 revisions listed above as "new" are entirely new to this
repository and will be described in separate emails.  The revisions
listed as "add" were already present in the repository and have only
been added to this reference.



[tvm] 01/01: [TIR] Make syntax of AST nodes different than ops

Posted by ju...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch feature/2022-11-09/printer-explicit-ir-node
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit 9cd4ad0077b15bdc03dfcc2db19f99ab224c9d79
Author: Junru Shao <ju...@gmail.com>
AuthorDate: Wed Nov 9 19:29:17 2022 -0800

    [TIR] Make syntax of AST nodes different than ops
    
    As part of effort of more formal TIR semantics, we want to more
    explicitly differentiate TIR AST nodes (defined in `tir/expr.h`)
    and TIR ops (defined in `tir/op.h`).
    
    A naming convention is that:
    - Lowercased methods, for example, `tvm.tir.mul`, means an TIR op, which
      will be eagerly constant-folded, i.e. `mul(1, 2)` returns `3`
      immediately rather than creating an AST node.
    - Capitalized callable, for example, `Mul`, means creating an AST node
      without constant folding.
    
    This PR makes this behavior more explictly by printing `T.Mul(a, b)`
    directly when `a` and `b` are both constants, rather than sugaring it
    into `mul(a. b)` or `a * b`, so that the difference between an op and
    an AST node is clarified.
    
    Co-authored-by: Yaxing Cai <ca...@gmail.com>
---
 python/tvm/script/tir/intrin.py                    | 80 +++++++++++++++++-
 src/printer/tvmscript_printer.cc                   | 97 +++++++++++++---------
 .../test_hexagon/test_async_dma_pipeline.py        | 23 +++--
 .../test_hexagon/test_parallel_hvx_load_vtcm.py    | 49 ++++-------
 .../unittest/test_aot_legalize_packed_call.py      | 12 +--
 .../unittest/test_meta_schedule_space_cuda.py      |  2 +-
 .../test_tir_transform_inject_software_pipeline.py | 16 ++--
 .../test_tir_transform_inject_virtual_thread.py    | 17 ++--
 .../unittest/test_tir_transform_thread_sync.py     |  2 +-
 9 files changed, 185 insertions(+), 113 deletions(-)

diff --git a/python/tvm/script/tir/intrin.py b/python/tvm/script/tir/intrin.py
index bd9aa1fdad..8e24f27325 100644
--- a/python/tvm/script/tir/intrin.py
+++ b/python/tvm/script/tir/intrin.py
@@ -17,12 +17,13 @@
 """TVM Script Parser Intrinsic Classes"""
 # pylint: disable=redefined-builtin, relative-beyond-top-level
 import builtins
-from typing import List, Any
+from typing import Any, List
 
 import tvm.tir
 from tvm.tir import FloatImm
-from ..registry import register
+
 from ...target import codegen
+from ..registry import register
 from ..utils import get_param_list, tvm_span_from_synr
 
 
@@ -229,3 +230,78 @@ def comm_reducer(lambda_io, identities, span):
 def llvm_lookup_intrinsic_id(name, span):
     # pylint: disable=unused-argument
     return codegen.llvm_lookup_intrinsic_id(name)
+
+
+@register
+def FloorMod(x, y, span):  # pylint: disable=invalid-name
+    return tvm.tir.FloorMod(x, y, span)
+
+
+@register
+def FloorDiv(x, y, span):  # pylint: disable=invalid-name
+    return tvm.tir.FloorDiv(x, y, span)
+
+
+@register
+def Mul(x, y, span):  # pylint: disable=invalid-name
+    return tvm.tir.Mul(x, y, span)
+
+
+@register
+def Div(x, y, span):  # pylint: disable=invalid-name
+    return tvm.tir.Div(x, y, span)
+
+
+@register
+def Add(x, y, span):  # pylint: disable=invalid-name
+    return tvm.tir.Add(x, y, span)
+
+
+@register
+def Sub(x, y, span):  # pylint: disable=invalid-name
+    return tvm.tir.Sub(x, y, span)
+
+
+@register
+def LT(x, y, span):  # pylint: disable=invalid-name
+    return tvm.tir.LT(x, y, span)
+
+
+@register
+def LE(x, y, span):  # pylint: disable=invalid-name
+    return tvm.tir.LE(x, y, span)
+
+
+@register
+def GT(x, y, span):  # pylint: disable=invalid-name
+    return tvm.tir.GT(x, y, span)
+
+
+@register
+def GE(x, y, span):  # pylint: disable=invalid-name
+    return tvm.tir.GE(x, y, span)
+
+
+@register
+def EQ(x, y, span):  # pylint: disable=invalid-name
+    return tvm.tir.EQ(x, y, span)
+
+
+@register
+def NE(x, y, span):  # pylint: disable=invalid-name
+    return tvm.tir.NE(x, y, span)
+
+
+@register
+def And(x, y, span):  # pylint: disable=invalid-name
+    return tvm.tir.And(x, y, span)
+
+
+@register
+def Or(x, y, span):  # pylint: disable=invalid-name
+    return tvm.tir.Or(x, y, span)
+
+
+@register
+def Cast(dtype, value, span):  # pylint: disable=invalid-name
+    return tvm.tir.Cast(dtype, value, span)
diff --git a/src/printer/tvmscript_printer.cc b/src/printer/tvmscript_printer.cc
index 64a576ef52..d7a3a406e3 100644
--- a/src/printer/tvmscript_printer.cc
+++ b/src/printer/tvmscript_printer.cc
@@ -788,7 +788,7 @@ Doc TVMScriptPrinter::VisitExpr_(const StringImmNode* op, ExprPrecedence* out_pr
 Doc TVMScriptPrinter::VisitExpr_(const CastNode* op, ExprPrecedence* out_precedence) {
   *out_precedence = ExprPrecedence::kIdentity;
   Doc doc;
-  doc << tir_prefix_ << ".cast(" << Print(op->value) << ", " << PrintDType(op->dtype) << ")";
+  doc << tir_prefix_ << ".Cast(" << PrintDType(op->dtype) << ", " << Print(op->value) << ")";
   return doc;
 }
 
@@ -798,46 +798,61 @@ Doc TVMScriptPrinter::VisitExpr_(const VarNode* op, ExprPrecedence* out_preceden
   return meta_.InMeta(var) ? meta_.GetMetaNode(var) : AllocVar(GetRef<Var>(op));
 }
 
-#define TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(OpName, OpString, OpPrecedence)            \
-  Doc TVMScriptPrinter::VisitExpr_(const OpName* op, ExprPrecedence* out_precedence) { \
-    Doc doc;                                                                           \
-    ExprPrecedence lhs_precedence = ExprPrecedence::kUnknown;                          \
-    ExprPrecedence rhs_precedence = ExprPrecedence::kUnknown;                          \
-    /* Get children expr out_precedence */                                             \
-    Doc lhs_doc = VisitExpr(op->a, &lhs_precedence);                                   \
-    Doc rhs_doc = VisitExpr(op->b, &rhs_precedence);                                   \
-    ICHECK(lhs_precedence != ExprPrecedence::kUnknown);                                \
-    ICHECK(rhs_precedence != ExprPrecedence::kUnknown);                                \
-    /* Update out_precedence of current node. */                                       \
-    *out_precedence = OpPrecedence;                                                    \
-    if (lhs_precedence > OpPrecedence) {                                               \
-      doc << "(" << lhs_doc << ")";                                                    \
-    } else {                                                                           \
-      doc << lhs_doc;                                                                  \
-    }                                                                                  \
-    doc << OpString;                                                                   \
-    if (rhs_precedence >= OpPrecedence) {                                              \
-      doc << "(" << rhs_doc << ")";                                                    \
-    } else {                                                                           \
-      doc << rhs_doc;                                                                  \
-    }                                                                                  \
-    return doc;                                                                        \
-  }
-
-TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(MulNode, " * ", ExprPrecedence::kMultiplicationDivision)
-TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(DivNode, " / ", ExprPrecedence::kMultiplicationDivision)
-TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(FloorDivNode, " // ", ExprPrecedence::kMultiplicationDivision)
-TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(FloorModNode, " % ", ExprPrecedence::kMultiplicationDivision)
-TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(AddNode, " + ", ExprPrecedence::kAdditionSubtraction)
-TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(SubNode, " - ", ExprPrecedence::kAdditionSubtraction)
-TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(LTNode, " < ", ExprPrecedence::kRelational)
-TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(LENode, " <= ", ExprPrecedence::kRelational)
-TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(GTNode, " > ", ExprPrecedence::kRelational)
-TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(GENode, " >= ", ExprPrecedence::kRelational)
-TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(EQNode, " == ", ExprPrecedence::kEquality)
-TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(NENode, " != ", ExprPrecedence::kEquality)
-TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(AndNode, " and ", ExprPrecedence::kAnd)
-TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(OrNode, " or ", ExprPrecedence::kOr)
+bool WillPrintConstScalar(const PrimExpr& expr) {
+  if (const auto* imm = expr.as<IntImmNode>()) {
+    DataType dtype = imm->dtype;
+    return dtype == DataType::Int(32) || dtype == DataType::Bool();
+  }
+  return false;
+}
+
+#define TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(OpName, OpString, OpClass, OpPrecedence)              \
+  Doc TVMScriptPrinter::VisitExpr_(const OpName* op, ExprPrecedence* out_precedence) {            \
+    Doc doc;                                                                                      \
+    if (WillPrintConstScalar(op->a) && WillPrintConstScalar(op->b)) {                             \
+      *out_precedence = ExprPrecedence::kIdentity;                                                \
+      doc << tir_prefix_ << "." << OpClass << "(" << Print(op->a) << ", " << Print(op->b) << ")"; \
+      return doc;                                                                                 \
+    }                                                                                             \
+    ExprPrecedence lhs_precedence = ExprPrecedence::kUnknown;                                     \
+    ExprPrecedence rhs_precedence = ExprPrecedence::kUnknown;                                     \
+    /* Get children expr out_precedence */                                                        \
+    Doc lhs_doc = VisitExpr(op->a, &lhs_precedence);                                              \
+    Doc rhs_doc = VisitExpr(op->b, &rhs_precedence);                                              \
+    ICHECK(lhs_precedence != ExprPrecedence::kUnknown);                                           \
+    ICHECK(rhs_precedence != ExprPrecedence::kUnknown);                                           \
+    /* Update out_precedence of current node. */                                                  \
+    *out_precedence = OpPrecedence;                                                               \
+    if (lhs_precedence > OpPrecedence) {                                                          \
+      doc << "(" << lhs_doc << ")";                                                               \
+    } else {                                                                                      \
+      doc << lhs_doc;                                                                             \
+    }                                                                                             \
+    doc << OpString;                                                                              \
+    if (rhs_precedence >= OpPrecedence) {                                                         \
+      doc << "(" << rhs_doc << ")";                                                               \
+    } else {                                                                                      \
+      doc << rhs_doc;                                                                             \
+    }                                                                                             \
+    return doc;                                                                                   \
+  }
+
+TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(MulNode, " * ", "Mul", ExprPrecedence::kMultiplicationDivision)
+TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(DivNode, " / ", "Div", ExprPrecedence::kMultiplicationDivision)
+TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(FloorDivNode, " // ", "FloorDiv",
+                                    ExprPrecedence::kMultiplicationDivision)
+TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(FloorModNode, " % ", "FloorMod",
+                                    ExprPrecedence::kMultiplicationDivision)
+TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(AddNode, " + ", "Add", ExprPrecedence::kAdditionSubtraction)
+TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(SubNode, " - ", "Sub", ExprPrecedence::kAdditionSubtraction)
+TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(LTNode, " < ", "LT", ExprPrecedence::kRelational)
+TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(LENode, " <= ", "LE", ExprPrecedence::kRelational)
+TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(GTNode, " > ", "GT", ExprPrecedence::kRelational)
+TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(GENode, " >= ", "GE", ExprPrecedence::kRelational)
+TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(EQNode, " == ", "EQ", ExprPrecedence::kEquality)
+TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(NENode, " != ", "NE", ExprPrecedence::kEquality)
+TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(AndNode, " and ", "And", ExprPrecedence::kAnd)
+TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(OrNode, " or ", "Or", ExprPrecedence::kOr)
 
 Doc TVMScriptPrinter::VisitExpr_(const ModNode* op, ExprPrecedence* out_precedence) {
   *out_precedence = ExprPrecedence::kIdentity;
diff --git a/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py b/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py
index a7a05c2aa3..19b380c1bd 100644
--- a/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py
+++ b/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py
@@ -18,11 +18,10 @@
 """ Test different strategies for loading data into vtcm before running HVX workloads. """
 
 import numpy as np
-import tvm
 import pytest
-
-from tvm.script import tir as T
+import tvm
 from numpy.random import default_rng
+from tvm.script import tir as T
 
 VRMPY_SIZE_B = 128
 VRMPY_SIZE_INT32 = 32
@@ -126,12 +125,12 @@ def get_single_dma_schedule(size_a, size_w):
     @T.prim_func
     def operator(a_input: T.handle, b_input: T.handle, c_output: T.handle) -> None:
         T.func_attr({"global_symbol": "main", "tir.noalias": True})
-        a_buffer = T.match_buffer(a_input, a_shape, dtype="uint8", mem_scope="global")
-        w_buffer = T.match_buffer(b_input, w_shape, dtype="uint8", mem_scope="global")
-        c_buffer = T.match_buffer(c_output, out_shape, dtype="int32", mem_scope="global")
-        a_global_vtcm = T.alloc_buffer(a_shape, dtype="uint8", mem_scope="global.vtcm")
-        w_global_vtcm = T.alloc_buffer(w_shape, dtype="uint8", mem_scope="global.vtcm")
-        c_global_vtcm = T.alloc_buffer(out_shape, dtype="int32", mem_scope="global.vtcm")
+        a_buffer = T.match_buffer(a_input, a_shape, dtype="uint8", scope="global")
+        w_buffer = T.match_buffer(b_input, w_shape, dtype="uint8", scope="global")
+        c_buffer = T.match_buffer(c_output, out_shape, dtype="int32", scope="global")
+        a_global_vtcm = T.alloc_buffer(a_shape, dtype="uint8", scope="global.vtcm")
+        w_global_vtcm = T.alloc_buffer(w_shape, dtype="uint8", scope="global.vtcm")
+        c_global_vtcm = T.alloc_buffer(out_shape, dtype="int32", scope="global.vtcm")
         T.evaluate(
             T.tvm_call_packed(
                 "device_api.hexagon.mem_copy_DLTensor",
@@ -153,7 +152,7 @@ def get_single_dma_schedule(size_a, size_w):
                     0,
                     dtype="handle",
                 ),
-                T.cast(a_bytes, dtype="int"),
+                T.Cast("int", a_bytes),
                 dtype="int32",
             )
         )
@@ -178,7 +177,7 @@ def get_single_dma_schedule(size_a, size_w):
                     0,
                     dtype="handle",
                 ),
-                T.cast(w_bytes, dtype="int"),
+                T.Cast("int", w_bytes),
                 dtype="int32",
             )
         )
@@ -222,7 +221,7 @@ def get_single_dma_schedule(size_a, size_w):
                     0,
                     dtype="handle",
                 ),
-                T.cast(a_bytes, dtype="int"),
+                T.Cast("int", a_bytes),
                 dtype="int32",
             )
         )
diff --git a/tests/python/contrib/test_hexagon/test_parallel_hvx_load_vtcm.py b/tests/python/contrib/test_hexagon/test_parallel_hvx_load_vtcm.py
index fb398f4397..e6fc0a3c20 100644
--- a/tests/python/contrib/test_hexagon/test_parallel_hvx_load_vtcm.py
+++ b/tests/python/contrib/test_hexagon/test_parallel_hvx_load_vtcm.py
@@ -18,9 +18,8 @@
 """ Test different strategies for loading data into vtcm before running HVX workloads. """
 
 import numpy as np
-from numpy.random import default_rng
-
 import tvm
+from numpy.random import default_rng
 from tvm.script import tir as T
 
 from .infrastructure import get_hexagon_target
@@ -109,17 +108,17 @@ def preloaded_vrmpy(operations):
             [T.cast(operations, "int32") * 128],
             dtype="uint8",
             align=128,
-            mem_scope="global.vtcm",
+            scope="global.vtcm",
         )
         b_buffer = T.match_buffer(
             b,
             [T.cast(operations, "int32") * 128],
             dtype="uint8",
             align=128,
-            mem_scope="global.vtcm",
+            scope="global.vtcm",
         )
         c_buffer = T.match_buffer(
-            c, [T.cast(operations, "int32") * 32], dtype="int32", align=128, mem_scope="global.vtcm"
+            c, [T.cast(operations, "int32") * 32], dtype="int32", align=128, scope="global.vtcm"
         )
         for n in T.grid(operations):
             with T.block("c_buffer"):
@@ -149,21 +148,13 @@ def preallocated_vrmpy(operations):
         a: T.handle, b: T.handle, c: T.handle, a_v: T.handle, b_v: T.handle, c_v: T.handle
     ) -> None:
         T.func_attr({"global_symbol": "main", "tir.noalias": True})
-        a_buffer = T.match_buffer(
-            a, [operations, 128], dtype="uint8", align=128, mem_scope="global"
-        )
-        b_buffer = T.match_buffer(
-            b, [operations, 128], dtype="uint8", align=128, mem_scope="global"
-        )
-        c_buffer = T.match_buffer(c, [operations, 32], dtype="int32", align=128, mem_scope="global")
-        a_global_vtcm = T.match_buffer(
-            a_v, [size], dtype="uint8", align=128, mem_scope="global.vtcm"
-        )
-        b_global_vtcm = T.match_buffer(
-            b_v, [size], dtype="uint8", align=128, mem_scope="global.vtcm"
-        )
+        a_buffer = T.match_buffer(a, [operations, 128], dtype="uint8", align=128, scope="global")
+        b_buffer = T.match_buffer(b, [operations, 128], dtype="uint8", align=128, scope="global")
+        c_buffer = T.match_buffer(c, [operations, 32], dtype="int32", align=128, scope="global")
+        a_global_vtcm = T.match_buffer(a_v, [size], dtype="uint8", align=128, scope="global.vtcm")
+        b_global_vtcm = T.match_buffer(b_v, [size], dtype="uint8", align=128, scope="global.vtcm")
         c_global_vtcm = T.match_buffer(
-            c_v, [out_size], dtype="int32", align=128, mem_scope="global.vtcm"
+            c_v, [out_size], dtype="int32", align=128, scope="global.vtcm"
         )
         for n, i in T.grid(operations, 128):
             with T.block("a_buffer_global.vtcm"):
@@ -212,21 +203,13 @@ def preallocated_single_dma_vrmpy(operations):
         c_v: T.handle,
     ) -> None:
         T.func_attr({"global_symbol": "main", "tir.noalias": True})
-        a_buffer = T.match_buffer(
-            a, [operations, 128], dtype="uint8", align=128, mem_scope="global"
-        )
-        b_buffer = T.match_buffer(
-            b, [operations, 128], dtype="uint8", align=128, mem_scope="global"
-        )
-        c_buffer = T.match_buffer(c, [operations, 32], dtype="int32", align=128, mem_scope="global")
-        a_global_vtcm = T.match_buffer(
-            a_v, [size], dtype="uint8", align=128, mem_scope="global.vtcm"
-        )
-        b_global_vtcm = T.match_buffer(
-            b_v, [size], dtype="uint8", align=128, mem_scope="global.vtcm"
-        )
+        a_buffer = T.match_buffer(a, [operations, 128], dtype="uint8", align=128, scope="global")
+        b_buffer = T.match_buffer(b, [operations, 128], dtype="uint8", align=128, scope="global")
+        c_buffer = T.match_buffer(c, [operations, 32], dtype="int32", align=128, scope="global")
+        a_global_vtcm = T.match_buffer(a_v, [size], dtype="uint8", align=128, scope="global.vtcm")
+        b_global_vtcm = T.match_buffer(b_v, [size], dtype="uint8", align=128, scope="global.vtcm")
         c_global_vtcm = T.match_buffer(
-            c_v, [out_size], dtype="int32", align=128, mem_scope="global.vtcm"
+            c_v, [out_size], dtype="int32", align=128, scope="global.vtcm"
         )
         T.evaluate(
             T.tvm_call_packed(
diff --git a/tests/python/unittest/test_aot_legalize_packed_call.py b/tests/python/unittest/test_aot_legalize_packed_call.py
index 9c597a55e5..cd0114d464 100644
--- a/tests/python/unittest/test_aot_legalize_packed_call.py
+++ b/tests/python/unittest/test_aot_legalize_packed_call.py
@@ -15,11 +15,11 @@
 # specific language governing permissions and limitations
 # under the License.
 # pylint: disable=missing-function-docstring,missing-module-docstring
+import pytest
 import tvm
-from tvm.script import tir as T
-from tvm import tir
 import tvm.testing
-import pytest
+from tvm import tir
+from tvm.script import tir as T
 
 
 @tvm.script.ir_module
@@ -85,7 +85,7 @@ class Expected:
                     T.tvm_stack_make_shape(1, dtype="handle"),
                     T.reinterpret(T.uint64(0), dtype="handle"),
                     T.uint32(1),
-                    T.cast(0, dtype="float32"),
+                    T.Cast("float32", 0),
                     0,
                     dtype="handle",
                 ),
@@ -94,7 +94,7 @@ class Expected:
                     T.tvm_stack_make_shape(1, dtype="handle"),
                     T.reinterpret(T.uint64(0), dtype="handle"),
                     T.uint32(1),
-                    T.cast(0, dtype="float32"),
+                    T.Cast("float32", 0),
                     0,
                     dtype="handle",
                 ),
@@ -103,7 +103,7 @@ class Expected:
                     T.tvm_stack_make_shape(1, dtype="handle"),
                     T.reinterpret(T.uint64(0), dtype="handle"),
                     T.uint32(1),
-                    T.cast(0, dtype="float32"),
+                    T.Cast("float32", 0),
                     0,
                     dtype="handle",
                 ),
diff --git a/tests/python/unittest/test_meta_schedule_space_cuda.py b/tests/python/unittest/test_meta_schedule_space_cuda.py
index 324d8a9ec4..0a518c840d 100644
--- a/tests/python/unittest/test_meta_schedule_space_cuda.py
+++ b/tests/python/unittest/test_meta_schedule_space_cuda.py
@@ -856,7 +856,7 @@ def test_cuda_nrm():
                 for i0_1 in T.thread_binding(128, thread="threadIdx.x"):
                     with T.block("D"):
                         b = T.axis.spatial(1, i0_1)
-                        T.where(0 * 128 + i0_1 < 1)
+                        T.where(T.Mul(0, 128) + i0_1 < 1)
                         T.reads(C_shared[b])
                         T.writes(D[b])
                         D[b] = T.sqrt(C_shared[b], dtype="float32")
diff --git a/tests/python/unittest/test_tir_transform_inject_software_pipeline.py b/tests/python/unittest/test_tir_transform_inject_software_pipeline.py
index 2a4cabc541..c70525b057 100644
--- a/tests/python/unittest/test_tir_transform_inject_software_pipeline.py
+++ b/tests/python/unittest/test_tir_transform_inject_software_pipeline.py
@@ -14,16 +14,16 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-import pytest
 import sys
-import numpy as np
 
+import numpy as np
+import pytest
 import tvm
 import tvm.testing
 import tvm.tir.tensor_intrin.cuda
-from tvm import tir, te, TVMError
-from tvm.script import tir as T
+from tvm import TVMError, te, tir
 from tvm.meta_schedule.testing import te_workload
+from tvm.script import tir as T
 from tvm.testing.tir import mma_schedule
 from tvm.tir.tensor_intrin.cuda import (
     LDMATRIX_16x16_A_DYN_INTRIN,
@@ -1116,7 +1116,7 @@ def test_simple_compute_async():
     mod = tvm.tir.transform.InjectSoftwarePipeline()(sch.mod)
 
     @T.prim_func
-    def ref(A: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), "float32"]) -> None:
+    def ref(A: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), "float32"]):
         for tx in T.thread_binding(16, thread="threadIdx.x"):
             with T.block():
                 T.reads(A[tx, 0:16])
@@ -1127,7 +1127,7 @@ def test_simple_compute_async():
                     T.writes(B[0, tx, 0])
                     with T.attr(0, "async_commit_queue_scope", 0):
                         with T.attr(0, "async_scope", 1):
-                            B[0 % 2, tx, 0] = A[tx, 0] * T.float32(2)
+                            B[T.FloorMod(0, 2), tx, 0] = A[tx, 0] * T.float32(2)
                 with T.block():
                     T.reads(A[tx, 1:16], B[0:2, tx, 0])
                     T.writes(B[0:2, tx, 0], C[tx, 0:15])
@@ -1147,11 +1147,11 @@ def test_simple_compute_async():
                                 with T.attr(0, "async_wait_inflight_count", 1):
                                     C[tx, i - 1 + 1] = B[(i - 1 + 1) % 2, tx, 0] + T.float32(1)
                 with T.block():
-                    T.reads(B[15 % 2, tx, 0])
+                    T.reads(B[T.FloorMod(15, 2), tx, 0])
                     T.writes(C[tx, 15])
                     with T.attr(0, "async_wait_queue_scope", 0):
                         with T.attr(0, "async_wait_inflight_count", 0):
-                            C[tx, 15] = B[15 % 2, tx, 0] + T.float32(1)
+                            C[tx, 15] = B[T.FloorMod(15, 2), tx, 0] + T.float32(1)
 
     tvm.ir.assert_structural_equal(mod["main"], ref, True)
 
diff --git a/tests/python/unittest/test_tir_transform_inject_virtual_thread.py b/tests/python/unittest/test_tir_transform_inject_virtual_thread.py
index 548f3bc8d1..b4ea4e712d 100644
--- a/tests/python/unittest/test_tir_transform_inject_virtual_thread.py
+++ b/tests/python/unittest/test_tir_transform_inject_virtual_thread.py
@@ -16,7 +16,6 @@
 # under the License.
 import tvm
 from tvm import te
-
 from tvm.script import tir as T
 
 vthread_name = tvm.testing.parameter("vthread", "cthread")
@@ -155,10 +154,10 @@ def test_vthread_simplified():
         B = T.buffer_decl([16], "int32", data=B_data, scope="shared")
         # The indices for B should each be a single Ramp node, and
         # should not be the sum of a Ramp and Broadcast node.
-        B[0 * 4 : 0 * 4 + 4] = T.broadcast(0, 4)
-        B[1 * 4 : 1 * 4 + 4] = T.broadcast(1, 4)
-        B[2 * 4 : 2 * 4 + 4] = T.broadcast(2, 4)
-        B[3 * 4 : 3 * 4 + 4] = T.broadcast(3, 4)
+        B[T.Mul(0, 4) : T.Mul(0, 4) + 4] = T.broadcast(0, 4)
+        B[T.Mul(1, 4) : T.Mul(1, 4) + 4] = T.broadcast(1, 4)
+        B[T.Mul(2, 4) : T.Mul(2, 4) + 4] = T.broadcast(2, 4)
+        B[T.Mul(3, 4) : T.Mul(3, 4) + 4] = T.broadcast(3, 4)
 
     before_mod = tvm.IRModule.from_expr(before_func)
     after_mod = tvm.tir.transform.InjectVirtualThread()(before_mod)
@@ -182,10 +181,10 @@ def test_vthread_vectorized():
     def expected_func():
         B_data = T.allocate([4], "int32x4", "shared")
         B = T.buffer_decl([4], "int32x4", data=B_data, scope="shared")
-        B[0 * 4 / 4] = T.broadcast(0, 4)
-        B[1 * 4 / 4] = T.broadcast(1, 4)
-        B[2 * 4 / 4] = T.broadcast(2, 4)
-        B[3 * 4 / 4] = T.broadcast(3, 4)
+        B[T.Mul(0, 4) / 4] = T.broadcast(0, 4)
+        B[T.Mul(1, 4) / 4] = T.broadcast(1, 4)
+        B[T.Mul(2, 4) / 4] = T.broadcast(2, 4)
+        B[T.Mul(3, 4) / 4] = T.broadcast(3, 4)
 
     before_mod = tvm.IRModule.from_expr(before_func)
     intermediate_mod = tvm.tir.transform.InjectVirtualThread()(before_mod)
diff --git a/tests/python/unittest/test_tir_transform_thread_sync.py b/tests/python/unittest/test_tir_transform_thread_sync.py
index 18607ca1a0..c80cd55ea2 100644
--- a/tests/python/unittest/test_tir_transform_thread_sync.py
+++ b/tests/python/unittest/test_tir_transform_thread_sync.py
@@ -102,9 +102,9 @@ def test_sync_read_thread_id_independent_location():
         threadIdx_x = T.env_thread("threadIdx.x")
         blockIdx_x = T.env_thread("blockIdx.x")
         T.preflattened_buffer(p0, [1, 2, 1, 1], dtype="float32", data=p0.data)
-        T.launch_thread(blockIdx_x, 8)
         result_local = T.alloc_buffer([1], dtype="float32", scope="local")
         temp_shared = T.alloc_buffer([1], dtype="float32", scope="shared")
+        T.launch_thread(blockIdx_x, 8)
         T.launch_thread(threadIdx_x, 4)
         result_local[0] = T.float32(0)
         if threadIdx_x < 1: