You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2023/04/06 23:39:09 UTC
[tvm] 02/03: [TVMScript] Expose IRModule::attrs as I.module_attrs
This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
commit ff5118f3981b0216c6c961185513d5625a37f82f
Author: Hongyi Jin <32...@qq.com>
AuthorDate: Sun Feb 26 11:05:47 2023 -0500
[TVMScript] Expose IRModule::attrs as I.module_attrs
This is an upstreaming of the non-relax portions of
https://github.com/apache/tvm/pull/14132, including a unit test
specically to validate `I.module_attrs`.
---
include/tvm/script/ir_builder/base.h | 2 ++
include/tvm/script/ir_builder/ir/frame.h | 3 +++
python/tvm/ir/module.py | 14 ++++++++++++--
python/tvm/script/ir_builder/base.py | 11 +++++++++++
python/tvm/script/ir_builder/ir/__init__.py | 7 ++++++-
python/tvm/script/ir_builder/ir/ir.py | 14 ++++++++++++++
python/tvm/script/parser/ir/__init__.py | 4 ++--
python/tvm/script/parser/ir/parser.py | 11 +++++++++--
src/ir/module.cc | 6 ++----
src/script/ir_builder/base.cc | 6 ++++++
src/script/ir_builder/ir/frame.cc | 3 ++-
src/script/ir_builder/ir/ir.cc | 12 ++++++++++++
src/script/printer/ir/ir.cc | 5 +++++
tests/python/unittest/test_tvmscript_roundtrip.py | 14 ++++++++++++++
14 files changed, 100 insertions(+), 12 deletions(-)
diff --git a/include/tvm/script/ir_builder/base.h b/include/tvm/script/ir_builder/base.h
index 61ca3eb9f7..a00ea5768e 100644
--- a/include/tvm/script/ir_builder/base.h
+++ b/include/tvm/script/ir_builder/base.h
@@ -237,6 +237,8 @@ class IRBuilder : public runtime::ObjectRef {
* \sa tvm::support::With
*/
static IRBuilder Current();
+ /*! \brief See if the current thread-local scope has an IRBuilder. */
+ static bool IsInScope();
/*!
* \brief Give a string name to the `obj`
* \tparam TObjectRef The type of the object to name.
diff --git a/include/tvm/script/ir_builder/ir/frame.h b/include/tvm/script/ir_builder/ir/frame.h
index dacfc361a6..ed425cf614 100644
--- a/include/tvm/script/ir_builder/ir/frame.h
+++ b/include/tvm/script/ir_builder/ir/frame.h
@@ -45,11 +45,14 @@ class IRModuleFrameNode : public IRBuilderFrameNode {
* \note Only defined functions are in the map, while declared functions are not included.
*/
Map<GlobalVar, BaseFunc> functions;
+ /*! \brief IRModule's attributes. */
+ Map<String, ObjectRef> attrs;
void VisitAttrs(tvm::AttrVisitor* v) {
IRBuilderFrameNode::VisitAttrs(v);
v->Visit("global_vars", &global_var_map);
v->Visit("functions", &functions);
+ v->Visit("attrs", &attrs);
}
static constexpr const char* _type_key = "script.ir_builder.IRModuleFrame";
diff --git a/python/tvm/ir/module.py b/python/tvm/ir/module.py
index 3daffb2640..232c70aa93 100644
--- a/python/tvm/ir/module.py
+++ b/python/tvm/ir/module.py
@@ -37,7 +37,7 @@ class IRModule(Node, Scriptable):
Map of global var to BaseFunc
"""
- def __init__(self, functions=None, type_definitions=None):
+ def __init__(self, functions=None, type_definitions=None, attrs=None):
if functions is None:
functions = {}
elif isinstance(functions, dict):
@@ -60,7 +60,17 @@ class IRModule(Node, Scriptable):
raise TypeError("Expect type_definitions to be Dict[GlobalTypeVar, Type]")
mapped_type_defs[k] = v
type_definitions = mapped_type_defs
- self.__init_handle_by_constructor__(_ffi_api.IRModule, functions, type_definitions)
+
+ attrs = None if not attrs else attrs
+ if attrs is not None:
+ attrs = ast.literal_eval(str(attrs))
+ attrs = tvm.ir.make_node("DictAttrs", **attrs)
+ self.__init_handle_by_constructor__(
+ _ffi_api.IRModule,
+ functions,
+ type_definitions,
+ attrs,
+ )
def __setitem__(self, var, val):
"""Add a mapping to the module.
diff --git a/python/tvm/script/ir_builder/base.py b/python/tvm/script/ir_builder/base.py
index b35bbd0a7d..1d5d050444 100644
--- a/python/tvm/script/ir_builder/base.py
+++ b/python/tvm/script/ir_builder/base.py
@@ -138,6 +138,17 @@ class IRBuilder(_Object):
"""
return _ffi_api.IRBuilderCurrent() # type: ignore[attr-defined] # pylint: disable=no-member
+ @staticmethod
+ def is_in_scope() -> bool:
+ """See if the current thread-local scope has an IRBuilder.
+
+ Returns
+ -------
+ bool
+ Whether the current thread-local scope has an IRBuilder
+ """
+ return _ffi_api.IRBuilderIsInScope() # type: ignore[attr-defined] # pylint: disable=no-member
+
def get(self) -> _Object:
"""Get the constructed IR."""
return _ffi_api.IRBuilderGet(self) # type: ignore[attr-defined] # pylint: disable=no-member
diff --git a/python/tvm/script/ir_builder/ir/__init__.py b/python/tvm/script/ir_builder/ir/__init__.py
index 946be263a7..b796de8113 100644
--- a/python/tvm/script/ir_builder/ir/__init__.py
+++ b/python/tvm/script/ir_builder/ir/__init__.py
@@ -16,4 +16,9 @@
# under the License.
"""Package tvm.script.ir_builder.ir"""
from .frame import IRModuleFrame
-from .ir import decl_function, def_function, ir_module
+from .ir import (
+ decl_function,
+ def_function,
+ ir_module,
+ module_attrs,
+)
diff --git a/python/tvm/script/ir_builder/ir/ir.py b/python/tvm/script/ir_builder/ir/ir.py
index 796d6f3aad..c5276f8d13 100644
--- a/python/tvm/script/ir_builder/ir/ir.py
+++ b/python/tvm/script/ir_builder/ir/ir.py
@@ -16,6 +16,10 @@
# under the License.
"""Package tvm.script.ir_builder.ir.ir"""
+from typing import Dict
+
+from tvm.runtime import Object as tvm_Object
+
from tvm.ir import BaseFunc, GlobalVar
from . import _ffi_api
@@ -67,3 +71,13 @@ def def_function(func_name: str, func: BaseFunc) -> None:
The given function implementation
"""
return _ffi_api.DefFunction(func_name, func) # type: ignore[attr-defined] # pylint: disable=no-member
+
+
+def module_attrs(attrs: Dict[str, tvm_Object]) -> None:
+ """Specify the attrs of the ir_module frame.
+ Parameters
+ ----------
+ attrs: Dict[str, Object]
+ The module attrs.
+ """
+ return _ffi_api.ModuleAttrs(attrs) # type: ignore[attr-defined] # pylint: disable=no-member
diff --git a/python/tvm/script/parser/ir/__init__.py b/python/tvm/script/parser/ir/__init__.py
index fedd2f0a14..adda176012 100644
--- a/python/tvm/script/parser/ir/__init__.py
+++ b/python/tvm/script/parser/ir/__init__.py
@@ -15,8 +15,8 @@
# specific language governing permissions and limitations
# under the License.
"""The ir module parser"""
-
+from ...ir_builder.ir import * # pylint: disable=redefined-builtin
from . import parser as _parser
from .entry import ir_module
-__all__ = ["ir_module"]
+__all__ = ["ir_module", "module_attrs"]
diff --git a/python/tvm/script/parser/ir/parser.py b/python/tvm/script/parser/ir/parser.py
index 13b3e29859..201c99074f 100644
--- a/python/tvm/script/parser/ir/parser.py
+++ b/python/tvm/script/parser/ir/parser.py
@@ -35,11 +35,17 @@ def _visit_class_def(self: Parser, node: doc.ClassDef) -> None:
with self.var_table.with_frame():
with I.ir_module():
+ with self.with_dispatch_token("ir"):
+ for stmt in node.body:
+ if not isinstance(stmt, doc.FunctionDef):
+ self.visit(stmt)
for stmt in node.body:
if isinstance(stmt, doc.FunctionDef):
self.visit_tvm_declare_function(stmt)
with self.with_dispatch_token("ir"):
- self.visit_body(node.body)
+ for stmt in node.body:
+ if isinstance(stmt, doc.FunctionDef):
+ self.visit(stmt)
@dispatch.register(token="ir", type_name="Assign")
@@ -57,7 +63,7 @@ def _visit_assign(_self: Parser, _node: doc.Assign) -> None:
@dispatch.register(token="ir", type_name="Expr")
-def _visit_expr(_self: Parser, _node: doc.Expr) -> None:
+def _visit_expr(self: Parser, node: doc.Expr) -> None:
"""The expression visiting method for ir module.
Parameters
@@ -68,6 +74,7 @@ def _visit_expr(_self: Parser, _node: doc.Expr) -> None:
node : doc.ClassDef
The doc AST expression node.
"""
+ self.eval_expr(node.value)
@dispatch.register(token="default", type_name="Assign")
diff --git a/src/ir/module.cc b/src/ir/module.cc
index 4d5bebf708..ba66a66894 100644
--- a/src/ir/module.cc
+++ b/src/ir/module.cc
@@ -382,10 +382,8 @@ IRModule IRModule::FromText(const String& text, const String& source_path) {
TVM_REGISTER_NODE_TYPE(IRModuleNode);
TVM_REGISTER_GLOBAL("ir.IRModule")
- .set_body_typed([](tvm::Map<GlobalVar, BaseFunc> funcs,
- tvm::Map<GlobalTypeVar, TypeData> types) {
- return IRModule(funcs, types, {});
- });
+ .set_body_typed([](tvm::Map<GlobalVar, BaseFunc> funcs, tvm::Map<GlobalTypeVar, TypeData> types,
+ tvm::DictAttrs attrs) { return IRModule(funcs, types, {}, {}, attrs); });
TVM_REGISTER_GLOBAL("ir.Module_Add")
.set_body_typed([](IRModule mod, GlobalVar var, ObjectRef val, bool update) -> IRModule {
diff --git a/src/script/ir_builder/base.cc b/src/script/ir_builder/base.cc
index 8303efff4f..879db4f3d7 100644
--- a/src/script/ir_builder/base.cc
+++ b/src/script/ir_builder/base.cc
@@ -77,6 +77,11 @@ IRBuilder IRBuilder::Current() {
return stack->back();
}
+bool IRBuilder::IsInScope() {
+ std::vector<IRBuilder>* stack = ThreadLocalBuilderStack();
+ return !stack->empty();
+}
+
namespace details {
Namer::FType& Namer::vtable() {
@@ -106,6 +111,7 @@ TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilder").set_body_typed([]() { return
TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilderEnter").set_body_method(&IRBuilder::EnterWithScope);
TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilderExit").set_body_method(&IRBuilder::ExitWithScope);
TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilderCurrent").set_body_typed(IRBuilder::Current);
+TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilderIsInScope").set_body_typed(IRBuilder::IsInScope);
TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilderGet")
.set_body_method<IRBuilder>(&IRBuilderNode::Get<ObjectRef>);
TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilderName").set_body_typed(IRBuilder::Name<ObjectRef>);
diff --git a/src/script/ir_builder/ir/frame.cc b/src/script/ir_builder/ir/frame.cc
index addf129284..92470ec653 100644
--- a/src/script/ir_builder/ir/frame.cc
+++ b/src/script/ir_builder/ir/frame.cc
@@ -38,7 +38,8 @@ void IRModuleFrameNode::ExitWithScope() {
}
IRBuilder builder = IRBuilder::Current();
ICHECK(!builder->result.defined()) << "ValueError: Builder.result has already been set";
- builder->result = tvm::IRModule(func_map);
+ auto dict_attrs = attrs.empty() ? NullValue<DictAttrs>() : DictAttrs(attrs);
+ builder->result = tvm::IRModule(func_map, {}, {}, {}, dict_attrs);
}
TVM_REGISTER_NODE_TYPE(IRModuleFrameNode);
diff --git a/src/script/ir_builder/ir/ir.cc b/src/script/ir_builder/ir/ir.cc
index 5764e90c8d..0c34f85246 100644
--- a/src/script/ir_builder/ir/ir.cc
+++ b/src/script/ir_builder/ir/ir.cc
@@ -60,9 +60,21 @@ void DefFunction(const String& func_name, const BaseFunc& func) {
}
}
+void ModuleAttrs(Map<String, ObjectRef> attrs) {
+ if (IRBuilder::IsInScope()) {
+ // TODO(hongyi): add comments to explain why we need to check if the module frame is in scope
+ IRModuleFrame frame = FindModuleFrame("I.ModuleAttr");
+ if (!frame->attrs.empty()) {
+ LOG(FATAL) << "ValueError: Duplicate module attrs, previous one is:\n" << frame->attrs;
+ }
+ frame->attrs = attrs;
+ }
+}
+
TVM_REGISTER_GLOBAL("script.ir_builder.ir.IRModule").set_body_typed(IRModule);
TVM_REGISTER_GLOBAL("script.ir_builder.ir.DeclFunction").set_body_typed(DeclFunction);
TVM_REGISTER_GLOBAL("script.ir_builder.ir.DefFunction").set_body_typed(DefFunction);
+TVM_REGISTER_GLOBAL("script.ir_builder.ir.ModuleAttrs").set_body_typed(ModuleAttrs);
} // namespace ir
} // namespace ir_builder
diff --git a/src/script/printer/ir/ir.cc b/src/script/printer/ir/ir.cc
index 065cfe5168..1c751d40f2 100644
--- a/src/script/printer/ir/ir.cc
+++ b/src/script/printer/ir/ir.cc
@@ -64,6 +64,11 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
std::sort(functions.begin(), functions.end());
With<IRFrame> f(d);
(*f)->AddDispatchToken(d, "ir");
+ if (mod->attrs.defined() && !mod->attrs->dict.empty()) {
+ (*f)->stmts.push_back(
+ ExprStmtDoc(IR(d, "module_attrs") //
+ ->Call({d->AsDoc<ExprDoc>(mod->attrs, p->Attr("attrs"))})));
+ }
for (const auto& entry : functions) {
const GlobalVar& gv = entry.gv;
const BaseFunc& func = entry.func;
diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py b/tests/python/unittest/test_tvmscript_roundtrip.py
index bbc6dd45a8..52d99550be 100644
--- a/tests/python/unittest/test_tvmscript_roundtrip.py
+++ b/tests/python/unittest/test_tvmscript_roundtrip.py
@@ -3725,6 +3725,19 @@ def tvm_struct_set_generated_in_cpp():
return tvm.tir.transform.LowerTVMBuiltin()(Module)
+def ir_module_with_attrs():
+ @I.ir_module
+ class Module:
+ I.module_attrs({"attr": 10})
+
+ @T.prim_func
+ def tir_func(A: T.Buffer(16, "int32"), B: T.Buffer(16, "int32")):
+ for i in range(16):
+ B[i] = A[i]
+
+ return Module
+
+
ir_generator = tvm.testing.parameter(
launch_env_thread,
opt_gemm_normalize,
@@ -3791,6 +3804,7 @@ ir_generator = tvm.testing.parameter(
if_then_else_var,
tvm_shfl_builtins,
tvm_struct_set_generated_in_cpp,
+ ir_module_with_attrs,
)