You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by sy...@apache.org on 2024/02/19 03:07:11 UTC

(tvm) branch main updated: [Unity][TVMScript] Parse R.Object return type from call_pure_packed (#16593)

This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new dd709412eb [Unity][TVMScript] Parse R.Object return type from call_pure_packed (#16593)
dd709412eb is described below

commit dd709412eb89b73e5bf41172b9d73360082816da
Author: Eric Lunderberg <Lu...@users.noreply.github.com>
AuthorDate: Sun Feb 18 21:07:05 2024 -0600

    [Unity][TVMScript] Parse R.Object return type from call_pure_packed (#16593)
    
    Prior to this commit, `R.call_packed` and `R.call_pure_packed` had
    different normalization for the `sinfo_args` argument.  While
    `R.call_packed` checked if the struct info needed to be converted using
    `ObjectGeneric.asobject()`, `R.call_pure_packed` did not.
    
    This commit updates the `R.call_pure_packed` to handle `sinfo_args`
    in the same manner as `R.call_packed`.
---
 python/tvm/relax/op/base.py                 | 13 +++++++++++++
 python/tvm/script/ir_builder/relax/ir.py    | 16 +++++++++-------
 tests/python/relax/test_tvmscript_parser.py | 14 ++++++++++++++
 3 files changed, 36 insertions(+), 7 deletions(-)

diff --git a/python/tvm/relax/op/base.py b/python/tvm/relax/op/base.py
index 92235ffb47..3effec242d 100644
--- a/python/tvm/relax/op/base.py
+++ b/python/tvm/relax/op/base.py
@@ -21,6 +21,7 @@ from typing import Dict, Union, List, Tuple, Optional, Callable
 import tvm
 import tvm.runtime
 from tvm.runtime.object import Object
+from tvm.runtime import ObjectGeneric
 
 from . import _ffi_api
 from ..expr import Expr, StringImm, ShapeExpr, Call, ExternFunc, GlobalVar, Var
@@ -709,12 +710,24 @@ def call_pure_packed(
         func = func.global_symbol
 
     op = ExternFunc(func)
+
     if sinfo_args is None:
         raise ValueError("R.call_pure_packed is required to have type_args")
+
     if isinstance(sinfo_args, tuple):  # type: ignore
         sinfo_args = list(sinfo_args)
     elif not isinstance(sinfo_args, list):
         sinfo_args = [sinfo_args]
+
+    sinfo_args = [
+        sinfo()
+        if callable(sinfo)
+        else sinfo.asobject()
+        if isinstance(sinfo, ObjectGeneric)
+        else sinfo
+        for sinfo in sinfo_args
+    ]
+
     # note: if we need attributes, we can also take them here
 
     return _ffi_api.call_pure_packed(op, args, None, sinfo_args)  # type: ignore # pylint: disable=no-member
diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py
index 6447178909..3e1927290d 100644
--- a/python/tvm/script/ir_builder/relax/ir.py
+++ b/python/tvm/script/ir_builder/relax/ir.py
@@ -357,13 +357,15 @@ def call_packed(
         sinfo_args = list(sinfo_args)
     elif not isinstance(sinfo_args, list):
         sinfo_args = [sinfo_args]
-    for i, sinfo_arg in enumerate(sinfo_args):
-        if callable(sinfo_arg):
-            sinfo_arg = sinfo_arg()
-        # Convert possible StructInfoProxy to StructInfo
-        if isinstance(sinfo_arg, ObjectGeneric):
-            sinfo_arg = sinfo_arg.asobject()
-        sinfo_args[i] = sinfo_arg
+
+    sinfo_args = [
+        sinfo()
+        if callable(sinfo)
+        else sinfo.asobject()
+        if isinstance(sinfo, ObjectGeneric)
+        else sinfo
+        for sinfo in sinfo_args
+    ]
 
     is_default = False
     if "attrs_type_key" in kwargs:
diff --git a/tests/python/relax/test_tvmscript_parser.py b/tests/python/relax/test_tvmscript_parser.py
index 71970ad965..01e71fa263 100644
--- a/tests/python/relax/test_tvmscript_parser.py
+++ b/tests/python/relax/test_tvmscript_parser.py
@@ -1800,6 +1800,20 @@ def test_call_pure_packed():
     _check(foo, bb.get()["foo"])
 
 
+def test_call_pure_packed_returning_object():
+    @R.function
+    def foo() -> R.Object:
+        z = R.call_pure_packed("dummy_func", sinfo_args=R.Object)
+        return z
+
+    bb = relax.BlockBuilder()
+    with bb.function("foo", params=[]):
+        z = bb.emit(R.call_pure_packed("dummy_func", sinfo_args=[relax.ObjectStructInfo()]))
+        bb.emit_func_output(z)
+
+    _check(foo, bb.get()["foo"])
+
+
 def test_private_function():
     @I.ir_module
     class Addition: