You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by sy...@apache.org on 2021/12/25 04:13:43 UTC

[tvm] branch main updated: [TIR] Affine utility support iter lowerbound and diagnostics (#9699)

This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new e2dcba2  [TIR] Affine utility support iter lowerbound and diagnostics (#9699)
e2dcba2 is described below

commit e2dcba2fde4f129c04bef6ab6310c74ba05f2a36
Author: wrongtest <wr...@gmail.com>
AuthorDate: Sat Dec 25 12:13:09 2021 +0800

    [TIR] Affine utility support iter lowerbound and diagnostics (#9699)
    
    * Enable freevars, iter lowerbound and diagnostics in affine utility
    
    * fix lint issues and compare bug
    
    * update to use iter shift instead of itermark min for lowerbound
    
    * add testcase of fused iters sum with multiple lowerbounds
    
    * add more affine check testcases, fix bug for single iter and duplicate constraints on iter
    
    * add a newline to comment
    
    * forbidden predicate unmatch
    
    Co-authored-by: baoxinqi <wr...@intellif.com>
---
 include/tvm/arith/iter_affine_map.h                |   8 +-
 src/arith/int_set.cc                               |   4 +-
 src/arith/iter_affine_map.cc                       | 471 ++++++++++++++++-----
 src/tir/schedule/analysis/analysis.cc              |   4 +-
 tests/python/unittest/test_arith_intset.py         |  87 ++++
 .../python/unittest/test_arith_iter_affine_map.py  | 200 +++++++++
 tests/python/unittest/test_tir_schedule_reorder.py |   5 +-
 tests/python/unittest/test_tir_schedule_rfactor.py |   8 +-
 .../test_tir_schedule_state_cached_flags.py        |  55 +++
 9 files changed, 720 insertions(+), 122 deletions(-)

diff --git a/include/tvm/arith/iter_affine_map.h b/include/tvm/arith/iter_affine_map.h
index 6c72cbe..22b4cd5 100644
--- a/include/tvm/arith/iter_affine_map.h
+++ b/include/tvm/arith/iter_affine_map.h
@@ -49,6 +49,7 @@
 #define TVM_ARITH_ITER_AFFINE_MAP_H_
 
 #include <tvm/arith/analyzer.h>
+#include <tvm/ir/diagnostic.h>
 #include <tvm/ir/expr.h>
 #include <tvm/tir/var.h>
 
@@ -275,13 +276,14 @@ class IterSumExpr : public IterMapExpr {
  * \param predicate The predicate constraints on the input iterators
  * \param require_bijective A boolean flag that indicates whether the mapping should be bijective.
  * \param analyzer Analyzer used to get context information.
+ * \param diag_ctx Diagnostic context.
  *
  * \return The detected pattern if a match exists,
  *         otherwise return an empty array.
  */
 Array<IterSumExpr> DetectIterMap(const Array<PrimExpr>& indices, const Map<Var, Range>& input_iters,
                                  const PrimExpr& predicate, bool require_bijective,
-                                 arith::Analyzer* analyzer);
+                                 arith::Analyzer* analyzer, DiagnosticContext diag_ctx);
 /*!
  * \brief Use IterVarMap detector to rewrite and simplify the indices
  *
@@ -333,6 +335,7 @@ Map<Var, PrimExpr> InverseAffineIterMap(const Array<IterSumExpr>& iter_map,
  * \param predicate The predicate constraints on the input iterators
  * \param require_bijective A boolean flag that indicates whether the mapping should be bijective.
  * \param analyzer Analyzer used to get context information.
+ * \param diag_ctx Diagnostic context.
  *
  * \return The result list has length len(bindings) + 1
         [0, len(bindings)): The iter map matching result. The inner list is of length 2.
@@ -344,7 +347,8 @@ Map<Var, PrimExpr> InverseAffineIterMap(const Array<IterSumExpr>& iter_map,
 Array<Array<IterMark>> SubspaceDivide(const Array<PrimExpr>& bindings,
                                       const Map<Var, Range>& input_iters,
                                       const Array<Var>& sub_iters, const PrimExpr& predicate,
-                                      bool require_bijective, arith::Analyzer* analyzer);
+                                      bool require_bijective, arith::Analyzer* analyzer,
+                                      DiagnosticContext diag_ctx);
 
 }  // namespace arith
 }  // namespace tvm
diff --git a/src/arith/int_set.cc b/src/arith/int_set.cc
index e620e3b..55a1a5a 100644
--- a/src/arith/int_set.cc
+++ b/src/arith/int_set.cc
@@ -835,9 +835,10 @@ Optional<Array<IntSet>> EstimateRegionLowerBound(const Array<Range>& region,
     for (const Range& range : region) {
       affine_indices.push_back(range->min);
     }
+    DiagnosticContext diag_ctx(DiagnosticContext::Default(IRModule()));
     iter_sum_exprs = DetectIterMap(
         /*indices=*/affine_indices, /*input_iters=*/var_dom,
-        /*predicate=*/predicate, /*require_bijective=*/false, analyzer);
+        /*predicate=*/predicate, /*require_bijective=*/false, analyzer, diag_ctx);
   }
   if (iter_sum_exprs.empty()) {
     return NullOpt;
@@ -857,6 +858,7 @@ Optional<Array<IntSet>> EstimateRegionLowerBound(const Array<Range>& region,
     if (!analyzer->CanProve(range->extent >= split->scale)) {
       return NullOpt;
     }
+
     const PrimExpr& base = sum_expr->base;
     // IterSplitExpr: (source // lower_factor) % extent * scale
     // where `(source // lower_factor) % extent` is within [0, extent - 1]
diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc
index 34c35ce..c9d4b1e 100644
--- a/src/arith/iter_affine_map.cc
+++ b/src/arith/iter_affine_map.cc
@@ -160,13 +160,22 @@ class IterMarkSplitCollector {
   }
 };
 
+/*! \brief Record form of IterMark(x, extent) + offset */
+struct IterMarkWithOffset {
+  IterMark mark;
+  PrimExpr offset{0};
+  IterMarkWithOffset() {}
+  IterMarkWithOffset(IterMark mark, PrimExpr offset) : mark(mark), offset(offset) {}
+};
+
 /*! \brief Rewriter to rewrite PrimExpr to IterMapExpr when possible */
 class IterMapRewriter : public ExprMutator {
  public:
   using Parent = ExprMutator;
 
-  explicit IterMapRewriter(Analyzer* analyzer, const Map<Var, Range>& input_iters)
-      : analyzer_(analyzer) {
+  explicit IterMapRewriter(Analyzer* analyzer, const Map<Var, Range>& input_iters,
+                           DiagnosticContext diag_ctx)
+      : analyzer_(analyzer), diag_ctx_(diag_ctx) {
     for (auto kv : input_iters) {
       const Var& var = kv.first;
       const Range& vrng = kv.second;
@@ -192,9 +201,10 @@ class IterMapRewriter : public ExprMutator {
     return NormalizeToIterWithOffset(ToIterSumExpr(DirectMutate(expr)));
   }
 
-  IterSumExpr RewriteIterConstraint(const PrimExpr& expr,
-                                    const PrimExpr& predicate_induced_extent) {
-    return NormalizeToIterOnBoundExpr(ToIterSumExpr(DirectMutate(expr)), predicate_induced_extent);
+  IterSumExpr RewriteIterConstraint(const PrimExpr& expr, const PrimExpr& predicate_induced_min,
+                                    const PrimExpr& predicate_induced_max) {
+    return NormalizeToIterOnBoundExpr(ToIterSumExpr(DirectMutate(expr)), predicate_induced_min,
+                                      predicate_induced_max);
   }
 
   /*!
@@ -224,13 +234,21 @@ class IterMapRewriter : public ExprMutator {
     // The splits do not overlap with each other.
     collector.Collect(bindings);
     for (const IterMark& mark : collector.visited_) {
-      if (TryNormalizeSplits(mark, collector.mark2splits_[mark], require_bijective).empty())
+      if (TryNormalizeSplits(mark, collector.mark2splits_[mark], require_bijective).empty()) {
+        diag_ctx_.Emit(Diagnostic::Error(mark->source->span)
+                       << "Fail to normalize iter mark splits: " << mark);
         return false;
+      }
     }
     if (require_bijective) {
       // all input marks must be visited
       for (const IterMark& mark : input_marks_) {
-        if (collector.visited_.count(mark) == 0) return false;
+        if (collector.visited_.count(mark) == 0) {
+          diag_ctx_.Emit(Diagnostic::Error(mark->source->span)
+                         << "The mapping is not bijective because input iter mark " << mark
+                         << " is not covered, ");
+          return false;
+        }
       }
     }
     return true;
@@ -278,7 +296,7 @@ class IterMapRewriter : public ExprMutator {
   PrimExpr VisitExpr(const PrimExpr& input_expr) final {
     auto expr = ExprMutator::VisitExpr(input_expr);
     if (expr->IsInstance<IterMapExprNode>()) {
-      ++unresolved_count_;
+      Fail(Diagnostic::Error(input_expr->span));
     }
     return expr;
   }
@@ -328,6 +346,13 @@ class IterMapRewriter : public ExprMutator {
     }
   };
 
+  void Fail(const Diagnostic& diagnostic) {
+    unresolved_count_++;
+    if (diag_ctx_.defined()) {
+      diag_ctx_.Emit(diagnostic);
+    }
+  }
+
   // Internal analyzer
   Analyzer* analyzer_;
   // Counter to keep track of unresolved cases.
@@ -336,8 +361,9 @@ class IterMapRewriter : public ExprMutator {
   std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> var_map_;
   // input iter marks
   std::vector<IterMark> input_marks_;
-  // The map for sum that maps flattened form to IterMark with normal form and extent
-  // Example: expr = i*9 + j*2 + k, i in [0, 4) j in [0, 5) k in [0, 2)
+  // The map for sum that maps flattened form to IterMark with normal form and extent (and possibly
+  // an extra offset)
+  // Example(1): expr = i*9 + j*2 + k, i in [0, 4) j in [0, 5) k in [0, 2)
   //          predicate: j*2 + k < 9
   // Then,    flattened form = IterSum(IterSplit(i, scale=9),
   //                                   IterSplit(j, scale=2),
@@ -347,11 +373,24 @@ class IterMapRewriter : public ExprMutator {
   //                                                              IterSplit(k, scale=1)),
   //                                                      extent=9)
   //                                             scale=1))
-  std::unordered_map<IterSumExpr, IterMark, IterSumHash, IterSumEqual> sum_fuse_map_;
+  // Example(2): expr = i*8 + j*2 + k, i in [0, 4) j in [0, 5) k in [0, 2)
+  //          predicate: 1 <= j*2 + k < 9
+  // Then,    flattened form = IterSum(IterSplit(i, scale=8),
+  //                                   IterSplit(j, scale=2),
+  //                                   IterSplit(k, scale=1))
+  //          normal form    = IterSum(IterSplit(i, scale=8),
+  //                                   IterSplit(IterMark(IterSum(IterSplit(j, scale=2),
+  //                                                              IterSplit(k, scale=1), base=-1),
+  //                                                      extent=9-1)
+  //                                             scale=1),
+  //                                   base=1)
+  std::unordered_map<IterSumExpr, IterMarkWithOffset, IterSumHash, IterSumEqual> sum_fuse_map_;
   // The map for sum that maps normal form to flattened form
   std::unordered_map<IterSumExpr, IterSumExpr, IterSumHash, IterSumEqual> flattened_map_;
   // The flattened forms of constrained iters
   std::vector<IterSumExpr> constrained_iters_flattened_;
+  // Diagnostic context
+  DiagnosticContext diag_ctx_;
 
   /*!
    * \brief Look for a split in splits that is not used such that its lower_factor is smallest.
@@ -407,19 +446,32 @@ class IterMapRewriter : public ExprMutator {
       }
       if (j == splits.size()) {
         // we do not allow incomplete split if the bindings should be bijective
-        if (require_bijective) return Array<IterSplitExpr>();
+        if (require_bijective) {
+          diag_ctx_.Emit(
+              Diagnostic::Error(mark->source->span)
+              << "Do not allow incomplete split in bijective checking, expected_lower_factor="
+              << expected_lower_factor);
+          return Array<IterSplitExpr>();
+        }
         // look for the next split skipping this lower factor
         // For example, y \in [0, 24) has 3 splits [y / 6, (y / 2) % 6, y % 2]
         // It is valid to only have [y / 6, y % 2] if bijective is not required
         // We can skip (y / 2) % 6
         j = SearchSkipLowerFactor(splits, used, expected_lower_factor);
         // split not found
-        if (j == splits.size()) return Array<IterSplitExpr>();
+        if (j == splits.size()) {
+          diag_ctx_.Emit(Diagnostic::Error(mark->source->span)
+                         << "Fail to find split skipping the lower factor in bijective-free "
+                            "checking, expected_lower_factor="
+                         << expected_lower_factor);
+          return Array<IterSplitExpr>();
+        }
       }
       used[j] = true;
       iters.push_back(splits[j]);
       expected_lower_factor = splits[j]->lower_factor * splits[j]->extent;
     }
+
     // Case 1. bijective is required.
     //         We check the extent we calculate is consistent with the extent of the mark
     // Case 2. bijective is not required.
@@ -427,42 +479,73 @@ class IterMapRewriter : public ExprMutator {
     //         For example, y \in [0, 24) [(y / 2) % 6, y % 2] is valid, but y \in [0, 25) is not.
     if ((require_bijective && !analyzer_->CanProveEqual(expected_lower_factor, mark->extent)) ||
         (!require_bijective && !CanProveDivisible(mark->extent, expected_lower_factor))) {
+      diag_ctx_.Emit(Diagnostic::Error(mark->source->span)
+                     << "Mark extent of " << mark
+                     << " is not compatible with expected_lower_factor=" << expected_lower_factor);
       return Array<IterSplitExpr>();
     }
     return Array<IterSplitExpr>(iters.rbegin(), iters.rend());
   }
 
   /*!
-   * \brief Normalize the left hand side of iter constraint(expr < predicate_induced_extent)
-   * \param expr The left hand side of iter constraint.
-   * \param predicate_induced_extent Extent from iter constraint.
+   * \brief Normalize the iter expression with constraint (min <= expr < max)
+   * \param expr The iter expression.
+   * \param predicate_induced_min Closed lower bound from iter constraint, maybe undefined.
+   * \param predicate_induced_max Open upper bound from iter constraint, maybe undefined.
    * \return The Normalized expression.
    */
-  IterSumExpr NormalizeToIterOnBoundExpr(IterSumExpr expr,
-                                         const PrimExpr& predicate_induced_extent) {
-    // We are normalizing the left hand side of iter constraint(iter < predicate_induced_extent)
-    Optional<IterSplitExpr> opt = TryFuseIters(expr);
+  IterSumExpr NormalizeToIterOnBoundExpr(IterSumExpr expr, PrimExpr predicate_induced_min,
+                                         PrimExpr predicate_induced_max) {
+    // normalize to zero base
+    PrimExpr base = expr->base;
+    if (!is_zero(base)) {
+      expr.CopyOnWrite()->base = 0;
+      if (predicate_induced_min.defined()) predicate_induced_min = predicate_induced_min - base;
+      if (predicate_induced_max.defined()) predicate_induced_max = predicate_induced_max - base;
+    }
+    Optional<IterSumExpr> opt = TryFuseIters(expr);
+    ICHECK(!opt.defined() || opt.value()->args.size() == 1);
     // scale should be 1
-    if (opt.defined() && is_one(opt.value()->scale)) {
-      IterSumExpr sum = Downcast<IterSumExpr>(opt.value()->source->source);
+    if (opt.defined() && is_one(opt.value()->args[0]->scale)) {
+      const IterSplitExpr split = opt.value()->args[0];
+      IterSumExpr structured_form = Downcast<IterSumExpr>(split->source->source);
       // get the flattened form
-      auto it = flattened_map_.find(sum);
+      auto it = flattened_map_.find(structured_form);
       ICHECK(it != flattened_map_.end());
       IterSumExpr flattened_form = it->second;
-      // get the mark
+      // get the mark and offset of the structured_form
       auto it_mark = sum_fuse_map_.find(flattened_form);
       ICHECK(it_mark != sum_fuse_map_.end());
-      IterMark mark = it_mark->second;
-      mark.CopyOnWrite()->extent = min(predicate_induced_extent, mark->extent);
-      // update the bound of the lhs based on predicate_induced_extent
-      sum_fuse_map_[flattened_form] = mark;
+      IterMark mark = it_mark->second.mark;
+      PrimExpr mark_offset = it_mark->second.offset;
+      PrimExpr iter_min = mark_offset;
+      PrimExpr iter_max = iter_min + mark->extent;
+      if (predicate_induced_min.defined()) {
+        iter_min = max(predicate_induced_min, iter_min);
+      }
+      if (predicate_induced_max.defined()) {
+        iter_max = min(predicate_induced_max, iter_max);
+      }
+      if (!is_zero(iter_min)) {
+        // structured form's offset should be updated
+        flattened_map_.erase(structured_form);
+        structured_form.CopyOnWrite()->base = -iter_min;
+        mark.CopyOnWrite()->source = structured_form;
+        flattened_map_[structured_form] = flattened_form;
+      }
+      mark.CopyOnWrite()->extent = iter_max - iter_min;
+      sum_fuse_map_[flattened_form] = {mark, iter_min};
+
       // we need to note down the flattened form of constrained iterators
       // to check the validity of constraints, see also CheckConstraints()
       constrained_iters_flattened_.push_back(flattened_form);
-      expr.CopyOnWrite()->args = Array<IterSplitExpr>({opt.value()});
+      expr.CopyOnWrite()->args = Array<IterSplitExpr>({split});
+      expr.CopyOnWrite()->base = base + iter_min;
       return expr;
     }
-    ++unresolved_count_;
+    Fail(Diagnostic::Error(expr->span)
+         << "Fail to normalize " << expr << " with predicate bound [" << predicate_induced_min
+         << ", " << predicate_induced_max << ")");
     return expr;
   }
 
@@ -473,16 +556,12 @@ class IterMapRewriter : public ExprMutator {
    */
   IterSumExpr NormalizeToIterWithOffset(IterSumExpr expr) {
     // We are normalizing a regular iter
-    if (expr->args.size() <= 1) return expr;
-    PrimExpr base = expr->base;
-    expr.CopyOnWrite()->base = make_zero(expr->dtype);
-    Optional<IterSplitExpr> opt = TryFuseIters(expr);
-    expr.CopyOnWrite()->base = base;
+    if (expr->args.size() < 1) return expr;
+    Optional<IterSumExpr> opt = TryFuseIters(expr);
     if (opt.defined()) {
-      expr.CopyOnWrite()->args = Array<IterSplitExpr>({opt.value()});
-      return expr;
+      return opt.value();
     } else {
-      ++unresolved_count_;
+      Fail(Diagnostic::Error(expr->span) << "Fail to normalize iter sum with offset: " << expr);
       return expr;
     }
   }
@@ -504,17 +583,16 @@ class IterMapRewriter : public ExprMutator {
   }
 
   /*!
-   * \brief IterSum = x1*c1 + x2*c2 + ... + xn*cn
-   *      = (x1*s1 + x2*s2 + ... + xn)*cn
-   *      = y*cn (IterMark y => x1*s1 + x2*s2 + ... + xn)
-   *      = [IterSplit(IterMark(y), scale=cn)]
-   *    return a corresponding IterSplitExpr if needed.
+   * \brief IterSum = x1*c1 + x2*c2 + ... + xn*cn + base
+   *      = (x1*s1 + x2*s2 + ... + xn)*cn + base
+   *      = y*cn (IterMark y => x1*s1 + x2*s2 + ... + xn) + base
+   *      = [IterSplit(IterMark(y), scale=cn)] + base
+   *    return a corresponding IterSumExpr with extra offset if needed.
    *    Try to normalize IterSum into a fused IterMark
    * \param expr The input sum.
-   * \return The split with the fused IterMark if succeed.
+   * \return The sum with the fused IterMark and extra offset if succeed.
    */
-  Optional<IterSplitExpr> TryFuseIters(IterSumExpr expr) {
-    if (!is_zero(expr->base)) return NullOpt;
+  Optional<IterSumExpr> TryFuseIters(IterSumExpr expr) {
     // select the iterators in order
     std::vector<bool> visited(expr->args.size(), false);
     std::vector<IterSplitExpr> flattened_iters, grouped_iters;
@@ -530,8 +608,13 @@ class IterMapRewriter : public ExprMutator {
         }
       }
     }
-    if (!base_scale) return NullOpt;
+    if (!base_scale) {
+      diag_ctx_.Emit(Diagnostic::Error(expr->span)
+                     << "Fuse iters failed, can not find a valid base scale");
+      return NullOpt;
+    }
     // check if it can be remapped into a fused pattern.
+    PrimExpr expected_extra_base = 0;
     PrimExpr expected_scale = base_scale.value();
     for (size_t i = 0; i < expr->args.size();) {
       // find j such that expr->args[j] has expected scale
@@ -539,7 +622,11 @@ class IterMapRewriter : public ExprMutator {
       for (; j < expr->args.size(); ++j) {
         if (!visited[j] && analyzer_->CanProveEqual(expr->args[j]->scale, expected_scale)) break;
       }
-      if (j == expr->args.size()) return NullOpt;
+      if (j == expr->args.size()) {
+        diag_ctx_.Emit(Diagnostic::Error(expr->span)
+                       << "Fuse iters failed, can not find expected scale " << expected_scale);
+        return NullOpt;
+      }
       // look for the longest constrained iter started from expr->args[j]
       // Example: expr = i*9 + j*2 + k, i in [0, 4) j in [0, 5) k in [0, 2)
       //          predicate: j*2 + k < 9
@@ -569,15 +656,21 @@ class IterMapRewriter : public ExprMutator {
                 break;
             }
           }
-          if (k == expr->args.size()) return NullOpt;
+          if (k == expr->args.size()) {
+            diag_ctx_.Emit(Diagnostic::Error(expr->span)
+                           << "Fuse iters failed, can not find flattened iter match constraint "
+                           << constraint_to_match.value());
+            return NullOpt;
+          }
           visited[k] = true;
           flattened_iters.push_back(expr->args[k]);
         }
         auto iter = sum_fuse_map_.find(constraint_to_match.value());
         ICHECK(iter != sum_fuse_map_.end());
-        IterMark iter_matched = iter->second;
-        grouped_iters.emplace_back(iter_matched, expected_scale);
-        expected_scale *= iter_matched->extent;
+        const IterMarkWithOffset& iter_matched = iter->second;
+        grouped_iters.emplace_back(iter_matched.mark, expected_scale);
+        expected_extra_base += iter_matched.offset * expected_scale;
+        expected_scale *= iter_matched.mark->extent;
         // move forward
         i += constraint_to_match.value()->args.size();
       } else {
@@ -594,18 +687,28 @@ class IterMapRewriter : public ExprMutator {
     IterSumExpr structured_form = expr, flattened_form = expr;
     flattened_form.CopyOnWrite()->args =
         Array<IterSplitExpr>(flattened_iters.rbegin(), flattened_iters.rend());
+    flattened_form.CopyOnWrite()->base = 0;
     structured_form.CopyOnWrite()->args =
         Array<IterSplitExpr>(grouped_iters.rbegin(), grouped_iters.rend());
+    structured_form.CopyOnWrite()->base = 0;
     auto it = sum_fuse_map_.find(flattened_form);
     if (it != sum_fuse_map_.end()) {
       // old iter
-      return IterSplitExpr(it->second, base_scale.value());
+      if (!analyzer_->CanProveEqual(expected_extra_base, it->second.offset * base_scale.value())) {
+        // the extra offset is not consistent with old
+        diag_ctx_.Emit(Diagnostic::Error(expr->span)
+                       << "Fuse iters failed, the extra offset is not consistent with old");
+        return NullOpt;
+      }
+      return IterSumExpr({IterSplitExpr(it->second.mark, base_scale.value())},
+                         expr->base + expected_extra_base);
     } else {
       // new iter, form a new mark
       IterMark mark = IterMark(structured_form, div(expected_scale, base_scale.value()));
-      sum_fuse_map_[flattened_form] = mark;
+      sum_fuse_map_[flattened_form] = IterMarkWithOffset(mark, 0);
       flattened_map_[structured_form] = flattened_form;
-      return IterSplitExpr(mark, base_scale.value());
+      return IterSumExpr({IterSplitExpr(mark, base_scale.value())},
+                         expr->base + expected_extra_base);
     }
   }
 
@@ -667,34 +770,126 @@ class IterMapRewriter : public ExprMutator {
 struct IterConstraint {
   // The expr of the iter
   PrimExpr iter;
+  // The expr of the lower_bound
+  PrimExpr lower_bound;
   // The expr of the upper_bound
   PrimExpr upper_bound;
   // The size of the iter, which is the number of nodes
   size_t expr_size = 0;
 
-  IterConstraint(PrimExpr iter, PrimExpr upper_bound, size_t size)
-      : iter(std::move(iter)), upper_bound(std::move(upper_bound)), expr_size(size) {}
+  IterConstraint(PrimExpr iter, PrimExpr lower_bound, PrimExpr upper_bound, size_t size)
+      : iter(std::move(iter)),
+        lower_bound(std::move(lower_bound)),
+        upper_bound(std::move(upper_bound)),
+        expr_size(size) {}
 };
 
 /*!
  * \brief Split the predicate into `(a < b) && (c < d) && ...`
  * \param pred The predicate to be split.
- * \return A list of pairs, each element of which are lhs and rhs of the '<' sign,
- *         empty if the split failed.
+ * \return A list of IterConstraint, empty if the split failed.
  */
-std::vector<IterConstraint> MatchUpperBoundConstraints(PrimExpr pred) {
+std::vector<IterConstraint> MatchBoundConstraints(PrimExpr pred,
+                                                  const Map<Var, Range>& input_iters) {
   std::vector<IterConstraint> result;
   arith::PVar<PrimExpr> lhs, rhs, rest;
   for (;;) {
-    if ((rest && (lhs < rhs)).Match(pred)) {
-      result.emplace_back(lhs.Eval(), rhs.Eval(), 0);
-      pred = rest.Eval();
+    // try extract comparisions
+    bool is_finish = false;
+    bool is_greater = false;
+    bool is_equal = false;
+    if ((rest && (lhs < rhs)).Match(pred) || ((lhs < rhs) && rest).Match(pred)) {
+      // pass
     } else if ((lhs < rhs).Match(pred)) {
-      result.emplace_back(lhs.Eval(), rhs.Eval(), 0);
-      break;
+      is_finish = true;
+    } else if ((rest && (lhs <= rhs)).Match(pred) || ((lhs <= rhs) && rest).Match(pred)) {
+      is_equal = true;
+    } else if ((lhs <= rhs).Match(pred)) {
+      is_equal = true;
+      is_finish = true;
+    } else if ((rest && (lhs > rhs)).Match(pred) || ((lhs > rhs) && rest).Match(pred)) {
+      is_greater = true;
+    } else if ((lhs > rhs).Match(pred)) {
+      is_greater = true;
+      is_finish = true;
+    } else if ((rest && (lhs >= rhs)).Match(pred) || ((lhs >= rhs) && rest).Match(pred)) {
+      is_greater = true;
+      is_equal = true;
+    } else if ((lhs >= rhs).Match(pred)) {
+      is_greater = true;
+      is_equal = true;
+      is_finish = true;
     } else {
       return std::vector<IterConstraint>();
     }
+    PrimExpr lhs_expr = lhs.Eval();
+    PrimExpr rhs_expr = rhs.Eval();
+    // we only accept predicate of integers
+    if (!((lhs_expr->dtype.is_int() || lhs_expr->dtype.is_uint()) &&
+          (rhs_expr->dtype.is_int() || rhs_expr->dtype.is_uint()))) {
+      return std::vector<IterConstraint>();
+    }
+    // determine iter and bound, if we can not distinguish them simply,
+    // try divide (lhs - rhs) into itervar aware and itervar free parts
+    auto f_use_itervar = [&input_iters](const VarNode* v) {
+      return input_iters.count(GetRef<Var>(v));
+    };
+    bool bound_at_left;
+    if (is_const_int(lhs_expr) || !UsesVar(lhs_expr, f_use_itervar)) {
+      bound_at_left = true;
+    } else if (is_const_int(rhs_expr) || !UsesVar(rhs_expr, f_use_itervar)) {
+      bound_at_left = false;
+    } else {
+      bound_at_left = false;  // accumulate bound to rhs
+      PrimExpr sum_parts = lhs_expr - rhs_expr;
+      lhs_expr = 0;
+      rhs_expr = 0;
+      std::function<void(const PrimExpr&, bool)> f_extract =
+          [&lhs_expr, &rhs_expr, f_use_itervar, &f_extract](const PrimExpr& part, bool sign) {
+            if (const AddNode* add = part.as<AddNode>()) {
+              f_extract(add->a, sign);
+              f_extract(add->b, sign);
+            } else if (const SubNode* sub = part.as<SubNode>()) {
+              f_extract(sub->a, sign);
+              f_extract(sub->b, !sign);
+            } else if (UsesVar(part, f_use_itervar)) {
+              lhs_expr = sign ? lhs_expr + part : lhs_expr - part;
+            } else {
+              rhs_expr = sign ? rhs_expr - part : rhs_expr + part;
+            }
+          };
+      f_extract(sum_parts, true);
+      arith::Analyzer analyzer;
+      lhs_expr = analyzer.Simplify(lhs_expr);
+      rhs_expr = analyzer.Simplify(rhs_expr);
+    }
+    PrimExpr lower_bound, upper_bound, iter;
+    if (is_greater) {
+      if (bound_at_left) {
+        // bound > iter
+        upper_bound = is_equal ? lhs_expr + 1 : lhs_expr;
+        iter = rhs_expr;
+      } else {
+        // iter > bound
+        lower_bound = is_equal ? rhs_expr : rhs_expr + 1;
+        iter = lhs_expr;
+      }
+    } else {
+      if (bound_at_left) {
+        // bound < iter
+        lower_bound = is_equal ? lhs_expr : lhs_expr + 1;
+        iter = rhs_expr;
+      } else {
+        // iter < bound
+        upper_bound = is_equal ? rhs_expr + 1 : rhs_expr;
+        iter = lhs_expr;
+      }
+    }
+    result.emplace_back(iter, lower_bound, upper_bound, 0);
+    if (is_finish) {
+      break;
+    }
+    pred = rest.Eval();
   }
   return result;
 }
@@ -711,14 +906,17 @@ bool IterRangeSanityCheck(const Map<Var, Range>& iter_ranges) {
 
 Array<IterSumExpr> DetectIterMap(const Array<PrimExpr>& indices, const Map<Var, Range>& input_iters,
                                  const PrimExpr& predicate, bool require_bijective,
-                                 arith::Analyzer* analyzer) {
+                                 arith::Analyzer* analyzer, DiagnosticContext diag_ctx) {
   // Overall detection algorithm is divided into two steps:
   // - Step0: IterMapRewriter rewrites the expression to use IterMapExpr patterns.
   // - Step1: IterIndependenceChecker checks if the iterator are independent.
-
   if (!IterRangeSanityCheck(input_iters)) return Array<IterSumExpr>();
-  std::vector<IterConstraint> constraints = MatchUpperBoundConstraints(predicate);
-  if (!is_one(predicate) && constraints.empty()) return Array<IterSumExpr>();
+  std::vector<IterConstraint> constraints = MatchBoundConstraints(predicate, input_iters);
+  if (!is_one(predicate) && constraints.empty()) {
+    diag_ctx.Emit(Diagnostic::Error(predicate->span)
+                  << "Fail to collect constraints from iteration predicate: " << predicate);
+    return Array<IterSumExpr>();
+  }
 
   // We have to make sure when we visit an iterator, all the constraints related with its successors
   // in the iter var graph has been visited, where the expression of this iterator will contain the
@@ -731,13 +929,17 @@ Array<IterSumExpr> DetectIterMap(const Array<PrimExpr>& indices, const Map<Var,
       constraints.begin(), constraints.end(),
       [](const IterConstraint& a, const IterConstraint& b) { return a.expr_size < b.expr_size; });
 
-  IterMapRewriter rewriter(analyzer, input_iters);
+  IterMapRewriter rewriter(analyzer, input_iters, diag_ctx);
   // Step0.0: rewrite constraints in the order from size-small ones to size-big ones
   for (const IterConstraint& constraint : constraints) {
-    PrimExpr res = rewriter.RewriteIterConstraint(constraint.iter, constraint.upper_bound);
+    rewriter.RewriteIterConstraint(constraint.iter, constraint.lower_bound, constraint.upper_bound);
     if (rewriter.unresolved_count() != 0) return Array<IterSumExpr>();
   }
-  if (!rewriter.CheckConstraints()) return Array<IterSumExpr>();
+  if (!rewriter.CheckConstraints()) {
+    diag_ctx.Emit(Diagnostic::Error(predicate->span)
+                  << "Illegal iteration constraints: " << predicate);
+    return Array<IterSumExpr>();
+  }
   // Step0.1: rewrite indices
   Array<IterSumExpr> results;
   for (PrimExpr value : indices) {
@@ -745,7 +947,10 @@ Array<IterSumExpr> DetectIterMap(const Array<PrimExpr>& indices, const Map<Var,
     if (rewriter.unresolved_count() != 0) return Array<IterSumExpr>();
   }
   // Step1: IterIndependenceChecker checks if the iterator are independent.
-  if (!rewriter.CheckMapping(results, require_bijective)) return Array<IterSumExpr>();
+  if (!rewriter.CheckMapping(results, require_bijective)) {
+    diag_ctx.Emit(Diagnostic::Error(predicate->span) << "Iterators are not independent");
+    return Array<IterSumExpr>();
+  }
 
   return results;
 }
@@ -754,7 +959,8 @@ TVM_REGISTER_GLOBAL("arith.DetectIterMap")
     .set_body_typed([](const Array<PrimExpr>& indices, const Map<Var, Range>& input_iters,
                        const PrimExpr& input_pred, bool is_bijective) {
       arith::Analyzer ana;
-      return DetectIterMap(indices, input_iters, input_pred, is_bijective, &ana);
+      DiagnosticContext diag_ctx(DiagnosticContext::Default(IRModule()));
+      return DetectIterMap(indices, input_iters, input_pred, is_bijective, &ana, diag_ctx);
     });
 
 PrimExpr IterMapRewriter::VisitExpr_(const VarNode* op) {
@@ -768,7 +974,6 @@ PrimExpr IterMapRewriter::VisitExpr_(const AddNode* op) {
   if (!IsIndexType(op->dtype)) {
     return Parent::VisitExpr_(op);
   }
-
   PrimExpr a = this->DirectMutate(op->a);
   PrimExpr b = this->DirectMutate(op->b);
 
@@ -858,7 +1063,7 @@ PrimExpr IterMapRewriter::VisitExpr_(const MulNode* op) {
 
   if (a->IsInstance<IterMapExprNode>() && b->IsInstance<IterMapExprNode>()) {
     // cannot multiply two iterators, mark as unresolved.
-    ++unresolved_count_;
+    Fail(Diagnostic::Error(op->span) << "Cannot multiply two iterators: " << GetRef<PrimExpr>(op));
     return GetRef<PrimExpr>(op);
   }
 
@@ -894,7 +1099,9 @@ PrimExpr IterMapRewriter::SplitFloorDivConst(IterSplitExpr lhs, PrimExpr rhs,
         lhs.CopyOnWrite()->scale = make_const(rhs->dtype, 1);
       } else {
         // mark as unresolved.
-        ++unresolved_count_;
+        Fail(Diagnostic::Error(orig->span)
+             << "Can not prove floordiv rhs " << rhs << " divisible by lhs scale " << lhs->scale
+             << ", lhs=" << lhs);
         return orig;
       }
     }
@@ -916,7 +1123,8 @@ PrimExpr IterMapRewriter::SplitFloorDivConst(IterSplitExpr lhs, PrimExpr rhs,
     return std::move(lhs);
   } else {
     // mark as unresolved.
-    ++unresolved_count_;
+    Fail(Diagnostic::Error(orig->span)
+         << "Can not prove floordiv lhs extent " << lhs->extent << " divisible by rhs " << rhs);
     return orig;
   }
 }
@@ -944,16 +1152,24 @@ PrimExpr IterMapRewriter::VisitExpr_(const FloorDivNode* op) {
 
   if (b->IsInstance<IterMapExprNode>()) {
     // cannot divide an iterator, mark as unresolved.
-    ++unresolved_count_;
+    Fail(Diagnostic::Error(op->span) << "Cannot divide an iterator: " << GetRef<PrimExpr>(op));
     return GetRef<PrimExpr>(op);
   }
 
   if (a->IsInstance<IterSumExprNode>()) {
     IterSumExpr ret = Downcast<IterSumExpr>(a);
-    if (Optional<IterSplitExpr> opt = TryFuseIters(ret)) {
-      return SplitFloorDivConst(opt.value(), b, GetRef<PrimExpr>(op));
+    if (Optional<IterSumExpr> opt = TryFuseIters(ret)) {
+      IterSumExpr sum = opt.value();
+      if (!is_zero(sum->base)) {
+        Fail(Diagnostic::Error(op->span)
+             << "Fuse IterSumExpr " << ret
+             << " failed, cannot floordiv an IterSumExpr with nonzero base");
+        return GetRef<PrimExpr>(op);
+      }
+      ICHECK_EQ(sum->args.size(), 1U);
+      return SplitFloorDivConst(sum->args[0], b, GetRef<PrimExpr>(op));
     } else {
-      ++unresolved_count_;
+      Fail(Diagnostic::Error(op->span) << "Fuse IterSumExpr " << ret << " failed");
       return GetRef<PrimExpr>(op);
     }
   } else {
@@ -977,7 +1193,8 @@ PrimExpr IterMapRewriter::SplitFloorModConst(IterSplitExpr lhs, PrimExpr rhs,
         rhs = floordiv(rhs, lhs->scale);
       } else {
         // mark as unresolved.
-        ++unresolved_count_;
+        Fail(Diagnostic::Error(orig->span) << "Can not prove floormod rhs " << rhs
+                                           << " divisible by " << lhs->scale << ", lhs=" << lhs);
         return orig;
       }
     }
@@ -991,7 +1208,8 @@ PrimExpr IterMapRewriter::SplitFloorModConst(IterSplitExpr lhs, PrimExpr rhs,
     return std::move(lhs);
   } else {
     // mark as unresolved.
-    ++unresolved_count_;
+    Fail(Diagnostic::Error(orig->span)
+         << "Can not prove floormod lhs extent " << lhs->extent << " divisible by rhs " << rhs);
     return orig;
   }
 }
@@ -1019,16 +1237,23 @@ PrimExpr IterMapRewriter::VisitExpr_(const FloorModNode* op) {
 
   if (b->IsInstance<IterMapExprNode>()) {
     // cannot mod an iterator, mark as unresolved.
-    ++unresolved_count_;
+    Fail(Diagnostic::Error(op->span) << "Cannot mod an iterator: " << GetRef<PrimExpr>(op));
     return GetRef<PrimExpr>(op);
   }
 
   if (a->IsInstance<IterSumExprNode>()) {
     IterSumExpr ret = Downcast<IterSumExpr>(a);
-    if (Optional<IterSplitExpr> opt = TryFuseIters(ret)) {
-      return SplitFloorModConst(opt.value(), b, GetRef<PrimExpr>(op));
+    if (Optional<IterSumExpr> opt = TryFuseIters(ret)) {
+      IterSumExpr sum = opt.value();
+      if (!is_zero(sum->base)) {
+        Fail(Diagnostic::Error(op->span)
+             << "Fuse IterSumExpr " << ret
+             << " failed, cannot floormod an IterSumExpr with nonzero base");
+        return GetRef<PrimExpr>(op);
+      }
+      return SplitFloorModConst(sum->args[0], b, GetRef<PrimExpr>(op));
     } else {
-      ++unresolved_count_;
+      Fail(Diagnostic::Error(op->span) << "Fail to fuse iters of " << ret);
       return GetRef<PrimExpr>(op);
     }
   } else {
@@ -1039,19 +1264,21 @@ PrimExpr IterMapRewriter::VisitExpr_(const FloorModNode* op) {
 }
 
 /*! * \brief Given an IterVarMapExpr, transform it to normal PrimExpr. */
-class IterMapToExprNormalizer {
+class IterMapToExprNormalizer : public ExprMutator {
  public:
   explicit IterMapToExprNormalizer(Analyzer* analyzer) : analyzer_(analyzer) {}
 
-  PrimExpr Convert(const IterMapExpr& expr) {
+  PrimExpr Convert(const IterMapExpr& expr) { return VisitExpr(expr); }
+
+ private:
+  /*! \brief Override VisitExpr for iter expr type processing */
+  PrimExpr VisitExpr(const PrimExpr& expr) override {
     if (const auto* op = expr.as<IterSplitExprNode>()) {
       return ConvertIterSplitExpr(GetRef<IterSplitExpr>(op));
     } else if (const auto* op = expr.as<IterSumExprNode>()) {
       return ConvertIterSumExpr(GetRef<IterSumExpr>(op));
     } else {
-      ICHECK(expr.defined());
-      LOG(FATAL) << "Unknown IterMapExpr type " << expr->GetTypeKey();
-      return 0;
+      return ExprMutator::VisitExpr(expr);
     }
   }
 
@@ -1071,7 +1298,7 @@ class IterMapToExprNormalizer {
     } else if (const auto* op = expr->source->source.as<IterSumExprNode>()) {
       source = ConvertIterSumExpr(GetRef<IterSumExpr>(op));
     } else {
-      LOG(FATAL) << "Unexpected source of IterSplitExpr";
+      source = VisitExpr(expr->source->source);
     }
     if (analyzer_->CanProve(expr->extent == expr->source->extent) && is_one(expr->lower_factor)) {
       return source * expr->scale;
@@ -1100,8 +1327,9 @@ Array<PrimExpr> IterMapSimplify(const Array<PrimExpr>& indices, const Map<Var, R
                                 const PrimExpr& input_pred, bool require_bijective) {
   if (!IterRangeSanityCheck(input_iters)) return indices;
   Analyzer analyzer;
+  DiagnosticContext diag_ctx(DiagnosticContext::Default(IRModule()));
   Array<IterSumExpr> rewrite =
-      DetectIterMap(indices, input_iters, input_pred, require_bijective, &analyzer);
+      DetectIterMap(indices, input_iters, input_pred, require_bijective, &analyzer, diag_ctx);
   if (rewrite.empty()) {
     return indices;
   }
@@ -1128,8 +1356,9 @@ Array<PrimExpr> IterMapSimplify(const Array<PrimExpr>& indices, const Map<Var, R
 class SubspaceDivider {
  public:
   explicit SubspaceDivider(Analyzer* analyzer, const IterMarkSplitCollector& collector,
-                           const std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual>& sub_iters)
-      : analyzer_(analyzer), collector_(collector), sub_iters_(sub_iters) {}
+                           const std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual>& sub_iters,
+                           DiagnosticContext diag_ctx)
+      : analyzer_(analyzer), collector_(collector), sub_iters_(sub_iters), diag_ctx_(diag_ctx) {}
 
   size_t unresolved_count() const { return unresolved_count_; }
 
@@ -1190,7 +1419,10 @@ class SubspaceDivider {
       return DivisionResult(IterSumExpr({}, 0), 1, IterSumExpr({}, expr->base), 1);
     } else if (expr->args.size() == 1) {
       // arg + base, if arg=Y*E(X)+X, then arg+base = Y*E(X)+(X+base)
-      if (!is_one(expr->args[0]->scale)) return Fail();
+      if (!is_one(expr->args[0]->scale)) {
+        return Fail(Diagnostic::Error(expr->span)
+                    << "Expect split scale be 1, got " << expr->args[0]->scale);
+      }
       DivisionResult res = DivideIterSplitExpr(expr->args[0]);
       if (!is_zero(expr->base)) res = AddBase(res, expr->base);
       return res;
@@ -1208,7 +1440,9 @@ class SubspaceDivider {
       DivisionResult arg_division = DivideIterSplitExpr(arg);
       IterSplitExpr new_arg;
       if (arg_division.IsInner()) {
-        if (!inner) return Fail();
+        if (!inner)
+          return Fail(Diagnostic::Error(expr->span)
+                      << "Current division is inner but outer division exists for previous args");
         new_arg = arg_division.GetInnerAsSplit();
         inner_args.push_back(new_arg);
         inner = true;
@@ -1217,11 +1451,13 @@ class SubspaceDivider {
         outer_args.push_back(new_arg);
         inner = false;
       } else {
-        return Fail();
+        return Fail(Diagnostic::Error(expr->span)
+                    << "Division of " << arg << " is neither inner nor outer");
       }
       extent *= new_arg->extent;
     }
-    if (!scale_is_one) return Fail();
+    if (!scale_is_one)
+      return Fail(Diagnostic::Error(expr->span) << "Expect all iter sum arg's scale be 1");
     bool need_predicate = !analyzer_->CanProveEqual(extent, mark_extent);
     const IterMark& outer_mark = MarkFromArgsAndBase(outer_args, 0);
     const IterMark& inner_mark = MarkFromArgsAndBase(inner_args, expr->base);
@@ -1240,7 +1476,8 @@ class SubspaceDivider {
         inner_preds_ = inner_preds_ && (converter.Convert(inner_source) < mark_extent);
         return DivisionResult::Inner(inner_source, mark_extent);
       } else {
-        return Fail();
+        return Fail(Diagnostic::Error(expr->span)
+                    << "Either inner or outer args should exists if need predicate: " << expr);
       }
     }
     return DivisionResult(outer_source, outer_mark->extent, inner_source, inner_mark->extent);
@@ -1250,8 +1487,11 @@ class SubspaceDivider {
   PrimExpr GetInnerPreds() const { return inner_preds_; }
 
  private:
-  DivisionResult Fail() {
+  DivisionResult Fail(const Diagnostic& diagnostic) {
     unresolved_count_++;
+    if (diag_ctx_.defined()) {
+      diag_ctx_.Emit(diagnostic);
+    }
     return DivisionResult(IterSumExpr({}, 0), 0, IterSumExpr({}, 0), 0);
   }
 
@@ -1330,7 +1570,10 @@ class SubspaceDivider {
           if (!used[j] && analyzer_->CanProveEqual(splits[j]->lower_factor, expected_lower_factor))
             break;
         }
-        if (j == splits.size()) return Fail();
+        if (j == splits.size())
+          return Fail(Diagnostic::Error(expr->span)
+                      << "Can not find expected lower factor " << expected_lower_factor
+                      << " in splits of " << expr->source);
         used[j] = true;
         if (!encountered_boundary) {
           inner_iters.push_back(splits[j]);
@@ -1341,7 +1584,9 @@ class SubspaceDivider {
         if (analyzer_->CanProveEqual(expected_lower_factor, mark_division.inner_extent))
           encountered_boundary = true;
       }
-      if (!encountered_boundary) return Fail();
+      if (!encountered_boundary)
+        return Fail(Diagnostic::Error(expr->span)
+                    << "Can not find inner/outer boundary of " << expr);
       for (const IterSplitExpr& inner_iter : inner_iters) {
         IterSplitExpr new_iter = inner_iter;
         new_iter.CopyOnWrite()->source = inner_mark;
@@ -1355,7 +1600,8 @@ class SubspaceDivider {
         split_map_.emplace(outer_iter, DivisionResult::Outer(new_iter, outer_iter->extent));
       }
     } else {
-      return Fail();
+      return Fail(Diagnostic::Error(expr->span)
+                  << "Source expr to divide is neither var nor IterSumExpr");
     }
     return split_map_.at(expr);
   }
@@ -1371,15 +1617,18 @@ class SubspaceDivider {
   std::unordered_map<IterSplitExpr, DivisionResult, ObjectPtrHash, ObjectPtrEqual> split_map_;
   // predicate of outer space and inner space;
   PrimExpr outer_preds_{Bool(true)}, inner_preds_{Bool(true)};
+  // diagnostic context
+  DiagnosticContext diag_ctx_;
 };
 
 Array<Array<IterMark>> SubspaceDivide(const Array<PrimExpr>& bindings,
                                       const Map<Var, Range>& input_iters,
                                       const Array<Var>& sub_iters, const PrimExpr& predicate,
-                                      bool require_bijective, arith::Analyzer* analyzer) {
+                                      bool require_bijective, arith::Analyzer* analyzer,
+                                      DiagnosticContext diag_ctx) {
   if (!IterRangeSanityCheck(input_iters)) return Array<Array<IterMark>>();
   const Array<IterSumExpr>& maps =
-      DetectIterMap(bindings, input_iters, predicate, require_bijective, analyzer);
+      DetectIterMap(bindings, input_iters, predicate, require_bijective, analyzer, diag_ctx);
   if (maps.empty()) return {};
 
   std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> inner_iter_set;
@@ -1389,7 +1638,7 @@ Array<Array<IterMark>> SubspaceDivide(const Array<PrimExpr>& bindings,
 
   IterMarkSplitCollector collector;
   collector.Collect(maps);
-  SubspaceDivider subspace_divider(analyzer, collector, inner_iter_set);
+  SubspaceDivider subspace_divider(analyzer, collector, inner_iter_set, diag_ctx);
 
   std::vector<Array<IterMark>> results;
   for (const IterSumExpr& expr : maps) {
@@ -1409,7 +1658,9 @@ TVM_REGISTER_GLOBAL("arith.SubspaceDivide")
                        const Array<Var>& sub_iters, const PrimExpr& predicate,
                        bool require_bijective) {
       arith::Analyzer ana;
-      return SubspaceDivide(bindings, root_iters, sub_iters, predicate, require_bijective, &ana);
+      DiagnosticContext diag_ctx(DiagnosticContext::Default(IRModule()));
+      return SubspaceDivide(bindings, root_iters, sub_iters, predicate, require_bijective, &ana,
+                            diag_ctx);
     });
 
 class InverseAffineIterMapTransformer {
diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc
index 6d744a6..0a7d57e 100644
--- a/src/tir/schedule/analysis/analysis.cc
+++ b/src/tir/schedule/analysis/analysis.cc
@@ -415,12 +415,14 @@ bool IsAffineBinding(const BlockRealize& realize, const Map<Var, Range>& loop_va
   if (loop_var_ranges.empty()) {
     return true;
   }
+  DiagnosticContext diag_ctx(DiagnosticContext::Default(IRModule()));
   Array<arith::IterSumExpr> results = arith::DetectIterMap(
       /*indices=*/realize->iter_values,
       /*input_iters=*/loop_var_ranges,
       /*predicate=*/realize->predicate,
       /*require_bijective=*/false,
-      /*analyzer=*/analyzer);
+      /*analyzer=*/analyzer,
+      /*diag_ctx*/ diag_ctx);
   if (results.empty()) {
     return false;
   }
diff --git a/tests/python/unittest/test_arith_intset.py b/tests/python/unittest/test_arith_intset.py
index 618fd55..b40f3c9 100644
--- a/tests/python/unittest/test_arith_intset.py
+++ b/tests/python/unittest/test_arith_intset.py
@@ -16,6 +16,8 @@
 # under the License.
 import tvm
 from tvm import te
+from tvm import tir
+from tvm.ir.base import structural_equal
 
 
 class IntSetChecker:
@@ -233,6 +235,90 @@ def test_region_lower_bound_negative_scale():
     assert int_set_1.max_value.value == 35
 
 
+def test_region_lower_bound_for_non_perfect_tile():
+    h1 = tvm.tir.Var("h1", "int32")
+    h2 = tvm.tir.Var("h2", "int32")
+    h3 = tvm.tir.Var("h3", "int32")
+    analyzer = tvm.arith.Analyzer()
+
+    def do_test_point_access(point, predicates, var_dom, expect):
+        regions = tvm.arith.estimate_region_lower_bound(
+            region=[
+                tvm.ir.Range.from_min_extent(min_value=point, extent=1),
+            ],
+            var_dom=var_dom,
+            predicate=tvm.tir.all(*predicates),
+        )
+        if expect is None:  # expect a failure
+            assert regions is None
+        else:
+            assert len(regions) == 1
+            for binding, expect_min, expect_max in expect:
+                min_diff = expect_min - regions[0].min_value
+                assert analyzer.simplify(tir.stmt_functor.substitute(min_diff, binding), 3) == 0
+                max_diff = expect_max - regions[0].max_value
+                assert analyzer.simplify(tir.stmt_functor.substitute(max_diff, binding), 3) == 0
+
+    # non-uniform tiling, single inner variable
+    # h3 == 0: region is [1, 9]
+    # 0 < h3 <= 26: region is [h3 * 8, h3 * 8 + 9]
+    # h3 > 26: region is [h3 * 8, 223]
+    do_test_point_access(
+        point=h3 * 8 + h2,
+        predicates=[1 <= h3 * 8 + h2, h3 * 8 + h2 < 224],
+        var_dom={
+            h2: tvm.ir.Range(begin=0, end=10),
+        },
+        expect=[
+            (
+                {},
+                tvm.tir.max(h3 * 8, 1),
+                tvm.tir.max(h3 * 8, 1)
+                - tvm.tir.max(h3 * 8, 214)
+                - tvm.tir.max(1 - h3 * 8, 0)
+                + 223,
+            ),
+            ({h3: 0}, 1, 9),
+            ({h3: 10}, h3 * 8, h3 * 8 + 9),
+            ({h3: 27}, h3 * 8, 223),
+        ],
+    )
+
+    # non-uniform tiling, two inner variables
+    do_test_point_access(
+        point=h3 * 8 + h2 * 5 + h1,
+        predicates=[1 <= h3 * 8 + h2 * 5 + h1, h3 * 8 + h2 * 5 + h1 < 224],
+        var_dom={
+            h2: tvm.ir.Range(begin=0, end=2),
+            h1: tvm.ir.Range(begin=0, end=5),
+        },
+        expect=[
+            (
+                {},
+                tvm.tir.max(h3 * 8, 1),
+                tvm.tir.max(h3 * 8, 1)
+                - tvm.tir.max(h3 * 8, 214)
+                - tvm.tir.max(1 - h3 * 8, 0)
+                + 223,
+            ),
+            ({h3: 0}, 1, 9),
+            ({h3: 10}, h3 * 8, h3 * 8 + 9),
+            ({h3: 27}, h3 * 8, 223),
+        ],
+    )
+
+    # should fail on incompatible predicates
+    do_test_point_access(
+        point=h3 * 8 + h2 * 5 + h1,
+        predicates=[1 <= h3 * 8 + h2 * 5 + h1, h3 * 8 + h1 * 2 + h2 < 224],
+        var_dom={
+            h2: tvm.ir.Range(begin=0, end=2),
+            h1: tvm.ir.Range(begin=0, end=5),
+        },
+        expect=None,
+    )
+
+
 def test_union_lower_bound():
     neg_inf = tvm.arith.int_set.neg_inf()
     pos_inf = tvm.arith.int_set.pos_inf()
@@ -257,4 +343,5 @@ if __name__ == "__main__":
     test_region_lower_bound_split_predicate()
     test_region_lower_bound_multiple_variables()
     test_region_lower_bound_negative_scale()
+    test_region_lower_bound_for_non_perfect_tile()
     test_union_lower_bound()
diff --git a/tests/python/unittest/test_arith_iter_affine_map.py b/tests/python/unittest/test_arith_iter_affine_map.py
index c307034..6b3c295 100644
--- a/tests/python/unittest/test_arith_iter_affine_map.py
+++ b/tests/python/unittest/test_arith_iter_affine_map.py
@@ -199,8 +199,171 @@ def test_predicate():
     x = tvm.tir.Var("x", "int32"), 13
     y = tvm.tir.Var("y", "int32"), 10
 
+    # available contraints
+    # upper bound only
     res = tvm.arith.detect_iter_map([x[0] * 10 + y[0]], var_dom([x, y]), x[0] * 10 + y[0] < 128)
+    assert len(res) == 1
+    assert_iter_sum_pattern(res[0], 128, 0)
+    res = tvm.arith.detect_iter_map([x[0] * 10 + y[0]], var_dom([x, y]), x[0] * 10 + y[0] <= 127)
+    assert len(res) == 1
+    assert_iter_sum_pattern(res[0], 128, 0)
 
+    # lower bound only
+    res = tvm.arith.detect_iter_map([x[0] * 10 + y[0]], var_dom([x, y]), x[0] * 10 + y[0] > 5)
+    assert len(res) == 1
+    assert_iter_sum_pattern(res[0], 124, 6)
+    res = tvm.arith.detect_iter_map([x[0] * 10 + y[0]], var_dom([x, y]), x[0] * 10 + y[0] >= 6)
+    assert len(res) == 1
+    assert_iter_sum_pattern(res[0], 124, 6)
+
+    # lower bound + upper bound
+    res = tvm.arith.detect_iter_map(
+        [x[0] * 10 + y[0]],
+        var_dom([x, y]),
+        tvm.tir.And(x[0] * 10 + y[0] > 5, x[0] * 10 + y[0] < 128),
+    )
+    assert len(res) == 1
+    assert_iter_sum_pattern(res[0], 122, 6)
+    res = tvm.arith.detect_iter_map(
+        [x[0] * 10 + y[0]],
+        var_dom([x, y]),
+        tvm.tir.And(x[0] * 10 + y[0] >= 6, x[0] * 10 + y[0] <= 127),
+    )
+    assert len(res) == 1
+    assert_iter_sum_pattern(res[0], 122, 6)
+
+    # constraint on one fused iter
+    i = tvm.tir.Var("i", "int32")
+    j = tvm.tir.Var("j", "int32")
+    k = tvm.tir.Var("k", "int32")
+    res = tvm.arith.detect_iter_map(
+        [i * 8 + j * 2 + k],
+        var_dom([(i, 11), (j, 5), (k, 2)]),
+        tvm.tir.all(1 <= j * 2 + k, j * 2 + k < 9),
+    )
+    assert_iter_sum_pattern(res[0], 88, 1)
+
+    # constraint on single var
+    res = tvm.arith.detect_iter_map([i], var_dom([(i, 48)]), tvm.tir.all(i < 10))
+    assert_iter_sum_pattern(res[0], 10, 0)
+
+    # iterations are subparts of constraint, invalid, case 1
+    res = tvm.arith.detect_iter_map(
+        [i, j, k],
+        var_dom([(i, 128), (j, 128), (k, 128)]),
+        tvm.tir.all(i * 16384 + j * 128 + k < 100),
+    )
+    assert len(res) == 0
+
+    # iterations are subparts of constraint, invalid, case 2
+    res = tvm.arith.detect_iter_map(
+        [i * 128 + j, k],
+        var_dom([(i, 128), (j, 128), (k, 128)]),
+        tvm.tir.all(i * 16384 + j * 128 + k < 100),
+    )
+    assert len(res) == 0
+
+    # constraint on nested fused iters
+    res = tvm.arith.detect_iter_map(
+        [i * 8 + j * 2 + k],
+        var_dom([(i, 11), (j, 5), (k, 2)]),
+        tvm.tir.all(1 <= j * 2 + k, j * 2 + k < 9, 3 <= i * 8 + j * 2 + k, i * 8 + j * 2 + k < 25),
+    )
+    assert_iter_sum_pattern(res[0], 22, 3)
+
+    # duplicate constraint on one fused iter
+    res = tvm.arith.detect_iter_map(
+        [i * 6 + j * 2 + k],
+        var_dom([(i, 11), (j, 5), (k, 2)]),
+        tvm.tir.all(1 <= j * 2 + k, 2 <= j * 2 + k, j * 2 + k < 8, j * 2 + k < 9),
+    )
+    assert_iter_sum_pattern(res[0], 66, 2)
+
+    # duplicate constraint on nested fused iters
+    res = tvm.arith.detect_iter_map(
+        [i * 6 + j * 2 + k],
+        var_dom([(i, 11), (j, 5), (k, 2)]),
+        tvm.tir.all(
+            1 <= j * 2 + k,
+            2 <= j * 2 + k,
+            j * 2 + k < 8,
+            j * 2 + k < 9,
+            3 <= i * 6 + j * 2 + k,
+            i * 6 + j * 2 + k < 25,
+            1 <= i * 6 + j * 2 + k,
+            i * 6 + j * 2 + k < 18,
+        ),
+    )
+    assert_iter_sum_pattern(res[0], 15, 3)
+
+    # constraint on non-disjoint fused iters should fail
+    res = tvm.arith.detect_iter_map(
+        [i * 8 + j * 2 + k],
+        var_dom([(i, 11), (j, 5), (k, 2)]),
+        tvm.tir.all(2 <= j * 2 + k, 0 <= i * 4 + j),
+    )
+    assert len(res) == 0
+
+    # constraint on many disjoint fused iters, case 1
+    # i4 * 6 + i5 in [3, 9), extent=6 (= scale of i2)
+    # i2 * 30 + i3 * 15 in [30, 90), extent=60 (= scale of i1)
+    # i1 * 60 in [60, 240), extent=180 (= scale of i0)
+    i0 = tvm.tir.Var("i0", "int32")
+    i1 = tvm.tir.Var("i1", "int32")
+    i2 = tvm.tir.Var("i2", "int32")
+    i3 = tvm.tir.Var("i3", "int32")
+    i4 = tvm.tir.Var("i4", "int32")
+    i5 = tvm.tir.Var("i5", "int32")
+    res = tvm.arith.detect_iter_map(
+        [i0 * 180 + i1 * 60 + i2 * 30 + i3 * 15 + i4 * 6 + i5],
+        var_dom([(i0, 3), (i1, 4), (i2, 3), (i3, 2), (i4, 3), (i5, 6)]),
+        tvm.tir.all(1 <= i1, 2 <= i2 * 2 + i3, 3 <= i4 * 6 + i5),
+    )
+    assert_iter_sum_pattern(res[0], 540, 93)
+
+    # constraint on many disjoint fused iters, case 2
+    res = tvm.arith.detect_iter_map(
+        [i0 * 45 + i1 * 45 + i2 * 9 + i3 * 4 + i4],
+        var_dom([(i0, 3), (i1, 2), (i2, 5), (i3, 3), (i4, 4)]),
+        tvm.tir.all(3 <= i1 * 5 + i2, i1 * 5 + i2 < 8, 1 <= i3 * 4 + i4, i3 * 4 + i4 < 10),
+    )
+    assert_iter_sum_pattern(res[0], 135, 28)
+
+    # constraint on split iters
+    res = tvm.arith.detect_iter_map(
+        [i % 16, i // 16],
+        var_dom([(i, 1024)]),
+        tvm.tir.all(3 <= i % 16, i % 16 < 10, 4 <= i // 16, i // 16 < 12),
+        require_bijective=True,
+    )
+    assert_iter_sum_pattern(res[0], 7, 3)
+    assert_iter_sum_pattern(res[1], 8, 4)
+
+    # constraint on split iters, nested case 1
+    res = tvm.arith.detect_iter_map(
+        [(i * 32 + j) % 16],
+        var_dom([(i, 5), (j, 32)]),
+        tvm.tir.all(3 <= (i * 32 + j) % 16, (i * 32 + j) % 16 < 10),
+    )
+    assert_iter_sum_pattern(res[0], 7, 3)
+
+    # constraint on split iters, nested case 2
+    res = tvm.arith.detect_iter_map(
+        [(i * 32 + j) % 16],
+        var_dom([(i, 5), (j, 32)]),
+        tvm.tir.all(1 <= i * 32 + j, i * 32 + j <= 32),
+    )
+    assert len(res) == 0
+    res = tvm.arith.detect_iter_map(
+        [(i * 32 + j - 1) % 16, (i * 32 + j - 1) // 16],
+        var_dom([(i, 5), (j, 32)]),
+        tvm.tir.all(1 <= i * 32 + j, i * 32 + j <= 64),
+    )
+    assert_iter_sum_pattern(res[0], 16, 0)
+    assert_iter_sum_pattern(res[1], 4, 0)
+
+    # non-standard form of predicate
+    res = tvm.arith.detect_iter_map([x[0] * 10 + y[0]], var_dom([x, y]), x[0] * 10 < 128 - y[0])
     assert len(res) == 1
     assert_iter_sum_pattern(res[0], 128, 0)
 
@@ -651,6 +814,10 @@ def test_normalize_iter_map_to_expr():
     )
     tvm.ir.assert_structural_equal(tvm.arith.normalize_iter_map_to_expr(res[1]), flm(x[0], 5))
 
+    # iter mark wrap a complex expr
+    split = tvm.arith.IterSplitExpr(tvm.arith.IterMark(x[0] * y[0] + 1, 1024), 1, 1024, 1)
+    tvm.ir.assert_structural_equal(tvm.arith.normalize_iter_map_to_expr(split), x[0] * y[0] + 1)
+
 
 def test_inverse_affine_iter_map():
     analyzer = tvm.arith.Analyzer()
@@ -712,6 +879,38 @@ def test_inverse_affine_iter_map():
     assert analyzer.simplify(res[l0[0]] - l0_inverse) == 0
 
 
+def test_free_variables():
+    x = tvm.tir.Var("x", "int32")
+    y = tvm.tir.Var("y", "int32")
+    z = tvm.tir.Var("z", "int32")
+
+    # illegal iter if z is within dom
+    res = tvm.arith.detect_iter_map([z * 19 + y * 3 + x], var_dom([(x, 3), (y, 3), (z, 3)]))
+    assert len(res) == 0
+
+    # iter is valid if z is free, even there are linear forms of z
+    res = tvm.arith.detect_iter_map(
+        [z * 19 + y * 3 + x],
+        var_dom(
+            [
+                (x, 3),
+                (y, 3),
+            ]
+        ),
+    )
+    assert_iter_sum_pattern(res[0], 9, z * 19)
+    res = tvm.arith.detect_iter_map(
+        [z * z + y * 3 + x],
+        var_dom(
+            [
+                (x, 3),
+                (y, 3),
+            ]
+        ),
+    )
+    assert_iter_sum_pattern(res[0], 9, z * z)
+
+
 if __name__ == "__main__":
     test_split()
     test_trivial()
@@ -722,3 +921,4 @@ if __name__ == "__main__":
     test_subspace_division()
     test_complex()
     test_inverse_affine_iter_map()
+    test_free_variables()
diff --git a/tests/python/unittest/test_tir_schedule_reorder.py b/tests/python/unittest/test_tir_schedule_reorder.py
index 8267a36..fd2d82d 100644
--- a/tests/python/unittest/test_tir_schedule_reorder.py
+++ b/tests/python/unittest/test_tir_schedule_reorder.py
@@ -217,9 +217,8 @@ def test_reorder_with_predicate():
     sch = tir.Schedule(elementwise_predicate, debug_mask="all")
     block_b = sch.get_block("B")
     i, j, k, l = sch.get_loops(block_b)
-    sch.reorder(l, i)
-    tvm.ir.assert_structural_equal(elementwise_reordered_with_predicate, sch.mod["main"])
-    verify_trace_roundtrip(sch=sch, mod=elementwise_predicate)
+    with pytest.raises(tvm.tir.ScheduleError):
+        sch.reorder(l, i)
 
 
 def test_reorder_fail_with_multi_appearance_loops():
diff --git a/tests/python/unittest/test_tir_schedule_rfactor.py b/tests/python/unittest/test_tir_schedule_rfactor.py
index 35d4f5f..f5fc5a7 100644
--- a/tests/python/unittest/test_tir_schedule_rfactor.py
+++ b/tests/python/unittest/test_tir_schedule_rfactor.py
@@ -690,11 +690,9 @@ def test_reduction_rfactor_predicate():  # pylint: disable=invalid-name
     s = tir.Schedule(rowsum_predicate, debug_mask="all")
     B = s.get_block("B")
     _, ko, _ = s.get_loops(B)
-    rf_block = s.rfactor(ko, 1)
-    tvm.ir.assert_structural_equal(s.mod["main"], rowsum_predicate_rfactor)
-    assert s.get(rf_block).same_as(s.get(s.get_block("B_rf")))
-    assert s.get(B).same_as(s.get(s.get_block("B")))
-    verify_trace_roundtrip(s, mod=rowsum_predicate)
+    # TODO: should be a tvm.tir.ScheduleError
+    with pytest.raises(tvm.TVMError):
+        rf_block = s.rfactor(ko, 1)
 
 
 def test_reduction_rfactor_with_annotation():
diff --git a/tests/python/unittest/test_tir_schedule_state_cached_flags.py b/tests/python/unittest/test_tir_schedule_state_cached_flags.py
index e3bd000..d86af72 100644
--- a/tests/python/unittest/test_tir_schedule_state_cached_flags.py
+++ b/tests/python/unittest/test_tir_schedule_state_cached_flags.py
@@ -314,6 +314,45 @@ def warp_memory_negative(a: T.handle, c: T.handle) -> None:
                         C[warp_id * 32 + lane_id, vj] = B[vj, warp_id, lane_id] + 1.0
 
 
+@T.prim_func
+def non_perfect_tiling_cache(a: T.handle, b: T.handle) -> None:
+    X = T.match_buffer(a, [224, 224], dtype="float32")
+    Y = T.match_buffer(b, [224, 224], dtype="float32")
+    cache = T.alloc_buffer([224, 224], dtype="float32")
+    for hh_0, ww_0 in T.grid(28, 28):
+        for ax0 in T.serial(0, 10):
+            for ax1 in T.serial(0, 10):
+                with T.block("cache"):
+                    h = T.axis.spatial(224, hh_0 * 8 - 1 + ax0)
+                    w = T.axis.spatial(224, ww_0 * 8 - 1 + ax1)
+                    T.where(
+                        1 <= hh_0 * 8 + ax0
+                        and hh_0 * 8 + ax0 < 225
+                        and 1 <= ww_0 * 8 + ax1
+                        and ww_0 * 8 + ax1 < 225
+                    )
+                    cache[h, w] = X[h, w]
+        for hh_1, ww_1, khh, kww in T.grid(8, 8, 3, 3):
+            with T.block("compute"):
+                h = T.axis.spatial(224, hh_0 * 8 + hh_1)
+                w = T.axis.spatial(224, ww_0 * 8 + ww_1)
+                kh, kw = T.axis.remap("RR", [khh, kww])
+                with T.init():
+                    Y[h, w] = 0.0
+                Y[h, w] = T.max(
+                    Y[h, w],
+                    T.if_then_else(
+                        T.likely(1 <= h + kh, dtype="bool")
+                        and T.likely(h + kh < 225, dtype="bool")
+                        and T.likely(1 <= w + kw, dtype="bool")
+                        and T.likely(w + kw < 225, dtype="bool"),
+                        cache[h + kh - 1, w + kw - 1],
+                        0.0,
+                        dtype="float32",
+                    ),
+                )
+
+
 # pylint: enable=no-member,invalid-name,unused-variable,unexpected-keyword-arg
 
 
@@ -702,5 +741,21 @@ def test_warp_memory_negative():
     # pylint: enable=protected-access
 
 
+def test_non_perfect_tiling_cache():
+    s = tir.ScheduleState(non_perfect_tiling_cache, debug_mask="all")
+    # pylint: disable=protected-access
+    assert s._get_cached_flags(_get_block(s, "cache")) == CachedFlags(
+        affine_binding=False,
+        region_cover=True,
+        stage_pipeline=True,
+    )
+    assert s._get_cached_flags(_get_block(s, "compute")) == CachedFlags(
+        affine_binding=True,
+        region_cover=False,
+        stage_pipeline=True,
+    )
+    # pylint: enable=protected-access
+
+
 if __name__ == "__main__":
     sys.exit(pytest.main([__file__] + sys.argv[1:]))