You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by jw...@apache.org on 2023/08/15 16:03:55 UTC
[tvm] branch unity updated: [Unity][ONNX] Improved symbolic handling and reshape functionality (#15550)
This is an automated email from the ASF dual-hosted git repository.
jwfromm pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 0e4c99cdb8 [Unity][ONNX] Improved symbolic handling and reshape functionality (#15550)
0e4c99cdb8 is described below
commit 0e4c99cdb84d51a8fccd2fde9bbad1fbf2e33e73
Author: Josh Fromm <jw...@octoml.ai>
AuthorDate: Tue Aug 15 09:03:49 2023 -0700
[Unity][ONNX] Improved symbolic handling and reshape functionality (#15550)
* Improved symbolic handling and reshape functionality
* retrigger ci
---
python/tvm/relax/frontend/onnx/onnx_frontend.py | 47 ++++++++++++++++++++++++-
src/relax/op/tensor/manipulate.cc | 12 +++++--
tests/python/relax/test_op_manipulate.py | 3 ++
3 files changed, 58 insertions(+), 4 deletions(-)
diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index 2ef6121002..152db73c51 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -150,7 +150,7 @@ def get_prim_expr_list(
Parameters
----------
- inputs : Union[relax.Constant, relax.ShapeExpr]
+ inputs : Union[relax.Constant, relax.ShapeExpr, relax.PrimValue]
The input value to try to convert to a list of PrimExpr.
Returns
@@ -165,6 +165,8 @@ def get_prim_expr_list(
return np_value.tolist()
elif isinstance(inputs, relax.ShapeExpr):
return inputs.values
+ elif isinstance(inputs, relax.PrimValue):
+ return [inputs.value.value]
else:
raise ValueError("Cannot cast {} to list of PrimExpr".format(type(inputs)))
@@ -233,6 +235,19 @@ class Div(OnnxOpConverter):
if all([isinstance(inp, relax.Constant) for inp in inputs]):
output = inputs[0].data.numpy() / inputs[1].data.numpy()
return relax.const(output, inputs[0].struct_info.dtype)
+ if any([isinstance(inp, relax.PrimValue) for inp in inputs]):
+ x = (
+ int(inputs[0].value)
+ if isinstance(inputs[0], relax.PrimValue)
+ else inputs[0].data.numpy()
+ )
+ y = (
+ int(inputs[1].value)
+ if isinstance(inputs[1], relax.PrimValue)
+ else inputs[1].data.numpy()
+ )
+ return relax.PrimValue(int(x / y))
+
return relax.op.divide(inputs[0], inputs[1])
@@ -359,6 +374,19 @@ class Add(OnnxOpConverter):
if all([isinstance(inp, relax.Constant) for inp in inputs]):
output = inputs[0].data.numpy() + inputs[1].data.numpy()
return relax.const(output, output.dtype)
+ # If primvalues are involved, handle them directly.
+ if any([isinstance(inp, relax.PrimValue) for inp in inputs]):
+ x = (
+ int(inputs[0].value)
+ if isinstance(inputs[0], relax.PrimValue)
+ else inputs[0].data.numpy()
+ )
+ y = (
+ int(inputs[1].value)
+ if isinstance(inputs[1], relax.PrimValue)
+ else inputs[1].data.numpy()
+ )
+ return relax.PrimValue(int(x + y))
return relax.op.add(inputs[0], inputs[1])
@@ -367,9 +395,24 @@ class Mul(OnnxOpConverter):
@classmethod
def _impl_v13(cls, bb, inputs, attr, params):
+ # When all inputs are constant, directly multiply.
if all([isinstance(inp, relax.Constant) for inp in inputs]):
output = inputs[0].data.numpy() * inputs[1].data.numpy()
return relax.const(output, output.dtype)
+ # If primvalues are involved, handle them directly.
+ if any([isinstance(inp, relax.PrimValue) for inp in inputs]):
+ x = (
+ int(inputs[0].value)
+ if isinstance(inputs[0], relax.PrimValue)
+ else inputs[0].data.numpy()
+ )
+ y = (
+ int(inputs[1].value)
+ if isinstance(inputs[1], relax.PrimValue)
+ else inputs[1].data.numpy()
+ )
+ return relax.PrimValue(int(x * y))
+
return relax.op.multiply(inputs[0], inputs[1])
@@ -382,6 +425,8 @@ class Cast(OnnxOpConverter):
if isinstance(inputs[0], relax.Constant):
output = inputs[0].data.numpy().astype(to_type)
return relax.const(output, to_type)
+ if isinstance(inputs[0], relax.PrimValue):
+ return relax.PrimValue(inputs[0].value.astype(to_type))
return relax.op.astype(inputs[0], to_type)
diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc
index 2d7e60c4f0..edf84e6887 100644
--- a/src/relax/op/tensor/manipulate.cc
+++ b/src/relax/op/tensor/manipulate.cc
@@ -604,11 +604,17 @@ TVM_REGISTER_OP("relax.permute_dims")
/* relax.reshape */
Expr ConvertNewShapeToExpr(const Expr& data, const ObjectRef& shape) {
- if (const auto* e = shape.as<ExprNode>()) {
+ const ArrayNode* array;
+ // Treat shape expressions as constant arrays to handle special values.
+ if (const auto* e = shape.as<ShapeExprNode>()) {
+ array = e->values.as<ArrayNode>();
+ // Other non-shape expressions are used directly.
+ } else if (const auto* e = shape.as<ExprNode>()) {
return GetRef<Expr>(e);
+ // Process special values in constants and produce an expression.
+ } else {
+ array = shape.as<ArrayNode>();
}
-
- const auto* array = shape.as<ArrayNode>();
CHECK(array != nullptr) << "Reshape only expects the input new shape to be either an Expr or an "
"Array of PrimExprs. However, the given new shape is "
<< shape;
diff --git a/tests/python/relax/test_op_manipulate.py b/tests/python/relax/test_op_manipulate.py
index 07e21cc179..b0b4b98ab5 100644
--- a/tests/python/relax/test_op_manipulate.py
+++ b/tests/python/relax/test_op_manipulate.py
@@ -72,6 +72,9 @@ def test_reshape_infer_struct_info():
bb, relax.op.reshape(x0, (3, -1, 5)), relax.TensorStructInfo((3, 8, 5), "float32")
)
_check_inference(bb, relax.op.reshape(x0, (-1,)), relax.TensorStructInfo((120,), "float32"))
+ _check_inference(
+ bb, relax.op.reshape(x0, relax.ShapeExpr([-1])), relax.TensorStructInfo((120,), "float32")
+ )
_check_inference(
bb, relax.op.reshape(x1, (3, 8, 5)), relax.TensorStructInfo((3, 8, 5), "float32")
)