You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by zh...@apache.org on 2020/12/30 01:43:31 UTC

[tvm] branch main updated: [Relay][fix] Stack should take exprs that evaluate to tuples (#7130)

This is an automated email from the ASF dual-hosted git repository.

zhic pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 466383a  [Relay][fix] Stack should take exprs that evaluate to tuples (#7130)
466383a is described below

commit 466383a232097b5c17733e347c2cb4a8ba14d972
Author: Steven S. Lyubomirsky <ss...@cs.washington.edu>
AuthorDate: Tue Dec 29 20:43:12 2020 -0500

    [Relay][fix] Stack should take exprs that evaluate to tuples (#7130)
    
    * Fix stack to take Relay exprs that evaluate to tuples
    
    * Doc tweak
    
    * Linting fix
---
 python/tvm/relay/op/tensor.py        |  9 +++---
 tests/python/relay/test_op_level3.py | 60 +++++++++++++++++++++++++++---------
 2 files changed, 50 insertions(+), 19 deletions(-)

diff --git a/python/tvm/relay/op/tensor.py b/python/tvm/relay/op/tensor.py
index 453a9b7..75e2987 100644
--- a/python/tvm/relay/op/tensor.py
+++ b/python/tvm/relay/op/tensor.py
@@ -1105,8 +1105,8 @@ def stack(data, axis):
 
     Parameters
     ----------
-    data : Union(List[relay.Expr], Tuple(relay.Expr))
-        A list of tensors.
+    data : Union(List[relay.Expr], relay.Expr)
+        A list of tensors or a Relay expression that evaluates to a tuple of tensors.
 
     axis : int
         The axis in the result array along which the input arrays are stacked.
@@ -1116,12 +1116,13 @@ def stack(data, axis):
     ret : relay.Expr
         The stacked tensor.
     """
-    data = list(data)
     if not data:
         raise ValueError("relay.stack requires data to be non-empty.")
     if not isinstance(axis, int):
         raise ValueError("For now, we only support integer axis")
-    return _make.stack(Tuple(data), axis)
+    if not isinstance(data, Expr):
+        data = Tuple(list(data))
+    return _make.stack(data, axis)
 
 
 def copy(data):
diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py
index 668285d..5e44170 100644
--- a/tests/python/relay/test_op_level3.py
+++ b/tests/python/relay/test_op_level3.py
@@ -787,28 +787,58 @@ def test_repeat():
 
 @tvm.testing.uses_gpu
 def test_stack():
-    def verify_stack(dshapes, axis):
-        y = []
-        for shape in dshapes:
-            y.append(relay.var("input", relay.TensorType(shape, "float32")))
-        x = relay.Tuple(y)
-        z = relay.stack(x, axis=axis)
+    def produce_input_tuple(dshapes):
+        y = [relay.var("input", relay.TensorType(shape, "float32")) for shape in dshapes]
+        return relay.Tuple(y)
 
-        func = relay.Function(y, z)
-        x_data = [np.random.normal(size=shape).astype("float32") for shape in dshapes]
-        ref_res = np.stack(x_data, axis=axis)
+    def ref_stack(inputs, axis):
+        return np.stack(inputs, axis=axis)
+
+    def verify_stack(input_expr, relay_args, ref_res, axis):
+        z = relay.stack(input_expr, axis=axis)
+        inp_vars = relay.analysis.free_vars(z)
+        func = relay.Function(inp_vars, z)
 
         for target, ctx in tvm.testing.enabled_targets():
             for kind in ["graph", "debug"]:
                 intrp = relay.create_executor(kind, ctx=ctx, target=target)
-                op_res = intrp.evaluate(func)(*x_data)
+                op_res = intrp.evaluate(func)(*relay_args)
                 tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)
 
-    verify_stack([(2,), (2,), (2,)], -1)
-    verify_stack([(2,), (2,), (2,)], 0)
-    verify_stack([(2, 2, 4), (2, 2, 4), (2, 2, 4)], 1)
-    verify_stack([(2, 2, 3, 4), (2, 2, 3, 4), (2, 2, 3, 4), (2, 2, 3, 4)], -1)
-    verify_stack([(2, 2, 3, 4), (2, 2, 3, 4), (2, 2, 3, 4), (2, 2, 3, 4)], 4)
+    def verify_tup_lit_stack(dshapes, axis):
+        input_tuple = produce_input_tuple(dshapes)
+        input_data = [np.random.normal(size=shape).astype("float32") for shape in dshapes]
+        ref_res = ref_stack(input_data, axis)
+        verify_stack(input_tuple, input_data, ref_res, axis)
+
+    def verify_list_lit_stack(dshapes, axis):
+        input_list = produce_input_tuple(dshapes).fields
+        input_data = [np.random.normal(size=shape).astype("float32") for shape in dshapes]
+        ref_res = ref_stack(input_data, axis)
+        verify_stack(input_list, input_data, ref_res, axis)
+
+    def verify_tup_expr_stack(dshapes, axis):
+        input_data = [np.random.normal(size=shape).astype("float32") for shape in dshapes]
+        ref_res = ref_stack(input_data, axis)
+
+        # expression that evaluates to a tuple
+        # but is not a tuple literal
+        x = relay.Var("x")
+        input_expr = relay.Let(x, relay.Tuple([relay.const(inp) for inp in input_data]), x)
+        verify_stack(input_expr, [], ref_res, axis)
+
+    dshape_axis_combos = [
+        ([(2,), (2,), (2,)], -1),
+        ([(2,), (2,), (2,)], 0),
+        ([(2, 2, 4), (2, 2, 4), (2, 2, 4)], 1),
+        ([(2, 2, 3, 4), (2, 2, 3, 4), (2, 2, 3, 4), (2, 2, 3, 4)], -1),
+        ([(2, 2, 3, 4), (2, 2, 3, 4), (2, 2, 3, 4), (2, 2, 3, 4)], 4),
+    ]
+
+    for dshapes, axis in dshape_axis_combos:
+        verify_tup_lit_stack(dshapes, axis)
+        verify_list_lit_stack(dshapes, axis)
+        verify_tup_expr_stack(dshapes, axis)
 
 
 @tvm.testing.uses_gpu