You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2023/10/31 20:58:56 UTC

(tvm) branch unity updated: [Unity] Avoid Emitting Redandunt Bindings in TensorExpr Op (#16026)

This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 3c8603789e [Unity] Avoid Emitting Redandunt Bindings in TensorExpr Op (#16026)
3c8603789e is described below

commit 3c8603789e9117f755bc09d73a08bc3efe72abc3
Author: Junru Shao <ju...@apache.org>
AuthorDate: Tue Oct 31 13:58:49 2023 -0700

    [Unity] Avoid Emitting Redandunt Bindings in TensorExpr Op (#16026)
---
 python/tvm/relax/frontend/nn/op.py        |  5 +++--
 tests/python/relax/test_frontend_nn_op.py | 10 +++++-----
 2 files changed, 8 insertions(+), 7 deletions(-)

diff --git a/python/tvm/relax/frontend/nn/op.py b/python/tvm/relax/frontend/nn/op.py
index b4be5cd1ff..03080615e9 100644
--- a/python/tvm/relax/frontend/nn/op.py
+++ b/python/tvm/relax/frontend/nn/op.py
@@ -33,7 +33,7 @@ from .spec import SpecBuilder
 IntExpr = Union[int, _tir.PrimExpr]
 
 
-def _wrap_nested(expr: rx.Expr, name: str) -> Union[Tensor, Tuple[Tensor, ...]]:
+def _wrap_nested(expr: rx.Expr, name: str) -> Union[Tensor, Sequence[Tensor]]:
     """Wrap the given relax.Expr, emit it using the current BlockBuilder,
     and automatically handle nested cases if the expr represents a Tuple.
 
@@ -50,7 +50,8 @@ def _wrap_nested(expr: rx.Expr, name: str) -> Union[Tensor, Tuple[Tensor, ...]]:
     result : Union[Tensor, Tuple[Tensor]]
         The computed result.
     """
-    expr = BlockBuilder.current().emit(expr, name)
+    if not isinstance(expr, rx.DataflowVar):
+        expr = BlockBuilder.current().emit(expr, name)
     if isinstance(expr.struct_info_, TensorStructInfo):
         return Tensor(_expr=expr)
     if isinstance(expr.struct_info_, TupleStructInfo):
diff --git a/tests/python/relax/test_frontend_nn_op.py b/tests/python/relax/test_frontend_nn_op.py
index fd77b76f9f..57bf6d273c 100644
--- a/tests/python/relax/test_frontend_nn_op.py
+++ b/tests/python/relax/test_frontend_nn_op.py
@@ -14,10 +14,11 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+import io
+import sys
+
 import pytest
 import torch
-import sys
-import io
 
 import tvm
 import tvm.testing
@@ -447,7 +448,7 @@ def test_tensor_expr_op():
             tensor_expr_op_out = op.tensor_expr_op(
                 tensor_expr_func=lambda x: x + 1, name_hint="add_one", args=[x]
             )
-            return x
+            return tensor_expr_op_out
 
     # fmt: off
     @I.ir_module
@@ -478,8 +479,7 @@ def test_tensor_expr_op():
             R.func_attr({"num_input": 2})
             with R.dataflow():
                 lv1 = R.call_tir(cls.add_one, (x,), out_sinfo=R.Tensor((10, 10), dtype="float32"))
-                add_one1: R.Tensor((10, 10), dtype="float32") = lv1
-                gv1: R.Tuple(R.Tensor((10, 10), dtype="float32"), R.Tuple(R.Object)) = x, (_io,)
+                gv1: R.Tuple(R.Tensor((10, 10), dtype="float32"), R.Tuple(R.Object)) = lv1, (_io,)
                 R.output(gv1)
             return gv1
     # fmt: on