You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2020/07/28 22:49:32 UTC

[incubator-tvm] branch master updated: [Relay] Handle ndarray_size in FoldConstant (#6156)

This is an automated email from the ASF dual-hosted git repository.

marisa pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new 44ff1f3  [Relay] Handle ndarray_size in FoldConstant (#6156)
44ff1f3 is described below

commit 44ff1f3b5ed0751fee39537a0e6e3870a74c930b
Author: lixiaoquan <ra...@163.com>
AuthorDate: Wed Jul 29 06:49:21 2020 +0800

    [Relay] Handle ndarray_size in FoldConstant (#6156)
    
    * [Relay] Handle ndarray_size in FoldConstant
    
    * Use Optional
---
 src/relay/transforms/fold_constant.cc         | 75 ++++++++++++++++++++++++---
 tests/python/relay/test_pass_fold_constant.py | 22 ++++++++
 2 files changed, 90 insertions(+), 7 deletions(-)

diff --git a/src/relay/transforms/fold_constant.cc b/src/relay/transforms/fold_constant.cc
index 0b873bf..3f5ecaa 100644
--- a/src/relay/transforms/fold_constant.cc
+++ b/src/relay/transforms/fold_constant.cc
@@ -86,7 +86,8 @@ class ConstantFolder : public ExprMutator {
         shape_func_op_(Op::Get("vm.shape_func")),
         alloc_tensor_op_(Op::Get("memory.alloc_tensor")),
         alloc_storage_op_(Op::Get("memory.alloc_storage")),
-        cast_op_(Op::Get("cast")) {}
+        cast_op_(Op::Get("cast")),
+        ndarray_size_op_(Op::Get("ndarray_size")) {}
 
   Expr VisitExpr_(const LetNode* op) final {
     Expr value = this->Mutate(op->value);
@@ -128,6 +129,10 @@ class ConstantFolder : public ExprMutator {
       return EvaluateShapeOf(res, origin_args, call->attrs);
     }
 
+    if (call->op == ndarray_size_op_) {
+      return EvaluateNdarraySize(res, origin_args, call->attrs);
+    }
+
     // We should think about potentially constant evaluation over these ops too.
     if (call->op == invoke_tvm_op_ || call->op == shape_func_op_ || call->op == alloc_tensor_op_ ||
         call->op == alloc_storage_op_) {
@@ -173,6 +178,7 @@ class ConstantFolder : public ExprMutator {
   const Op& alloc_tensor_op_;
   const Op& alloc_storage_op_;
   const Op& cast_op_;
+  const Op& ndarray_size_op_;
 
   // Convert value to expression.
   Expr ObjectToExpr(const ObjectRef& value) {
@@ -223,10 +229,8 @@ class ConstantFolder : public ExprMutator {
     CHECK(param != nullptr);
 
     tvm::Array<IndexExpr> ishape;
-    if (const ConstantNode* op = input.as<ConstantNode>()) {
-      ishape = op->tensor_type()->shape;
-    } else if (input->checked_type_.defined()) {
-      ishape = input->checked_type().as<TensorTypeNode>()->shape;
+    if (auto opt = GetConstantShape(input)) {
+      ishape = opt.value();
     } else {
       return expr;
     }
@@ -261,12 +265,69 @@ class ConstantFolder : public ExprMutator {
       shape = Constant(ndarray);
     }
 
+    return CastValue(shape, param->dtype);
+  }
+
+  // Evaluate a call to the ndarray_size operator for tensors with constant
+  // shapes.
+  Expr EvaluateNdarraySize(Expr expr, Array<Expr> args, Attrs attrs) {
+    Expr input = args[0];
+    const auto* param = attrs.as<NdarraySizeAttrs>();
+    CHECK(param != nullptr);
+
+    tvm::Array<IndexExpr> ishape;
+    if (auto opt = GetConstantShape(input)) {
+      ishape = opt.value();
+    } else {
+      return expr;
+    }
+
+    // Get the constant size
+    DLContext ctx;
+    ctx.device_type = kDLCPU;
+    ctx.device_id = 0;
+    runtime::NDArray value;
+    DLDataType cdtype = DataType::Int(32);
+    value = runtime::NDArray::Empty({1}, cdtype, ctx);
+    int32_t* data = static_cast<int32_t*>(value->data);
+    if (ishape.size() == 0) {
+      *data = 0;
+    } else {
+      *data = 1;
+      using ::tvm::tir::IntImmNode;
+      for (size_t i = 0; i < ishape.size(); ++i) {
+        if (const IntImmNode* dim = ishape[i].as<IntImmNode>()) {
+          *data *= dim->value;
+        } else {
+          return expr;
+        }
+      }
+    }
+
+    Constant size = Downcast<Constant>(ObjectToExpr(value));
+    return CastValue(size, param->dtype);
+  }
+
+  Expr CastValue(const Expr& value, DataType dtype) {
     // Cast the constant into correct dtype
     auto cast_attrs = make_object<CastAttrs>();
-    cast_attrs->dtype = param->dtype;
-    Expr ret = Call(cast_op_, {shape}, Attrs(cast_attrs), {});
+    cast_attrs->dtype = dtype;
+    Expr ret = Call(cast_op_, {value}, Attrs(cast_attrs), {});
     return ConstEvaluate(ret);
   }
+
+  Optional<tvm::Array<IndexExpr>> GetConstantShape(const Expr& input) {
+    tvm::Array<IndexExpr> ishape;
+    if (const ConstantNode* op = input.as<ConstantNode>()) {
+      ishape = op->tensor_type()->shape;
+    } else if (input->checked_type_.defined()) {
+      ishape = input->checked_type().as<TensorTypeNode>()->shape;
+    } else {
+      return Optional<tvm::Array<IndexExpr>>(nullptr);
+    }
+
+    return Optional<tvm::Array<IndexExpr>>(ishape);
+  }
 };
 
 Expr FoldConstant(const Expr& expr, const IRModule& mod) {
diff --git a/tests/python/relay/test_pass_fold_constant.py b/tests/python/relay/test_pass_fold_constant.py
index fcccab5..e985268 100644
--- a/tests/python/relay/test_pass_fold_constant.py
+++ b/tests/python/relay/test_pass_fold_constant.py
@@ -164,6 +164,27 @@ def test_fold_shape_of():
         assert tvm.ir.structural_equal(zz, zexpected)
 
 
+def test_fold_ndarray_size():
+    c_shape = (8, 9, 10)
+    def before(dtype):
+        x = relay.var("x", shape=c_shape, dtype="float32")
+        y = relay.var("y", shape=c_shape, dtype="float32")
+        z = relay.ndarray_size(x + y, dtype)
+        return relay.Function([x, y], z)
+
+    def expected(dtype):
+        x = relay.var("x", shape=c_shape, dtype="float32")
+        y = relay.var("y", shape=c_shape, dtype="float32")
+        z = relay.const([np.size(np.zeros(c_shape))], dtype=dtype)
+        func = relay.Function([x, y], z)
+        return func
+
+    for dtype in ["int32", "float32"]:
+        zz = run_opt_pass(before(dtype), transform.FoldConstant())
+        zexpected = run_opt_pass(expected(dtype), transform.InferType())
+        assert tvm.ir.structural_equal(zz, zexpected)
+
+
 def test_fold_full():
     c_shape = (8, 9, 10)
     def before():
@@ -228,3 +249,4 @@ if __name__ == "__main__":
     test_fold_shape_of()
     test_fold_full()
     test_fold_batch_norm()
+    test_fold_ndarray_size()