You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2020/07/25 15:16:19 UTC
[incubator-tvm] branch master updated: [Relay] Fix interpreter for
dyanmic shape input of ndarray_size (#6086)
This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push:
new 959cff1 [Relay] Fix interpreter for dyanmic shape input of ndarray_size (#6086)
959cff1 is described below
commit 959cff1c786e0eb33b99007be66de61d2275d7a5
Author: lixiaoquan <ra...@163.com>
AuthorDate: Sat Jul 25 23:16:06 2020 +0800
[Relay] Fix interpreter for dyanmic shape input of ndarray_size (#6086)
---
src/relay/backend/interpreter.cc | 14 ++------------
tests/python/relay/test_any.py | 22 ++++++++++++++++++++--
2 files changed, 22 insertions(+), 14 deletions(-)
diff --git a/src/relay/backend/interpreter.cc b/src/relay/backend/interpreter.cc
index 9a75c0a..08c5a7c 100644
--- a/src/relay/backend/interpreter.cc
+++ b/src/relay/backend/interpreter.cc
@@ -213,11 +213,7 @@ class Interpreter : public ExprFunctor<ObjectRef(const Expr& n)>,
PatternFunctor<bool(const Pattern& p, const ObjectRef& v)> {
public:
Interpreter(IRModule mod, DLContext context, Target target)
- : mod_(mod),
- context_(context),
- target_(target),
- debug_op_(Op::Get("debug")),
- shape_of_op_(Op::Get("shape_of")) {
+ : mod_(mod), context_(context), target_(target), debug_op_(Op::Get("debug")) {
engine_ = CompileEngine::Global();
}
@@ -481,12 +477,7 @@ class Interpreter : public ExprFunctor<ObjectRef(const Expr& n)>,
Array<Shape> out_shapes;
auto ret_type = func->body->checked_type();
- bool is_dyn = IsDynamic(func->checked_type());
- if (call_node->op == shape_of_op_) {
- // The output shape of shape_of must be static since Relay doesn't support
- // dynamic rank tensors.
- is_dyn = false;
- }
+ bool is_dyn = IsDynamic(ret_type);
if (is_dyn) {
CHECK(func->HasNonzeroAttr(attr::kPrimitive));
@@ -722,7 +713,6 @@ class Interpreter : public ExprFunctor<ObjectRef(const Expr& n)>,
CompileEngine engine_;
// Cache ops that need to be frequently used later to reduce lookup overhead.
const Op& debug_op_;
- const Op& shape_of_op_;
};
TypedPackedFunc<ObjectRef(Expr)> CreateInterpreter(IRModule mod, DLContext context, Target target) {
diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py
index bf28ee1..0e8a328 100644
--- a/tests/python/relay/test_any.py
+++ b/tests/python/relay/test_any.py
@@ -814,7 +814,7 @@ def test_mixed_input_type():
assert result.asnumpy().shape == ref_out_shape, \
"Shape mismatch: expect %s but got %s." % (str(ref_out_shape), str(result.asnumpy().shape))
-def verify_any_crop_and_resize(data_shape, boxes_shape, box_indices_shape, crop_size,
+def verify_any_crop_and_resize(data_shape, boxes_shape, box_indices_shape, crop_size,
layout, static_boxes, static_box_indices_shape, ref_out_shape):
mod = tvm.IRModule()
dtype = "float32"
@@ -872,6 +872,24 @@ def test_any_mirror_pad():
static_data_shape=(1, 256, 232, 232),
ref_out_shape=(1, 256, 234, 234))
+def verify_any_ndarray_size(data_np_shape):
+ v = relay.var("v", shape=any_dims(len(data_np_shape)), dtype='float32')
+ n = relay.ndarray_size(v, dtype='int32')
+ mod = tvm.IRModule()
+ mod['main'] = relay.Function([v], n)
+ np_data = np.zeros(data_np_shape, dtype='float32')
+ ref_res = np.size(np_data)
+
+ for kind in ["debug", "vm"]:
+ ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
+ result = ex.evaluate()(np_data)
+ tvm.testing.assert_allclose(result.asnumpy(), ref_res)
+
+def test_any_ndarray_size():
+ verify_any_ndarray_size((2,))
+ verify_any_ndarray_size((2, 2))
+ verify_any_ndarray_size((1, 2, 3, 4))
+
if __name__ == "__main__":
test_any_full()
test_any_full_like()
@@ -908,4 +926,4 @@ if __name__ == "__main__":
test_mixed_input_type()
test_any_crop_and_resize()
test_any_mirror_pad()
-
+ test_any_ndarray_size()