You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by dr...@apache.org on 2023/04/28 00:18:13 UTC

[tvm] branch v0.12.0 updated: [cherry-pick][ARITH] Enhance CanonicalSimplify to Simplify ProdDiv and [ci] disable merge (#14715)

This is an automated email from the ASF dual-hosted git repository.

driazati pushed a commit to branch v0.12.0
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/v0.12.0 by this push:
     new dd25968edc [cherry-pick][ARITH] Enhance CanonicalSimplify to Simplify ProdDiv and [ci] disable merge (#14715)
dd25968edc is described below

commit dd25968edc6491c6dfb5455eba8e4c413fd07d51
Author: driazati <94...@users.noreply.github.com>
AuthorDate: Thu Apr 27 17:18:00 2023 -0700

    [cherry-pick][ARITH] Enhance CanonicalSimplify to Simplify ProdDiv and [ci] disable merge (#14715)
    
    This disables the merge to main behavior for this branch and includes the code #14725 to get this to pass CI. This also disables the docker GPU build which is currently broken on this branch and has no bearing on the release (and is being built erroneously anyways since this PR has no docker changes).
---
 ci/jenkins/generated/arm_jenkinsfile.groovy        |   5 +-
 ci/jenkins/generated/cortexm_jenkinsfile.groovy    |   5 +-
 ci/jenkins/generated/cpu_jenkinsfile.groovy        |   5 +-
 ci/jenkins/generated/docker_jenkinsfile.groovy     |   7 +-
 ci/jenkins/generated/gpu_jenkinsfile.groovy        |   5 +-
 ci/jenkins/generated/hexagon_jenkinsfile.groovy    |   5 +-
 ci/jenkins/generated/i386_jenkinsfile.groovy       |   5 +-
 ci/jenkins/generated/lint_jenkinsfile.groovy       |   5 +-
 .../generated/minimal_cross_isa_jenkinsfile.groovy |   5 +-
 ci/jenkins/generated/minimal_jenkinsfile.groovy    |   5 +-
 ci/jenkins/generated/riscv_jenkinsfile.groovy      |   5 +-
 ci/jenkins/generated/wasm_jenkinsfile.groovy       |   5 +-
 ci/jenkins/templates/docker_jenkinsfile.groovy.j2  |   2 +-
 ci/jenkins/templates/utils/Prepare.groovy.j2       |   3 +-
 src/arith/bound_deducer.cc                         |   4 +
 src/arith/canonical_simplify.cc                    | 106 +++++++++++++++++++++
 src/arith/pattern_match.h                          |  17 ++++
 tests/python/unittest/test_arith_deduce_bound.py   |  10 +-
 18 files changed, 171 insertions(+), 33 deletions(-)

diff --git a/ci/jenkins/generated/arm_jenkinsfile.groovy b/ci/jenkins/generated/arm_jenkinsfile.groovy
index 4c830dce2c..14ad1ad780 100644
--- a/ci/jenkins/generated/arm_jenkinsfile.groovy
+++ b/ci/jenkins/generated/arm_jenkinsfile.groovy
@@ -60,7 +60,7 @@
 // 'python3 jenkins/generate.py'
 // Note: This timestamp is here to ensure that updates to the Jenkinsfile are
 // always rebased on main before merging:
-// Generated at 2023-02-02T20:12:16.672484
+// Generated at 2023-04-25T11:40:51.453275
 
 import org.jenkinsci.plugins.pipeline.modeldefinition.Utils
 // These are set at runtime from data in ci/jenkins/docker-images.yml, update
@@ -150,7 +150,8 @@ def init_git() {
     update_upstream_revision("HEAD")
   } else {
     // This is PR branch so merge with latest main.
-    merge_with_main()
+    // merge_with_main()
+    update_upstream_revision("HEAD")
   }
 
   sh(
diff --git a/ci/jenkins/generated/cortexm_jenkinsfile.groovy b/ci/jenkins/generated/cortexm_jenkinsfile.groovy
index d8a4d4671e..3f0347c37b 100644
--- a/ci/jenkins/generated/cortexm_jenkinsfile.groovy
+++ b/ci/jenkins/generated/cortexm_jenkinsfile.groovy
@@ -60,7 +60,7 @@
 // 'python3 jenkins/generate.py'
 // Note: This timestamp is here to ensure that updates to the Jenkinsfile are
 // always rebased on main before merging:
-// Generated at 2023-02-02T20:12:16.614676
+// Generated at 2023-04-25T11:40:51.505590
 
 import org.jenkinsci.plugins.pipeline.modeldefinition.Utils
 // These are set at runtime from data in ci/jenkins/docker-images.yml, update
@@ -150,7 +150,8 @@ def init_git() {
     update_upstream_revision("HEAD")
   } else {
     // This is PR branch so merge with latest main.
-    merge_with_main()
+    // merge_with_main()
+    update_upstream_revision("HEAD")
   }
 
   sh(
diff --git a/ci/jenkins/generated/cpu_jenkinsfile.groovy b/ci/jenkins/generated/cpu_jenkinsfile.groovy
index cdd2564e05..caaeafcb78 100644
--- a/ci/jenkins/generated/cpu_jenkinsfile.groovy
+++ b/ci/jenkins/generated/cpu_jenkinsfile.groovy
@@ -60,7 +60,7 @@
 // 'python3 jenkins/generate.py'
 // Note: This timestamp is here to ensure that updates to the Jenkinsfile are
 // always rebased on main before merging:
-// Generated at 2023-02-02T20:12:16.563887
+// Generated at 2023-04-25T11:40:51.627063
 
 import org.jenkinsci.plugins.pipeline.modeldefinition.Utils
 // These are set at runtime from data in ci/jenkins/docker-images.yml, update
@@ -150,7 +150,8 @@ def init_git() {
     update_upstream_revision("HEAD")
   } else {
     // This is PR branch so merge with latest main.
-    merge_with_main()
+    // merge_with_main()
+    update_upstream_revision("HEAD")
   }
 
   sh(
diff --git a/ci/jenkins/generated/docker_jenkinsfile.groovy b/ci/jenkins/generated/docker_jenkinsfile.groovy
index 32dec7863b..d0b37bdc52 100644
--- a/ci/jenkins/generated/docker_jenkinsfile.groovy
+++ b/ci/jenkins/generated/docker_jenkinsfile.groovy
@@ -60,7 +60,7 @@
 // 'python3 jenkins/generate.py'
 // Note: This timestamp is here to ensure that updates to the Jenkinsfile are
 // always rebased on main before merging:
-// Generated at 2023-02-02T20:12:16.699838
+// Generated at 2023-04-26T17:36:59.403201
 
 import org.jenkinsci.plugins.pipeline.modeldefinition.Utils
 // These are set at runtime from data in ci/jenkins/docker-images.yml, update
@@ -150,7 +150,8 @@ def init_git() {
     update_upstream_revision("HEAD")
   } else {
     // This is PR branch so merge with latest main.
-    merge_with_main()
+    // merge_with_main()
+    update_upstream_revision("HEAD")
   }
 
   sh(
@@ -845,7 +846,7 @@ def deploy() {
 
 
 
-if (rebuild_docker_images) {
+if (false && rebuild_docker_images) {
   stage('Docker Image Build') {
     parallel(
       'ci_arm': {
diff --git a/ci/jenkins/generated/gpu_jenkinsfile.groovy b/ci/jenkins/generated/gpu_jenkinsfile.groovy
index 390c8ddc3d..428caedcbf 100644
--- a/ci/jenkins/generated/gpu_jenkinsfile.groovy
+++ b/ci/jenkins/generated/gpu_jenkinsfile.groovy
@@ -60,7 +60,7 @@
 // 'python3 jenkins/generate.py'
 // Note: This timestamp is here to ensure that updates to the Jenkinsfile are
 // always rebased on main before merging:
-// Generated at 2023-02-02T20:12:16.640362
+// Generated at 2023-04-25T11:40:51.523364
 
 import org.jenkinsci.plugins.pipeline.modeldefinition.Utils
 // These are set at runtime from data in ci/jenkins/docker-images.yml, update
@@ -150,7 +150,8 @@ def init_git() {
     update_upstream_revision("HEAD")
   } else {
     // This is PR branch so merge with latest main.
-    merge_with_main()
+    // merge_with_main()
+    update_upstream_revision("HEAD")
   }
 
   sh(
diff --git a/ci/jenkins/generated/hexagon_jenkinsfile.groovy b/ci/jenkins/generated/hexagon_jenkinsfile.groovy
index 58fe4d14c9..e774518be8 100644
--- a/ci/jenkins/generated/hexagon_jenkinsfile.groovy
+++ b/ci/jenkins/generated/hexagon_jenkinsfile.groovy
@@ -60,7 +60,7 @@
 // 'python3 jenkins/generate.py'
 // Note: This timestamp is here to ensure that updates to the Jenkinsfile are
 // always rebased on main before merging:
-// Generated at 2023-02-02T20:12:16.512545
+// Generated at 2023-04-25T11:40:51.434735
 
 import org.jenkinsci.plugins.pipeline.modeldefinition.Utils
 // These are set at runtime from data in ci/jenkins/docker-images.yml, update
@@ -150,7 +150,8 @@ def init_git() {
     update_upstream_revision("HEAD")
   } else {
     // This is PR branch so merge with latest main.
-    merge_with_main()
+    // merge_with_main()
+    update_upstream_revision("HEAD")
   }
 
   sh(
diff --git a/ci/jenkins/generated/i386_jenkinsfile.groovy b/ci/jenkins/generated/i386_jenkinsfile.groovy
index b5bf5cb1fe..3f3fed8244 100644
--- a/ci/jenkins/generated/i386_jenkinsfile.groovy
+++ b/ci/jenkins/generated/i386_jenkinsfile.groovy
@@ -60,7 +60,7 @@
 // 'python3 jenkins/generate.py'
 // Note: This timestamp is here to ensure that updates to the Jenkinsfile are
 // always rebased on main before merging:
-// Generated at 2023-02-02T20:12:16.590456
+// Generated at 2023-04-25T11:40:51.488582
 
 import org.jenkinsci.plugins.pipeline.modeldefinition.Utils
 // These are set at runtime from data in ci/jenkins/docker-images.yml, update
@@ -150,7 +150,8 @@ def init_git() {
     update_upstream_revision("HEAD")
   } else {
     // This is PR branch so merge with latest main.
-    merge_with_main()
+    // merge_with_main()
+    update_upstream_revision("HEAD")
   }
 
   sh(
diff --git a/ci/jenkins/generated/lint_jenkinsfile.groovy b/ci/jenkins/generated/lint_jenkinsfile.groovy
index ed5aa8d679..52d6036d2c 100644
--- a/ci/jenkins/generated/lint_jenkinsfile.groovy
+++ b/ci/jenkins/generated/lint_jenkinsfile.groovy
@@ -60,7 +60,7 @@
 // 'python3 jenkins/generate.py'
 // Note: This timestamp is here to ensure that updates to the Jenkinsfile are
 // always rebased on main before merging:
-// Generated at 2023-02-02T20:12:16.725728
+// Generated at 2023-04-25T11:40:51.545459
 
 import org.jenkinsci.plugins.pipeline.modeldefinition.Utils
 // These are set at runtime from data in ci/jenkins/docker-images.yml, update
@@ -150,7 +150,8 @@ def init_git() {
     update_upstream_revision("HEAD")
   } else {
     // This is PR branch so merge with latest main.
-    merge_with_main()
+    // merge_with_main()
+    update_upstream_revision("HEAD")
   }
 
   sh(
diff --git a/ci/jenkins/generated/minimal_cross_isa_jenkinsfile.groovy b/ci/jenkins/generated/minimal_cross_isa_jenkinsfile.groovy
index 4c748e3f20..f6d8f52c64 100644
--- a/ci/jenkins/generated/minimal_cross_isa_jenkinsfile.groovy
+++ b/ci/jenkins/generated/minimal_cross_isa_jenkinsfile.groovy
@@ -60,7 +60,7 @@
 // 'python3 jenkins/generate.py'
 // Note: This timestamp is here to ensure that updates to the Jenkinsfile are
 // always rebased on main before merging:
-// Generated at 2023-02-07T23:01:16.071376
+// Generated at 2023-04-25T11:40:51.596303
 
 import org.jenkinsci.plugins.pipeline.modeldefinition.Utils
 // These are set at runtime from data in ci/jenkins/docker-images.yml, update
@@ -150,7 +150,8 @@ def init_git() {
     update_upstream_revision("HEAD")
   } else {
     // This is PR branch so merge with latest main.
-    merge_with_main()
+    // merge_with_main()
+    update_upstream_revision("HEAD")
   }
 
   sh(
diff --git a/ci/jenkins/generated/minimal_jenkinsfile.groovy b/ci/jenkins/generated/minimal_jenkinsfile.groovy
index 72864ec4ca..6b25b37063 100644
--- a/ci/jenkins/generated/minimal_jenkinsfile.groovy
+++ b/ci/jenkins/generated/minimal_jenkinsfile.groovy
@@ -60,7 +60,7 @@
 // 'python3 jenkins/generate.py'
 // Note: This timestamp is here to ensure that updates to the Jenkinsfile are
 // always rebased on main before merging:
-// Generated at 2023-02-02T20:12:16.540335
+// Generated at 2023-04-25T11:40:51.561737
 
 import org.jenkinsci.plugins.pipeline.modeldefinition.Utils
 // These are set at runtime from data in ci/jenkins/docker-images.yml, update
@@ -150,7 +150,8 @@ def init_git() {
     update_upstream_revision("HEAD")
   } else {
     // This is PR branch so merge with latest main.
-    merge_with_main()
+    // merge_with_main()
+    update_upstream_revision("HEAD")
   }
 
   sh(
diff --git a/ci/jenkins/generated/riscv_jenkinsfile.groovy b/ci/jenkins/generated/riscv_jenkinsfile.groovy
index 2dfeb35612..47f2d6c92f 100644
--- a/ci/jenkins/generated/riscv_jenkinsfile.groovy
+++ b/ci/jenkins/generated/riscv_jenkinsfile.groovy
@@ -60,7 +60,7 @@
 // 'python3 jenkins/generate.py'
 // Note: This timestamp is here to ensure that updates to the Jenkinsfile are
 // always rebased on main before merging:
-// Generated at 2023-02-02T20:12:16.792163
+// Generated at 2023-04-25T11:40:51.472038
 
 import org.jenkinsci.plugins.pipeline.modeldefinition.Utils
 // These are set at runtime from data in ci/jenkins/docker-images.yml, update
@@ -150,7 +150,8 @@ def init_git() {
     update_upstream_revision("HEAD")
   } else {
     // This is PR branch so merge with latest main.
-    merge_with_main()
+    // merge_with_main()
+    update_upstream_revision("HEAD")
   }
 
   sh(
diff --git a/ci/jenkins/generated/wasm_jenkinsfile.groovy b/ci/jenkins/generated/wasm_jenkinsfile.groovy
index 27e8f6570e..bd84e4fef2 100644
--- a/ci/jenkins/generated/wasm_jenkinsfile.groovy
+++ b/ci/jenkins/generated/wasm_jenkinsfile.groovy
@@ -60,7 +60,7 @@
 // 'python3 jenkins/generate.py'
 // Note: This timestamp is here to ensure that updates to the Jenkinsfile are
 // always rebased on main before merging:
-// Generated at 2023-02-02T20:12:16.748767
+// Generated at 2023-04-25T11:40:51.612532
 
 import org.jenkinsci.plugins.pipeline.modeldefinition.Utils
 // These are set at runtime from data in ci/jenkins/docker-images.yml, update
@@ -150,7 +150,8 @@ def init_git() {
     update_upstream_revision("HEAD")
   } else {
     // This is PR branch so merge with latest main.
-    merge_with_main()
+    // merge_with_main()
+    update_upstream_revision("HEAD")
   }
 
   sh(
diff --git a/ci/jenkins/templates/docker_jenkinsfile.groovy.j2 b/ci/jenkins/templates/docker_jenkinsfile.groovy.j2
index beb9b478ba..f395f45dca 100644
--- a/ci/jenkins/templates/docker_jenkinsfile.groovy.j2
+++ b/ci/jenkins/templates/docker_jenkinsfile.groovy.j2
@@ -179,7 +179,7 @@ def deploy() {
 
 
 
-if (rebuild_docker_images) {
+if (false && rebuild_docker_images) {
   stage('Docker Image Build') {
     parallel(
     {% for image in images %}
diff --git a/ci/jenkins/templates/utils/Prepare.groovy.j2 b/ci/jenkins/templates/utils/Prepare.groovy.j2
index d5aebdc070..66db58c366 100644
--- a/ci/jenkins/templates/utils/Prepare.groovy.j2
+++ b/ci/jenkins/templates/utils/Prepare.groovy.j2
@@ -20,7 +20,8 @@ def init_git() {
     update_upstream_revision("HEAD")
   } else {
     // This is PR branch so merge with latest main.
-    merge_with_main()
+    // merge_with_main()
+    update_upstream_revision("HEAD")
   }
 
   sh(
diff --git a/src/arith/bound_deducer.cc b/src/arith/bound_deducer.cc
index 7cfe8681be..d52ae7e6fd 100644
--- a/src/arith/bound_deducer.cc
+++ b/src/arith/bound_deducer.cc
@@ -344,6 +344,10 @@ void BoundDeducer::Deduce() {
   expr_map_ = EvalSetForEachSubExpr(expr_, hint_map_);
 
   this->VisitExpr(expr_);
+
+  if (success_) {
+    result_ = analyzer_.Simplify(result_);
+  }
 }
 
 void BoundDeducer::Relax() {
diff --git a/src/arith/canonical_simplify.cc b/src/arith/canonical_simplify.cc
index 11fb041511..14c91934d3 100644
--- a/src/arith/canonical_simplify.cc
+++ b/src/arith/canonical_simplify.cc
@@ -633,6 +633,27 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl {
    */
   void SeparateDivisibleParts(const SumExprNode* psum, int64_t coeff, SumExpr* out_divisible,
                               SumExpr* out_non_divisible);
+  /*!
+   * \brief Pattern match and check whether lhs is fully divisible by
+   *        rhs using prod pattern simiplification expressions.
+   *
+   * The following two relations holds for floordiv/mod and truncdiv/mod
+   * Note that the relation do not hold for euclidean divide and mod.
+   *
+   * This is because the floordiv/mod and truncdiv/mod result can be
+   * uniquely determined by the value of the realdiv result and the
+   * relation holds for realdiv.
+   *
+   * - div((a0 * a1 * c), (b0 * b1 * c)) = div((a0 * a1), (b0 * b1))
+   * - mod((a0 * a1 * c), (b0 * b1 * c)) = mod((a0 * a1), (b0 * b1)) * c
+   *
+   * \param lhs The left operand to be updated.
+   * \param rhs The right operand to be updated.
+   * \param common_scale The common scale between lhs and rhs.
+   * \returns The simplified result if it is successful.
+   * \note This simplification mainly target when rhs is symbolic.
+   */
+  bool ProdDivSimplify(PrimExpr* lhs, PrimExpr* rhs, PrimExpr* common_scale);
   /*!
    * \brief Normalize expr to normal expr.
    * \param expr The input expression.
@@ -862,6 +883,66 @@ SplitExpr CanonicalSimplifier::Impl::SplitDivConst(SplitExpr lhs, int64_t cval,
   return lhs;
 }
 
+bool CanonicalSimplifier::Impl::ProdDivSimplify(PrimExpr* plhs, PrimExpr* prhs,
+                                                PrimExpr* common_scale) {
+  // the constant rhs case is covered by other simplifier so
+  // we just skip to save the time
+  if (prhs->as<IntImmNode>()) return false;
+  // collect lhs products and try to eliminate by matching them to prod in rhs
+  Array<Optional<PrimExpr>> lhs_prods;
+  PrimExpr new_rhs = make_const(prhs->dtype(), 1);
+  PrimExpr new_common_scale = make_const(prhs->dtype(), 1);
+  int64_t lhs_cscale = 1, rhs_cscale = 1;
+  int num_elimination = 0;
+
+  // collect lhs product and constant scale.
+  auto fcollect_lhs = [&](PrimExpr value) {
+    if (auto* intimm = value.as<tir::IntImmNode>()) {
+      lhs_cscale *= intimm->value;
+    } else {
+      lhs_prods.push_back(value);
+    }
+  };
+  UnpackReduction<tir::MulNode>(*plhs, fcollect_lhs);
+
+  // collect rhs product and try to eliminate when possible
+  PEqualChecker<PrimExpr> deep_equal;
+  auto fcollect_rhs = [&](PrimExpr value) {
+    if (auto* intimm = value.as<tir::IntImmNode>()) {
+      rhs_cscale *= intimm->value;
+    } else {
+      // try eliminate from lhs
+      for (size_t i = 0; i < lhs_prods.size(); ++i) {
+        if (lhs_prods[i].defined() && deep_equal(value, lhs_prods[i].value())) {
+          lhs_prods.Set(i, NullOpt);
+          ++num_elimination;
+          new_common_scale = new_common_scale * value;
+          return;
+        }
+      }
+      // if elimination is not possible then construct the expression.
+      new_rhs = new_rhs * value;
+    }
+  };
+  UnpackReduction<tir::MulNode>(*prhs, fcollect_rhs);
+  // find gcd of const scales.
+  int64_t cscale_gcd = ZeroAwareGCD(lhs_cscale, rhs_cscale);
+  lhs_cscale /= cscale_gcd;
+  rhs_cscale /= cscale_gcd;
+  // if no elimination is possible
+  if (num_elimination == 0 && cscale_gcd == 1) return false;
+
+  // construct prod via canonical form
+  PrimExpr new_lhs = make_const(plhs->dtype(), 1);
+  for (Optional<PrimExpr> val : lhs_prods) {
+    if (val.defined()) new_lhs = new_lhs * val.value();
+  }
+  *plhs = new_lhs * make_const(plhs->dtype(), lhs_cscale);
+  *prhs = new_rhs * make_const(prhs->dtype(), rhs_cscale);
+  *common_scale = new_common_scale * make_const(prhs->dtype(), cscale_gcd);
+  return true;
+}
+
 PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const DivNode* op) {
   if (!IsIndexType(op->dtype)) {
     return Rewriter::VisitExpr_(op);
@@ -913,6 +994,12 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const DivNode* op) {
   // normal path
   a = Normalize(a);
   b = Normalize(b);
+  PrimExpr scale;
+  // note this is the case where b is not constant
+  if (ProdDivSimplify(&a, &b, &scale)) {
+    // use operator ver so it can constant fold if b == 1
+    return truncdiv(a, b);
+  }
   if (op->a.same_as(a) && op->b.same_as(b)) {
     return GetRef<PrimExpr>(op);
   } else {
@@ -967,6 +1054,11 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const FloorDivNode* op) {
   // normal path
   a = Normalize(a);
   b = Normalize(b);
+  PrimExpr scale;
+  if (ProdDivSimplify(&a, &b, &scale)) {
+    // use operator ver so it can const fold.
+    return floordiv(a, b);
+  }
   if (op->a.same_as(a) && op->b.same_as(b)) {
     return GetRef<PrimExpr>(op);
   } else {
@@ -1088,6 +1180,13 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const ModNode* op) {
   // normal path
   a = Normalize(a);
   b = Normalize(b);
+
+  PrimExpr scale;
+  if (ProdDivSimplify(&a, &b, &scale)) {
+    // use operator version here so it can const fold b == 1
+    return truncmod(a, b) * scale;
+  }
+
   if (op->a.same_as(a) && op->b.same_as(b)) {
     return GetRef<PrimExpr>(op);
   } else {
@@ -1146,6 +1245,13 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const FloorModNode* op) {
   // normal path
   a = Normalize(a);
   b = Normalize(b);
+
+  PrimExpr scale;
+  if (ProdDivSimplify(&a, &b, &scale)) {
+    // use operator version here so it can const fold b == 1
+    return floormod(a, b) * scale;
+  }
+
   if (op->a.same_as(a) && op->b.same_as(b)) {
     return GetRef<PrimExpr>(op);
   } else {
diff --git a/src/arith/pattern_match.h b/src/arith/pattern_match.h
index 55b51d7a31..0bb172e560 100644
--- a/src/arith/pattern_match.h
+++ b/src/arith/pattern_match.h
@@ -915,6 +915,23 @@ matches_one_of(const TPattern&... patterns) {
   return PMatchesOneOf<TPattern...>(patterns...);
 }
 
+/*!
+ * \brief Unpack reduction by calling each leaf via fleaf.
+ *
+ * \param value The expression value.
+ * \tparam TNode the reduction node to match.
+ * \tparam FLeaf The callback function at leaf.
+ */
+template <typename TNode, typename FLeaf>
+inline void UnpackReduction(const PrimExpr& value, FLeaf fleaf) {
+  if (const TNode* node = value.as<TNode>()) {
+    UnpackReduction<TNode, FLeaf>(node->a, fleaf);
+    UnpackReduction<TNode, FLeaf>(node->b, fleaf);
+  } else {
+    fleaf(value);
+  }
+}
+
 }  // namespace arith
 }  // namespace tvm
 #endif  // TVM_ARITH_PATTERN_MATCH_H_
diff --git a/tests/python/unittest/test_arith_deduce_bound.py b/tests/python/unittest/test_arith_deduce_bound.py
index 45ecb62755..a36fd21479 100644
--- a/tests/python/unittest/test_arith_deduce_bound.py
+++ b/tests/python/unittest/test_arith_deduce_bound.py
@@ -114,12 +114,10 @@ def test_deduce():
     assert str(res9.max_value) == "neg_inf"
     assert str(res9.min_value) == "pos_inf"
 
-    # Unsatisfiable Mul in `EQ`
-    res10 = tvm.arith.deduce_bound(
-        a, (b * a == b), {b: b_s}, {}
-    )  # simplifier is not able to prove that (b % b == 0)
-    assert str(res10.max_value) == "neg_inf"
-    assert str(res10.min_value) == "pos_inf"
+    res10 = tvm.arith.deduce_bound(a, (b * a == b), {b: b_s}, {})
+    # simplifier is now able to prove symbolic relation (b * a % b == 0)
+    tvm.testing.assert_prim_expr_equal(res10.max_value, 1)
+    tvm.testing.assert_prim_expr_equal(res10.min_value, 1)
 
 
 def test_check():