You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2020/07/29 07:53:47 UTC

[GitHub] [incubator-tvm] MarisaKirisame commented on a change in pull request #6156: [Relay] Handle ndarray_size in FoldConstant

MarisaKirisame commented on a change in pull request #6156:
URL: https://github.com/apache/incubator-tvm/pull/6156#discussion_r461567738



##########
File path: src/relay/transforms/fold_constant.cc
##########
@@ -261,12 +263,66 @@ class ConstantFolder : public ExprMutator {
       shape = Constant(ndarray);
     }
 
+    return CastValue(shape, param->dtype);
+  }
+
+  // Evaluate a call to the ndarray_size operator for tensors with constant
+  // shapes.
+  Expr EvaluateNdarraySize(Expr expr, Array<Expr> args, Attrs attrs) {
+    Expr input = args[0];
+    const auto* param = attrs.as<NdarraySizeAttrs>();
+    CHECK(param != nullptr);
+
+    tvm::Array<IndexExpr> ishape;
+    if (!GetConstantShape(input, &ishape)) {
+      return expr;
+    }
+
+    // Get the constant size
+    DLContext ctx;
+    ctx.device_type = kDLCPU;
+    ctx.device_id = 0;
+    runtime::NDArray value;
+    DLDataType cdtype = DataType::Int(32);
+    value = runtime::NDArray::Empty({1}, cdtype, ctx);
+    int32_t* data = static_cast<int32_t*>(value->data);
+    if (ishape.size() == 0) {
+      *data = 0;
+    } else {
+      *data = 1;
+      using ::tvm::tir::IntImmNode;
+      for (size_t i = 0; i < ishape.size(); ++i) {
+        if (const IntImmNode* dim = ishape[i].as<IntImmNode>()) {
+          *data *= dim->value;
+        } else {
+          return expr;
+        }
+      }
+    }
+
+    Constant size = Downcast<Constant>(ObjectToExpr(value));
+    return CastValue(size, param->dtype);
+  }
+
+  Expr CastValue(const Expr& value, DataType dtype) {
     // Cast the constant into correct dtype
     auto cast_attrs = make_object<CastAttrs>();
-    cast_attrs->dtype = param->dtype;
-    Expr ret = Call(cast_op_, {shape}, Attrs(cast_attrs), {});
+    cast_attrs->dtype = dtype;
+    Expr ret = Call(cast_op_, {value}, Attrs(cast_attrs), {});
     return ConstEvaluate(ret);
   }
+
+  bool GetConstantShape(const Expr& input, tvm::Array<IndexExpr>* ishape) {

Review comment:
       please use optional.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org