You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by wu...@apache.org on 2021/02/10 20:41:56 UTC

[tvm] branch fix/iter_map created (now 693aa52)

This is an automated email from the ASF dual-hosted git repository.

wuwei pushed a change to branch fix/iter_map
in repository https://gitbox.apache.org/repos/asf/tvm.git.


      at 693aa52  [Arith] Fix iter_affine_map with non-const extent

This branch includes the following new commits:

     new 693aa52  [Arith] Fix iter_affine_map with non-const extent

The 1 revisions listed above as "new" are entirely new to this
repository and will be described in separate emails.  The revisions
listed as "add" were already present in the repository and have only
been added to this reference.



[tvm] 01/01: [Arith] Fix iter_affine_map with non-const extent

Posted by wu...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

wuwei pushed a commit to branch fix/iter_map
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit 693aa52c733c94f82b8c3f75a1bac245aca164ce
Author: Wuwei Lin <wu...@apache.org>
AuthorDate: Wed Feb 10 12:41:17 2021 -0800

    [Arith] Fix iter_affine_map with non-const extent
---
 src/arith/iter_affine_map.cc                       | 34 +++++++++++-----------
 .../python/unittest/test_arith_iter_affine_map.py  |  3 ++
 2 files changed, 20 insertions(+), 17 deletions(-)

diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc
index 7896db7..170e825 100644
--- a/src/arith/iter_affine_map.cc
+++ b/src/arith/iter_affine_map.cc
@@ -412,8 +412,8 @@ class IterMapRewriter : public ExprMutator {
     return analyzer_->CanProve(floormod(lhs, rhs) == 0);
   }
 
