You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by jw...@apache.org on 2023/08/15 16:03:55 UTC

[tvm] branch unity updated: [Unity][ONNX] Improved symbolic handling and reshape functionality (#15550)

This is an automated email from the ASF dual-hosted git repository.

jwfromm pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 0e4c99cdb8 [Unity][ONNX] Improved symbolic handling and reshape functionality (#15550)
0e4c99cdb8 is described below

commit 0e4c99cdb84d51a8fccd2fde9bbad1fbf2e33e73
Author: Josh Fromm <jw...@octoml.ai>
AuthorDate: Tue Aug 15 09:03:49 2023 -0700

    [Unity][ONNX] Improved symbolic handling and reshape functionality (#15550)
    
    * Improved symbolic handling and reshape functionality
    
    * retrigger ci
---
 python/tvm/relax/frontend/onnx/onnx_frontend.py | 47 ++++++++++++++++++++++++-
 src/relax/op/tensor/manipulate.cc               | 12 +++++--
 tests/python/relax/test_op_manipulate.py        |  3 ++
 3 files changed, 58 insertions(+), 4 deletions(-)

diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index 2ef6121002..152db73c51 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -150,7 +150,7 @@ def get_prim_expr_list(
 
     Parameters
     ----------
-    inputs : Union[relax.Constant, relax.ShapeExpr]
+    inputs : Union[relax.Constant, relax.ShapeExpr, relax.PrimValue]
         The input value to try to convert to a list of PrimExpr.
 
     Returns
@@ -165,6 +165,8 @@ def get_prim_expr_list(
         return np_value.tolist()
     elif isinstance(inputs, relax.ShapeExpr):
         return inputs.values
+    elif isinstance(inputs, relax.PrimValue):
+        return [inputs.value.value]
     else:
         raise ValueError("Cannot cast {} to list of PrimExpr".format(type(inputs)))
 
@@ -233,6 +235,19 @@ class Div(OnnxOpConverter):
         if all([isinstance(inp, relax.Constant) for inp in inputs]):
             output = inputs[0].data.numpy() / inputs[1].data.numpy()
             return relax.const(output, inputs[0].struct_info.dtype)
+        if any([isinstance(inp, relax.PrimValue) for inp in inputs]):
+            x = (
+                int(inputs[0].value)
+                if isinstance(inputs[0], relax.PrimValue)
+                else inputs[0].data.numpy()
+            )
+            y = (
+                int(inputs[1].value)
+                if isinstance(inputs[1], relax.PrimValue)
+                else inputs[1].data.numpy()
+            )
+            return relax.PrimValue(int(x / y))
+
         return relax.op.divide(inputs[0], inputs[1])
 
 
@@ -359,6 +374,19 @@ class Add(OnnxOpConverter):
         if all([isinstance(inp, relax.Constant) for inp in inputs]):
             output = inputs[0].data.numpy() + inputs[1].data.numpy()
             return relax.const(output, output.dtype)
+        # If primvalues are involved, handle them directly.
+        if any([isinstance(inp, relax.PrimValue) for inp in inputs]):
+            x = (
+                int(inputs[0].value)
+                if isinstance(inputs[0], relax.PrimValue)
+                else inputs[0].data.numpy()
+            )
+            y = (
+                int(inputs[1].value)
+                if isinstance(inputs[1], relax.PrimValue)
+                else inputs[1].data.numpy()
+            )
+            return relax.PrimValue(int(x + y))
         return relax.op.add(inputs[0], inputs[1])
 
 
@@ -367,9 +395,24 @@ class Mul(OnnxOpConverter):
 
     @classmethod
     def _impl_v13(cls, bb, inputs, attr, params):
+        # When all inputs are constant, directly multiply.
         if all([isinstance(inp, relax.Constant) for inp in inputs]):
             output = inputs[0].data.numpy() * inputs[1].data.numpy()
             return relax.const(output, output.dtype)
+        # If primvalues are involved, handle them directly.
+        if any([isinstance(inp, relax.PrimValue) for inp in inputs]):
+            x = (
+                int(inputs[0].value)
+                if isinstance(inputs[0], relax.PrimValue)
+                else inputs[0].data.numpy()
+            )
+            y = (
+                int(inputs[1].value)
+                if isinstance(inputs[1], relax.PrimValue)
+                else inputs[1].data.numpy()
+            )
+            return relax.PrimValue(int(x * y))
+
         return relax.op.multiply(inputs[0], inputs[1])
 
 
@@ -382,6 +425,8 @@ class Cast(OnnxOpConverter):
         if isinstance(inputs[0], relax.Constant):
             output = inputs[0].data.numpy().astype(to_type)
             return relax.const(output, to_type)
+        if isinstance(inputs[0], relax.PrimValue):
+            return relax.PrimValue(inputs[0].value.astype(to_type))
         return relax.op.astype(inputs[0], to_type)
 
 
diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc
index 2d7e60c4f0..edf84e6887 100644
--- a/src/relax/op/tensor/manipulate.cc
+++ b/src/relax/op/tensor/manipulate.cc
@@ -604,11 +604,17 @@ TVM_REGISTER_OP("relax.permute_dims")
 
 /* relax.reshape */
 Expr ConvertNewShapeToExpr(const Expr& data, const ObjectRef& shape) {
-  if (const auto* e = shape.as<ExprNode>()) {
+  const ArrayNode* array;
+  // Treat shape expressions as constant arrays to handle special values.
+  if (const auto* e = shape.as<ShapeExprNode>()) {
+    array = e->values.as<ArrayNode>();
+    // Other non-shape expressions are used directly.
+  } else if (const auto* e = shape.as<ExprNode>()) {
     return GetRef<Expr>(e);
+    // Process special values in constants and produce an expression.
+  } else {
+    array = shape.as<ArrayNode>();
   }
-
-  const auto* array = shape.as<ArrayNode>();
   CHECK(array != nullptr) << "Reshape only expects the input new shape to be either an Expr or an "
                              "Array of PrimExprs. However, the given new shape is "
                           << shape;
diff --git a/tests/python/relax/test_op_manipulate.py b/tests/python/relax/test_op_manipulate.py
index 07e21cc179..b0b4b98ab5 100644
--- a/tests/python/relax/test_op_manipulate.py
+++ b/tests/python/relax/test_op_manipulate.py
@@ -72,6 +72,9 @@ def test_reshape_infer_struct_info():
         bb, relax.op.reshape(x0, (3, -1, 5)), relax.TensorStructInfo((3, 8, 5), "float32")
     )
     _check_inference(bb, relax.op.reshape(x0, (-1,)), relax.TensorStructInfo((120,), "float32"))
+    _check_inference(
+        bb, relax.op.reshape(x0, relax.ShapeExpr([-1])), relax.TensorStructInfo((120,), "float32")
+    )
     _check_inference(
         bb, relax.op.reshape(x1, (3, 8, 5)), relax.TensorStructInfo((3, 8, 5), "float32")
     )