You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2023/09/26 21:23:00 UTC
[tvm] branch unity updated: [Unity][TVMScript] Produce var = R.ExternFunc("") statements (#15703)
This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 732ae53653 [Unity][TVMScript] Produce var = R.ExternFunc("") statements (#15703)
732ae53653 is described below
commit 732ae53653b6123bb61609645b946a249a1351f2
Author: Eric Lunderberg <Lu...@users.noreply.github.com>
AuthorDate: Tue Sep 26 16:22:54 2023 -0500
[Unity][TVMScript] Produce var = R.ExternFunc("") statements (#15703)
* [Unity][TVMScript] Produce var = R.ExternFunc("") statements
Prior to this commit, any `ExternFunc` usage in a relax function would
print the string name of the function on its own line, omitting any
variable definition, and later use of the variable would occur without
a definition. This commit updates the printing of `R.ExternFunc` to
appear as a normal relax variable.
* Preserve special handling as callee, test round-trip
* Updated parser to handle `var = R.ExternFunc(...)` in IRModule
Since this is now a representation that may be produced by the
TVMScript printer, it must also be handled at the parser.
---
python/tvm/script/ir_builder/relax/ir.py | 1 +
python/tvm/script/parser/ir/parser.py | 9 +-
python/tvm/script/parser/relax/entry.py | 110 +++++++++++++--------
src/script/printer/ir/ir.cc | 15 ++-
src/script/printer/relax/binding.cc | 3 +-
src/script/printer/relax/function.cc | 2 +-
src/script/printer/relax/struct_info.cc | 26 ++++-
tests/python/relax/test_tvmscript_parser.py | 43 ++++++++
tests/python/relax/test_tvmscript_printer_relax.py | 54 +++++++++-
9 files changed, 208 insertions(+), 55 deletions(-)
diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py
index 5c53836851..151a8caf40 100644
--- a/python/tvm/script/ir_builder/relax/ir.py
+++ b/python/tvm/script/ir_builder/relax/ir.py
@@ -627,6 +627,7 @@ __all__ = [
"SeqExpr",
"Then",
"TupleGetItem",
+ "ExternFunc",
"abs",
"acos",
"acosh",
diff --git a/python/tvm/script/parser/ir/parser.py b/python/tvm/script/parser/ir/parser.py
index e11fa43162..4ea57130f1 100644
--- a/python/tvm/script/parser/ir/parser.py
+++ b/python/tvm/script/parser/ir/parser.py
@@ -71,7 +71,7 @@ def _visit_class_def(self: Parser, node: doc.ClassDef) -> None:
@dispatch.register(token="ir", type_name="Assign")
-def _visit_assign(_self: Parser, _node: doc.Assign) -> None:
+def _visit_assign(self: Parser, node: doc.Assign) -> None:
"""The assign visiting method for ir module.
Parameters
@@ -82,6 +82,13 @@ def _visit_assign(_self: Parser, _node: doc.Assign) -> None:
node : doc.ClassDef
The doc AST assign node.
"""
+ if len(node.targets) != 1:
+ self.report_error(node, "Consequential assignments like 'a = b = c' are not supported.")
+ lhs = node.targets[0].id
+ rhs = self.eval_expr(node.value)
+
+ I.decl_function(lhs, rhs)
+ I.def_function(lhs, rhs)
@dispatch.register(token="ir", type_name="Expr")
diff --git a/python/tvm/script/parser/relax/entry.py b/python/tvm/script/parser/relax/entry.py
index d3e2342750..d5950dc66d 100644
--- a/python/tvm/script/parser/relax/entry.py
+++ b/python/tvm/script/parser/relax/entry.py
@@ -153,6 +153,35 @@ class StructInfoProxy(ObjectGeneric):
return self.as_struct_info(None)
+############################### R.Object ################################
+
+
+class ObjectProxy(StructInfoProxy):
+ """The proxy fo ObjectStructInfo.
+
+ Parameters
+ ----------
+ values : Optional[List[PrimExpr]]
+ The symbolic shape values if known.
+
+ ndim : Optional[int]
+ The size of the shape.
+ """
+
+ def __init__(self) -> None:
+ pass
+
+ def get_symbolic_vars(self) -> Set[str]:
+ return set()
+
+ def as_struct_info(self, dict_globals: Optional[Dict[str, Any]] = None) -> ShapeStructInfo:
+ return ObjectStructInfo()
+
+
+def Object() -> ObjectProxy:
+ return ObjectProxy()
+
+
############################### R.Tensor ###############################
@@ -270,30 +299,54 @@ class CallableProxy(StructInfoProxy):
def __init__(
self,
- params: Union[StructInfoProxy, List[StructInfoProxy]],
- ret: StructInfoProxy,
- purity: bool = True,
+ params: Optional[Union[StructInfoProxy, List[StructInfoProxy]]] = None,
+ ret: Optional[StructInfoProxy] = None,
+ purity: Optional[bool] = None,
) -> None:
- if not isinstance(params, (list, tuple)):
- params = [params]
- # convert `R.Tensor` to `R.Tensor()`
- self.params = [param() if callable(param) else param for param in params]
+ if params is None:
+ self.params = params
+ else:
+ if not isinstance(params, (list, tuple)):
+ params = [params]
+ # convert `R.Callable` to `R.Callable()`
+ self.params = [param() if callable(param) else param for param in params]
+
+ # Mimic the C++ defaults, where an opaque function is assumed
+ # to be impure, and a non-opaque function is assumed to be
+ # pure.
+ if purity is None:
+ purity = params is not None
+
self.ret = ret() if callable(ret) else ret
self.purity = purity
def get_symbolic_vars(self) -> Set[str]:
- return set().union(*[p.get_symbolic_vars() for p in self.params])
+ if self.params is None:
+ return set()
+ else:
+ return set().union(*[p.get_symbolic_vars() for p in self.params])
def as_struct_info(self, dict_globals: Optional[Dict[str, Any]] = None) -> FuncStructInfo:
- params = [param.as_struct_info(dict_globals) for param in self.params]
- ret = self.ret.as_struct_info(dict_globals)
- return FuncStructInfo(params, ret, purity=self.purity)
+ if self.ret is None:
+ ret = None
+ else:
+ ret = self.ret.as_struct_info(dict_globals)
+
+ if self.params is None:
+ params = None
+ else:
+ params = [param.as_struct_info(dict_globals) for param in self.params]
+
+ if params is None:
+ return FuncStructInfo.opaque_func(ret=ret, purity=self.purity)
+ else:
+ return FuncStructInfo(params, ret, purity=self.purity)
def Callable(
- params: Union[StructInfoProxy, List[StructInfoProxy]],
- ret: StructInfoProxy,
- purity: bool = True,
+ params: Optional[Union[StructInfoProxy, List[StructInfoProxy]]] = None,
+ ret: Optional[StructInfoProxy] = None,
+ purity: Optional[bool] = None,
) -> CallableProxy:
return CallableProxy(params, ret, purity=purity)
@@ -372,35 +425,6 @@ def Shape(values: Optional[List[PrimExpr]] = None, ndim: int = -1) -> ShapeProxy
return ShapeProxy(values, ndim)
-############################### R.Object ################################
-
-
-class ObjectProxy(StructInfoProxy):
- """The proxy fo ObjectStructInfo.
-
- Parameters
- ----------
- values : Optional[List[PrimExpr]]
- The symbolic shape values if known.
-
- ndim : Optional[int]
- The size of the shape.
- """
-
- def __init__(self) -> None:
- pass
-
- def get_symbolic_vars(self) -> Set[str]:
- return set()
-
- def as_struct_info(self, dict_globals: Optional[Dict[str, Any]] = None) -> ShapeStructInfo:
- return ObjectStructInfo()
-
-
-def Object() -> ObjectProxy:
- return ObjectProxy()
-
-
################################ R.Prim ################################
diff --git a/src/script/printer/ir/ir.cc b/src/script/printer/ir/ir.cc
index a239481d03..5295cf2e41 100644
--- a/src/script/printer/ir/ir.cc
+++ b/src/script/printer/ir/ir.cc
@@ -87,17 +87,26 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
// Print functions
for (const auto& entry : functions) {
const GlobalVar& gv = entry.gv;
- const BaseFunc& func = entry.func;
+ const BaseFunc& base_func = entry.func;
d->cfg->binding_names.push_back(gv->name_hint);
- Doc doc = d->AsDoc(func, p->Attr("functions")->MapValue(gv));
+ Doc doc = d->AsDoc(base_func, p->Attr("functions")->MapValue(gv));
d->cfg->binding_names.pop_back();
if (const auto* stmt_block = doc.as<StmtBlockDocNode>()) {
(*f)->stmts.push_back(stmt_block->stmts.back());
(*f)->stmts.back()->source_paths = std::move(doc->source_paths);
} else if (auto stmt = doc.as<StmtDoc>()) {
(*f)->stmts.push_back(stmt.value());
+ } else if (auto func = doc.as<FunctionDoc>()) {
+ (*f)->stmts.push_back(func.value());
+ } else if (auto expr = doc.as<ExprDoc>()) {
+ ExprDoc lhs = IdDoc(gv->name_hint);
+ AssignDoc assignment(lhs, expr.value(), NullOpt);
+ (*f)->stmts.push_back(assignment);
} else {
- (*f)->stmts.push_back(Downcast<FunctionDoc>(doc));
+ LOG(FATAL) << "TypeError: "
+ << "Expected IRModule to only contain functions, "
+ << " but mod[" << gv->name_hint << "] with type " << base_func->GetTypeKey()
+ << " produced Doc type of " << doc->GetTypeKey();
}
}
return HeaderWrapper(d, ClassDoc(module_doc, {IR(d, "ir_module")}, (*f)->stmts));
diff --git a/src/script/printer/relax/binding.cc b/src/script/printer/relax/binding.cc
index 8a50fe9698..395b4251fb 100644
--- a/src/script/printer/relax/binding.cc
+++ b/src/script/printer/relax/binding.cc
@@ -59,7 +59,8 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
Optional<ExprDoc> ann = StructInfoAsAnn(n->var, n_p->Attr("var"), d, n->value);
ExprDoc lhs = DefineVar(n->var, d->frames.back(), d);
return PrintIfExpr(GetRef<relax::If>(if_), n_p->Attr("value"), d, lhs, ann);
- } else if (n->value->IsInstance<tvm::BaseFuncNode>()) {
+ } else if (n->value->IsInstance<tvm::BaseFuncNode>() &&
+ !n->value->IsInstance<relax::ExternFuncNode>()) {
IdDoc lhs = DefineVar(n->var, d->frames.back(), d);
d->cfg->binding_names.push_back(lhs->name);
Doc ret = d->AsDoc(n->value, n_p->Attr("value"));
diff --git a/src/script/printer/relax/function.cc b/src/script/printer/relax/function.cc
index bc5f12309f..5fb54c793d 100644
--- a/src/script/printer/relax/function.cc
+++ b/src/script/printer/relax/function.cc
@@ -129,7 +129,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<relax::ExternFunc>( //
"", [](relax::ExternFunc n, ObjectPath n_p, IRDocsifier d) -> Doc {
// TODO(@junrushao): print more information out of extern function.
- return ExprStmtDoc(LiteralDoc::Str(n->global_symbol, n_p));
+ return Relax(d, "ExternFunc")->Call({LiteralDoc::Str(n->global_symbol, n_p)});
});
TVM_SCRIPT_REPR(relax::FunctionNode, ReprPrintRelax);
diff --git a/src/script/printer/relax/struct_info.cc b/src/script/printer/relax/struct_info.cc
index 11f987c368..7043952c7c 100644
--- a/src/script/printer/relax/struct_info.cc
+++ b/src/script/printer/relax/struct_info.cc
@@ -152,8 +152,27 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<relax::FuncStructInfo>( //
"", [](relax::FuncStructInfo n, ObjectPath n_p, IRDocsifier d) -> Doc {
+ auto ret_doc = d->AsDoc<ExprDoc>(n->ret, n_p->Attr("ret"));
+ auto purity_doc = LiteralDoc::Boolean(n->purity, n_p->Attr("purity"));
+
if (n->IsOpaque()) {
- return Relax(d, "Callable");
+ Array<String> keys;
+ Array<ExprDoc, void> values;
+
+ if (!n->ret->IsInstance<relax::ObjectStructInfoNode>()) {
+ keys.push_back("ret");
+ values.push_back(ret_doc);
+ }
+ if (n->purity) {
+ keys.push_back("purity");
+ values.push_back(purity_doc);
+ }
+
+ if (keys.size()) {
+ return Relax(d, "Callable")->Call({}, keys, values);
+ } else {
+ return Relax(d, "Callable");
+ }
}
// TODO(@junrushao): track symbolic shape relation
Array<ExprDoc> params_doc;
@@ -162,10 +181,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
for (int i = 0, n_params = params.size(); i < n_params; ++i) {
params_doc.push_back(d->AsDoc<ExprDoc>(params[i], params_p->ArrayIndex(i)));
}
- return Relax(d, "Callable")
- ->Call({TupleDoc(params_doc), //
- d->AsDoc<ExprDoc>(n->ret, n_p->Attr("ret")), //
- LiteralDoc::Boolean(n->purity, n_p->Attr("purity"))});
+ return Relax(d, "Callable")->Call({TupleDoc(params_doc), ret_doc, purity_doc});
});
TVM_SCRIPT_REPR(relax::ObjectStructInfoNode, ReprPrintRelax);
diff --git a/tests/python/relax/test_tvmscript_parser.py b/tests/python/relax/test_tvmscript_parser.py
index d86b1e0108..b45c3c6e4a 100644
--- a/tests/python/relax/test_tvmscript_parser.py
+++ b/tests/python/relax/test_tvmscript_parser.py
@@ -1772,5 +1772,48 @@ def test_macro_no_variable_leak():
return x # Should be undefined here
+def test_reused_extern_func():
+ """ExternFunc lookups can become bindings in EliminateCommonSubexpr"""
+
+ @R.function(private=True)
+ def parsed(x: R.Tensor((128, 128), "float32")) -> R.Tensor((128, 128), "float32"):
+ func = R.ExternFunc("extern_func")
+ gv0 = R.call_dps_packed(func, x, R.Tensor((128, 128), dtype="float32"))
+ gv1 = R.call_dps_packed(func, gv0, R.Tensor((128, 128), dtype="float32"))
+ return gv1
+
+ x = relax.Var("x", R.Tensor((128, 128), "float32"))
+ bb = relax.BlockBuilder()
+ with bb.function("main", [x], private=True):
+ func = bb.emit(relax.ExternFunc("extern_func"))
+ y = bb.emit(relax.call_dps_packed(func, x, out_sinfo=R.Tensor((128, 128), "float32")))
+ z = bb.emit(relax.call_dps_packed(func, y, out_sinfo=R.Tensor((128, 128), "float32")))
+ bb.emit_func_output(z)
+
+ expected = bb.get()["main"]
+
+ _check(parsed, expected)
+
+
+def test_extern_func_in_module():
+ """Module-level parsing may produce function bindings"""
+
+ @I.ir_module
+ class parsed_module:
+ my_ext = R.ExternFunc("my_ext")
+
+ @R.function
+ def func(a: R.Tensor((10, 10))) -> R.Tensor((10, 10)):
+ return a
+
+ @R.function
+ def func(a: R.Tensor((10, 10))) -> R.Tensor((10, 10)):
+ return a
+
+ expected = tvm.IRModule({"my_ext": relax.ExternFunc("my_ext"), "func": func})
+
+ _check(parsed_module, expected)
+
+
if __name__ == "__main__":
tvm.testing.main()
diff --git a/tests/python/relax/test_tvmscript_printer_relax.py b/tests/python/relax/test_tvmscript_printer_relax.py
index 2e4218b2ab..dc3334f216 100644
--- a/tests/python/relax/test_tvmscript_printer_relax.py
+++ b/tests/python/relax/test_tvmscript_printer_relax.py
@@ -88,7 +88,7 @@ def test_extern_func():
@I.ir_module
class Module:
- "my_ext"
+ my_ext = R.ExternFunc("my_ext")
@R.function
def func(a: R.Tensor((10, 10))) -> R.Tensor((10, 10)):
return a
@@ -752,5 +752,57 @@ class Module:
)
+def test_reused_extern_func():
+ """An ExternFunc used in a variable binding should be explicit"""
+
+ @R.function
+ def func(x: R.Tensor((128, 128), dtype="float32")) -> R.Tensor((128, 128), dtype="float32"):
+ extern_func = R.ExternFunc("extern_func")
+ y = R.call_dps_packed(extern_func, (x,), out_sinfo=R.Tensor((128, 128), dtype="float32"))
+ z = R.call_dps_packed(extern_func, (y,), out_sinfo=R.Tensor((128, 128), dtype="float32"))
+ return z
+
+ _assert_print(
+ func,
+ """
+# from tvm.script import relax as R
+
+@R.function
+def func(x: R.Tensor((128, 128), dtype="float32")) -> R.Tensor((128, 128), dtype="float32"):
+ extern_func: R.Callable = R.ExternFunc("extern_func")
+ y = R.call_dps_packed(extern_func, (x,), out_sinfo=R.Tensor((128, 128), dtype="float32"))
+ z = R.call_dps_packed(extern_func, (y,), out_sinfo=R.Tensor((128, 128), dtype="float32"))
+ return z
+ """,
+ )
+
+
+def test_inline_extern_func():
+ """An ExternFunc used in-line may be printed as a string"""
+
+ @R.function
+ def func(x: R.Tensor((128, 128), dtype="float32")) -> R.Tensor((128, 128), dtype="float32"):
+ y = R.call_dps_packed(
+ R.ExternFunc("extern_func"), (x,), out_sinfo=R.Tensor((128, 128), dtype="float32")
+ )
+ z = R.call_dps_packed(
+ R.ExternFunc("extern_func"), (y,), out_sinfo=R.Tensor((128, 128), dtype="float32")
+ )
+ return z
+
+ _assert_print(
+ func,
+ """
+# from tvm.script import relax as R
+
+@R.function
+def func(x: R.Tensor((128, 128), dtype="float32")) -> R.Tensor((128, 128), dtype="float32"):
+ y = R.call_dps_packed("extern_func", (x,), out_sinfo=R.Tensor((128, 128), dtype="float32"))
+ z = R.call_dps_packed("extern_func", (y,), out_sinfo=R.Tensor((128, 128), dtype="float32"))
+ return z
+ """,
+ )
+
+
if __name__ == "__main__":
tvm.testing.main()