You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2022/09/14 04:10:38 UTC
[tvm] 03/03: support constant folding on ndarray_size
This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch torchbench
in repository https://gitbox.apache.org/repos/asf/tvm.git
commit bacf3946c727682e7aad82f03e34abbbd9f120a2
Author: Masahiro Masuda <ma...@gmail.com>
AuthorDate: Wed Sep 14 13:09:45 2022 +0900
support constant folding on ndarray_size
---
python/tvm/relay/frontend/pytorch.py | 2 +-
src/relay/transforms/fold_constant.cc | 10 ++++++++--
2 files changed, 9 insertions(+), 3 deletions(-)
diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py
index e2badaabf7..722b2889d3 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -2489,7 +2489,7 @@ class PyTorchOpConverter:
)
def numel(self, inputs, input_types):
- return _op.ndarray_size(inputs[0])
+ return fold_constant(_op.ndarray_size(inputs[0]))
def empty(self, inputs, input_types):
shape = inputs[0]
diff --git a/src/relay/transforms/fold_constant.cc b/src/relay/transforms/fold_constant.cc
index 9dec840be0..f484dfc700 100644
--- a/src/relay/transforms/fold_constant.cc
+++ b/src/relay/transforms/fold_constant.cc
@@ -188,8 +188,7 @@ class ConstantFolder : public MixedModeMutator {
if (is_no_computational && (is_no_qnn_canonicalized || !fold_qnn_)) {
return std::move(post_call);
}
- if (op == device_copy_op_ || op == shape_of_op_ || op == vm_shape_of_op_ ||
- op == ndarray_size_op_) {
+ if (op == device_copy_op_ || op == shape_of_op_ || op == vm_shape_of_op_) {
// We should think about potentially constant evaluation over these ops too.
return std::move(post_call);
}
@@ -383,6 +382,13 @@ class ConstantFolder : public MixedModeMutator {
// TODO(mbs): This is not necessary since we only ever ask for the shapes for
// pre-rewritten expressions which will always have a checked_type.
return const_node->tensor_type()->shape;
+ // } else if (auto ttype = input->type_as<TensorTypeNode>()) {
+ } else if (const auto* var = input.as<VarNode>()) {
+ auto ty = var->type_annotation;
+ if (ty->IsInstance<TensorTypeNode>()) {
+ return Downcast<TensorType>(ty)->shape;
+ }
+ return {};
} else if (input->checked_type_.defined()) {
return input->checked_type().as<TensorTypeNode>()->shape;
} else {