You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by lu...@apache.org on 2023/05/20 14:13:47 UTC
[tvm] branch main updated: [TIR][TVMScript] Added format/parsing of subroutine calls (#14889)
This is an automated email from the ASF dual-hosted git repository.
lunderberg pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new f4a7eaebd6 [TIR][TVMScript] Added format/parsing of subroutine calls (#14889)
f4a7eaebd6 is described below
commit f4a7eaebd62f68c7f59ca3d6d3f45a969bfc9bc9
Author: Eric Lunderberg <Lu...@users.noreply.github.com>
AuthorDate: Sat May 20 09:13:37 2023 -0500
[TIR][TVMScript] Added format/parsing of subroutine calls (#14889)
* [TVMScript] Cherry-pick module.other_func syntax from unity
* [TIR][TVMScript] Added format/parsing of subroutine calls
Similar to `module.relax_func(args)` syntax used when parsing Relax
functions, this allows `module.tir_func(args)` to be used when parsing
TIR PrimFuncs.
---
include/tvm/node/script_printer.h | 7 ++++++
python/tvm/ir/expr.py | 5 +++++
python/tvm/script/parser/core/parser.py | 5 +++--
python/tvm/script/parser/core/utils.py | 26 ++++++++++++++++++++++-
python/tvm/script/parser/ir/parser.py | 23 +++++++++++++++++++-
python/tvm/script/parser/tir/parser.py | 7 +++---
python/tvm/tir/__init__.py | 2 +-
python/tvm/tir/op.py | 12 +++++++++++
src/script/printer/ir/ir.cc | 15 +++++++++++--
src/script/printer/tir/expr.cc | 5 +++--
src/script/printer/tir/function.cc | 20 +++++++++++++++++
tests/python/unittest/test_tvmscript_roundtrip.py | 17 +++++++++++++++
12 files changed, 131 insertions(+), 13 deletions(-)
diff --git a/include/tvm/node/script_printer.h b/include/tvm/node/script_printer.h
index f4fec04035..c65394f7b7 100644
--- a/include/tvm/node/script_printer.h
+++ b/include/tvm/node/script_printer.h
@@ -43,6 +43,11 @@ class PrinterConfigNode : public Object {
std::string ir_prefix = "I";
/*! \brief The prefix of TIR nodes */
std::string tir_prefix = "T";
+ /*!
+ * \brief The alias of the current module at cross-function call
+ * \note Directly use module name if it's empty.
+ */
+ std::string module_alias = "cls";
/*! \brief Default data type of TIR buffer */
DataType buffer_dtype = DataType::Float(32);
/*! \brief Default data type of integer literals */
@@ -76,6 +81,8 @@ class PrinterConfigNode : public Object {
v->Visit("binding_names", &binding_names);
v->Visit("show_meta", &show_meta);
v->Visit("ir_prefix", &ir_prefix);
+ v->Visit("tir_prefix", &tir_prefix);
+ v->Visit("module_alias", &module_alias);
v->Visit("buffer_dtype", &buffer_dtype);
v->Visit("int_dtype", &int_dtype);
v->Visit("float_dtype", &float_dtype);
diff --git a/python/tvm/ir/expr.py b/python/tvm/ir/expr.py
index 3c3fefb6d6..1c775b461e 100644
--- a/python/tvm/ir/expr.py
+++ b/python/tvm/ir/expr.py
@@ -15,6 +15,8 @@
# specific language governing permissions and limitations
# under the License.
"""Common expressions data structures in the IR."""
+from numbers import Number
+
import tvm._ffi
from ..runtime import Scriptable, const, convert
@@ -86,6 +88,9 @@ class GlobalVar(RelayExpr):
from tvm import relay
return relay.Call(self, args)
+ elif all(isinstance(x, (Number, PrimExpr)) for x in args):
+ return tvm.tir.call_tir(self, *args)
+
arg_types = [type(x) for x in args]
raise RuntimeError(
"Do not know how to handle GlobalVar.__call__ for types {}".format(arg_types)
diff --git a/python/tvm/script/parser/core/parser.py b/python/tvm/script/parser/core/parser.py
index 72858a2028..c253f61c31 100644
--- a/python/tvm/script/parser/core/parser.py
+++ b/python/tvm/script/parser/core/parser.py
@@ -23,6 +23,7 @@ import numpy as np
from tvm._ffi.base import TVMError
from tvm.error import DiagnosticError
+from tvm.ir import GlobalVar
from . import dispatch, doc
from .diagnostics import Diagnostics, Source
@@ -504,10 +505,10 @@ class Parser(doc.NodeVisitor):
_dispatch_wrapper(func)(self, node)
post_func(self, node)
- def visit_tvm_declare_function(self, node: doc.FunctionDef) -> None:
+ def visit_tvm_declare_function(self, node: doc.FunctionDef) -> GlobalVar:
token = self.get_dispatch_token(node)
with self.with_dispatch_token(token):
- _dispatch(self, "tvm_declare_function")(self, node)
+ return _dispatch(self, "tvm_declare_function")(self, node)
def visit_ClassDef(self, node: doc.ClassDef) -> Any: # pylint: disable=invalid-name
"""The general class definition visiting method.
diff --git a/python/tvm/script/parser/core/utils.py b/python/tvm/script/parser/core/utils.py
index 6a693df12f..3edae3f25a 100644
--- a/python/tvm/script/parser/core/utils.py
+++ b/python/tvm/script/parser/core/utils.py
@@ -22,6 +22,30 @@ from typing import Any, Callable, Dict, List
from .diagnostics import findsource
+def get_func_nonlocals(func):
+ """A modified version of `inspect.getclosurevars`"""
+
+ if inspect.ismethod(func):
+ func = func.__func__
+
+ if not inspect.isfunction(func):
+ raise TypeError("{!r} is not a Python function".format(func))
+
+ code = func.__code__
+ # Nonlocal references are named in co_freevars and resolved
+ # by looking them up in __closure__ by positional index
+ nonlocal_vars = {}
+ if func.__closure__ is not None:
+ for var, cell in zip(code.co_freevars, func.__closure__):
+ try:
+ nonlocal_vars[var] = cell.cell_contents
+ except ValueError as err:
+ # cell_contents may raise ValueError if the cell is empty.
+ if "empty" not in str(err):
+ raise
+ return nonlocal_vars
+
+
def inspect_function_capture(func: Callable) -> Dict[str, Any]:
"""Capture function non-locals and global variables.
@@ -37,7 +61,7 @@ def inspect_function_capture(func: Callable) -> Dict[str, Any]:
"""
captured = {
**func.__globals__, # type: ignore
- **inspect.getclosurevars(func).nonlocals,
+ **get_func_nonlocals(func),
}
return captured
diff --git a/python/tvm/script/parser/ir/parser.py b/python/tvm/script/parser/ir/parser.py
index 201c99074f..075ca08703 100644
--- a/python/tvm/script/parser/ir/parser.py
+++ b/python/tvm/script/parser/ir/parser.py
@@ -20,6 +20,15 @@ from ...ir_builder import ir as I
from .._core import Parser, dispatch, doc
+class ModuleWithGlobalVars:
+ """A Module that can add global vars during parsing, to support `Module.function` syntax."""
+
+ def __getattr__(self, attr):
+ # Customize the error message.
+ # NOTE: `__getattr__` is only called when the attribute access fails with an AttributeError
+ raise AttributeError(f"Cannot find the function `{attr}` in the current IRModule")
+
+
@dispatch.register(token="ir", type_name="ClassDef")
def _visit_class_def(self: Parser, node: doc.ClassDef) -> None:
"""The class definition visiting method for ir module.
@@ -35,13 +44,25 @@ def _visit_class_def(self: Parser, node: doc.ClassDef) -> None:
with self.var_table.with_frame():
with I.ir_module():
+ # Step 0. Add the class name to the var table
+ fake_module = ModuleWithGlobalVars()
+ self.var_table.add(node.name, fake_module)
+
+ # Step 1. Visit non-function stmts, including but not limited to
+ # 1. `I.module_attrs`
+ # 2. `I.module_global_infos`
with self.with_dispatch_token("ir"):
for stmt in node.body:
if not isinstance(stmt, doc.FunctionDef):
self.visit(stmt)
+
+ # Step 2. Visit function stmts to declare the global vars
for stmt in node.body:
if isinstance(stmt, doc.FunctionDef):
- self.visit_tvm_declare_function(stmt)
+ global_var = self.visit_tvm_declare_function(stmt)
+ fake_module.__setattr__(stmt.name, global_var)
+
+ # Step 3. Visit and parse the functions
with self.with_dispatch_token("ir"):
for stmt in node.body:
if isinstance(stmt, doc.FunctionDef):
diff --git a/python/tvm/script/parser/tir/parser.py b/python/tvm/script/parser/tir/parser.py
index 0a489a8f04..dfecaacdf6 100644
--- a/python/tvm/script/parser/tir/parser.py
+++ b/python/tvm/script/parser/tir/parser.py
@@ -21,7 +21,7 @@ from functools import partial
from typing import Any
import tvm
-from tvm.ir import PrimType
+from tvm.ir import GlobalVar, PrimType
from tvm.tir import Buffer, IterVar, PrimExpr, Var
from ...ir_builder import ir as I
@@ -473,7 +473,7 @@ def visit_return(self: Parser, node: doc.Return) -> None:
@dispatch.register(token="tir", type_name="tvm_declare_function")
-def visit_tvm_declare_function(self: Parser, node: doc.FunctionDef) -> None:
+def visit_tvm_declare_function(self: Parser, node: doc.FunctionDef) -> GlobalVar:
"""The function declaration step for tir
Parameters
@@ -493,5 +493,4 @@ def visit_tvm_declare_function(self: Parser, node: doc.FunctionDef) -> None:
# Only ret_type is needed for func_signature.
func_signature = tvm.tir.PrimFunc([], None, ret_type=ret_type)
- global_var = I.decl_function(node.name, func_signature)
- self.var_table.add(node.name, global_var)
+ return I.decl_function(node.name, func_signature)
diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py
index 10e75b9151..6583af6e79 100644
--- a/python/tvm/tir/__init__.py
+++ b/python/tvm/tir/__init__.py
@@ -44,7 +44,7 @@ from .stmt import BufferRegion, MatchBufferRegion, Block, BlockRealize
from .function import PrimFunc, TensorIntrin, IndexMap
-from .op import call_packed_lowered, call_cpacked_lowered
+from .op import call_packed_lowered, call_cpacked_lowered, call_tir
from .op import call_packed, call_cpacked, call_intrin, call_pure_extern, call_extern
from .op import call_llvm_intrin, call_llvm_pure_intrin, ret, all, any, min_value, max_value, trace
from .op import tvm_check_return
diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py
index 419ab22758..90e3db4cb9 100644
--- a/python/tvm/tir/op.py
+++ b/python/tvm/tir/op.py
@@ -436,6 +436,18 @@ def undef():
return call_intrin("int32", "tir.undef")
+def call_tir(global_var: tvm.ir.GlobalVar, *args):
+ """Performs a call into another PrimFunc in the same IRModule
+
+ Returns
+ -------
+ call : PrimExpr
+ The call expression.
+ """
+ assert isinstance(global_var, tvm.ir.GlobalVar)
+ return Call(dtype="handle", op=global_var, args=args)
+
+
def start_profile_intrinsic(id):
"""Start profile intrinsic.
Parameters
diff --git a/src/script/printer/ir/ir.cc b/src/script/printer/ir/ir.cc
index 87e7bfbcd9..7b6da42305 100644
--- a/src/script/printer/ir/ir.cc
+++ b/src/script/printer/ir/ir.cc
@@ -64,11 +64,23 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
std::sort(functions.begin(), functions.end());
With<IRFrame> f(d);
(*f)->AddDispatchToken(d, "ir");
+ IdDoc module_doc = d->Define(mod, f(), GetBindingName(d).value_or("Module"));
if (mod->attrs.defined() && !mod->attrs->dict.empty()) {
(*f)->stmts.push_back(
ExprStmtDoc(IR(d, "module_attrs") //
->Call({d->AsDoc<ExprDoc>(mod->attrs, p->Attr("attrs"))})));
}
+
+ // Declare GlobalVars first
+ IdDoc module_alias = d->cfg->module_alias.empty() ? module_doc : IdDoc(d->cfg->module_alias);
+ for (const auto& entry : functions) {
+ const GlobalVar& gv = entry.gv;
+ d->Define(gv, f(), [=]() {
+ return d->AsDoc<ExprDoc>(mod, p->Attr("global_vars"))->Attr(gv->name_hint);
+ });
+ }
+ // Print functions
+
for (const auto& entry : functions) {
const GlobalVar& gv = entry.gv;
const BaseFunc& func = entry.func;
@@ -84,8 +96,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
(*f)->stmts.push_back(Downcast<FunctionDoc>(doc));
}
}
- return HeaderWrapper(d, ClassDoc(IdDoc(GetBindingName(d).value_or("Module")),
- {IR(d, "ir_module")}, (*f)->stmts));
+ return HeaderWrapper(d, ClassDoc(module_doc, {IR(d, "ir_module")}, (*f)->stmts));
});
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
diff --git a/src/script/printer/tir/expr.cc b/src/script/printer/tir/expr.cc
index db99c24886..8de142f861 100644
--- a/src/script/printer/tir/expr.cc
+++ b/src/script/printer/tir/expr.cc
@@ -250,8 +250,8 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
dtype_print_location =
static_cast<tir::ScriptDtypePrintLocation>(dtype_locations[op].IntValue());
}
- } else if (const auto* gv = call->op.as<GlobalVarNode>()) {
- prefix = LiteralDoc::Str(gv->name_hint, call_p->Attr("op"));
+ } else if (call->op.as<GlobalVarNode>()) {
+ prefix = d->AsDoc<ExprDoc>(call->op, call_p->Attr("op"));
} else {
LOG(FATAL) << "call: " << call;
}
@@ -261,6 +261,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
if (dtype_print_location == tir::ScriptDtypePrintLocation::kFirst) {
args.push_back(LiteralDoc::DataType(call->dtype, call_p->Attr("dtype")));
}
+
for (int i = 0; i < n_args; ++i) {
args.push_back(d->AsDoc<ExprDoc>(call->args[i], call_p->Attr("args")->ArrayIndex(i)));
}
diff --git a/src/script/printer/tir/function.cc b/src/script/printer/tir/function.cc
index f40d7818d7..a8445f23df 100644
--- a/src/script/printer/tir/function.cc
+++ b/src/script/printer/tir/function.cc
@@ -176,6 +176,26 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
TVM_SCRIPT_REPR(tir::PrimFuncNode, ReprPrintTIR);
+TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
+ .set_dispatch<tvm::GlobalVar>( //
+ "tir", [](tvm::GlobalVar n, ObjectPath n_p, IRDocsifier d) -> Doc { //
+ if (Optional<ExprDoc> doc = d->GetVarDoc(n)) {
+ return doc.value();
+ } else {
+ IdDoc ret(n->name_hint);
+ ret->source_paths.push_back(n_p);
+ return ret;
+ }
+ });
+
+TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
+ .set_dispatch<tvm::IRModule>( //
+ "tir", [](tvm::IRModule mod, ObjectPath n_p, IRDocsifier d) -> Doc { //
+ Optional<ExprDoc> doc = d->GetVarDoc(mod);
+ ICHECK(doc) << "Unable to print IRModule before definition in TIR.";
+ return doc.value();
+ });
+
} // namespace printer
} // namespace script
} // namespace tvm
diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py b/tests/python/unittest/test_tvmscript_roundtrip.py
index 757f74ab83..7eee601358 100644
--- a/tests/python/unittest/test_tvmscript_roundtrip.py
+++ b/tests/python/unittest/test_tvmscript_roundtrip.py
@@ -3792,6 +3792,22 @@ def nested_seqstmt():
return func
+def subroutine_call():
+ """A GlobalVar may reference other functions in the module"""
+
+ @I.ir_module
+ class mod:
+ @T.prim_func
+ def main(A: T.Buffer(16, "float32")):
+ mod.subroutine(A.data, T.int32(16))
+
+ @T.prim_func
+ def subroutine(A_data: T.handle("float32"), n: T.int32):
+ T.evaluate(0)
+
+ return mod
+
+
ir_generator = tvm.testing.parameter(
launch_env_thread,
opt_gemm_normalize,
@@ -3861,6 +3877,7 @@ ir_generator = tvm.testing.parameter(
tvm_struct_set_generated_in_cpp,
ir_module_with_attrs,
nested_seqstmt,
+ subroutine_call,
)