You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by zh...@apache.org on 2020/12/30 01:43:31 UTC
[tvm] branch main updated: [Relay][fix] Stack should take exprs
that evaluate to tuples (#7130)
This is an automated email from the ASF dual-hosted git repository.
zhic pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 466383a [Relay][fix] Stack should take exprs that evaluate to tuples (#7130)
466383a is described below
commit 466383a232097b5c17733e347c2cb4a8ba14d972
Author: Steven S. Lyubomirsky <ss...@cs.washington.edu>
AuthorDate: Tue Dec 29 20:43:12 2020 -0500
[Relay][fix] Stack should take exprs that evaluate to tuples (#7130)
* Fix stack to take Relay exprs that evaluate to tuples
* Doc tweak
* Linting fix
---
python/tvm/relay/op/tensor.py | 9 +++---
tests/python/relay/test_op_level3.py | 60 +++++++++++++++++++++++++++---------
2 files changed, 50 insertions(+), 19 deletions(-)
diff --git a/python/tvm/relay/op/tensor.py b/python/tvm/relay/op/tensor.py
index 453a9b7..75e2987 100644
--- a/python/tvm/relay/op/tensor.py
+++ b/python/tvm/relay/op/tensor.py
@@ -1105,8 +1105,8 @@ def stack(data, axis):
Parameters
----------
- data : Union(List[relay.Expr], Tuple(relay.Expr))
- A list of tensors.
+ data : Union(List[relay.Expr], relay.Expr)
+ A list of tensors or a Relay expression that evaluates to a tuple of tensors.
axis : int
The axis in the result array along which the input arrays are stacked.
@@ -1116,12 +1116,13 @@ def stack(data, axis):
ret : relay.Expr
The stacked tensor.
"""
- data = list(data)
if not data:
raise ValueError("relay.stack requires data to be non-empty.")
if not isinstance(axis, int):
raise ValueError("For now, we only support integer axis")
- return _make.stack(Tuple(data), axis)
+ if not isinstance(data, Expr):
+ data = Tuple(list(data))
+ return _make.stack(data, axis)
def copy(data):
diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py
index 668285d..5e44170 100644
--- a/tests/python/relay/test_op_level3.py
+++ b/tests/python/relay/test_op_level3.py
@@ -787,28 +787,58 @@ def test_repeat():
@tvm.testing.uses_gpu
def test_stack():
- def verify_stack(dshapes, axis):
- y = []
- for shape in dshapes:
- y.append(relay.var("input", relay.TensorType(shape, "float32")))
- x = relay.Tuple(y)
- z = relay.stack(x, axis=axis)
+ def produce_input_tuple(dshapes):
+ y = [relay.var("input", relay.TensorType(shape, "float32")) for shape in dshapes]
+ return relay.Tuple(y)
- func = relay.Function(y, z)
- x_data = [np.random.normal(size=shape).astype("float32") for shape in dshapes]
- ref_res = np.stack(x_data, axis=axis)
+ def ref_stack(inputs, axis):
+ return np.stack(inputs, axis=axis)
+
+ def verify_stack(input_expr, relay_args, ref_res, axis):
+ z = relay.stack(input_expr, axis=axis)
+ inp_vars = relay.analysis.free_vars(z)
+ func = relay.Function(inp_vars, z)
for target, ctx in tvm.testing.enabled_targets():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
- op_res = intrp.evaluate(func)(*x_data)
+ op_res = intrp.evaluate(func)(*relay_args)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)
- verify_stack([(2,), (2,), (2,)], -1)
- verify_stack([(2,), (2,), (2,)], 0)
- verify_stack([(2, 2, 4), (2, 2, 4), (2, 2, 4)], 1)
- verify_stack([(2, 2, 3, 4), (2, 2, 3, 4), (2, 2, 3, 4), (2, 2, 3, 4)], -1)
- verify_stack([(2, 2, 3, 4), (2, 2, 3, 4), (2, 2, 3, 4), (2, 2, 3, 4)], 4)
+ def verify_tup_lit_stack(dshapes, axis):
+ input_tuple = produce_input_tuple(dshapes)
+ input_data = [np.random.normal(size=shape).astype("float32") for shape in dshapes]
+ ref_res = ref_stack(input_data, axis)
+ verify_stack(input_tuple, input_data, ref_res, axis)
+
+ def verify_list_lit_stack(dshapes, axis):
+ input_list = produce_input_tuple(dshapes).fields
+ input_data = [np.random.normal(size=shape).astype("float32") for shape in dshapes]
+ ref_res = ref_stack(input_data, axis)
+ verify_stack(input_list, input_data, ref_res, axis)
+
+ def verify_tup_expr_stack(dshapes, axis):
+ input_data = [np.random.normal(size=shape).astype("float32") for shape in dshapes]
+ ref_res = ref_stack(input_data, axis)
+
+ # expression that evaluates to a tuple
+ # but is not a tuple literal
+ x = relay.Var("x")
+ input_expr = relay.Let(x, relay.Tuple([relay.const(inp) for inp in input_data]), x)
+ verify_stack(input_expr, [], ref_res, axis)
+
+ dshape_axis_combos = [
+ ([(2,), (2,), (2,)], -1),
+ ([(2,), (2,), (2,)], 0),
+ ([(2, 2, 4), (2, 2, 4), (2, 2, 4)], 1),
+ ([(2, 2, 3, 4), (2, 2, 3, 4), (2, 2, 3, 4), (2, 2, 3, 4)], -1),
+ ([(2, 2, 3, 4), (2, 2, 3, 4), (2, 2, 3, 4), (2, 2, 3, 4)], 4),
+ ]
+
+ for dshapes, axis in dshape_axis_combos:
+ verify_tup_lit_stack(dshapes, axis)
+ verify_list_lit_stack(dshapes, axis)
+ verify_tup_expr_stack(dshapes, axis)
@tvm.testing.uses_gpu