You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2022/11/10 04:59:59 UTC
[tvm] branch feature/2022-11-09/printer-explicit-ir-node created (now 9cd4ad0077)
This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a change to branch feature/2022-11-09/printer-explicit-ir-node
in repository https://gitbox.apache.org/repos/asf/tvm.git
at 9cd4ad0077 [TIR] Make syntax of AST nodes different than ops
This branch includes the following new commits:
new 9cd4ad0077 [TIR] Make syntax of AST nodes different than ops
The 1 revisions listed above as "new" are entirely new to this
repository and will be described in separate emails. The revisions
listed as "add" were already present in the repository and have only
been added to this reference.
[tvm] 01/01: [TIR] Make syntax of AST nodes different than ops
Posted by ju...@apache.org.
This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch feature/2022-11-09/printer-explicit-ir-node
in repository https://gitbox.apache.org/repos/asf/tvm.git
commit 9cd4ad0077b15bdc03dfcc2db19f99ab224c9d79
Author: Junru Shao <ju...@gmail.com>
AuthorDate: Wed Nov 9 19:29:17 2022 -0800
[TIR] Make syntax of AST nodes different than ops
As part of effort of more formal TIR semantics, we want to more
explicitly differentiate TIR AST nodes (defined in `tir/expr.h`)
and TIR ops (defined in `tir/op.h`).
A naming convention is that:
- Lowercased methods, for example, `tvm.tir.mul`, means an TIR op, which
will be eagerly constant-folded, i.e. `mul(1, 2)` returns `3`
immediately rather than creating an AST node.
- Capitalized callable, for example, `Mul`, means creating an AST node
without constant folding.
This PR makes this behavior more explictly by printing `T.Mul(a, b)`
directly when `a` and `b` are both constants, rather than sugaring it
into `mul(a. b)` or `a * b`, so that the difference between an op and
an AST node is clarified.
Co-authored-by: Yaxing Cai <ca...@gmail.com>
---
python/tvm/script/tir/intrin.py | 80 +++++++++++++++++-
src/printer/tvmscript_printer.cc | 97 +++++++++++++---------
.../test_hexagon/test_async_dma_pipeline.py | 23 +++--
.../test_hexagon/test_parallel_hvx_load_vtcm.py | 49 ++++-------
.../unittest/test_aot_legalize_packed_call.py | 12 +--
.../unittest/test_meta_schedule_space_cuda.py | 2 +-
.../test_tir_transform_inject_software_pipeline.py | 16 ++--
.../test_tir_transform_inject_virtual_thread.py | 17 ++--
.../unittest/test_tir_transform_thread_sync.py | 2 +-
9 files changed, 185 insertions(+), 113 deletions(-)
diff --git a/python/tvm/script/tir/intrin.py b/python/tvm/script/tir/intrin.py
index bd9aa1fdad..8e24f27325 100644
--- a/python/tvm/script/tir/intrin.py
+++ b/python/tvm/script/tir/intrin.py
@@ -17,12 +17,13 @@
"""TVM Script Parser Intrinsic Classes"""
# pylint: disable=redefined-builtin, relative-beyond-top-level
import builtins
-from typing import List, Any
+from typing import Any, List
import tvm.tir
from tvm.tir import FloatImm
-from ..registry import register
+
from ...target import codegen
+from ..registry import register
from ..utils import get_param_list, tvm_span_from_synr
@@ -229,3 +230,78 @@ def comm_reducer(lambda_io, identities, span):
def llvm_lookup_intrinsic_id(name, span):
# pylint: disable=unused-argument
return codegen.llvm_lookup_intrinsic_id(name)
+
+
+@register
+def FloorMod(x, y, span): # pylint: disable=invalid-name
+ return tvm.tir.FloorMod(x, y, span)
+
+
+@register
+def FloorDiv(x, y, span): # pylint: disable=invalid-name
+ return tvm.tir.FloorDiv(x, y, span)
+
+
+@register
+def Mul(x, y, span): # pylint: disable=invalid-name
+ return tvm.tir.Mul(x, y, span)
+
+
+@register
+def Div(x, y, span): # pylint: disable=invalid-name
+ return tvm.tir.Div(x, y, span)
+
+
+@register
+def Add(x, y, span): # pylint: disable=invalid-name
+ return tvm.tir.Add(x, y, span)
+
+
+@register
+def Sub(x, y, span): # pylint: disable=invalid-name
+ return tvm.tir.Sub(x, y, span)
+
+
+@register
+def LT(x, y, span): # pylint: disable=invalid-name
+ return tvm.tir.LT(x, y, span)
+
+
+@register
+def LE(x, y, span): # pylint: disable=invalid-name
+ return tvm.tir.LE(x, y, span)
+
+
+@register
+def GT(x, y, span): # pylint: disable=invalid-name
+ return tvm.tir.GT(x, y, span)
+
+
+@register
+def GE(x, y, span): # pylint: disable=invalid-name
+ return tvm.tir.GE(x, y, span)
+
+
+@register
+def EQ(x, y, span): # pylint: disable=invalid-name
+ return tvm.tir.EQ(x, y, span)
+
+
+@register
+def NE(x, y, span): # pylint: disable=invalid-name
+ return tvm.tir.NE(x, y, span)
+
+
+@register
+def And(x, y, span): # pylint: disable=invalid-name
+ return tvm.tir.And(x, y, span)
+
+
+@register
+def Or(x, y, span): # pylint: disable=invalid-name
+ return tvm.tir.Or(x, y, span)
+
+
+@register
+def Cast(dtype, value, span): # pylint: disable=invalid-name
+ return tvm.tir.Cast(dtype, value, span)
diff --git a/src/printer/tvmscript_printer.cc b/src/printer/tvmscript_printer.cc
index 64a576ef52..d7a3a406e3 100644
--- a/src/printer/tvmscript_printer.cc
+++ b/src/printer/tvmscript_printer.cc
@@ -788,7 +788,7 @@ Doc TVMScriptPrinter::VisitExpr_(const StringImmNode* op, ExprPrecedence* out_pr
Doc TVMScriptPrinter::VisitExpr_(const CastNode* op, ExprPrecedence* out_precedence) {
*out_precedence = ExprPrecedence::kIdentity;
Doc doc;
- doc << tir_prefix_ << ".cast(" << Print(op->value) << ", " << PrintDType(op->dtype) << ")";
+ doc << tir_prefix_ << ".Cast(" << PrintDType(op->dtype) << ", " << Print(op->value) << ")";
return doc;
}
@@ -798,46 +798,61 @@ Doc TVMScriptPrinter::VisitExpr_(const VarNode* op, ExprPrecedence* out_preceden
return meta_.InMeta(var) ? meta_.GetMetaNode(var) : AllocVar(GetRef<Var>(op));
}
-#define TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(OpName, OpString, OpPrecedence) \
- Doc TVMScriptPrinter::VisitExpr_(const OpName* op, ExprPrecedence* out_precedence) { \
- Doc doc; \
- ExprPrecedence lhs_precedence = ExprPrecedence::kUnknown; \
- ExprPrecedence rhs_precedence = ExprPrecedence::kUnknown; \
- /* Get children expr out_precedence */ \
- Doc lhs_doc = VisitExpr(op->a, &lhs_precedence); \
- Doc rhs_doc = VisitExpr(op->b, &rhs_precedence); \
- ICHECK(lhs_precedence != ExprPrecedence::kUnknown); \
- ICHECK(rhs_precedence != ExprPrecedence::kUnknown); \
- /* Update out_precedence of current node. */ \
- *out_precedence = OpPrecedence; \
- if (lhs_precedence > OpPrecedence) { \
- doc << "(" << lhs_doc << ")"; \
- } else { \
- doc << lhs_doc; \
- } \
- doc << OpString; \
- if (rhs_precedence >= OpPrecedence) { \
- doc << "(" << rhs_doc << ")"; \
- } else { \
- doc << rhs_doc; \
- } \
- return doc; \
- }
-
-TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(MulNode, " * ", ExprPrecedence::kMultiplicationDivision)
-TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(DivNode, " / ", ExprPrecedence::kMultiplicationDivision)
-TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(FloorDivNode, " // ", ExprPrecedence::kMultiplicationDivision)
-TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(FloorModNode, " % ", ExprPrecedence::kMultiplicationDivision)
-TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(AddNode, " + ", ExprPrecedence::kAdditionSubtraction)
-TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(SubNode, " - ", ExprPrecedence::kAdditionSubtraction)
-TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(LTNode, " < ", ExprPrecedence::kRelational)
-TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(LENode, " <= ", ExprPrecedence::kRelational)
-TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(GTNode, " > ", ExprPrecedence::kRelational)
-TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(GENode, " >= ", ExprPrecedence::kRelational)
-TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(EQNode, " == ", ExprPrecedence::kEquality)
-TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(NENode, " != ", ExprPrecedence::kEquality)
-TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(AndNode, " and ", ExprPrecedence::kAnd)
-TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(OrNode, " or ", ExprPrecedence::kOr)
+bool WillPrintConstScalar(const PrimExpr& expr) {
+ if (const auto* imm = expr.as<IntImmNode>()) {
+ DataType dtype = imm->dtype;
+ return dtype == DataType::Int(32) || dtype == DataType::Bool();
+ }
+ return false;
+}
+
+#define TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(OpName, OpString, OpClass, OpPrecedence) \
+ Doc TVMScriptPrinter::VisitExpr_(const OpName* op, ExprPrecedence* out_precedence) { \
+ Doc doc; \
+ if (WillPrintConstScalar(op->a) && WillPrintConstScalar(op->b)) { \
+ *out_precedence = ExprPrecedence::kIdentity; \
+ doc << tir_prefix_ << "." << OpClass << "(" << Print(op->a) << ", " << Print(op->b) << ")"; \
+ return doc; \
+ } \
+ ExprPrecedence lhs_precedence = ExprPrecedence::kUnknown; \
+ ExprPrecedence rhs_precedence = ExprPrecedence::kUnknown; \
+ /* Get children expr out_precedence */ \
+ Doc lhs_doc = VisitExpr(op->a, &lhs_precedence); \
+ Doc rhs_doc = VisitExpr(op->b, &rhs_precedence); \
+ ICHECK(lhs_precedence != ExprPrecedence::kUnknown); \
+ ICHECK(rhs_precedence != ExprPrecedence::kUnknown); \
+ /* Update out_precedence of current node. */ \
+ *out_precedence = OpPrecedence; \
+ if (lhs_precedence > OpPrecedence) { \
+ doc << "(" << lhs_doc << ")"; \
+ } else { \
+ doc << lhs_doc; \
+ } \
+ doc << OpString; \
+ if (rhs_precedence >= OpPrecedence) { \
+ doc << "(" << rhs_doc << ")"; \
+ } else { \
+ doc << rhs_doc; \
+ } \
+ return doc; \
+ }
+
+TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(MulNode, " * ", "Mul", ExprPrecedence::kMultiplicationDivision)
+TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(DivNode, " / ", "Div", ExprPrecedence::kMultiplicationDivision)
+TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(FloorDivNode, " // ", "FloorDiv",
+ ExprPrecedence::kMultiplicationDivision)
+TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(FloorModNode, " % ", "FloorMod",
+ ExprPrecedence::kMultiplicationDivision)
+TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(AddNode, " + ", "Add", ExprPrecedence::kAdditionSubtraction)
+TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(SubNode, " - ", "Sub", ExprPrecedence::kAdditionSubtraction)
+TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(LTNode, " < ", "LT", ExprPrecedence::kRelational)
+TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(LENode, " <= ", "LE", ExprPrecedence::kRelational)
+TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(GTNode, " > ", "GT", ExprPrecedence::kRelational)
+TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(GENode, " >= ", "GE", ExprPrecedence::kRelational)
+TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(EQNode, " == ", "EQ", ExprPrecedence::kEquality)
+TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(NENode, " != ", "NE", ExprPrecedence::kEquality)
+TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(AndNode, " and ", "And", ExprPrecedence::kAnd)
+TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(OrNode, " or ", "Or", ExprPrecedence::kOr)
Doc TVMScriptPrinter::VisitExpr_(const ModNode* op, ExprPrecedence* out_precedence) {
*out_precedence = ExprPrecedence::kIdentity;
diff --git a/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py b/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py
index a7a05c2aa3..19b380c1bd 100644
--- a/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py
+++ b/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py
@@ -18,11 +18,10 @@
""" Test different strategies for loading data into vtcm before running HVX workloads. """
import numpy as np
-import tvm
import pytest
-
-from tvm.script import tir as T
+import tvm
from numpy.random import default_rng
+from tvm.script import tir as T
VRMPY_SIZE_B = 128
VRMPY_SIZE_INT32 = 32
@@ -126,12 +125,12 @@ def get_single_dma_schedule(size_a, size_w):
@T.prim_func
def operator(a_input: T.handle, b_input: T.handle, c_output: T.handle) -> None:
T.func_attr({"global_symbol": "main", "tir.noalias": True})
- a_buffer = T.match_buffer(a_input, a_shape, dtype="uint8", mem_scope="global")
- w_buffer = T.match_buffer(b_input, w_shape, dtype="uint8", mem_scope="global")
- c_buffer = T.match_buffer(c_output, out_shape, dtype="int32", mem_scope="global")
- a_global_vtcm = T.alloc_buffer(a_shape, dtype="uint8", mem_scope="global.vtcm")
- w_global_vtcm = T.alloc_buffer(w_shape, dtype="uint8", mem_scope="global.vtcm")
- c_global_vtcm = T.alloc_buffer(out_shape, dtype="int32", mem_scope="global.vtcm")
+ a_buffer = T.match_buffer(a_input, a_shape, dtype="uint8", scope="global")
+ w_buffer = T.match_buffer(b_input, w_shape, dtype="uint8", scope="global")
+ c_buffer = T.match_buffer(c_output, out_shape, dtype="int32", scope="global")
+ a_global_vtcm = T.alloc_buffer(a_shape, dtype="uint8", scope="global.vtcm")
+ w_global_vtcm = T.alloc_buffer(w_shape, dtype="uint8", scope="global.vtcm")
+ c_global_vtcm = T.alloc_buffer(out_shape, dtype="int32", scope="global.vtcm")
T.evaluate(
T.tvm_call_packed(
"device_api.hexagon.mem_copy_DLTensor",
@@ -153,7 +152,7 @@ def get_single_dma_schedule(size_a, size_w):
0,
dtype="handle",
),
- T.cast(a_bytes, dtype="int"),
+ T.Cast("int", a_bytes),
dtype="int32",
)
)
@@ -178,7 +177,7 @@ def get_single_dma_schedule(size_a, size_w):
0,
dtype="handle",
),
- T.cast(w_bytes, dtype="int"),
+ T.Cast("int", w_bytes),
dtype="int32",
)
)
@@ -222,7 +221,7 @@ def get_single_dma_schedule(size_a, size_w):
0,
dtype="handle",
),
- T.cast(a_bytes, dtype="int"),
+ T.Cast("int", a_bytes),
dtype="int32",
)
)
diff --git a/tests/python/contrib/test_hexagon/test_parallel_hvx_load_vtcm.py b/tests/python/contrib/test_hexagon/test_parallel_hvx_load_vtcm.py
index fb398f4397..e6fc0a3c20 100644
--- a/tests/python/contrib/test_hexagon/test_parallel_hvx_load_vtcm.py
+++ b/tests/python/contrib/test_hexagon/test_parallel_hvx_load_vtcm.py
@@ -18,9 +18,8 @@
""" Test different strategies for loading data into vtcm before running HVX workloads. """
import numpy as np
-from numpy.random import default_rng
-
import tvm
+from numpy.random import default_rng
from tvm.script import tir as T
from .infrastructure import get_hexagon_target
@@ -109,17 +108,17 @@ def preloaded_vrmpy(operations):
[T.cast(operations, "int32") * 128],
dtype="uint8",
align=128,
- mem_scope="global.vtcm",
+ scope="global.vtcm",
)
b_buffer = T.match_buffer(
b,
[T.cast(operations, "int32") * 128],
dtype="uint8",
align=128,
- mem_scope="global.vtcm",
+ scope="global.vtcm",
)
c_buffer = T.match_buffer(
- c, [T.cast(operations, "int32") * 32], dtype="int32", align=128, mem_scope="global.vtcm"
+ c, [T.cast(operations, "int32") * 32], dtype="int32", align=128, scope="global.vtcm"
)
for n in T.grid(operations):
with T.block("c_buffer"):
@@ -149,21 +148,13 @@ def preallocated_vrmpy(operations):
a: T.handle, b: T.handle, c: T.handle, a_v: T.handle, b_v: T.handle, c_v: T.handle
) -> None:
T.func_attr({"global_symbol": "main", "tir.noalias": True})
- a_buffer = T.match_buffer(
- a, [operations, 128], dtype="uint8", align=128, mem_scope="global"
- )
- b_buffer = T.match_buffer(
- b, [operations, 128], dtype="uint8", align=128, mem_scope="global"
- )
- c_buffer = T.match_buffer(c, [operations, 32], dtype="int32", align=128, mem_scope="global")
- a_global_vtcm = T.match_buffer(
- a_v, [size], dtype="uint8", align=128, mem_scope="global.vtcm"
- )
- b_global_vtcm = T.match_buffer(
- b_v, [size], dtype="uint8", align=128, mem_scope="global.vtcm"
- )
+ a_buffer = T.match_buffer(a, [operations, 128], dtype="uint8", align=128, scope="global")
+ b_buffer = T.match_buffer(b, [operations, 128], dtype="uint8", align=128, scope="global")
+ c_buffer = T.match_buffer(c, [operations, 32], dtype="int32", align=128, scope="global")
+ a_global_vtcm = T.match_buffer(a_v, [size], dtype="uint8", align=128, scope="global.vtcm")
+ b_global_vtcm = T.match_buffer(b_v, [size], dtype="uint8", align=128, scope="global.vtcm")
c_global_vtcm = T.match_buffer(
- c_v, [out_size], dtype="int32", align=128, mem_scope="global.vtcm"
+ c_v, [out_size], dtype="int32", align=128, scope="global.vtcm"
)
for n, i in T.grid(operations, 128):
with T.block("a_buffer_global.vtcm"):
@@ -212,21 +203,13 @@ def preallocated_single_dma_vrmpy(operations):
c_v: T.handle,
) -> None:
T.func_attr({"global_symbol": "main", "tir.noalias": True})
- a_buffer = T.match_buffer(
- a, [operations, 128], dtype="uint8", align=128, mem_scope="global"
- )
- b_buffer = T.match_buffer(
- b, [operations, 128], dtype="uint8", align=128, mem_scope="global"
- )
- c_buffer = T.match_buffer(c, [operations, 32], dtype="int32", align=128, mem_scope="global")
- a_global_vtcm = T.match_buffer(
- a_v, [size], dtype="uint8", align=128, mem_scope="global.vtcm"
- )
- b_global_vtcm = T.match_buffer(
- b_v, [size], dtype="uint8", align=128, mem_scope="global.vtcm"
- )
+ a_buffer = T.match_buffer(a, [operations, 128], dtype="uint8", align=128, scope="global")
+ b_buffer = T.match_buffer(b, [operations, 128], dtype="uint8", align=128, scope="global")
+ c_buffer = T.match_buffer(c, [operations, 32], dtype="int32", align=128, scope="global")
+ a_global_vtcm = T.match_buffer(a_v, [size], dtype="uint8", align=128, scope="global.vtcm")
+ b_global_vtcm = T.match_buffer(b_v, [size], dtype="uint8", align=128, scope="global.vtcm")
c_global_vtcm = T.match_buffer(
- c_v, [out_size], dtype="int32", align=128, mem_scope="global.vtcm"
+ c_v, [out_size], dtype="int32", align=128, scope="global.vtcm"
)
T.evaluate(
T.tvm_call_packed(
diff --git a/tests/python/unittest/test_aot_legalize_packed_call.py b/tests/python/unittest/test_aot_legalize_packed_call.py
index 9c597a55e5..cd0114d464 100644
--- a/tests/python/unittest/test_aot_legalize_packed_call.py
+++ b/tests/python/unittest/test_aot_legalize_packed_call.py
@@ -15,11 +15,11 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=missing-function-docstring,missing-module-docstring
+import pytest
import tvm
-from tvm.script import tir as T
-from tvm import tir
import tvm.testing
-import pytest
+from tvm import tir
+from tvm.script import tir as T
@tvm.script.ir_module
@@ -85,7 +85,7 @@ class Expected:
T.tvm_stack_make_shape(1, dtype="handle"),
T.reinterpret(T.uint64(0), dtype="handle"),
T.uint32(1),
- T.cast(0, dtype="float32"),
+ T.Cast("float32", 0),
0,
dtype="handle",
),
@@ -94,7 +94,7 @@ class Expected:
T.tvm_stack_make_shape(1, dtype="handle"),
T.reinterpret(T.uint64(0), dtype="handle"),
T.uint32(1),
- T.cast(0, dtype="float32"),
+ T.Cast("float32", 0),
0,
dtype="handle",
),
@@ -103,7 +103,7 @@ class Expected:
T.tvm_stack_make_shape(1, dtype="handle"),
T.reinterpret(T.uint64(0), dtype="handle"),
T.uint32(1),
- T.cast(0, dtype="float32"),
+ T.Cast("float32", 0),
0,
dtype="handle",
),
diff --git a/tests/python/unittest/test_meta_schedule_space_cuda.py b/tests/python/unittest/test_meta_schedule_space_cuda.py
index 324d8a9ec4..0a518c840d 100644
--- a/tests/python/unittest/test_meta_schedule_space_cuda.py
+++ b/tests/python/unittest/test_meta_schedule_space_cuda.py
@@ -856,7 +856,7 @@ def test_cuda_nrm():
for i0_1 in T.thread_binding(128, thread="threadIdx.x"):
with T.block("D"):
b = T.axis.spatial(1, i0_1)
- T.where(0 * 128 + i0_1 < 1)
+ T.where(T.Mul(0, 128) + i0_1 < 1)
T.reads(C_shared[b])
T.writes(D[b])
D[b] = T.sqrt(C_shared[b], dtype="float32")
diff --git a/tests/python/unittest/test_tir_transform_inject_software_pipeline.py b/tests/python/unittest/test_tir_transform_inject_software_pipeline.py
index 2a4cabc541..c70525b057 100644
--- a/tests/python/unittest/test_tir_transform_inject_software_pipeline.py
+++ b/tests/python/unittest/test_tir_transform_inject_software_pipeline.py
@@ -14,16 +14,16 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-import pytest
import sys
-import numpy as np
+import numpy as np
+import pytest
import tvm
import tvm.testing
import tvm.tir.tensor_intrin.cuda
-from tvm import tir, te, TVMError
-from tvm.script import tir as T
+from tvm import TVMError, te, tir
from tvm.meta_schedule.testing import te_workload
+from tvm.script import tir as T
from tvm.testing.tir import mma_schedule
from tvm.tir.tensor_intrin.cuda import (
LDMATRIX_16x16_A_DYN_INTRIN,
@@ -1116,7 +1116,7 @@ def test_simple_compute_async():
mod = tvm.tir.transform.InjectSoftwarePipeline()(sch.mod)
@T.prim_func
- def ref(A: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), "float32"]) -> None:
+ def ref(A: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), "float32"]):
for tx in T.thread_binding(16, thread="threadIdx.x"):
with T.block():
T.reads(A[tx, 0:16])
@@ -1127,7 +1127,7 @@ def test_simple_compute_async():
T.writes(B[0, tx, 0])
with T.attr(0, "async_commit_queue_scope", 0):
with T.attr(0, "async_scope", 1):
- B[0 % 2, tx, 0] = A[tx, 0] * T.float32(2)
+ B[T.FloorMod(0, 2), tx, 0] = A[tx, 0] * T.float32(2)
with T.block():
T.reads(A[tx, 1:16], B[0:2, tx, 0])
T.writes(B[0:2, tx, 0], C[tx, 0:15])
@@ -1147,11 +1147,11 @@ def test_simple_compute_async():
with T.attr(0, "async_wait_inflight_count", 1):
C[tx, i - 1 + 1] = B[(i - 1 + 1) % 2, tx, 0] + T.float32(1)
with T.block():
- T.reads(B[15 % 2, tx, 0])
+ T.reads(B[T.FloorMod(15, 2), tx, 0])
T.writes(C[tx, 15])
with T.attr(0, "async_wait_queue_scope", 0):
with T.attr(0, "async_wait_inflight_count", 0):
- C[tx, 15] = B[15 % 2, tx, 0] + T.float32(1)
+ C[tx, 15] = B[T.FloorMod(15, 2), tx, 0] + T.float32(1)
tvm.ir.assert_structural_equal(mod["main"], ref, True)
diff --git a/tests/python/unittest/test_tir_transform_inject_virtual_thread.py b/tests/python/unittest/test_tir_transform_inject_virtual_thread.py
index 548f3bc8d1..b4ea4e712d 100644
--- a/tests/python/unittest/test_tir_transform_inject_virtual_thread.py
+++ b/tests/python/unittest/test_tir_transform_inject_virtual_thread.py
@@ -16,7 +16,6 @@
# under the License.
import tvm
from tvm import te
-
from tvm.script import tir as T
vthread_name = tvm.testing.parameter("vthread", "cthread")
@@ -155,10 +154,10 @@ def test_vthread_simplified():
B = T.buffer_decl([16], "int32", data=B_data, scope="shared")
# The indices for B should each be a single Ramp node, and
# should not be the sum of a Ramp and Broadcast node.
- B[0 * 4 : 0 * 4 + 4] = T.broadcast(0, 4)
- B[1 * 4 : 1 * 4 + 4] = T.broadcast(1, 4)
- B[2 * 4 : 2 * 4 + 4] = T.broadcast(2, 4)
- B[3 * 4 : 3 * 4 + 4] = T.broadcast(3, 4)
+ B[T.Mul(0, 4) : T.Mul(0, 4) + 4] = T.broadcast(0, 4)
+ B[T.Mul(1, 4) : T.Mul(1, 4) + 4] = T.broadcast(1, 4)
+ B[T.Mul(2, 4) : T.Mul(2, 4) + 4] = T.broadcast(2, 4)
+ B[T.Mul(3, 4) : T.Mul(3, 4) + 4] = T.broadcast(3, 4)
before_mod = tvm.IRModule.from_expr(before_func)
after_mod = tvm.tir.transform.InjectVirtualThread()(before_mod)
@@ -182,10 +181,10 @@ def test_vthread_vectorized():
def expected_func():
B_data = T.allocate([4], "int32x4", "shared")
B = T.buffer_decl([4], "int32x4", data=B_data, scope="shared")
- B[0 * 4 / 4] = T.broadcast(0, 4)
- B[1 * 4 / 4] = T.broadcast(1, 4)
- B[2 * 4 / 4] = T.broadcast(2, 4)
- B[3 * 4 / 4] = T.broadcast(3, 4)
+ B[T.Mul(0, 4) / 4] = T.broadcast(0, 4)
+ B[T.Mul(1, 4) / 4] = T.broadcast(1, 4)
+ B[T.Mul(2, 4) / 4] = T.broadcast(2, 4)
+ B[T.Mul(3, 4) / 4] = T.broadcast(3, 4)
before_mod = tvm.IRModule.from_expr(before_func)
intermediate_mod = tvm.tir.transform.InjectVirtualThread()(before_mod)
diff --git a/tests/python/unittest/test_tir_transform_thread_sync.py b/tests/python/unittest/test_tir_transform_thread_sync.py
index 18607ca1a0..c80cd55ea2 100644
--- a/tests/python/unittest/test_tir_transform_thread_sync.py
+++ b/tests/python/unittest/test_tir_transform_thread_sync.py
@@ -102,9 +102,9 @@ def test_sync_read_thread_id_independent_location():
threadIdx_x = T.env_thread("threadIdx.x")
blockIdx_x = T.env_thread("blockIdx.x")
T.preflattened_buffer(p0, [1, 2, 1, 1], dtype="float32", data=p0.data)
- T.launch_thread(blockIdx_x, 8)
result_local = T.alloc_buffer([1], dtype="float32", scope="local")
temp_shared = T.alloc_buffer([1], dtype="float32", scope="shared")
+ T.launch_thread(blockIdx_x, 8)
T.launch_thread(threadIdx_x, 4)
result_local[0] = T.float32(0)
if threadIdx_x < 1: