You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by sy...@apache.org on 2021/10/07 12:45:20 UTC
[tvm] branch main updated: Adding annotations for tir.allocate
(#9168)
This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 2dae303 Adding annotations for tir.allocate (#9168)
2dae303 is described below
commit 2dae30372c6ef1b148f545ede4952dc498ecb13f
Author: Manupa Karunaratne <ma...@arm.com>
AuthorDate: Thu Oct 7 13:44:59 2021 +0100
Adding annotations for tir.allocate (#9168)
* Adding annotation for tir.allocate
This commit is adding annotations for tir.allocate
node to be used as hints for future transformations.
Change-Id: I02a3a875c38c3edd449385da5b741ef4958bb47f
* Adding annotation for tir.allocate
* adding tvmscript support
* adding tir text printing support
Change-Id: Id0b6725b2e79c23f6b8ff192772f1ea4125a27c2
---
include/tvm/tir/stmt.h | 14 ++++++++--
python/tvm/script/tir/scope_handler.py | 16 ++++++++---
python/tvm/tir/stmt.py | 16 +++++++++--
src/printer/tir_text_printer.cc | 12 +++++++--
src/printer/tvmscript_printer.cc | 10 +++++++
src/tir/ir/stmt.cc | 7 ++---
tests/python/unittest/test_tir_nodes.py | 33 +++++++++++++++++++++++
tests/python/unittest/test_tvmscript_roundtrip.py | 27 +++++++++++++++++++
8 files changed, 122 insertions(+), 13 deletions(-)
diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h
index 94ba853..5cd860b 100644
--- a/include/tvm/tir/stmt.h
+++ b/include/tvm/tir/stmt.h
@@ -521,6 +521,13 @@ class AllocateNode : public StmtNode {
PrimExpr condition;
/*! \brief The body to be executed. */
Stmt body;
+ /*!
+ * \brief Additional annotations about the allocation.
+ *
+ * These annotations can be used as auxiliary hint
+ * to future transformations.
+ */
+ Map<String, ObjectRef> annotations;
void VisitAttrs(AttrVisitor* v) {
v->Visit("buffer_var", &buffer_var);
@@ -528,13 +535,14 @@ class AllocateNode : public StmtNode {
v->Visit("extents", &extents);
v->Visit("condition", &condition);
v->Visit("body", &body);
+ v->Visit("annotations", &annotations);
v->Visit("span", &span);
}
bool SEqualReduce(const AllocateNode* other, SEqualReducer equal) const {
return equal.DefEqual(buffer_var, other->buffer_var) && equal(dtype, other->dtype) &&
equal(extents, other->extents) && equal(condition, other->condition) &&
- equal(body, other->body);
+ equal(body, other->body) && equal(annotations, other->annotations);
}
void SHashReduce(SHashReducer hash_reduce) const {
@@ -543,6 +551,7 @@ class AllocateNode : public StmtNode {
hash_reduce(extents);
hash_reduce(condition);
hash_reduce(body);
+ hash_reduce(annotations);
}
/*!
@@ -570,7 +579,8 @@ class AllocateNode : public StmtNode {
class Allocate : public Stmt {
public:
TVM_DLL Allocate(Var buffer_var, DataType dtype, Array<PrimExpr> extents, PrimExpr condition,
- Stmt body, Span span = Span());
+ Stmt body, Map<String, ObjectRef> annotations = Map<String, ObjectRef>(),
+ Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(Allocate, Stmt, AllocateNode);
};
diff --git a/python/tvm/script/tir/scope_handler.py b/python/tvm/script/tir/scope_handler.py
index 1072809..487a71d 100644
--- a/python/tvm/script/tir/scope_handler.py
+++ b/python/tvm/script/tir/scope_handler.py
@@ -104,14 +104,20 @@ class WithScopeHandler(ScopeHandler):
@register
class Allocate(WithScopeHandler):
- """With scope handler T.allocate(extents, dtype, scope, condition)"""
+ """With scope handler T.allocate(extents, dtype, scope, condition, annotations)"""
def __init__(self):
- def allocate(extents, dtype, scope, condition=True, span=None):
+ def allocate(extents, dtype, scope, condition=True, annotations=None, span=None):
condition = tvm.runtime.convert(condition)
scope = tvm.runtime.convert(scope)
return tvm.tir.Allocate(
- self.buffer_var, dtype, extents, condition, self.body, span=span
+ self.buffer_var,
+ dtype,
+ extents,
+ condition,
+ self.body,
+ annotations=annotations,
+ span=span,
)
super().__init__(allocate, concise_scope=True, def_symbol=True)
@@ -137,7 +143,9 @@ class Allocate(WithScopeHandler):
else:
raise Exception("Internal Bug")
- def setup_buffer_var(extents, dtype, scope, condition=True, span: Span = None):
+ def setup_buffer_var(
+ extents, dtype, scope, condition=True, annotations=None, span: Span = None
+ ):
"""Setup buffer var for a given type."""
buffer_ptr_type = tvm.ir.PointerType(tvm.ir.PrimType(dtype), scope)
self.buffer_var = tvm.tir.Var(name, buffer_ptr_type, span)
diff --git a/python/tvm/tir/stmt.py b/python/tvm/tir/stmt.py
index d57077f..de200d5 100644
--- a/python/tvm/tir/stmt.py
+++ b/python/tvm/tir/stmt.py
@@ -318,13 +318,25 @@ class Allocate(Stmt):
body : Stmt
The body statement.
+ annotations: Optional[Mapping[str, Object]]
+ Additional annotation hints
+
span : Optional[Span]
The location of this itervar in the source code.
"""
- def __init__(self, buffer_var, dtype, extents, condition, body, span=None):
+ def __init__(self, buffer_var, dtype, extents, condition, body, annotations=None, span=None):
+ if annotations is None:
+ annotations = dict()
self.__init_handle_by_constructor__(
- _ffi_api.Allocate, buffer_var, dtype, extents, condition, body, span # type: ignore
+ _ffi_api.Allocate, # type: ignore
+ buffer_var,
+ dtype,
+ extents,
+ condition,
+ body,
+ annotations,
+ span,
)
diff --git a/src/printer/tir_text_printer.cc b/src/printer/tir_text_printer.cc
index f232994..fa132f0 100644
--- a/src/printer/tir_text_printer.cc
+++ b/src/printer/tir_text_printer.cc
@@ -449,8 +449,16 @@ Doc TIRTextPrinter::VisitStmt_(const BufferRealizeNode* op) {
Doc TIRTextPrinter::VisitStmt_(const AllocateNode* op) {
Doc doc;
auto scope = GetPtrStorageScope(op->buffer_var);
- doc << "allocate(" << Print(op->buffer_var) << ", " << PrintDType(op->dtype) << ", "
- << Print(op->extents) << "), storage_scope = " << scope;
+ doc << "allocate(" << Print(op->buffer_var) << ", ";
+ doc << PrintDType(op->dtype) << ", ";
+ doc << Print(op->extents) << "), storage_scope = " << scope;
+ if (!op->annotations.empty()) {
+ std::vector<Doc> attr_docs;
+ for (const auto& it : op->annotations) {
+ attr_docs.push_back(Doc::StrLiteral(it.first) << ": " << Print(it.second));
+ }
+ doc << ", annotations = {" << PrintSep(attr_docs, Doc::Text(", ")) << "})";
+ }
if (!is_one(op->condition)) {
doc << " if " << Print(op->condition);
}
diff --git a/src/printer/tvmscript_printer.cc b/src/printer/tvmscript_printer.cc
index fdafdbf..fa74e56 100644
--- a/src/printer/tvmscript_printer.cc
+++ b/src/printer/tvmscript_printer.cc
@@ -769,6 +769,11 @@ Doc TVMScriptPrinter::VisitStmt_(const AllocateNode* op) {
if (!is_one(op->condition)) {
doc << ", " << Print(op->condition);
}
+ if (!op->annotations.empty()) {
+ doc << ", annotations={";
+ doc << PrintAnnotations(op->annotations);
+ doc << "}";
+ }
doc << ") as " << Print(op->buffer_var) << ":";
doc << Doc::Indent(4, Doc::NewLine() << PrintBody(op->body));
} else {
@@ -777,6 +782,11 @@ Doc TVMScriptPrinter::VisitStmt_(const AllocateNode* op) {
if (!is_one(op->condition)) {
doc << ", " << Print(op->condition);
}
+ if (!op->annotations.empty()) {
+ doc << ", annotations={";
+ doc << PrintAnnotations(op->annotations);
+ doc << "}";
+ }
doc << ")" << Doc::NewLine() << PrintBody(op->body);
}
TryDeallocVar(op->buffer_var);
diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc
index d59c94d..0d42c20 100644
--- a/src/tir/ir/stmt.cc
+++ b/src/tir/ir/stmt.cc
@@ -333,7 +333,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
// Allocate
Allocate::Allocate(Var buffer_var, DataType dtype, Array<PrimExpr> extents, PrimExpr condition,
- Stmt body, Span span) {
+ Stmt body, Map<String, ObjectRef> annotations, Span span) {
CHECK(IsPointerType(buffer_var->type_annotation, dtype))
<< "The allocated data type (" << dtype
<< ") does not match the type annotation of the buffer " << buffer_var << " ("
@@ -354,6 +354,7 @@ Allocate::Allocate(Var buffer_var, DataType dtype, Array<PrimExpr> extents, Prim
node->extents = std::move(extents);
node->condition = std::move(condition);
node->body = std::move(body);
+ node->annotations = std::move(annotations);
node->span = std::move(span);
data_ = std::move(node);
}
@@ -375,8 +376,8 @@ int32_t AllocateNode::constant_allocation_size(const Array<PrimExpr>& extents) {
TVM_REGISTER_GLOBAL("tir.Allocate")
.set_body_typed([](Var buffer_var, DataType type, Array<PrimExpr> extents, PrimExpr condition,
- Stmt body, Span span) {
- return Allocate(buffer_var, type, extents, condition, body, span);
+ Stmt body, Map<String, ObjectRef> annotations, Span span) {
+ return Allocate(buffer_var, type, extents, condition, body, annotations, span);
});
TVM_REGISTER_NODE_TYPE(AllocateNode);
diff --git a/tests/python/unittest/test_tir_nodes.py b/tests/python/unittest/test_tir_nodes.py
index de94464..fe719ee 100644
--- a/tests/python/unittest/test_tir_nodes.py
+++ b/tests/python/unittest/test_tir_nodes.py
@@ -473,5 +473,38 @@ def test_block_blockrealize():
assert output.find("with init()") != -1
+def test_tir_allocate():
+ dtype = "int8"
+ storage_scope = "global"
+ ptype = tvm.ir.PointerType(tvm.ir.PrimType(dtype), storage_scope)
+ a = te.var("buffer", ptype)
+ allocate = tvm.tir.Allocate(
+ buffer_var=a,
+ dtype=dtype,
+ extents=[2, 2],
+ condition=tvm.get_global_func("tir.const_true")(dtype, None),
+ body=tvm.tir.Evaluate(2 + 1),
+ annotations={
+ "attr1": "foo",
+ "attr2": "bar",
+ },
+ )
+ assert allocate.buffer_var == a
+ assert allocate.dtype == "int8"
+ assert list(allocate.extents) == [2, 2]
+ assert allocate.annotations["attr1"] == "foo"
+ assert allocate.annotations["attr2"] == "bar"
+
+ # make sure we can print using TIRTextPrinter
+ func = tvm.tir.PrimFunc([], allocate)
+ output = func.astext()
+ assert (
+ output.find(
+ 'allocate(buffer: Pointer(global int8), int8, [2, 2]), storage_scope = global, annotations = {"attr2": "bar", "attr1": "foo"})'
+ )
+ != -1
+ )
+
+
if __name__ == "__main__":
pytest.main([__file__])
diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py b/tests/python/unittest/test_tvmscript_roundtrip.py
index 94d4bed..8058b96 100644
--- a/tests/python/unittest/test_tvmscript_roundtrip.py
+++ b/tests/python/unittest/test_tvmscript_roundtrip.py
@@ -3059,5 +3059,32 @@ def test_while_loop():
tvm.ir.assert_structural_equal(while_loop, rt_func)
+# fmt: off
+@T.prim_func
+def primfunc_with_allocate_annotations(placeholder_28: T.handle, T_cast_6: T.handle) -> None:
+ # function attr dict
+ T.func_attr({"global_symbol": "tvmgen_default_fused_nn_max_pool2d_cast", "tir.noalias": True})
+ placeholder_29 = T.match_buffer(placeholder_28, [1, 112, 112, 64], dtype="uint8", elem_offset=0, align=128, offset_factor=1)
+ T_cast_7 = T.match_buffer(T_cast_6, [1, 56, 56, 64], dtype="int16", elem_offset=0, align=128, offset_factor=1)
+ # body
+ tensor_2 = T.allocate([200704], "uint8", "global", annotations={"attr1_key": "attr1_value"})
+ for ax0_ax1_fused_4 in T.serial(0, 56):
+ for ax2_4 in T.serial(0, 56):
+ for ax3_init in T.serial(0, 64):
+ T.store(tensor_2, (((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_init), T.uint8(0), True)
+ for rv0_rv1_fused_1, ax3_2 in T.grid(9, 64):
+ T.store(tensor_2, (((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2), T.max(T.load("uint8", tensor_2, (((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2)), T.if_then_else(((((ax0_ax1_fused_4*2) + T.floordiv(rv0_rv1_fused_1, 3)) < 112) and (((ax2_4*2) + T.floormod(rv0_rv1_fused_1, 3)) < 112)), T.load("uint8", placeholder_29.data, (((((ax0_ax1_fused_4*14336) + (T.floordiv(rv0_rv1_fused_1, 3)*7168)) + (ax2_4*128)) + (T.floormod(rv0_rv1_fused_1, 3)*64)) + ax3_2)), T.uint8(0), dtype= [...]
+ for ax0_ax1_fused_5 in T.serial(0, 56):
+ for ax2_5, ax3_3 in T.grid(56, 64):
+ T.store(T_cast_7.data, (((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3), T.cast(T.load("uint8", tensor_2, (((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3)), "int16"), True)
+# fmt: on
+
+
+def test_primfunc_with_allocate_annotations():
+ func = primfunc_with_allocate_annotations
+ rt_func = tvm.script.from_source(func.script(show_meta=True))
+ tvm.ir.assert_structural_equal(func, rt_func, True)
+
+
if __name__ == "__main__":
sys.exit(pytest.main([__file__] + sys.argv[1:]))