You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2022/09/14 04:10:38 UTC

[tvm] 03/03: support constant folding on ndarray_size

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch torchbench
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit bacf3946c727682e7aad82f03e34abbbd9f120a2
Author: Masahiro Masuda <ma...@gmail.com>
AuthorDate: Wed Sep 14 13:09:45 2022 +0900

    support constant folding on ndarray_size
---
 python/tvm/relay/frontend/pytorch.py  |  2 +-
 src/relay/transforms/fold_constant.cc | 10 ++++++++--
 2 files changed, 9 insertions(+), 3 deletions(-)

diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py
index e2badaabf7..722b2889d3 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -2489,7 +2489,7 @@ class PyTorchOpConverter:
         )
 
     def numel(self, inputs, input_types):
-        return _op.ndarray_size(inputs[0])
+        return fold_constant(_op.ndarray_size(inputs[0]))
 
     def empty(self, inputs, input_types):
         shape = inputs[0]
diff --git a/src/relay/transforms/fold_constant.cc b/src/relay/transforms/fold_constant.cc
index 9dec840be0..f484dfc700 100644
--- a/src/relay/transforms/fold_constant.cc
+++ b/src/relay/transforms/fold_constant.cc
@@ -188,8 +188,7 @@ class ConstantFolder : public MixedModeMutator {
     if (is_no_computational && (is_no_qnn_canonicalized || !fold_qnn_)) {
       return std::move(post_call);
     }
-    if (op == device_copy_op_ || op == shape_of_op_ || op == vm_shape_of_op_ ||
-        op == ndarray_size_op_) {
+    if (op == device_copy_op_ || op == shape_of_op_ || op == vm_shape_of_op_) {
       // We should think about potentially constant evaluation over these ops too.
       return std::move(post_call);
     }
@@ -383,6 +382,13 @@ class ConstantFolder : public MixedModeMutator {
       // TODO(mbs): This is not necessary since we only ever ask for the shapes for
       // pre-rewritten expressions which will always have a checked_type.
       return const_node->tensor_type()->shape;
+      //    } else if (auto ttype = input->type_as<TensorTypeNode>()) {
+    }  else if (const auto* var = input.as<VarNode>()) {
+      auto ty = var->type_annotation;
+      if (ty->IsInstance<TensorTypeNode>()) {
+        return Downcast<TensorType>(ty)->shape;
+      }
+      return {};
     } else if (input->checked_type_.defined()) {
       return input->checked_type().as<TensorTypeNode>()->shape;
     } else {