You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2023/09/26 21:23:00 UTC

[tvm] branch unity updated: [Unity][TVMScript] Produce var = R.ExternFunc("") statements (#15703)

This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 732ae53653 [Unity][TVMScript] Produce var = R.ExternFunc("") statements (#15703)
732ae53653 is described below

commit 732ae53653b6123bb61609645b946a249a1351f2
Author: Eric Lunderberg <Lu...@users.noreply.github.com>
AuthorDate: Tue Sep 26 16:22:54 2023 -0500

    [Unity][TVMScript] Produce var = R.ExternFunc("") statements (#15703)
    
    * [Unity][TVMScript] Produce var = R.ExternFunc("") statements
    
    Prior to this commit, any `ExternFunc` usage in a relax function would
    print the string name of the function on its own line, omitting any
    variable definition, and later use of the variable would occur without
    a definition.  This commit updates the printing of `R.ExternFunc` to
    appear as a normal relax variable.
    
    * Preserve special handling as callee, test round-trip
    
    * Updated parser to handle `var = R.ExternFunc(...)` in IRModule
    
    Since this is now a representation that may be produced by the
    TVMScript printer, it must also be handled at the parser.
---
 python/tvm/script/ir_builder/relax/ir.py           |   1 +
 python/tvm/script/parser/ir/parser.py              |   9 +-
 python/tvm/script/parser/relax/entry.py            | 110 +++++++++++++--------
 src/script/printer/ir/ir.cc                        |  15 ++-
 src/script/printer/relax/binding.cc                |   3 +-
 src/script/printer/relax/function.cc               |   2 +-
 src/script/printer/relax/struct_info.cc            |  26 ++++-
 tests/python/relax/test_tvmscript_parser.py        |  43 ++++++++
 tests/python/relax/test_tvmscript_printer_relax.py |  54 +++++++++-
 9 files changed, 208 insertions(+), 55 deletions(-)

diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py
index 5c53836851..151a8caf40 100644
--- a/python/tvm/script/ir_builder/relax/ir.py
+++ b/python/tvm/script/ir_builder/relax/ir.py
@@ -627,6 +627,7 @@ __all__ = [
     "SeqExpr",
     "Then",
     "TupleGetItem",
+    "ExternFunc",
     "abs",
     "acos",
     "acosh",
diff --git a/python/tvm/script/parser/ir/parser.py b/python/tvm/script/parser/ir/parser.py
index e11fa43162..4ea57130f1 100644
--- a/python/tvm/script/parser/ir/parser.py
+++ b/python/tvm/script/parser/ir/parser.py
@@ -71,7 +71,7 @@ def _visit_class_def(self: Parser, node: doc.ClassDef) -> None:
 
 
 @dispatch.register(token="ir", type_name="Assign")
-def _visit_assign(_self: Parser, _node: doc.Assign) -> None:
+def _visit_assign(self: Parser, node: doc.Assign) -> None:
     """The assign visiting method for ir module.
 
     Parameters
@@ -82,6 +82,13 @@ def _visit_assign(_self: Parser, _node: doc.Assign) -> None:
     node : doc.ClassDef
         The doc AST assign node.
     """
+    if len(node.targets) != 1:
+        self.report_error(node, "Consequential assignments like 'a = b = c' are not supported.")
+    lhs = node.targets[0].id
+    rhs = self.eval_expr(node.value)
+
+    I.decl_function(lhs, rhs)
+    I.def_function(lhs, rhs)
 
 
 @dispatch.register(token="ir", type_name="Expr")
diff --git a/python/tvm/script/parser/relax/entry.py b/python/tvm/script/parser/relax/entry.py
index d3e2342750..d5950dc66d 100644
--- a/python/tvm/script/parser/relax/entry.py
+++ b/python/tvm/script/parser/relax/entry.py
@@ -153,6 +153,35 @@ class StructInfoProxy(ObjectGeneric):
         return self.as_struct_info(None)
 
 
+############################### R.Object ################################
+
+
+class ObjectProxy(StructInfoProxy):
+    """The proxy fo ObjectStructInfo.
+
+    Parameters
+    ----------
+    values : Optional[List[PrimExpr]]
+       The symbolic shape values if known.
+
+    ndim : Optional[int]
+       The size of the shape.
+    """
+
+    def __init__(self) -> None:
+        pass
+
+    def get_symbolic_vars(self) -> Set[str]:
+        return set()
+
+    def as_struct_info(self, dict_globals: Optional[Dict[str, Any]] = None) -> ShapeStructInfo:
+        return ObjectStructInfo()
+
+
+def Object() -> ObjectProxy:
+    return ObjectProxy()
+
+
 ############################### R.Tensor ###############################
 
 
@@ -270,30 +299,54 @@ class CallableProxy(StructInfoProxy):
 
     def __init__(
         self,
-        params: Union[StructInfoProxy, List[StructInfoProxy]],
-        ret: StructInfoProxy,
-        purity: bool = True,
+        params: Optional[Union[StructInfoProxy, List[StructInfoProxy]]] = None,
+        ret: Optional[StructInfoProxy] = None,
+        purity: Optional[bool] = None,
     ) -> None:
-        if not isinstance(params, (list, tuple)):
-            params = [params]
-        # convert `R.Tensor` to `R.Tensor()`
-        self.params = [param() if callable(param) else param for param in params]
+        if params is None:
+            self.params = params
+        else:
+            if not isinstance(params, (list, tuple)):
+                params = [params]
+            # convert `R.Callable` to `R.Callable()`
+            self.params = [param() if callable(param) else param for param in params]
+
+        # Mimic the C++ defaults, where an opaque function is assumed
+        # to be impure, and a non-opaque function is assumed to be
+        # pure.
+        if purity is None:
+            purity = params is not None
+
         self.ret = ret() if callable(ret) else ret
         self.purity = purity
 
     def get_symbolic_vars(self) -> Set[str]:
-        return set().union(*[p.get_symbolic_vars() for p in self.params])
+        if self.params is None:
+            return set()
+        else:
+            return set().union(*[p.get_symbolic_vars() for p in self.params])
 
     def as_struct_info(self, dict_globals: Optional[Dict[str, Any]] = None) -> FuncStructInfo:
-        params = [param.as_struct_info(dict_globals) for param in self.params]
-        ret = self.ret.as_struct_info(dict_globals)
-        return FuncStructInfo(params, ret, purity=self.purity)
+        if self.ret is None:
+            ret = None
+        else:
+            ret = self.ret.as_struct_info(dict_globals)
+
+        if self.params is None:
+            params = None
+        else:
+            params = [param.as_struct_info(dict_globals) for param in self.params]
+
+        if params is None:
+            return FuncStructInfo.opaque_func(ret=ret, purity=self.purity)
+        else:
+            return FuncStructInfo(params, ret, purity=self.purity)
 
 
 def Callable(
-    params: Union[StructInfoProxy, List[StructInfoProxy]],
-    ret: StructInfoProxy,
-    purity: bool = True,
+    params: Optional[Union[StructInfoProxy, List[StructInfoProxy]]] = None,
+    ret: Optional[StructInfoProxy] = None,
+    purity: Optional[bool] = None,
 ) -> CallableProxy:
     return CallableProxy(params, ret, purity=purity)
 
@@ -372,35 +425,6 @@ def Shape(values: Optional[List[PrimExpr]] = None, ndim: int = -1) -> ShapeProxy
     return ShapeProxy(values, ndim)
 
 
-############################### R.Object ################################
-
-
-class ObjectProxy(StructInfoProxy):
-    """The proxy fo ObjectStructInfo.
-
-    Parameters
-    ----------
-    values : Optional[List[PrimExpr]]
-       The symbolic shape values if known.
-
-    ndim : Optional[int]
-       The size of the shape.
-    """
-
-    def __init__(self) -> None:
-        pass
-
-    def get_symbolic_vars(self) -> Set[str]:
-        return set()
-
-    def as_struct_info(self, dict_globals: Optional[Dict[str, Any]] = None) -> ShapeStructInfo:
-        return ObjectStructInfo()
-
-
-def Object() -> ObjectProxy:
-    return ObjectProxy()
-
-
 ################################ R.Prim ################################
 
 
diff --git a/src/script/printer/ir/ir.cc b/src/script/printer/ir/ir.cc
index a239481d03..5295cf2e41 100644
--- a/src/script/printer/ir/ir.cc
+++ b/src/script/printer/ir/ir.cc
@@ -87,17 +87,26 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
       // Print functions
       for (const auto& entry : functions) {
         const GlobalVar& gv = entry.gv;
-        const BaseFunc& func = entry.func;
+        const BaseFunc& base_func = entry.func;
         d->cfg->binding_names.push_back(gv->name_hint);
-        Doc doc = d->AsDoc(func, p->Attr("functions")->MapValue(gv));
+        Doc doc = d->AsDoc(base_func, p->Attr("functions")->MapValue(gv));
         d->cfg->binding_names.pop_back();
         if (const auto* stmt_block = doc.as<StmtBlockDocNode>()) {
           (*f)->stmts.push_back(stmt_block->stmts.back());
           (*f)->stmts.back()->source_paths = std::move(doc->source_paths);
         } else if (auto stmt = doc.as<StmtDoc>()) {
           (*f)->stmts.push_back(stmt.value());
+        } else if (auto func = doc.as<FunctionDoc>()) {
+          (*f)->stmts.push_back(func.value());
+        } else if (auto expr = doc.as<ExprDoc>()) {
+          ExprDoc lhs = IdDoc(gv->name_hint);
+          AssignDoc assignment(lhs, expr.value(), NullOpt);
+          (*f)->stmts.push_back(assignment);
         } else {
-          (*f)->stmts.push_back(Downcast<FunctionDoc>(doc));
+          LOG(FATAL) << "TypeError: "
+                     << "Expected IRModule to only contain functions, "
+                     << " but mod[" << gv->name_hint << "] with type  " << base_func->GetTypeKey()
+                     << " produced Doc type of " << doc->GetTypeKey();
         }
       }
       return HeaderWrapper(d, ClassDoc(module_doc, {IR(d, "ir_module")}, (*f)->stmts));
diff --git a/src/script/printer/relax/binding.cc b/src/script/printer/relax/binding.cc
index 8a50fe9698..395b4251fb 100644
--- a/src/script/printer/relax/binding.cc
+++ b/src/script/printer/relax/binding.cc
@@ -59,7 +59,8 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
             Optional<ExprDoc> ann = StructInfoAsAnn(n->var, n_p->Attr("var"), d, n->value);
             ExprDoc lhs = DefineVar(n->var, d->frames.back(), d);
             return PrintIfExpr(GetRef<relax::If>(if_), n_p->Attr("value"), d, lhs, ann);
-          } else if (n->value->IsInstance<tvm::BaseFuncNode>()) {
+          } else if (n->value->IsInstance<tvm::BaseFuncNode>() &&
+                     !n->value->IsInstance<relax::ExternFuncNode>()) {
             IdDoc lhs = DefineVar(n->var, d->frames.back(), d);
             d->cfg->binding_names.push_back(lhs->name);
             Doc ret = d->AsDoc(n->value, n_p->Attr("value"));
diff --git a/src/script/printer/relax/function.cc b/src/script/printer/relax/function.cc
index bc5f12309f..5fb54c793d 100644
--- a/src/script/printer/relax/function.cc
+++ b/src/script/printer/relax/function.cc
@@ -129,7 +129,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
     .set_dispatch<relax::ExternFunc>(  //
         "", [](relax::ExternFunc n, ObjectPath n_p, IRDocsifier d) -> Doc {
           // TODO(@junrushao): print more information out of extern function.
-          return ExprStmtDoc(LiteralDoc::Str(n->global_symbol, n_p));
+          return Relax(d, "ExternFunc")->Call({LiteralDoc::Str(n->global_symbol, n_p)});
         });
 
 TVM_SCRIPT_REPR(relax::FunctionNode, ReprPrintRelax);
diff --git a/src/script/printer/relax/struct_info.cc b/src/script/printer/relax/struct_info.cc
index 11f987c368..7043952c7c 100644
--- a/src/script/printer/relax/struct_info.cc
+++ b/src/script/printer/relax/struct_info.cc
@@ -152,8 +152,27 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
 TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
     .set_dispatch<relax::FuncStructInfo>(  //
         "", [](relax::FuncStructInfo n, ObjectPath n_p, IRDocsifier d) -> Doc {
+          auto ret_doc = d->AsDoc<ExprDoc>(n->ret, n_p->Attr("ret"));
+          auto purity_doc = LiteralDoc::Boolean(n->purity, n_p->Attr("purity"));
+
           if (n->IsOpaque()) {
-            return Relax(d, "Callable");
+            Array<String> keys;
+            Array<ExprDoc, void> values;
+
+            if (!n->ret->IsInstance<relax::ObjectStructInfoNode>()) {
+              keys.push_back("ret");
+              values.push_back(ret_doc);
+            }
+            if (n->purity) {
+              keys.push_back("purity");
+              values.push_back(purity_doc);
+            }
+
+            if (keys.size()) {
+              return Relax(d, "Callable")->Call({}, keys, values);
+            } else {
+              return Relax(d, "Callable");
+            }
           }
           // TODO(@junrushao): track symbolic shape relation
           Array<ExprDoc> params_doc;
@@ -162,10 +181,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
           for (int i = 0, n_params = params.size(); i < n_params; ++i) {
             params_doc.push_back(d->AsDoc<ExprDoc>(params[i], params_p->ArrayIndex(i)));
           }
-          return Relax(d, "Callable")
-              ->Call({TupleDoc(params_doc),                         //
-                      d->AsDoc<ExprDoc>(n->ret, n_p->Attr("ret")),  //
-                      LiteralDoc::Boolean(n->purity, n_p->Attr("purity"))});
+          return Relax(d, "Callable")->Call({TupleDoc(params_doc), ret_doc, purity_doc});
         });
 
 TVM_SCRIPT_REPR(relax::ObjectStructInfoNode, ReprPrintRelax);
diff --git a/tests/python/relax/test_tvmscript_parser.py b/tests/python/relax/test_tvmscript_parser.py
index d86b1e0108..b45c3c6e4a 100644
--- a/tests/python/relax/test_tvmscript_parser.py
+++ b/tests/python/relax/test_tvmscript_parser.py
@@ -1772,5 +1772,48 @@ def test_macro_no_variable_leak():
             return x  # Should be undefined here
 
 
+def test_reused_extern_func():
+    """ExternFunc lookups can become bindings in EliminateCommonSubexpr"""
+
+    @R.function(private=True)
+    def parsed(x: R.Tensor((128, 128), "float32")) -> R.Tensor((128, 128), "float32"):
+        func = R.ExternFunc("extern_func")
+        gv0 = R.call_dps_packed(func, x, R.Tensor((128, 128), dtype="float32"))
+        gv1 = R.call_dps_packed(func, gv0, R.Tensor((128, 128), dtype="float32"))
+        return gv1
+
+    x = relax.Var("x", R.Tensor((128, 128), "float32"))
+    bb = relax.BlockBuilder()
+    with bb.function("main", [x], private=True):
+        func = bb.emit(relax.ExternFunc("extern_func"))
+        y = bb.emit(relax.call_dps_packed(func, x, out_sinfo=R.Tensor((128, 128), "float32")))
+        z = bb.emit(relax.call_dps_packed(func, y, out_sinfo=R.Tensor((128, 128), "float32")))
+        bb.emit_func_output(z)
+
+    expected = bb.get()["main"]
+
+    _check(parsed, expected)
+
+
+def test_extern_func_in_module():
+    """Module-level parsing may produce function bindings"""
+
+    @I.ir_module
+    class parsed_module:
+        my_ext = R.ExternFunc("my_ext")
+
+        @R.function
+        def func(a: R.Tensor((10, 10))) -> R.Tensor((10, 10)):
+            return a
+
+    @R.function
+    def func(a: R.Tensor((10, 10))) -> R.Tensor((10, 10)):
+        return a
+
+    expected = tvm.IRModule({"my_ext": relax.ExternFunc("my_ext"), "func": func})
+
+    _check(parsed_module, expected)
+
+
 if __name__ == "__main__":
     tvm.testing.main()
diff --git a/tests/python/relax/test_tvmscript_printer_relax.py b/tests/python/relax/test_tvmscript_printer_relax.py
index 2e4218b2ab..dc3334f216 100644
--- a/tests/python/relax/test_tvmscript_printer_relax.py
+++ b/tests/python/relax/test_tvmscript_printer_relax.py
@@ -88,7 +88,7 @@ def test_extern_func():
 
 @I.ir_module
 class Module:
-    "my_ext"
+    my_ext = R.ExternFunc("my_ext")
     @R.function
     def func(a: R.Tensor((10, 10))) -> R.Tensor((10, 10)):
         return a
@@ -752,5 +752,57 @@ class Module:
     )
 
 
+def test_reused_extern_func():
+    """An ExternFunc used in a variable binding should be explicit"""
+
+    @R.function
+    def func(x: R.Tensor((128, 128), dtype="float32")) -> R.Tensor((128, 128), dtype="float32"):
+        extern_func = R.ExternFunc("extern_func")
+        y = R.call_dps_packed(extern_func, (x,), out_sinfo=R.Tensor((128, 128), dtype="float32"))
+        z = R.call_dps_packed(extern_func, (y,), out_sinfo=R.Tensor((128, 128), dtype="float32"))
+        return z
+
+    _assert_print(
+        func,
+        """
+# from tvm.script import relax as R
+
+@R.function
+def func(x: R.Tensor((128, 128), dtype="float32")) -> R.Tensor((128, 128), dtype="float32"):
+    extern_func: R.Callable = R.ExternFunc("extern_func")
+    y = R.call_dps_packed(extern_func, (x,), out_sinfo=R.Tensor((128, 128), dtype="float32"))
+    z = R.call_dps_packed(extern_func, (y,), out_sinfo=R.Tensor((128, 128), dtype="float32"))
+    return z
+                  """,
+    )
+
+
+def test_inline_extern_func():
+    """An ExternFunc used in-line may be printed as a string"""
+
+    @R.function
+    def func(x: R.Tensor((128, 128), dtype="float32")) -> R.Tensor((128, 128), dtype="float32"):
+        y = R.call_dps_packed(
+            R.ExternFunc("extern_func"), (x,), out_sinfo=R.Tensor((128, 128), dtype="float32")
+        )
+        z = R.call_dps_packed(
+            R.ExternFunc("extern_func"), (y,), out_sinfo=R.Tensor((128, 128), dtype="float32")
+        )
+        return z
+
+    _assert_print(
+        func,
+        """
+# from tvm.script import relax as R
+
+@R.function
+def func(x: R.Tensor((128, 128), dtype="float32")) -> R.Tensor((128, 128), dtype="float32"):
+    y = R.call_dps_packed("extern_func", (x,), out_sinfo=R.Tensor((128, 128), dtype="float32"))
+    z = R.call_dps_packed("extern_func", (y,), out_sinfo=R.Tensor((128, 128), dtype="float32"))
+    return z
+                  """,
+    )
+
+
 if __name__ == "__main__":
     tvm.testing.main()