You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by lu...@apache.org on 2023/05/20 14:13:47 UTC

[tvm] branch main updated: [TIR][TVMScript] Added format/parsing of subroutine calls (#14889)

This is an automated email from the ASF dual-hosted git repository.

lunderberg pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new f4a7eaebd6 [TIR][TVMScript] Added format/parsing of subroutine calls (#14889)
f4a7eaebd6 is described below

commit f4a7eaebd62f68c7f59ca3d6d3f45a969bfc9bc9
Author: Eric Lunderberg <Lu...@users.noreply.github.com>
AuthorDate: Sat May 20 09:13:37 2023 -0500

    [TIR][TVMScript] Added format/parsing of subroutine calls (#14889)
    
    * [TVMScript] Cherry-pick module.other_func syntax from unity
    
    * [TIR][TVMScript] Added format/parsing of subroutine calls
    
    Similar to `module.relax_func(args)` syntax used when parsing Relax
    functions, this allows `module.tir_func(args)` to be used when parsing
    TIR PrimFuncs.
---
 include/tvm/node/script_printer.h                 |  7 ++++++
 python/tvm/ir/expr.py                             |  5 +++++
 python/tvm/script/parser/core/parser.py           |  5 +++--
 python/tvm/script/parser/core/utils.py            | 26 ++++++++++++++++++++++-
 python/tvm/script/parser/ir/parser.py             | 23 +++++++++++++++++++-
 python/tvm/script/parser/tir/parser.py            |  7 +++---
 python/tvm/tir/__init__.py                        |  2 +-
 python/tvm/tir/op.py                              | 12 +++++++++++
 src/script/printer/ir/ir.cc                       | 15 +++++++++++--
 src/script/printer/tir/expr.cc                    |  5 +++--
 src/script/printer/tir/function.cc                | 20 +++++++++++++++++
 tests/python/unittest/test_tvmscript_roundtrip.py | 17 +++++++++++++++
 12 files changed, 131 insertions(+), 13 deletions(-)

diff --git a/include/tvm/node/script_printer.h b/include/tvm/node/script_printer.h
index f4fec04035..c65394f7b7 100644
--- a/include/tvm/node/script_printer.h
+++ b/include/tvm/node/script_printer.h
@@ -43,6 +43,11 @@ class PrinterConfigNode : public Object {
   std::string ir_prefix = "I";
   /*! \brief The prefix of TIR nodes */
   std::string tir_prefix = "T";
+  /*!
+   * \brief The alias of the current module at cross-function call
+   * \note Directly use module name if it's empty.
+   */
+  std::string module_alias = "cls";
   /*! \brief Default data type of TIR buffer */
   DataType buffer_dtype = DataType::Float(32);
   /*! \brief Default data type of integer literals */
@@ -76,6 +81,8 @@ class PrinterConfigNode : public Object {
     v->Visit("binding_names", &binding_names);
     v->Visit("show_meta", &show_meta);
     v->Visit("ir_prefix", &ir_prefix);
+    v->Visit("tir_prefix", &tir_prefix);
+    v->Visit("module_alias", &module_alias);
     v->Visit("buffer_dtype", &buffer_dtype);
     v->Visit("int_dtype", &int_dtype);
     v->Visit("float_dtype", &float_dtype);
diff --git a/python/tvm/ir/expr.py b/python/tvm/ir/expr.py
index 3c3fefb6d6..1c775b461e 100644
--- a/python/tvm/ir/expr.py
+++ b/python/tvm/ir/expr.py
@@ -15,6 +15,8 @@
 # specific language governing permissions and limitations
 # under the License.
 """Common expressions data structures in the IR."""
+from numbers import Number
+
 import tvm._ffi
 
 from ..runtime import Scriptable, const, convert
@@ -86,6 +88,9 @@ class GlobalVar(RelayExpr):
             from tvm import relay
 
             return relay.Call(self, args)
+        elif all(isinstance(x, (Number, PrimExpr)) for x in args):
+            return tvm.tir.call_tir(self, *args)
+
         arg_types = [type(x) for x in args]
         raise RuntimeError(
             "Do not know how to handle GlobalVar.__call__ for types {}".format(arg_types)
diff --git a/python/tvm/script/parser/core/parser.py b/python/tvm/script/parser/core/parser.py
index 72858a2028..c253f61c31 100644
--- a/python/tvm/script/parser/core/parser.py
+++ b/python/tvm/script/parser/core/parser.py
@@ -23,6 +23,7 @@ import numpy as np
 from tvm._ffi.base import TVMError
 
 from tvm.error import DiagnosticError
+from tvm.ir import GlobalVar
 
 from . import dispatch, doc
 from .diagnostics import Diagnostics, Source
@@ -504,10 +505,10 @@ class Parser(doc.NodeVisitor):
         _dispatch_wrapper(func)(self, node)
         post_func(self, node)
 
-    def visit_tvm_declare_function(self, node: doc.FunctionDef) -> None:
+    def visit_tvm_declare_function(self, node: doc.FunctionDef) -> GlobalVar:
         token = self.get_dispatch_token(node)
         with self.with_dispatch_token(token):
-            _dispatch(self, "tvm_declare_function")(self, node)
+            return _dispatch(self, "tvm_declare_function")(self, node)
 
     def visit_ClassDef(self, node: doc.ClassDef) -> Any:  # pylint: disable=invalid-name
         """The general class definition visiting method.
diff --git a/python/tvm/script/parser/core/utils.py b/python/tvm/script/parser/core/utils.py
index 6a693df12f..3edae3f25a 100644
--- a/python/tvm/script/parser/core/utils.py
+++ b/python/tvm/script/parser/core/utils.py
@@ -22,6 +22,30 @@ from typing import Any, Callable, Dict, List
 from .diagnostics import findsource
 
 
+def get_func_nonlocals(func):
+    """A modified version of `inspect.getclosurevars`"""
+
+    if inspect.ismethod(func):
+        func = func.__func__
+
+    if not inspect.isfunction(func):
+        raise TypeError("{!r} is not a Python function".format(func))
+
+    code = func.__code__
+    # Nonlocal references are named in co_freevars and resolved
+    # by looking them up in __closure__ by positional index
+    nonlocal_vars = {}
+    if func.__closure__ is not None:
+        for var, cell in zip(code.co_freevars, func.__closure__):
+            try:
+                nonlocal_vars[var] = cell.cell_contents
+            except ValueError as err:
+                # cell_contents may raise ValueError if the cell is empty.
+                if "empty" not in str(err):
+                    raise
+    return nonlocal_vars
+
+
 def inspect_function_capture(func: Callable) -> Dict[str, Any]:
     """Capture function non-locals and global variables.
 
@@ -37,7 +61,7 @@ def inspect_function_capture(func: Callable) -> Dict[str, Any]:
     """
     captured = {
         **func.__globals__,  # type: ignore
-        **inspect.getclosurevars(func).nonlocals,
+        **get_func_nonlocals(func),
     }
     return captured
 
diff --git a/python/tvm/script/parser/ir/parser.py b/python/tvm/script/parser/ir/parser.py
index 201c99074f..075ca08703 100644
--- a/python/tvm/script/parser/ir/parser.py
+++ b/python/tvm/script/parser/ir/parser.py
@@ -20,6 +20,15 @@ from ...ir_builder import ir as I
 from .._core import Parser, dispatch, doc
 
 
+class ModuleWithGlobalVars:
+    """A Module that can add global vars during parsing, to support `Module.function` syntax."""
+
+    def __getattr__(self, attr):
+        # Customize the error message.
+        # NOTE: `__getattr__` is only called when the attribute access fails with an AttributeError
+        raise AttributeError(f"Cannot find the function `{attr}` in the current IRModule")
+
+
 @dispatch.register(token="ir", type_name="ClassDef")
 def _visit_class_def(self: Parser, node: doc.ClassDef) -> None:
     """The class definition visiting method for ir module.
@@ -35,13 +44,25 @@ def _visit_class_def(self: Parser, node: doc.ClassDef) -> None:
 
     with self.var_table.with_frame():
         with I.ir_module():
+            # Step 0. Add the class name to the var table
+            fake_module = ModuleWithGlobalVars()
+            self.var_table.add(node.name, fake_module)
+
+            # Step 1. Visit non-function stmts, including but not limited to
+            # 1. `I.module_attrs`
+            # 2. `I.module_global_infos`
             with self.with_dispatch_token("ir"):
                 for stmt in node.body:
                     if not isinstance(stmt, doc.FunctionDef):
                         self.visit(stmt)
+
+            # Step 2. Visit function stmts to declare the global vars
             for stmt in node.body:
                 if isinstance(stmt, doc.FunctionDef):
-                    self.visit_tvm_declare_function(stmt)
+                    global_var = self.visit_tvm_declare_function(stmt)
+                    fake_module.__setattr__(stmt.name, global_var)
+
+            # Step 3. Visit and parse the functions
             with self.with_dispatch_token("ir"):
                 for stmt in node.body:
                     if isinstance(stmt, doc.FunctionDef):
diff --git a/python/tvm/script/parser/tir/parser.py b/python/tvm/script/parser/tir/parser.py
index 0a489a8f04..dfecaacdf6 100644
--- a/python/tvm/script/parser/tir/parser.py
+++ b/python/tvm/script/parser/tir/parser.py
@@ -21,7 +21,7 @@ from functools import partial
 from typing import Any
 
 import tvm
-from tvm.ir import PrimType
+from tvm.ir import GlobalVar, PrimType
 from tvm.tir import Buffer, IterVar, PrimExpr, Var
 
 from ...ir_builder import ir as I
@@ -473,7 +473,7 @@ def visit_return(self: Parser, node: doc.Return) -> None:
 
 
 @dispatch.register(token="tir", type_name="tvm_declare_function")
-def visit_tvm_declare_function(self: Parser, node: doc.FunctionDef) -> None:
+def visit_tvm_declare_function(self: Parser, node: doc.FunctionDef) -> GlobalVar:
     """The function declaration step for tir
 
     Parameters
@@ -493,5 +493,4 @@ def visit_tvm_declare_function(self: Parser, node: doc.FunctionDef) -> None:
 
     # Only ret_type is needed for func_signature.
     func_signature = tvm.tir.PrimFunc([], None, ret_type=ret_type)
-    global_var = I.decl_function(node.name, func_signature)
-    self.var_table.add(node.name, global_var)
+    return I.decl_function(node.name, func_signature)
diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py
index 10e75b9151..6583af6e79 100644
--- a/python/tvm/tir/__init__.py
+++ b/python/tvm/tir/__init__.py
@@ -44,7 +44,7 @@ from .stmt import BufferRegion, MatchBufferRegion, Block, BlockRealize
 
 from .function import PrimFunc, TensorIntrin, IndexMap
 
-from .op import call_packed_lowered, call_cpacked_lowered
+from .op import call_packed_lowered, call_cpacked_lowered, call_tir
 from .op import call_packed, call_cpacked, call_intrin, call_pure_extern, call_extern
 from .op import call_llvm_intrin, call_llvm_pure_intrin, ret, all, any, min_value, max_value, trace
 from .op import tvm_check_return
diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py
index 419ab22758..90e3db4cb9 100644
--- a/python/tvm/tir/op.py
+++ b/python/tvm/tir/op.py
@@ -436,6 +436,18 @@ def undef():
     return call_intrin("int32", "tir.undef")
 
 
+def call_tir(global_var: tvm.ir.GlobalVar, *args):
+    """Performs a call into another PrimFunc in the same IRModule
+
+    Returns
+    -------
+    call : PrimExpr
+        The call expression.
+    """
+    assert isinstance(global_var, tvm.ir.GlobalVar)
+    return Call(dtype="handle", op=global_var, args=args)
+
+
 def start_profile_intrinsic(id):
     """Start profile intrinsic.
     Parameters
diff --git a/src/script/printer/ir/ir.cc b/src/script/printer/ir/ir.cc
index 87e7bfbcd9..7b6da42305 100644
--- a/src/script/printer/ir/ir.cc
+++ b/src/script/printer/ir/ir.cc
@@ -64,11 +64,23 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
       std::sort(functions.begin(), functions.end());
       With<IRFrame> f(d);
       (*f)->AddDispatchToken(d, "ir");
+      IdDoc module_doc = d->Define(mod, f(), GetBindingName(d).value_or("Module"));
       if (mod->attrs.defined() && !mod->attrs->dict.empty()) {
         (*f)->stmts.push_back(
             ExprStmtDoc(IR(d, "module_attrs")  //
                             ->Call({d->AsDoc<ExprDoc>(mod->attrs, p->Attr("attrs"))})));
       }
+
+      // Declare GlobalVars first
+      IdDoc module_alias = d->cfg->module_alias.empty() ? module_doc : IdDoc(d->cfg->module_alias);
+      for (const auto& entry : functions) {
+        const GlobalVar& gv = entry.gv;
+        d->Define(gv, f(), [=]() {
+          return d->AsDoc<ExprDoc>(mod, p->Attr("global_vars"))->Attr(gv->name_hint);
+        });
+      }
+      // Print functions
+
       for (const auto& entry : functions) {
         const GlobalVar& gv = entry.gv;
         const BaseFunc& func = entry.func;
@@ -84,8 +96,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
           (*f)->stmts.push_back(Downcast<FunctionDoc>(doc));
         }
       }
-      return HeaderWrapper(d, ClassDoc(IdDoc(GetBindingName(d).value_or("Module")),
-                                       {IR(d, "ir_module")}, (*f)->stmts));
+      return HeaderWrapper(d, ClassDoc(module_doc, {IR(d, "ir_module")}, (*f)->stmts));
     });
 
 TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
diff --git a/src/script/printer/tir/expr.cc b/src/script/printer/tir/expr.cc
index db99c24886..8de142f861 100644
--- a/src/script/printer/tir/expr.cc
+++ b/src/script/printer/tir/expr.cc
@@ -250,8 +250,8 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
           dtype_print_location =
               static_cast<tir::ScriptDtypePrintLocation>(dtype_locations[op].IntValue());
         }
-      } else if (const auto* gv = call->op.as<GlobalVarNode>()) {
-        prefix = LiteralDoc::Str(gv->name_hint, call_p->Attr("op"));
+      } else if (call->op.as<GlobalVarNode>()) {
+        prefix = d->AsDoc<ExprDoc>(call->op, call_p->Attr("op"));
       } else {
         LOG(FATAL) << "call: " << call;
       }
@@ -261,6 +261,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
       if (dtype_print_location == tir::ScriptDtypePrintLocation::kFirst) {
         args.push_back(LiteralDoc::DataType(call->dtype, call_p->Attr("dtype")));
       }
+
       for (int i = 0; i < n_args; ++i) {
         args.push_back(d->AsDoc<ExprDoc>(call->args[i], call_p->Attr("args")->ArrayIndex(i)));
       }
diff --git a/src/script/printer/tir/function.cc b/src/script/printer/tir/function.cc
index f40d7818d7..a8445f23df 100644
--- a/src/script/printer/tir/function.cc
+++ b/src/script/printer/tir/function.cc
@@ -176,6 +176,26 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
 
 TVM_SCRIPT_REPR(tir::PrimFuncNode, ReprPrintTIR);
 
+TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
+    .set_dispatch<tvm::GlobalVar>(                                           //
+        "tir", [](tvm::GlobalVar n, ObjectPath n_p, IRDocsifier d) -> Doc {  //
+          if (Optional<ExprDoc> doc = d->GetVarDoc(n)) {
+            return doc.value();
+          } else {
+            IdDoc ret(n->name_hint);
+            ret->source_paths.push_back(n_p);
+            return ret;
+          }
+        });
+
+TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
+    .set_dispatch<tvm::IRModule>(                                             //
+        "tir", [](tvm::IRModule mod, ObjectPath n_p, IRDocsifier d) -> Doc {  //
+          Optional<ExprDoc> doc = d->GetVarDoc(mod);
+          ICHECK(doc) << "Unable to print IRModule before definition in TIR.";
+          return doc.value();
+        });
+
 }  // namespace printer
 }  // namespace script
 }  // namespace tvm
diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py b/tests/python/unittest/test_tvmscript_roundtrip.py
index 757f74ab83..7eee601358 100644
--- a/tests/python/unittest/test_tvmscript_roundtrip.py
+++ b/tests/python/unittest/test_tvmscript_roundtrip.py
@@ -3792,6 +3792,22 @@ def nested_seqstmt():
     return func
 
 
+def subroutine_call():
+    """A GlobalVar may reference other functions in the module"""
+
+    @I.ir_module
+    class mod:
+        @T.prim_func
+        def main(A: T.Buffer(16, "float32")):
+            mod.subroutine(A.data, T.int32(16))
+
+        @T.prim_func
+        def subroutine(A_data: T.handle("float32"), n: T.int32):
+            T.evaluate(0)
+
+    return mod
+
+
 ir_generator = tvm.testing.parameter(
     launch_env_thread,
     opt_gemm_normalize,
@@ -3861,6 +3877,7 @@ ir_generator = tvm.testing.parameter(
     tvm_struct_set_generated_in_cpp,
     ir_module_with_attrs,
     nested_seqstmt,
+    subroutine_call,
 )