You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by sy...@apache.org on 2021/12/20 10:47:55 UTC
[tvm] branch main updated: [TIRScript] fix parse StringImm value in for loop annotations (#9755)
This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new f4af81c [TIRScript] fix parse StringImm value in for loop annotations (#9755)
f4af81c is described below
commit f4af81c5f7bbdf70f7271c5fe54f332857165cab
Author: wrongtest <wr...@gmail.com>
AuthorDate: Mon Dec 20 18:47:17 2021 +0800
[TIRScript] fix parse StringImm value in for loop annotations (#9755)
* fix parse strimm value in for annotations
* flatten buffer allow runtime.String attr value
* remove unused import
* rebase and ensure flattened attr order
---
python/tvm/script/tir/scope_handler.py | 12 +++-------
src/tir/transforms/flatten_buffer.cc | 26 +++++++++++++++++++---
.../unittest/test_tir_transform_flatten_buffer.py | 22 ++++++++++++++++++
3 files changed, 48 insertions(+), 12 deletions(-)
diff --git a/python/tvm/script/tir/scope_handler.py b/python/tvm/script/tir/scope_handler.py
index db3261e..fc95377 100644
--- a/python/tvm/script/tir/scope_handler.py
+++ b/python/tvm/script/tir/scope_handler.py
@@ -20,7 +20,7 @@ from typing import Tuple, Any, Callable, Optional, List, Union, Mapping
import synr
import tvm.tir
-from tvm.runtime import Object, String
+from tvm.runtime import Object
from tvm.ir import Span, Range
from tvm.tir import Stmt, PrimExpr, IterVar, Var, Buffer, BufferRegion, ForKind
@@ -483,14 +483,8 @@ class ForScopeHandler(ScopeHandler):
"""
assert self.context and self.node, "call 'exit_scope' before 'enter_scope'"
extent = end if begin == 0 else self.context.analyzer.simplify(end - begin)
- self.annotations: Mapping[str, Object] = {}
- if annotations is not None:
- self.annotations = {
- key: String(val) if isinstance(val, str) else val
- for key, val in annotations.items()
- }
-
- self.loop_info.append(LoopInfo(begin, extent, kind, thread_binding, annotations))
+ self.annotations = annotations
+ self.loop_info.append(LoopInfo(begin, extent, kind, thread_binding, self.annotations))
@register
diff --git a/src/tir/transforms/flatten_buffer.cc b/src/tir/transforms/flatten_buffer.cc
index e0ab95a..e9d99cd 100644
--- a/src/tir/transforms/flatten_buffer.cc
+++ b/src/tir/transforms/flatten_buffer.cc
@@ -97,11 +97,16 @@ class BufferFlattener : public StmtExprMutator {
body = For(op->loop_var, std::move(min), std::move(extent), op->kind, std::move(body));
}
// Step 4. Handle annotations
+ std::set<std::string> ordered_ann_keys;
for (const auto& annotation : op->annotations) {
- const String& ann_key = annotation.first;
- const ObjectRef& ann_value = annotation.second;
+ ordered_ann_keys.insert(annotation.first);
+ }
+ for (auto it = ordered_ann_keys.rbegin(); it != ordered_ann_keys.rend(); ++it) {
+ const std::string& ann_key = *it;
+ const ObjectRef& ann_value = op->annotations.at(ann_key);
if (attr::IsPragmaKey(ann_key)) {
- body = AttrStmt(op->loop_var, ann_key, Downcast<PrimExpr>(ann_value), std::move(body));
+ body =
+ AttrStmt(op->loop_var, ann_key, ConvertAttrValue(ann_key, ann_value), std::move(body));
}
}
return body;
@@ -154,6 +159,21 @@ class BufferFlattener : public StmtExprMutator {
/*body=*/std::move(body));
}
+ /*! \brief Convert attr value from annotation map into PrimExpr. */
+ PrimExpr ConvertAttrValue(const String& key, const ObjectRef& obj) {
+ if (!obj.defined()) {
+ return PrimExpr();
+ } else if (const PrimExprNode* expr = obj.as<PrimExprNode>()) {
+ return GetRef<PrimExpr>(expr);
+ } else if (const StringObj* str = obj.as<StringObj>()) {
+ return std::move(StringImm(str->data));
+ } else {
+ LOG(FATAL) << "Illegal attribute of key " << key << ", value type " << obj->GetTypeKey()
+ << " not supported";
+ return PrimExpr();
+ }
+ }
+
/*! \brief Record the loop_var and loop start value of unit loops, whose extent is one. */
std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> unit_loop_vars_;
};
diff --git a/tests/python/unittest/test_tir_transform_flatten_buffer.py b/tests/python/unittest/test_tir_transform_flatten_buffer.py
index eed82eb..ca3d4aa 100644
--- a/tests/python/unittest/test_tir_transform_flatten_buffer.py
+++ b/tests/python/unittest/test_tir_transform_flatten_buffer.py
@@ -247,6 +247,13 @@ def flattened_strided_buffer_func(a: T.handle, c: T.handle) -> None:
C.data[i0 * 64 + i1 * 16 + j] = T.load("float32", B_new, i1 * 17 + j) * 2.0
+@T.prim_func
+def annotated_loops(a: T.handle) -> None:
+ A = T.match_buffer(a, (16,), "float32")
+ for i in range(0, 16, annotations={"pragma_1": "str_value", "pragma_2": 1, "pragma_3": 0.0}):
+ A[i] = 0.0
+
+
def test_elementwise():
_check(compacted_elementwise_func, flattened_elementwise_func)
@@ -284,6 +291,20 @@ def test_lower_te():
tvm.ir.assert_structural_equal(mod, orig_mod) # FlattenBuffer should do nothing on TE
+def test_annotated_loops():
+ mod = tvm.IRModule.from_expr(annotated_loops)
+ mod = tvm.tir.transform.FlattenBuffer()(mod)
+ # _check(annotated_loops, compacted_annotated_loops)
+ attr1 = mod["main"].body
+ attr2 = attr1.body
+ attr3 = attr2.body
+ assert attr1.attr_key == "pragma_1" and attr1.value == "str_value"
+ assert attr2.attr_key == "pragma_2"
+ tvm.ir.assert_structural_equal(attr2.value, tvm.tir.IntImm("int32", 1))
+ assert attr3.attr_key == "pragma_3"
+ tvm.ir.assert_structural_equal(attr3.value, tvm.tir.FloatImm("float32", 0.0))
+
+
if __name__ == "__main__":
test_elementwise()
test_gpu_workload()
@@ -293,3 +314,4 @@ if __name__ == "__main__":
test_multi_alloc()
test_strided_buffer()
test_lower_te()
+ test_annotated_loops()