You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by wu...@apache.org on 2021/02/10 20:41:57 UTC

[tvm] 01/01: [Arith] Fix iter_affine_map with non-const extent

This is an automated email from the ASF dual-hosted git repository.

wuwei pushed a commit to branch fix/iter_map
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit 693aa52c733c94f82b8c3f75a1bac245aca164ce
Author: Wuwei Lin <wu...@apache.org>
AuthorDate: Wed Feb 10 12:41:17 2021 -0800

    [Arith] Fix iter_affine_map with non-const extent
---
 src/arith/iter_affine_map.cc                       | 34 +++++++++++-----------
 .../python/unittest/test_arith_iter_affine_map.py  |  3 ++
 2 files changed, 20 insertions(+), 17 deletions(-)

diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc
index 7896db7..170e825 100644
--- a/src/arith/iter_affine_map.cc
+++ b/src/arith/iter_affine_map.cc
@@ -412,8 +412,8 @@ class IterMapRewriter : public ExprMutator {
     return analyzer_->CanProve(floormod(lhs, rhs) == 0);
   }
 
-  PrimExpr SplitFloorDivConst(IterSplitExpr lhs, PrimExpr rhs);
-  PrimExpr SplitFloorModConst(IterSplitExpr lhs, PrimExpr rhs);
+  PrimExpr SplitFloorDivConst(IterSplitExpr lhs, PrimExpr rhs, const PrimExpr &orig);
+  PrimExpr SplitFloorModConst(IterSplitExpr lhs, PrimExpr rhs, const PrimExpr &orig);
 
   static void AddToLhs(IterSumExprNode* lhs, IterSplitExpr rhs, int sign) {
     tir::ExprDeepEqual equal;
@@ -577,14 +577,14 @@ PrimExpr IterMapRewriter::VisitExpr_(const MulNode* op) {
     if (op->a.same_as(a) && op->b.same_as(b)) {
       return GetRef<PrimExpr>(op);
     } else {
-      return Mul(a, b);
+      return GetRef<PrimExpr>(op);
     }
   }
 
   if (a->IsInstance<IterMapExprNode>() && b->IsInstance<IterMapExprNode>()) {
     // cannot multiply two iterators, mark as unresolved.
     ++unresolved_count_;
-    return Mul(a, b);
+    return GetRef<PrimExpr>(op);
   }
 
   if (!a->IsInstance<IterMapExprNode>()) {
@@ -603,7 +603,7 @@ PrimExpr IterMapRewriter::VisitExpr_(const MulNode* op) {
   }
 }
 
-PrimExpr IterMapRewriter::SplitFloorDivConst(IterSplitExpr lhs, PrimExpr rhs) {
+PrimExpr IterMapRewriter::SplitFloorDivConst(IterSplitExpr lhs, PrimExpr rhs, const PrimExpr &orig) {
   // floordiv(x*scale, rhs)
   if (is_one(rhs)) return std::move(lhs);
   if (!is_one(lhs->scale)) {
@@ -619,7 +619,7 @@ PrimExpr IterMapRewriter::SplitFloorDivConst(IterSplitExpr lhs, PrimExpr rhs) {
       } else {
         // mark as unresolved.
         ++unresolved_count_;
-        return floordiv(lhs, rhs);
+        return orig;
       }
     }
   }
@@ -641,7 +641,7 @@ PrimExpr IterMapRewriter::SplitFloorDivConst(IterSplitExpr lhs, PrimExpr rhs) {
   } else {
     // mark as unresolved.
     ++unresolved_count_;
-    return floordiv(lhs, rhs);
+    return orig;
   }
 }
 
@@ -669,7 +669,7 @@ PrimExpr IterMapRewriter::VisitExpr_(const FloorDivNode* op) {
   if (b->IsInstance<IterMapExprNode>()) {
     // cannot divide an iterator, mark as unresolved.
     ++unresolved_count_;
-    return FloorDiv(a, b);
+    return GetRef<PrimExpr>(op);
   }
 
   if (a->IsInstance<IterSumExprNode>()) {
@@ -678,16 +678,16 @@ PrimExpr IterMapRewriter::VisitExpr_(const FloorDivNode* op) {
       return SplitFloorDivConst(opt.value(), b);
     } else {
       ++unresolved_count_;
-      return FloorDiv(a, b);
+      return GetRef<PrimExpr>(op);
     }
   } else {
     ICHECK(a->IsInstance<IterSplitExprNode>());
     IterSplitExpr ret = Downcast<IterSplitExpr>(std::move(a));
-    return SplitFloorDivConst(ret, b);
+    return SplitFloorDivConst(ret, b, GetRef<PrimExpr>(op));
   }
 }
 
-PrimExpr IterMapRewriter::SplitFloorModConst(IterSplitExpr lhs, PrimExpr rhs) {
+PrimExpr IterMapRewriter::SplitFloorModConst(IterSplitExpr lhs, PrimExpr rhs, const PrimExpr &orig) {
   // floormod(x*scale, rhs)
   if (is_one(rhs)) return make_zero(lhs->dtype);
   if (!is_one(lhs->scale)) {
@@ -701,7 +701,7 @@ PrimExpr IterMapRewriter::SplitFloorModConst(IterSplitExpr lhs, PrimExpr rhs) {
       } else {
         // mark as unresolved.
         ++unresolved_count_;
-        return floormod(lhs, rhs);
+        return orig;
       }
     }
   }
@@ -715,7 +715,7 @@ PrimExpr IterMapRewriter::SplitFloorModConst(IterSplitExpr lhs, PrimExpr rhs) {
   } else {
     // mark as unresolved.
     ++unresolved_count_;
-    return floormod(lhs, rhs);
+    return orig;
   }
 }
 
@@ -743,21 +743,21 @@ PrimExpr IterMapRewriter::VisitExpr_(const FloorModNode* op) {
   if (b->IsInstance<IterMapExprNode>()) {
     // cannot mod an iterator, mark as unresolved.
     ++unresolved_count_;
-    return FloorMod(a, b);
+    return GetRef<PrimExpr>(op);
   }
 
   if (a->IsInstance<IterSumExprNode>()) {
     IterSumExpr ret = Downcast<IterSumExpr>(a);
     if (auto opt = TryFuseIters(ret)) {
-      return SplitFloorModConst(opt.value(), b);
+      return SplitFloorModConst(opt.value(), b, GetRef<PrimExpr>(op));
     } else {
       ++unresolved_count_;
-      return FloorMod(a, b);
+      return GetRef<PrimExpr>(op);
     }
   } else {
     ICHECK(a->IsInstance<IterSplitExprNode>());
     IterSplitExpr ret = Downcast<IterSplitExpr>(std::move(a));
-    return SplitFloorModConst(ret, b);
+    return SplitFloorModConst(ret, b, GetRef<PrimExpr>(op))
   }
 }
 
diff --git a/tests/python/unittest/test_arith_iter_affine_map.py b/tests/python/unittest/test_arith_iter_affine_map.py
index 620540c..6ab61fd 100644
--- a/tests/python/unittest/test_arith_iter_affine_map.py
+++ b/tests/python/unittest/test_arith_iter_affine_map.py
@@ -161,6 +161,9 @@ def test_split():
     assert len(res) == 1
     assert_iter_sum_pattern(res[0], 8, 0, scale=2)
 
+    res = tvm.arith.detect_iter_map([fld(x, flm(flm(y, 8), 6))], var_dom([(x, 24), (y, 8)]))
+    assert len(res) == 0
+
 
 def test_compound():
     x = tvm.tir.Var("x", "int32"), 10