-  PrimExpr SplitFloorDivConst(IterSplitExpr lhs, PrimExpr rhs);
-  PrimExpr SplitFloorModConst(IterSplitExpr lhs, PrimExpr rhs);
+  PrimExpr SplitFloorDivConst(IterSplitExpr lhs, PrimExpr rhs, const PrimExpr &orig);
+  PrimExpr SplitFloorModConst(IterSplitExpr lhs, PrimExpr rhs, const PrimExpr &orig);
 
   static void AddToLhs(IterSumExprNode* lhs, IterSplitExpr rhs, int sign) {
     tir::ExprDeepEqual equal;
@@ -577,14 +577,14 @@ PrimExpr IterMapRewriter::VisitExpr_(const MulNode* op) {
     if (op->a.same_as(a) && op->b.same_as(b)) {
       return GetRef<PrimExpr>(op);
     } else {
-      return Mul(a, b);
+      return GetRef<PrimExpr>(op);
     }
   }
 
   if (a->IsInstance<IterMapExprNode>() && b->IsInstance<IterMapExprNode>()) {
     // cannot multiply two iterators, mark as unresolved.
     ++unresolved_count_;
-    return Mul(a, b);
+    return GetRef<PrimExpr>(op);
   }
 
   if (!a->IsInstance<IterMapExprNode>()) {
@@ -603,7 +603,7 @@ PrimExpr IterMapRewriter::VisitExpr_(const MulNode* op) {
   }
 }
 
-PrimExpr IterMapRewriter::SplitFloorDivConst(IterSplitExpr lhs, PrimExpr rhs) {
+PrimExpr IterMapRewriter::SplitFloorDivConst(IterSplitExpr lhs, PrimExpr rhs, const PrimExpr &orig) {
   // floordiv(x*scale, rhs)
   if (is_one(rhs)) return std::move(lhs);
   if (!is_one(lhs->scale)) {
@@ -619,7 +619,7 @@ PrimExpr IterMapRewriter::SplitFloorDivConst(IterSplitExpr lhs, PrimExpr rhs) {
       } else {
         // mark as unresolved.
         ++unresolved_count_;
-        return floordiv(lhs, rhs);
+        return orig;
       }
     }
   }
@@ -641,7 +641,7 @@ PrimExpr IterMapRewriter::SplitFloorDivConst(IterSplitExpr lhs, PrimExpr rhs) {
   } else {
     // mark as unresolved.
     ++unresolved_count_;
-    return floordiv(lhs, rhs);
+    return orig;
   }
 }
 
@@ -669,7 +669,7 @@ PrimExpr IterMapRewriter::VisitExpr_(const FloorDivNode* op) {
   if (b->IsInstance<IterMapExprNode>()) {
     // cannot divide an iterator, mark as unresolved.
     ++unresolved_count_;
-    return FloorDiv(a, b);
+    return GetRef<PrimExpr>(op);
   }
 
   if (a->IsInstance<IterSumExprNode>()) {
@@ -678,16 +678,16 @@ PrimExpr IterMapRewriter::VisitExpr_(const FloorDivNode* op) {
       return SplitFloorDivConst(opt.value(), b);
     } else {
       ++unresolved_count_;
-      return FloorDiv(a, b);
+      return GetRef<PrimExpr>(op);
     }
   } else {
     ICHECK(a->IsInstance<IterSplitExprNode>());
     IterSplitExpr ret = Downcast<IterSplitExpr>(std::move(a));
-    return SplitFloorDivConst(ret, b);
+    return SplitFloorDivConst(ret, b, GetRef<PrimExpr>(op));
   }
 }
 
-PrimExpr IterMapRewriter::SplitFloorModConst(IterSplitExpr lhs, PrimExpr rhs) {
+PrimExpr IterMapRewriter::SplitFloorModConst(IterSplitExpr lhs, PrimExpr rhs, const PrimExpr &orig) {
   // floormod(x*scale, rhs)
   if (is_one(rhs)) return make_zero(lhs->dtype);
   if (!is_one(lhs->scale)) {
@@ -701,7 +701,7 @@ PrimExpr IterMapRewriter::SplitFloorModConst(IterSplitExpr lhs, PrimExpr rhs) {
       } else {
         // mark as unresolved.
         ++unresolved_count_;
-        return floormod(lhs, rhs);
+        return orig;
       }
     }
   }
@@ -715,7 +715,7 @@ PrimExpr IterMapRewriter::SplitFloorModConst(IterSplitExpr lhs, PrimExpr rhs) {
   } else {
     // mark as unresolved.
     ++unresolved_count_;
-    return floormod(lhs, rhs);
+    return orig;
   }
 }
 
@@ -743,21 +743,21 @@ PrimExpr IterMapRewriter::VisitExpr_(const FloorModNode* op) {
   if (b->IsInstance<IterMapExprNode>()) {
     // cannot mod an iterator, mark as unresolved.
     ++unresolved_count_;
-    return FloorMod(a, b);
+    return GetRef<PrimExpr>(op);
   }
 
   if (a->IsInstance<IterSumExprNode>()) {
     IterSumExpr ret = Downcast<IterSumExpr>(a);
     if (auto opt = TryFuseIters(ret)) {
-      return SplitFloorModConst(opt.value(), b);
+      return SplitFloorModConst(opt.value(), b, GetRef<PrimExpr>(op));
     } else {
       ++unresolved_count_;
-      return FloorMod(a, b);
+      return GetRef<PrimExpr>(op);
     }
   } else {
     ICHECK(a->IsInstance<IterSplitExprNode>());
     IterSplitExpr ret = Downcast<IterSplitExpr>(std::move(a));
-    return SplitFloorModConst(ret, b);
+    return SplitFloorModConst(ret, b, GetRef<PrimExpr>(op))
   }
 }
 
diff --git a/tests/python/unittest/test_arith_iter_affine_map.py b/tests/python/unittest/test_arith_iter_affine_map.py
index 620540c..6ab61fd 100644
--- a/tests/python/unittest/test_arith_iter_affine_map.py
+++ b/tests/python/unittest/test_arith_iter_affine_map.py
@@ -161,6 +161,9 @@ def test_split():
     assert len(res) == 1
     assert_iter_sum_pattern(res[0], 8, 0, scale=2)
 
+    res = tvm.arith.detect_iter_map([fld(x, flm(flm(y, 8), 6))], var_dom([(x, 24), (y, 8)]))
+    assert len(res) == 0
+
 
 def test_compound():
     x = tvm.tir.Var("x", "int32"), 10