You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by co...@apache.org on 2021/01/09 03:49:25 UTC

[tvm] branch main updated: [TOPI] Treat undefined elements as constants in Array (#7232)

This is an automated email from the ASF dual-hosted git repository.

comaniac pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 02ef6e6  [TOPI] Treat undefined elements as constants in Array (#7232)
02ef6e6 is described below

commit 02ef6e6243dbe525dc7e0a2f10704add0d7c24d7
Author: Cody Yu <co...@gmail.com>
AuthorDate: Fri Jan 8 19:49:06 2021 -0800

    [TOPI] Treat undefined elements as constants in Array (#7232)
    
    * [TOPI] Treat undefined elements as constants in Array
    
    * Add a checker
    
    * fix
    
    * add test case
---
 include/tvm/topi/detail/constant_utils.h        | 5 +++--
 include/tvm/topi/transform.h                    | 1 +
 tests/python/topi/python/test_topi_transform.py | 1 +
 3 files changed, 5 insertions(+), 2 deletions(-)

diff --git a/include/tvm/topi/detail/constant_utils.h b/include/tvm/topi/detail/constant_utils.h
index 49ce21b..92ff3a4 100644
--- a/include/tvm/topi/detail/constant_utils.h
+++ b/include/tvm/topi/detail/constant_utils.h
@@ -48,7 +48,8 @@ using namespace tvm::te;
 inline bool IsConstInt(PrimExpr expr) { return expr->IsInstance<tvm::tir::IntImmNode>(); }
 
 /*!
- * \brief Test whether the given Array has every element as constant integer
+ * \brief Test whether the given Array has every element as constant integer.
+ * Undefined elements are also treat as constants.
  *
  * \param array the array to query
  *
@@ -57,7 +58,7 @@ inline bool IsConstInt(PrimExpr expr) { return expr->IsInstance<tvm::tir::IntImm
 inline bool IsConstIntArray(Array<PrimExpr> array) {
   bool is_const_int = true;
   for (auto const& elem : array) {
-    is_const_int &= elem->IsInstance<tvm::tir::IntImmNode>();
+    is_const_int &= !elem.defined() || elem->IsInstance<tvm::tir::IntImmNode>();
   }
   return is_const_int;
 }
diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h
index a04762f..261fdf9 100644
--- a/include/tvm/topi/transform.h
+++ b/include/tvm/topi/transform.h
@@ -612,6 +612,7 @@ inline Tensor strided_slice(const Tensor& x, const Array<PrimExpr>& begin,
 
   Array<PrimExpr> out_shape;
   if (!is_static) {
+    ICHECK_EQ(strides.size(), src_tensor_dim);
     for (size_t i = 0; i < src_tensor_dim; ++i) {
       out_shape.push_back(indexdiv(end[i] - begin[i], strides[i]));
     }
diff --git a/tests/python/topi/python/test_topi_transform.py b/tests/python/topi/python/test_topi_transform.py
index 30434f6..e0018ba 100644
--- a/tests/python/topi/python/test_topi_transform.py
+++ b/tests/python/topi/python/test_topi_transform.py
@@ -817,6 +817,7 @@ def test_strided_slice():
     verify_strided_slice((3, 4, 3), [1, -1, 0], [2, -3, 3], [1, -1, 1])
     verify_strided_slice((3, 4, 3), [1, 1, 0], [4, 4, 3])
     verify_strided_slice((3, 4, 3), [0, 2, 0], [1, 2, 3])
+    verify_strided_slice((3, 4, 3), [0, 0, 0], [None, None, None])
 
 
 @tvm.testing.uses_gpu