You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2021/02/26 01:41:53 UTC
[tvm] branch main updated: [Relay] Enforce static dim for
non-concat axis if one or more tensors have static dim (#7487)
This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 63ea8e1 [Relay] Enforce static dim for non-concat axis if one or more tensors have static dim (#7487)
63ea8e1 is described below
commit 63ea8e1fc934229a7fb56cac642a588ff3337e6e
Author: masahi <ma...@gmail.com>
AuthorDate: Fri Feb 26 10:41:13 2021 +0900
[Relay] Enforce static dim for non-concat axis if one or more tensors have static dim (#7487)
* enforce static dim for non-concat axis
* assign any when all dims are dyn
* add missing case
* simplify
* add test
* only enforce static dim constraint if concat output is dynamic
* more update to concat type rel
* update tests
* fixed compile warning
---
src/relay/op/tensor/transform.h | 69 +++++++++++++++++++++++++++++++----------
tests/python/relay/test_any.py | 21 +++++++++++++
2 files changed, 73 insertions(+), 17 deletions(-)
diff --git a/src/relay/op/tensor/transform.h b/src/relay/op/tensor/transform.h
index 95a83a9..dbf8537 100644
--- a/src/relay/op/tensor/transform.h
+++ b/src/relay/op/tensor/transform.h
@@ -101,29 +101,64 @@ bool ConcatenateRel(const Array<Type>& types, int num_inputs, const Attrs& attrs
}
// Calculate shape
- std::vector<IndexExpr> oshape(first->shape.begin(), first->shape.end());
- int data_length = static_cast<int>(tensor_tuple->fields.size());
+ std::vector<IndexExpr> oshape(ndim);
+ const size_t data_length = tensor_tuple->fields.size();
+
+ // Accumulate the concat axis output dim or decide if this is dynamic concat
+ bool is_dynamic_concat = false;
+ std::vector<TensorType> input_tensors;
+ IndexExpr concat_output_dim = first->shape[axis];
+ for (size_t i = 0; i < data_length; ++i) {
+ const auto& e = Downcast<TensorType>(tensor_tuple->fields[i]);
+ input_tensors.push_back(e);
+ if (e->shape[axis].as<AnyNode>()) {
+ is_dynamic_concat = true;
+ concat_output_dim = Any();
+ } else if (i > 0 && !is_dynamic_concat) {
+ // accumulate axis dimension
+ concat_output_dim += e->shape[axis];
+ }
+ }
+
+ oshape[axis] = concat_output_dim;
+
for (int i = 0; i < ndim; ++i) {
+ if (i == axis) {
+ // The concat axis is already handled above.
+ // The rest of the body sets the output shape for non-concat axes
+ continue;
+ }
std::vector<IndexExpr> non_any;
- for (int j = 0; j < data_length; ++j) {
- const auto& e = Downcast<TensorType>(tensor_tuple->fields[j]);
+ for (size_t j = 0; j < data_length; ++j) {
+ const auto& e = input_tensors[j];
if (!e->shape[i].as<AnyNode>()) {
non_any.push_back(e->shape[i]);
- // accumulate axis dimension
- if (j > 0 && i == axis && !oshape[i].as<AnyNode>()) {
- oshape[i] += e->shape[i];
- }
}
}
- int non_any_size = static_cast<int>(non_any.size());
- if (non_any_size != data_length) oshape[i] = Any();
- if (i != axis) {
- for (int k = 1; k < non_any_size; k++) {
- if (reporter->AssertEQ(non_any[0], non_any[k])) continue;
- throw Error(
- "relay.concatenate requires all tensors have the same shape "
- "on non-concatenating axes");
- }
+ size_t non_any_size = non_any.size();
+ for (size_t k = 1; k < non_any_size; k++) {
+ if (reporter->AssertEQ(non_any[0], non_any[k])) continue;
+ throw Error(
+ "relay.concatenate requires all tensors have the same shape "
+ "on non-concatenating axes");
+ }
+
+ if (non_any_size == data_length) {
+ // All static case
+ oshape[i] = non_any[0];
+ } else if (non_any_size > 0 && is_dynamic_concat) {
+ // For non-concat axes, we want to enforce static shape constraint.
+ // However, if the concat axis is static, the output shape would become static while
+ // the input could be partially static/dynamic. To prevent runtime segfaults due to the lack
+ // of runtime input shape checking for such cases, static shape constraint is only enforced
+ // when the output concat axis is dynamic.
+ //
+ // Examples (both concat on the first axis):
+ // * [(?, 3), (?, ?)] -> (?, 3)
+ // * [(1, 3), (1, ?)] -> (2, ?)
+ oshape[i] = non_any[0];
+ } else {
+ oshape[i] = Any();
}
}
diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py
index 9d05631..b75cc5f 100644
--- a/tests/python/relay/test_any.py
+++ b/tests/python/relay/test_any.py
@@ -208,6 +208,27 @@ def test_any_concat():
ref = np.concatenate(x_np, axis=0)
check_result(x_np, mod, ref)
+ def test_oshape(in_vars, axis, oshape):
+ z = relay.op.concatenate(in_vars, axis=axis)
+ mod = tvm.IRModule()
+ mod["main"] = relay.Function(in_vars, z)
+ typed_mod = relay.transform.InferType()(mod)
+ assert typed_mod["main"].body.checked_type == relay.TensorType(oshape, dtype="float32")
+
+ x = [relay.var("x", shape=(relay.Any(), 3), dtype="float32") for _ in range(3)]
+ x.append(relay.var("x", shape=(relay.Any(), relay.Any()), dtype="float32"))
+
+ test_oshape(x, 0, (relay.Any(), 3))
+ test_oshape(x, 1, (relay.Any(), relay.Any()))
+
+ # [(1, 3), (1, ?)] -> (2, ?)
+ x = [
+ relay.var("x", shape=(1, 3), dtype="float32"),
+ relay.var("x", shape=(1, relay.Any()), dtype="float32"),
+ ]
+ test_oshape(x, 0, (2, relay.Any()))
+ test_oshape(x, 1, (1, relay.Any()))
+
def verify_any_reshape(x_shape, newshape, x_np_shape, out_shape, variable_newshape=False):
x = relay.var("x", shape=x_shape, dtype="float32")