You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by co...@apache.org on 2021/01/09 03:49:25 UTC
[tvm] branch main updated: [TOPI] Treat undefined elements as
constants in Array (#7232)
This is an automated email from the ASF dual-hosted git repository.
comaniac pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 02ef6e6 [TOPI] Treat undefined elements as constants in Array (#7232)
02ef6e6 is described below
commit 02ef6e6243dbe525dc7e0a2f10704add0d7c24d7
Author: Cody Yu <co...@gmail.com>
AuthorDate: Fri Jan 8 19:49:06 2021 -0800
[TOPI] Treat undefined elements as constants in Array (#7232)
* [TOPI] Treat undefined elements as constants in Array
* Add a checker
* fix
* add test case
---
include/tvm/topi/detail/constant_utils.h | 5 +++--
include/tvm/topi/transform.h | 1 +
tests/python/topi/python/test_topi_transform.py | 1 +
3 files changed, 5 insertions(+), 2 deletions(-)
diff --git a/include/tvm/topi/detail/constant_utils.h b/include/tvm/topi/detail/constant_utils.h
index 49ce21b..92ff3a4 100644
--- a/include/tvm/topi/detail/constant_utils.h
+++ b/include/tvm/topi/detail/constant_utils.h
@@ -48,7 +48,8 @@ using namespace tvm::te;
inline bool IsConstInt(PrimExpr expr) { return expr->IsInstance<tvm::tir::IntImmNode>(); }
/*!
- * \brief Test whether the given Array has every element as constant integer
+ * \brief Test whether the given Array has every element as constant integer.
+ * Undefined elements are also treat as constants.
*
* \param array the array to query
*
@@ -57,7 +58,7 @@ inline bool IsConstInt(PrimExpr expr) { return expr->IsInstance<tvm::tir::IntImm
inline bool IsConstIntArray(Array<PrimExpr> array) {
bool is_const_int = true;
for (auto const& elem : array) {
- is_const_int &= elem->IsInstance<tvm::tir::IntImmNode>();
+ is_const_int &= !elem.defined() || elem->IsInstance<tvm::tir::IntImmNode>();
}
return is_const_int;
}
diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h
index a04762f..261fdf9 100644
--- a/include/tvm/topi/transform.h
+++ b/include/tvm/topi/transform.h
@@ -612,6 +612,7 @@ inline Tensor strided_slice(const Tensor& x, const Array<PrimExpr>& begin,
Array<PrimExpr> out_shape;
if (!is_static) {
+ ICHECK_EQ(strides.size(), src_tensor_dim);
for (size_t i = 0; i < src_tensor_dim; ++i) {
out_shape.push_back(indexdiv(end[i] - begin[i], strides[i]));
}
diff --git a/tests/python/topi/python/test_topi_transform.py b/tests/python/topi/python/test_topi_transform.py
index 30434f6..e0018ba 100644
--- a/tests/python/topi/python/test_topi_transform.py
+++ b/tests/python/topi/python/test_topi_transform.py
@@ -817,6 +817,7 @@ def test_strided_slice():
verify_strided_slice((3, 4, 3), [1, -1, 0], [2, -3, 3], [1, -1, 1])
verify_strided_slice((3, 4, 3), [1, 1, 0], [4, 4, 3])
verify_strided_slice((3, 4, 3), [0, 2, 0], [1, 2, 3])
+ verify_strided_slice((3, 4, 3), [0, 0, 0], [None, None, None])
@tvm.testing.uses_gpu