You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by wu...@apache.org on 2021/02/10 20:41:57 UTC
[tvm] 01/01: [Arith] Fix iter_affine_map with non-const extent
This is an automated email from the ASF dual-hosted git repository.
wuwei pushed a commit to branch fix/iter_map
in repository https://gitbox.apache.org/repos/asf/tvm.git
commit 693aa52c733c94f82b8c3f75a1bac245aca164ce
Author: Wuwei Lin <wu...@apache.org>
AuthorDate: Wed Feb 10 12:41:17 2021 -0800
[Arith] Fix iter_affine_map with non-const extent
---
src/arith/iter_affine_map.cc | 34 +++++++++++-----------
.../python/unittest/test_arith_iter_affine_map.py | 3 ++
2 files changed, 20 insertions(+), 17 deletions(-)
diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc
index 7896db7..170e825 100644
--- a/src/arith/iter_affine_map.cc
+++ b/src/arith/iter_affine_map.cc
@@ -412,8 +412,8 @@ class IterMapRewriter : public ExprMutator {
return analyzer_->CanProve(floormod(lhs, rhs) == 0);
}
- PrimExpr SplitFloorDivConst(IterSplitExpr lhs, PrimExpr rhs);
- PrimExpr SplitFloorModConst(IterSplitExpr lhs, PrimExpr rhs);
+ PrimExpr SplitFloorDivConst(IterSplitExpr lhs, PrimExpr rhs, const PrimExpr &orig);
+ PrimExpr SplitFloorModConst(IterSplitExpr lhs, PrimExpr rhs, const PrimExpr &orig);
static void AddToLhs(IterSumExprNode* lhs, IterSplitExpr rhs, int sign) {
tir::ExprDeepEqual equal;
@@ -577,14 +577,14 @@ PrimExpr IterMapRewriter::VisitExpr_(const MulNode* op) {
if (op->a.same_as(a) && op->b.same_as(b)) {
return GetRef<PrimExpr>(op);
} else {
- return Mul(a, b);
+ return GetRef<PrimExpr>(op);
}
}
if (a->IsInstance<IterMapExprNode>() && b->IsInstance<IterMapExprNode>()) {
// cannot multiply two iterators, mark as unresolved.
++unresolved_count_;
- return Mul(a, b);
+ return GetRef<PrimExpr>(op);
}
if (!a->IsInstance<IterMapExprNode>()) {
@@ -603,7 +603,7 @@ PrimExpr IterMapRewriter::VisitExpr_(const MulNode* op) {
}
}
-PrimExpr IterMapRewriter::SplitFloorDivConst(IterSplitExpr lhs, PrimExpr rhs) {
+PrimExpr IterMapRewriter::SplitFloorDivConst(IterSplitExpr lhs, PrimExpr rhs, const PrimExpr &orig) {
// floordiv(x*scale, rhs)
if (is_one(rhs)) return std::move(lhs);
if (!is_one(lhs->scale)) {
@@ -619,7 +619,7 @@ PrimExpr IterMapRewriter::SplitFloorDivConst(IterSplitExpr lhs, PrimExpr rhs) {
} else {
// mark as unresolved.
++unresolved_count_;
- return floordiv(lhs, rhs);
+ return orig;
}
}
}
@@ -641,7 +641,7 @@ PrimExpr IterMapRewriter::SplitFloorDivConst(IterSplitExpr lhs, PrimExpr rhs) {
} else {
// mark as unresolved.
++unresolved_count_;
- return floordiv(lhs, rhs);
+ return orig;
}
}
@@ -669,7 +669,7 @@ PrimExpr IterMapRewriter::VisitExpr_(const FloorDivNode* op) {
if (b->IsInstance<IterMapExprNode>()) {
// cannot divide an iterator, mark as unresolved.
++unresolved_count_;
- return FloorDiv(a, b);
+ return GetRef<PrimExpr>(op);
}
if (a->IsInstance<IterSumExprNode>()) {
@@ -678,16 +678,16 @@ PrimExpr IterMapRewriter::VisitExpr_(const FloorDivNode* op) {
return SplitFloorDivConst(opt.value(), b);
} else {
++unresolved_count_;
- return FloorDiv(a, b);
+ return GetRef<PrimExpr>(op);
}
} else {
ICHECK(a->IsInstance<IterSplitExprNode>());
IterSplitExpr ret = Downcast<IterSplitExpr>(std::move(a));
- return SplitFloorDivConst(ret, b);
+ return SplitFloorDivConst(ret, b, GetRef<PrimExpr>(op));
}
}
-PrimExpr IterMapRewriter::SplitFloorModConst(IterSplitExpr lhs, PrimExpr rhs) {
+PrimExpr IterMapRewriter::SplitFloorModConst(IterSplitExpr lhs, PrimExpr rhs, const PrimExpr &orig) {
// floormod(x*scale, rhs)
if (is_one(rhs)) return make_zero(lhs->dtype);
if (!is_one(lhs->scale)) {
@@ -701,7 +701,7 @@ PrimExpr IterMapRewriter::SplitFloorModConst(IterSplitExpr lhs, PrimExpr rhs) {
} else {
// mark as unresolved.
++unresolved_count_;
- return floormod(lhs, rhs);
+ return orig;
}
}
}
@@ -715,7 +715,7 @@ PrimExpr IterMapRewriter::SplitFloorModConst(IterSplitExpr lhs, PrimExpr rhs) {
} else {
// mark as unresolved.
++unresolved_count_;
- return floormod(lhs, rhs);
+ return orig;
}
}
@@ -743,21 +743,21 @@ PrimExpr IterMapRewriter::VisitExpr_(const FloorModNode* op) {
if (b->IsInstance<IterMapExprNode>()) {
// cannot mod an iterator, mark as unresolved.
++unresolved_count_;
- return FloorMod(a, b);
+ return GetRef<PrimExpr>(op);
}
if (a->IsInstance<IterSumExprNode>()) {
IterSumExpr ret = Downcast<IterSumExpr>(a);
if (auto opt = TryFuseIters(ret)) {
- return SplitFloorModConst(opt.value(), b);
+ return SplitFloorModConst(opt.value(), b, GetRef<PrimExpr>(op));
} else {
++unresolved_count_;
- return FloorMod(a, b);
+ return GetRef<PrimExpr>(op);
}
} else {
ICHECK(a->IsInstance<IterSplitExprNode>());
IterSplitExpr ret = Downcast<IterSplitExpr>(std::move(a));
- return SplitFloorModConst(ret, b);
+ return SplitFloorModConst(ret, b, GetRef<PrimExpr>(op))
}
}
diff --git a/tests/python/unittest/test_arith_iter_affine_map.py b/tests/python/unittest/test_arith_iter_affine_map.py
index 620540c..6ab61fd 100644
--- a/tests/python/unittest/test_arith_iter_affine_map.py
+++ b/tests/python/unittest/test_arith_iter_affine_map.py
@@ -161,6 +161,9 @@ def test_split():
assert len(res) == 1
assert_iter_sum_pattern(res[0], 8, 0, scale=2)
+ res = tvm.arith.detect_iter_map([fld(x, flm(flm(y, 8), 6))], var_dom([(x, 24), (y, 8)]))
+ assert len(res) == 0
+
def test_compound():
x = tvm.tir.Var("x", "int32"), 10