You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by sy...@apache.org on 2021/10/07 12:45:20 UTC

[tvm] branch main updated: Adding annotations for tir.allocate (#9168)

This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 2dae303  Adding annotations for tir.allocate (#9168)
2dae303 is described below

commit 2dae30372c6ef1b148f545ede4952dc498ecb13f
Author: Manupa Karunaratne <ma...@arm.com>
AuthorDate: Thu Oct 7 13:44:59 2021 +0100

    Adding annotations for tir.allocate (#9168)
    
    * Adding annotation for tir.allocate
    
    This commit is adding annotations for tir.allocate
    node to be used as hints for future transformations.
    
    Change-Id: I02a3a875c38c3edd449385da5b741ef4958bb47f
    
    * Adding annotation for tir.allocate
    
    * adding tvmscript support
    * adding tir text printing support
    
    Change-Id: Id0b6725b2e79c23f6b8ff192772f1ea4125a27c2
---
 include/tvm/tir/stmt.h                            | 14 ++++++++--
 python/tvm/script/tir/scope_handler.py            | 16 ++++++++---
 python/tvm/tir/stmt.py                            | 16 +++++++++--
 src/printer/tir_text_printer.cc                   | 12 +++++++--
 src/printer/tvmscript_printer.cc                  | 10 +++++++
 src/tir/ir/stmt.cc                                |  7 ++---
 tests/python/unittest/test_tir_nodes.py           | 33 +++++++++++++++++++++++
 tests/python/unittest/test_tvmscript_roundtrip.py | 27 +++++++++++++++++++
 8 files changed, 122 insertions(+), 13 deletions(-)

diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h
index 94ba853..5cd860b 100644
--- a/include/tvm/tir/stmt.h
+++ b/include/tvm/tir/stmt.h
@@ -521,6 +521,13 @@ class AllocateNode : public StmtNode {
   PrimExpr condition;
   /*! \brief The body to be executed. */
   Stmt body;
+  /*!
+   * \brief Additional annotations about the allocation.
+   *
+   *  These annotations can be used as auxiliary hint
+   *  to future transformations.
+   */
+  Map<String, ObjectRef> annotations;
 
   void VisitAttrs(AttrVisitor* v) {
     v->Visit("buffer_var", &buffer_var);
@@ -528,13 +535,14 @@ class AllocateNode : public StmtNode {
     v->Visit("extents", &extents);
     v->Visit("condition", &condition);
     v->Visit("body", &body);
+    v->Visit("annotations", &annotations);
     v->Visit("span", &span);
   }
 
   bool SEqualReduce(const AllocateNode* other, SEqualReducer equal) const {
     return equal.DefEqual(buffer_var, other->buffer_var) && equal(dtype, other->dtype) &&
            equal(extents, other->extents) && equal(condition, other->condition) &&
-           equal(body, other->body);
+           equal(body, other->body) && equal(annotations, other->annotations);
   }
 
   void SHashReduce(SHashReducer hash_reduce) const {
@@ -543,6 +551,7 @@ class AllocateNode : public StmtNode {
     hash_reduce(extents);
     hash_reduce(condition);
     hash_reduce(body);
+    hash_reduce(annotations);
   }
 
   /*!
@@ -570,7 +579,8 @@ class AllocateNode : public StmtNode {
 class Allocate : public Stmt {
  public:
   TVM_DLL Allocate(Var buffer_var, DataType dtype, Array<PrimExpr> extents, PrimExpr condition,
-                   Stmt body, Span span = Span());
+                   Stmt body, Map<String, ObjectRef> annotations = Map<String, ObjectRef>(),
+                   Span span = Span());
 
   TVM_DEFINE_OBJECT_REF_METHODS(Allocate, Stmt, AllocateNode);
 };
diff --git a/python/tvm/script/tir/scope_handler.py b/python/tvm/script/tir/scope_handler.py
index 1072809..487a71d 100644
--- a/python/tvm/script/tir/scope_handler.py
+++ b/python/tvm/script/tir/scope_handler.py
@@ -104,14 +104,20 @@ class WithScopeHandler(ScopeHandler):
 
 @register
 class Allocate(WithScopeHandler):
-    """With scope handler T.allocate(extents, dtype, scope, condition)"""
+    """With scope handler T.allocate(extents, dtype, scope, condition, annotations)"""
 
     def __init__(self):
-        def allocate(extents, dtype, scope, condition=True, span=None):
+        def allocate(extents, dtype, scope, condition=True, annotations=None, span=None):
             condition = tvm.runtime.convert(condition)
             scope = tvm.runtime.convert(scope)
             return tvm.tir.Allocate(
-                self.buffer_var, dtype, extents, condition, self.body, span=span
+                self.buffer_var,
+                dtype,
+                extents,
+                condition,
+                self.body,
+                annotations=annotations,
+                span=span,
             )
 
         super().__init__(allocate, concise_scope=True, def_symbol=True)
@@ -137,7 +143,9 @@ class Allocate(WithScopeHandler):
         else:
             raise Exception("Internal Bug")
 
-        def setup_buffer_var(extents, dtype, scope, condition=True, span: Span = None):
+        def setup_buffer_var(
+            extents, dtype, scope, condition=True, annotations=None, span: Span = None
+        ):
             """Setup buffer var for a given type."""
             buffer_ptr_type = tvm.ir.PointerType(tvm.ir.PrimType(dtype), scope)
             self.buffer_var = tvm.tir.Var(name, buffer_ptr_type, span)
diff --git a/python/tvm/tir/stmt.py b/python/tvm/tir/stmt.py
index d57077f..de200d5 100644
--- a/python/tvm/tir/stmt.py
+++ b/python/tvm/tir/stmt.py
@@ -318,13 +318,25 @@ class Allocate(Stmt):
     body : Stmt
         The body statement.
 
+    annotations: Optional[Mapping[str, Object]]
+        Additional annotation hints
+
     span : Optional[Span]
         The location of this itervar in the source code.
     """
 
-    def __init__(self, buffer_var, dtype, extents, condition, body, span=None):
+    def __init__(self, buffer_var, dtype, extents, condition, body, annotations=None, span=None):
+        if annotations is None:
+            annotations = dict()
         self.__init_handle_by_constructor__(
-            _ffi_api.Allocate, buffer_var, dtype, extents, condition, body, span  # type: ignore
+            _ffi_api.Allocate,  # type: ignore
+            buffer_var,
+            dtype,
+            extents,
+            condition,
+            body,
+            annotations,
+            span,
         )
 
 
diff --git a/src/printer/tir_text_printer.cc b/src/printer/tir_text_printer.cc
index f232994..fa132f0 100644
--- a/src/printer/tir_text_printer.cc
+++ b/src/printer/tir_text_printer.cc
@@ -449,8 +449,16 @@ Doc TIRTextPrinter::VisitStmt_(const BufferRealizeNode* op) {
 Doc TIRTextPrinter::VisitStmt_(const AllocateNode* op) {
   Doc doc;
   auto scope = GetPtrStorageScope(op->buffer_var);
-  doc << "allocate(" << Print(op->buffer_var) << ", " << PrintDType(op->dtype) << ", "
-      << Print(op->extents) << "), storage_scope = " << scope;
+  doc << "allocate(" << Print(op->buffer_var) << ", ";
+  doc << PrintDType(op->dtype) << ", ";
+  doc << Print(op->extents) << "), storage_scope = " << scope;
+  if (!op->annotations.empty()) {
+    std::vector<Doc> attr_docs;
+    for (const auto& it : op->annotations) {
+      attr_docs.push_back(Doc::StrLiteral(it.first) << ": " << Print(it.second));
+    }
+    doc << ", annotations = {" << PrintSep(attr_docs, Doc::Text(", ")) << "})";
+  }
   if (!is_one(op->condition)) {
     doc << " if " << Print(op->condition);
   }
diff --git a/src/printer/tvmscript_printer.cc b/src/printer/tvmscript_printer.cc
index fdafdbf..fa74e56 100644
--- a/src/printer/tvmscript_printer.cc
+++ b/src/printer/tvmscript_printer.cc
@@ -769,6 +769,11 @@ Doc TVMScriptPrinter::VisitStmt_(const AllocateNode* op) {
     if (!is_one(op->condition)) {
       doc << ", " << Print(op->condition);
     }
+    if (!op->annotations.empty()) {
+      doc << ", annotations={";
+      doc << PrintAnnotations(op->annotations);
+      doc << "}";
+    }
     doc << ") as " << Print(op->buffer_var) << ":";
     doc << Doc::Indent(4, Doc::NewLine() << PrintBody(op->body));
   } else {
@@ -777,6 +782,11 @@ Doc TVMScriptPrinter::VisitStmt_(const AllocateNode* op) {
     if (!is_one(op->condition)) {
       doc << ", " << Print(op->condition);
     }
+    if (!op->annotations.empty()) {
+      doc << ", annotations={";
+      doc << PrintAnnotations(op->annotations);
+      doc << "}";
+    }
     doc << ")" << Doc::NewLine() << PrintBody(op->body);
   }
   TryDeallocVar(op->buffer_var);
diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc
index d59c94d..0d42c20 100644
--- a/src/tir/ir/stmt.cc
+++ b/src/tir/ir/stmt.cc
@@ -333,7 +333,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
 
 // Allocate
 Allocate::Allocate(Var buffer_var, DataType dtype, Array<PrimExpr> extents, PrimExpr condition,
-                   Stmt body, Span span) {
+                   Stmt body, Map<String, ObjectRef> annotations, Span span) {
   CHECK(IsPointerType(buffer_var->type_annotation, dtype))
       << "The allocated data type (" << dtype
       << ") does not match the type annotation of the buffer " << buffer_var << " ("
@@ -354,6 +354,7 @@ Allocate::Allocate(Var buffer_var, DataType dtype, Array<PrimExpr> extents, Prim
   node->extents = std::move(extents);
   node->condition = std::move(condition);
   node->body = std::move(body);
+  node->annotations = std::move(annotations);
   node->span = std::move(span);
   data_ = std::move(node);
 }
@@ -375,8 +376,8 @@ int32_t AllocateNode::constant_allocation_size(const Array<PrimExpr>& extents) {
 
 TVM_REGISTER_GLOBAL("tir.Allocate")
     .set_body_typed([](Var buffer_var, DataType type, Array<PrimExpr> extents, PrimExpr condition,
-                       Stmt body, Span span) {
-      return Allocate(buffer_var, type, extents, condition, body, span);
+                       Stmt body, Map<String, ObjectRef> annotations, Span span) {
+      return Allocate(buffer_var, type, extents, condition, body, annotations, span);
     });
 
 TVM_REGISTER_NODE_TYPE(AllocateNode);
diff --git a/tests/python/unittest/test_tir_nodes.py b/tests/python/unittest/test_tir_nodes.py
index de94464..fe719ee 100644
--- a/tests/python/unittest/test_tir_nodes.py
+++ b/tests/python/unittest/test_tir_nodes.py
@@ -473,5 +473,38 @@ def test_block_blockrealize():
     assert output.find("with init()") != -1
 
 
+def test_tir_allocate():
+    dtype = "int8"
+    storage_scope = "global"
+    ptype = tvm.ir.PointerType(tvm.ir.PrimType(dtype), storage_scope)
+    a = te.var("buffer", ptype)
+    allocate = tvm.tir.Allocate(
+        buffer_var=a,
+        dtype=dtype,
+        extents=[2, 2],
+        condition=tvm.get_global_func("tir.const_true")(dtype, None),
+        body=tvm.tir.Evaluate(2 + 1),
+        annotations={
+            "attr1": "foo",
+            "attr2": "bar",
+        },
+    )
+    assert allocate.buffer_var == a
+    assert allocate.dtype == "int8"
+    assert list(allocate.extents) == [2, 2]
+    assert allocate.annotations["attr1"] == "foo"
+    assert allocate.annotations["attr2"] == "bar"
+
+    # make sure we can print using TIRTextPrinter
+    func = tvm.tir.PrimFunc([], allocate)
+    output = func.astext()
+    assert (
+        output.find(
+            'allocate(buffer: Pointer(global int8), int8, [2, 2]), storage_scope = global, annotations = {"attr2": "bar", "attr1": "foo"})'
+        )
+        != -1
+    )
+
+
 if __name__ == "__main__":
     pytest.main([__file__])
diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py b/tests/python/unittest/test_tvmscript_roundtrip.py
index 94d4bed..8058b96 100644
--- a/tests/python/unittest/test_tvmscript_roundtrip.py
+++ b/tests/python/unittest/test_tvmscript_roundtrip.py
@@ -3059,5 +3059,32 @@ def test_while_loop():
     tvm.ir.assert_structural_equal(while_loop, rt_func)
 
 
+# fmt: off
+@T.prim_func
+def primfunc_with_allocate_annotations(placeholder_28: T.handle, T_cast_6: T.handle) -> None:
+    # function attr dict
+    T.func_attr({"global_symbol": "tvmgen_default_fused_nn_max_pool2d_cast", "tir.noalias": True})
+    placeholder_29 = T.match_buffer(placeholder_28, [1, 112, 112, 64], dtype="uint8", elem_offset=0, align=128, offset_factor=1)
+    T_cast_7 = T.match_buffer(T_cast_6, [1, 56, 56, 64], dtype="int16", elem_offset=0, align=128, offset_factor=1)
+    # body
+    tensor_2 = T.allocate([200704], "uint8", "global", annotations={"attr1_key": "attr1_value"})
+    for ax0_ax1_fused_4 in T.serial(0, 56):
+        for ax2_4 in T.serial(0, 56):
+            for ax3_init in T.serial(0, 64):
+                T.store(tensor_2, (((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_init), T.uint8(0), True)
+            for rv0_rv1_fused_1, ax3_2 in T.grid(9, 64):
+                T.store(tensor_2, (((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2), T.max(T.load("uint8", tensor_2, (((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2)), T.if_then_else(((((ax0_ax1_fused_4*2) + T.floordiv(rv0_rv1_fused_1, 3)) < 112) and (((ax2_4*2) + T.floormod(rv0_rv1_fused_1, 3)) < 112)), T.load("uint8", placeholder_29.data, (((((ax0_ax1_fused_4*14336) + (T.floordiv(rv0_rv1_fused_1, 3)*7168)) + (ax2_4*128)) + (T.floormod(rv0_rv1_fused_1, 3)*64)) + ax3_2)), T.uint8(0), dtype= [...]
+    for ax0_ax1_fused_5 in T.serial(0, 56):
+        for ax2_5, ax3_3 in T.grid(56, 64):
+            T.store(T_cast_7.data, (((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3), T.cast(T.load("uint8", tensor_2, (((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3)), "int16"), True)
+# fmt: on
+
+
+def test_primfunc_with_allocate_annotations():
+    func = primfunc_with_allocate_annotations
+    rt_func = tvm.script.from_source(func.script(show_meta=True))
+    tvm.ir.assert_structural_equal(func, rt_func, True)
+
+
 if __name__ == "__main__":
     sys.exit(pytest.main([__file__] + sys.argv[1:]))