You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2023/10/31 20:58:56 UTC
(tvm) branch unity updated: [Unity] Avoid Emitting Redandunt Bindings in TensorExpr Op (#16026)
This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 3c8603789e [Unity] Avoid Emitting Redandunt Bindings in TensorExpr Op (#16026)
3c8603789e is described below
commit 3c8603789e9117f755bc09d73a08bc3efe72abc3
Author: Junru Shao <ju...@apache.org>
AuthorDate: Tue Oct 31 13:58:49 2023 -0700
[Unity] Avoid Emitting Redandunt Bindings in TensorExpr Op (#16026)
---
python/tvm/relax/frontend/nn/op.py | 5 +++--
tests/python/relax/test_frontend_nn_op.py | 10 +++++-----
2 files changed, 8 insertions(+), 7 deletions(-)
diff --git a/python/tvm/relax/frontend/nn/op.py b/python/tvm/relax/frontend/nn/op.py
index b4be5cd1ff..03080615e9 100644
--- a/python/tvm/relax/frontend/nn/op.py
+++ b/python/tvm/relax/frontend/nn/op.py
@@ -33,7 +33,7 @@ from .spec import SpecBuilder
IntExpr = Union[int, _tir.PrimExpr]
-def _wrap_nested(expr: rx.Expr, name: str) -> Union[Tensor, Tuple[Tensor, ...]]:
+def _wrap_nested(expr: rx.Expr, name: str) -> Union[Tensor, Sequence[Tensor]]:
"""Wrap the given relax.Expr, emit it using the current BlockBuilder,
and automatically handle nested cases if the expr represents a Tuple.
@@ -50,7 +50,8 @@ def _wrap_nested(expr: rx.Expr, name: str) -> Union[Tensor, Tuple[Tensor, ...]]:
result : Union[Tensor, Tuple[Tensor]]
The computed result.
"""
- expr = BlockBuilder.current().emit(expr, name)
+ if not isinstance(expr, rx.DataflowVar):
+ expr = BlockBuilder.current().emit(expr, name)
if isinstance(expr.struct_info_, TensorStructInfo):
return Tensor(_expr=expr)
if isinstance(expr.struct_info_, TupleStructInfo):
diff --git a/tests/python/relax/test_frontend_nn_op.py b/tests/python/relax/test_frontend_nn_op.py
index fd77b76f9f..57bf6d273c 100644
--- a/tests/python/relax/test_frontend_nn_op.py
+++ b/tests/python/relax/test_frontend_nn_op.py
@@ -14,10 +14,11 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+import io
+import sys
+
import pytest
import torch
-import sys
-import io
import tvm
import tvm.testing
@@ -447,7 +448,7 @@ def test_tensor_expr_op():
tensor_expr_op_out = op.tensor_expr_op(
tensor_expr_func=lambda x: x + 1, name_hint="add_one", args=[x]
)
- return x
+ return tensor_expr_op_out
# fmt: off
@I.ir_module
@@ -478,8 +479,7 @@ def test_tensor_expr_op():
R.func_attr({"num_input": 2})
with R.dataflow():
lv1 = R.call_tir(cls.add_one, (x,), out_sinfo=R.Tensor((10, 10), dtype="float32"))
- add_one1: R.Tensor((10, 10), dtype="float32") = lv1
- gv1: R.Tuple(R.Tensor((10, 10), dtype="float32"), R.Tuple(R.Object)) = x, (_io,)
+ gv1: R.Tuple(R.Tensor((10, 10), dtype="float32"), R.Tuple(R.Object)) = lv1, (_io,)
R.output(gv1)
return gv1
# fmt: on