You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by sy...@apache.org on 2021/12/20 10:47:55 UTC

[tvm] branch main updated: [TIRScript] fix parse StringImm value in for loop annotations (#9755)

This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new f4af81c  [TIRScript] fix parse StringImm value in for loop annotations (#9755)
f4af81c is described below

commit f4af81c5f7bbdf70f7271c5fe54f332857165cab
Author: wrongtest <wr...@gmail.com>
AuthorDate: Mon Dec 20 18:47:17 2021 +0800

    [TIRScript] fix parse StringImm value in for loop annotations (#9755)
    
    * fix parse strimm value in for annotations
    
    * flatten buffer allow runtime.String attr value
    
    * remove unused import
    
    * rebase and ensure flattened attr order
---
 python/tvm/script/tir/scope_handler.py             | 12 +++-------
 src/tir/transforms/flatten_buffer.cc               | 26 +++++++++++++++++++---
 .../unittest/test_tir_transform_flatten_buffer.py  | 22 ++++++++++++++++++
 3 files changed, 48 insertions(+), 12 deletions(-)

diff --git a/python/tvm/script/tir/scope_handler.py b/python/tvm/script/tir/scope_handler.py
index db3261e..fc95377 100644
--- a/python/tvm/script/tir/scope_handler.py
+++ b/python/tvm/script/tir/scope_handler.py
@@ -20,7 +20,7 @@ from typing import Tuple, Any, Callable, Optional, List, Union, Mapping
 
 import synr
 import tvm.tir
-from tvm.runtime import Object, String
+from tvm.runtime import Object
 from tvm.ir import Span, Range
 from tvm.tir import Stmt, PrimExpr, IterVar, Var, Buffer, BufferRegion, ForKind
 
@@ -483,14 +483,8 @@ class ForScopeHandler(ScopeHandler):
         """
         assert self.context and self.node, "call 'exit_scope' before 'enter_scope'"
         extent = end if begin == 0 else self.context.analyzer.simplify(end - begin)
-        self.annotations: Mapping[str, Object] = {}
-        if annotations is not None:
-            self.annotations = {
-                key: String(val) if isinstance(val, str) else val
-                for key, val in annotations.items()
-            }
-
-        self.loop_info.append(LoopInfo(begin, extent, kind, thread_binding, annotations))
+        self.annotations = annotations
+        self.loop_info.append(LoopInfo(begin, extent, kind, thread_binding, self.annotations))
 
 
 @register
diff --git a/src/tir/transforms/flatten_buffer.cc b/src/tir/transforms/flatten_buffer.cc
index e0ab95a..e9d99cd 100644
--- a/src/tir/transforms/flatten_buffer.cc
+++ b/src/tir/transforms/flatten_buffer.cc
@@ -97,11 +97,16 @@ class BufferFlattener : public StmtExprMutator {
       body = For(op->loop_var, std::move(min), std::move(extent), op->kind, std::move(body));
     }
     // Step 4. Handle annotations
+    std::set<std::string> ordered_ann_keys;
     for (const auto& annotation : op->annotations) {
-      const String& ann_key = annotation.first;
-      const ObjectRef& ann_value = annotation.second;
+      ordered_ann_keys.insert(annotation.first);
+    }
+    for (auto it = ordered_ann_keys.rbegin(); it != ordered_ann_keys.rend(); ++it) {
+      const std::string& ann_key = *it;
+      const ObjectRef& ann_value = op->annotations.at(ann_key);
       if (attr::IsPragmaKey(ann_key)) {
-        body = AttrStmt(op->loop_var, ann_key, Downcast<PrimExpr>(ann_value), std::move(body));
+        body =
+            AttrStmt(op->loop_var, ann_key, ConvertAttrValue(ann_key, ann_value), std::move(body));
       }
     }
     return body;
@@ -154,6 +159,21 @@ class BufferFlattener : public StmtExprMutator {
                     /*body=*/std::move(body));
   }
 
+  /*! \brief Convert attr value from annotation map into PrimExpr. */
+  PrimExpr ConvertAttrValue(const String& key, const ObjectRef& obj) {
+    if (!obj.defined()) {
+      return PrimExpr();
+    } else if (const PrimExprNode* expr = obj.as<PrimExprNode>()) {
+      return GetRef<PrimExpr>(expr);
+    } else if (const StringObj* str = obj.as<StringObj>()) {
+      return std::move(StringImm(str->data));
+    } else {
+      LOG(FATAL) << "Illegal attribute of key " << key << ", value type " << obj->GetTypeKey()
+                 << " not supported";
+      return PrimExpr();
+    }
+  }
+
   /*! \brief Record the loop_var and loop start value of unit loops, whose extent is one. */
   std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> unit_loop_vars_;
 };
diff --git a/tests/python/unittest/test_tir_transform_flatten_buffer.py b/tests/python/unittest/test_tir_transform_flatten_buffer.py
index eed82eb..ca3d4aa 100644
--- a/tests/python/unittest/test_tir_transform_flatten_buffer.py
+++ b/tests/python/unittest/test_tir_transform_flatten_buffer.py
@@ -247,6 +247,13 @@ def flattened_strided_buffer_func(a: T.handle, c: T.handle) -> None:
                 C.data[i0 * 64 + i1 * 16 + j] = T.load("float32", B_new, i1 * 17 + j) * 2.0
 
 
+@T.prim_func
+def annotated_loops(a: T.handle) -> None:
+    A = T.match_buffer(a, (16,), "float32")
+    for i in range(0, 16, annotations={"pragma_1": "str_value", "pragma_2": 1, "pragma_3": 0.0}):
+        A[i] = 0.0
+
+
 def test_elementwise():
     _check(compacted_elementwise_func, flattened_elementwise_func)
 
@@ -284,6 +291,20 @@ def test_lower_te():
     tvm.ir.assert_structural_equal(mod, orig_mod)  # FlattenBuffer should do nothing on TE
 
 
+def test_annotated_loops():
+    mod = tvm.IRModule.from_expr(annotated_loops)
+    mod = tvm.tir.transform.FlattenBuffer()(mod)
+    # _check(annotated_loops, compacted_annotated_loops)
+    attr1 = mod["main"].body
+    attr2 = attr1.body
+    attr3 = attr2.body
+    assert attr1.attr_key == "pragma_1" and attr1.value == "str_value"
+    assert attr2.attr_key == "pragma_2"
+    tvm.ir.assert_structural_equal(attr2.value, tvm.tir.IntImm("int32", 1))
+    assert attr3.attr_key == "pragma_3"
+    tvm.ir.assert_structural_equal(attr3.value, tvm.tir.FloatImm("float32", 0.0))
+
+
 if __name__ == "__main__":
     test_elementwise()
     test_gpu_workload()
@@ -293,3 +314,4 @@ if __name__ == "__main__":
     test_multi_alloc()
     test_strided_buffer()
     test_lower_te()
+    test_annotated_loops()