You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by sy...@apache.org on 2024/02/19 03:07:11 UTC
(tvm) branch main updated: [Unity][TVMScript] Parse R.Object return type from call_pure_packed (#16593)
This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new dd709412eb [Unity][TVMScript] Parse R.Object return type from call_pure_packed (#16593)
dd709412eb is described below
commit dd709412eb89b73e5bf41172b9d73360082816da
Author: Eric Lunderberg <Lu...@users.noreply.github.com>
AuthorDate: Sun Feb 18 21:07:05 2024 -0600
[Unity][TVMScript] Parse R.Object return type from call_pure_packed (#16593)
Prior to this commit, `R.call_packed` and `R.call_pure_packed` had
different normalization for the `sinfo_args` argument. While
`R.call_packed` checked if the struct info needed to be converted using
`ObjectGeneric.asobject()`, `R.call_pure_packed` did not.
This commit updates the `R.call_pure_packed` to handle `sinfo_args`
in the same manner as `R.call_packed`.
---
python/tvm/relax/op/base.py | 13 +++++++++++++
python/tvm/script/ir_builder/relax/ir.py | 16 +++++++++-------
tests/python/relax/test_tvmscript_parser.py | 14 ++++++++++++++
3 files changed, 36 insertions(+), 7 deletions(-)
diff --git a/python/tvm/relax/op/base.py b/python/tvm/relax/op/base.py
index 92235ffb47..3effec242d 100644
--- a/python/tvm/relax/op/base.py
+++ b/python/tvm/relax/op/base.py
@@ -21,6 +21,7 @@ from typing import Dict, Union, List, Tuple, Optional, Callable
import tvm
import tvm.runtime
from tvm.runtime.object import Object
+from tvm.runtime import ObjectGeneric
from . import _ffi_api
from ..expr import Expr, StringImm, ShapeExpr, Call, ExternFunc, GlobalVar, Var
@@ -709,12 +710,24 @@ def call_pure_packed(
func = func.global_symbol
op = ExternFunc(func)
+
if sinfo_args is None:
raise ValueError("R.call_pure_packed is required to have type_args")
+
if isinstance(sinfo_args, tuple): # type: ignore
sinfo_args = list(sinfo_args)
elif not isinstance(sinfo_args, list):
sinfo_args = [sinfo_args]
+
+ sinfo_args = [
+ sinfo()
+ if callable(sinfo)
+ else sinfo.asobject()
+ if isinstance(sinfo, ObjectGeneric)
+ else sinfo
+ for sinfo in sinfo_args
+ ]
+
# note: if we need attributes, we can also take them here
return _ffi_api.call_pure_packed(op, args, None, sinfo_args) # type: ignore # pylint: disable=no-member
diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py
index 6447178909..3e1927290d 100644
--- a/python/tvm/script/ir_builder/relax/ir.py
+++ b/python/tvm/script/ir_builder/relax/ir.py
@@ -357,13 +357,15 @@ def call_packed(
sinfo_args = list(sinfo_args)
elif not isinstance(sinfo_args, list):
sinfo_args = [sinfo_args]
- for i, sinfo_arg in enumerate(sinfo_args):
- if callable(sinfo_arg):
- sinfo_arg = sinfo_arg()
- # Convert possible StructInfoProxy to StructInfo
- if isinstance(sinfo_arg, ObjectGeneric):
- sinfo_arg = sinfo_arg.asobject()
- sinfo_args[i] = sinfo_arg
+
+ sinfo_args = [
+ sinfo()
+ if callable(sinfo)
+ else sinfo.asobject()
+ if isinstance(sinfo, ObjectGeneric)
+ else sinfo
+ for sinfo in sinfo_args
+ ]
is_default = False
if "attrs_type_key" in kwargs:
diff --git a/tests/python/relax/test_tvmscript_parser.py b/tests/python/relax/test_tvmscript_parser.py
index 71970ad965..01e71fa263 100644
--- a/tests/python/relax/test_tvmscript_parser.py
+++ b/tests/python/relax/test_tvmscript_parser.py
@@ -1800,6 +1800,20 @@ def test_call_pure_packed():
_check(foo, bb.get()["foo"])
+def test_call_pure_packed_returning_object():
+ @R.function
+ def foo() -> R.Object:
+ z = R.call_pure_packed("dummy_func", sinfo_args=R.Object)
+ return z
+
+ bb = relax.BlockBuilder()
+ with bb.function("foo", params=[]):
+ z = bb.emit(R.call_pure_packed("dummy_func", sinfo_args=[relax.ObjectStructInfo()]))
+ bb.emit_func_output(z)
+
+ _check(foo, bb.get()["foo"])
+
+
def test_private_function():
@I.ir_module
class Addition: