You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2023/04/06 23:39:09 UTC

[tvm] 02/03: [TVMScript] Expose IRModule::attrs as I.module_attrs

This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit ff5118f3981b0216c6c961185513d5625a37f82f
Author: Hongyi Jin <32...@qq.com>
AuthorDate: Sun Feb 26 11:05:47 2023 -0500

    [TVMScript] Expose IRModule::attrs as I.module_attrs
    
    This is an upstreaming of the non-relax portions of
    https://github.com/apache/tvm/pull/14132, including a unit test
    specically to validate `I.module_attrs`.
---
 include/tvm/script/ir_builder/base.h              |  2 ++
 include/tvm/script/ir_builder/ir/frame.h          |  3 +++
 python/tvm/ir/module.py                           | 14 ++++++++++++--
 python/tvm/script/ir_builder/base.py              | 11 +++++++++++
 python/tvm/script/ir_builder/ir/__init__.py       |  7 ++++++-
 python/tvm/script/ir_builder/ir/ir.py             | 14 ++++++++++++++
 python/tvm/script/parser/ir/__init__.py           |  4 ++--
 python/tvm/script/parser/ir/parser.py             | 11 +++++++++--
 src/ir/module.cc                                  |  6 ++----
 src/script/ir_builder/base.cc                     |  6 ++++++
 src/script/ir_builder/ir/frame.cc                 |  3 ++-
 src/script/ir_builder/ir/ir.cc                    | 12 ++++++++++++
 src/script/printer/ir/ir.cc                       |  5 +++++
 tests/python/unittest/test_tvmscript_roundtrip.py | 14 ++++++++++++++
 14 files changed, 100 insertions(+), 12 deletions(-)

