You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2022/08/10 00:26:06 UTC

[tvm] branch main updated: [FIX,STORAGE REWRITE] Rewrite buffers in let statements (#12349)

This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 0c281b7064 [FIX,STORAGE REWRITE] Rewrite buffers in let statements (#12349)
0c281b7064 is described below

commit 0c281b7064a8dba4ab1943cf1282b381ac8208ae
Author: Tristan Konolige <tk...@octoml.ai>
AuthorDate: Tue Aug 9 17:25:59 2022 -0700

    [FIX,STORAGE REWRITE] Rewrite buffers in let statements (#12349)
    
    Storage rewrite was missing a visitor for let statements so buffers
    added in them would still refer to the pre-rewritten version. This error
    was originally noticed when using `global.vtcm` buffers which get
    changed to let statements by LowerVtcmAlloc.
    
    Implementing the test for this change also required adding support for
    vectorized datatypes to tvmscript. The solution included is a little
    hacky and involes adding the datatypes to the `global()` table of each
    module they need to be defined in.
---
 python/tvm/script/tir/__init__.py                  | 12 +++-
 python/tvm/script/tir/intrin.py                    | 84 +++++-----------------
 python/tvm/script/tir/ty.py                        | 15 ++--
 src/tir/transforms/storage_rewrite.cc              |  8 +++
 .../unittest/test_tir_transform_storage_rewrite.py | 24 +++++++
 5 files changed, 65 insertions(+), 78 deletions(-)

diff --git a/python/tvm/script/tir/__init__.py b/python/tvm/script/tir/__init__.py
index 2655f5bb33..2f2b4bbc25 100644
--- a/python/tvm/script/tir/__init__.py
+++ b/python/tvm/script/tir/__init__.py
@@ -17,7 +17,15 @@
 """TVMScript for TIR"""
 
 # Type system
-from .ty import uint8, int8, int16, int32, int64, float16, float32, float64, void
-from .ty import boolean, handle, Ptr, Tuple, Buffer
+from .ty import void, boolean, handle, Ptr, Tuple, Buffer
 
 from .prim_func import prim_func
+
+# add all floating point and integer datatypes to the module
+for _dtype in ["float", "uint", "int"]:
+    for _size in ["8", "16", "32", "64"]:
+        for _lanes in ["", "x4", "x8", "x16", "x32"]:
+            from . import ty
+
+            _name = _dtype + _size + _lanes
+            globals()[_name] = getattr(ty, _name)
diff --git a/python/tvm/script/tir/intrin.py b/python/tvm/script/tir/intrin.py
index 627b89086a..382431c229 100644
--- a/python/tvm/script/tir/intrin.py
+++ b/python/tvm/script/tir/intrin.py
@@ -42,74 +42,22 @@ def bool(imm, span):
     return imm.astype("bool", span)
 
 
-@register
-def int8(imm, span):
-    return imm.astype("int8", span)
-
-
-@register
-def int16(imm, span):
-    return imm.astype("int16", span)
-
-
-@register
-def int32(imm, span):
-    return imm.astype("int32", span)
-
-
-@register
-def int64(imm, span):
-    return imm.astype("int64", span)
-
-
-@register
-def uint8(imm, span):
-    return imm.astype("uint8", span)
-
-
-@register
-def uint16(imm, span):
-    return imm.astype("uint16", span)
-
-
-@register
-def uint32(imm, span):
-    return imm.astype("uint32", span)
-
-
-@register
-def uint64(imm, span):
-    return imm.astype("uint64", span)
-
-
-@register
-def float8(imm, span):
-    return imm.astype("float8", span)
-
-
-@register
-def float16(imm, span):
-    return imm.astype("float16", span)
-
-
-@register
-def float32(imm, span):
-    return imm.astype("float32", span)
-
-
-@register
-def float64(imm, span):
-    return imm.astype("float64", span)
-
-
-@register
-def int32x16(imm, span):
-    return imm.astype("int32x16", span)
-
-
-@register
-def int32x4(imm, span):
-    return imm.astype("int32x4", span)
+# register all datatypes
+for _dtype in ["float", "uint", "int"]:
+    for _size in ["8", "16", "32", "64"]:
+        for _lanes in ["", "x4", "x8", "x16", "x32"]:
+            _name = _dtype + _size + _lanes
+
+            # nest closures so we copy the name string
+            def wrap(name):
+                def f(imm, span):
+                    return imm.astype(name, span)
+
+                f.__name__ = name
+                return f
+
+            _intrin = wrap(_name)
+            register(_intrin)
 
 
 @register
diff --git a/python/tvm/script/tir/ty.py b/python/tvm/script/tir/ty.py
index a64485b215..4548102a9e 100644
--- a/python/tvm/script/tir/ty.py
+++ b/python/tvm/script/tir/ty.py
@@ -199,14 +199,13 @@ class GenericBufferType(SpecialStmt):  # pylint: disable=too-few-public-methods,
             )
 
 
-uint8 = ConcreteType("uint8")
-int8 = ConcreteType("int8")
-int16 = ConcreteType("int16")
-int32 = ConcreteType("int32")
-int64 = ConcreteType("int64")
-float16 = ConcreteType("float16")
-float32 = ConcreteType("float32")
-float64 = ConcreteType("float64")
+# add all floating point and integer datatypes to the module
+for _dtype in ["float", "uint", "int"]:
+    for _size in ["8", "16", "32", "64"]:
+        for _lanes in ["", "x4", "x8", "x16", "x32"]:
+            _name = _dtype + _size + _lanes
+            globals()[_name] = ConcreteType(_name)
+
 boolean = ConcreteType("bool")
 handle = ConcreteType("handle")
 void = VoidType()
diff --git a/src/tir/transforms/storage_rewrite.cc b/src/tir/transforms/storage_rewrite.cc
index c5f27b8de3..5a326d9fac 100644
--- a/src/tir/transforms/storage_rewrite.cc
+++ b/src/tir/transforms/storage_rewrite.cc
@@ -1461,6 +1461,14 @@ class VectorTypeRewriter : public StmtExprMutator {
     return VisitBufferAccess(std::move(node));
   }
 
+  Stmt VisitStmt_(const LetStmtNode* op) final {
+    auto it = rewrite_map_.find(op->var.get());
+    if (it == rewrite_map_.end()) {
+      return GetRef<Stmt>(op);
+    }
+    return LetStmt(it->second.new_buffer_var, op->value, op->body);
+  }
+
   Buffer RemapBuffer(Buffer buf) {
     auto cache_key = buf.get();
 
diff --git a/tests/python/unittest/test_tir_transform_storage_rewrite.py b/tests/python/unittest/test_tir_transform_storage_rewrite.py
index df147e411f..dc84cadd54 100644
--- a/tests/python/unittest/test_tir_transform_storage_rewrite.py
+++ b/tests/python/unittest/test_tir_transform_storage_rewrite.py
@@ -671,5 +671,29 @@ def test_access_in_let_value():
     tvm.ir.assert_structural_equal(mod["main"], func_rewritten)
 
 
+class TestLetBufferRewrite(tvm.testing.CompareBeforeAfter):
+    """StorageRewrite replaces the bound var of backing allocations
+
+    If StorageRewrite replaces the backing variable of an array, such
+    as when vectorizing the storage type, the variable must be
+    replaced in the LetStmt that defines it.  Currently, StmtMutator
+    only visits usage of variables, and does not visit definitions of
+    variables, so the definition in a LetStmt must be explicitly
+    handled.
+    """
+
+    transform = tvm.tir.transform.StorageRewrite()
+
+    def before() -> None:
+        A_data: T.Ptr[T.int32] = T.call_extern("dummy_func", dtype="handle")
+        A = T.buffer_decl([8], "int32", data=A_data)
+        A[0:8] = T.broadcast(42, 8)
+
+    def expected() -> None:
+        A_data: T.Ptr[T.int32x8] = T.call_extern("dummy_func", dtype="handle")
+        A = T.buffer_decl([8], "int32", data=A_data)
+        A[0:8] = T.broadcast(42, 8)
+
+
 if __name__ == "__main__":
     tvm.testing.main()