You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2023/04/06 23:39:08 UTC
[tvm] 01/03: [TVMScript] IRModule TVMScript Parser.
This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
commit 11c13ace0b5cef71f50193248ecaac7e845ee25e
Author: Siyuan Feng <Hz...@sjtu.edu.cn>
AuthorDate: Wed Feb 8 22:31:47 2023 +0800
[TVMScript] IRModule TVMScript Parser.
This PR adds the TVMScript parser/ir_builder support based on the
blockbuilder. This commit contains the non-relax portions from
https://github.com/apache/tvm/pull/13932.
Co-authored-by: Ruihang Lai <ru...@cs.cmu.edu>
Co-authored-by: Junru Shao <ju...@gmail.com>
Co-authored-by: Tianqi Chen <ti...@gmail.com>
Co-authored-by: Yuchen Jin <yu...@cs.washington.edu>
Co-authored-by: Steven S. Lyubomirsky <sl...@gmail.com>
Co-authored-by: Yong Wu <yo...@gmail.com>
---
include/tvm/script/ir_builder/ir/frame.h | 11 ++++--
include/tvm/script/ir_builder/ir/ir.h | 17 +++++++++
python/tvm/script/ir_builder/base.py | 6 ++--
python/tvm/script/ir_builder/ir/__init__.py | 2 +-
python/tvm/script/ir_builder/ir/ir.py | 45 +++++++++++++++++++++++
python/tvm/script/parser/core/diagnostics.py | 2 +-
python/tvm/script/parser/core/evaluator.py | 2 +-
python/tvm/script/parser/core/parser.py | 50 ++++++++++++++++++--------
python/tvm/script/parser/ir/parser.py | 4 +++
python/tvm/script/parser/tir/entry.py | 4 +--
python/tvm/script/parser/tir/parser.py | 26 ++++++++++++++
src/script/ir_builder/ir/frame.cc | 12 ++++---
src/script/ir_builder/ir/ir.cc | 32 ++++++++++++++++-
src/script/ir_builder/ir/{frame.cc => utils.h} | 30 +++++++++-------
src/script/ir_builder/tir/frame.cc | 15 ++++++--
src/script/ir_builder/tir/utils.h | 2 +-
16 files changed, 213 insertions(+), 47 deletions(-)
diff --git a/include/tvm/script/ir_builder/ir/frame.h b/include/tvm/script/ir_builder/ir/frame.h
index 887981ccff..dacfc361a6 100644
--- a/include/tvm/script/ir_builder/ir/frame.h
+++ b/include/tvm/script/ir_builder/ir/frame.h
@@ -38,12 +38,17 @@ namespace ir {
*/
class IRModuleFrameNode : public IRBuilderFrameNode {
public:
- Array<GlobalVar> global_vars;
- Array<BaseFunc> functions;
+ /*! \brief A map from string names to global variables that ensures global uniqueness. */
+ Map<String, GlobalVar> global_var_map;
+ /*!
+ * \brief A map from GlobalVar to all global functions.
+ * \note Only defined functions are in the map, while declared functions are not included.
+ */
+ Map<GlobalVar, BaseFunc> functions;
void VisitAttrs(tvm::AttrVisitor* v) {
IRBuilderFrameNode::VisitAttrs(v);
- v->Visit("global_vars", &global_vars);
+ v->Visit("global_vars", &global_var_map);
v->Visit("functions", &functions);
}
diff --git a/include/tvm/script/ir_builder/ir/ir.h b/include/tvm/script/ir_builder/ir/ir.h
index f0e7cc6f5c..49bdcf60e6 100644
--- a/include/tvm/script/ir_builder/ir/ir.h
+++ b/include/tvm/script/ir_builder/ir/ir.h
@@ -37,6 +37,23 @@ namespace ir {
*/
TVM_DLL IRModuleFrame IRModule();
+/*!
+ * \brief Declare a Function without given the specific function implementation.
+ * \note It is usually used in cross-function call. And we can specify the function by `DefFunction`
+ * \param func_name The function unique name.
+ * \param func_signature A Function w/o body, which used to specify the function signature
+ * (i.e. func params and func return type/shape).
+ * \return The corresponding GlobalVar.
+ */
+TVM_DLL GlobalVar DeclFunction(const String& func_name, const BaseFunc& func_signature);
+
+/*!
+ * \brief Define the function which is declared before.
+ * \param func_name The function unique name.
+ * \param func The given function implementation
+ */
+TVM_DLL void DefFunction(const String& func_name, const BaseFunc& func);
+
} // namespace ir
} // namespace ir_builder
} // namespace script
diff --git a/python/tvm/script/ir_builder/base.py b/python/tvm/script/ir_builder/base.py
index 7aa33ee49c..b35bbd0a7d 100644
--- a/python/tvm/script/ir_builder/base.py
+++ b/python/tvm/script/ir_builder/base.py
@@ -64,8 +64,10 @@ class IRBuilderFrame(_Object):
_ffi_api.IRBuilderFrameEnter(self) # type: ignore[attr-defined] # pylint: disable=no-member
return self
- def __exit__(self, ptype, value, trace) -> None: # pylint: disable=unused-argument
- _ffi_api.IRBuilderFrameExit(self) # type: ignore[attr-defined] # pylint: disable=no-member
+ def __exit__(self, exc_type, exc_value, trace) -> None: # pylint: disable=unused-argument
+ if exc_type is None and exc_value is None:
+ # Do not execute `FrameExit` if the with scope exits because of exceptions
+ _ffi_api.IRBuilderFrameExit(self) # type: ignore[attr-defined] # pylint: disable=no-member
def add_callback(self, callback: Callable[[], None]) -> None:
"""Add a callback method invoked when exiting the with-scope.
diff --git a/python/tvm/script/ir_builder/ir/__init__.py b/python/tvm/script/ir_builder/ir/__init__.py
index ebb9728737..946be263a7 100644
--- a/python/tvm/script/ir_builder/ir/__init__.py
+++ b/python/tvm/script/ir_builder/ir/__init__.py
@@ -16,4 +16,4 @@
# under the License.
"""Package tvm.script.ir_builder.ir"""
from .frame import IRModuleFrame
-from .ir import ir_module
+from .ir import decl_function, def_function, ir_module
diff --git a/python/tvm/script/ir_builder/ir/ir.py b/python/tvm/script/ir_builder/ir/ir.py
index 213180463c..796d6f3aad 100644
--- a/python/tvm/script/ir_builder/ir/ir.py
+++ b/python/tvm/script/ir_builder/ir/ir.py
@@ -16,9 +16,54 @@
# under the License.
"""Package tvm.script.ir_builder.ir.ir"""
+from tvm.ir import BaseFunc, GlobalVar
+
from . import _ffi_api
from .frame import IRModuleFrame
def ir_module() -> IRModuleFrame:
+ """Start a ir_module frame.
+ Returns
+ -------
+ frame: IRModuleFrame
+ The constructed frame.
+ """
return _ffi_api.IRModule() # type: ignore[attr-defined] # pylint: disable=no-member
+
+
+def decl_function(func_name: str, func_signature: BaseFunc) -> GlobalVar:
+ """Declare a Function without given the specific function implementation.
+ Parameters
+ ----------
+ func_name : str
+ The function unique name.
+
+ func_signature: Optional[BaseFunc]
+ A Function w/o body, which used to specify the function signature
+ (i.e. func params and func return type/shape).
+
+ Note
+ ----
+ It is usually used in cross-function call. And we can specify the function by `DefFunction`
+ Returns
+ -------
+ gv : GlobalVar
+ The corresponding GlobalVar.
+ """
+
+ return _ffi_api.DeclFunction( # type: ignore[attr-defined] # pylint: disable=no-member
+ func_name, func_signature
+ )
+
+
+def def_function(func_name: str, func: BaseFunc) -> None:
+ """Define the function which is declared before.
+ Parameters
+ ----------
+ func_name : str
+ The function unique name.
+ func: BaseFunc
+ The given function implementation
+ """
+ return _ffi_api.DefFunction(func_name, func) # type: ignore[attr-defined] # pylint: disable=no-member
diff --git a/python/tvm/script/parser/core/diagnostics.py b/python/tvm/script/parser/core/diagnostics.py
index ad7ae50347..2767a97f60 100644
--- a/python/tvm/script/parser/core/diagnostics.py
+++ b/python/tvm/script/parser/core/diagnostics.py
@@ -220,7 +220,7 @@ class Diagnostics:
level : diagnostics.DiagnosticLevel
The diagnostic level.
"""
- lineno = node.lineno or self.source.start_line
+ lineno = node.lineno or 1
col_offset = node.col_offset or self.source.start_column
end_lineno = node.end_lineno or lineno
end_col_offset = node.end_col_offset or col_offset
diff --git a/python/tvm/script/parser/core/evaluator.py b/python/tvm/script/parser/core/evaluator.py
index 3a72a3c331..075aedd891 100644
--- a/python/tvm/script/parser/core/evaluator.py
+++ b/python/tvm/script/parser/core/evaluator.py
@@ -203,7 +203,7 @@ class ExprEvaluator:
else:
value = self._eval_expr(node.__class__(**fields))
except Exception as e: # pylint: disable=broad-except,invalid-name
- self.parser.report_error(node, str(e))
+ self.parser.report_error(node, e)
return self._add_intermediate_result(value)
def _eval_lambda(self, node: doc.Lambda) -> Any:
diff --git a/python/tvm/script/parser/core/parser.py b/python/tvm/script/parser/core/parser.py
index fdccabcd23..837b7cce5d 100644
--- a/python/tvm/script/parser/core/parser.py
+++ b/python/tvm/script/parser/core/parser.py
@@ -60,6 +60,10 @@ def _deferred(exit_f: Callable[[], None]):
return context()
+def _do_nothing(*args, **kwargs): # pylint: disable=unused-argument
+ pass
+
+
class VarTableFrame:
"""The variable table frame.
A frame of variable table stores the variables created in one block or scope.
@@ -260,6 +264,17 @@ class Parser(doc.NodeVisitor):
node = self.diag.source.as_ast()
self.visit(node)
+ def get_dispatch_token(self, node: doc.FunctionDef) -> str:
+ if not isinstance(node, doc.FunctionDef):
+ self.report_error(node, "Only can get dispatch token for function.")
+ if not node.decorator_list:
+ self.report_error(node, "Function must be decorated")
+ # TODO: only the last decorator is parsed
+ decorator = self.eval_expr(node.decorator_list[-1])
+ if not hasattr(decorator, "dispatch_token"):
+ self.report_error(node, "The parser does not understand the decorator")
+ return decorator.dispatch_token
+
def with_dispatch_token(self, token: str):
"""Add a new dispatching token as with statement.
@@ -389,6 +404,8 @@ class Parser(doc.NodeVisitor):
# Only take the last line of the error message
if isinstance(err, TVMError):
msg = list(filter(None, str(err).split("\n")))[-1]
+ elif isinstance(err, KeyError):
+ msg = "KeyError: " + str(err)
else:
msg = str(err)
self.diag.error(node, msg)
@@ -458,30 +475,33 @@ class Parser(doc.NodeVisitor):
"""
return _dispatch(self, "tvm_annotation")(self, node)
- def visit_FunctionDef(self, node: doc.FunctionDef) -> Any: # pylint: disable=invalid-name
- """The general function definition visiting method.
+ def visit_FunctionDef(self, node: doc.FunctionDef) -> None: # pylint: disable=invalid-name
+ """The general function definition visit method.
Parameters
----------
node : doc.FunctionDef
- The doc AST function definition node.
-
- Returns
- -------
- res : Any
- The visiting result.
+ The doc FunctionDef node.
"""
- if not node.decorator_list:
- self.report_error(node, "Function must be decorated")
- # TODO: only the last decorator is parsed
- decorator = self.eval_expr(node.decorator_list[-1])
- if not hasattr(decorator, "dispatch_token"):
- self.report_error(node, "The parser does not understand the decorator")
- token = decorator.dispatch_token
+ token = self.get_dispatch_token(node)
+ current_token = self.dispatch_tokens[-1]
func = dispatch.get(token=token, type_name="FunctionDef", default=None)
if func is None:
self.report_error(node, "The parser does not understand the decorator")
+ pre_func = dispatch.get(
+ token=current_token, type_name="pre_token_switch", default=_do_nothing
+ )
+ post_func = dispatch.get(
+ token=current_token, type_name="post_token_switch", default=_do_nothing
+ )
+ pre_func(self, node)
_dispatch_wrapper(func)(self, node)
+ post_func(self, node)
+
+ def visit_tvm_declare_function(self, node: doc.FunctionDef) -> None:
+ token = self.get_dispatch_token(node)
+ with self.with_dispatch_token(token):
+ _dispatch(self, "tvm_declare_function")(self, node)
def visit_ClassDef(self, node: doc.ClassDef) -> Any: # pylint: disable=invalid-name
"""The general class definition visiting method.
diff --git a/python/tvm/script/parser/ir/parser.py b/python/tvm/script/parser/ir/parser.py
index e0268412d2..13b3e29859 100644
--- a/python/tvm/script/parser/ir/parser.py
+++ b/python/tvm/script/parser/ir/parser.py
@@ -32,8 +32,12 @@ def _visit_class_def(self: Parser, node: doc.ClassDef) -> None:
node : doc.ClassDef
The doc AST class definition node.
"""
+
with self.var_table.with_frame():
with I.ir_module():
+ for stmt in node.body:
+ if isinstance(stmt, doc.FunctionDef):
+ self.visit_tvm_declare_function(stmt)
with self.with_dispatch_token("ir"):
self.visit_body(node.body)
diff --git a/python/tvm/script/parser/tir/entry.py b/python/tvm/script/parser/tir/entry.py
index 411a7f8f3c..649f817411 100644
--- a/python/tvm/script/parser/tir/entry.py
+++ b/python/tvm/script/parser/tir/entry.py
@@ -83,7 +83,7 @@ class BufferProxy:
return self(keys)
if len(keys) >= 2 and not isinstance(keys[1], str):
return self(keys)
- return self(*keys) # pylint: disable=no-member # type: ignore
+ return self(*keys) # type: ignore[attr-defined] # pylint: disable=no-member
class PtrProxy:
@@ -93,7 +93,7 @@ class PtrProxy:
def __call__(self, dtype, storage_scope="global"):
if callable(dtype):
dtype = dtype().dtype
- return ptr(dtype, storage_scope) # pylint: disable=no-member # type: ignore
+ return ptr(dtype, storage_scope) # type: ignore[attr-defined] # pylint: disable=no-member
@deprecated("T.Ptr[...]", "T.handle(...)")
def __getitem__(self, keys):
diff --git a/python/tvm/script/parser/tir/parser.py b/python/tvm/script/parser/tir/parser.py
index 8a067267a3..63171f6722 100644
--- a/python/tvm/script/parser/tir/parser.py
+++ b/python/tvm/script/parser/tir/parser.py
@@ -24,6 +24,7 @@ import tvm
from tvm.ir import PrimType
from tvm.tir import Buffer, IterVar, PrimExpr, Var
+from ...ir_builder import ir as I
from ...ir_builder import tir as T
from ...ir_builder.base import IRBuilder
from ...ir_builder.base import IRBuilderFrame as Frame
@@ -473,3 +474,28 @@ def visit_return(self: Parser, node: doc.Return) -> None:
The doc AST return node.
"""
self.report_error(node, "Return is not allowed.")
+
+
+@dispatch.register(token="tir", type_name="tvm_declare_function")
+def visit_tvm_declare_function(self: Parser, node: doc.FunctionDef) -> None:
+ """The function declaration step for tir
+
+ Parameters
+ ----------
+ self : Parser
+ The visiting parser.
+
+ node : doc.Return
+ The doc AST return node.
+ """
+
+ ret_type = None
+ if node.returns is not None:
+ ret_type = self.eval_expr(node.returns)
+ if callable(ret_type):
+ ret_type = PrimType(ret_type().dtype)
+
+ # Only ret_type is needed for func_signature.
+ func_signature = tvm.tir.PrimFunc([], None, ret_type=ret_type)
+ global_var = I.decl_function(node.name, func_signature)
+ self.var_table.add(node.name, global_var)
diff --git a/src/script/ir_builder/ir/frame.cc b/src/script/ir_builder/ir/frame.cc
index a81c56922d..addf129284 100644
--- a/src/script/ir_builder/ir/frame.cc
+++ b/src/script/ir_builder/ir/frame.cc
@@ -26,11 +26,15 @@ namespace ir_builder {
namespace ir {
void IRModuleFrameNode::ExitWithScope() {
- ICHECK_EQ(functions.size(), global_vars.size());
- int n = functions.size();
Map<GlobalVar, BaseFunc> func_map;
- for (int i = 0; i < n; ++i) {
- func_map.Set(global_vars[i], functions[i]);
+ CHECK_EQ(functions.size(), global_var_map.size())
+ << "All functions must be defined in the IRModule. Got " << global_var_map.size()
+ << "declared function(s), but only " << functions.size() << "defined function(s).";
+ for (const auto& kv : functions) {
+ const GlobalVar& gv = kv.first;
+ const BaseFunc& func = kv.second;
+ CHECK(func.defined()) << "ValueError: function " << gv->name_hint << " is not defined";
+ func_map.Set(gv, func);
}
IRBuilder builder = IRBuilder::Current();
ICHECK(!builder->result.defined()) << "ValueError: Builder.result has already been set";
diff --git a/src/script/ir_builder/ir/ir.cc b/src/script/ir_builder/ir/ir.cc
index a8cc452e4f..5764e90c8d 100644
--- a/src/script/ir_builder/ir/ir.cc
+++ b/src/script/ir_builder/ir/ir.cc
@@ -20,6 +20,8 @@
#include <tvm/runtime/registry.h>
#include <tvm/script/ir_builder/ir/ir.h>
+#include "./utils.h"
+
namespace tvm {
namespace script {
namespace ir_builder {
@@ -27,12 +29,40 @@ namespace ir {
IRModuleFrame IRModule() {
ObjectPtr<IRModuleFrameNode> n = make_object<IRModuleFrameNode>();
- n->global_vars.clear();
+ n->global_var_map.clear();
n->functions.clear();
return IRModuleFrame(n);
}
+GlobalVar DeclFunction(const String& func_name, const BaseFunc& func_signature) {
+ IRModuleFrame frame = FindModuleFrame("I.DeclFunction");
+ CHECK(!frame->global_var_map.count(func_name))
+ << "ValueError: function " << func_name << " already exists";
+ GlobalVar gv = GlobalVar(func_name);
+ CHECK(frame->functions.find(gv) == frame->functions.end())
+ << "ValueError: function " << func_name << " has already been defined.";
+ frame->global_var_map.Set(func_name, gv);
+ if (func_signature.defined()) {
+ frame->functions.Set(gv, func_signature);
+ }
+ return gv;
+}
+
+void DefFunction(const String& func_name, const BaseFunc& func) {
+ IRModuleFrame frame = FindModuleFrame("I.DefFunction");
+ auto it = frame->global_var_map.find(func_name);
+ CHECK(it != frame->global_var_map.end())
+ << "ValueError: function " << func_name << " does not exist, please declare it first.";
+ const GlobalVar& gv = (*it).second;
+ frame->functions.Set(gv, func);
+ if (func->checked_type_.defined()) {
+ gv->checked_type_ = func->checked_type_;
+ }
+}
+
TVM_REGISTER_GLOBAL("script.ir_builder.ir.IRModule").set_body_typed(IRModule);
+TVM_REGISTER_GLOBAL("script.ir_builder.ir.DeclFunction").set_body_typed(DeclFunction);
+TVM_REGISTER_GLOBAL("script.ir_builder.ir.DefFunction").set_body_typed(DefFunction);
} // namespace ir
} // namespace ir_builder
diff --git a/src/script/ir_builder/ir/frame.cc b/src/script/ir_builder/ir/utils.h
similarity index 59%
copy from src/script/ir_builder/ir/frame.cc
copy to src/script/ir_builder/ir/utils.h
index a81c56922d..58d5e53f70 100644
--- a/src/script/ir_builder/ir/frame.cc
+++ b/src/script/ir_builder/ir/utils.h
@@ -16,8 +16,9 @@
* specific language governing permissions and limitations
* under the License.
*/
-#include <tvm/ir/module.h>
-#include <tvm/runtime/registry.h>
+#ifndef TVM_SCRIPT_IR_BUILDER_IR_UTILS_H_
+#define TVM_SCRIPT_IR_BUILDER_IR_UTILS_H_
+
#include <tvm/script/ir_builder/ir/frame.h>
namespace tvm {
@@ -25,21 +26,24 @@ namespace script {
namespace ir_builder {
namespace ir {
-void IRModuleFrameNode::ExitWithScope() {
- ICHECK_EQ(functions.size(), global_vars.size());
- int n = functions.size();
- Map<GlobalVar, BaseFunc> func_map;
- for (int i = 0; i < n; ++i) {
- func_map.Set(global_vars[i], functions[i]);
- }
+inline IRModuleFrame FindModuleFrame(const String& method) {
IRBuilder builder = IRBuilder::Current();
- ICHECK(!builder->result.defined()) << "ValueError: Builder.result has already been set";
- builder->result = tvm::IRModule(func_map);
+ if (Optional<IRModuleFrame> frame = builder->FindFrame<IRModuleFrame>()) {
+ const Optional<IRModuleFrame>& last_module_frame = builder->GetLastFrame<IRModuleFrame>();
+ if (last_module_frame.defined() && last_module_frame.value() == frame) {
+ return frame.value();
+ }
+ } else {
+ LOG(FATAL) << "ValueError: IRModule frame not find. Please ensure '" << method
+ << "' is called under I.ir_module()";
+ }
+ LOG(FATAL) << "ValueError: '" << method << "' must be called immediately under I.ir_module()";
+ throw;
}
-TVM_REGISTER_NODE_TYPE(IRModuleFrameNode);
-
} // namespace ir
} // namespace ir_builder
} // namespace script
} // namespace tvm
+
+#endif // TVM_SCRIPT_IR_BUILDER_IR_UTILS_H_
diff --git a/src/script/ir_builder/tir/frame.cc b/src/script/ir_builder/tir/frame.cc
index 1e63201a40..dd8d3c2ed3 100644
--- a/src/script/ir_builder/tir/frame.cc
+++ b/src/script/ir_builder/tir/frame.cc
@@ -16,6 +16,7 @@
* specific language governing permissions and limitations
* under the License.
*/
+#include <tvm/script/ir_builder/ir/ir.h>
#include <tvm/script/ir_builder/tir/frame.h>
#include <tvm/tir/function.h>
@@ -41,9 +42,17 @@ void PrimFuncFrameNode::ExitWithScope() {
ICHECK(!builder->result.defined()) << "ValueError: Builder.result has already been set";
builder->result = func;
} else if (Optional<ir::IRModuleFrame> opt_frame = builder->FindFrame<ir::IRModuleFrame>()) {
- ir::IRModuleFrame frame = opt_frame.value();
- frame->global_vars.push_back(GlobalVar(name.value_or("")));
- frame->functions.push_back(func);
+ CHECK(name.defined()) << "ValueError: The function name must be defined before exiting the "
+ "function scope, if it's defined in a Module";
+ const ir::IRModuleFrame& frame = opt_frame.value();
+ const String& func_name = name.value_or("");
+ if (!frame->global_var_map.count(func_name)) {
+ // Case. First time visiting the function.
+ ir::DeclFunction(func_name, func);
+ }
+ // Define the function.
+ // Note we do checks to disallow redefinition of functions inside the `DefFunction`.
+ ir::DefFunction(func_name, func);
} else {
LOG(FATAL) << "ValueError: Cannot find where to insert PrimFunc";
}
diff --git a/src/script/ir_builder/tir/utils.h b/src/script/ir_builder/tir/utils.h
index 7ccc132fa1..f3b547532c 100644
--- a/src/script/ir_builder/tir/utils.h
+++ b/src/script/ir_builder/tir/utils.h
@@ -87,7 +87,7 @@ inline PrimFuncFrame FindPrimFuncFrame(const String& method) {
* \return The top frame of BlockFrame.
*/
inline BlockFrame FindBlockFrame(const String& method) {
- if (Optional<BlockFrame> frame = IRBuilder::Current()->GetLastFrame<BlockFrame>()) {
+ if (Optional<BlockFrame> frame = IRBuilder::Current()->FindFrame<BlockFrame>()) {
return frame.value();
} else if (Optional<BlockFrame> frame = IRBuilder::Current()->FindFrame<BlockFrame>()) {
LOG(FATAL) << "ValueError: " << method << " must be called at the top of a T.block(). "