diff --git a/include/tvm/script/ir_builder/base.h b/include/tvm/script/ir_builder/base.h
index 61ca3eb9f7..a00ea5768e 100644
--- a/include/tvm/script/ir_builder/base.h
+++ b/include/tvm/script/ir_builder/base.h
@@ -237,6 +237,8 @@ class IRBuilder : public runtime::ObjectRef {
    * \sa tvm::support::With
    */
   static IRBuilder Current();
+  /*! \brief See if the current thread-local scope has an IRBuilder. */
+  static bool IsInScope();
   /*!
    * \brief Give a string name to the `obj`
    * \tparam TObjectRef The type of the object to name.
diff --git a/include/tvm/script/ir_builder/ir/frame.h b/include/tvm/script/ir_builder/ir/frame.h
index dacfc361a6..ed425cf614 100644
--- a/include/tvm/script/ir_builder/ir/frame.h
+++ b/include/tvm/script/ir_builder/ir/frame.h
@@ -45,11 +45,14 @@ class IRModuleFrameNode : public IRBuilderFrameNode {
    * \note Only defined functions are in the map, while declared functions are not included.
    */
   Map<GlobalVar, BaseFunc> functions;
+  /*! \brief IRModule's attributes. */
+  Map<String, ObjectRef> attrs;
 
   void VisitAttrs(tvm::AttrVisitor* v) {
     IRBuilderFrameNode::VisitAttrs(v);
     v->Visit("global_vars", &global_var_map);
     v->Visit("functions", &functions);
+    v->Visit("attrs", &attrs);
   }
 
   static constexpr const char* _type_key = "script.ir_builder.IRModuleFrame";
diff --git a/python/tvm/ir/module.py b/python/tvm/ir/module.py
index 3daffb2640..232c70aa93 100644
--- a/python/tvm/ir/module.py
+++ b/python/tvm/ir/module.py
@@ -37,7 +37,7 @@ class IRModule(Node, Scriptable):
         Map of global var to BaseFunc
     """
 
-    def __init__(self, functions=None, type_definitions=None):
+    def __init__(self, functions=None, type_definitions=None, attrs=None):
         if functions is None:
             functions = {}
         elif isinstance(functions, dict):
@@ -60,7 +60,17 @@ class IRModule(Node, Scriptable):
                     raise TypeError("Expect type_definitions to be Dict[GlobalTypeVar, Type]")
                 mapped_type_defs[k] = v
             type_definitions = mapped_type_defs
-        self.__init_handle_by_constructor__(_ffi_api.IRModule, functions, type_definitions)
+
+        attrs = None if not attrs else attrs
+        if attrs is not None:
+            attrs = ast.literal_eval(str(attrs))
+            attrs = tvm.ir.make_node("DictAttrs", **attrs)
+        self.__init_handle_by_constructor__(
+            _ffi_api.IRModule,
+            functions,
+            type_definitions,
+            attrs,
+        )
 
     def __setitem__(self, var, val):
         """Add a mapping to the module.
diff --git a/python/tvm/script/ir_builder/base.py b/python/tvm/script/ir_builder/base.py
index b35bbd0a7d..1d5d050444 100644
--- a/python/tvm/script/ir_builder/base.py
+++ b/python/tvm/script/ir_builder/base.py
@@ -138,6 +138,17 @@ class IRBuilder(_Object):
         """
         return _ffi_api.IRBuilderCurrent()  # type: ignore[attr-defined] # pylint: disable=no-member
 
+    @staticmethod
+    def is_in_scope() -> bool:
+        """See if the current thread-local scope has an IRBuilder.
+
+        Returns
+        -------
+        bool
+            Whether the current thread-local scope has an IRBuilder
+        """
+        return _ffi_api.IRBuilderIsInScope()  # type: ignore[attr-defined] # pylint: disable=no-member
+
     def get(self) -> _Object:
         """Get the constructed IR."""
         return _ffi_api.IRBuilderGet(self)  # type: ignore[attr-defined] # pylint: disable=no-member
diff --git a/python/tvm/script/ir_builder/ir/__init__.py b/python/tvm/script/ir_builder/ir/__init__.py
index 946be263a7..b796de8113 100644
--- a/python/tvm/script/ir_builder/ir/__init__.py
+++ b/python/tvm/script/ir_builder/ir/__init__.py
@@ -16,4 +16,9 @@
 # under the License.
 """Package tvm.script.ir_builder.ir"""
 from .frame import IRModuleFrame
-from .ir import decl_function, def_function, ir_module
+from .ir import (
+    decl_function,
+    def_function,
+    ir_module,
+    module_attrs,
+)
diff --git a/python/tvm/script/ir_builder/ir/ir.py b/python/tvm/script/ir_builder/ir/ir.py
index 796d6f3aad..c5276f8d13 100644
--- a/python/tvm/script/ir_builder/ir/ir.py
+++ b/python/tvm/script/ir_builder/ir/ir.py
@@ -16,6 +16,10 @@
 # under the License.
 """Package tvm.script.ir_builder.ir.ir"""
 
+from typing import Dict
+
+from tvm.runtime import Object as tvm_Object
+
 from tvm.ir import BaseFunc, GlobalVar
 
 from . import _ffi_api
@@ -67,3 +71,13 @@ def def_function(func_name: str, func: BaseFunc) -> None:
         The given function implementation
     """
     return _ffi_api.DefFunction(func_name, func)  # type: ignore[attr-defined] # pylint: disable=no-member
+
+
+def module_attrs(attrs: Dict[str, tvm_Object]) -> None:
+    """Specify the attrs of the ir_module frame.
+    Parameters
+    ----------
+    attrs: Dict[str, Object]
+        The module attrs.
+    """
+    return _ffi_api.ModuleAttrs(attrs)  # type: ignore[attr-defined] # pylint: disable=no-member
diff --git a/python/tvm/script/parser/ir/__init__.py b/python/tvm/script/parser/ir/__init__.py
index fedd2f0a14..adda176012 100644
--- a/python/tvm/script/parser/ir/__init__.py
+++ b/python/tvm/script/parser/ir/__init__.py
@@ -15,8 +15,8 @@
 # specific language governing permissions and limitations
 # under the License.
 """The ir module parser"""
-
+from ...ir_builder.ir import *  # pylint: disable=redefined-builtin
 from . import parser as _parser
 from .entry import ir_module
 
-__all__ = ["ir_module"]
+__all__ = ["ir_module", "module_attrs"]
diff --git a/python/tvm/script/parser/ir/parser.py b/python/tvm/script/parser/ir/parser.py
index 13b3e29859..201c99074f 100644
--- a/python/tvm/script/parser/ir/parser.py
+++ b/python/tvm/script/parser/ir/parser.py
@@ -35,11 +35,17 @@ def _visit_class_def(self: Parser, node: doc.ClassDef) -> None:
 
     with self.var_table.with_frame():
         with I.ir_module():
+            with self.with_dispatch_token("ir"):
+                for stmt in node.body:
+                    if not isinstance(stmt, doc.FunctionDef):
+                        self.visit(stmt)
             for stmt in node.body:
                 if isinstance(stmt, doc.FunctionDef):
                     self.visit_tvm_declare_function(stmt)
             with self.with_dispatch_token("ir"):
-                self.visit_body(node.body)
+                for stmt in node.body:
+                    if isinstance(stmt, doc.FunctionDef):
+                        self.visit(stmt)
 
 
 @dispatch.register(token="ir", type_name="Assign")
@@ -57,7 +63,7 @@ def _visit_assign(_self: Parser, _node: doc.Assign) -> None:
 
 
 @dispatch.register(token="ir", type_name="Expr")
-def _visit_expr(_self: Parser, _node: doc.Expr) -> None:
+def _visit_expr(self: Parser, node: doc.Expr) -> None:
     """The expression visiting method for ir module.
 
     Parameters
@@ -68,6 +74,7 @@ def _visit_expr(_self: Parser, _node: doc.Expr) -> None:
     node : doc.ClassDef
         The doc AST expression node.
     """
+    self.eval_expr(node.value)
 
 
 @dispatch.register(token="default", type_name="Assign")
diff --git a/src/ir/module.cc b/src/ir/module.cc
index 4d5bebf708..ba66a66894 100644
--- a/src/ir/module.cc
+++ b/src/ir/module.cc
@@ -382,10 +382,8 @@ IRModule IRModule::FromText(const String& text, const String& source_path) {
 TVM_REGISTER_NODE_TYPE(IRModuleNode);
 
 TVM_REGISTER_GLOBAL("ir.IRModule")
-    .set_body_typed([](tvm::Map<GlobalVar, BaseFunc> funcs,
-                       tvm::Map<GlobalTypeVar, TypeData> types) {
-      return IRModule(funcs, types, {});
-    });
+    .set_body_typed([](tvm::Map<GlobalVar, BaseFunc> funcs, tvm::Map<GlobalTypeVar, TypeData> types,
+                       tvm::DictAttrs attrs) { return IRModule(funcs, types, {}, {}, attrs); });
 
 TVM_REGISTER_GLOBAL("ir.Module_Add")
     .set_body_typed([](IRModule mod, GlobalVar var, ObjectRef val, bool update) -> IRModule {
diff --git a/src/script/ir_builder/base.cc b/src/script/ir_builder/base.cc
index 8303efff4f..879db4f3d7 100644
--- a/src/script/ir_builder/base.cc
+++ b/src/script/ir_builder/base.cc
@@ -77,6 +77,11 @@ IRBuilder IRBuilder::Current() {
   return stack->back();
 }
 
+bool IRBuilder::IsInScope() {
+  std::vector<IRBuilder>* stack = ThreadLocalBuilderStack();
+  return !stack->empty();
+}
+
 namespace details {
 
 Namer::FType& Namer::vtable() {
@@ -106,6 +111,7 @@ TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilder").set_body_typed([]() { return
 TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilderEnter").set_body_method(&IRBuilder::EnterWithScope);
 TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilderExit").set_body_method(&IRBuilder::ExitWithScope);
 TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilderCurrent").set_body_typed(IRBuilder::Current);
+TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilderIsInScope").set_body_typed(IRBuilder::IsInScope);
 TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilderGet")
     .set_body_method<IRBuilder>(&IRBuilderNode::Get<ObjectRef>);
 TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilderName").set_body_typed(IRBuilder::Name<ObjectRef>);
diff --git a/src/script/ir_builder/ir/frame.cc b/src/script/ir_builder/ir/frame.cc
index addf129284..92470ec653 100644
--- a/src/script/ir_builder/ir/frame.cc
+++ b/src/script/ir_builder/ir/frame.cc
@@ -38,7 +38,8 @@ void IRModuleFrameNode::ExitWithScope() {
   }
   IRBuilder builder = IRBuilder::Current();
   ICHECK(!builder->result.defined()) << "ValueError: Builder.result has already been set";
-  builder->result = tvm::IRModule(func_map);
+  auto dict_attrs = attrs.empty() ? NullValue<DictAttrs>() : DictAttrs(attrs);
+  builder->result = tvm::IRModule(func_map, {}, {}, {}, dict_attrs);
 }
 
 TVM_REGISTER_NODE_TYPE(IRModuleFrameNode);
diff --git a/src/script/ir_builder/ir/ir.cc b/src/script/ir_builder/ir/ir.cc
index 5764e90c8d..0c34f85246 100644
--- a/src/script/ir_builder/ir/ir.cc
+++ b/src/script/ir_builder/ir/ir.cc
@@ -60,9 +60,21 @@ void DefFunction(const String& func_name, const BaseFunc& func) {
   }
 }
 
+void ModuleAttrs(Map<String, ObjectRef> attrs) {
+  if (IRBuilder::IsInScope()) {
+    // TODO(hongyi): add comments to explain why we need to check if the module frame is in scope
+    IRModuleFrame frame = FindModuleFrame("I.ModuleAttr");
+    if (!frame->attrs.empty()) {
+      LOG(FATAL) << "ValueError: Duplicate module attrs, previous one is:\n" << frame->attrs;
+    }
+    frame->attrs = attrs;
+  }
+}
+
 TVM_REGISTER_GLOBAL("script.ir_builder.ir.IRModule").set_body_typed(IRModule);
 TVM_REGISTER_GLOBAL("script.ir_builder.ir.DeclFunction").set_body_typed(DeclFunction);
 TVM_REGISTER_GLOBAL("script.ir_builder.ir.DefFunction").set_body_typed(DefFunction);
+TVM_REGISTER_GLOBAL("script.ir_builder.ir.ModuleAttrs").set_body_typed(ModuleAttrs);
 
 }  // namespace ir
 }  // namespace ir_builder
diff --git a/src/script/printer/ir/ir.cc b/src/script/printer/ir/ir.cc
index 065cfe5168..1c751d40f2 100644
--- a/src/script/printer/ir/ir.cc
+++ b/src/script/printer/ir/ir.cc
@@ -64,6 +64,11 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
       std::sort(functions.begin(), functions.end());
       With<IRFrame> f(d);
       (*f)->AddDispatchToken(d, "ir");
+      if (mod->attrs.defined() && !mod->attrs->dict.empty()) {
+        (*f)->stmts.push_back(
+            ExprStmtDoc(IR(d, "module_attrs")  //
+                            ->Call({d->AsDoc<ExprDoc>(mod->attrs, p->Attr("attrs"))})));
+      }
       for (const auto& entry : functions) {
         const GlobalVar& gv = entry.gv;
         const BaseFunc& func = entry.func;
diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py b/tests/python/unittest/test_tvmscript_roundtrip.py
index bbc6dd45a8..52d99550be 100644
--- a/tests/python/unittest/test_tvmscript_roundtrip.py
+++ b/tests/python/unittest/test_tvmscript_roundtrip.py
@@ -3725,6 +3725,19 @@ def tvm_struct_set_generated_in_cpp():
     return tvm.tir.transform.LowerTVMBuiltin()(Module)
 
 
+def ir_module_with_attrs():
+    @I.ir_module
+    class Module:
+        I.module_attrs({"attr": 10})
+
+        @T.prim_func
+        def tir_func(A: T.Buffer(16, "int32"), B: T.Buffer(16, "int32")):
+            for i in range(16):
+                B[i] = A[i]
+
+    return Module
+
+
 ir_generator = tvm.testing.parameter(
     launch_env_thread,
     opt_gemm_normalize,
@@ -3791,6 +3804,7 @@ ir_generator = tvm.testing.parameter(
     if_then_else_var,
     tvm_shfl_builtins,
     tvm_struct_set_generated_in_cpp,
+    ir_module_with_attrs,
 )