You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by jw...@apache.org on 2023/03/16 00:10:20 UTC
[tvm] branch unity updated: [Unity][Op] Enable special dimension value 0 in reshape (#14311)
This is an automated email from the ASF dual-hosted git repository.
jwfromm pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 39dc299c68 [Unity][Op] Enable special dimension value 0 in reshape (#14311)
39dc299c68 is described below
commit 39dc299c688f30e22f4d4d334d099c04696da148
Author: Josh Fromm <jw...@octoml.ai>
AuthorDate: Wed Mar 15 17:10:12 2023 -0700
[Unity][Op] Enable special dimension value 0 in reshape (#14311)
[Unity] Enable special dimension value 0 in reshape
---
src/relax/op/tensor/manipulate.cc | 42 +++++++++++++++++++++++++-------
tests/python/relax/test_op_manipulate.py | 8 ++++--
2 files changed, 39 insertions(+), 11 deletions(-)
diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc
index d90fd41e1c..c7bf051302 100644
--- a/src/relax/op/tensor/manipulate.cc
+++ b/src/relax/op/tensor/manipulate.cc
@@ -522,7 +522,8 @@ Expr ConvertNewShapeToExpr(const Expr& data, const ObjectRef& shape) {
"Array of PrimExprs. However, the given new shape is "
<< shape;
int dim_to_infer = -1;
- PrimExpr new_shape_prod = IntImm(DataType::Int(64), 1);
+ // Keep track of which dimensions should be copied from input.
+ std::vector<int> zero_dims;
for (int i = 0; i < static_cast<int>(array->size()); ++i) {
const auto* _len = array->at(i).as<PrimExprNode>();
CHECK(_len != nullptr) << "Reshape only expects the input new shape to be either an Expr or an "
@@ -533,7 +534,10 @@ Expr ConvertNewShapeToExpr(const Expr& data, const ObjectRef& shape) {
"integers. However, the give new shape is "
<< shape;
const auto* int_len = len.as<IntImmNode>();
- if (int_len != nullptr && int_len->value == -1) {
+ if (int_len != nullptr && int_len->value == 0) {
+ // Note that this dimension should be copied from the original shape.
+ zero_dims.push_back(i);
+ } else if (int_len != nullptr && int_len->value == -1) {
CHECK_EQ(dim_to_infer, -1) << "Reshape accepts at most one \"-1\" in the new shape. However, "
"there are multiple \"-1\" in the given new shape "
<< shape;
@@ -543,15 +547,12 @@ Expr ConvertNewShapeToExpr(const Expr& data, const ObjectRef& shape) {
<< "Reshape requires all values in the new shape to be positive except a single \"-1\". "
"However, the given new shape is "
<< shape;
- // We expect any symbolic not to signal the intent of -1, and therefore do no check for
- // symbolic value here.
- new_shape_prod = new_shape_prod * len;
}
}
Array<PrimExpr> array_ref = GetRef<Array<PrimExpr>>(array);
// When there is no dimension to infer, just return the input array as ShapeExpr.
- if (dim_to_infer == -1) {
+ if (dim_to_infer == -1 && zero_dims.empty()) {
return ShapeExpr(array_ref);
}
@@ -569,9 +570,32 @@ Expr ConvertNewShapeToExpr(const Expr& data, const ObjectRef& shape) {
"to infer. However, the given input shape is "
<< data_sinfo->shape << " whose shape value is unknown.";
- arith::Analyzer analyzer;
- PrimExpr old_shape_prod = ComputeShapeProduct(shape_sinfo->values.value());
- array_ref.Set(dim_to_infer, analyzer.Simplify(floordiv(old_shape_prod, new_shape_prod)));
+ // Set any 0 valued dimensions to match the corresponding input shape.
+ if (!zero_dims.empty()) {
+ for (int i : zero_dims) {
+ array_ref.Set(i, shape_sinfo->values.value()[i]);
+ }
+ }
+
+ // Set any -1 dimensions to complete the number of appropriate elements.
+ // Start by computing the shape product of all positive indices.
+ PrimExpr new_shape_prod = IntImm(DataType::Int(64), 1);
+ for (int i = 0; i < static_cast<int>(array_ref.size()); ++i) {
+ PrimExpr new_dim = array_ref[i];
+ const auto* int_dim = new_dim.as<IntImmNode>();
+ // We expect any symbolic not to signal the intent of -1, and therefore do no check for
+ // symbolic value here.
+ if (int_dim == nullptr || int_dim->value > 0) {
+ new_shape_prod = new_shape_prod * new_dim;
+ }
+ }
+
+ // Assign appropriate value to -1 dimension.
+ if (dim_to_infer != -1) {
+ arith::Analyzer analyzer;
+ PrimExpr old_shape_prod = ComputeShapeProduct(shape_sinfo->values.value());
+ array_ref.Set(dim_to_infer, analyzer.Simplify(floordiv(old_shape_prod, new_shape_prod)));
+ }
return ShapeExpr(array_ref);
}
diff --git a/tests/python/relax/test_op_manipulate.py b/tests/python/relax/test_op_manipulate.py
index af20639a8e..16bbc04d26 100644
--- a/tests/python/relax/test_op_manipulate.py
+++ b/tests/python/relax/test_op_manipulate.py
@@ -179,6 +179,12 @@ def test_reshape_infer_struct_info_shape_var():
_check_inference(
bb, relax.op.reshape(x0, (3, 8, 5)), relax.TensorStructInfo((3, 8, 5), "float32")
)
+ _check_inference(
+ bb, relax.op.reshape(x0, (2, 3, 0, 5)), relax.TensorStructInfo((2, 3, 4, 5), "float32")
+ )
+ _check_inference(
+ bb, relax.op.reshape(x0, (1, 3, 0, -1)), relax.TensorStructInfo((1, 3, 4, 10), "float32")
+ )
_check_inference(
bb, relax.op.reshape(x0, (3, -1, 5)), relax.TensorStructInfo((3, 8, 5), "float32")
)
@@ -281,8 +287,6 @@ def test_reshape_infer_struct_info_non_positive_new_shape():
bb = relax.BlockBuilder()
x = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32"))
- with pytest.raises(TVMError):
- bb.normalize(relax.op.reshape(x, (2, 0, 4, 5)))
with pytest.raises(TVMError):
bb.normalize(relax.op.reshape(x, (-2, -3, -4, -5)))