You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2022/06/28 19:07:52 UTC

[tvm] branch main updated: [Relay][VirtualDevice] Expose WithFields to Python to do proper copy in ExprMutator (#11882)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 6c433d2309 [Relay][VirtualDevice] Expose WithFields to Python to do proper copy in ExprMutator (#11882)
6c433d2309 is described below

commit 6c433d2309ffe4ca6954d3ca420027cdfa944fa0
Author: Rafael Stahl <du...@web.de>
AuthorDate: Tue Jun 28 21:07:45 2022 +0200

    [Relay][VirtualDevice] Expose WithFields to Python to do proper copy in ExprMutator (#11882)
    
    * [Relay][VirtualDevice] Expose WithFields to Python to do proper copy in ExprMutator
    
    * [Relay] give FunctionWithFields optional arguments
    
    * [lint] fix wrong line length
    
    * [lint] missing newline
    
    * [doc] add doc string to FunctionWithFields
---
 python/tvm/relay/expr_functor.py |  8 ++++++--
 python/tvm/relay/function.py     | 21 +++++++++++++++++++++
 src/relay/ir/function.cc         |  8 ++++++++
 3 files changed, 35 insertions(+), 2 deletions(-)

diff --git a/python/tvm/relay/expr_functor.py b/python/tvm/relay/expr_functor.py
index b9ca7d0e11..ebea344b41 100644
--- a/python/tvm/relay/expr_functor.py
+++ b/python/tvm/relay/expr_functor.py
@@ -18,7 +18,7 @@
 """The expression functor of Relay."""
 from tvm.ir import Op
 
-from .function import Function
+from .function import Function, FunctionWithFields
 from .expr import Call, Let, Var, GlobalVar
 from .expr import If, Tuple, TupleGetItem, Constant
 from .expr import RefCreate, RefRead, RefWrite
@@ -204,7 +204,11 @@ class ExprMutator(ExprFunctor):
     def visit_function(self, fn):
         new_params = [self.visit(x) for x in fn.params]
         new_body = self.visit(fn.body)
-        return Function(list(new_params), new_body, fn.ret_type, fn.type_params, fn.attrs)
+        return FunctionWithFields(
+            fn,
+            list(new_params),
+            new_body,
+        )
 
     def visit_let(self, let):
         new_var = self.visit(let.var)
diff --git a/python/tvm/relay/function.py b/python/tvm/relay/function.py
index f889f1e596..6b3513cb5e 100644
--- a/python/tvm/relay/function.py
+++ b/python/tvm/relay/function.py
@@ -63,3 +63,24 @@ class Function(BaseFunc):
             Arguments.
         """
         return Call(self, args, None, None)
+
+
+@tvm._ffi.register_func("relay.FunctionWithFields")
+def FunctionWithFields(
+    function,
+    params=None,
+    body=None,
+    ret_type=None,
+    ty_params=None,
+    attrs=None,
+    virtual_device=None,
+    span=None,
+):
+    """
+    Returns function with the given properties. A None property denotes 'no change'.
+    Returns function if all properties are unchanged. Otherwise, returns a copy with the new
+    fields.
+    """
+    return _ffi_api.FunctionWithFields(
+        function, params, body, ret_type, ty_params, attrs, virtual_device, span
+    )
diff --git a/src/relay/ir/function.cc b/src/relay/ir/function.cc
index 63e74144e0..1a3db9974f 100644
--- a/src/relay/ir/function.cc
+++ b/src/relay/ir/function.cc
@@ -127,6 +127,14 @@ TVM_REGISTER_GLOBAL("relay.ir.Function")
                        tvm::Array<TypeVar> ty_params, tvm::DictAttrs attrs) {
       return Function(params, body, ret_type, ty_params, attrs);
     });
+TVM_REGISTER_GLOBAL("relay.ir.FunctionWithFields")
+    .set_body_typed([](Function function, Optional<Array<Var>> opt_params, Optional<Expr> opt_body,
+                       Optional<Type> opt_ret_type, Optional<Array<TypeVar>> opt_ty_params,
+                       Optional<DictAttrs> opt_attrs, Optional<VirtualDevice> opt_virtual_device,
+                       Optional<Span> opt_span) {
+      return WithFields(function, opt_params, opt_body, opt_ret_type, opt_ty_params, opt_attrs,
+                        opt_virtual_device, opt_span);
+    });
 
 TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
     .set_dispatch<FunctionNode>([](const ObjectRef& ref, ReprPrinter* p) {