You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2022/08/10 00:26:06 UTC
[tvm] branch main updated: [FIX,STORAGE REWRITE] Rewrite buffers in let statements (#12349)
This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 0c281b7064 [FIX,STORAGE REWRITE] Rewrite buffers in let statements (#12349)
0c281b7064 is described below
commit 0c281b7064a8dba4ab1943cf1282b381ac8208ae
Author: Tristan Konolige <tk...@octoml.ai>
AuthorDate: Tue Aug 9 17:25:59 2022 -0700
[FIX,STORAGE REWRITE] Rewrite buffers in let statements (#12349)
Storage rewrite was missing a visitor for let statements so buffers
added in them would still refer to the pre-rewritten version. This error
was originally noticed when using `global.vtcm` buffers which get
changed to let statements by LowerVtcmAlloc.
Implementing the test for this change also required adding support for
vectorized datatypes to tvmscript. The solution included is a little
hacky and involes adding the datatypes to the `global()` table of each
module they need to be defined in.
---
python/tvm/script/tir/__init__.py | 12 +++-
python/tvm/script/tir/intrin.py | 84 +++++-----------------
python/tvm/script/tir/ty.py | 15 ++--
src/tir/transforms/storage_rewrite.cc | 8 +++
.../unittest/test_tir_transform_storage_rewrite.py | 24 +++++++
5 files changed, 65 insertions(+), 78 deletions(-)
diff --git a/python/tvm/script/tir/__init__.py b/python/tvm/script/tir/__init__.py
index 2655f5bb33..2f2b4bbc25 100644
--- a/python/tvm/script/tir/__init__.py
+++ b/python/tvm/script/tir/__init__.py
@@ -17,7 +17,15 @@
"""TVMScript for TIR"""
# Type system
-from .ty import uint8, int8, int16, int32, int64, float16, float32, float64, void
-from .ty import boolean, handle, Ptr, Tuple, Buffer
+from .ty import void, boolean, handle, Ptr, Tuple, Buffer
from .prim_func import prim_func
+
+# add all floating point and integer datatypes to the module
+for _dtype in ["float", "uint", "int"]:
+ for _size in ["8", "16", "32", "64"]:
+ for _lanes in ["", "x4", "x8", "x16", "x32"]:
+ from . import ty
+
+ _name = _dtype + _size + _lanes
+ globals()[_name] = getattr(ty, _name)
diff --git a/python/tvm/script/tir/intrin.py b/python/tvm/script/tir/intrin.py
index 627b89086a..382431c229 100644
--- a/python/tvm/script/tir/intrin.py
+++ b/python/tvm/script/tir/intrin.py
@@ -42,74 +42,22 @@ def bool(imm, span):
return imm.astype("bool", span)
-@register
-def int8(imm, span):
- return imm.astype("int8", span)
-
-
-@register
-def int16(imm, span):
- return imm.astype("int16", span)
-
-
-@register
-def int32(imm, span):
- return imm.astype("int32", span)
-
-
-@register
-def int64(imm, span):
- return imm.astype("int64", span)
-
-
-@register
-def uint8(imm, span):
- return imm.astype("uint8", span)
-
-
-@register
-def uint16(imm, span):
- return imm.astype("uint16", span)
-
-
-@register
-def uint32(imm, span):
- return imm.astype("uint32", span)
-
-
-@register
-def uint64(imm, span):
- return imm.astype("uint64", span)
-
-
-@register
-def float8(imm, span):
- return imm.astype("float8", span)
-
-
-@register
-def float16(imm, span):
- return imm.astype("float16", span)
-
-
-@register
-def float32(imm, span):
- return imm.astype("float32", span)
-
-
-@register
-def float64(imm, span):
- return imm.astype("float64", span)
-
-
-@register
-def int32x16(imm, span):
- return imm.astype("int32x16", span)
-
-
-@register
-def int32x4(imm, span):
- return imm.astype("int32x4", span)
+# register all datatypes
+for _dtype in ["float", "uint", "int"]:
+ for _size in ["8", "16", "32", "64"]:
+ for _lanes in ["", "x4", "x8", "x16", "x32"]:
+ _name = _dtype + _size + _lanes
+
+ # nest closures so we copy the name string
+ def wrap(name):
+ def f(imm, span):
+ return imm.astype(name, span)
+
+ f.__name__ = name
+ return f
+
+ _intrin = wrap(_name)
+ register(_intrin)
@register
diff --git a/python/tvm/script/tir/ty.py b/python/tvm/script/tir/ty.py
index a64485b215..4548102a9e 100644
--- a/python/tvm/script/tir/ty.py
+++ b/python/tvm/script/tir/ty.py
@@ -199,14 +199,13 @@ class GenericBufferType(SpecialStmt): # pylint: disable=too-few-public-methods,
)
-uint8 = ConcreteType("uint8")
-int8 = ConcreteType("int8")
-int16 = ConcreteType("int16")
-int32 = ConcreteType("int32")
-int64 = ConcreteType("int64")
-float16 = ConcreteType("float16")
-float32 = ConcreteType("float32")
-float64 = ConcreteType("float64")
+# add all floating point and integer datatypes to the module
+for _dtype in ["float", "uint", "int"]:
+ for _size in ["8", "16", "32", "64"]:
+ for _lanes in ["", "x4", "x8", "x16", "x32"]:
+ _name = _dtype + _size + _lanes
+ globals()[_name] = ConcreteType(_name)
+
boolean = ConcreteType("bool")
handle = ConcreteType("handle")
void = VoidType()
diff --git a/src/tir/transforms/storage_rewrite.cc b/src/tir/transforms/storage_rewrite.cc
index c5f27b8de3..5a326d9fac 100644
--- a/src/tir/transforms/storage_rewrite.cc
+++ b/src/tir/transforms/storage_rewrite.cc
@@ -1461,6 +1461,14 @@ class VectorTypeRewriter : public StmtExprMutator {
return VisitBufferAccess(std::move(node));
}
+ Stmt VisitStmt_(const LetStmtNode* op) final {
+ auto it = rewrite_map_.find(op->var.get());
+ if (it == rewrite_map_.end()) {
+ return GetRef<Stmt>(op);
+ }
+ return LetStmt(it->second.new_buffer_var, op->value, op->body);
+ }
+
Buffer RemapBuffer(Buffer buf) {
auto cache_key = buf.get();
diff --git a/tests/python/unittest/test_tir_transform_storage_rewrite.py b/tests/python/unittest/test_tir_transform_storage_rewrite.py
index df147e411f..dc84cadd54 100644
--- a/tests/python/unittest/test_tir_transform_storage_rewrite.py
+++ b/tests/python/unittest/test_tir_transform_storage_rewrite.py
@@ -671,5 +671,29 @@ def test_access_in_let_value():
tvm.ir.assert_structural_equal(mod["main"], func_rewritten)
+class TestLetBufferRewrite(tvm.testing.CompareBeforeAfter):
+ """StorageRewrite replaces the bound var of backing allocations
+
+ If StorageRewrite replaces the backing variable of an array, such
+ as when vectorizing the storage type, the variable must be
+ replaced in the LetStmt that defines it. Currently, StmtMutator
+ only visits usage of variables, and does not visit definitions of
+ variables, so the definition in a LetStmt must be explicitly
+ handled.
+ """
+
+ transform = tvm.tir.transform.StorageRewrite()
+
+ def before() -> None:
+ A_data: T.Ptr[T.int32] = T.call_extern("dummy_func", dtype="handle")
+ A = T.buffer_decl([8], "int32", data=A_data)
+ A[0:8] = T.broadcast(42, 8)
+
+ def expected() -> None:
+ A_data: T.Ptr[T.int32x8] = T.call_extern("dummy_func", dtype="handle")
+ A = T.buffer_decl([8], "int32", data=A_data)
+ A[0:8] = T.broadcast(42, 8)
+
+
if __name__ == "__main__":
tvm.testing.main()