You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by jw...@apache.org on 2023/03/16 00:10:20 UTC

[tvm] branch unity updated: [Unity][Op] Enable special dimension value 0 in reshape (#14311)

This is an automated email from the ASF dual-hosted git repository.

jwfromm pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 39dc299c68 [Unity][Op] Enable special dimension value 0 in reshape (#14311)
39dc299c68 is described below

commit 39dc299c688f30e22f4d4d334d099c04696da148
Author: Josh Fromm <jw...@octoml.ai>
AuthorDate: Wed Mar 15 17:10:12 2023 -0700

    [Unity][Op] Enable special dimension value 0 in reshape (#14311)
    
    [Unity] Enable special dimension value 0 in reshape
---
 src/relax/op/tensor/manipulate.cc        | 42 +++++++++++++++++++++++++-------
 tests/python/relax/test_op_manipulate.py |  8 ++++--
 2 files changed, 39 insertions(+), 11 deletions(-)

diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc
index d90fd41e1c..c7bf051302 100644
--- a/src/relax/op/tensor/manipulate.cc
+++ b/src/relax/op/tensor/manipulate.cc
@@ -522,7 +522,8 @@ Expr ConvertNewShapeToExpr(const Expr& data, const ObjectRef& shape) {
                              "Array of PrimExprs. However, the given new shape is "
                           << shape;
   int dim_to_infer = -1;
-  PrimExpr new_shape_prod = IntImm(DataType::Int(64), 1);
+  // Keep track of which dimensions should be copied from input.
+  std::vector<int> zero_dims;
   for (int i = 0; i < static_cast<int>(array->size()); ++i) {
     const auto* _len = array->at(i).as<PrimExprNode>();
     CHECK(_len != nullptr) << "Reshape only expects the input new shape to be either an Expr or an "
@@ -533,7 +534,10 @@ Expr ConvertNewShapeToExpr(const Expr& data, const ObjectRef& shape) {
                                   "integers. However, the give new shape is "
                                << shape;
     const auto* int_len = len.as<IntImmNode>();
-    if (int_len != nullptr && int_len->value == -1) {
+    if (int_len != nullptr && int_len->value == 0) {
+      // Note that this dimension should be copied from the original shape.
+      zero_dims.push_back(i);
+    } else if (int_len != nullptr && int_len->value == -1) {
       CHECK_EQ(dim_to_infer, -1) << "Reshape accepts at most one \"-1\" in the new shape. However, "
                                     "there are multiple \"-1\" in the given new shape  "
                                  << shape;
@@ -543,15 +547,12 @@ Expr ConvertNewShapeToExpr(const Expr& data, const ObjectRef& shape) {
           << "Reshape requires all values in the new shape to be positive except a single \"-1\". "
              "However, the given new shape is "
           << shape;
-      // We expect any symbolic not to signal the intent of -1, and therefore do no check for
-      // symbolic value here.
-      new_shape_prod = new_shape_prod * len;
     }
   }
 
   Array<PrimExpr> array_ref = GetRef<Array<PrimExpr>>(array);
   // When there is no dimension to infer, just return the input array as ShapeExpr.
-  if (dim_to_infer == -1) {
+  if (dim_to_infer == -1 && zero_dims.empty()) {
     return ShapeExpr(array_ref);
   }
 
@@ -569,9 +570,32 @@ Expr ConvertNewShapeToExpr(const Expr& data, const ObjectRef& shape) {
          "to infer. However, the given input shape is "
       << data_sinfo->shape << " whose shape value is unknown.";
 
-  arith::Analyzer analyzer;
-  PrimExpr old_shape_prod = ComputeShapeProduct(shape_sinfo->values.value());
-  array_ref.Set(dim_to_infer, analyzer.Simplify(floordiv(old_shape_prod, new_shape_prod)));
+  // Set any 0 valued dimensions to match the corresponding input shape.
+  if (!zero_dims.empty()) {
+    for (int i : zero_dims) {
+      array_ref.Set(i, shape_sinfo->values.value()[i]);
+    }
+  }
+
+  // Set any -1 dimensions to complete the number of appropriate elements.
+  // Start by computing the shape product of all positive indices.
+  PrimExpr new_shape_prod = IntImm(DataType::Int(64), 1);
+  for (int i = 0; i < static_cast<int>(array_ref.size()); ++i) {
+    PrimExpr new_dim = array_ref[i];
+    const auto* int_dim = new_dim.as<IntImmNode>();
+    // We expect any symbolic not to signal the intent of -1, and therefore do no check for
+    // symbolic value here.
+    if (int_dim == nullptr || int_dim->value > 0) {
+      new_shape_prod = new_shape_prod * new_dim;
+    }
+  }
+
+  // Assign appropriate value to -1 dimension.
+  if (dim_to_infer != -1) {
+    arith::Analyzer analyzer;
+    PrimExpr old_shape_prod = ComputeShapeProduct(shape_sinfo->values.value());
+    array_ref.Set(dim_to_infer, analyzer.Simplify(floordiv(old_shape_prod, new_shape_prod)));
+  }
   return ShapeExpr(array_ref);
 }
 
diff --git a/tests/python/relax/test_op_manipulate.py b/tests/python/relax/test_op_manipulate.py
index af20639a8e..16bbc04d26 100644
--- a/tests/python/relax/test_op_manipulate.py
+++ b/tests/python/relax/test_op_manipulate.py
@@ -179,6 +179,12 @@ def test_reshape_infer_struct_info_shape_var():
     _check_inference(
         bb, relax.op.reshape(x0, (3, 8, 5)), relax.TensorStructInfo((3, 8, 5), "float32")
     )
+    _check_inference(
+        bb, relax.op.reshape(x0, (2, 3, 0, 5)), relax.TensorStructInfo((2, 3, 4, 5), "float32")
+    )
+    _check_inference(
+        bb, relax.op.reshape(x0, (1, 3, 0, -1)), relax.TensorStructInfo((1, 3, 4, 10), "float32")
+    )
     _check_inference(
         bb, relax.op.reshape(x0, (3, -1, 5)), relax.TensorStructInfo((3, 8, 5), "float32")
     )
@@ -281,8 +287,6 @@ def test_reshape_infer_struct_info_non_positive_new_shape():
     bb = relax.BlockBuilder()
     x = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32"))
 
-    with pytest.raises(TVMError):
-        bb.normalize(relax.op.reshape(x, (2, 0, 4, 5)))
     with pytest.raises(TVMError):
         bb.normalize(relax.op.reshape(x, (-2, -3, -4, -5)))