You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2022/06/28 19:07:52 UTC
[tvm] branch main updated: [Relay][VirtualDevice] Expose WithFields to Python to do proper copy in ExprMutator (#11882)
This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 6c433d2309 [Relay][VirtualDevice] Expose WithFields to Python to do proper copy in ExprMutator (#11882)
6c433d2309 is described below
commit 6c433d2309ffe4ca6954d3ca420027cdfa944fa0
Author: Rafael Stahl <du...@web.de>
AuthorDate: Tue Jun 28 21:07:45 2022 +0200
[Relay][VirtualDevice] Expose WithFields to Python to do proper copy in ExprMutator (#11882)
* [Relay][VirtualDevice] Expose WithFields to Python to do proper copy in ExprMutator
* [Relay] give FunctionWithFields optional arguments
* [lint] fix wrong line length
* [lint] missing newline
* [doc] add doc string to FunctionWithFields
---
python/tvm/relay/expr_functor.py | 8 ++++++--
python/tvm/relay/function.py | 21 +++++++++++++++++++++
src/relay/ir/function.cc | 8 ++++++++
3 files changed, 35 insertions(+), 2 deletions(-)
diff --git a/python/tvm/relay/expr_functor.py b/python/tvm/relay/expr_functor.py
index b9ca7d0e11..ebea344b41 100644
--- a/python/tvm/relay/expr_functor.py
+++ b/python/tvm/relay/expr_functor.py
@@ -18,7 +18,7 @@
"""The expression functor of Relay."""
from tvm.ir import Op
-from .function import Function
+from .function import Function, FunctionWithFields
from .expr import Call, Let, Var, GlobalVar
from .expr import If, Tuple, TupleGetItem, Constant
from .expr import RefCreate, RefRead, RefWrite
@@ -204,7 +204,11 @@ class ExprMutator(ExprFunctor):
def visit_function(self, fn):
new_params = [self.visit(x) for x in fn.params]
new_body = self.visit(fn.body)
- return Function(list(new_params), new_body, fn.ret_type, fn.type_params, fn.attrs)
+ return FunctionWithFields(
+ fn,
+ list(new_params),
+ new_body,
+ )
def visit_let(self, let):
new_var = self.visit(let.var)
diff --git a/python/tvm/relay/function.py b/python/tvm/relay/function.py
index f889f1e596..6b3513cb5e 100644
--- a/python/tvm/relay/function.py
+++ b/python/tvm/relay/function.py
@@ -63,3 +63,24 @@ class Function(BaseFunc):
Arguments.
"""
return Call(self, args, None, None)
+
+
+@tvm._ffi.register_func("relay.FunctionWithFields")
+def FunctionWithFields(
+ function,
+ params=None,
+ body=None,
+ ret_type=None,
+ ty_params=None,
+ attrs=None,
+ virtual_device=None,
+ span=None,
+):
+ """
+ Returns function with the given properties. A None property denotes 'no change'.
+ Returns function if all properties are unchanged. Otherwise, returns a copy with the new
+ fields.
+ """
+ return _ffi_api.FunctionWithFields(
+ function, params, body, ret_type, ty_params, attrs, virtual_device, span
+ )
diff --git a/src/relay/ir/function.cc b/src/relay/ir/function.cc
index 63e74144e0..1a3db9974f 100644
--- a/src/relay/ir/function.cc
+++ b/src/relay/ir/function.cc
@@ -127,6 +127,14 @@ TVM_REGISTER_GLOBAL("relay.ir.Function")
tvm::Array<TypeVar> ty_params, tvm::DictAttrs attrs) {
return Function(params, body, ret_type, ty_params, attrs);
});
+TVM_REGISTER_GLOBAL("relay.ir.FunctionWithFields")
+ .set_body_typed([](Function function, Optional<Array<Var>> opt_params, Optional<Expr> opt_body,
+ Optional<Type> opt_ret_type, Optional<Array<TypeVar>> opt_ty_params,
+ Optional<DictAttrs> opt_attrs, Optional<VirtualDevice> opt_virtual_device,
+ Optional<Span> opt_span) {
+ return WithFields(function, opt_params, opt_body, opt_ret_type, opt_ty_params, opt_attrs,
+ opt_virtual_device, opt_span);
+ });
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<FunctionNode>([](const ObjectRef& ref, ReprPrinter* p) {