You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2021/02/26 01:41:53 UTC

[tvm] branch main updated: [Relay] Enforce static dim for non-concat axis if one or more tensors have static dim (#7487)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 63ea8e1  [Relay] Enforce static dim for non-concat axis if one or more tensors have static dim (#7487)
63ea8e1 is described below

commit 63ea8e1fc934229a7fb56cac642a588ff3337e6e
Author: masahi <ma...@gmail.com>
AuthorDate: Fri Feb 26 10:41:13 2021 +0900

    [Relay] Enforce static dim for non-concat axis if one or more tensors have static dim (#7487)
    
    * enforce static dim for non-concat axis
    
    * assign any when all dims are dyn
    
    * add missing case
    
    * simplify
    
    * add test
    
    * only enforce static dim constraint if concat output is dynamic
    
    * more update to concat type rel
    
    * update tests
    
    * fixed compile warning
---
 src/relay/op/tensor/transform.h | 69 +++++++++++++++++++++++++++++++----------
 tests/python/relay/test_any.py  | 21 +++++++++++++
 2 files changed, 73 insertions(+), 17 deletions(-)

diff --git a/src/relay/op/tensor/transform.h b/src/relay/op/tensor/transform.h
index 95a83a9..dbf8537 100644
--- a/src/relay/op/tensor/transform.h
+++ b/src/relay/op/tensor/transform.h
@@ -101,29 +101,64 @@ bool ConcatenateRel(const Array<Type>& types, int num_inputs, const Attrs& attrs
   }
 
   // Calculate shape
-  std::vector<IndexExpr> oshape(first->shape.begin(), first->shape.end());
-  int data_length = static_cast<int>(tensor_tuple->fields.size());
+  std::vector<IndexExpr> oshape(ndim);
+  const size_t data_length = tensor_tuple->fields.size();
+
+  // Accumulate the concat axis output dim or decide if this is dynamic concat
+  bool is_dynamic_concat = false;
+  std::vector<TensorType> input_tensors;
+  IndexExpr concat_output_dim = first->shape[axis];
+  for (size_t i = 0; i < data_length; ++i) {
+    const auto& e = Downcast<TensorType>(tensor_tuple->fields[i]);
+    input_tensors.push_back(e);
+    if (e->shape[axis].as<AnyNode>()) {
+      is_dynamic_concat = true;
+      concat_output_dim = Any();
+    } else if (i > 0 && !is_dynamic_concat) {
+      // accumulate axis dimension
+      concat_output_dim += e->shape[axis];
+    }
+  }
+
+  oshape[axis] = concat_output_dim;
+
   for (int i = 0; i < ndim; ++i) {
+    if (i == axis) {
+      // The concat axis is already handled above.
+      // The rest of the body sets the output shape for non-concat axes
+      continue;
+    }
     std::vector<IndexExpr> non_any;
-    for (int j = 0; j < data_length; ++j) {
-      const auto& e = Downcast<TensorType>(tensor_tuple->fields[j]);
+    for (size_t j = 0; j < data_length; ++j) {
+      const auto& e = input_tensors[j];
       if (!e->shape[i].as<AnyNode>()) {
         non_any.push_back(e->shape[i]);
-        // accumulate axis dimension
-        if (j > 0 && i == axis && !oshape[i].as<AnyNode>()) {
-          oshape[i] += e->shape[i];
-        }
       }
     }
-    int non_any_size = static_cast<int>(non_any.size());
-    if (non_any_size != data_length) oshape[i] = Any();
-    if (i != axis) {
-      for (int k = 1; k < non_any_size; k++) {
-        if (reporter->AssertEQ(non_any[0], non_any[k])) continue;
-        throw Error(
-            "relay.concatenate requires all tensors have the same shape "
-            "on non-concatenating axes");
-      }
+    size_t non_any_size = non_any.size();
+    for (size_t k = 1; k < non_any_size; k++) {
+      if (reporter->AssertEQ(non_any[0], non_any[k])) continue;
+      throw Error(
+          "relay.concatenate requires all tensors have the same shape "
+          "on non-concatenating axes");
+    }
+
+    if (non_any_size == data_length) {
+      // All static case
+      oshape[i] = non_any[0];
+    } else if (non_any_size > 0 && is_dynamic_concat) {
+      // For non-concat axes, we want to enforce static shape constraint.
+      // However, if the concat axis is static, the output shape would become static while
+      // the input could be partially static/dynamic. To prevent runtime segfaults due to the lack
+      // of runtime input shape checking for such cases, static shape constraint is only enforced
+      // when the output concat axis is dynamic.
+      //
+      // Examples (both concat on the first axis):
+      // * [(?, 3), (?, ?)] -> (?, 3)
+      // * [(1, 3), (1, ?)] -> (2, ?)
+      oshape[i] = non_any[0];
+    } else {
+      oshape[i] = Any();
     }
   }
 
diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py
index 9d05631..b75cc5f 100644
--- a/tests/python/relay/test_any.py
+++ b/tests/python/relay/test_any.py
@@ -208,6 +208,27 @@ def test_any_concat():
     ref = np.concatenate(x_np, axis=0)
     check_result(x_np, mod, ref)
 
+    def test_oshape(in_vars, axis, oshape):
+        z = relay.op.concatenate(in_vars, axis=axis)
+        mod = tvm.IRModule()
+        mod["main"] = relay.Function(in_vars, z)
+        typed_mod = relay.transform.InferType()(mod)
+        assert typed_mod["main"].body.checked_type == relay.TensorType(oshape, dtype="float32")
+
+    x = [relay.var("x", shape=(relay.Any(), 3), dtype="float32") for _ in range(3)]
+    x.append(relay.var("x", shape=(relay.Any(), relay.Any()), dtype="float32"))
+
+    test_oshape(x, 0, (relay.Any(), 3))
+    test_oshape(x, 1, (relay.Any(), relay.Any()))
+
+    # [(1, 3), (1, ?)] -> (2, ?)
+    x = [
+        relay.var("x", shape=(1, 3), dtype="float32"),
+        relay.var("x", shape=(1, relay.Any()), dtype="float32"),
+    ]
+    test_oshape(x, 0, (2, relay.Any()))
+    test_oshape(x, 1, (1, relay.Any()))
+
 
 def verify_any_reshape(x_shape, newshape, x_np_shape, out_shape, variable_newshape=False):
     x = relay.var("x", shape=x_shape, dtype="float32")