You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2022/09/21 17:51:03 UTC

[GitHub] [tvm] Lunderberg opened a new pull request, #12863: [TIR][Arith] Prove conditionals by transitively applying knowns

Lunderberg opened a new pull request, #12863:
URL: https://github.com/apache/tvm/pull/12863

   This commit adds a new sub-analyzer, `TransitiveComparisonAnalyzer`, which attempts to apply multiple known comparisons to prove an unknown.  For example, `a <= b` and `b <= c` imply that `a <= c`. These simplifications are necessary for simplifying conditionals resulting from padded layout
   transformations (https://github.com/apache/tvm/issues/12261).
   
   While some of these conditions may be proven using `ConstIntBoundAnalyzer` or `IntSetAnalyzer`, each has some limitations.  `ConstIntBoundAnalyzer` can only compare against a constant, `IntSetAnalyzer` internally calls `RewriteSimplifier` which can result in infinite recursion, and neither can handle not-equal conditions because it would require tracking multiple intervals per expression.  Therefore, introducing a new sub-analyzer for these simplifications.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] Lunderberg commented on a diff in pull request #12863: [TIR][Arith] Prove conditionals by transitively applying knowns

Posted by GitBox <gi...@apache.org>.
Lunderberg commented on code in PR #12863:
URL: https://github.com/apache/tvm/pull/12863#discussion_r989448408


##########
src/arith/transitive_comparison_analyzer.cc:
##########
@@ -0,0 +1,683 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file tvm/arith/transitive_comparison_analyzer.cc
+ */
+
+#include <tvm/arith/analyzer.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/expr.h>
+
+#include <optional>
+#include <vector>
+
+#include "constraint_extract.h"
+#include "pattern_match.h"
+
+namespace tvm {
+namespace arith {
+
+using namespace tir;
+
+class TransitiveComparisonAnalyzer::Impl {
+ public:
+  /* \brief Using previously specified knowns, compare the expressions provided
+   *
+   * \param lhs The left-hand side of the comparison
+   *
+   * \param rhs The right-hand side of the comparison
+   *
+   * \return The most specific result that can be proven about the
+   * comparison.  If nothing can be proven, returns kUnknown.
+   */
+  CompareResult TryCompare(const PrimExpr& lhs, const PrimExpr& rhs) const;
+
+  /*! \brief Bind a variable as being equal to a known expression
+   *
+   * \param var The variable of interest.
+   * \param expr The bound expression
+   * \param allow_override Whether to allow override of existing information.
+   */
+  void Bind(const tir::Var& var, const PrimExpr& expr, bool allow_override = false);
+
+  /*! \brief Bind a variable as being within a specified range
+   *
+   * \param var The variable of interest.
+   * \param range The known range
+   * \param allow_override Whether to allow override of existing information.
+   */
+  void Bind(const tir::Var& var, const Range& expr, bool allow_override = false);
+
+  /*!
+   * \brief Update the internal state to enter constraint.
+   * \param constraint A constraint expression.
+   *
+   * \return An exit function that must be called to cleanup.  May be
+   * `nullptr`, if no cleanup is required.
+   */
+  std::function<void()> EnterConstraint(const PrimExpr& expr);
+
+ private:
+  // Utility class to avoid needing to repeatedly call ExprDeepEqual
+  enum class Key : size_t {};
+  std::optional<Key> ExprToPreviousKey(const PrimExpr& expr) const;
+  Key ExprToKey(const PrimExpr& expr);
+  std::unordered_map<PrimExpr, Key, StructuralHash, StructuralEqual> expr_to_key;
+
+  /*! \brief Internal representation of a comparison operator */
+  struct Comparison {
+    /*! \brief Construct a comparison that represents `lhs OP rhs +
+     * offset`, where the operation is specified by the CompareResult.
+     */
+    Comparison(Key lhs, Key rhs, int64_t offset, CompareResult result);
+
+    /*! \brief Utility function to validate that all GT and LT results
+     *  have been normalized out
+     */
+    bool IsNormalized() const;
+
+    /*! \brief Move the specified expression to the LHS.
+     *
+     * \param new_lhs The argument that should be moved to the LHS of the
+     * comparison.
+     *
+     * \return If possible, returns a comparison that is equivalent to
+     * the current comparison, but with the specified LHS.  If not
+     * possible, returns nullopt.
+     */
+    std::optional<Comparison> WithLHS(Key new_lhs) const;
+
+    /*! \brief Create the negation of the current comparison */
+    Comparison Negated() const;
+
+    /*! \brief Check the this comparison implies
+     *
+     * Returns true if this comparison being true implies that the
+     * other comparison must also be true.  Returns false if the other
+     * comparison cannot be shown to be true.
+     */
+    bool Implies(const Comparison& other) const;
+
+    // The LHS of the comparison
+    Key lhs_;
+
+    // The RHS of the comparison, not including any constant offset.
+    Key rhs_;
+
+    // Additive offset on rhs
+    int64_t offset_{0};
+
+    // The comparison operator.
+    CompareResult result_{CompareResult::kInconsistent};
+  };
+
+  /*! \brief Generate a Comparison representing the given expression */
+  std::optional<Comparison> FromExpr(const PrimExpr& expr);
+
+  /*! \brief Utility function used by Bind and EnterConstraint
+   *
+   * \param expr The comparison expression, to be converted into
+   * internal Comparison objects.
+   *
+   * \param vec The vector to which the Comparison objects should be
+   * appended.
+   */
+  void AddKnown(const PrimExpr& expr, std::vector<Comparison>* vec);
+
+  /*! \brief Attempt to compare, starting at the lhs.
+   *
+   * Taking each available `Comparison` as a node edge, search for a
+   * path from lhs to rhs.  For example, the priors (a<=b), (b<=c+1)
+   * and (c<=d-5) can be used to prove that (a<=d-4).
+   *
+   * \param lhs The left-hand side of the comparison
+   *
+   * \param rhs The right-hand side of the comparison
+   *
+   * \return The result of the comparison
+   */
+  CompareResult TryCompareFromLHS(Key lhs_key, Key rhs_key, int64_t offset, const PrimExpr& lhs,
+                                  const PrimExpr& rhs) const;
+
+  /*! \brief Previous Range bindings
+   *
+   * Tracked separatedly to handle the `allow_override` option used by
+   * all sub-analyzers when binding variables.
+   */
+  Map<Var, Range> prev_bindings_;
+
+  /*! \brief Known comparisons based on definitionally-true statements
+   *
+   * For example, a Let binding, or the range of an iterator.
+   */
+  std::vector<Comparison> knowns_;
+
+  /*! \brief Known comparisons based on of scope-based statements
+   *
+   * For example, the condition of an IfThenElse, which is known to be
+   * true while within the if scope.
+   */
+  std::vector<Comparison> scoped_knowns_;
+};
+
+namespace {
+
+// Internal utility, return the CompareResult resulting from swapping
+// the left-hand side with the right-hand side.
+CompareResult Reverse(CompareResult res) {
+  switch (res) {
+    case CompareResult::kInconsistent:
+      return CompareResult::kInconsistent;
+    case CompareResult::kEQ:
+      return CompareResult::kEQ;
+    case CompareResult::kLT:
+      return CompareResult::kGT;
+    case CompareResult::kLE:
+      return CompareResult::kGE;
+    case CompareResult::kGT:
+      return CompareResult::kLT;
+    case CompareResult::kGE:
+      return CompareResult::kLE;
+    case CompareResult::kNE:
+      return CompareResult::kNE;
+    case CompareResult::kUnknown:
+      return CompareResult::kUnknown;
+    default:
+      LOG(FATAL) << "Invalid CompareResult: " << static_cast<int>(res);
+      return CompareResult::kInconsistent;
+  }
+}
+
+// Internal utility, return the CompareResult resulting from negating
+// the comparison.
+CompareResult Negate(CompareResult res) {
+  switch (res) {
+    case CompareResult::kInconsistent:
+      return CompareResult::kInconsistent;
+    case CompareResult::kUnknown:
+      return CompareResult::kUnknown;
+    default:
+      return CompareResult(~static_cast<int>(res) & static_cast<int>(CompareResult::kUnknown));
+  }
+}
+
+// Internal utility, extract constant offsets out of the two sides of
+// a comparison.  Given lhs and rhs, return a tuple of three elements
+// (lhs_inner, rhs_inner, offset), such that (lhs OP rhs) and
+// (lhs_inner OP rhs_inner + offset) are equivalent.
+std::tuple<PrimExpr, PrimExpr, int64_t> ExtractOffsets(const PrimExpr& lhs, const PrimExpr& rhs) {
+  auto extract_offset = [](const PrimExpr& expr) -> std::pair<PrimExpr, int64_t> {
+    PVar<PrimExpr> x;
+    PVar<IntImm> c;
+    if ((x + c).Match(expr)) {
+      return {x.Eval(), c.Eval()->value};
+    } else if ((x - c).Match(expr)) {
+      return {x.Eval(), -c.Eval()->value};
+    } else if (c.Match(expr)) {
+      return {0, c.Eval()->value};
+    } else {
+      return {expr, 0};
+    }
+  };
+
+  auto lhs_split = extract_offset(lhs);
+  auto rhs_split = extract_offset(rhs);
+  return {lhs_split.first, rhs_split.first, rhs_split.second - lhs_split.second};
+}
+
+}  // namespace
+
+std::optional<TransitiveComparisonAnalyzer::Impl::Comparison>
+TransitiveComparisonAnalyzer::Impl::FromExpr(const PrimExpr& expr) {
+  CompareResult res;
+  PVar<PrimExpr> x, y;
+  if ((x <= y).Match(expr)) {
+    res = CompareResult::kLE;
+  } else if ((x >= y).Match(expr)) {
+    res = CompareResult::kGE;
+  } else if ((x < y).Match(expr)) {
+    res = CompareResult::kLT;
+  } else if ((x > y).Match(expr)) {
+    res = CompareResult::kGT;
+  } else if ((x == y).Match(expr)) {
+    res = CompareResult::kEQ;
+  } else if ((x != y).Match(expr)) {
+    res = CompareResult::kNE;
+  } else {
+    return std::nullopt;
+  }
+
+  PrimExpr lhs_expr = x.Eval();
+  PrimExpr rhs_expr = y.Eval();
+
+  if (lhs_expr.as<IntImmNode>() && rhs_expr.as<IntImmNode>()) {
+    return std::nullopt;
+  }
+
+  auto [lhs, rhs, offset] = ExtractOffsets(lhs_expr, rhs_expr);
+  Key lhs_key = ExprToKey(lhs);
+  Key rhs_key = ExprToKey(rhs);
+
+  return Comparison(lhs_key, rhs_key, offset, res);
+}
+
+TransitiveComparisonAnalyzer::Impl::Comparison::Comparison(Key lhs, Key rhs, int64_t offset,
+                                                           CompareResult result)
+    : lhs_(lhs), rhs_(rhs), offset_(offset), result_(result) {
+  if (result_ == CompareResult::kLT) {
+    result_ = CompareResult::kLE;
+    offset_ -= 1;
+  }
+  if (result_ == CompareResult::kGT) {
+    result_ = CompareResult::kGE;
+    offset_ += 1;
+  }
+}
+
+std::optional<TransitiveComparisonAnalyzer::Impl::Key>
+TransitiveComparisonAnalyzer::Impl::ExprToPreviousKey(const PrimExpr& expr) const {
+  auto it = expr_to_key.find(expr);
+  if (it != expr_to_key.end()) {
+    return it->second;
+  } else {
+    return std::nullopt;
+  }
+}
+
+TransitiveComparisonAnalyzer::Impl::Key TransitiveComparisonAnalyzer::Impl::ExprToKey(
+    const PrimExpr& expr) {
+  if (auto prev = ExprToPreviousKey(expr)) {
+    return prev.value();
+  } else {
+    Key new_key = Key(expr_to_key.size());
+    expr_to_key[expr] = new_key;
+    return new_key;
+  }
+}
+
+bool TransitiveComparisonAnalyzer::Impl::Comparison::IsNormalized() const {
+  // These < and > should be removed during normalization.
+  return result_ != CompareResult::kLT && result_ != CompareResult::kGT;
+}
+
+std::optional<TransitiveComparisonAnalyzer::Impl::Comparison>
+TransitiveComparisonAnalyzer::Impl::Comparison::WithLHS(Key new_lhs) const {
+  if (new_lhs == lhs_) {
+    return *this;
+  } else if (new_lhs == rhs_) {
+    return Comparison(rhs_, lhs_, -offset_, Reverse(result_));
+  } else {
+    return std::nullopt;
+  }
+}
+
+TransitiveComparisonAnalyzer::Impl::Comparison
+TransitiveComparisonAnalyzer::Impl::Comparison::Negated() const {
+  return Comparison(lhs_, rhs_, offset_, Negate(result_));
+}
+
+bool TransitiveComparisonAnalyzer::Impl::Comparison::Implies(
+    const TransitiveComparisonAnalyzer::Impl::Comparison& other) const {
+  ICHECK(lhs_ == other.lhs_);
+  ICHECK(rhs_ == other.rhs_);
+  ICHECK(IsNormalized());
+  ICHECK(other.IsNormalized());
+
+  if (result_ == other.result_ && offset_ == other.offset_) {
+    // if c1 == c2, x != y + c1 => x != y + c2
+    // if c1 == c2, x == y + c1 => x == y + c2
+    return true;
+  }
+
+  if (other.result_ == CompareResult::kLE && offset_ <= other.offset_) {
+    if (result_ == CompareResult::kEQ || result_ == CompareResult::kLE) {
+      // if c1 <= c2, x <= y + c1 => x <= y + c2
+      // if c1 <= c2, x == y + c1 => x <= y + c2
+      return true;
+    }
+  }
+
+  if (other.result_ == CompareResult::kGE && offset_ >= other.offset_) {
+    if (result_ == CompareResult::kEQ || result_ == CompareResult::kGE) {
+      // if c1 >= c2, x == y + c1 => x >= y + c2
+      // if c1 >= c2, x >= y + c1 => x >= y + c2
+      return true;
+    }
+  }
+
+  if (other.result_ == CompareResult::kNE) {
+    if (result_ == CompareResult::kEQ && offset_ != other.offset_) {
+      // if c1 != c2, x == y + c1 => x != y + c2
+      return true;
+    }
+
+    if (result_ == CompareResult::kLE && offset_ < other.offset_) {
+      // if c1 < c2, x <= y + c1 => x < y + c2 => x != y + c2
+      return true;
+    }
+
+    if (result_ == CompareResult::kGE && offset_ > other.offset_) {
+      // if c1 != c2, x >= y + c1 => x > y + c2 => x != y + c2
+      return true;

Review Comment:
   This should already be handled by the check on [line 364](https://github.com/apache/tvm/pull/12863/files/2abbe2f47adbd2dfa23fe5bffb216689079b93ae#diff-3635e35a52a78480d112dd843d407cd80f9634c44e24b532f21fc8522eae8066R364).   These three conditions only apply when `other.result_ == kNE`.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] Lunderberg commented on a diff in pull request #12863: [TIR][Arith] Prove conditionals by transitively applying knowns

Posted by GitBox <gi...@apache.org>.
Lunderberg commented on code in PR #12863:
URL: https://github.com/apache/tvm/pull/12863#discussion_r989464193


##########
src/arith/transitive_comparison_analyzer.cc:
##########
@@ -0,0 +1,683 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file tvm/arith/transitive_comparison_analyzer.cc
+ */
+
+#include <tvm/arith/analyzer.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/expr.h>
+
+#include <optional>
+#include <vector>
+
+#include "constraint_extract.h"
+#include "pattern_match.h"
+
+namespace tvm {
+namespace arith {
+
+using namespace tir;
+
+class TransitiveComparisonAnalyzer::Impl {
+ public:
+  /* \brief Using previously specified knowns, compare the expressions provided
+   *
+   * \param lhs The left-hand side of the comparison
+   *
+   * \param rhs The right-hand side of the comparison
+   *
+   * \return The most specific result that can be proven about the
+   * comparison.  If nothing can be proven, returns kUnknown.
+   */
+  CompareResult TryCompare(const PrimExpr& lhs, const PrimExpr& rhs) const;
+
+  /*! \brief Bind a variable as being equal to a known expression
+   *
+   * \param var The variable of interest.
+   * \param expr The bound expression
+   * \param allow_override Whether to allow override of existing information.
+   */
+  void Bind(const tir::Var& var, const PrimExpr& expr, bool allow_override = false);
+
+  /*! \brief Bind a variable as being within a specified range
+   *
+   * \param var The variable of interest.
+   * \param range The known range
+   * \param allow_override Whether to allow override of existing information.
+   */
+  void Bind(const tir::Var& var, const Range& expr, bool allow_override = false);
+
+  /*!
+   * \brief Update the internal state to enter constraint.
+   * \param constraint A constraint expression.
+   *
+   * \return An exit function that must be called to cleanup.  May be
+   * `nullptr`, if no cleanup is required.
+   */
+  std::function<void()> EnterConstraint(const PrimExpr& expr);
+
+ private:
+  // Utility class to avoid needing to repeatedly call ExprDeepEqual
+  enum class Key : size_t {};
+  std::optional<Key> ExprToPreviousKey(const PrimExpr& expr) const;
+  Key ExprToKey(const PrimExpr& expr);
+  std::unordered_map<PrimExpr, Key, StructuralHash, StructuralEqual> expr_to_key;
+
+  /*! \brief Internal representation of a comparison operator */
+  struct Comparison {
+    /*! \brief Construct a comparison that represents `lhs OP rhs +
+     * offset`, where the operation is specified by the CompareResult.
+     */
+    Comparison(Key lhs, Key rhs, int64_t offset, CompareResult result);
+
+    /*! \brief Utility function to validate that all GT and LT results
+     *  have been normalized out
+     */
+    bool IsNormalized() const;
+
+    /*! \brief Move the specified expression to the LHS.
+     *
+     * \param new_lhs The argument that should be moved to the LHS of the
+     * comparison.
+     *
+     * \return If possible, returns a comparison that is equivalent to
+     * the current comparison, but with the specified LHS.  If not
+     * possible, returns nullopt.
+     */
+    std::optional<Comparison> WithLHS(Key new_lhs) const;
+
+    /*! \brief Create the negation of the current comparison */
+    Comparison Negated() const;
+
+    /*! \brief Check the this comparison implies
+     *
+     * Returns true if this comparison being true implies that the
+     * other comparison must also be true.  Returns false if the other
+     * comparison cannot be shown to be true.
+     */
+    bool Implies(const Comparison& other) const;
+
+    // The LHS of the comparison
+    Key lhs_;
+
+    // The RHS of the comparison, not including any constant offset.
+    Key rhs_;
+
+    // Additive offset on rhs
+    int64_t offset_{0};
+
+    // The comparison operator.
+    CompareResult result_{CompareResult::kInconsistent};
+  };
+
+  /*! \brief Generate a Comparison representing the given expression */
+  std::optional<Comparison> FromExpr(const PrimExpr& expr);
+
+  /*! \brief Utility function used by Bind and EnterConstraint
+   *
+   * \param expr The comparison expression, to be converted into
+   * internal Comparison objects.
+   *
+   * \param vec The vector to which the Comparison objects should be
+   * appended.
+   */
+  void AddKnown(const PrimExpr& expr, std::vector<Comparison>* vec);
+
+  /*! \brief Attempt to compare, starting at the lhs.
+   *
+   * Taking each available `Comparison` as a node edge, search for a
+   * path from lhs to rhs.  For example, the priors (a<=b), (b<=c+1)
+   * and (c<=d-5) can be used to prove that (a<=d-4).
+   *
+   * \param lhs The left-hand side of the comparison
+   *
+   * \param rhs The right-hand side of the comparison
+   *
+   * \return The result of the comparison
+   */
+  CompareResult TryCompareFromLHS(Key lhs_key, Key rhs_key, int64_t offset, const PrimExpr& lhs,
+                                  const PrimExpr& rhs) const;
+
+  /*! \brief Previous Range bindings
+   *
+   * Tracked separatedly to handle the `allow_override` option used by
+   * all sub-analyzers when binding variables.
+   */
+  Map<Var, Range> prev_bindings_;
+
+  /*! \brief Known comparisons based on definitionally-true statements
+   *
+   * For example, a Let binding, or the range of an iterator.
+   */
+  std::vector<Comparison> knowns_;
+
+  /*! \brief Known comparisons based on of scope-based statements
+   *
+   * For example, the condition of an IfThenElse, which is known to be
+   * true while within the if scope.
+   */
+  std::vector<Comparison> scoped_knowns_;
+};
+
+namespace {
+
+// Internal utility, return the CompareResult resulting from swapping
+// the left-hand side with the right-hand side.
+CompareResult Reverse(CompareResult res) {
+  switch (res) {
+    case CompareResult::kInconsistent:
+      return CompareResult::kInconsistent;
+    case CompareResult::kEQ:
+      return CompareResult::kEQ;
+    case CompareResult::kLT:
+      return CompareResult::kGT;
+    case CompareResult::kLE:
+      return CompareResult::kGE;
+    case CompareResult::kGT:
+      return CompareResult::kLT;
+    case CompareResult::kGE:
+      return CompareResult::kLE;
+    case CompareResult::kNE:
+      return CompareResult::kNE;
+    case CompareResult::kUnknown:
+      return CompareResult::kUnknown;
+    default:
+      LOG(FATAL) << "Invalid CompareResult: " << static_cast<int>(res);
+      return CompareResult::kInconsistent;
+  }
+}
+
+// Internal utility, return the CompareResult resulting from negating
+// the comparison.
+CompareResult Negate(CompareResult res) {
+  switch (res) {
+    case CompareResult::kInconsistent:
+      return CompareResult::kInconsistent;
+    case CompareResult::kUnknown:
+      return CompareResult::kUnknown;
+    default:
+      return CompareResult(~static_cast<int>(res) & static_cast<int>(CompareResult::kUnknown));
+  }
+}
+
+// Internal utility, extract constant offsets out of the two sides of
+// a comparison.  Given lhs and rhs, return a tuple of three elements
+// (lhs_inner, rhs_inner, offset), such that (lhs OP rhs) and
+// (lhs_inner OP rhs_inner + offset) are equivalent.
+std::tuple<PrimExpr, PrimExpr, int64_t> ExtractOffsets(const PrimExpr& lhs, const PrimExpr& rhs) {
+  auto extract_offset = [](const PrimExpr& expr) -> std::pair<PrimExpr, int64_t> {
+    PVar<PrimExpr> x;
+    PVar<IntImm> c;
+    if ((x + c).Match(expr)) {
+      return {x.Eval(), c.Eval()->value};
+    } else if ((x - c).Match(expr)) {
+      return {x.Eval(), -c.Eval()->value};
+    } else if (c.Match(expr)) {
+      return {0, c.Eval()->value};
+    } else {
+      return {expr, 0};
+    }
+  };
+
+  auto lhs_split = extract_offset(lhs);
+  auto rhs_split = extract_offset(rhs);
+  return {lhs_split.first, rhs_split.first, rhs_split.second - lhs_split.second};
+}
+
+}  // namespace
+
+std::optional<TransitiveComparisonAnalyzer::Impl::Comparison>
+TransitiveComparisonAnalyzer::Impl::FromExpr(const PrimExpr& expr) {
+  CompareResult res;
+  PVar<PrimExpr> x, y;
+  if ((x <= y).Match(expr)) {
+    res = CompareResult::kLE;
+  } else if ((x >= y).Match(expr)) {
+    res = CompareResult::kGE;
+  } else if ((x < y).Match(expr)) {
+    res = CompareResult::kLT;
+  } else if ((x > y).Match(expr)) {
+    res = CompareResult::kGT;
+  } else if ((x == y).Match(expr)) {
+    res = CompareResult::kEQ;
+  } else if ((x != y).Match(expr)) {
+    res = CompareResult::kNE;
+  } else {
+    return std::nullopt;
+  }
+
+  PrimExpr lhs_expr = x.Eval();
+  PrimExpr rhs_expr = y.Eval();
+
+  if (lhs_expr.as<IntImmNode>() && rhs_expr.as<IntImmNode>()) {
+    return std::nullopt;
+  }
+
+  auto [lhs, rhs, offset] = ExtractOffsets(lhs_expr, rhs_expr);
+  Key lhs_key = ExprToKey(lhs);
+  Key rhs_key = ExprToKey(rhs);
+
+  return Comparison(lhs_key, rhs_key, offset, res);
+}
+
+TransitiveComparisonAnalyzer::Impl::Comparison::Comparison(Key lhs, Key rhs, int64_t offset,
+                                                           CompareResult result)
+    : lhs_(lhs), rhs_(rhs), offset_(offset), result_(result) {
+  if (result_ == CompareResult::kLT) {
+    result_ = CompareResult::kLE;
+    offset_ -= 1;
+  }
+  if (result_ == CompareResult::kGT) {
+    result_ = CompareResult::kGE;
+    offset_ += 1;
+  }
+}
+
+std::optional<TransitiveComparisonAnalyzer::Impl::Key>
+TransitiveComparisonAnalyzer::Impl::ExprToPreviousKey(const PrimExpr& expr) const {
+  auto it = expr_to_key.find(expr);
+  if (it != expr_to_key.end()) {
+    return it->second;
+  } else {
+    return std::nullopt;
+  }
+}
+
+TransitiveComparisonAnalyzer::Impl::Key TransitiveComparisonAnalyzer::Impl::ExprToKey(
+    const PrimExpr& expr) {
+  if (auto prev = ExprToPreviousKey(expr)) {
+    return prev.value();
+  } else {
+    Key new_key = Key(expr_to_key.size());
+    expr_to_key[expr] = new_key;
+    return new_key;
+  }
+}
+
+bool TransitiveComparisonAnalyzer::Impl::Comparison::IsNormalized() const {
+  // These < and > should be removed during normalization.
+  return result_ != CompareResult::kLT && result_ != CompareResult::kGT;
+}
+
+std::optional<TransitiveComparisonAnalyzer::Impl::Comparison>
+TransitiveComparisonAnalyzer::Impl::Comparison::WithLHS(Key new_lhs) const {
+  if (new_lhs == lhs_) {
+    return *this;
+  } else if (new_lhs == rhs_) {
+    return Comparison(rhs_, lhs_, -offset_, Reverse(result_));
+  } else {
+    return std::nullopt;
+  }
+}
+
+TransitiveComparisonAnalyzer::Impl::Comparison
+TransitiveComparisonAnalyzer::Impl::Comparison::Negated() const {
+  return Comparison(lhs_, rhs_, offset_, Negate(result_));
+}
+
+bool TransitiveComparisonAnalyzer::Impl::Comparison::Implies(
+    const TransitiveComparisonAnalyzer::Impl::Comparison& other) const {
+  ICHECK(lhs_ == other.lhs_);
+  ICHECK(rhs_ == other.rhs_);
+  ICHECK(IsNormalized());
+  ICHECK(other.IsNormalized());
+
+  if (result_ == other.result_ && offset_ == other.offset_) {
+    // if c1 == c2, x != y + c1 => x != y + c2
+    // if c1 == c2, x == y + c1 => x == y + c2
+    return true;
+  }
+
+  if (other.result_ == CompareResult::kLE && offset_ <= other.offset_) {
+    if (result_ == CompareResult::kEQ || result_ == CompareResult::kLE) {
+      // if c1 <= c2, x <= y + c1 => x <= y + c2
+      // if c1 <= c2, x == y + c1 => x <= y + c2
+      return true;
+    }
+  }
+
+  if (other.result_ == CompareResult::kGE && offset_ >= other.offset_) {
+    if (result_ == CompareResult::kEQ || result_ == CompareResult::kGE) {
+      // if c1 >= c2, x == y + c1 => x >= y + c2
+      // if c1 >= c2, x >= y + c1 => x >= y + c2
+      return true;
+    }
+  }
+
+  if (other.result_ == CompareResult::kNE) {
+    if (result_ == CompareResult::kEQ && offset_ != other.offset_) {
+      // if c1 != c2, x == y + c1 => x != y + c2
+      return true;
+    }
+
+    if (result_ == CompareResult::kLE && offset_ < other.offset_) {
+      // if c1 < c2, x <= y + c1 => x < y + c2 => x != y + c2
+      return true;
+    }
+
+    if (result_ == CompareResult::kGE && offset_ > other.offset_) {
+      // if c1 != c2, x >= y + c1 => x > y + c2 => x != y + c2
+      return true;
+    }
+  }
+
+  return false;
+}
+
+TransitiveComparisonAnalyzer::TransitiveComparisonAnalyzer() : impl_(std::make_unique<Impl>()) {}
+TransitiveComparisonAnalyzer::~TransitiveComparisonAnalyzer() {}
+
+CompareResult TransitiveComparisonAnalyzer::TryCompare(const PrimExpr& lhs, const PrimExpr& rhs) {
+  return impl_->TryCompare(lhs, rhs);
+}
+
+void TransitiveComparisonAnalyzer::Bind(const Var& var, const PrimExpr& expr, bool allow_override) {
+  impl_->Bind(var, expr, allow_override);
+}
+void TransitiveComparisonAnalyzer::Bind(const Var& var, const Range& range, bool allow_override) {
+  impl_->Bind(var, range, allow_override);
+}
+
+std::function<void()> TransitiveComparisonAnalyzer::EnterConstraint(const PrimExpr& constraint) {
+  return impl_->EnterConstraint(constraint);
+}
+
+void TransitiveComparisonAnalyzer::Impl::AddKnown(const PrimExpr& expr,
+                                                  std::vector<Comparison>* vec) {
+  for (const auto& subexpr : ExtractConstraints(expr)) {
+    if (tir::SideEffect(expr) <= tir::CallEffectKind::kPure) {
+      if (auto cmp = FromExpr(subexpr)) {
+        vec->push_back(cmp.value());
+      }
+    }
+  }
+}
+
+void TransitiveComparisonAnalyzer::Impl::Bind(const tir::Var& var, const Range& range,
+                                              bool allow_override) {
+  auto it = prev_bindings_.find(var);
+  if (it != prev_bindings_.end()) {
+    ExprDeepEqual expr_equal;
+    bool differs_from_previous = !expr_equal(range->min, (*it).second->min) ||
+                                 !expr_equal(range->extent, (*it).second->extent);
+    if (differs_from_previous) {
+      ICHECK(allow_override) << "Binding of variable " << var << " as " << range
+                             << " conflicts with previous binding as " << (*it).second;
+      if (auto key = ExprToPreviousKey(var)) {
+        knowns_.erase(std::remove_if(knowns_.begin(), knowns_.end(),
+                                     [&](const auto& known) { return known.lhs_ == key.value(); }),
+                      knowns_.end());
+      }
+    }
+  }
+
+  prev_bindings_.Set(var, range);
+
+  if (is_const_int(range->extent, 1)) {
+    AddKnown(var == range->min, &knowns_);
+  } else {
+    AddKnown(var >= range->min, &knowns_);
+    AddKnown(var < range->min + range->extent, &knowns_);
+  }
+}
+
+void TransitiveComparisonAnalyzer::Impl::Bind(const tir::Var& var, const PrimExpr& expr,
+                                              bool allow_override) {
+  Bind(var, Range::FromMinExtent(expr, 1), allow_override);
+}
+
+std::function<void()> TransitiveComparisonAnalyzer::Impl::EnterConstraint(const PrimExpr& expr) {
+  size_t old_literal_size = scoped_knowns_.size();
+  AddKnown(expr, &scoped_knowns_);
+  size_t new_literal_size = scoped_knowns_.size();
+
+  PrimExpr temp = expr;
+  auto frecover = [old_literal_size, new_literal_size, this, temp]() {
+    ICHECK_EQ(scoped_knowns_.size(), new_literal_size);
+    scoped_knowns_.erase(scoped_knowns_.begin() + old_literal_size, scoped_knowns_.end());
+  };
+  return frecover;
+}
+
+CompareResult TransitiveComparisonAnalyzer::Impl::TryCompare(const PrimExpr& lhs_expr,
+                                                             const PrimExpr& rhs_expr) const {
+  // Currently only supports integer checks
+  if (!lhs_expr.dtype().is_int() || !rhs_expr.dtype().is_int()) {
+    return CompareResult::kUnknown;
+  }
+
+  // Bail out early if possible.  This int check should have been
+  // constant-folded earlier, so this check shouldn't occur.
+  auto* x_int = lhs_expr.as<IntImmNode>();
+  auto* y_int = rhs_expr.as<IntImmNode>();
+  if (x_int && y_int) {
+    if (x_int->value < y_int->value) {
+      return CompareResult::kLT;
+    } else if (x_int->value > y_int->value) {
+      return CompareResult::kGT;
+    } else {
+      return CompareResult::kEQ;
+    }
+  }
+
+  auto [lhs, rhs, offset] = ExtractOffsets(lhs_expr, rhs_expr);
+  auto lhs_key = ExprToPreviousKey(lhs);
+  auto rhs_key = ExprToPreviousKey(rhs);
+
+  if (!lhs_key.has_value() || !rhs_key.has_value()) {
+    return CompareResult::kUnknown;
+  }
+
+  auto from_lhs = TryCompareFromLHS(lhs_key.value(), rhs_key.value(), offset, lhs, rhs);
+  auto from_rhs = Reverse(TryCompareFromLHS(rhs_key.value(), lhs_key.value(), -offset, rhs, lhs));
+  auto output = from_lhs & from_rhs;
+
+  return output;
+}
+
+CompareResult TransitiveComparisonAnalyzer::Impl::TryCompareFromLHS(

Review Comment:
   Hmm, good point.  Renamed to `DFSFromLHS`, which hopefully works with the updated documentation to reduce the verbosity.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] Lunderberg commented on a diff in pull request #12863: [TIR][Arith] Prove conditionals by transitively applying knowns

Posted by GitBox <gi...@apache.org>.
Lunderberg commented on code in PR #12863:
URL: https://github.com/apache/tvm/pull/12863#discussion_r989462775


##########
src/arith/transitive_comparison_analyzer.cc:
##########
@@ -0,0 +1,683 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file tvm/arith/transitive_comparison_analyzer.cc
+ */
+
+#include <tvm/arith/analyzer.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/expr.h>
+
+#include <optional>
+#include <vector>
+
+#include "constraint_extract.h"
+#include "pattern_match.h"
+
+namespace tvm {
+namespace arith {
+
+using namespace tir;
+
+class TransitiveComparisonAnalyzer::Impl {
+ public:
+  /* \brief Using previously specified knowns, compare the expressions provided
+   *
+   * \param lhs The left-hand side of the comparison
+   *
+   * \param rhs The right-hand side of the comparison
+   *
+   * \return The most specific result that can be proven about the
+   * comparison.  If nothing can be proven, returns kUnknown.
+   */
+  CompareResult TryCompare(const PrimExpr& lhs, const PrimExpr& rhs) const;
+
+  /*! \brief Bind a variable as being equal to a known expression
+   *
+   * \param var The variable of interest.
+   * \param expr The bound expression
+   * \param allow_override Whether to allow override of existing information.
+   */
+  void Bind(const tir::Var& var, const PrimExpr& expr, bool allow_override = false);
+
+  /*! \brief Bind a variable as being within a specified range
+   *
+   * \param var The variable of interest.
+   * \param range The known range
+   * \param allow_override Whether to allow override of existing information.
+   */
+  void Bind(const tir::Var& var, const Range& expr, bool allow_override = false);
+
+  /*!
+   * \brief Update the internal state to enter constraint.
+   * \param constraint A constraint expression.
+   *
+   * \return An exit function that must be called to cleanup.  May be
+   * `nullptr`, if no cleanup is required.
+   */
+  std::function<void()> EnterConstraint(const PrimExpr& expr);
+
+ private:
+  // Utility class to avoid needing to repeatedly call ExprDeepEqual
+  enum class Key : size_t {};
+  std::optional<Key> ExprToPreviousKey(const PrimExpr& expr) const;
+  Key ExprToKey(const PrimExpr& expr);
+  std::unordered_map<PrimExpr, Key, StructuralHash, StructuralEqual> expr_to_key;
+
+  /*! \brief Internal representation of a comparison operator */
+  struct Comparison {
+    /*! \brief Construct a comparison that represents `lhs OP rhs +
+     * offset`, where the operation is specified by the CompareResult.
+     */
+    Comparison(Key lhs, Key rhs, int64_t offset, CompareResult result);
+
+    /*! \brief Utility function to validate that all GT and LT results
+     *  have been normalized out
+     */
+    bool IsNormalized() const;
+
+    /*! \brief Move the specified expression to the LHS.
+     *
+     * \param new_lhs The argument that should be moved to the LHS of the
+     * comparison.
+     *
+     * \return If possible, returns a comparison that is equivalent to
+     * the current comparison, but with the specified LHS.  If not
+     * possible, returns nullopt.
+     */
+    std::optional<Comparison> WithLHS(Key new_lhs) const;
+
+    /*! \brief Create the negation of the current comparison */
+    Comparison Negated() const;
+
+    /*! \brief Check the this comparison implies
+     *
+     * Returns true if this comparison being true implies that the
+     * other comparison must also be true.  Returns false if the other
+     * comparison cannot be shown to be true.
+     */
+    bool Implies(const Comparison& other) const;
+
+    // The LHS of the comparison
+    Key lhs_;
+
+    // The RHS of the comparison, not including any constant offset.
+    Key rhs_;
+
+    // Additive offset on rhs
+    int64_t offset_{0};
+
+    // The comparison operator.
+    CompareResult result_{CompareResult::kInconsistent};
+  };
+
+  /*! \brief Generate a Comparison representing the given expression */
+  std::optional<Comparison> FromExpr(const PrimExpr& expr);
+
+  /*! \brief Utility function used by Bind and EnterConstraint
+   *
+   * \param expr The comparison expression, to be converted into
+   * internal Comparison objects.
+   *
+   * \param vec The vector to which the Comparison objects should be
+   * appended.
+   */
+  void AddKnown(const PrimExpr& expr, std::vector<Comparison>* vec);
+
+  /*! \brief Attempt to compare, starting at the lhs.
+   *
+   * Taking each available `Comparison` as a node edge, search for a
+   * path from lhs to rhs.  For example, the priors (a<=b), (b<=c+1)
+   * and (c<=d-5) can be used to prove that (a<=d-4).
+   *
+   * \param lhs The left-hand side of the comparison
+   *
+   * \param rhs The right-hand side of the comparison
+   *
+   * \return The result of the comparison
+   */
+  CompareResult TryCompareFromLHS(Key lhs_key, Key rhs_key, int64_t offset, const PrimExpr& lhs,
+                                  const PrimExpr& rhs) const;
+
+  /*! \brief Previous Range bindings
+   *
+   * Tracked separatedly to handle the `allow_override` option used by
+   * all sub-analyzers when binding variables.
+   */
+  Map<Var, Range> prev_bindings_;
+
+  /*! \brief Known comparisons based on definitionally-true statements
+   *
+   * For example, a Let binding, or the range of an iterator.
+   */
+  std::vector<Comparison> knowns_;
+
+  /*! \brief Known comparisons based on of scope-based statements
+   *
+   * For example, the condition of an IfThenElse, which is known to be
+   * true while within the if scope.
+   */
+  std::vector<Comparison> scoped_knowns_;
+};
+
+namespace {
+
+// Internal utility, return the CompareResult resulting from swapping
+// the left-hand side with the right-hand side.
+CompareResult Reverse(CompareResult res) {
+  switch (res) {
+    case CompareResult::kInconsistent:
+      return CompareResult::kInconsistent;
+    case CompareResult::kEQ:
+      return CompareResult::kEQ;
+    case CompareResult::kLT:
+      return CompareResult::kGT;
+    case CompareResult::kLE:
+      return CompareResult::kGE;
+    case CompareResult::kGT:
+      return CompareResult::kLT;
+    case CompareResult::kGE:
+      return CompareResult::kLE;
+    case CompareResult::kNE:
+      return CompareResult::kNE;
+    case CompareResult::kUnknown:
+      return CompareResult::kUnknown;
+    default:
+      LOG(FATAL) << "Invalid CompareResult: " << static_cast<int>(res);
+      return CompareResult::kInconsistent;
+  }
+}
+
+// Internal utility, return the CompareResult resulting from negating
+// the comparison.
+CompareResult Negate(CompareResult res) {
+  switch (res) {
+    case CompareResult::kInconsistent:
+      return CompareResult::kInconsistent;
+    case CompareResult::kUnknown:
+      return CompareResult::kUnknown;
+    default:
+      return CompareResult(~static_cast<int>(res) & static_cast<int>(CompareResult::kUnknown));
+  }
+}
+
+// Internal utility, extract constant offsets out of the two sides of
+// a comparison.  Given lhs and rhs, return a tuple of three elements
+// (lhs_inner, rhs_inner, offset), such that (lhs OP rhs) and
+// (lhs_inner OP rhs_inner + offset) are equivalent.
+std::tuple<PrimExpr, PrimExpr, int64_t> ExtractOffsets(const PrimExpr& lhs, const PrimExpr& rhs) {
+  auto extract_offset = [](const PrimExpr& expr) -> std::pair<PrimExpr, int64_t> {
+    PVar<PrimExpr> x;
+    PVar<IntImm> c;
+    if ((x + c).Match(expr)) {
+      return {x.Eval(), c.Eval()->value};
+    } else if ((x - c).Match(expr)) {
+      return {x.Eval(), -c.Eval()->value};
+    } else if (c.Match(expr)) {
+      return {0, c.Eval()->value};
+    } else {
+      return {expr, 0};
+    }
+  };
+
+  auto lhs_split = extract_offset(lhs);
+  auto rhs_split = extract_offset(rhs);
+  return {lhs_split.first, rhs_split.first, rhs_split.second - lhs_split.second};
+}
+
+}  // namespace
+
+std::optional<TransitiveComparisonAnalyzer::Impl::Comparison>
+TransitiveComparisonAnalyzer::Impl::FromExpr(const PrimExpr& expr) {
+  CompareResult res;
+  PVar<PrimExpr> x, y;
+  if ((x <= y).Match(expr)) {
+    res = CompareResult::kLE;
+  } else if ((x >= y).Match(expr)) {
+    res = CompareResult::kGE;
+  } else if ((x < y).Match(expr)) {
+    res = CompareResult::kLT;
+  } else if ((x > y).Match(expr)) {
+    res = CompareResult::kGT;
+  } else if ((x == y).Match(expr)) {
+    res = CompareResult::kEQ;
+  } else if ((x != y).Match(expr)) {
+    res = CompareResult::kNE;
+  } else {
+    return std::nullopt;
+  }
+
+  PrimExpr lhs_expr = x.Eval();
+  PrimExpr rhs_expr = y.Eval();
+
+  if (lhs_expr.as<IntImmNode>() && rhs_expr.as<IntImmNode>()) {
+    return std::nullopt;
+  }
+
+  auto [lhs, rhs, offset] = ExtractOffsets(lhs_expr, rhs_expr);
+  Key lhs_key = ExprToKey(lhs);
+  Key rhs_key = ExprToKey(rhs);
+
+  return Comparison(lhs_key, rhs_key, offset, res);
+}
+
+TransitiveComparisonAnalyzer::Impl::Comparison::Comparison(Key lhs, Key rhs, int64_t offset,
+                                                           CompareResult result)
+    : lhs_(lhs), rhs_(rhs), offset_(offset), result_(result) {
+  if (result_ == CompareResult::kLT) {
+    result_ = CompareResult::kLE;
+    offset_ -= 1;
+  }
+  if (result_ == CompareResult::kGT) {
+    result_ = CompareResult::kGE;
+    offset_ += 1;
+  }
+}
+
+std::optional<TransitiveComparisonAnalyzer::Impl::Key>
+TransitiveComparisonAnalyzer::Impl::ExprToPreviousKey(const PrimExpr& expr) const {
+  auto it = expr_to_key.find(expr);
+  if (it != expr_to_key.end()) {
+    return it->second;
+  } else {
+    return std::nullopt;
+  }
+}
+
+TransitiveComparisonAnalyzer::Impl::Key TransitiveComparisonAnalyzer::Impl::ExprToKey(
+    const PrimExpr& expr) {
+  if (auto prev = ExprToPreviousKey(expr)) {
+    return prev.value();
+  } else {
+    Key new_key = Key(expr_to_key.size());
+    expr_to_key[expr] = new_key;
+    return new_key;
+  }
+}
+
+bool TransitiveComparisonAnalyzer::Impl::Comparison::IsNormalized() const {
+  // These < and > should be removed during normalization.
+  return result_ != CompareResult::kLT && result_ != CompareResult::kGT;
+}
+
+std::optional<TransitiveComparisonAnalyzer::Impl::Comparison>
+TransitiveComparisonAnalyzer::Impl::Comparison::WithLHS(Key new_lhs) const {
+  if (new_lhs == lhs_) {
+    return *this;
+  } else if (new_lhs == rhs_) {
+    return Comparison(rhs_, lhs_, -offset_, Reverse(result_));
+  } else {
+    return std::nullopt;
+  }
+}
+
+TransitiveComparisonAnalyzer::Impl::Comparison
+TransitiveComparisonAnalyzer::Impl::Comparison::Negated() const {
+  return Comparison(lhs_, rhs_, offset_, Negate(result_));
+}
+
+bool TransitiveComparisonAnalyzer::Impl::Comparison::Implies(
+    const TransitiveComparisonAnalyzer::Impl::Comparison& other) const {
+  ICHECK(lhs_ == other.lhs_);
+  ICHECK(rhs_ == other.rhs_);
+  ICHECK(IsNormalized());
+  ICHECK(other.IsNormalized());
+
+  if (result_ == other.result_ && offset_ == other.offset_) {
+    // if c1 == c2, x != y + c1 => x != y + c2
+    // if c1 == c2, x == y + c1 => x == y + c2
+    return true;
+  }
+
+  if (other.result_ == CompareResult::kLE && offset_ <= other.offset_) {
+    if (result_ == CompareResult::kEQ || result_ == CompareResult::kLE) {
+      // if c1 <= c2, x <= y + c1 => x <= y + c2
+      // if c1 <= c2, x == y + c1 => x <= y + c2
+      return true;
+    }
+  }
+
+  if (other.result_ == CompareResult::kGE && offset_ >= other.offset_) {
+    if (result_ == CompareResult::kEQ || result_ == CompareResult::kGE) {
+      // if c1 >= c2, x == y + c1 => x >= y + c2
+      // if c1 >= c2, x >= y + c1 => x >= y + c2
+      return true;
+    }
+  }
+
+  if (other.result_ == CompareResult::kNE) {
+    if (result_ == CompareResult::kEQ && offset_ != other.offset_) {
+      // if c1 != c2, x == y + c1 => x != y + c2
+      return true;
+    }
+
+    if (result_ == CompareResult::kLE && offset_ < other.offset_) {
+      // if c1 < c2, x <= y + c1 => x < y + c2 => x != y + c2
+      return true;
+    }
+
+    if (result_ == CompareResult::kGE && offset_ > other.offset_) {
+      // if c1 != c2, x >= y + c1 => x > y + c2 => x != y + c2
+      return true;
+    }
+  }
+
+  return false;
+}
+
+TransitiveComparisonAnalyzer::TransitiveComparisonAnalyzer() : impl_(std::make_unique<Impl>()) {}
+TransitiveComparisonAnalyzer::~TransitiveComparisonAnalyzer() {}
+
+CompareResult TransitiveComparisonAnalyzer::TryCompare(const PrimExpr& lhs, const PrimExpr& rhs) {
+  return impl_->TryCompare(lhs, rhs);
+}
+
+void TransitiveComparisonAnalyzer::Bind(const Var& var, const PrimExpr& expr, bool allow_override) {
+  impl_->Bind(var, expr, allow_override);
+}
+void TransitiveComparisonAnalyzer::Bind(const Var& var, const Range& range, bool allow_override) {
+  impl_->Bind(var, range, allow_override);
+}
+
+std::function<void()> TransitiveComparisonAnalyzer::EnterConstraint(const PrimExpr& constraint) {
+  return impl_->EnterConstraint(constraint);
+}
+
+void TransitiveComparisonAnalyzer::Impl::AddKnown(const PrimExpr& expr,
+                                                  std::vector<Comparison>* vec) {
+  for (const auto& subexpr : ExtractConstraints(expr)) {
+    if (tir::SideEffect(expr) <= tir::CallEffectKind::kPure) {
+      if (auto cmp = FromExpr(subexpr)) {
+        vec->push_back(cmp.value());
+      }
+    }
+  }
+}
+
+void TransitiveComparisonAnalyzer::Impl::Bind(const tir::Var& var, const Range& range,
+                                              bool allow_override) {
+  auto it = prev_bindings_.find(var);
+  if (it != prev_bindings_.end()) {
+    ExprDeepEqual expr_equal;
+    bool differs_from_previous = !expr_equal(range->min, (*it).second->min) ||
+                                 !expr_equal(range->extent, (*it).second->extent);
+    if (differs_from_previous) {
+      ICHECK(allow_override) << "Binding of variable " << var << " as " << range
+                             << " conflicts with previous binding as " << (*it).second;
+      if (auto key = ExprToPreviousKey(var)) {
+        knowns_.erase(std::remove_if(knowns_.begin(), knowns_.end(),
+                                     [&](const auto& known) { return known.lhs_ == key.value(); }),
+                      knowns_.end());
+      }
+    }
+  }
+
+  prev_bindings_.Set(var, range);
+
+  if (is_const_int(range->extent, 1)) {
+    AddKnown(var == range->min, &knowns_);
+  } else {
+    AddKnown(var >= range->min, &knowns_);
+    AddKnown(var < range->min + range->extent, &knowns_);
+  }
+}
+
+void TransitiveComparisonAnalyzer::Impl::Bind(const tir::Var& var, const PrimExpr& expr,
+                                              bool allow_override) {
+  Bind(var, Range::FromMinExtent(expr, 1), allow_override);
+}
+
+std::function<void()> TransitiveComparisonAnalyzer::Impl::EnterConstraint(const PrimExpr& expr) {
+  size_t old_literal_size = scoped_knowns_.size();
+  AddKnown(expr, &scoped_knowns_);
+  size_t new_literal_size = scoped_knowns_.size();
+
+  PrimExpr temp = expr;
+  auto frecover = [old_literal_size, new_literal_size, this, temp]() {
+    ICHECK_EQ(scoped_knowns_.size(), new_literal_size);
+    scoped_knowns_.erase(scoped_knowns_.begin() + old_literal_size, scoped_knowns_.end());
+  };
+  return frecover;
+}
+
+CompareResult TransitiveComparisonAnalyzer::Impl::TryCompare(const PrimExpr& lhs_expr,
+                                                             const PrimExpr& rhs_expr) const {
+  // Currently only supports integer checks
+  if (!lhs_expr.dtype().is_int() || !rhs_expr.dtype().is_int()) {
+    return CompareResult::kUnknown;
+  }
+
+  // Bail out early if possible.  This int check should have been
+  // constant-folded earlier, so this check shouldn't occur.
+  auto* x_int = lhs_expr.as<IntImmNode>();
+  auto* y_int = rhs_expr.as<IntImmNode>();
+  if (x_int && y_int) {
+    if (x_int->value < y_int->value) {
+      return CompareResult::kLT;
+    } else if (x_int->value > y_int->value) {
+      return CompareResult::kGT;
+    } else {
+      return CompareResult::kEQ;
+    }
+  }
+
+  auto [lhs, rhs, offset] = ExtractOffsets(lhs_expr, rhs_expr);
+  auto lhs_key = ExprToPreviousKey(lhs);
+  auto rhs_key = ExprToPreviousKey(rhs);
+
+  if (!lhs_key.has_value() || !rhs_key.has_value()) {
+    return CompareResult::kUnknown;
+  }
+
+  auto from_lhs = TryCompareFromLHS(lhs_key.value(), rhs_key.value(), offset, lhs, rhs);
+  auto from_rhs = Reverse(TryCompareFromLHS(rhs_key.value(), lhs_key.value(), -offset, rhs, lhs));
+  auto output = from_lhs & from_rhs;
+
+  return output;
+}
+
+CompareResult TransitiveComparisonAnalyzer::Impl::TryCompareFromLHS(
+    Key lhs_key_input, Key rhs_key_input, int64_t offset_input, const PrimExpr& lhs_input,
+    const PrimExpr& rhs_input) const {
+  Key lhs_key = lhs_key_input;
+  Key rhs_key = rhs_key_input;
+  int64_t offset = offset_input;
+
+  // Everything in `to_visit` has lhs as its lhs.
+  std::unordered_set<Key> seen;
+  std::unordered_set<Key> to_visit;
+  std::unordered_map<Key, std::vector<Comparison>> compared_to_x;
+
+  // Utility function to add a new known statement
+  auto declare_known = [&](Comparison cmp) {
+    auto& prev_knowns = compared_to_x[cmp.rhs_];
+
+    for (auto& prev_known : prev_knowns) {
+      if (prev_known.Implies(cmp)) {
+        return;
+      }
+    }
+
+    if (cmp.rhs_ != rhs_key && !seen.count(cmp.rhs_)) {
+      to_visit.insert(cmp.rhs_);
+      seen.insert(cmp.rhs_);
+    }
+
+    for (auto& prev_known : prev_knowns) {
+      if (cmp.Implies(prev_known)) {
+        prev_known = cmp;
+        return;
+      }
+    }
+
+    prev_knowns.push_back(cmp);
+  };
+
+  // Initialize the search based on any known (in)equalities that use
+  // the LHS of the comparison.
+  for (const auto& known : knowns_) {
+    if (auto normalized = known.WithLHS(lhs_key)) {
+      declare_known(normalized.value());
+    }
+  }
+  for (const auto& known : scoped_knowns_) {
+    if (auto normalized = known.WithLHS(lhs_key)) {
+      declare_known(normalized.value());
+    }
+  }
+
+  // Walk through the space of all comparisons that can be made with
+  // LHS.
+  while (to_visit.size()) {
+    Key middle_key = *to_visit.begin();
+    to_visit.erase(to_visit.begin());
+
+    std::vector<Comparison>& prev_knowns_using_middle = compared_to_x.at(middle_key);
+    ICHECK(compared_to_x.count(middle_key));
+
+    std::vector<Comparison> new_knowns_using_lhs;
+
+    auto attempt_transitive = [&](Comparison cmp) {
+      ICHECK(cmp.IsNormalized());
+
+      Key right_key = cmp.rhs_;
+
+      if (right_key == lhs_key) {
+        return;
+      }
+
+      for (const auto& prev : prev_knowns_using_middle) {
+        CompareResult new_result = CompareResult::kUnknown;
+        int64_t new_offset = prev.offset_ + cmp.offset_;
+
+        if (prev.result_ == CompareResult::kEQ) {
+          // x == y + c1 && y OP z + c2, x OP z + (c1 + c2)
+          new_result = cmp.result_;
+        } else if (cmp.result_ == CompareResult::kEQ) {
+          // x OP y + c1 && y == z + c2, x OP z + (c1 + c2)
+          new_result = prev.result_;
+        } else if (prev.result_ == cmp.result_ &&
+                   (prev.result_ == CompareResult::kLE || prev.result_ == CompareResult::kGE)) {
+          // x <= y + c1 && y <= z + c2, x <= z + (c1 + c2)
+          // x >= y + c1 && y >= z + c2, x >= z + (c1 + c2)
+          //
+          // This condition is much simpler to write than the
+          // equivalent handling of < or of >, which is why the
+          // inequalities are normalized to <= and to >=.
+          new_result = prev.result_;
+        }
+
+        if (new_result != CompareResult::kUnknown) {
+          Comparison new_known(lhs_key, right_key, new_offset, new_result);
+          new_knowns_using_lhs.push_back(new_known);
+        }
+      }
+    };
+
+    // Attempt to prove a new comparison using one of the original
+    // known comparisons.  We want to find a known such that
+    // `(LHS OP1 middle) && (middle OP2 right)` can be simplified

Review Comment:
   I'm not sure the best way to phrase this part.  The switch of convention is intentional, as the `right` in this comment might not be the RHS of the desired comparison.  However, this new comparison will always use the LHS of the   For example, if it is known that `a <= b`, `b <= c`, and `c <= d`, then the first step is to prove that `a <= c`.  While proving that, `a` is LHS, `b` is middle, and `c` is right.
   
   I've updated the comment for clarity, with an example.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] Lunderberg commented on a diff in pull request #12863: [TIR][Arith] Prove conditionals by transitively applying knowns

Posted by GitBox <gi...@apache.org>.
Lunderberg commented on code in PR #12863:
URL: https://github.com/apache/tvm/pull/12863#discussion_r989408515


##########
src/arith/transitive_comparison_analyzer.cc:
##########
@@ -0,0 +1,683 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file tvm/arith/transitive_comparison_analyzer.cc
+ */
+
+#include <tvm/arith/analyzer.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/expr.h>
+
+#include <optional>
+#include <vector>
+
+#include "constraint_extract.h"
+#include "pattern_match.h"
+
+namespace tvm {
+namespace arith {
+
+using namespace tir;
+
+class TransitiveComparisonAnalyzer::Impl {
+ public:
+  /* \brief Using previously specified knowns, compare the expressions provided
+   *
+   * \param lhs The left-hand side of the comparison
+   *
+   * \param rhs The right-hand side of the comparison
+   *
+   * \return The most specific result that can be proven about the
+   * comparison.  If nothing can be proven, returns kUnknown.
+   */
+  CompareResult TryCompare(const PrimExpr& lhs, const PrimExpr& rhs) const;
+
+  /*! \brief Bind a variable as being equal to a known expression
+   *
+   * \param var The variable of interest.
+   * \param expr The bound expression
+   * \param allow_override Whether to allow override of existing information.
+   */
+  void Bind(const tir::Var& var, const PrimExpr& expr, bool allow_override = false);
+
+  /*! \brief Bind a variable as being within a specified range
+   *
+   * \param var The variable of interest.
+   * \param range The known range
+   * \param allow_override Whether to allow override of existing information.
+   */
+  void Bind(const tir::Var& var, const Range& expr, bool allow_override = false);
+
+  /*!
+   * \brief Update the internal state to enter constraint.
+   * \param constraint A constraint expression.
+   *
+   * \return An exit function that must be called to cleanup.  May be
+   * `nullptr`, if no cleanup is required.
+   */
+  std::function<void()> EnterConstraint(const PrimExpr& expr);
+
+ private:
+  // Utility class to avoid needing to repeatedly call ExprDeepEqual
+  enum class Key : size_t {};
+  std::optional<Key> ExprToPreviousKey(const PrimExpr& expr) const;
+  Key ExprToKey(const PrimExpr& expr);
+  std::unordered_map<PrimExpr, Key, StructuralHash, StructuralEqual> expr_to_key;
+
+  /*! \brief Internal representation of a comparison operator */
+  struct Comparison {
+    /*! \brief Construct a comparison that represents `lhs OP rhs +
+     * offset`, where the operation is specified by the CompareResult.
+     */
+    Comparison(Key lhs, Key rhs, int64_t offset, CompareResult result);
+
+    /*! \brief Utility function to validate that all GT and LT results
+     *  have been normalized out
+     */
+    bool IsNormalized() const;
+
+    /*! \brief Move the specified expression to the LHS.
+     *
+     * \param new_lhs The argument that should be moved to the LHS of the
+     * comparison.
+     *
+     * \return If possible, returns a comparison that is equivalent to
+     * the current comparison, but with the specified LHS.  If not
+     * possible, returns nullopt.
+     */
+    std::optional<Comparison> WithLHS(Key new_lhs) const;
+
+    /*! \brief Create the negation of the current comparison */
+    Comparison Negated() const;
+
+    /*! \brief Check the this comparison implies
+     *
+     * Returns true if this comparison being true implies that the
+     * other comparison must also be true.  Returns false if the other
+     * comparison cannot be shown to be true.
+     */
+    bool Implies(const Comparison& other) const;
+
+    // The LHS of the comparison
+    Key lhs_;
+
+    // The RHS of the comparison, not including any constant offset.
+    Key rhs_;
+
+    // Additive offset on rhs
+    int64_t offset_{0};
+
+    // The comparison operator.
+    CompareResult result_{CompareResult::kInconsistent};
+  };
+
+  /*! \brief Generate a Comparison representing the given expression */
+  std::optional<Comparison> FromExpr(const PrimExpr& expr);
+
+  /*! \brief Utility function used by Bind and EnterConstraint
+   *
+   * \param expr The comparison expression, to be converted into
+   * internal Comparison objects.
+   *
+   * \param vec The vector to which the Comparison objects should be
+   * appended.
+   */
+  void AddKnown(const PrimExpr& expr, std::vector<Comparison>* vec);
+
+  /*! \brief Attempt to compare, starting at the lhs.
+   *
+   * Taking each available `Comparison` as a node edge, search for a

Review Comment:
   I probably should have said `graph edge`, but didn't specify what it is the edge between.  I've entirely re-written this function's documentation with an example, as this is the heart of the transitive analyzer and should be explained more fully.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] Lunderberg commented on a diff in pull request #12863: [TIR][Arith] Prove conditionals by transitively applying knowns

Posted by GitBox <gi...@apache.org>.
Lunderberg commented on code in PR #12863:
URL: https://github.com/apache/tvm/pull/12863#discussion_r989385552


##########
include/tvm/arith/analyzer.h:
##########
@@ -317,6 +347,82 @@ class CanonicalSimplifier {
   Impl* impl_;
 };
 
+/*! \brief Structure for representing result of known
+ *
+ * Values are assigned to allow these flags to be used in bitwise
+ * operations.
+ */
+enum class CompareResult : int {
+  kInconsistent = 0,
+  kEQ = 1,
+  kLT = 2,
+  kLE = 3,
+  kGT = 4,
+  kGE = 5,
+  kNE = 6,
+  kUnknown = 7
+};
+
+inline constexpr CompareResult operator&(CompareResult lhs, CompareResult rhs) {
+  return CompareResult(static_cast<int>(lhs) & static_cast<int>(rhs));
+}
+inline constexpr CompareResult operator|(CompareResult lhs, CompareResult rhs) {
+  return CompareResult(static_cast<int>(lhs) | static_cast<int>(rhs));
+}
+
+/*!
+ * \brief Using previously specified knowns, compare the expressions provided
+ *
+ * Given known expressions [(a OP b), (b OP c), ..., (y OP z)], search
+ * for a known result for `(a OP z)`.
+ */
+class TransitiveComparisonAnalyzer {
+ public:
+  /* \brief Using previously specified knowns, compare the expressions provided
+   *
+   * \param lhs The left-hand side of the comparison
+   *
+   * \param rhs The right-hand side of the comparison
+   *
+   * \return The most specific result that can be proven about the
+   * comparison.  If nothing can be proven, returns kUnknown.
+   */
+  CompareResult TryCompare(const PrimExpr& lhs, const PrimExpr& rhs);
+
+  /*! \brief Bind a variable as being equal to a known expression
+   *
+   * \param var The variable of interest.
+   * \param expr The bound expression
+   * \param allow_override Whether to allow override of existing information.
+   */
+  void Bind(const Var& var, const PrimExpr& expr, bool allow_override = false);

Review Comment:
   And updated



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] Lunderberg commented on a diff in pull request #12863: [TIR][Arith] Prove conditionals by transitively applying knowns

Posted by GitBox <gi...@apache.org>.
Lunderberg commented on code in PR #12863:
URL: https://github.com/apache/tvm/pull/12863#discussion_r989390769


##########
tests/python/unittest/test_tir_transform_simplify.py:
##########
@@ -138,6 +138,20 @@ def sls(n, d):
 
 class BaseBeforeAfter(tvm.testing.CompareBeforeAfter):

Review Comment:
   That's correct, enabling this extension has minimal effect on the other tests.  That said, I like the cleanliness of having it disabled by default while testing, and explicitly enable as needed.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] Lunderberg commented on a diff in pull request #12863: [TIR][Arith] Prove conditionals by transitively applying knowns

Posted by GitBox <gi...@apache.org>.
Lunderberg commented on code in PR #12863:
URL: https://github.com/apache/tvm/pull/12863#discussion_r989454291


##########
src/arith/transitive_comparison_analyzer.cc:
##########
@@ -0,0 +1,683 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file tvm/arith/transitive_comparison_analyzer.cc
+ */
+
+#include <tvm/arith/analyzer.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/expr.h>
+
+#include <optional>
+#include <vector>
+
+#include "constraint_extract.h"
+#include "pattern_match.h"
+
+namespace tvm {
+namespace arith {
+
+using namespace tir;
+
+class TransitiveComparisonAnalyzer::Impl {
+ public:
+  /* \brief Using previously specified knowns, compare the expressions provided
+   *
+   * \param lhs The left-hand side of the comparison
+   *
+   * \param rhs The right-hand side of the comparison
+   *
+   * \return The most specific result that can be proven about the
+   * comparison.  If nothing can be proven, returns kUnknown.
+   */
+  CompareResult TryCompare(const PrimExpr& lhs, const PrimExpr& rhs) const;
+
+  /*! \brief Bind a variable as being equal to a known expression
+   *
+   * \param var The variable of interest.
+   * \param expr The bound expression
+   * \param allow_override Whether to allow override of existing information.
+   */
+  void Bind(const tir::Var& var, const PrimExpr& expr, bool allow_override = false);
+
+  /*! \brief Bind a variable as being within a specified range
+   *
+   * \param var The variable of interest.
+   * \param range The known range
+   * \param allow_override Whether to allow override of existing information.
+   */
+  void Bind(const tir::Var& var, const Range& expr, bool allow_override = false);
+
+  /*!
+   * \brief Update the internal state to enter constraint.
+   * \param constraint A constraint expression.
+   *
+   * \return An exit function that must be called to cleanup.  May be
+   * `nullptr`, if no cleanup is required.
+   */
+  std::function<void()> EnterConstraint(const PrimExpr& expr);
+
+ private:
+  // Utility class to avoid needing to repeatedly call ExprDeepEqual
+  enum class Key : size_t {};
+  std::optional<Key> ExprToPreviousKey(const PrimExpr& expr) const;
+  Key ExprToKey(const PrimExpr& expr);
+  std::unordered_map<PrimExpr, Key, StructuralHash, StructuralEqual> expr_to_key;
+
+  /*! \brief Internal representation of a comparison operator */
+  struct Comparison {
+    /*! \brief Construct a comparison that represents `lhs OP rhs +
+     * offset`, where the operation is specified by the CompareResult.
+     */
+    Comparison(Key lhs, Key rhs, int64_t offset, CompareResult result);
+
+    /*! \brief Utility function to validate that all GT and LT results
+     *  have been normalized out
+     */
+    bool IsNormalized() const;
+
+    /*! \brief Move the specified expression to the LHS.
+     *
+     * \param new_lhs The argument that should be moved to the LHS of the
+     * comparison.
+     *
+     * \return If possible, returns a comparison that is equivalent to
+     * the current comparison, but with the specified LHS.  If not
+     * possible, returns nullopt.
+     */
+    std::optional<Comparison> WithLHS(Key new_lhs) const;
+
+    /*! \brief Create the negation of the current comparison */
+    Comparison Negated() const;
+
+    /*! \brief Check the this comparison implies
+     *
+     * Returns true if this comparison being true implies that the
+     * other comparison must also be true.  Returns false if the other
+     * comparison cannot be shown to be true.
+     */
+    bool Implies(const Comparison& other) const;
+
+    // The LHS of the comparison
+    Key lhs_;
+
+    // The RHS of the comparison, not including any constant offset.
+    Key rhs_;
+
+    // Additive offset on rhs
+    int64_t offset_{0};
+
+    // The comparison operator.
+    CompareResult result_{CompareResult::kInconsistent};
+  };
+
+  /*! \brief Generate a Comparison representing the given expression */
+  std::optional<Comparison> FromExpr(const PrimExpr& expr);
+
+  /*! \brief Utility function used by Bind and EnterConstraint
+   *
+   * \param expr The comparison expression, to be converted into
+   * internal Comparison objects.
+   *
+   * \param vec The vector to which the Comparison objects should be
+   * appended.
+   */
+  void AddKnown(const PrimExpr& expr, std::vector<Comparison>* vec);
+
+  /*! \brief Attempt to compare, starting at the lhs.
+   *
+   * Taking each available `Comparison` as a node edge, search for a
+   * path from lhs to rhs.  For example, the priors (a<=b), (b<=c+1)
+   * and (c<=d-5) can be used to prove that (a<=d-4).
+   *
+   * \param lhs The left-hand side of the comparison
+   *
+   * \param rhs The right-hand side of the comparison
+   *
+   * \return The result of the comparison
+   */
+  CompareResult TryCompareFromLHS(Key lhs_key, Key rhs_key, int64_t offset, const PrimExpr& lhs,
+                                  const PrimExpr& rhs) const;
+
+  /*! \brief Previous Range bindings
+   *
+   * Tracked separatedly to handle the `allow_override` option used by
+   * all sub-analyzers when binding variables.
+   */
+  Map<Var, Range> prev_bindings_;
+
+  /*! \brief Known comparisons based on definitionally-true statements
+   *
+   * For example, a Let binding, or the range of an iterator.
+   */
+  std::vector<Comparison> knowns_;
+
+  /*! \brief Known comparisons based on of scope-based statements
+   *
+   * For example, the condition of an IfThenElse, which is known to be
+   * true while within the if scope.
+   */
+  std::vector<Comparison> scoped_knowns_;
+};
+
+namespace {
+
+// Internal utility, return the CompareResult resulting from swapping
+// the left-hand side with the right-hand side.
+CompareResult Reverse(CompareResult res) {
+  switch (res) {
+    case CompareResult::kInconsistent:
+      return CompareResult::kInconsistent;
+    case CompareResult::kEQ:
+      return CompareResult::kEQ;
+    case CompareResult::kLT:
+      return CompareResult::kGT;
+    case CompareResult::kLE:
+      return CompareResult::kGE;
+    case CompareResult::kGT:
+      return CompareResult::kLT;
+    case CompareResult::kGE:
+      return CompareResult::kLE;
+    case CompareResult::kNE:
+      return CompareResult::kNE;
+    case CompareResult::kUnknown:
+      return CompareResult::kUnknown;
+    default:
+      LOG(FATAL) << "Invalid CompareResult: " << static_cast<int>(res);
+      return CompareResult::kInconsistent;
+  }
+}
+
+// Internal utility, return the CompareResult resulting from negating
+// the comparison.
+CompareResult Negate(CompareResult res) {
+  switch (res) {
+    case CompareResult::kInconsistent:
+      return CompareResult::kInconsistent;
+    case CompareResult::kUnknown:
+      return CompareResult::kUnknown;
+    default:
+      return CompareResult(~static_cast<int>(res) & static_cast<int>(CompareResult::kUnknown));
+  }
+}
+
+// Internal utility, extract constant offsets out of the two sides of
+// a comparison.  Given lhs and rhs, return a tuple of three elements
+// (lhs_inner, rhs_inner, offset), such that (lhs OP rhs) and
+// (lhs_inner OP rhs_inner + offset) are equivalent.
+std::tuple<PrimExpr, PrimExpr, int64_t> ExtractOffsets(const PrimExpr& lhs, const PrimExpr& rhs) {
+  auto extract_offset = [](const PrimExpr& expr) -> std::pair<PrimExpr, int64_t> {
+    PVar<PrimExpr> x;
+    PVar<IntImm> c;
+    if ((x + c).Match(expr)) {
+      return {x.Eval(), c.Eval()->value};
+    } else if ((x - c).Match(expr)) {
+      return {x.Eval(), -c.Eval()->value};
+    } else if (c.Match(expr)) {
+      return {0, c.Eval()->value};
+    } else {
+      return {expr, 0};
+    }
+  };
+
+  auto lhs_split = extract_offset(lhs);
+  auto rhs_split = extract_offset(rhs);
+  return {lhs_split.first, rhs_split.first, rhs_split.second - lhs_split.second};
+}
+
+}  // namespace
+
+std::optional<TransitiveComparisonAnalyzer::Impl::Comparison>
+TransitiveComparisonAnalyzer::Impl::FromExpr(const PrimExpr& expr) {
+  CompareResult res;
+  PVar<PrimExpr> x, y;
+  if ((x <= y).Match(expr)) {
+    res = CompareResult::kLE;
+  } else if ((x >= y).Match(expr)) {
+    res = CompareResult::kGE;
+  } else if ((x < y).Match(expr)) {
+    res = CompareResult::kLT;
+  } else if ((x > y).Match(expr)) {
+    res = CompareResult::kGT;
+  } else if ((x == y).Match(expr)) {
+    res = CompareResult::kEQ;
+  } else if ((x != y).Match(expr)) {
+    res = CompareResult::kNE;
+  } else {
+    return std::nullopt;
+  }
+
+  PrimExpr lhs_expr = x.Eval();
+  PrimExpr rhs_expr = y.Eval();
+
+  if (lhs_expr.as<IntImmNode>() && rhs_expr.as<IntImmNode>()) {
+    return std::nullopt;
+  }
+
+  auto [lhs, rhs, offset] = ExtractOffsets(lhs_expr, rhs_expr);
+  Key lhs_key = ExprToKey(lhs);
+  Key rhs_key = ExprToKey(rhs);
+
+  return Comparison(lhs_key, rhs_key, offset, res);
+}
+
+TransitiveComparisonAnalyzer::Impl::Comparison::Comparison(Key lhs, Key rhs, int64_t offset,
+                                                           CompareResult result)
+    : lhs_(lhs), rhs_(rhs), offset_(offset), result_(result) {
+  if (result_ == CompareResult::kLT) {
+    result_ = CompareResult::kLE;
+    offset_ -= 1;
+  }
+  if (result_ == CompareResult::kGT) {
+    result_ = CompareResult::kGE;
+    offset_ += 1;
+  }
+}
+
+std::optional<TransitiveComparisonAnalyzer::Impl::Key>
+TransitiveComparisonAnalyzer::Impl::ExprToPreviousKey(const PrimExpr& expr) const {
+  auto it = expr_to_key.find(expr);
+  if (it != expr_to_key.end()) {
+    return it->second;
+  } else {
+    return std::nullopt;
+  }
+}
+
+TransitiveComparisonAnalyzer::Impl::Key TransitiveComparisonAnalyzer::Impl::ExprToKey(
+    const PrimExpr& expr) {
+  if (auto prev = ExprToPreviousKey(expr)) {
+    return prev.value();
+  } else {
+    Key new_key = Key(expr_to_key.size());
+    expr_to_key[expr] = new_key;
+    return new_key;
+  }
+}
+
+bool TransitiveComparisonAnalyzer::Impl::Comparison::IsNormalized() const {
+  // These < and > should be removed during normalization.
+  return result_ != CompareResult::kLT && result_ != CompareResult::kGT;
+}
+
+std::optional<TransitiveComparisonAnalyzer::Impl::Comparison>
+TransitiveComparisonAnalyzer::Impl::Comparison::WithLHS(Key new_lhs) const {
+  if (new_lhs == lhs_) {
+    return *this;
+  } else if (new_lhs == rhs_) {
+    return Comparison(rhs_, lhs_, -offset_, Reverse(result_));
+  } else {
+    return std::nullopt;
+  }
+}
+
+TransitiveComparisonAnalyzer::Impl::Comparison
+TransitiveComparisonAnalyzer::Impl::Comparison::Negated() const {
+  return Comparison(lhs_, rhs_, offset_, Negate(result_));
+}
+
+bool TransitiveComparisonAnalyzer::Impl::Comparison::Implies(
+    const TransitiveComparisonAnalyzer::Impl::Comparison& other) const {
+  ICHECK(lhs_ == other.lhs_);
+  ICHECK(rhs_ == other.rhs_);
+  ICHECK(IsNormalized());
+  ICHECK(other.IsNormalized());
+
+  if (result_ == other.result_ && offset_ == other.offset_) {
+    // if c1 == c2, x != y + c1 => x != y + c2
+    // if c1 == c2, x == y + c1 => x == y + c2
+    return true;
+  }
+
+  if (other.result_ == CompareResult::kLE && offset_ <= other.offset_) {
+    if (result_ == CompareResult::kEQ || result_ == CompareResult::kLE) {
+      // if c1 <= c2, x <= y + c1 => x <= y + c2
+      // if c1 <= c2, x == y + c1 => x <= y + c2
+      return true;
+    }
+  }
+
+  if (other.result_ == CompareResult::kGE && offset_ >= other.offset_) {
+    if (result_ == CompareResult::kEQ || result_ == CompareResult::kGE) {
+      // if c1 >= c2, x == y + c1 => x >= y + c2
+      // if c1 >= c2, x >= y + c1 => x >= y + c2
+      return true;
+    }
+  }
+
+  if (other.result_ == CompareResult::kNE) {
+    if (result_ == CompareResult::kEQ && offset_ != other.offset_) {
+      // if c1 != c2, x == y + c1 => x != y + c2
+      return true;
+    }
+
+    if (result_ == CompareResult::kLE && offset_ < other.offset_) {
+      // if c1 < c2, x <= y + c1 => x < y + c2 => x != y + c2
+      return true;
+    }
+
+    if (result_ == CompareResult::kGE && offset_ > other.offset_) {
+      // if c1 != c2, x >= y + c1 => x > y + c2 => x != y + c2
+      return true;
+    }
+  }
+
+  return false;
+}
+
+TransitiveComparisonAnalyzer::TransitiveComparisonAnalyzer() : impl_(std::make_unique<Impl>()) {}
+TransitiveComparisonAnalyzer::~TransitiveComparisonAnalyzer() {}
+
+CompareResult TransitiveComparisonAnalyzer::TryCompare(const PrimExpr& lhs, const PrimExpr& rhs) {
+  return impl_->TryCompare(lhs, rhs);
+}
+
+void TransitiveComparisonAnalyzer::Bind(const Var& var, const PrimExpr& expr, bool allow_override) {
+  impl_->Bind(var, expr, allow_override);
+}
+void TransitiveComparisonAnalyzer::Bind(const Var& var, const Range& range, bool allow_override) {
+  impl_->Bind(var, range, allow_override);
+}
+
+std::function<void()> TransitiveComparisonAnalyzer::EnterConstraint(const PrimExpr& constraint) {
+  return impl_->EnterConstraint(constraint);
+}
+
+void TransitiveComparisonAnalyzer::Impl::AddKnown(const PrimExpr& expr,
+                                                  std::vector<Comparison>* vec) {
+  for (const auto& subexpr : ExtractConstraints(expr)) {
+    if (tir::SideEffect(expr) <= tir::CallEffectKind::kPure) {
+      if (auto cmp = FromExpr(subexpr)) {
+        vec->push_back(cmp.value());
+      }
+    }
+  }
+}
+
+void TransitiveComparisonAnalyzer::Impl::Bind(const tir::Var& var, const Range& range,
+                                              bool allow_override) {
+  auto it = prev_bindings_.find(var);
+  if (it != prev_bindings_.end()) {
+    ExprDeepEqual expr_equal;
+    bool differs_from_previous = !expr_equal(range->min, (*it).second->min) ||
+                                 !expr_equal(range->extent, (*it).second->extent);
+    if (differs_from_previous) {
+      ICHECK(allow_override) << "Binding of variable " << var << " as " << range
+                             << " conflicts with previous binding as " << (*it).second;
+      if (auto key = ExprToPreviousKey(var)) {
+        knowns_.erase(std::remove_if(knowns_.begin(), knowns_.end(),
+                                     [&](const auto& known) { return known.lhs_ == key.value(); }),
+                      knowns_.end());
+      }
+    }
+  }
+
+  prev_bindings_.Set(var, range);
+
+  if (is_const_int(range->extent, 1)) {
+    AddKnown(var == range->min, &knowns_);
+  } else {
+    AddKnown(var >= range->min, &knowns_);
+    AddKnown(var < range->min + range->extent, &knowns_);
+  }
+}
+
+void TransitiveComparisonAnalyzer::Impl::Bind(const tir::Var& var, const PrimExpr& expr,
+                                              bool allow_override) {
+  Bind(var, Range::FromMinExtent(expr, 1), allow_override);
+}
+
+std::function<void()> TransitiveComparisonAnalyzer::Impl::EnterConstraint(const PrimExpr& expr) {
+  size_t old_literal_size = scoped_knowns_.size();
+  AddKnown(expr, &scoped_knowns_);
+  size_t new_literal_size = scoped_knowns_.size();
+
+  PrimExpr temp = expr;
+  auto frecover = [old_literal_size, new_literal_size, this, temp]() {
+    ICHECK_EQ(scoped_knowns_.size(), new_literal_size);
+    scoped_knowns_.erase(scoped_knowns_.begin() + old_literal_size, scoped_knowns_.end());
+  };
+  return frecover;
+}
+
+CompareResult TransitiveComparisonAnalyzer::Impl::TryCompare(const PrimExpr& lhs_expr,
+                                                             const PrimExpr& rhs_expr) const {
+  // Currently only supports integer checks
+  if (!lhs_expr.dtype().is_int() || !rhs_expr.dtype().is_int()) {
+    return CompareResult::kUnknown;
+  }
+
+  // Bail out early if possible.  This int check should have been
+  // constant-folded earlier, so this check shouldn't occur.
+  auto* x_int = lhs_expr.as<IntImmNode>();
+  auto* y_int = rhs_expr.as<IntImmNode>();
+  if (x_int && y_int) {
+    if (x_int->value < y_int->value) {
+      return CompareResult::kLT;
+    } else if (x_int->value > y_int->value) {
+      return CompareResult::kGT;
+    } else {
+      return CompareResult::kEQ;
+    }
+  }
+
+  auto [lhs, rhs, offset] = ExtractOffsets(lhs_expr, rhs_expr);
+  auto lhs_key = ExprToPreviousKey(lhs);
+  auto rhs_key = ExprToPreviousKey(rhs);
+
+  if (!lhs_key.has_value() || !rhs_key.has_value()) {
+    return CompareResult::kUnknown;
+  }
+
+  auto from_lhs = TryCompareFromLHS(lhs_key.value(), rhs_key.value(), offset, lhs, rhs);
+  auto from_rhs = Reverse(TryCompareFromLHS(rhs_key.value(), lhs_key.value(), -offset, rhs, lhs));
+  auto output = from_lhs & from_rhs;
+
+  return output;
+}
+
+CompareResult TransitiveComparisonAnalyzer::Impl::TryCompareFromLHS(
+    Key lhs_key_input, Key rhs_key_input, int64_t offset_input, const PrimExpr& lhs_input,
+    const PrimExpr& rhs_input) const {
+  Key lhs_key = lhs_key_input;
+  Key rhs_key = rhs_key_input;
+  int64_t offset = offset_input;
+
+  // Everything in `to_visit` has lhs as its lhs.
+  std::unordered_set<Key> seen;
+  std::unordered_set<Key> to_visit;
+  std::unordered_map<Key, std::vector<Comparison>> compared_to_x;
+
+  // Utility function to add a new known statement
+  auto declare_known = [&](Comparison cmp) {
+    auto& prev_knowns = compared_to_x[cmp.rhs_];

Review Comment:
   Ooh, good call.  Updated, and added some comments that explicitly state when a modification is being performed, to draw additional attention to it.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] Lunderberg commented on a diff in pull request #12863: [TIR][Arith] Prove conditionals by transitively applying knowns

Posted by GitBox <gi...@apache.org>.
Lunderberg commented on code in PR #12863:
URL: https://github.com/apache/tvm/pull/12863#discussion_r989415776


##########
src/arith/transitive_comparison_analyzer.cc:
##########
@@ -0,0 +1,683 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file tvm/arith/transitive_comparison_analyzer.cc
+ */
+
+#include <tvm/arith/analyzer.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/expr.h>
+
+#include <optional>
+#include <vector>
+
+#include "constraint_extract.h"
+#include "pattern_match.h"
+
+namespace tvm {
+namespace arith {
+
+using namespace tir;
+
+class TransitiveComparisonAnalyzer::Impl {
+ public:
+  /* \brief Using previously specified knowns, compare the expressions provided
+   *
+   * \param lhs The left-hand side of the comparison
+   *
+   * \param rhs The right-hand side of the comparison
+   *
+   * \return The most specific result that can be proven about the
+   * comparison.  If nothing can be proven, returns kUnknown.
+   */
+  CompareResult TryCompare(const PrimExpr& lhs, const PrimExpr& rhs) const;
+
+  /*! \brief Bind a variable as being equal to a known expression
+   *
+   * \param var The variable of interest.
+   * \param expr The bound expression
+   * \param allow_override Whether to allow override of existing information.
+   */
+  void Bind(const tir::Var& var, const PrimExpr& expr, bool allow_override = false);
+
+  /*! \brief Bind a variable as being within a specified range
+   *
+   * \param var The variable of interest.
+   * \param range The known range
+   * \param allow_override Whether to allow override of existing information.
+   */
+  void Bind(const tir::Var& var, const Range& expr, bool allow_override = false);
+
+  /*!
+   * \brief Update the internal state to enter constraint.
+   * \param constraint A constraint expression.
+   *
+   * \return An exit function that must be called to cleanup.  May be
+   * `nullptr`, if no cleanup is required.
+   */
+  std::function<void()> EnterConstraint(const PrimExpr& expr);
+
+ private:
+  // Utility class to avoid needing to repeatedly call ExprDeepEqual
+  enum class Key : size_t {};

Review Comment:
   That's exactly correct.  A `size_t` could have been used directly without introducing the `enum class Key`, but that would introduce the possibility of a `size_t` being erroneously used as a key (e.g. accidentally treating a loop iterator as a lookup key).
   
   Since it required reading into it to discern the intent, I'm updating the docstrings for `Key`, `ExprToPreviousKey`, and `ExprToKey` to make the intended usage explicit.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] csullivan merged pull request #12863: [TIR][Arith] Prove conditionals by transitively applying knowns

Posted by GitBox <gi...@apache.org>.
csullivan merged PR #12863:
URL: https://github.com/apache/tvm/pull/12863


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] csullivan commented on a diff in pull request #12863: [TIR][Arith] Prove conditionals by transitively applying knowns

Posted by GitBox <gi...@apache.org>.
csullivan commented on code in PR #12863:
URL: https://github.com/apache/tvm/pull/12863#discussion_r988193162


##########
include/tvm/arith/analyzer.h:
##########
@@ -275,6 +275,36 @@ class RewriteSimplifier {
    */
   std::function<void()> EnterConstraint(const PrimExpr& constraint);
 
+  /*! \brief Flags to enable more computationally-intensive simplifications
+   *
+   * These simplifications may be required for specific schedules, but
+   * would impose too high a compile-time cost to enable by default.
+   * They can be enabled on an as-needed basis by calling
+   * `RewriteSimplifier::SetEnabledFeatures` prior to using
+   * `RewriteSimplifier::operator()`.
+   */
+  enum Feature {
+    // No features enabled
+    kNone = 0,
+
+    /* When simplifying an inequality, attempt to use scope-based knowns.
+     *
+     * Example:
+     * if_then_else(i<j && j<k, i<k, false) => if_then_else(i<j && j<k, true, false)
+     */
+    kTransitivelyProveInequalities = (1 << 0),

Review Comment:
   Powers of two for the ability to combine features I assume; do we expect additional entries in the future? The bitwise shift implicitly indicates this. If so a comment to demystify could be nice. 



##########
tests/python/unittest/test_tir_transform_simplify.py:
##########
@@ -138,6 +138,20 @@ def sls(n, d):
 
 class BaseBeforeAfter(tvm.testing.CompareBeforeAfter):

Review Comment:
   I'll assume the runtime isn't significantly altered by enabling transitively_prove_inequalities for existing tests in addition to those you're adding here. If this wasn't intentional feel free to add a new base class. 



##########
tests/python/unittest/test_tir_transform_simplify.py:
##########
@@ -547,5 +561,129 @@ def before(A: T.Buffer[16, "float32"]):
     expected = before
 
 
+class TestRemoveTransitivelyProvableCondition(BaseBeforeAfter):
+    """Remove comparisons that may be proven using multiple others
+
+    For example, the `0 < i` and `i <= j` conditions can be used to prove
+    that `0 < j`.
+    """
+
+    i, j, k = [tvm.tir.Var(name, "int32") for name in "ijk"]
+    zero = tvm.tir.IntImm("int32", 0)
+
+    test_case = tvm.testing.parameter(
+        (tvm.tir.all(zero < i, i <= j), zero < j, True),
+        # Transitive comparisons from LT
+        (tvm.tir.all(i < j, j < k), i < k, True),
+        (tvm.tir.all(i < j, j == k), i < k, True),
+        (tvm.tir.all(i < j, j <= k), i < k, True),
+        (tvm.tir.all(i < j, j > k), i < k, False),
+        (tvm.tir.all(i < j, j >= k), i < k, False),
+        (tvm.tir.all(i < j, j != k), i < k, False),
+        # Transitive comparisons from LE
+        (tvm.tir.all(i <= j, j < k), i < k, True),
+        (tvm.tir.all(i <= j, j == k), i == k, False),
+        (tvm.tir.all(i <= j, j == k), i <= k, True),
+        (tvm.tir.all(i <= j, j <= k), i <= k, True),
+        (tvm.tir.all(i <= j, j <= k), i < k, False),
+        (tvm.tir.all(i <= j, j > k), i < k, False),
+        (tvm.tir.all(i <= j, j >= k), i < k, False),
+        (tvm.tir.all(i <= j, j != k), i < k, False),
+        # Transitive comparisons from GT
+        (tvm.tir.all(i > j, j > k), i > k, True),
+        (tvm.tir.all(i > j, j == k), i > k, True),
+        (tvm.tir.all(i > j, j >= k), i > k, True),
+        (tvm.tir.all(i > j, j < k), i > k, False),
+        (tvm.tir.all(i > j, j <= k), i > k, False),
+        (tvm.tir.all(i > j, j != k), i > k, False),
+        # Transitive comparisons from GE
+        (tvm.tir.all(i >= j, j > k), i > k, True),
+        (tvm.tir.all(i >= j, j == k), i == k, False),
+        (tvm.tir.all(i >= j, j == k), i >= k, True),
+        (tvm.tir.all(i >= j, j >= k), i >= k, True),
+        (tvm.tir.all(i >= j, j >= k), i > k, False),
+        (tvm.tir.all(i >= j, j < k), i > k, False),
+        (tvm.tir.all(i >= j, j <= k), i > k, False),
+        (tvm.tir.all(i >= j, j != k), i > k, False),
+        # GT or LT may be used to prove NE
+        (tvm.tir.all(i == j, j != k), i != k, True),
+        (tvm.tir.all(i == j, j < k), i != k, True),
+        (tvm.tir.all(i == j, j > k), i != k, True),
+        (tvm.tir.all(i == j, j != k), i < k, False),
+        (tvm.tir.all(i == j, j != k), i > k, False),
+        # Because these are integers, x<y is equivalent to x <= y-1,
+        # and may be used in equivalent simplifications.
+        (tvm.tir.all(i < j, j < k), i < k, True),

Review Comment:
   Duplicates [L577](https://github.com/apache/tvm/pull/12863/files#diff-d7436dae3a0ec5555c249400c293d8d035753562f13fddadc0ebadf4f4c0d997R577). I think you want `(tvm.tir.all(i <= j-1, j < k), i < k, True),`.
   



##########
include/tvm/arith/analyzer.h:
##########
@@ -317,6 +347,82 @@ class CanonicalSimplifier {
   Impl* impl_;
 };
 
+/*! \brief Structure for representing result of known
+ *
+ * Values are assigned to allow these flags to be used in bitwise
+ * operations.
+ */
+enum class CompareResult : int {
+  kInconsistent = 0,
+  kEQ = 1,
+  kLT = 2,
+  kLE = 3,
+  kGT = 4,
+  kGE = 5,
+  kNE = 6,
+  kUnknown = 7
+};
+
+inline constexpr CompareResult operator&(CompareResult lhs, CompareResult rhs) {
+  return CompareResult(static_cast<int>(lhs) & static_cast<int>(rhs));
+}
+inline constexpr CompareResult operator|(CompareResult lhs, CompareResult rhs) {
+  return CompareResult(static_cast<int>(lhs) | static_cast<int>(rhs));
+}
+
+/*!
+ * \brief Using previously specified knowns, compare the expressions provided
+ *
+ * Given known expressions [(a OP b), (b OP c), ..., (y OP z)], search
+ * for a known result for `(a OP z)`.
+ */
+class TransitiveComparisonAnalyzer {
+ public:
+  /* \brief Using previously specified knowns, compare the expressions provided
+   *
+   * \param lhs The left-hand side of the comparison
+   *
+   * \param rhs The right-hand side of the comparison
+   *
+   * \return The most specific result that can be proven about the
+   * comparison.  If nothing can be proven, returns kUnknown.
+   */
+  CompareResult TryCompare(const PrimExpr& lhs, const PrimExpr& rhs);
+
+  /*! \brief Bind a variable as being equal to a known expression
+   *
+   * \param var The variable of interest.
+   * \param expr The bound expression
+   * \param allow_override Whether to allow override of existing information.
+   */
+  void Bind(const Var& var, const PrimExpr& expr, bool allow_override = false);

Review Comment:
   TVM_DLL, here and elsewhere 



##########
src/arith/transitive_comparison_analyzer.cc:
##########
@@ -0,0 +1,683 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file tvm/arith/transitive_comparison_analyzer.cc
+ */
+
+#include <tvm/arith/analyzer.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/expr.h>
+
+#include <optional>
+#include <vector>
+
+#include "constraint_extract.h"
+#include "pattern_match.h"
+
+namespace tvm {
+namespace arith {
+
+using namespace tir;
+
+class TransitiveComparisonAnalyzer::Impl {
+ public:
+  /* \brief Using previously specified knowns, compare the expressions provided
+   *
+   * \param lhs The left-hand side of the comparison
+   *
+   * \param rhs The right-hand side of the comparison
+   *
+   * \return The most specific result that can be proven about the
+   * comparison.  If nothing can be proven, returns kUnknown.
+   */
+  CompareResult TryCompare(const PrimExpr& lhs, const PrimExpr& rhs) const;
+
+  /*! \brief Bind a variable as being equal to a known expression
+   *
+   * \param var The variable of interest.
+   * \param expr The bound expression
+   * \param allow_override Whether to allow override of existing information.
+   */
+  void Bind(const tir::Var& var, const PrimExpr& expr, bool allow_override = false);
+
+  /*! \brief Bind a variable as being within a specified range
+   *
+   * \param var The variable of interest.
+   * \param range The known range
+   * \param allow_override Whether to allow override of existing information.
+   */
+  void Bind(const tir::Var& var, const Range& expr, bool allow_override = false);
+
+  /*!
+   * \brief Update the internal state to enter constraint.
+   * \param constraint A constraint expression.
+   *
+   * \return An exit function that must be called to cleanup.  May be
+   * `nullptr`, if no cleanup is required.
+   */
+  std::function<void()> EnterConstraint(const PrimExpr& expr);
+
+ private:
+  // Utility class to avoid needing to repeatedly call ExprDeepEqual
+  enum class Key : size_t {};

Review Comment:
   This is functioning as an ID for an expr st once a relationship between two keys is established, e.g. equality, it can be looked up without needing to re-evaluate equality. Do I read your intent correctly? 



##########
src/arith/transitive_comparison_analyzer.cc:
##########
@@ -0,0 +1,683 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file tvm/arith/transitive_comparison_analyzer.cc
+ */
+
+#include <tvm/arith/analyzer.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/expr.h>
+
+#include <optional>
+#include <vector>
+
+#include "constraint_extract.h"
+#include "pattern_match.h"
+
+namespace tvm {
+namespace arith {
+
+using namespace tir;
+
+class TransitiveComparisonAnalyzer::Impl {
+ public:
+  /* \brief Using previously specified knowns, compare the expressions provided
+   *
+   * \param lhs The left-hand side of the comparison
+   *
+   * \param rhs The right-hand side of the comparison
+   *
+   * \return The most specific result that can be proven about the
+   * comparison.  If nothing can be proven, returns kUnknown.
+   */
+  CompareResult TryCompare(const PrimExpr& lhs, const PrimExpr& rhs) const;
+
+  /*! \brief Bind a variable as being equal to a known expression
+   *
+   * \param var The variable of interest.
+   * \param expr The bound expression
+   * \param allow_override Whether to allow override of existing information.
+   */
+  void Bind(const tir::Var& var, const PrimExpr& expr, bool allow_override = false);
+
+  /*! \brief Bind a variable as being within a specified range
+   *
+   * \param var The variable of interest.
+   * \param range The known range
+   * \param allow_override Whether to allow override of existing information.
+   */
+  void Bind(const tir::Var& var, const Range& expr, bool allow_override = false);
+
+  /*!
+   * \brief Update the internal state to enter constraint.
+   * \param constraint A constraint expression.
+   *
+   * \return An exit function that must be called to cleanup.  May be
+   * `nullptr`, if no cleanup is required.
+   */
+  std::function<void()> EnterConstraint(const PrimExpr& expr);
+
+ private:
+  // Utility class to avoid needing to repeatedly call ExprDeepEqual
+  enum class Key : size_t {};
+  std::optional<Key> ExprToPreviousKey(const PrimExpr& expr) const;
+  Key ExprToKey(const PrimExpr& expr);
+  std::unordered_map<PrimExpr, Key, StructuralHash, StructuralEqual> expr_to_key;
+
+  /*! \brief Internal representation of a comparison operator */
+  struct Comparison {
+    /*! \brief Construct a comparison that represents `lhs OP rhs +
+     * offset`, where the operation is specified by the CompareResult.
+     */
+    Comparison(Key lhs, Key rhs, int64_t offset, CompareResult result);
+
+    /*! \brief Utility function to validate that all GT and LT results
+     *  have been normalized out
+     */
+    bool IsNormalized() const;
+
+    /*! \brief Move the specified expression to the LHS.
+     *
+     * \param new_lhs The argument that should be moved to the LHS of the
+     * comparison.
+     *
+     * \return If possible, returns a comparison that is equivalent to
+     * the current comparison, but with the specified LHS.  If not
+     * possible, returns nullopt.
+     */
+    std::optional<Comparison> WithLHS(Key new_lhs) const;
+
+    /*! \brief Create the negation of the current comparison */
+    Comparison Negated() const;
+
+    /*! \brief Check the this comparison implies
+     *
+     * Returns true if this comparison being true implies that the
+     * other comparison must also be true.  Returns false if the other
+     * comparison cannot be shown to be true.
+     */
+    bool Implies(const Comparison& other) const;
+
+    // The LHS of the comparison
+    Key lhs_;
+
+    // The RHS of the comparison, not including any constant offset.
+    Key rhs_;
+
+    // Additive offset on rhs
+    int64_t offset_{0};
+
+    // The comparison operator.
+    CompareResult result_{CompareResult::kInconsistent};
+  };
+
+  /*! \brief Generate a Comparison representing the given expression */
+  std::optional<Comparison> FromExpr(const PrimExpr& expr);
+
+  /*! \brief Utility function used by Bind and EnterConstraint
+   *
+   * \param expr The comparison expression, to be converted into
+   * internal Comparison objects.
+   *
+   * \param vec The vector to which the Comparison objects should be
+   * appended.
+   */
+  void AddKnown(const PrimExpr& expr, std::vector<Comparison>* vec);
+
+  /*! \brief Attempt to compare, starting at the lhs.
+   *
+   * Taking each available `Comparison` as a node edge, search for a
+   * path from lhs to rhs.  For example, the priors (a<=b), (b<=c+1)
+   * and (c<=d-5) can be used to prove that (a<=d-4).
+   *
+   * \param lhs The left-hand side of the comparison
+   *
+   * \param rhs The right-hand side of the comparison
+   *
+   * \return The result of the comparison
+   */
+  CompareResult TryCompareFromLHS(Key lhs_key, Key rhs_key, int64_t offset, const PrimExpr& lhs,
+                                  const PrimExpr& rhs) const;
+
+  /*! \brief Previous Range bindings
+   *
+   * Tracked separatedly to handle the `allow_override` option used by
+   * all sub-analyzers when binding variables.
+   */
+  Map<Var, Range> prev_bindings_;
+
+  /*! \brief Known comparisons based on definitionally-true statements
+   *
+   * For example, a Let binding, or the range of an iterator.
+   */
+  std::vector<Comparison> knowns_;
+
+  /*! \brief Known comparisons based on of scope-based statements
+   *
+   * For example, the condition of an IfThenElse, which is known to be
+   * true while within the if scope.
+   */
+  std::vector<Comparison> scoped_knowns_;
+};
+
+namespace {
+
+// Internal utility, return the CompareResult resulting from swapping
+// the left-hand side with the right-hand side.
+CompareResult Reverse(CompareResult res) {
+  switch (res) {
+    case CompareResult::kInconsistent:
+      return CompareResult::kInconsistent;
+    case CompareResult::kEQ:
+      return CompareResult::kEQ;
+    case CompareResult::kLT:
+      return CompareResult::kGT;
+    case CompareResult::kLE:
+      return CompareResult::kGE;
+    case CompareResult::kGT:
+      return CompareResult::kLT;
+    case CompareResult::kGE:
+      return CompareResult::kLE;
+    case CompareResult::kNE:
+      return CompareResult::kNE;
+    case CompareResult::kUnknown:
+      return CompareResult::kUnknown;
+    default:
+      LOG(FATAL) << "Invalid CompareResult: " << static_cast<int>(res);
+      return CompareResult::kInconsistent;
+  }
+}
+
+// Internal utility, return the CompareResult resulting from negating
+// the comparison.
+CompareResult Negate(CompareResult res) {
+  switch (res) {
+    case CompareResult::kInconsistent:
+      return CompareResult::kInconsistent;
+    case CompareResult::kUnknown:
+      return CompareResult::kUnknown;
+    default:
+      return CompareResult(~static_cast<int>(res) & static_cast<int>(CompareResult::kUnknown));
+  }
+}
+
+// Internal utility, extract constant offsets out of the two sides of
+// a comparison.  Given lhs and rhs, return a tuple of three elements
+// (lhs_inner, rhs_inner, offset), such that (lhs OP rhs) and
+// (lhs_inner OP rhs_inner + offset) are equivalent.
+std::tuple<PrimExpr, PrimExpr, int64_t> ExtractOffsets(const PrimExpr& lhs, const PrimExpr& rhs) {
+  auto extract_offset = [](const PrimExpr& expr) -> std::pair<PrimExpr, int64_t> {
+    PVar<PrimExpr> x;
+    PVar<IntImm> c;
+    if ((x + c).Match(expr)) {
+      return {x.Eval(), c.Eval()->value};
+    } else if ((x - c).Match(expr)) {
+      return {x.Eval(), -c.Eval()->value};
+    } else if (c.Match(expr)) {
+      return {0, c.Eval()->value};
+    } else {
+      return {expr, 0};
+    }
+  };
+
+  auto lhs_split = extract_offset(lhs);
+  auto rhs_split = extract_offset(rhs);
+  return {lhs_split.first, rhs_split.first, rhs_split.second - lhs_split.second};
+}
+
+}  // namespace
+
+std::optional<TransitiveComparisonAnalyzer::Impl::Comparison>
+TransitiveComparisonAnalyzer::Impl::FromExpr(const PrimExpr& expr) {
+  CompareResult res;
+  PVar<PrimExpr> x, y;
+  if ((x <= y).Match(expr)) {
+    res = CompareResult::kLE;
+  } else if ((x >= y).Match(expr)) {
+    res = CompareResult::kGE;
+  } else if ((x < y).Match(expr)) {
+    res = CompareResult::kLT;
+  } else if ((x > y).Match(expr)) {
+    res = CompareResult::kGT;
+  } else if ((x == y).Match(expr)) {
+    res = CompareResult::kEQ;
+  } else if ((x != y).Match(expr)) {
+    res = CompareResult::kNE;
+  } else {
+    return std::nullopt;
+  }
+
+  PrimExpr lhs_expr = x.Eval();
+  PrimExpr rhs_expr = y.Eval();
+
+  if (lhs_expr.as<IntImmNode>() && rhs_expr.as<IntImmNode>()) {
+    return std::nullopt;
+  }
+
+  auto [lhs, rhs, offset] = ExtractOffsets(lhs_expr, rhs_expr);
+  Key lhs_key = ExprToKey(lhs);
+  Key rhs_key = ExprToKey(rhs);
+
+  return Comparison(lhs_key, rhs_key, offset, res);
+}
+
+TransitiveComparisonAnalyzer::Impl::Comparison::Comparison(Key lhs, Key rhs, int64_t offset,
+                                                           CompareResult result)
+    : lhs_(lhs), rhs_(rhs), offset_(offset), result_(result) {
+  if (result_ == CompareResult::kLT) {
+    result_ = CompareResult::kLE;
+    offset_ -= 1;
+  }
+  if (result_ == CompareResult::kGT) {
+    result_ = CompareResult::kGE;
+    offset_ += 1;
+  }
+}
+
+std::optional<TransitiveComparisonAnalyzer::Impl::Key>
+TransitiveComparisonAnalyzer::Impl::ExprToPreviousKey(const PrimExpr& expr) const {
+  auto it = expr_to_key.find(expr);
+  if (it != expr_to_key.end()) {
+    return it->second;
+  } else {
+    return std::nullopt;
+  }
+}
+
+TransitiveComparisonAnalyzer::Impl::Key TransitiveComparisonAnalyzer::Impl::ExprToKey(
+    const PrimExpr& expr) {
+  if (auto prev = ExprToPreviousKey(expr)) {
+    return prev.value();
+  } else {
+    Key new_key = Key(expr_to_key.size());
+    expr_to_key[expr] = new_key;
+    return new_key;
+  }
+}
+
+bool TransitiveComparisonAnalyzer::Impl::Comparison::IsNormalized() const {
+  // These < and > should be removed during normalization.
+  return result_ != CompareResult::kLT && result_ != CompareResult::kGT;
+}
+
+std::optional<TransitiveComparisonAnalyzer::Impl::Comparison>
+TransitiveComparisonAnalyzer::Impl::Comparison::WithLHS(Key new_lhs) const {
+  if (new_lhs == lhs_) {
+    return *this;
+  } else if (new_lhs == rhs_) {
+    return Comparison(rhs_, lhs_, -offset_, Reverse(result_));
+  } else {
+    return std::nullopt;
+  }
+}
+
+TransitiveComparisonAnalyzer::Impl::Comparison
+TransitiveComparisonAnalyzer::Impl::Comparison::Negated() const {
+  return Comparison(lhs_, rhs_, offset_, Negate(result_));
+}
+
+bool TransitiveComparisonAnalyzer::Impl::Comparison::Implies(
+    const TransitiveComparisonAnalyzer::Impl::Comparison& other) const {
+  ICHECK(lhs_ == other.lhs_);
+  ICHECK(rhs_ == other.rhs_);
+  ICHECK(IsNormalized());
+  ICHECK(other.IsNormalized());
+
+  if (result_ == other.result_ && offset_ == other.offset_) {
+    // if c1 == c2, x != y + c1 => x != y + c2
+    // if c1 == c2, x == y + c1 => x == y + c2
+    return true;
+  }
+
+  if (other.result_ == CompareResult::kLE && offset_ <= other.offset_) {
+    if (result_ == CompareResult::kEQ || result_ == CompareResult::kLE) {
+      // if c1 <= c2, x <= y + c1 => x <= y + c2
+      // if c1 <= c2, x == y + c1 => x <= y + c2
+      return true;
+    }
+  }
+
+  if (other.result_ == CompareResult::kGE && offset_ >= other.offset_) {
+    if (result_ == CompareResult::kEQ || result_ == CompareResult::kGE) {
+      // if c1 >= c2, x == y + c1 => x >= y + c2
+      // if c1 >= c2, x >= y + c1 => x >= y + c2
+      return true;
+    }
+  }
+
+  if (other.result_ == CompareResult::kNE) {
+    if (result_ == CompareResult::kEQ && offset_ != other.offset_) {
+      // if c1 != c2, x == y + c1 => x != y + c2
+      return true;
+    }
+
+    if (result_ == CompareResult::kLE && offset_ < other.offset_) {
+      // if c1 < c2, x <= y + c1 => x < y + c2 => x != y + c2
+      return true;
+    }
+
+    if (result_ == CompareResult::kGE && offset_ > other.offset_) {
+      // if c1 != c2, x >= y + c1 => x > y + c2 => x != y + c2
+      return true;
+    }
+  }
+
+  return false;
+}
+
+TransitiveComparisonAnalyzer::TransitiveComparisonAnalyzer() : impl_(std::make_unique<Impl>()) {}
+TransitiveComparisonAnalyzer::~TransitiveComparisonAnalyzer() {}
+
+CompareResult TransitiveComparisonAnalyzer::TryCompare(const PrimExpr& lhs, const PrimExpr& rhs) {
+  return impl_->TryCompare(lhs, rhs);
+}
+
+void TransitiveComparisonAnalyzer::Bind(const Var& var, const PrimExpr& expr, bool allow_override) {
+  impl_->Bind(var, expr, allow_override);
+}
+void TransitiveComparisonAnalyzer::Bind(const Var& var, const Range& range, bool allow_override) {
+  impl_->Bind(var, range, allow_override);
+}
+
+std::function<void()> TransitiveComparisonAnalyzer::EnterConstraint(const PrimExpr& constraint) {
+  return impl_->EnterConstraint(constraint);
+}
+
+void TransitiveComparisonAnalyzer::Impl::AddKnown(const PrimExpr& expr,
+                                                  std::vector<Comparison>* vec) {
+  for (const auto& subexpr : ExtractConstraints(expr)) {
+    if (tir::SideEffect(expr) <= tir::CallEffectKind::kPure) {
+      if (auto cmp = FromExpr(subexpr)) {
+        vec->push_back(cmp.value());
+      }
+    }
+  }
+}
+
+void TransitiveComparisonAnalyzer::Impl::Bind(const tir::Var& var, const Range& range,
+                                              bool allow_override) {
+  auto it = prev_bindings_.find(var);
+  if (it != prev_bindings_.end()) {
+    ExprDeepEqual expr_equal;
+    bool differs_from_previous = !expr_equal(range->min, (*it).second->min) ||
+                                 !expr_equal(range->extent, (*it).second->extent);
+    if (differs_from_previous) {
+      ICHECK(allow_override) << "Binding of variable " << var << " as " << range
+                             << " conflicts with previous binding as " << (*it).second;
+      if (auto key = ExprToPreviousKey(var)) {
+        knowns_.erase(std::remove_if(knowns_.begin(), knowns_.end(),
+                                     [&](const auto& known) { return known.lhs_ == key.value(); }),
+                      knowns_.end());
+      }
+    }
+  }
+
+  prev_bindings_.Set(var, range);
+
+  if (is_const_int(range->extent, 1)) {
+    AddKnown(var == range->min, &knowns_);
+  } else {
+    AddKnown(var >= range->min, &knowns_);
+    AddKnown(var < range->min + range->extent, &knowns_);
+  }
+}
+
+void TransitiveComparisonAnalyzer::Impl::Bind(const tir::Var& var, const PrimExpr& expr,
+                                              bool allow_override) {
+  Bind(var, Range::FromMinExtent(expr, 1), allow_override);
+}
+
+std::function<void()> TransitiveComparisonAnalyzer::Impl::EnterConstraint(const PrimExpr& expr) {
+  size_t old_literal_size = scoped_knowns_.size();
+  AddKnown(expr, &scoped_knowns_);
+  size_t new_literal_size = scoped_knowns_.size();
+
+  PrimExpr temp = expr;
+  auto frecover = [old_literal_size, new_literal_size, this, temp]() {

Review Comment:
   nit: Can define `temp` here in the lambda capture list, `[..., temp = expr]() {...};` since it isn't used elsewhere.



##########
src/arith/transitive_comparison_analyzer.cc:
##########
@@ -0,0 +1,683 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file tvm/arith/transitive_comparison_analyzer.cc
+ */
+
+#include <tvm/arith/analyzer.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/expr.h>
+
+#include <optional>
+#include <vector>
+
+#include "constraint_extract.h"
+#include "pattern_match.h"
+
+namespace tvm {
+namespace arith {
+
+using namespace tir;
+
+class TransitiveComparisonAnalyzer::Impl {
+ public:
+  /* \brief Using previously specified knowns, compare the expressions provided
+   *
+   * \param lhs The left-hand side of the comparison
+   *
+   * \param rhs The right-hand side of the comparison
+   *
+   * \return The most specific result that can be proven about the
+   * comparison.  If nothing can be proven, returns kUnknown.
+   */
+  CompareResult TryCompare(const PrimExpr& lhs, const PrimExpr& rhs) const;
+
+  /*! \brief Bind a variable as being equal to a known expression
+   *
+   * \param var The variable of interest.
+   * \param expr The bound expression
+   * \param allow_override Whether to allow override of existing information.
+   */
+  void Bind(const tir::Var& var, const PrimExpr& expr, bool allow_override = false);
+
+  /*! \brief Bind a variable as being within a specified range
+   *
+   * \param var The variable of interest.
+   * \param range The known range
+   * \param allow_override Whether to allow override of existing information.
+   */
+  void Bind(const tir::Var& var, const Range& expr, bool allow_override = false);
+
+  /*!
+   * \brief Update the internal state to enter constraint.
+   * \param constraint A constraint expression.
+   *
+   * \return An exit function that must be called to cleanup.  May be
+   * `nullptr`, if no cleanup is required.
+   */
+  std::function<void()> EnterConstraint(const PrimExpr& expr);
+
+ private:
+  // Utility class to avoid needing to repeatedly call ExprDeepEqual
+  enum class Key : size_t {};
+  std::optional<Key> ExprToPreviousKey(const PrimExpr& expr) const;
+  Key ExprToKey(const PrimExpr& expr);
+  std::unordered_map<PrimExpr, Key, StructuralHash, StructuralEqual> expr_to_key;
+
+  /*! \brief Internal representation of a comparison operator */
+  struct Comparison {
+    /*! \brief Construct a comparison that represents `lhs OP rhs +
+     * offset`, where the operation is specified by the CompareResult.
+     */
+    Comparison(Key lhs, Key rhs, int64_t offset, CompareResult result);
+
+    /*! \brief Utility function to validate that all GT and LT results
+     *  have been normalized out
+     */
+    bool IsNormalized() const;
+
+    /*! \brief Move the specified expression to the LHS.
+     *
+     * \param new_lhs The argument that should be moved to the LHS of the
+     * comparison.
+     *
+     * \return If possible, returns a comparison that is equivalent to
+     * the current comparison, but with the specified LHS.  If not
+     * possible, returns nullopt.
+     */
+    std::optional<Comparison> WithLHS(Key new_lhs) const;
+
+    /*! \brief Create the negation of the current comparison */
+    Comparison Negated() const;
+
+    /*! \brief Check the this comparison implies
+     *
+     * Returns true if this comparison being true implies that the
+     * other comparison must also be true.  Returns false if the other
+     * comparison cannot be shown to be true.
+     */
+    bool Implies(const Comparison& other) const;
+
+    // The LHS of the comparison
+    Key lhs_;
+
+    // The RHS of the comparison, not including any constant offset.
+    Key rhs_;
+
+    // Additive offset on rhs
+    int64_t offset_{0};
+
+    // The comparison operator.
+    CompareResult result_{CompareResult::kInconsistent};
+  };
+
+  /*! \brief Generate a Comparison representing the given expression */
+  std::optional<Comparison> FromExpr(const PrimExpr& expr);
+
+  /*! \brief Utility function used by Bind and EnterConstraint
+   *
+   * \param expr The comparison expression, to be converted into
+   * internal Comparison objects.
+   *
+   * \param vec The vector to which the Comparison objects should be
+   * appended.
+   */
+  void AddKnown(const PrimExpr& expr, std::vector<Comparison>* vec);
+
+  /*! \brief Attempt to compare, starting at the lhs.
+   *
+   * Taking each available `Comparison` as a node edge, search for a
+   * path from lhs to rhs.  For example, the priors (a<=b), (b<=c+1)
+   * and (c<=d-5) can be used to prove that (a<=d-4).
+   *
+   * \param lhs The left-hand side of the comparison
+   *
+   * \param rhs The right-hand side of the comparison
+   *
+   * \return The result of the comparison
+   */
+  CompareResult TryCompareFromLHS(Key lhs_key, Key rhs_key, int64_t offset, const PrimExpr& lhs,
+                                  const PrimExpr& rhs) const;
+
+  /*! \brief Previous Range bindings
+   *
+   * Tracked separatedly to handle the `allow_override` option used by
+   * all sub-analyzers when binding variables.
+   */
+  Map<Var, Range> prev_bindings_;
+
+  /*! \brief Known comparisons based on definitionally-true statements
+   *
+   * For example, a Let binding, or the range of an iterator.
+   */
+  std::vector<Comparison> knowns_;
+
+  /*! \brief Known comparisons based on of scope-based statements
+   *
+   * For example, the condition of an IfThenElse, which is known to be
+   * true while within the if scope.
+   */
+  std::vector<Comparison> scoped_knowns_;
+};
+
+namespace {
+
+// Internal utility, return the CompareResult resulting from swapping
+// the left-hand side with the right-hand side.
+CompareResult Reverse(CompareResult res) {
+  switch (res) {
+    case CompareResult::kInconsistent:
+      return CompareResult::kInconsistent;
+    case CompareResult::kEQ:
+      return CompareResult::kEQ;
+    case CompareResult::kLT:
+      return CompareResult::kGT;
+    case CompareResult::kLE:
+      return CompareResult::kGE;
+    case CompareResult::kGT:
+      return CompareResult::kLT;
+    case CompareResult::kGE:
+      return CompareResult::kLE;
+    case CompareResult::kNE:
+      return CompareResult::kNE;
+    case CompareResult::kUnknown:
+      return CompareResult::kUnknown;
+    default:
+      LOG(FATAL) << "Invalid CompareResult: " << static_cast<int>(res);
+      return CompareResult::kInconsistent;
+  }
+}
+
+// Internal utility, return the CompareResult resulting from negating
+// the comparison.
+CompareResult Negate(CompareResult res) {
+  switch (res) {
+    case CompareResult::kInconsistent:
+      return CompareResult::kInconsistent;
+    case CompareResult::kUnknown:
+      return CompareResult::kUnknown;
+    default:
+      return CompareResult(~static_cast<int>(res) & static_cast<int>(CompareResult::kUnknown));
+  }
+}
+
+// Internal utility, extract constant offsets out of the two sides of
+// a comparison.  Given lhs and rhs, return a tuple of three elements
+// (lhs_inner, rhs_inner, offset), such that (lhs OP rhs) and
+// (lhs_inner OP rhs_inner + offset) are equivalent.
+std::tuple<PrimExpr, PrimExpr, int64_t> ExtractOffsets(const PrimExpr& lhs, const PrimExpr& rhs) {
+  auto extract_offset = [](const PrimExpr& expr) -> std::pair<PrimExpr, int64_t> {
+    PVar<PrimExpr> x;
+    PVar<IntImm> c;
+    if ((x + c).Match(expr)) {
+      return {x.Eval(), c.Eval()->value};
+    } else if ((x - c).Match(expr)) {
+      return {x.Eval(), -c.Eval()->value};
+    } else if (c.Match(expr)) {
+      return {0, c.Eval()->value};
+    } else {
+      return {expr, 0};
+    }
+  };
+
+  auto lhs_split = extract_offset(lhs);
+  auto rhs_split = extract_offset(rhs);
+  return {lhs_split.first, rhs_split.first, rhs_split.second - lhs_split.second};
+}
+
+}  // namespace
+
+std::optional<TransitiveComparisonAnalyzer::Impl::Comparison>
+TransitiveComparisonAnalyzer::Impl::FromExpr(const PrimExpr& expr) {
+  CompareResult res;
+  PVar<PrimExpr> x, y;
+  if ((x <= y).Match(expr)) {
+    res = CompareResult::kLE;
+  } else if ((x >= y).Match(expr)) {
+    res = CompareResult::kGE;
+  } else if ((x < y).Match(expr)) {
+    res = CompareResult::kLT;
+  } else if ((x > y).Match(expr)) {
+    res = CompareResult::kGT;
+  } else if ((x == y).Match(expr)) {
+    res = CompareResult::kEQ;
+  } else if ((x != y).Match(expr)) {
+    res = CompareResult::kNE;
+  } else {
+    return std::nullopt;
+  }
+
+  PrimExpr lhs_expr = x.Eval();
+  PrimExpr rhs_expr = y.Eval();
+
+  if (lhs_expr.as<IntImmNode>() && rhs_expr.as<IntImmNode>()) {
+    return std::nullopt;
+  }
+
+  auto [lhs, rhs, offset] = ExtractOffsets(lhs_expr, rhs_expr);
+  Key lhs_key = ExprToKey(lhs);
+  Key rhs_key = ExprToKey(rhs);
+
+  return Comparison(lhs_key, rhs_key, offset, res);
+}
+
+TransitiveComparisonAnalyzer::Impl::Comparison::Comparison(Key lhs, Key rhs, int64_t offset,
+                                                           CompareResult result)
+    : lhs_(lhs), rhs_(rhs), offset_(offset), result_(result) {
+  if (result_ == CompareResult::kLT) {
+    result_ = CompareResult::kLE;
+    offset_ -= 1;
+  }
+  if (result_ == CompareResult::kGT) {
+    result_ = CompareResult::kGE;
+    offset_ += 1;
+  }
+}
+
+std::optional<TransitiveComparisonAnalyzer::Impl::Key>
+TransitiveComparisonAnalyzer::Impl::ExprToPreviousKey(const PrimExpr& expr) const {
+  auto it = expr_to_key.find(expr);
+  if (it != expr_to_key.end()) {
+    return it->second;
+  } else {
+    return std::nullopt;
+  }
+}
+
+TransitiveComparisonAnalyzer::Impl::Key TransitiveComparisonAnalyzer::Impl::ExprToKey(
+    const PrimExpr& expr) {
+  if (auto prev = ExprToPreviousKey(expr)) {
+    return prev.value();
+  } else {
+    Key new_key = Key(expr_to_key.size());
+    expr_to_key[expr] = new_key;
+    return new_key;
+  }
+}
+
+bool TransitiveComparisonAnalyzer::Impl::Comparison::IsNormalized() const {
+  // These < and > should be removed during normalization.
+  return result_ != CompareResult::kLT && result_ != CompareResult::kGT;
+}
+
+std::optional<TransitiveComparisonAnalyzer::Impl::Comparison>
+TransitiveComparisonAnalyzer::Impl::Comparison::WithLHS(Key new_lhs) const {
+  if (new_lhs == lhs_) {
+    return *this;
+  } else if (new_lhs == rhs_) {
+    return Comparison(rhs_, lhs_, -offset_, Reverse(result_));
+  } else {
+    return std::nullopt;
+  }
+}
+
+TransitiveComparisonAnalyzer::Impl::Comparison
+TransitiveComparisonAnalyzer::Impl::Comparison::Negated() const {
+  return Comparison(lhs_, rhs_, offset_, Negate(result_));
+}
+
+bool TransitiveComparisonAnalyzer::Impl::Comparison::Implies(
+    const TransitiveComparisonAnalyzer::Impl::Comparison& other) const {
+  ICHECK(lhs_ == other.lhs_);
+  ICHECK(rhs_ == other.rhs_);
+  ICHECK(IsNormalized());
+  ICHECK(other.IsNormalized());
+
+  if (result_ == other.result_ && offset_ == other.offset_) {
+    // if c1 == c2, x != y + c1 => x != y + c2
+    // if c1 == c2, x == y + c1 => x == y + c2
+    return true;
+  }
+
+  if (other.result_ == CompareResult::kLE && offset_ <= other.offset_) {
+    if (result_ == CompareResult::kEQ || result_ == CompareResult::kLE) {
+      // if c1 <= c2, x <= y + c1 => x <= y + c2
+      // if c1 <= c2, x == y + c1 => x <= y + c2
+      return true;
+    }
+  }
+
+  if (other.result_ == CompareResult::kGE && offset_ >= other.offset_) {
+    if (result_ == CompareResult::kEQ || result_ == CompareResult::kGE) {
+      // if c1 >= c2, x == y + c1 => x >= y + c2
+      // if c1 >= c2, x >= y + c1 => x >= y + c2
+      return true;
+    }
+  }
+
+  if (other.result_ == CompareResult::kNE) {
+    if (result_ == CompareResult::kEQ && offset_ != other.offset_) {
+      // if c1 != c2, x == y + c1 => x != y + c2
+      return true;
+    }
+
+    if (result_ == CompareResult::kLE && offset_ < other.offset_) {
+      // if c1 < c2, x <= y + c1 => x < y + c2 => x != y + c2
+      return true;
+    }
+
+    if (result_ == CompareResult::kGE && offset_ > other.offset_) {
+      // if c1 != c2, x >= y + c1 => x > y + c2 => x != y + c2
+      return true;

Review Comment:
   Should we also check the value of `other.result_` to ensure against an erroneous match, e.g. `other.result_ == CompareResult::kEQ` / `x == y + c2`?



##########
src/arith/transitive_comparison_analyzer.cc:
##########
@@ -0,0 +1,683 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file tvm/arith/transitive_comparison_analyzer.cc
+ */
+
+#include <tvm/arith/analyzer.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/expr.h>
+
+#include <optional>
+#include <vector>
+
+#include "constraint_extract.h"
+#include "pattern_match.h"
+
+namespace tvm {
+namespace arith {
+
+using namespace tir;
+
+class TransitiveComparisonAnalyzer::Impl {
+ public:
+  /* \brief Using previously specified knowns, compare the expressions provided
+   *
+   * \param lhs The left-hand side of the comparison
+   *
+   * \param rhs The right-hand side of the comparison
+   *
+   * \return The most specific result that can be proven about the
+   * comparison.  If nothing can be proven, returns kUnknown.
+   */
+  CompareResult TryCompare(const PrimExpr& lhs, const PrimExpr& rhs) const;
+
+  /*! \brief Bind a variable as being equal to a known expression
+   *
+   * \param var The variable of interest.
+   * \param expr The bound expression
+   * \param allow_override Whether to allow override of existing information.
+   */
+  void Bind(const tir::Var& var, const PrimExpr& expr, bool allow_override = false);
+
+  /*! \brief Bind a variable as being within a specified range
+   *
+   * \param var The variable of interest.
+   * \param range The known range
+   * \param allow_override Whether to allow override of existing information.
+   */
+  void Bind(const tir::Var& var, const Range& expr, bool allow_override = false);
+
+  /*!
+   * \brief Update the internal state to enter constraint.
+   * \param constraint A constraint expression.
+   *
+   * \return An exit function that must be called to cleanup.  May be
+   * `nullptr`, if no cleanup is required.
+   */
+  std::function<void()> EnterConstraint(const PrimExpr& expr);
+
+ private:
+  // Utility class to avoid needing to repeatedly call ExprDeepEqual
+  enum class Key : size_t {};
+  std::optional<Key> ExprToPreviousKey(const PrimExpr& expr) const;
+  Key ExprToKey(const PrimExpr& expr);
+  std::unordered_map<PrimExpr, Key, StructuralHash, StructuralEqual> expr_to_key;
+
+  /*! \brief Internal representation of a comparison operator */
+  struct Comparison {
+    /*! \brief Construct a comparison that represents `lhs OP rhs +
+     * offset`, where the operation is specified by the CompareResult.
+     */
+    Comparison(Key lhs, Key rhs, int64_t offset, CompareResult result);
+
+    /*! \brief Utility function to validate that all GT and LT results
+     *  have been normalized out
+     */
+    bool IsNormalized() const;
+
+    /*! \brief Move the specified expression to the LHS.
+     *
+     * \param new_lhs The argument that should be moved to the LHS of the
+     * comparison.
+     *
+     * \return If possible, returns a comparison that is equivalent to
+     * the current comparison, but with the specified LHS.  If not
+     * possible, returns nullopt.
+     */
+    std::optional<Comparison> WithLHS(Key new_lhs) const;
+
+    /*! \brief Create the negation of the current comparison */
+    Comparison Negated() const;
+
+    /*! \brief Check the this comparison implies
+     *
+     * Returns true if this comparison being true implies that the
+     * other comparison must also be true.  Returns false if the other
+     * comparison cannot be shown to be true.
+     */
+    bool Implies(const Comparison& other) const;
+
+    // The LHS of the comparison
+    Key lhs_;
+
+    // The RHS of the comparison, not including any constant offset.
+    Key rhs_;
+
+    // Additive offset on rhs
+    int64_t offset_{0};
+
+    // The comparison operator.
+    CompareResult result_{CompareResult::kInconsistent};
+  };
+
+  /*! \brief Generate a Comparison representing the given expression */
+  std::optional<Comparison> FromExpr(const PrimExpr& expr);
+
+  /*! \brief Utility function used by Bind and EnterConstraint
+   *
+   * \param expr The comparison expression, to be converted into
+   * internal Comparison objects.
+   *
+   * \param vec The vector to which the Comparison objects should be
+   * appended.
+   */
+  void AddKnown(const PrimExpr& expr, std::vector<Comparison>* vec);
+
+  /*! \brief Attempt to compare, starting at the lhs.
+   *
+   * Taking each available `Comparison` as a node edge, search for a
+   * path from lhs to rhs.  For example, the priors (a<=b), (b<=c+1)
+   * and (c<=d-5) can be used to prove that (a<=d-4).
+   *
+   * \param lhs The left-hand side of the comparison
+   *
+   * \param rhs The right-hand side of the comparison
+   *
+   * \return The result of the comparison
+   */
+  CompareResult TryCompareFromLHS(Key lhs_key, Key rhs_key, int64_t offset, const PrimExpr& lhs,
+                                  const PrimExpr& rhs) const;
+
+  /*! \brief Previous Range bindings
+   *
+   * Tracked separatedly to handle the `allow_override` option used by
+   * all sub-analyzers when binding variables.
+   */
+  Map<Var, Range> prev_bindings_;
+
+  /*! \brief Known comparisons based on definitionally-true statements
+   *
+   * For example, a Let binding, or the range of an iterator.
+   */
+  std::vector<Comparison> knowns_;
+
+  /*! \brief Known comparisons based on of scope-based statements
+   *
+   * For example, the condition of an IfThenElse, which is known to be
+   * true while within the if scope.
+   */
+  std::vector<Comparison> scoped_knowns_;
+};
+
+namespace {
+
+// Internal utility, return the CompareResult resulting from swapping
+// the left-hand side with the right-hand side.
+CompareResult Reverse(CompareResult res) {
+  switch (res) {
+    case CompareResult::kInconsistent:
+      return CompareResult::kInconsistent;
+    case CompareResult::kEQ:
+      return CompareResult::kEQ;
+    case CompareResult::kLT:
+      return CompareResult::kGT;
+    case CompareResult::kLE:
+      return CompareResult::kGE;
+    case CompareResult::kGT:
+      return CompareResult::kLT;
+    case CompareResult::kGE:
+      return CompareResult::kLE;
+    case CompareResult::kNE:
+      return CompareResult::kNE;
+    case CompareResult::kUnknown:
+      return CompareResult::kUnknown;
+    default:
+      LOG(FATAL) << "Invalid CompareResult: " << static_cast<int>(res);
+      return CompareResult::kInconsistent;
+  }
+}
+
+// Internal utility, return the CompareResult resulting from negating
+// the comparison.
+CompareResult Negate(CompareResult res) {
+  switch (res) {
+    case CompareResult::kInconsistent:
+      return CompareResult::kInconsistent;
+    case CompareResult::kUnknown:
+      return CompareResult::kUnknown;
+    default:
+      return CompareResult(~static_cast<int>(res) & static_cast<int>(CompareResult::kUnknown));
+  }
+}
+
+// Internal utility, extract constant offsets out of the two sides of
+// a comparison.  Given lhs and rhs, return a tuple of three elements
+// (lhs_inner, rhs_inner, offset), such that (lhs OP rhs) and
+// (lhs_inner OP rhs_inner + offset) are equivalent.
+std::tuple<PrimExpr, PrimExpr, int64_t> ExtractOffsets(const PrimExpr& lhs, const PrimExpr& rhs) {
+  auto extract_offset = [](const PrimExpr& expr) -> std::pair<PrimExpr, int64_t> {
+    PVar<PrimExpr> x;
+    PVar<IntImm> c;
+    if ((x + c).Match(expr)) {
+      return {x.Eval(), c.Eval()->value};
+    } else if ((x - c).Match(expr)) {
+      return {x.Eval(), -c.Eval()->value};
+    } else if (c.Match(expr)) {
+      return {0, c.Eval()->value};
+    } else {
+      return {expr, 0};
+    }
+  };
+
+  auto lhs_split = extract_offset(lhs);
+  auto rhs_split = extract_offset(rhs);
+  return {lhs_split.first, rhs_split.first, rhs_split.second - lhs_split.second};
+}
+
+}  // namespace
+
+std::optional<TransitiveComparisonAnalyzer::Impl::Comparison>
+TransitiveComparisonAnalyzer::Impl::FromExpr(const PrimExpr& expr) {
+  CompareResult res;
+  PVar<PrimExpr> x, y;
+  if ((x <= y).Match(expr)) {
+    res = CompareResult::kLE;
+  } else if ((x >= y).Match(expr)) {
+    res = CompareResult::kGE;
+  } else if ((x < y).Match(expr)) {
+    res = CompareResult::kLT;
+  } else if ((x > y).Match(expr)) {
+    res = CompareResult::kGT;
+  } else if ((x == y).Match(expr)) {
+    res = CompareResult::kEQ;
+  } else if ((x != y).Match(expr)) {
+    res = CompareResult::kNE;
+  } else {
+    return std::nullopt;
+  }
+
+  PrimExpr lhs_expr = x.Eval();
+  PrimExpr rhs_expr = y.Eval();
+
+  if (lhs_expr.as<IntImmNode>() && rhs_expr.as<IntImmNode>()) {
+    return std::nullopt;
+  }
+
+  auto [lhs, rhs, offset] = ExtractOffsets(lhs_expr, rhs_expr);
+  Key lhs_key = ExprToKey(lhs);
+  Key rhs_key = ExprToKey(rhs);
+
+  return Comparison(lhs_key, rhs_key, offset, res);
+}
+
+TransitiveComparisonAnalyzer::Impl::Comparison::Comparison(Key lhs, Key rhs, int64_t offset,
+                                                           CompareResult result)
+    : lhs_(lhs), rhs_(rhs), offset_(offset), result_(result) {
+  if (result_ == CompareResult::kLT) {
+    result_ = CompareResult::kLE;
+    offset_ -= 1;
+  }
+  if (result_ == CompareResult::kGT) {
+    result_ = CompareResult::kGE;
+    offset_ += 1;
+  }
+}
+
+std::optional<TransitiveComparisonAnalyzer::Impl::Key>
+TransitiveComparisonAnalyzer::Impl::ExprToPreviousKey(const PrimExpr& expr) const {
+  auto it = expr_to_key.find(expr);
+  if (it != expr_to_key.end()) {
+    return it->second;
+  } else {
+    return std::nullopt;
+  }
+}
+
+TransitiveComparisonAnalyzer::Impl::Key TransitiveComparisonAnalyzer::Impl::ExprToKey(
+    const PrimExpr& expr) {
+  if (auto prev = ExprToPreviousKey(expr)) {
+    return prev.value();
+  } else {
+    Key new_key = Key(expr_to_key.size());
+    expr_to_key[expr] = new_key;
+    return new_key;
+  }
+}
+
+bool TransitiveComparisonAnalyzer::Impl::Comparison::IsNormalized() const {
+  // These < and > should be removed during normalization.
+  return result_ != CompareResult::kLT && result_ != CompareResult::kGT;
+}
+
+std::optional<TransitiveComparisonAnalyzer::Impl::Comparison>
+TransitiveComparisonAnalyzer::Impl::Comparison::WithLHS(Key new_lhs) const {
+  if (new_lhs == lhs_) {
+    return *this;
+  } else if (new_lhs == rhs_) {
+    return Comparison(rhs_, lhs_, -offset_, Reverse(result_));
+  } else {
+    return std::nullopt;
+  }
+}
+
+TransitiveComparisonAnalyzer::Impl::Comparison
+TransitiveComparisonAnalyzer::Impl::Comparison::Negated() const {
+  return Comparison(lhs_, rhs_, offset_, Negate(result_));
+}
+
+bool TransitiveComparisonAnalyzer::Impl::Comparison::Implies(
+    const TransitiveComparisonAnalyzer::Impl::Comparison& other) const {
+  ICHECK(lhs_ == other.lhs_);
+  ICHECK(rhs_ == other.rhs_);
+  ICHECK(IsNormalized());
+  ICHECK(other.IsNormalized());
+
+  if (result_ == other.result_ && offset_ == other.offset_) {
+    // if c1 == c2, x != y + c1 => x != y + c2
+    // if c1 == c2, x == y + c1 => x == y + c2
+    return true;
+  }
+
+  if (other.result_ == CompareResult::kLE && offset_ <= other.offset_) {
+    if (result_ == CompareResult::kEQ || result_ == CompareResult::kLE) {
+      // if c1 <= c2, x <= y + c1 => x <= y + c2
+      // if c1 <= c2, x == y + c1 => x <= y + c2
+      return true;
+    }
+  }
+
+  if (other.result_ == CompareResult::kGE && offset_ >= other.offset_) {
+    if (result_ == CompareResult::kEQ || result_ == CompareResult::kGE) {
+      // if c1 >= c2, x == y + c1 => x >= y + c2
+      // if c1 >= c2, x >= y + c1 => x >= y + c2
+      return true;
+    }
+  }
+
+  if (other.result_ == CompareResult::kNE) {
+    if (result_ == CompareResult::kEQ && offset_ != other.offset_) {
+      // if c1 != c2, x == y + c1 => x != y + c2
+      return true;
+    }
+
+    if (result_ == CompareResult::kLE && offset_ < other.offset_) {
+      // if c1 < c2, x <= y + c1 => x < y + c2 => x != y + c2
+      return true;
+    }
+
+    if (result_ == CompareResult::kGE && offset_ > other.offset_) {
+      // if c1 != c2, x >= y + c1 => x > y + c2 => x != y + c2
+      return true;
+    }
+  }
+
+  return false;
+}
+
+TransitiveComparisonAnalyzer::TransitiveComparisonAnalyzer() : impl_(std::make_unique<Impl>()) {}
+TransitiveComparisonAnalyzer::~TransitiveComparisonAnalyzer() {}
+
+CompareResult TransitiveComparisonAnalyzer::TryCompare(const PrimExpr& lhs, const PrimExpr& rhs) {
+  return impl_->TryCompare(lhs, rhs);
+}
+
+void TransitiveComparisonAnalyzer::Bind(const Var& var, const PrimExpr& expr, bool allow_override) {
+  impl_->Bind(var, expr, allow_override);
+}
+void TransitiveComparisonAnalyzer::Bind(const Var& var, const Range& range, bool allow_override) {
+  impl_->Bind(var, range, allow_override);
+}
+
+std::function<void()> TransitiveComparisonAnalyzer::EnterConstraint(const PrimExpr& constraint) {
+  return impl_->EnterConstraint(constraint);
+}
+
+void TransitiveComparisonAnalyzer::Impl::AddKnown(const PrimExpr& expr,
+                                                  std::vector<Comparison>* vec) {
+  for (const auto& subexpr : ExtractConstraints(expr)) {
+    if (tir::SideEffect(expr) <= tir::CallEffectKind::kPure) {
+      if (auto cmp = FromExpr(subexpr)) {
+        vec->push_back(cmp.value());
+      }
+    }
+  }
+}
+
+void TransitiveComparisonAnalyzer::Impl::Bind(const tir::Var& var, const Range& range,
+                                              bool allow_override) {
+  auto it = prev_bindings_.find(var);
+  if (it != prev_bindings_.end()) {
+    ExprDeepEqual expr_equal;
+    bool differs_from_previous = !expr_equal(range->min, (*it).second->min) ||
+                                 !expr_equal(range->extent, (*it).second->extent);
+    if (differs_from_previous) {
+      ICHECK(allow_override) << "Binding of variable " << var << " as " << range
+                             << " conflicts with previous binding as " << (*it).second;
+      if (auto key = ExprToPreviousKey(var)) {
+        knowns_.erase(std::remove_if(knowns_.begin(), knowns_.end(),
+                                     [&](const auto& known) { return known.lhs_ == key.value(); }),
+                      knowns_.end());
+      }
+    }
+  }
+
+  prev_bindings_.Set(var, range);
+
+  if (is_const_int(range->extent, 1)) {
+    AddKnown(var == range->min, &knowns_);
+  } else {
+    AddKnown(var >= range->min, &knowns_);
+    AddKnown(var < range->min + range->extent, &knowns_);
+  }
+}
+
+void TransitiveComparisonAnalyzer::Impl::Bind(const tir::Var& var, const PrimExpr& expr,
+                                              bool allow_override) {
+  Bind(var, Range::FromMinExtent(expr, 1), allow_override);
+}
+
+std::function<void()> TransitiveComparisonAnalyzer::Impl::EnterConstraint(const PrimExpr& expr) {
+  size_t old_literal_size = scoped_knowns_.size();
+  AddKnown(expr, &scoped_knowns_);
+  size_t new_literal_size = scoped_knowns_.size();
+
+  PrimExpr temp = expr;
+  auto frecover = [old_literal_size, new_literal_size, this, temp]() {
+    ICHECK_EQ(scoped_knowns_.size(), new_literal_size);
+    scoped_knowns_.erase(scoped_knowns_.begin() + old_literal_size, scoped_knowns_.end());
+  };
+  return frecover;
+}
+
+CompareResult TransitiveComparisonAnalyzer::Impl::TryCompare(const PrimExpr& lhs_expr,
+                                                             const PrimExpr& rhs_expr) const {
+  // Currently only supports integer checks
+  if (!lhs_expr.dtype().is_int() || !rhs_expr.dtype().is_int()) {
+    return CompareResult::kUnknown;
+  }
+
+  // Bail out early if possible.  This int check should have been
+  // constant-folded earlier, so this check shouldn't occur.
+  auto* x_int = lhs_expr.as<IntImmNode>();
+  auto* y_int = rhs_expr.as<IntImmNode>();
+  if (x_int && y_int) {
+    if (x_int->value < y_int->value) {
+      return CompareResult::kLT;
+    } else if (x_int->value > y_int->value) {
+      return CompareResult::kGT;
+    } else {
+      return CompareResult::kEQ;
+    }
+  }
+
+  auto [lhs, rhs, offset] = ExtractOffsets(lhs_expr, rhs_expr);
+  auto lhs_key = ExprToPreviousKey(lhs);
+  auto rhs_key = ExprToPreviousKey(rhs);
+
+  if (!lhs_key.has_value() || !rhs_key.has_value()) {
+    return CompareResult::kUnknown;
+  }
+
+  auto from_lhs = TryCompareFromLHS(lhs_key.value(), rhs_key.value(), offset, lhs, rhs);
+  auto from_rhs = Reverse(TryCompareFromLHS(rhs_key.value(), lhs_key.value(), -offset, rhs, lhs));
+  auto output = from_lhs & from_rhs;
+
+  return output;
+}
+
+CompareResult TransitiveComparisonAnalyzer::Impl::TryCompareFromLHS(
+    Key lhs_key_input, Key rhs_key_input, int64_t offset_input, const PrimExpr& lhs_input,
+    const PrimExpr& rhs_input) const {
+  Key lhs_key = lhs_key_input;
+  Key rhs_key = rhs_key_input;
+  int64_t offset = offset_input;
+
+  // Everything in `to_visit` has lhs as its lhs.
+  std::unordered_set<Key> seen;
+  std::unordered_set<Key> to_visit;
+  std::unordered_map<Key, std::vector<Comparison>> compared_to_x;
+
+  // Utility function to add a new known statement
+  auto declare_known = [&](Comparison cmp) {
+    auto& prev_knowns = compared_to_x[cmp.rhs_];
+
+    for (auto& prev_known : prev_knowns) {
+      if (prev_known.Implies(cmp)) {
+        return;
+      }
+    }
+
+    if (cmp.rhs_ != rhs_key && !seen.count(cmp.rhs_)) {
+      to_visit.insert(cmp.rhs_);
+      seen.insert(cmp.rhs_);
+    }
+
+    for (auto& prev_known : prev_knowns) {
+      if (cmp.Implies(prev_known)) {
+        prev_known = cmp;
+        return;
+      }
+    }
+
+    prev_knowns.push_back(cmp);
+  };
+
+  // Initialize the search based on any known (in)equalities that use
+  // the LHS of the comparison.
+  for (const auto& known : knowns_) {
+    if (auto normalized = known.WithLHS(lhs_key)) {
+      declare_known(normalized.value());
+    }
+  }
+  for (const auto& known : scoped_knowns_) {
+    if (auto normalized = known.WithLHS(lhs_key)) {
+      declare_known(normalized.value());
+    }
+  }
+
+  // Walk through the space of all comparisons that can be made with
+  // LHS.
+  while (to_visit.size()) {
+    Key middle_key = *to_visit.begin();
+    to_visit.erase(to_visit.begin());
+
+    std::vector<Comparison>& prev_knowns_using_middle = compared_to_x.at(middle_key);
+    ICHECK(compared_to_x.count(middle_key));
+
+    std::vector<Comparison> new_knowns_using_lhs;
+
+    auto attempt_transitive = [&](Comparison cmp) {
+      ICHECK(cmp.IsNormalized());
+
+      Key right_key = cmp.rhs_;
+
+      if (right_key == lhs_key) {
+        return;
+      }
+
+      for (const auto& prev : prev_knowns_using_middle) {
+        CompareResult new_result = CompareResult::kUnknown;
+        int64_t new_offset = prev.offset_ + cmp.offset_;
+
+        if (prev.result_ == CompareResult::kEQ) {
+          // x == y + c1 && y OP z + c2, x OP z + (c1 + c2)
+          new_result = cmp.result_;
+        } else if (cmp.result_ == CompareResult::kEQ) {
+          // x OP y + c1 && y == z + c2, x OP z + (c1 + c2)
+          new_result = prev.result_;
+        } else if (prev.result_ == cmp.result_ &&
+                   (prev.result_ == CompareResult::kLE || prev.result_ == CompareResult::kGE)) {
+          // x <= y + c1 && y <= z + c2, x <= z + (c1 + c2)
+          // x >= y + c1 && y >= z + c2, x >= z + (c1 + c2)
+          //
+          // This condition is much simpler to write than the
+          // equivalent handling of < or of >, which is why the
+          // inequalities are normalized to <= and to >=.
+          new_result = prev.result_;
+        }
+
+        if (new_result != CompareResult::kUnknown) {
+          Comparison new_known(lhs_key, right_key, new_offset, new_result);
+          new_knowns_using_lhs.push_back(new_known);
+        }
+      }
+    };
+
+    // Attempt to prove a new comparison using one of the original
+    // known comparisons.  We want to find a known such that
+    // `(LHS OP1 middle) && (middle OP2 right)` can be simplified

Review Comment:
   nit: Switch of convention from LHS to right instead of RHS. Maybe just use left/middle/right



##########
include/tvm/arith/analyzer.h:
##########
@@ -275,6 +275,36 @@ class RewriteSimplifier {
    */
   std::function<void()> EnterConstraint(const PrimExpr& constraint);
 
+  /*! \brief Flags to enable more computationally-intensive simplifications
+   *
+   * These simplifications may be required for specific schedules, but
+   * would impose too high a compile-time cost to enable by default.
+   * They can be enabled on an as-needed basis by calling
+   * `RewriteSimplifier::SetEnabledFeatures` prior to using
+   * `RewriteSimplifier::operator()`.
+   */
+  enum Feature {
+    // No features enabled
+    kNone = 0,
+
+    /* When simplifying an inequality, attempt to use scope-based knowns.
+     *
+     * Example:
+     * if_then_else(i<j && j<k, i<k, false) => if_then_else(i<j && j<k, true, false)
+     */
+    kTransitivelyProveInequalities = (1 << 0),
+  };
+
+  /*! \brief Enable an optional feature or features
+   *
+   * \param flags A bitwise OR of all optional features that should be
+   * enabled.
+   */
+  void SetEnabledFeatures(Feature flags);

Review Comment:
   Use of TVM_DLL on member functions.



##########
src/arith/transitive_comparison_analyzer.cc:
##########
@@ -0,0 +1,683 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file tvm/arith/transitive_comparison_analyzer.cc
+ */
+
+#include <tvm/arith/analyzer.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/expr.h>
+
+#include <optional>
+#include <vector>
+
+#include "constraint_extract.h"
+#include "pattern_match.h"
+
+namespace tvm {
+namespace arith {
+
+using namespace tir;
+
+class TransitiveComparisonAnalyzer::Impl {
+ public:
+  /* \brief Using previously specified knowns, compare the expressions provided
+   *
+   * \param lhs The left-hand side of the comparison
+   *
+   * \param rhs The right-hand side of the comparison
+   *
+   * \return The most specific result that can be proven about the
+   * comparison.  If nothing can be proven, returns kUnknown.
+   */
+  CompareResult TryCompare(const PrimExpr& lhs, const PrimExpr& rhs) const;
+
+  /*! \brief Bind a variable as being equal to a known expression
+   *
+   * \param var The variable of interest.
+   * \param expr The bound expression
+   * \param allow_override Whether to allow override of existing information.
+   */
+  void Bind(const tir::Var& var, const PrimExpr& expr, bool allow_override = false);
+
+  /*! \brief Bind a variable as being within a specified range
+   *
+   * \param var The variable of interest.
+   * \param range The known range
+   * \param allow_override Whether to allow override of existing information.
+   */
+  void Bind(const tir::Var& var, const Range& expr, bool allow_override = false);
+
+  /*!
+   * \brief Update the internal state to enter constraint.
+   * \param constraint A constraint expression.
+   *
+   * \return An exit function that must be called to cleanup.  May be
+   * `nullptr`, if no cleanup is required.
+   */
+  std::function<void()> EnterConstraint(const PrimExpr& expr);
+
+ private:
+  // Utility class to avoid needing to repeatedly call ExprDeepEqual
+  enum class Key : size_t {};
+  std::optional<Key> ExprToPreviousKey(const PrimExpr& expr) const;
+  Key ExprToKey(const PrimExpr& expr);
+  std::unordered_map<PrimExpr, Key, StructuralHash, StructuralEqual> expr_to_key;
+
+  /*! \brief Internal representation of a comparison operator */
+  struct Comparison {
+    /*! \brief Construct a comparison that represents `lhs OP rhs +
+     * offset`, where the operation is specified by the CompareResult.
+     */
+    Comparison(Key lhs, Key rhs, int64_t offset, CompareResult result);
+
+    /*! \brief Utility function to validate that all GT and LT results
+     *  have been normalized out
+     */
+    bool IsNormalized() const;
+
+    /*! \brief Move the specified expression to the LHS.
+     *
+     * \param new_lhs The argument that should be moved to the LHS of the
+     * comparison.
+     *
+     * \return If possible, returns a comparison that is equivalent to
+     * the current comparison, but with the specified LHS.  If not
+     * possible, returns nullopt.
+     */
+    std::optional<Comparison> WithLHS(Key new_lhs) const;
+
+    /*! \brief Create the negation of the current comparison */
+    Comparison Negated() const;
+
+    /*! \brief Check the this comparison implies
+     *
+     * Returns true if this comparison being true implies that the
+     * other comparison must also be true.  Returns false if the other
+     * comparison cannot be shown to be true.
+     */
+    bool Implies(const Comparison& other) const;
+
+    // The LHS of the comparison
+    Key lhs_;
+
+    // The RHS of the comparison, not including any constant offset.
+    Key rhs_;
+
+    // Additive offset on rhs
+    int64_t offset_{0};
+
+    // The comparison operator.
+    CompareResult result_{CompareResult::kInconsistent};
+  };
+
+  /*! \brief Generate a Comparison representing the given expression */
+  std::optional<Comparison> FromExpr(const PrimExpr& expr);
+
+  /*! \brief Utility function used by Bind and EnterConstraint
+   *
+   * \param expr The comparison expression, to be converted into
+   * internal Comparison objects.
+   *
+   * \param vec The vector to which the Comparison objects should be
+   * appended.
+   */
+  void AddKnown(const PrimExpr& expr, std::vector<Comparison>* vec);
+
+  /*! \brief Attempt to compare, starting at the lhs.
+   *
+   * Taking each available `Comparison` as a node edge, search for a

Review Comment:
   nit: A diagram of the AST referenced by `node edge` in this line would make the following description quite clear. 



##########
include/tvm/arith/analyzer.h:
##########
@@ -275,6 +275,36 @@ class RewriteSimplifier {
    */
   std::function<void()> EnterConstraint(const PrimExpr& constraint);
 
+  /*! \brief Flags to enable more computationally-intensive simplifications
+   *
+   * These simplifications may be required for specific schedules, but
+   * would impose too high a compile-time cost to enable by default.
+   * They can be enabled on an as-needed basis by calling
+   * `RewriteSimplifier::SetEnabledFeatures` prior to using
+   * `RewriteSimplifier::operator()`.
+   */
+  enum Feature {

Review Comment:
   Extensions?



##########
src/arith/transitive_comparison_analyzer.cc:
##########
@@ -0,0 +1,683 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file tvm/arith/transitive_comparison_analyzer.cc
+ */
+
+#include <tvm/arith/analyzer.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/expr.h>
+
+#include <optional>
+#include <vector>
+
+#include "constraint_extract.h"
+#include "pattern_match.h"
+
+namespace tvm {
+namespace arith {
+
+using namespace tir;
+
+class TransitiveComparisonAnalyzer::Impl {
+ public:
+  /* \brief Using previously specified knowns, compare the expressions provided
+   *
+   * \param lhs The left-hand side of the comparison
+   *
+   * \param rhs The right-hand side of the comparison
+   *
+   * \return The most specific result that can be proven about the
+   * comparison.  If nothing can be proven, returns kUnknown.
+   */
+  CompareResult TryCompare(const PrimExpr& lhs, const PrimExpr& rhs) const;
+
+  /*! \brief Bind a variable as being equal to a known expression
+   *
+   * \param var The variable of interest.
+   * \param expr The bound expression
+   * \param allow_override Whether to allow override of existing information.
+   */
+  void Bind(const tir::Var& var, const PrimExpr& expr, bool allow_override = false);
+
+  /*! \brief Bind a variable as being within a specified range
+   *
+   * \param var The variable of interest.
+   * \param range The known range
+   * \param allow_override Whether to allow override of existing information.
+   */
+  void Bind(const tir::Var& var, const Range& expr, bool allow_override = false);
+
+  /*!
+   * \brief Update the internal state to enter constraint.
+   * \param constraint A constraint expression.
+   *
+   * \return An exit function that must be called to cleanup.  May be
+   * `nullptr`, if no cleanup is required.
+   */
+  std::function<void()> EnterConstraint(const PrimExpr& expr);
+
+ private:
+  // Utility class to avoid needing to repeatedly call ExprDeepEqual
+  enum class Key : size_t {};
+  std::optional<Key> ExprToPreviousKey(const PrimExpr& expr) const;
+  Key ExprToKey(const PrimExpr& expr);
+  std::unordered_map<PrimExpr, Key, StructuralHash, StructuralEqual> expr_to_key;
+
+  /*! \brief Internal representation of a comparison operator */
+  struct Comparison {
+    /*! \brief Construct a comparison that represents `lhs OP rhs +
+     * offset`, where the operation is specified by the CompareResult.
+     */
+    Comparison(Key lhs, Key rhs, int64_t offset, CompareResult result);
+
+    /*! \brief Utility function to validate that all GT and LT results
+     *  have been normalized out
+     */
+    bool IsNormalized() const;
+
+    /*! \brief Move the specified expression to the LHS.
+     *
+     * \param new_lhs The argument that should be moved to the LHS of the
+     * comparison.
+     *
+     * \return If possible, returns a comparison that is equivalent to
+     * the current comparison, but with the specified LHS.  If not
+     * possible, returns nullopt.
+     */
+    std::optional<Comparison> WithLHS(Key new_lhs) const;
+
+    /*! \brief Create the negation of the current comparison */
+    Comparison Negated() const;
+
+    /*! \brief Check the this comparison implies
+     *
+     * Returns true if this comparison being true implies that the
+     * other comparison must also be true.  Returns false if the other
+     * comparison cannot be shown to be true.
+     */
+    bool Implies(const Comparison& other) const;
+
+    // The LHS of the comparison
+    Key lhs_;
+
+    // The RHS of the comparison, not including any constant offset.
+    Key rhs_;
+
+    // Additive offset on rhs
+    int64_t offset_{0};
+
+    // The comparison operator.
+    CompareResult result_{CompareResult::kInconsistent};
+  };
+
+  /*! \brief Generate a Comparison representing the given expression */
+  std::optional<Comparison> FromExpr(const PrimExpr& expr);
+
+  /*! \brief Utility function used by Bind and EnterConstraint
+   *
+   * \param expr The comparison expression, to be converted into
+   * internal Comparison objects.
+   *
+   * \param vec The vector to which the Comparison objects should be
+   * appended.
+   */
+  void AddKnown(const PrimExpr& expr, std::vector<Comparison>* vec);
+
+  /*! \brief Attempt to compare, starting at the lhs.
+   *
+   * Taking each available `Comparison` as a node edge, search for a
+   * path from lhs to rhs.  For example, the priors (a<=b), (b<=c+1)
+   * and (c<=d-5) can be used to prove that (a<=d-4).
+   *
+   * \param lhs The left-hand side of the comparison
+   *
+   * \param rhs The right-hand side of the comparison
+   *
+   * \return The result of the comparison
+   */
+  CompareResult TryCompareFromLHS(Key lhs_key, Key rhs_key, int64_t offset, const PrimExpr& lhs,
+                                  const PrimExpr& rhs) const;
+
+  /*! \brief Previous Range bindings
+   *
+   * Tracked separatedly to handle the `allow_override` option used by
+   * all sub-analyzers when binding variables.
+   */
+  Map<Var, Range> prev_bindings_;
+
+  /*! \brief Known comparisons based on definitionally-true statements
+   *
+   * For example, a Let binding, or the range of an iterator.
+   */
+  std::vector<Comparison> knowns_;
+
+  /*! \brief Known comparisons based on of scope-based statements

Review Comment:
   Unclear sentence



##########
src/arith/transitive_comparison_analyzer.cc:
##########
@@ -0,0 +1,683 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file tvm/arith/transitive_comparison_analyzer.cc
+ */
+
+#include <tvm/arith/analyzer.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/expr.h>
+
+#include <optional>
+#include <vector>
+
+#include "constraint_extract.h"
+#include "pattern_match.h"
+
+namespace tvm {
+namespace arith {
+
+using namespace tir;
+
+class TransitiveComparisonAnalyzer::Impl {
+ public:
+  /* \brief Using previously specified knowns, compare the expressions provided
+   *
+   * \param lhs The left-hand side of the comparison
+   *
+   * \param rhs The right-hand side of the comparison
+   *
+   * \return The most specific result that can be proven about the
+   * comparison.  If nothing can be proven, returns kUnknown.
+   */
+  CompareResult TryCompare(const PrimExpr& lhs, const PrimExpr& rhs) const;
+
+  /*! \brief Bind a variable as being equal to a known expression
+   *
+   * \param var The variable of interest.
+   * \param expr The bound expression
+   * \param allow_override Whether to allow override of existing information.
+   */
+  void Bind(const tir::Var& var, const PrimExpr& expr, bool allow_override = false);
+
+  /*! \brief Bind a variable as being within a specified range
+   *
+   * \param var The variable of interest.
+   * \param range The known range
+   * \param allow_override Whether to allow override of existing information.
+   */
+  void Bind(const tir::Var& var, const Range& expr, bool allow_override = false);
+
+  /*!
+   * \brief Update the internal state to enter constraint.
+   * \param constraint A constraint expression.
+   *
+   * \return An exit function that must be called to cleanup.  May be
+   * `nullptr`, if no cleanup is required.
+   */
+  std::function<void()> EnterConstraint(const PrimExpr& expr);
+
+ private:
+  // Utility class to avoid needing to repeatedly call ExprDeepEqual
+  enum class Key : size_t {};
+  std::optional<Key> ExprToPreviousKey(const PrimExpr& expr) const;
+  Key ExprToKey(const PrimExpr& expr);
+  std::unordered_map<PrimExpr, Key, StructuralHash, StructuralEqual> expr_to_key;
+
+  /*! \brief Internal representation of a comparison operator */
+  struct Comparison {
+    /*! \brief Construct a comparison that represents `lhs OP rhs +
+     * offset`, where the operation is specified by the CompareResult.
+     */
+    Comparison(Key lhs, Key rhs, int64_t offset, CompareResult result);
+
+    /*! \brief Utility function to validate that all GT and LT results
+     *  have been normalized out
+     */
+    bool IsNormalized() const;
+
+    /*! \brief Move the specified expression to the LHS.
+     *
+     * \param new_lhs The argument that should be moved to the LHS of the
+     * comparison.
+     *
+     * \return If possible, returns a comparison that is equivalent to
+     * the current comparison, but with the specified LHS.  If not
+     * possible, returns nullopt.
+     */
+    std::optional<Comparison> WithLHS(Key new_lhs) const;
+
+    /*! \brief Create the negation of the current comparison */
+    Comparison Negated() const;
+
+    /*! \brief Check the this comparison implies
+     *
+     * Returns true if this comparison being true implies that the
+     * other comparison must also be true.  Returns false if the other
+     * comparison cannot be shown to be true.
+     */
+    bool Implies(const Comparison& other) const;
+
+    // The LHS of the comparison
+    Key lhs_;
+
+    // The RHS of the comparison, not including any constant offset.
+    Key rhs_;
+
+    // Additive offset on rhs
+    int64_t offset_{0};
+
+    // The comparison operator.
+    CompareResult result_{CompareResult::kInconsistent};
+  };
+
+  /*! \brief Generate a Comparison representing the given expression */
+  std::optional<Comparison> FromExpr(const PrimExpr& expr);
+
+  /*! \brief Utility function used by Bind and EnterConstraint
+   *
+   * \param expr The comparison expression, to be converted into
+   * internal Comparison objects.
+   *
+   * \param vec The vector to which the Comparison objects should be
+   * appended.
+   */
+  void AddKnown(const PrimExpr& expr, std::vector<Comparison>* vec);
+
+  /*! \brief Attempt to compare, starting at the lhs.
+   *
+   * Taking each available `Comparison` as a node edge, search for a
+   * path from lhs to rhs.  For example, the priors (a<=b), (b<=c+1)
+   * and (c<=d-5) can be used to prove that (a<=d-4).
+   *
+   * \param lhs The left-hand side of the comparison
+   *
+   * \param rhs The right-hand side of the comparison
+   *
+   * \return The result of the comparison
+   */
+  CompareResult TryCompareFromLHS(Key lhs_key, Key rhs_key, int64_t offset, const PrimExpr& lhs,
+                                  const PrimExpr& rhs) const;
+
+  /*! \brief Previous Range bindings
+   *
+   * Tracked separatedly to handle the `allow_override` option used by
+   * all sub-analyzers when binding variables.
+   */
+  Map<Var, Range> prev_bindings_;
+
+  /*! \brief Known comparisons based on definitionally-true statements
+   *
+   * For example, a Let binding, or the range of an iterator.
+   */
+  std::vector<Comparison> knowns_;
+
+  /*! \brief Known comparisons based on of scope-based statements
+   *
+   * For example, the condition of an IfThenElse, which is known to be
+   * true while within the if scope.
+   */
+  std::vector<Comparison> scoped_knowns_;
+};
+
+namespace {
+
+// Internal utility, return the CompareResult resulting from swapping
+// the left-hand side with the right-hand side.
+CompareResult Reverse(CompareResult res) {
+  switch (res) {
+    case CompareResult::kInconsistent:
+      return CompareResult::kInconsistent;
+    case CompareResult::kEQ:
+      return CompareResult::kEQ;
+    case CompareResult::kLT:
+      return CompareResult::kGT;
+    case CompareResult::kLE:
+      return CompareResult::kGE;
+    case CompareResult::kGT:
+      return CompareResult::kLT;
+    case CompareResult::kGE:
+      return CompareResult::kLE;
+    case CompareResult::kNE:
+      return CompareResult::kNE;
+    case CompareResult::kUnknown:
+      return CompareResult::kUnknown;
+    default:
+      LOG(FATAL) << "Invalid CompareResult: " << static_cast<int>(res);
+      return CompareResult::kInconsistent;
+  }
+}
+
+// Internal utility, return the CompareResult resulting from negating
+// the comparison.
+CompareResult Negate(CompareResult res) {
+  switch (res) {
+    case CompareResult::kInconsistent:
+      return CompareResult::kInconsistent;
+    case CompareResult::kUnknown:
+      return CompareResult::kUnknown;
+    default:
+      return CompareResult(~static_cast<int>(res) & static_cast<int>(CompareResult::kUnknown));
+  }
+}
+
+// Internal utility, extract constant offsets out of the two sides of
+// a comparison.  Given lhs and rhs, return a tuple of three elements
+// (lhs_inner, rhs_inner, offset), such that (lhs OP rhs) and
+// (lhs_inner OP rhs_inner + offset) are equivalent.
+std::tuple<PrimExpr, PrimExpr, int64_t> ExtractOffsets(const PrimExpr& lhs, const PrimExpr& rhs) {
+  auto extract_offset = [](const PrimExpr& expr) -> std::pair<PrimExpr, int64_t> {
+    PVar<PrimExpr> x;
+    PVar<IntImm> c;
+    if ((x + c).Match(expr)) {
+      return {x.Eval(), c.Eval()->value};
+    } else if ((x - c).Match(expr)) {
+      return {x.Eval(), -c.Eval()->value};
+    } else if (c.Match(expr)) {
+      return {0, c.Eval()->value};
+    } else {
+      return {expr, 0};
+    }
+  };
+
+  auto lhs_split = extract_offset(lhs);
+  auto rhs_split = extract_offset(rhs);
+  return {lhs_split.first, rhs_split.first, rhs_split.second - lhs_split.second};
+}
+
+}  // namespace
+
+std::optional<TransitiveComparisonAnalyzer::Impl::Comparison>
+TransitiveComparisonAnalyzer::Impl::FromExpr(const PrimExpr& expr) {
+  CompareResult res;
+  PVar<PrimExpr> x, y;
+  if ((x <= y).Match(expr)) {
+    res = CompareResult::kLE;
+  } else if ((x >= y).Match(expr)) {
+    res = CompareResult::kGE;
+  } else if ((x < y).Match(expr)) {
+    res = CompareResult::kLT;
+  } else if ((x > y).Match(expr)) {
+    res = CompareResult::kGT;
+  } else if ((x == y).Match(expr)) {
+    res = CompareResult::kEQ;
+  } else if ((x != y).Match(expr)) {
+    res = CompareResult::kNE;
+  } else {
+    return std::nullopt;
+  }
+
+  PrimExpr lhs_expr = x.Eval();
+  PrimExpr rhs_expr = y.Eval();
+
+  if (lhs_expr.as<IntImmNode>() && rhs_expr.as<IntImmNode>()) {
+    return std::nullopt;
+  }
+
+  auto [lhs, rhs, offset] = ExtractOffsets(lhs_expr, rhs_expr);
+  Key lhs_key = ExprToKey(lhs);
+  Key rhs_key = ExprToKey(rhs);
+
+  return Comparison(lhs_key, rhs_key, offset, res);
+}
+
+TransitiveComparisonAnalyzer::Impl::Comparison::Comparison(Key lhs, Key rhs, int64_t offset,
+                                                           CompareResult result)
+    : lhs_(lhs), rhs_(rhs), offset_(offset), result_(result) {
+  if (result_ == CompareResult::kLT) {
+    result_ = CompareResult::kLE;
+    offset_ -= 1;

Review Comment:
   Any comment on why this representation is beneficial? E.g. to normalize, but perhaps a brief description on IsNormalized can provide clarity. 



##########
src/arith/transitive_comparison_analyzer.cc:
##########
@@ -0,0 +1,683 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file tvm/arith/transitive_comparison_analyzer.cc
+ */
+
+#include <tvm/arith/analyzer.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/expr.h>
+
+#include <optional>
+#include <vector>
+
+#include "constraint_extract.h"
+#include "pattern_match.h"
+
+namespace tvm {
+namespace arith {
+
+using namespace tir;
+
+class TransitiveComparisonAnalyzer::Impl {
+ public:
+  /* \brief Using previously specified knowns, compare the expressions provided
+   *
+   * \param lhs The left-hand side of the comparison
+   *
+   * \param rhs The right-hand side of the comparison
+   *
+   * \return The most specific result that can be proven about the
+   * comparison.  If nothing can be proven, returns kUnknown.
+   */
+  CompareResult TryCompare(const PrimExpr& lhs, const PrimExpr& rhs) const;
+
+  /*! \brief Bind a variable as being equal to a known expression
+   *
+   * \param var The variable of interest.
+   * \param expr The bound expression
+   * \param allow_override Whether to allow override of existing information.
+   */
+  void Bind(const tir::Var& var, const PrimExpr& expr, bool allow_override = false);
+
+  /*! \brief Bind a variable as being within a specified range
+   *
+   * \param var The variable of interest.
+   * \param range The known range
+   * \param allow_override Whether to allow override of existing information.
+   */
+  void Bind(const tir::Var& var, const Range& expr, bool allow_override = false);
+
+  /*!
+   * \brief Update the internal state to enter constraint.
+   * \param constraint A constraint expression.
+   *
+   * \return An exit function that must be called to cleanup.  May be
+   * `nullptr`, if no cleanup is required.
+   */
+  std::function<void()> EnterConstraint(const PrimExpr& expr);
+
+ private:
+  // Utility class to avoid needing to repeatedly call ExprDeepEqual
+  enum class Key : size_t {};
+  std::optional<Key> ExprToPreviousKey(const PrimExpr& expr) const;
+  Key ExprToKey(const PrimExpr& expr);
+  std::unordered_map<PrimExpr, Key, StructuralHash, StructuralEqual> expr_to_key;
+
+  /*! \brief Internal representation of a comparison operator */
+  struct Comparison {
+    /*! \brief Construct a comparison that represents `lhs OP rhs +
+     * offset`, where the operation is specified by the CompareResult.
+     */
+    Comparison(Key lhs, Key rhs, int64_t offset, CompareResult result);
+
+    /*! \brief Utility function to validate that all GT and LT results
+     *  have been normalized out
+     */
+    bool IsNormalized() const;
+
+    /*! \brief Move the specified expression to the LHS.
+     *
+     * \param new_lhs The argument that should be moved to the LHS of the
+     * comparison.
+     *
+     * \return If possible, returns a comparison that is equivalent to
+     * the current comparison, but with the specified LHS.  If not
+     * possible, returns nullopt.
+     */
+    std::optional<Comparison> WithLHS(Key new_lhs) const;
+
+    /*! \brief Create the negation of the current comparison */
+    Comparison Negated() const;
+
+    /*! \brief Check the this comparison implies
+     *
+     * Returns true if this comparison being true implies that the
+     * other comparison must also be true.  Returns false if the other
+     * comparison cannot be shown to be true.
+     */
+    bool Implies(const Comparison& other) const;
+
+    // The LHS of the comparison
+    Key lhs_;
+
+    // The RHS of the comparison, not including any constant offset.
+    Key rhs_;
+
+    // Additive offset on rhs
+    int64_t offset_{0};
+
+    // The comparison operator.
+    CompareResult result_{CompareResult::kInconsistent};
+  };
+
+  /*! \brief Generate a Comparison representing the given expression */
+  std::optional<Comparison> FromExpr(const PrimExpr& expr);
+
+  /*! \brief Utility function used by Bind and EnterConstraint
+   *
+   * \param expr The comparison expression, to be converted into
+   * internal Comparison objects.
+   *
+   * \param vec The vector to which the Comparison objects should be
+   * appended.
+   */
+  void AddKnown(const PrimExpr& expr, std::vector<Comparison>* vec);
+
+  /*! \brief Attempt to compare, starting at the lhs.
+   *
+   * Taking each available `Comparison` as a node edge, search for a
+   * path from lhs to rhs.  For example, the priors (a<=b), (b<=c+1)
+   * and (c<=d-5) can be used to prove that (a<=d-4).
+   *
+   * \param lhs The left-hand side of the comparison
+   *
+   * \param rhs The right-hand side of the comparison
+   *
+   * \return The result of the comparison
+   */
+  CompareResult TryCompareFromLHS(Key lhs_key, Key rhs_key, int64_t offset, const PrimExpr& lhs,
+                                  const PrimExpr& rhs) const;
+
+  /*! \brief Previous Range bindings
+   *
+   * Tracked separatedly to handle the `allow_override` option used by
+   * all sub-analyzers when binding variables.
+   */
+  Map<Var, Range> prev_bindings_;
+
+  /*! \brief Known comparisons based on definitionally-true statements
+   *
+   * For example, a Let binding, or the range of an iterator.
+   */
+  std::vector<Comparison> knowns_;
+
+  /*! \brief Known comparisons based on of scope-based statements
+   *
+   * For example, the condition of an IfThenElse, which is known to be
+   * true while within the if scope.
+   */
+  std::vector<Comparison> scoped_knowns_;
+};
+
+namespace {
+
+// Internal utility, return the CompareResult resulting from swapping
+// the left-hand side with the right-hand side.
+CompareResult Reverse(CompareResult res) {
+  switch (res) {
+    case CompareResult::kInconsistent:
+      return CompareResult::kInconsistent;
+    case CompareResult::kEQ:
+      return CompareResult::kEQ;
+    case CompareResult::kLT:
+      return CompareResult::kGT;
+    case CompareResult::kLE:
+      return CompareResult::kGE;
+    case CompareResult::kGT:
+      return CompareResult::kLT;
+    case CompareResult::kGE:
+      return CompareResult::kLE;
+    case CompareResult::kNE:
+      return CompareResult::kNE;
+    case CompareResult::kUnknown:
+      return CompareResult::kUnknown;
+    default:
+      LOG(FATAL) << "Invalid CompareResult: " << static_cast<int>(res);
+      return CompareResult::kInconsistent;
+  }
+}
+
+// Internal utility, return the CompareResult resulting from negating
+// the comparison.
+CompareResult Negate(CompareResult res) {
+  switch (res) {
+    case CompareResult::kInconsistent:
+      return CompareResult::kInconsistent;
+    case CompareResult::kUnknown:
+      return CompareResult::kUnknown;
+    default:
+      return CompareResult(~static_cast<int>(res) & static_cast<int>(CompareResult::kUnknown));
+  }
+}
+
+// Internal utility, extract constant offsets out of the two sides of
+// a comparison.  Given lhs and rhs, return a tuple of three elements
+// (lhs_inner, rhs_inner, offset), such that (lhs OP rhs) and
+// (lhs_inner OP rhs_inner + offset) are equivalent.
+std::tuple<PrimExpr, PrimExpr, int64_t> ExtractOffsets(const PrimExpr& lhs, const PrimExpr& rhs) {
+  auto extract_offset = [](const PrimExpr& expr) -> std::pair<PrimExpr, int64_t> {
+    PVar<PrimExpr> x;
+    PVar<IntImm> c;
+    if ((x + c).Match(expr)) {
+      return {x.Eval(), c.Eval()->value};
+    } else if ((x - c).Match(expr)) {
+      return {x.Eval(), -c.Eval()->value};
+    } else if (c.Match(expr)) {
+      return {0, c.Eval()->value};
+    } else {
+      return {expr, 0};
+    }
+  };
+
+  auto lhs_split = extract_offset(lhs);
+  auto rhs_split = extract_offset(rhs);
+  return {lhs_split.first, rhs_split.first, rhs_split.second - lhs_split.second};
+}
+
+}  // namespace
+
+std::optional<TransitiveComparisonAnalyzer::Impl::Comparison>
+TransitiveComparisonAnalyzer::Impl::FromExpr(const PrimExpr& expr) {
+  CompareResult res;
+  PVar<PrimExpr> x, y;
+  if ((x <= y).Match(expr)) {
+    res = CompareResult::kLE;
+  } else if ((x >= y).Match(expr)) {
+    res = CompareResult::kGE;
+  } else if ((x < y).Match(expr)) {
+    res = CompareResult::kLT;
+  } else if ((x > y).Match(expr)) {
+    res = CompareResult::kGT;
+  } else if ((x == y).Match(expr)) {
+    res = CompareResult::kEQ;
+  } else if ((x != y).Match(expr)) {
+    res = CompareResult::kNE;
+  } else {
+    return std::nullopt;
+  }
+
+  PrimExpr lhs_expr = x.Eval();
+  PrimExpr rhs_expr = y.Eval();
+
+  if (lhs_expr.as<IntImmNode>() && rhs_expr.as<IntImmNode>()) {
+    return std::nullopt;
+  }
+
+  auto [lhs, rhs, offset] = ExtractOffsets(lhs_expr, rhs_expr);
+  Key lhs_key = ExprToKey(lhs);
+  Key rhs_key = ExprToKey(rhs);
+
+  return Comparison(lhs_key, rhs_key, offset, res);
+}
+
+TransitiveComparisonAnalyzer::Impl::Comparison::Comparison(Key lhs, Key rhs, int64_t offset,
+                                                           CompareResult result)
+    : lhs_(lhs), rhs_(rhs), offset_(offset), result_(result) {
+  if (result_ == CompareResult::kLT) {
+    result_ = CompareResult::kLE;
+    offset_ -= 1;
+  }
+  if (result_ == CompareResult::kGT) {
+    result_ = CompareResult::kGE;
+    offset_ += 1;
+  }
+}
+
+std::optional<TransitiveComparisonAnalyzer::Impl::Key>
+TransitiveComparisonAnalyzer::Impl::ExprToPreviousKey(const PrimExpr& expr) const {
+  auto it = expr_to_key.find(expr);
+  if (it != expr_to_key.end()) {
+    return it->second;
+  } else {
+    return std::nullopt;
+  }
+}
+
+TransitiveComparisonAnalyzer::Impl::Key TransitiveComparisonAnalyzer::Impl::ExprToKey(
+    const PrimExpr& expr) {
+  if (auto prev = ExprToPreviousKey(expr)) {
+    return prev.value();
+  } else {
+    Key new_key = Key(expr_to_key.size());
+    expr_to_key[expr] = new_key;
+    return new_key;
+  }
+}
+
+bool TransitiveComparisonAnalyzer::Impl::Comparison::IsNormalized() const {
+  // These < and > should be removed during normalization.
+  return result_ != CompareResult::kLT && result_ != CompareResult::kGT;
+}
+
+std::optional<TransitiveComparisonAnalyzer::Impl::Comparison>
+TransitiveComparisonAnalyzer::Impl::Comparison::WithLHS(Key new_lhs) const {
+  if (new_lhs == lhs_) {
+    return *this;
+  } else if (new_lhs == rhs_) {
+    return Comparison(rhs_, lhs_, -offset_, Reverse(result_));
+  } else {
+    return std::nullopt;
+  }
+}
+
+TransitiveComparisonAnalyzer::Impl::Comparison
+TransitiveComparisonAnalyzer::Impl::Comparison::Negated() const {
+  return Comparison(lhs_, rhs_, offset_, Negate(result_));
+}
+
+bool TransitiveComparisonAnalyzer::Impl::Comparison::Implies(
+    const TransitiveComparisonAnalyzer::Impl::Comparison& other) const {
+  ICHECK(lhs_ == other.lhs_);
+  ICHECK(rhs_ == other.rhs_);
+  ICHECK(IsNormalized());
+  ICHECK(other.IsNormalized());
+
+  if (result_ == other.result_ && offset_ == other.offset_) {
+    // if c1 == c2, x != y + c1 => x != y + c2
+    // if c1 == c2, x == y + c1 => x == y + c2
+    return true;
+  }
+
+  if (other.result_ == CompareResult::kLE && offset_ <= other.offset_) {
+    if (result_ == CompareResult::kEQ || result_ == CompareResult::kLE) {
+      // if c1 <= c2, x <= y + c1 => x <= y + c2
+      // if c1 <= c2, x == y + c1 => x <= y + c2
+      return true;
+    }
+  }
+
+  if (other.result_ == CompareResult::kGE && offset_ >= other.offset_) {
+    if (result_ == CompareResult::kEQ || result_ == CompareResult::kGE) {
+      // if c1 >= c2, x == y + c1 => x >= y + c2
+      // if c1 >= c2, x >= y + c1 => x >= y + c2
+      return true;
+    }
+  }
+
+  if (other.result_ == CompareResult::kNE) {
+    if (result_ == CompareResult::kEQ && offset_ != other.offset_) {
+      // if c1 != c2, x == y + c1 => x != y + c2
+      return true;
+    }
+
+    if (result_ == CompareResult::kLE && offset_ < other.offset_) {
+      // if c1 < c2, x <= y + c1 => x < y + c2 => x != y + c2
+      return true;
+    }
+
+    if (result_ == CompareResult::kGE && offset_ > other.offset_) {
+      // if c1 != c2, x >= y + c1 => x > y + c2 => x != y + c2
+      return true;
+    }
+  }
+
+  return false;
+}
+
+TransitiveComparisonAnalyzer::TransitiveComparisonAnalyzer() : impl_(std::make_unique<Impl>()) {}
+TransitiveComparisonAnalyzer::~TransitiveComparisonAnalyzer() {}
+
+CompareResult TransitiveComparisonAnalyzer::TryCompare(const PrimExpr& lhs, const PrimExpr& rhs) {
+  return impl_->TryCompare(lhs, rhs);
+}
+
+void TransitiveComparisonAnalyzer::Bind(const Var& var, const PrimExpr& expr, bool allow_override) {
+  impl_->Bind(var, expr, allow_override);
+}
+void TransitiveComparisonAnalyzer::Bind(const Var& var, const Range& range, bool allow_override) {
+  impl_->Bind(var, range, allow_override);
+}
+
+std::function<void()> TransitiveComparisonAnalyzer::EnterConstraint(const PrimExpr& constraint) {
+  return impl_->EnterConstraint(constraint);
+}
+
+void TransitiveComparisonAnalyzer::Impl::AddKnown(const PrimExpr& expr,
+                                                  std::vector<Comparison>* vec) {
+  for (const auto& subexpr : ExtractConstraints(expr)) {
+    if (tir::SideEffect(expr) <= tir::CallEffectKind::kPure) {
+      if (auto cmp = FromExpr(subexpr)) {
+        vec->push_back(cmp.value());
+      }
+    }
+  }
+}
+
+void TransitiveComparisonAnalyzer::Impl::Bind(const tir::Var& var, const Range& range,
+                                              bool allow_override) {
+  auto it = prev_bindings_.find(var);
+  if (it != prev_bindings_.end()) {
+    ExprDeepEqual expr_equal;
+    bool differs_from_previous = !expr_equal(range->min, (*it).second->min) ||
+                                 !expr_equal(range->extent, (*it).second->extent);
+    if (differs_from_previous) {
+      ICHECK(allow_override) << "Binding of variable " << var << " as " << range
+                             << " conflicts with previous binding as " << (*it).second;
+      if (auto key = ExprToPreviousKey(var)) {
+        knowns_.erase(std::remove_if(knowns_.begin(), knowns_.end(),
+                                     [&](const auto& known) { return known.lhs_ == key.value(); }),
+                      knowns_.end());
+      }
+    }
+  }
+
+  prev_bindings_.Set(var, range);
+
+  if (is_const_int(range->extent, 1)) {
+    AddKnown(var == range->min, &knowns_);
+  } else {
+    AddKnown(var >= range->min, &knowns_);
+    AddKnown(var < range->min + range->extent, &knowns_);
+  }
+}
+
+void TransitiveComparisonAnalyzer::Impl::Bind(const tir::Var& var, const PrimExpr& expr,
+                                              bool allow_override) {
+  Bind(var, Range::FromMinExtent(expr, 1), allow_override);
+}
+
+std::function<void()> TransitiveComparisonAnalyzer::Impl::EnterConstraint(const PrimExpr& expr) {
+  size_t old_literal_size = scoped_knowns_.size();
+  AddKnown(expr, &scoped_knowns_);
+  size_t new_literal_size = scoped_knowns_.size();
+
+  PrimExpr temp = expr;
+  auto frecover = [old_literal_size, new_literal_size, this, temp]() {
+    ICHECK_EQ(scoped_knowns_.size(), new_literal_size);
+    scoped_knowns_.erase(scoped_knowns_.begin() + old_literal_size, scoped_knowns_.end());
+  };
+  return frecover;
+}
+
+CompareResult TransitiveComparisonAnalyzer::Impl::TryCompare(const PrimExpr& lhs_expr,
+                                                             const PrimExpr& rhs_expr) const {
+  // Currently only supports integer checks
+  if (!lhs_expr.dtype().is_int() || !rhs_expr.dtype().is_int()) {
+    return CompareResult::kUnknown;
+  }
+
+  // Bail out early if possible.  This int check should have been
+  // constant-folded earlier, so this check shouldn't occur.
+  auto* x_int = lhs_expr.as<IntImmNode>();
+  auto* y_int = rhs_expr.as<IntImmNode>();
+  if (x_int && y_int) {
+    if (x_int->value < y_int->value) {
+      return CompareResult::kLT;
+    } else if (x_int->value > y_int->value) {
+      return CompareResult::kGT;
+    } else {
+      return CompareResult::kEQ;
+    }
+  }
+
+  auto [lhs, rhs, offset] = ExtractOffsets(lhs_expr, rhs_expr);
+  auto lhs_key = ExprToPreviousKey(lhs);
+  auto rhs_key = ExprToPreviousKey(rhs);
+
+  if (!lhs_key.has_value() || !rhs_key.has_value()) {
+    return CompareResult::kUnknown;
+  }
+
+  auto from_lhs = TryCompareFromLHS(lhs_key.value(), rhs_key.value(), offset, lhs, rhs);
+  auto from_rhs = Reverse(TryCompareFromLHS(rhs_key.value(), lhs_key.value(), -offset, rhs, lhs));
+  auto output = from_lhs & from_rhs;
+
+  return output;
+}
+
+CompareResult TransitiveComparisonAnalyzer::Impl::TryCompareFromLHS(
+    Key lhs_key_input, Key rhs_key_input, int64_t offset_input, const PrimExpr& lhs_input,
+    const PrimExpr& rhs_input) const {
+  Key lhs_key = lhs_key_input;
+  Key rhs_key = rhs_key_input;
+  int64_t offset = offset_input;
+
+  // Everything in `to_visit` has lhs as its lhs.
+  std::unordered_set<Key> seen;
+  std::unordered_set<Key> to_visit;
+  std::unordered_map<Key, std::vector<Comparison>> compared_to_x;
+
+  // Utility function to add a new known statement
+  auto declare_known = [&](Comparison cmp) {
+    auto& prev_knowns = compared_to_x[cmp.rhs_];

Review Comment:
   Scratched my head for a while on whether `compared_to_x` always only contained default initialized vectors until I noticed you are updating the map value by reference. It maybe could have help if the type used was `std::vector<Comparison>&` to call attention to the container 🤷 



##########
src/arith/transitive_comparison_analyzer.cc:
##########
@@ -0,0 +1,683 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file tvm/arith/transitive_comparison_analyzer.cc
+ */
+
+#include <tvm/arith/analyzer.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/expr.h>
+
+#include <optional>
+#include <vector>
+
+#include "constraint_extract.h"
+#include "pattern_match.h"
+
+namespace tvm {
+namespace arith {
+
+using namespace tir;
+
+class TransitiveComparisonAnalyzer::Impl {
+ public:
+  /* \brief Using previously specified knowns, compare the expressions provided
+   *
+   * \param lhs The left-hand side of the comparison
+   *
+   * \param rhs The right-hand side of the comparison
+   *
+   * \return The most specific result that can be proven about the
+   * comparison.  If nothing can be proven, returns kUnknown.
+   */
+  CompareResult TryCompare(const PrimExpr& lhs, const PrimExpr& rhs) const;
+
+  /*! \brief Bind a variable as being equal to a known expression
+   *
+   * \param var The variable of interest.
+   * \param expr The bound expression
+   * \param allow_override Whether to allow override of existing information.
+   */
+  void Bind(const tir::Var& var, const PrimExpr& expr, bool allow_override = false);
+
+  /*! \brief Bind a variable as being within a specified range
+   *
+   * \param var The variable of interest.
+   * \param range The known range
+   * \param allow_override Whether to allow override of existing information.
+   */
+  void Bind(const tir::Var& var, const Range& expr, bool allow_override = false);
+
+  /*!
+   * \brief Update the internal state to enter constraint.
+   * \param constraint A constraint expression.
+   *
+   * \return An exit function that must be called to cleanup.  May be
+   * `nullptr`, if no cleanup is required.
+   */
+  std::function<void()> EnterConstraint(const PrimExpr& expr);
+
+ private:
+  // Utility class to avoid needing to repeatedly call ExprDeepEqual
+  enum class Key : size_t {};
+  std::optional<Key> ExprToPreviousKey(const PrimExpr& expr) const;
+  Key ExprToKey(const PrimExpr& expr);
+  std::unordered_map<PrimExpr, Key, StructuralHash, StructuralEqual> expr_to_key;
+
+  /*! \brief Internal representation of a comparison operator */
+  struct Comparison {
+    /*! \brief Construct a comparison that represents `lhs OP rhs +
+     * offset`, where the operation is specified by the CompareResult.
+     */
+    Comparison(Key lhs, Key rhs, int64_t offset, CompareResult result);
+
+    /*! \brief Utility function to validate that all GT and LT results
+     *  have been normalized out
+     */
+    bool IsNormalized() const;
+
+    /*! \brief Move the specified expression to the LHS.
+     *
+     * \param new_lhs The argument that should be moved to the LHS of the
+     * comparison.
+     *
+     * \return If possible, returns a comparison that is equivalent to
+     * the current comparison, but with the specified LHS.  If not
+     * possible, returns nullopt.
+     */
+    std::optional<Comparison> WithLHS(Key new_lhs) const;
+
+    /*! \brief Create the negation of the current comparison */
+    Comparison Negated() const;
+
+    /*! \brief Check the this comparison implies
+     *
+     * Returns true if this comparison being true implies that the
+     * other comparison must also be true.  Returns false if the other
+     * comparison cannot be shown to be true.
+     */
+    bool Implies(const Comparison& other) const;
+
+    // The LHS of the comparison
+    Key lhs_;
+
+    // The RHS of the comparison, not including any constant offset.
+    Key rhs_;
+
+    // Additive offset on rhs
+    int64_t offset_{0};
+
+    // The comparison operator.
+    CompareResult result_{CompareResult::kInconsistent};
+  };
+
+  /*! \brief Generate a Comparison representing the given expression */
+  std::optional<Comparison> FromExpr(const PrimExpr& expr);
+
+  /*! \brief Utility function used by Bind and EnterConstraint
+   *
+   * \param expr The comparison expression, to be converted into
+   * internal Comparison objects.
+   *
+   * \param vec The vector to which the Comparison objects should be
+   * appended.
+   */
+  void AddKnown(const PrimExpr& expr, std::vector<Comparison>* vec);
+
+  /*! \brief Attempt to compare, starting at the lhs.
+   *
+   * Taking each available `Comparison` as a node edge, search for a
+   * path from lhs to rhs.  For example, the priors (a<=b), (b<=c+1)
+   * and (c<=d-5) can be used to prove that (a<=d-4).
+   *
+   * \param lhs The left-hand side of the comparison
+   *
+   * \param rhs The right-hand side of the comparison
+   *
+   * \return The result of the comparison
+   */
+  CompareResult TryCompareFromLHS(Key lhs_key, Key rhs_key, int64_t offset, const PrimExpr& lhs,
+                                  const PrimExpr& rhs) const;
+
+  /*! \brief Previous Range bindings
+   *
+   * Tracked separatedly to handle the `allow_override` option used by
+   * all sub-analyzers when binding variables.
+   */
+  Map<Var, Range> prev_bindings_;
+
+  /*! \brief Known comparisons based on definitionally-true statements
+   *
+   * For example, a Let binding, or the range of an iterator.
+   */
+  std::vector<Comparison> knowns_;
+
+  /*! \brief Known comparisons based on of scope-based statements
+   *
+   * For example, the condition of an IfThenElse, which is known to be
+   * true while within the if scope.
+   */
+  std::vector<Comparison> scoped_knowns_;
+};
+
+namespace {
+
+// Internal utility, return the CompareResult resulting from swapping
+// the left-hand side with the right-hand side.
+CompareResult Reverse(CompareResult res) {
+  switch (res) {
+    case CompareResult::kInconsistent:
+      return CompareResult::kInconsistent;
+    case CompareResult::kEQ:
+      return CompareResult::kEQ;
+    case CompareResult::kLT:
+      return CompareResult::kGT;
+    case CompareResult::kLE:
+      return CompareResult::kGE;
+    case CompareResult::kGT:
+      return CompareResult::kLT;
+    case CompareResult::kGE:
+      return CompareResult::kLE;
+    case CompareResult::kNE:
+      return CompareResult::kNE;
+    case CompareResult::kUnknown:
+      return CompareResult::kUnknown;
+    default:
+      LOG(FATAL) << "Invalid CompareResult: " << static_cast<int>(res);
+      return CompareResult::kInconsistent;
+  }
+}
+
+// Internal utility, return the CompareResult resulting from negating
+// the comparison.
+CompareResult Negate(CompareResult res) {
+  switch (res) {
+    case CompareResult::kInconsistent:
+      return CompareResult::kInconsistent;
+    case CompareResult::kUnknown:
+      return CompareResult::kUnknown;
+    default:
+      return CompareResult(~static_cast<int>(res) & static_cast<int>(CompareResult::kUnknown));
+  }
+}
+
+// Internal utility, extract constant offsets out of the two sides of
+// a comparison.  Given lhs and rhs, return a tuple of three elements
+// (lhs_inner, rhs_inner, offset), such that (lhs OP rhs) and
+// (lhs_inner OP rhs_inner + offset) are equivalent.
+std::tuple<PrimExpr, PrimExpr, int64_t> ExtractOffsets(const PrimExpr& lhs, const PrimExpr& rhs) {
+  auto extract_offset = [](const PrimExpr& expr) -> std::pair<PrimExpr, int64_t> {
+    PVar<PrimExpr> x;
+    PVar<IntImm> c;
+    if ((x + c).Match(expr)) {
+      return {x.Eval(), c.Eval()->value};
+    } else if ((x - c).Match(expr)) {
+      return {x.Eval(), -c.Eval()->value};
+    } else if (c.Match(expr)) {
+      return {0, c.Eval()->value};
+    } else {
+      return {expr, 0};
+    }
+  };
+
+  auto lhs_split = extract_offset(lhs);
+  auto rhs_split = extract_offset(rhs);
+  return {lhs_split.first, rhs_split.first, rhs_split.second - lhs_split.second};
+}
+
+}  // namespace
+
+std::optional<TransitiveComparisonAnalyzer::Impl::Comparison>
+TransitiveComparisonAnalyzer::Impl::FromExpr(const PrimExpr& expr) {
+  CompareResult res;
+  PVar<PrimExpr> x, y;
+  if ((x <= y).Match(expr)) {
+    res = CompareResult::kLE;
+  } else if ((x >= y).Match(expr)) {
+    res = CompareResult::kGE;
+  } else if ((x < y).Match(expr)) {
+    res = CompareResult::kLT;
+  } else if ((x > y).Match(expr)) {
+    res = CompareResult::kGT;
+  } else if ((x == y).Match(expr)) {
+    res = CompareResult::kEQ;
+  } else if ((x != y).Match(expr)) {
+    res = CompareResult::kNE;
+  } else {
+    return std::nullopt;
+  }
+
+  PrimExpr lhs_expr = x.Eval();
+  PrimExpr rhs_expr = y.Eval();
+
+  if (lhs_expr.as<IntImmNode>() && rhs_expr.as<IntImmNode>()) {
+    return std::nullopt;
+  }
+
+  auto [lhs, rhs, offset] = ExtractOffsets(lhs_expr, rhs_expr);
+  Key lhs_key = ExprToKey(lhs);
+  Key rhs_key = ExprToKey(rhs);
+
+  return Comparison(lhs_key, rhs_key, offset, res);
+}
+
+TransitiveComparisonAnalyzer::Impl::Comparison::Comparison(Key lhs, Key rhs, int64_t offset,
+                                                           CompareResult result)
+    : lhs_(lhs), rhs_(rhs), offset_(offset), result_(result) {
+  if (result_ == CompareResult::kLT) {
+    result_ = CompareResult::kLE;
+    offset_ -= 1;
+  }
+  if (result_ == CompareResult::kGT) {
+    result_ = CompareResult::kGE;
+    offset_ += 1;
+  }
+}
+
+std::optional<TransitiveComparisonAnalyzer::Impl::Key>
+TransitiveComparisonAnalyzer::Impl::ExprToPreviousKey(const PrimExpr& expr) const {
+  auto it = expr_to_key.find(expr);
+  if (it != expr_to_key.end()) {
+    return it->second;
+  } else {
+    return std::nullopt;
+  }
+}
+
+TransitiveComparisonAnalyzer::Impl::Key TransitiveComparisonAnalyzer::Impl::ExprToKey(
+    const PrimExpr& expr) {
+  if (auto prev = ExprToPreviousKey(expr)) {
+    return prev.value();
+  } else {
+    Key new_key = Key(expr_to_key.size());
+    expr_to_key[expr] = new_key;
+    return new_key;
+  }
+}
+
+bool TransitiveComparisonAnalyzer::Impl::Comparison::IsNormalized() const {
+  // These < and > should be removed during normalization.
+  return result_ != CompareResult::kLT && result_ != CompareResult::kGT;
+}
+
+std::optional<TransitiveComparisonAnalyzer::Impl::Comparison>
+TransitiveComparisonAnalyzer::Impl::Comparison::WithLHS(Key new_lhs) const {
+  if (new_lhs == lhs_) {
+    return *this;
+  } else if (new_lhs == rhs_) {
+    return Comparison(rhs_, lhs_, -offset_, Reverse(result_));
+  } else {
+    return std::nullopt;
+  }
+}
+
+TransitiveComparisonAnalyzer::Impl::Comparison
+TransitiveComparisonAnalyzer::Impl::Comparison::Negated() const {
+  return Comparison(lhs_, rhs_, offset_, Negate(result_));
+}
+
+bool TransitiveComparisonAnalyzer::Impl::Comparison::Implies(
+    const TransitiveComparisonAnalyzer::Impl::Comparison& other) const {
+  ICHECK(lhs_ == other.lhs_);
+  ICHECK(rhs_ == other.rhs_);
+  ICHECK(IsNormalized());
+  ICHECK(other.IsNormalized());
+
+  if (result_ == other.result_ && offset_ == other.offset_) {
+    // if c1 == c2, x != y + c1 => x != y + c2
+    // if c1 == c2, x == y + c1 => x == y + c2
+    return true;
+  }
+
+  if (other.result_ == CompareResult::kLE && offset_ <= other.offset_) {
+    if (result_ == CompareResult::kEQ || result_ == CompareResult::kLE) {
+      // if c1 <= c2, x <= y + c1 => x <= y + c2
+      // if c1 <= c2, x == y + c1 => x <= y + c2
+      return true;
+    }
+  }
+
+  if (other.result_ == CompareResult::kGE && offset_ >= other.offset_) {
+    if (result_ == CompareResult::kEQ || result_ == CompareResult::kGE) {
+      // if c1 >= c2, x == y + c1 => x >= y + c2
+      // if c1 >= c2, x >= y + c1 => x >= y + c2
+      return true;
+    }
+  }
+
+  if (other.result_ == CompareResult::kNE) {
+    if (result_ == CompareResult::kEQ && offset_ != other.offset_) {
+      // if c1 != c2, x == y + c1 => x != y + c2
+      return true;
+    }
+
+    if (result_ == CompareResult::kLE && offset_ < other.offset_) {
+      // if c1 < c2, x <= y + c1 => x < y + c2 => x != y + c2
+      return true;
+    }
+
+    if (result_ == CompareResult::kGE && offset_ > other.offset_) {
+      // if c1 != c2, x >= y + c1 => x > y + c2 => x != y + c2
+      return true;
+    }
+  }
+
+  return false;
+}
+
+TransitiveComparisonAnalyzer::TransitiveComparisonAnalyzer() : impl_(std::make_unique<Impl>()) {}
+TransitiveComparisonAnalyzer::~TransitiveComparisonAnalyzer() {}
+
+CompareResult TransitiveComparisonAnalyzer::TryCompare(const PrimExpr& lhs, const PrimExpr& rhs) {
+  return impl_->TryCompare(lhs, rhs);
+}
+
+void TransitiveComparisonAnalyzer::Bind(const Var& var, const PrimExpr& expr, bool allow_override) {
+  impl_->Bind(var, expr, allow_override);
+}
+void TransitiveComparisonAnalyzer::Bind(const Var& var, const Range& range, bool allow_override) {
+  impl_->Bind(var, range, allow_override);
+}
+
+std::function<void()> TransitiveComparisonAnalyzer::EnterConstraint(const PrimExpr& constraint) {
+  return impl_->EnterConstraint(constraint);
+}
+
+void TransitiveComparisonAnalyzer::Impl::AddKnown(const PrimExpr& expr,
+                                                  std::vector<Comparison>* vec) {
+  for (const auto& subexpr : ExtractConstraints(expr)) {
+    if (tir::SideEffect(expr) <= tir::CallEffectKind::kPure) {
+      if (auto cmp = FromExpr(subexpr)) {
+        vec->push_back(cmp.value());
+      }
+    }
+  }
+}
+
+void TransitiveComparisonAnalyzer::Impl::Bind(const tir::Var& var, const Range& range,
+                                              bool allow_override) {
+  auto it = prev_bindings_.find(var);
+  if (it != prev_bindings_.end()) {
+    ExprDeepEqual expr_equal;
+    bool differs_from_previous = !expr_equal(range->min, (*it).second->min) ||
+                                 !expr_equal(range->extent, (*it).second->extent);
+    if (differs_from_previous) {
+      ICHECK(allow_override) << "Binding of variable " << var << " as " << range
+                             << " conflicts with previous binding as " << (*it).second;
+      if (auto key = ExprToPreviousKey(var)) {
+        knowns_.erase(std::remove_if(knowns_.begin(), knowns_.end(),
+                                     [&](const auto& known) { return known.lhs_ == key.value(); }),
+                      knowns_.end());
+      }
+    }
+  }
+
+  prev_bindings_.Set(var, range);
+
+  if (is_const_int(range->extent, 1)) {
+    AddKnown(var == range->min, &knowns_);
+  } else {
+    AddKnown(var >= range->min, &knowns_);
+    AddKnown(var < range->min + range->extent, &knowns_);
+  }
+}
+
+void TransitiveComparisonAnalyzer::Impl::Bind(const tir::Var& var, const PrimExpr& expr,
+                                              bool allow_override) {
+  Bind(var, Range::FromMinExtent(expr, 1), allow_override);
+}
+
+std::function<void()> TransitiveComparisonAnalyzer::Impl::EnterConstraint(const PrimExpr& expr) {
+  size_t old_literal_size = scoped_knowns_.size();
+  AddKnown(expr, &scoped_knowns_);
+  size_t new_literal_size = scoped_knowns_.size();
+
+  PrimExpr temp = expr;
+  auto frecover = [old_literal_size, new_literal_size, this, temp]() {
+    ICHECK_EQ(scoped_knowns_.size(), new_literal_size);
+    scoped_knowns_.erase(scoped_knowns_.begin() + old_literal_size, scoped_knowns_.end());
+  };
+  return frecover;
+}
+
+CompareResult TransitiveComparisonAnalyzer::Impl::TryCompare(const PrimExpr& lhs_expr,
+                                                             const PrimExpr& rhs_expr) const {
+  // Currently only supports integer checks
+  if (!lhs_expr.dtype().is_int() || !rhs_expr.dtype().is_int()) {
+    return CompareResult::kUnknown;
+  }
+
+  // Bail out early if possible.  This int check should have been
+  // constant-folded earlier, so this check shouldn't occur.
+  auto* x_int = lhs_expr.as<IntImmNode>();
+  auto* y_int = rhs_expr.as<IntImmNode>();
+  if (x_int && y_int) {
+    if (x_int->value < y_int->value) {
+      return CompareResult::kLT;
+    } else if (x_int->value > y_int->value) {
+      return CompareResult::kGT;
+    } else {
+      return CompareResult::kEQ;
+    }
+  }
+
+  auto [lhs, rhs, offset] = ExtractOffsets(lhs_expr, rhs_expr);
+  auto lhs_key = ExprToPreviousKey(lhs);
+  auto rhs_key = ExprToPreviousKey(rhs);
+
+  if (!lhs_key.has_value() || !rhs_key.has_value()) {
+    return CompareResult::kUnknown;
+  }
+
+  auto from_lhs = TryCompareFromLHS(lhs_key.value(), rhs_key.value(), offset, lhs, rhs);
+  auto from_rhs = Reverse(TryCompareFromLHS(rhs_key.value(), lhs_key.value(), -offset, rhs, lhs));
+  auto output = from_lhs & from_rhs;
+
+  return output;
+}
+
+CompareResult TransitiveComparisonAnalyzer::Impl::TryCompareFromLHS(

Review Comment:
   nit: TryCompareFromLHS is a bit long



##########
src/arith/transitive_comparison_analyzer.cc:
##########
@@ -0,0 +1,683 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file tvm/arith/transitive_comparison_analyzer.cc
+ */
+
+#include <tvm/arith/analyzer.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/expr.h>
+
+#include <optional>
+#include <vector>
+
+#include "constraint_extract.h"
+#include "pattern_match.h"
+
+namespace tvm {
+namespace arith {
+
+using namespace tir;
+
+class TransitiveComparisonAnalyzer::Impl {
+ public:
+  /* \brief Using previously specified knowns, compare the expressions provided
+   *
+   * \param lhs The left-hand side of the comparison
+   *
+   * \param rhs The right-hand side of the comparison
+   *
+   * \return The most specific result that can be proven about the
+   * comparison.  If nothing can be proven, returns kUnknown.
+   */
+  CompareResult TryCompare(const PrimExpr& lhs, const PrimExpr& rhs) const;
+
+  /*! \brief Bind a variable as being equal to a known expression
+   *
+   * \param var The variable of interest.
+   * \param expr The bound expression
+   * \param allow_override Whether to allow override of existing information.
+   */
+  void Bind(const tir::Var& var, const PrimExpr& expr, bool allow_override = false);
+
+  /*! \brief Bind a variable as being within a specified range
+   *
+   * \param var The variable of interest.
+   * \param range The known range
+   * \param allow_override Whether to allow override of existing information.
+   */
+  void Bind(const tir::Var& var, const Range& expr, bool allow_override = false);
+
+  /*!
+   * \brief Update the internal state to enter constraint.
+   * \param constraint A constraint expression.
+   *
+   * \return An exit function that must be called to cleanup.  May be
+   * `nullptr`, if no cleanup is required.
+   */
+  std::function<void()> EnterConstraint(const PrimExpr& expr);
+
+ private:
+  // Utility class to avoid needing to repeatedly call ExprDeepEqual
+  enum class Key : size_t {};
+  std::optional<Key> ExprToPreviousKey(const PrimExpr& expr) const;
+  Key ExprToKey(const PrimExpr& expr);
+  std::unordered_map<PrimExpr, Key, StructuralHash, StructuralEqual> expr_to_key;
+
+  /*! \brief Internal representation of a comparison operator */
+  struct Comparison {
+    /*! \brief Construct a comparison that represents `lhs OP rhs +
+     * offset`, where the operation is specified by the CompareResult.
+     */
+    Comparison(Key lhs, Key rhs, int64_t offset, CompareResult result);
+
+    /*! \brief Utility function to validate that all GT and LT results
+     *  have been normalized out
+     */
+    bool IsNormalized() const;
+
+    /*! \brief Move the specified expression to the LHS.
+     *
+     * \param new_lhs The argument that should be moved to the LHS of the
+     * comparison.
+     *
+     * \return If possible, returns a comparison that is equivalent to
+     * the current comparison, but with the specified LHS.  If not
+     * possible, returns nullopt.
+     */
+    std::optional<Comparison> WithLHS(Key new_lhs) const;
+
+    /*! \brief Create the negation of the current comparison */
+    Comparison Negated() const;
+
+    /*! \brief Check the this comparison implies
+     *
+     * Returns true if this comparison being true implies that the
+     * other comparison must also be true.  Returns false if the other
+     * comparison cannot be shown to be true.
+     */
+    bool Implies(const Comparison& other) const;
+
+    // The LHS of the comparison
+    Key lhs_;
+
+    // The RHS of the comparison, not including any constant offset.
+    Key rhs_;
+
+    // Additive offset on rhs
+    int64_t offset_{0};
+
+    // The comparison operator.
+    CompareResult result_{CompareResult::kInconsistent};
+  };
+
+  /*! \brief Generate a Comparison representing the given expression */
+  std::optional<Comparison> FromExpr(const PrimExpr& expr);
+
+  /*! \brief Utility function used by Bind and EnterConstraint
+   *
+   * \param expr The comparison expression, to be converted into
+   * internal Comparison objects.
+   *
+   * \param vec The vector to which the Comparison objects should be
+   * appended.
+   */
+  void AddKnown(const PrimExpr& expr, std::vector<Comparison>* vec);
+
+  /*! \brief Attempt to compare, starting at the lhs.
+   *
+   * Taking each available `Comparison` as a node edge, search for a
+   * path from lhs to rhs.  For example, the priors (a<=b), (b<=c+1)
+   * and (c<=d-5) can be used to prove that (a<=d-4).
+   *
+   * \param lhs The left-hand side of the comparison
+   *
+   * \param rhs The right-hand side of the comparison
+   *
+   * \return The result of the comparison
+   */
+  CompareResult TryCompareFromLHS(Key lhs_key, Key rhs_key, int64_t offset, const PrimExpr& lhs,
+                                  const PrimExpr& rhs) const;
+
+  /*! \brief Previous Range bindings
+   *
+   * Tracked separatedly to handle the `allow_override` option used by
+   * all sub-analyzers when binding variables.
+   */
+  Map<Var, Range> prev_bindings_;
+
+  /*! \brief Known comparisons based on definitionally-true statements
+   *
+   * For example, a Let binding, or the range of an iterator.
+   */
+  std::vector<Comparison> knowns_;
+
+  /*! \brief Known comparisons based on of scope-based statements
+   *
+   * For example, the condition of an IfThenElse, which is known to be
+   * true while within the if scope.
+   */
+  std::vector<Comparison> scoped_knowns_;
+};
+
+namespace {
+
+// Internal utility, return the CompareResult resulting from swapping
+// the left-hand side with the right-hand side.
+CompareResult Reverse(CompareResult res) {
+  switch (res) {
+    case CompareResult::kInconsistent:
+      return CompareResult::kInconsistent;
+    case CompareResult::kEQ:
+      return CompareResult::kEQ;
+    case CompareResult::kLT:
+      return CompareResult::kGT;
+    case CompareResult::kLE:
+      return CompareResult::kGE;
+    case CompareResult::kGT:
+      return CompareResult::kLT;
+    case CompareResult::kGE:
+      return CompareResult::kLE;
+    case CompareResult::kNE:
+      return CompareResult::kNE;
+    case CompareResult::kUnknown:
+      return CompareResult::kUnknown;
+    default:
+      LOG(FATAL) << "Invalid CompareResult: " << static_cast<int>(res);
+      return CompareResult::kInconsistent;
+  }
+}
+
+// Internal utility, return the CompareResult resulting from negating
+// the comparison.
+CompareResult Negate(CompareResult res) {
+  switch (res) {
+    case CompareResult::kInconsistent:
+      return CompareResult::kInconsistent;
+    case CompareResult::kUnknown:
+      return CompareResult::kUnknown;
+    default:
+      return CompareResult(~static_cast<int>(res) & static_cast<int>(CompareResult::kUnknown));
+  }
+}
+
+// Internal utility, extract constant offsets out of the two sides of
+// a comparison.  Given lhs and rhs, return a tuple of three elements
+// (lhs_inner, rhs_inner, offset), such that (lhs OP rhs) and
+// (lhs_inner OP rhs_inner + offset) are equivalent.
+std::tuple<PrimExpr, PrimExpr, int64_t> ExtractOffsets(const PrimExpr& lhs, const PrimExpr& rhs) {
+  auto extract_offset = [](const PrimExpr& expr) -> std::pair<PrimExpr, int64_t> {
+    PVar<PrimExpr> x;
+    PVar<IntImm> c;
+    if ((x + c).Match(expr)) {
+      return {x.Eval(), c.Eval()->value};
+    } else if ((x - c).Match(expr)) {
+      return {x.Eval(), -c.Eval()->value};
+    } else if (c.Match(expr)) {
+      return {0, c.Eval()->value};
+    } else {
+      return {expr, 0};
+    }
+  };
+
+  auto lhs_split = extract_offset(lhs);
+  auto rhs_split = extract_offset(rhs);
+  return {lhs_split.first, rhs_split.first, rhs_split.second - lhs_split.second};
+}
+
+}  // namespace
+
+std::optional<TransitiveComparisonAnalyzer::Impl::Comparison>
+TransitiveComparisonAnalyzer::Impl::FromExpr(const PrimExpr& expr) {
+  CompareResult res;
+  PVar<PrimExpr> x, y;
+  if ((x <= y).Match(expr)) {
+    res = CompareResult::kLE;
+  } else if ((x >= y).Match(expr)) {
+    res = CompareResult::kGE;
+  } else if ((x < y).Match(expr)) {
+    res = CompareResult::kLT;
+  } else if ((x > y).Match(expr)) {
+    res = CompareResult::kGT;
+  } else if ((x == y).Match(expr)) {
+    res = CompareResult::kEQ;
+  } else if ((x != y).Match(expr)) {
+    res = CompareResult::kNE;
+  } else {
+    return std::nullopt;
+  }
+
+  PrimExpr lhs_expr = x.Eval();
+  PrimExpr rhs_expr = y.Eval();
+
+  if (lhs_expr.as<IntImmNode>() && rhs_expr.as<IntImmNode>()) {
+    return std::nullopt;
+  }
+
+  auto [lhs, rhs, offset] = ExtractOffsets(lhs_expr, rhs_expr);
+  Key lhs_key = ExprToKey(lhs);
+  Key rhs_key = ExprToKey(rhs);
+
+  return Comparison(lhs_key, rhs_key, offset, res);
+}
+
+TransitiveComparisonAnalyzer::Impl::Comparison::Comparison(Key lhs, Key rhs, int64_t offset,
+                                                           CompareResult result)
+    : lhs_(lhs), rhs_(rhs), offset_(offset), result_(result) {
+  if (result_ == CompareResult::kLT) {
+    result_ = CompareResult::kLE;
+    offset_ -= 1;
+  }
+  if (result_ == CompareResult::kGT) {
+    result_ = CompareResult::kGE;
+    offset_ += 1;
+  }
+}
+
+std::optional<TransitiveComparisonAnalyzer::Impl::Key>
+TransitiveComparisonAnalyzer::Impl::ExprToPreviousKey(const PrimExpr& expr) const {
+  auto it = expr_to_key.find(expr);
+  if (it != expr_to_key.end()) {
+    return it->second;
+  } else {
+    return std::nullopt;
+  }
+}
+
+TransitiveComparisonAnalyzer::Impl::Key TransitiveComparisonAnalyzer::Impl::ExprToKey(
+    const PrimExpr& expr) {
+  if (auto prev = ExprToPreviousKey(expr)) {
+    return prev.value();
+  } else {
+    Key new_key = Key(expr_to_key.size());
+    expr_to_key[expr] = new_key;
+    return new_key;
+  }
+}
+
+bool TransitiveComparisonAnalyzer::Impl::Comparison::IsNormalized() const {
+  // These < and > should be removed during normalization.
+  return result_ != CompareResult::kLT && result_ != CompareResult::kGT;
+}
+
+std::optional<TransitiveComparisonAnalyzer::Impl::Comparison>
+TransitiveComparisonAnalyzer::Impl::Comparison::WithLHS(Key new_lhs) const {
+  if (new_lhs == lhs_) {
+    return *this;
+  } else if (new_lhs == rhs_) {
+    return Comparison(rhs_, lhs_, -offset_, Reverse(result_));
+  } else {
+    return std::nullopt;
+  }
+}
+
+TransitiveComparisonAnalyzer::Impl::Comparison
+TransitiveComparisonAnalyzer::Impl::Comparison::Negated() const {
+  return Comparison(lhs_, rhs_, offset_, Negate(result_));
+}
+
+bool TransitiveComparisonAnalyzer::Impl::Comparison::Implies(
+    const TransitiveComparisonAnalyzer::Impl::Comparison& other) const {
+  ICHECK(lhs_ == other.lhs_);
+  ICHECK(rhs_ == other.rhs_);
+  ICHECK(IsNormalized());
+  ICHECK(other.IsNormalized());
+
+  if (result_ == other.result_ && offset_ == other.offset_) {
+    // if c1 == c2, x != y + c1 => x != y + c2
+    // if c1 == c2, x == y + c1 => x == y + c2
+    return true;
+  }
+
+  if (other.result_ == CompareResult::kLE && offset_ <= other.offset_) {
+    if (result_ == CompareResult::kEQ || result_ == CompareResult::kLE) {
+      // if c1 <= c2, x <= y + c1 => x <= y + c2
+      // if c1 <= c2, x == y + c1 => x <= y + c2
+      return true;
+    }
+  }
+
+  if (other.result_ == CompareResult::kGE && offset_ >= other.offset_) {
+    if (result_ == CompareResult::kEQ || result_ == CompareResult::kGE) {
+      // if c1 >= c2, x == y + c1 => x >= y + c2
+      // if c1 >= c2, x >= y + c1 => x >= y + c2
+      return true;
+    }
+  }
+
+  if (other.result_ == CompareResult::kNE) {
+    if (result_ == CompareResult::kEQ && offset_ != other.offset_) {
+      // if c1 != c2, x == y + c1 => x != y + c2
+      return true;
+    }
+
+    if (result_ == CompareResult::kLE && offset_ < other.offset_) {
+      // if c1 < c2, x <= y + c1 => x < y + c2 => x != y + c2
+      return true;
+    }
+
+    if (result_ == CompareResult::kGE && offset_ > other.offset_) {
+      // if c1 != c2, x >= y + c1 => x > y + c2 => x != y + c2
+      return true;
+    }
+  }
+
+  return false;
+}
+
+TransitiveComparisonAnalyzer::TransitiveComparisonAnalyzer() : impl_(std::make_unique<Impl>()) {}
+TransitiveComparisonAnalyzer::~TransitiveComparisonAnalyzer() {}
+
+CompareResult TransitiveComparisonAnalyzer::TryCompare(const PrimExpr& lhs, const PrimExpr& rhs) {
+  return impl_->TryCompare(lhs, rhs);
+}
+
+void TransitiveComparisonAnalyzer::Bind(const Var& var, const PrimExpr& expr, bool allow_override) {
+  impl_->Bind(var, expr, allow_override);
+}
+void TransitiveComparisonAnalyzer::Bind(const Var& var, const Range& range, bool allow_override) {
+  impl_->Bind(var, range, allow_override);
+}
+
+std::function<void()> TransitiveComparisonAnalyzer::EnterConstraint(const PrimExpr& constraint) {
+  return impl_->EnterConstraint(constraint);
+}
+
+void TransitiveComparisonAnalyzer::Impl::AddKnown(const PrimExpr& expr,
+                                                  std::vector<Comparison>* vec) {
+  for (const auto& subexpr : ExtractConstraints(expr)) {
+    if (tir::SideEffect(expr) <= tir::CallEffectKind::kPure) {
+      if (auto cmp = FromExpr(subexpr)) {
+        vec->push_back(cmp.value());
+      }
+    }
+  }
+}
+
+void TransitiveComparisonAnalyzer::Impl::Bind(const tir::Var& var, const Range& range,
+                                              bool allow_override) {
+  auto it = prev_bindings_.find(var);
+  if (it != prev_bindings_.end()) {
+    ExprDeepEqual expr_equal;
+    bool differs_from_previous = !expr_equal(range->min, (*it).second->min) ||
+                                 !expr_equal(range->extent, (*it).second->extent);
+    if (differs_from_previous) {
+      ICHECK(allow_override) << "Binding of variable " << var << " as " << range
+                             << " conflicts with previous binding as " << (*it).second;
+      if (auto key = ExprToPreviousKey(var)) {
+        knowns_.erase(std::remove_if(knowns_.begin(), knowns_.end(),
+                                     [&](const auto& known) { return known.lhs_ == key.value(); }),
+                      knowns_.end());
+      }
+    }
+  }
+
+  prev_bindings_.Set(var, range);
+
+  if (is_const_int(range->extent, 1)) {
+    AddKnown(var == range->min, &knowns_);
+  } else {
+    AddKnown(var >= range->min, &knowns_);
+    AddKnown(var < range->min + range->extent, &knowns_);
+  }
+}
+
+void TransitiveComparisonAnalyzer::Impl::Bind(const tir::Var& var, const PrimExpr& expr,
+                                              bool allow_override) {
+  Bind(var, Range::FromMinExtent(expr, 1), allow_override);
+}
+
+std::function<void()> TransitiveComparisonAnalyzer::Impl::EnterConstraint(const PrimExpr& expr) {
+  size_t old_literal_size = scoped_knowns_.size();
+  AddKnown(expr, &scoped_knowns_);
+  size_t new_literal_size = scoped_knowns_.size();
+
+  PrimExpr temp = expr;
+  auto frecover = [old_literal_size, new_literal_size, this, temp]() {
+    ICHECK_EQ(scoped_knowns_.size(), new_literal_size);
+    scoped_knowns_.erase(scoped_knowns_.begin() + old_literal_size, scoped_knowns_.end());
+  };
+  return frecover;
+}
+
+CompareResult TransitiveComparisonAnalyzer::Impl::TryCompare(const PrimExpr& lhs_expr,
+                                                             const PrimExpr& rhs_expr) const {
+  // Currently only supports integer checks
+  if (!lhs_expr.dtype().is_int() || !rhs_expr.dtype().is_int()) {
+    return CompareResult::kUnknown;
+  }
+
+  // Bail out early if possible.  This int check should have been
+  // constant-folded earlier, so this check shouldn't occur.
+  auto* x_int = lhs_expr.as<IntImmNode>();
+  auto* y_int = rhs_expr.as<IntImmNode>();
+  if (x_int && y_int) {
+    if (x_int->value < y_int->value) {
+      return CompareResult::kLT;
+    } else if (x_int->value > y_int->value) {
+      return CompareResult::kGT;
+    } else {
+      return CompareResult::kEQ;
+    }
+  }
+
+  auto [lhs, rhs, offset] = ExtractOffsets(lhs_expr, rhs_expr);
+  auto lhs_key = ExprToPreviousKey(lhs);
+  auto rhs_key = ExprToPreviousKey(rhs);
+
+  if (!lhs_key.has_value() || !rhs_key.has_value()) {
+    return CompareResult::kUnknown;
+  }
+
+  auto from_lhs = TryCompareFromLHS(lhs_key.value(), rhs_key.value(), offset, lhs, rhs);
+  auto from_rhs = Reverse(TryCompareFromLHS(rhs_key.value(), lhs_key.value(), -offset, rhs, lhs));
+  auto output = from_lhs & from_rhs;
+
+  return output;
+}
+
+CompareResult TransitiveComparisonAnalyzer::Impl::TryCompareFromLHS(
+    Key lhs_key_input, Key rhs_key_input, int64_t offset_input, const PrimExpr& lhs_input,
+    const PrimExpr& rhs_input) const {
+  Key lhs_key = lhs_key_input;
+  Key rhs_key = rhs_key_input;
+  int64_t offset = offset_input;
+
+  // Everything in `to_visit` has lhs as its lhs.
+  std::unordered_set<Key> seen;
+  std::unordered_set<Key> to_visit;
+  std::unordered_map<Key, std::vector<Comparison>> compared_to_x;
+
+  // Utility function to add a new known statement
+  auto declare_known = [&](Comparison cmp) {
+    auto& prev_knowns = compared_to_x[cmp.rhs_];
+
+    for (auto& prev_known : prev_knowns) {
+      if (prev_known.Implies(cmp)) {
+        return;
+      }
+    }
+
+    if (cmp.rhs_ != rhs_key && !seen.count(cmp.rhs_)) {
+      to_visit.insert(cmp.rhs_);
+      seen.insert(cmp.rhs_);
+    }
+
+    for (auto& prev_known : prev_knowns) {
+      if (cmp.Implies(prev_known)) {
+        prev_known = cmp;
+        return;
+      }
+    }
+
+    prev_knowns.push_back(cmp);
+  };
+
+  // Initialize the search based on any known (in)equalities that use
+  // the LHS of the comparison.
+  for (const auto& known : knowns_) {
+    if (auto normalized = known.WithLHS(lhs_key)) {
+      declare_known(normalized.value());
+    }
+  }
+  for (const auto& known : scoped_knowns_) {
+    if (auto normalized = known.WithLHS(lhs_key)) {
+      declare_known(normalized.value());
+    }
+  }
+
+  // Walk through the space of all comparisons that can be made with
+  // LHS.
+  while (to_visit.size()) {
+    Key middle_key = *to_visit.begin();
+    to_visit.erase(to_visit.begin());
+
+    std::vector<Comparison>& prev_knowns_using_middle = compared_to_x.at(middle_key);
+    ICHECK(compared_to_x.count(middle_key));
+
+    std::vector<Comparison> new_knowns_using_lhs;
+
+    auto attempt_transitive = [&](Comparison cmp) {
+      ICHECK(cmp.IsNormalized());
+
+      Key right_key = cmp.rhs_;
+
+      if (right_key == lhs_key) {
+        return;
+      }
+
+      for (const auto& prev : prev_knowns_using_middle) {
+        CompareResult new_result = CompareResult::kUnknown;
+        int64_t new_offset = prev.offset_ + cmp.offset_;
+
+        if (prev.result_ == CompareResult::kEQ) {
+          // x == y + c1 && y OP z + c2, x OP z + (c1 + c2)
+          new_result = cmp.result_;
+        } else if (cmp.result_ == CompareResult::kEQ) {
+          // x OP y + c1 && y == z + c2, x OP z + (c1 + c2)
+          new_result = prev.result_;
+        } else if (prev.result_ == cmp.result_ &&
+                   (prev.result_ == CompareResult::kLE || prev.result_ == CompareResult::kGE)) {
+          // x <= y + c1 && y <= z + c2, x <= z + (c1 + c2)
+          // x >= y + c1 && y >= z + c2, x >= z + (c1 + c2)
+          //
+          // This condition is much simpler to write than the
+          // equivalent handling of < or of >, which is why the
+          // inequalities are normalized to <= and to >=.

Review Comment:
   Ah, here is the reasoning for such normalization! Some discussion of this in the docstring brief for IsNormalized would be great :).



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] Lunderberg commented on a diff in pull request #12863: [TIR][Arith] Prove conditionals by transitively applying knowns

Posted by GitBox <gi...@apache.org>.
Lunderberg commented on code in PR #12863:
URL: https://github.com/apache/tvm/pull/12863#discussion_r989376748


##########
include/tvm/arith/analyzer.h:
##########
@@ -275,6 +275,36 @@ class RewriteSimplifier {
    */
   std::function<void()> EnterConstraint(const PrimExpr& constraint);
 
+  /*! \brief Flags to enable more computationally-intensive simplifications
+   *
+   * These simplifications may be required for specific schedules, but
+   * would impose too high a compile-time cost to enable by default.
+   * They can be enabled on an as-needed basis by calling
+   * `RewriteSimplifier::SetEnabledFeatures` prior to using
+   * `RewriteSimplifier::operator()`.
+   */
+  enum Feature {
+    // No features enabled
+    kNone = 0,
+
+    /* When simplifying an inequality, attempt to use scope-based knowns.
+     *
+     * Example:
+     * if_then_else(i<j && j<k, i<k, false) => if_then_else(i<j && j<k, true, false)
+     */
+    kTransitivelyProveInequalities = (1 << 0),

Review Comment:
   That's correct, additional entries are expected.  (e.g. https://github.com/apache/tvm/pull/12972)  Good call on the documentation, and I've added a comment about it.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] Lunderberg commented on a diff in pull request #12863: [TIR][Arith] Prove conditionals by transitively applying knowns

Posted by GitBox <gi...@apache.org>.
Lunderberg commented on code in PR #12863:
URL: https://github.com/apache/tvm/pull/12863#discussion_r989456039


##########
src/arith/transitive_comparison_analyzer.cc:
##########
@@ -0,0 +1,683 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file tvm/arith/transitive_comparison_analyzer.cc
+ */
+
+#include <tvm/arith/analyzer.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/expr.h>
+
+#include <optional>
+#include <vector>
+
+#include "constraint_extract.h"
+#include "pattern_match.h"
+
+namespace tvm {
+namespace arith {
+
+using namespace tir;
+
+class TransitiveComparisonAnalyzer::Impl {
+ public:
+  /* \brief Using previously specified knowns, compare the expressions provided
+   *
+   * \param lhs The left-hand side of the comparison
+   *
+   * \param rhs The right-hand side of the comparison
+   *
+   * \return The most specific result that can be proven about the
+   * comparison.  If nothing can be proven, returns kUnknown.
+   */
+  CompareResult TryCompare(const PrimExpr& lhs, const PrimExpr& rhs) const;
+
+  /*! \brief Bind a variable as being equal to a known expression
+   *
+   * \param var The variable of interest.
+   * \param expr The bound expression
+   * \param allow_override Whether to allow override of existing information.
+   */
+  void Bind(const tir::Var& var, const PrimExpr& expr, bool allow_override = false);
+
+  /*! \brief Bind a variable as being within a specified range
+   *
+   * \param var The variable of interest.
+   * \param range The known range
+   * \param allow_override Whether to allow override of existing information.
+   */
+  void Bind(const tir::Var& var, const Range& expr, bool allow_override = false);
+
+  /*!
+   * \brief Update the internal state to enter constraint.
+   * \param constraint A constraint expression.
+   *
+   * \return An exit function that must be called to cleanup.  May be
+   * `nullptr`, if no cleanup is required.
+   */
+  std::function<void()> EnterConstraint(const PrimExpr& expr);
+
+ private:
+  // Utility class to avoid needing to repeatedly call ExprDeepEqual
+  enum class Key : size_t {};
+  std::optional<Key> ExprToPreviousKey(const PrimExpr& expr) const;
+  Key ExprToKey(const PrimExpr& expr);
+  std::unordered_map<PrimExpr, Key, StructuralHash, StructuralEqual> expr_to_key;
+
+  /*! \brief Internal representation of a comparison operator */
+  struct Comparison {
+    /*! \brief Construct a comparison that represents `lhs OP rhs +
+     * offset`, where the operation is specified by the CompareResult.
+     */
+    Comparison(Key lhs, Key rhs, int64_t offset, CompareResult result);
+
+    /*! \brief Utility function to validate that all GT and LT results
+     *  have been normalized out
+     */
+    bool IsNormalized() const;
+
+    /*! \brief Move the specified expression to the LHS.
+     *
+     * \param new_lhs The argument that should be moved to the LHS of the
+     * comparison.
+     *
+     * \return If possible, returns a comparison that is equivalent to
+     * the current comparison, but with the specified LHS.  If not
+     * possible, returns nullopt.
+     */
+    std::optional<Comparison> WithLHS(Key new_lhs) const;
+
+    /*! \brief Create the negation of the current comparison */
+    Comparison Negated() const;
+
+    /*! \brief Check the this comparison implies
+     *
+     * Returns true if this comparison being true implies that the
+     * other comparison must also be true.  Returns false if the other
+     * comparison cannot be shown to be true.
+     */
+    bool Implies(const Comparison& other) const;
+
+    // The LHS of the comparison
+    Key lhs_;
+
+    // The RHS of the comparison, not including any constant offset.
+    Key rhs_;
+
+    // Additive offset on rhs
+    int64_t offset_{0};
+
+    // The comparison operator.
+    CompareResult result_{CompareResult::kInconsistent};
+  };
+
+  /*! \brief Generate a Comparison representing the given expression */
+  std::optional<Comparison> FromExpr(const PrimExpr& expr);
+
+  /*! \brief Utility function used by Bind and EnterConstraint
+   *
+   * \param expr The comparison expression, to be converted into
+   * internal Comparison objects.
+   *
+   * \param vec The vector to which the Comparison objects should be
+   * appended.
+   */
+  void AddKnown(const PrimExpr& expr, std::vector<Comparison>* vec);
+
+  /*! \brief Attempt to compare, starting at the lhs.
+   *
+   * Taking each available `Comparison` as a node edge, search for a
+   * path from lhs to rhs.  For example, the priors (a<=b), (b<=c+1)
+   * and (c<=d-5) can be used to prove that (a<=d-4).
+   *
+   * \param lhs The left-hand side of the comparison
+   *
+   * \param rhs The right-hand side of the comparison
+   *
+   * \return The result of the comparison
+   */
+  CompareResult TryCompareFromLHS(Key lhs_key, Key rhs_key, int64_t offset, const PrimExpr& lhs,
+                                  const PrimExpr& rhs) const;
+
+  /*! \brief Previous Range bindings
+   *
+   * Tracked separatedly to handle the `allow_override` option used by
+   * all sub-analyzers when binding variables.
+   */
+  Map<Var, Range> prev_bindings_;
+
+  /*! \brief Known comparisons based on definitionally-true statements
+   *
+   * For example, a Let binding, or the range of an iterator.
+   */
+  std::vector<Comparison> knowns_;
+
+  /*! \brief Known comparisons based on of scope-based statements
+   *
+   * For example, the condition of an IfThenElse, which is known to be
+   * true while within the if scope.
+   */
+  std::vector<Comparison> scoped_knowns_;
+};
+
+namespace {
+
+// Internal utility, return the CompareResult resulting from swapping
+// the left-hand side with the right-hand side.
+CompareResult Reverse(CompareResult res) {
+  switch (res) {
+    case CompareResult::kInconsistent:
+      return CompareResult::kInconsistent;
+    case CompareResult::kEQ:
+      return CompareResult::kEQ;
+    case CompareResult::kLT:
+      return CompareResult::kGT;
+    case CompareResult::kLE:
+      return CompareResult::kGE;
+    case CompareResult::kGT:
+      return CompareResult::kLT;
+    case CompareResult::kGE:
+      return CompareResult::kLE;
+    case CompareResult::kNE:
+      return CompareResult::kNE;
+    case CompareResult::kUnknown:
+      return CompareResult::kUnknown;
+    default:
+      LOG(FATAL) << "Invalid CompareResult: " << static_cast<int>(res);
+      return CompareResult::kInconsistent;
+  }
+}
+
+// Internal utility, return the CompareResult resulting from negating
+// the comparison.
+CompareResult Negate(CompareResult res) {
+  switch (res) {
+    case CompareResult::kInconsistent:
+      return CompareResult::kInconsistent;
+    case CompareResult::kUnknown:
+      return CompareResult::kUnknown;
+    default:
+      return CompareResult(~static_cast<int>(res) & static_cast<int>(CompareResult::kUnknown));
+  }
+}
+
+// Internal utility, extract constant offsets out of the two sides of
+// a comparison.  Given lhs and rhs, return a tuple of three elements
+// (lhs_inner, rhs_inner, offset), such that (lhs OP rhs) and
+// (lhs_inner OP rhs_inner + offset) are equivalent.
+std::tuple<PrimExpr, PrimExpr, int64_t> ExtractOffsets(const PrimExpr& lhs, const PrimExpr& rhs) {
+  auto extract_offset = [](const PrimExpr& expr) -> std::pair<PrimExpr, int64_t> {
+    PVar<PrimExpr> x;
+    PVar<IntImm> c;
+    if ((x + c).Match(expr)) {
+      return {x.Eval(), c.Eval()->value};
+    } else if ((x - c).Match(expr)) {
+      return {x.Eval(), -c.Eval()->value};
+    } else if (c.Match(expr)) {
+      return {0, c.Eval()->value};
+    } else {
+      return {expr, 0};
+    }
+  };
+
+  auto lhs_split = extract_offset(lhs);
+  auto rhs_split = extract_offset(rhs);
+  return {lhs_split.first, rhs_split.first, rhs_split.second - lhs_split.second};
+}
+
+}  // namespace
+
+std::optional<TransitiveComparisonAnalyzer::Impl::Comparison>
+TransitiveComparisonAnalyzer::Impl::FromExpr(const PrimExpr& expr) {
+  CompareResult res;
+  PVar<PrimExpr> x, y;
+  if ((x <= y).Match(expr)) {
+    res = CompareResult::kLE;
+  } else if ((x >= y).Match(expr)) {
+    res = CompareResult::kGE;
+  } else if ((x < y).Match(expr)) {
+    res = CompareResult::kLT;
+  } else if ((x > y).Match(expr)) {
+    res = CompareResult::kGT;
+  } else if ((x == y).Match(expr)) {
+    res = CompareResult::kEQ;
+  } else if ((x != y).Match(expr)) {
+    res = CompareResult::kNE;
+  } else {
+    return std::nullopt;
+  }
+
+  PrimExpr lhs_expr = x.Eval();
+  PrimExpr rhs_expr = y.Eval();
+
+  if (lhs_expr.as<IntImmNode>() && rhs_expr.as<IntImmNode>()) {
+    return std::nullopt;
+  }
+
+  auto [lhs, rhs, offset] = ExtractOffsets(lhs_expr, rhs_expr);
+  Key lhs_key = ExprToKey(lhs);
+  Key rhs_key = ExprToKey(rhs);
+
+  return Comparison(lhs_key, rhs_key, offset, res);
+}
+
+TransitiveComparisonAnalyzer::Impl::Comparison::Comparison(Key lhs, Key rhs, int64_t offset,
+                                                           CompareResult result)
+    : lhs_(lhs), rhs_(rhs), offset_(offset), result_(result) {
+  if (result_ == CompareResult::kLT) {
+    result_ = CompareResult::kLE;
+    offset_ -= 1;
+  }
+  if (result_ == CompareResult::kGT) {
+    result_ = CompareResult::kGE;
+    offset_ += 1;
+  }
+}
+
+std::optional<TransitiveComparisonAnalyzer::Impl::Key>
+TransitiveComparisonAnalyzer::Impl::ExprToPreviousKey(const PrimExpr& expr) const {
+  auto it = expr_to_key.find(expr);
+  if (it != expr_to_key.end()) {
+    return it->second;
+  } else {
+    return std::nullopt;
+  }
+}
+
+TransitiveComparisonAnalyzer::Impl::Key TransitiveComparisonAnalyzer::Impl::ExprToKey(
+    const PrimExpr& expr) {
+  if (auto prev = ExprToPreviousKey(expr)) {
+    return prev.value();
+  } else {
+    Key new_key = Key(expr_to_key.size());
+    expr_to_key[expr] = new_key;
+    return new_key;
+  }
+}
+
+bool TransitiveComparisonAnalyzer::Impl::Comparison::IsNormalized() const {
+  // These < and > should be removed during normalization.
+  return result_ != CompareResult::kLT && result_ != CompareResult::kGT;
+}
+
+std::optional<TransitiveComparisonAnalyzer::Impl::Comparison>
+TransitiveComparisonAnalyzer::Impl::Comparison::WithLHS(Key new_lhs) const {
+  if (new_lhs == lhs_) {
+    return *this;
+  } else if (new_lhs == rhs_) {
+    return Comparison(rhs_, lhs_, -offset_, Reverse(result_));
+  } else {
+    return std::nullopt;
+  }
+}
+
+TransitiveComparisonAnalyzer::Impl::Comparison
+TransitiveComparisonAnalyzer::Impl::Comparison::Negated() const {
+  return Comparison(lhs_, rhs_, offset_, Negate(result_));
+}
+
+bool TransitiveComparisonAnalyzer::Impl::Comparison::Implies(
+    const TransitiveComparisonAnalyzer::Impl::Comparison& other) const {
+  ICHECK(lhs_ == other.lhs_);
+  ICHECK(rhs_ == other.rhs_);
+  ICHECK(IsNormalized());
+  ICHECK(other.IsNormalized());
+
+  if (result_ == other.result_ && offset_ == other.offset_) {
+    // if c1 == c2, x != y + c1 => x != y + c2
+    // if c1 == c2, x == y + c1 => x == y + c2
+    return true;
+  }
+
+  if (other.result_ == CompareResult::kLE && offset_ <= other.offset_) {
+    if (result_ == CompareResult::kEQ || result_ == CompareResult::kLE) {
+      // if c1 <= c2, x <= y + c1 => x <= y + c2
+      // if c1 <= c2, x == y + c1 => x <= y + c2
+      return true;
+    }
+  }
+
+  if (other.result_ == CompareResult::kGE && offset_ >= other.offset_) {
+    if (result_ == CompareResult::kEQ || result_ == CompareResult::kGE) {
+      // if c1 >= c2, x == y + c1 => x >= y + c2
+      // if c1 >= c2, x >= y + c1 => x >= y + c2
+      return true;
+    }
+  }
+
+  if (other.result_ == CompareResult::kNE) {
+    if (result_ == CompareResult::kEQ && offset_ != other.offset_) {
+      // if c1 != c2, x == y + c1 => x != y + c2
+      return true;
+    }
+
+    if (result_ == CompareResult::kLE && offset_ < other.offset_) {
+      // if c1 < c2, x <= y + c1 => x < y + c2 => x != y + c2
+      return true;
+    }
+
+    if (result_ == CompareResult::kGE && offset_ > other.offset_) {
+      // if c1 != c2, x >= y + c1 => x > y + c2 => x != y + c2
+      return true;
+    }
+  }
+
+  return false;
+}
+
+TransitiveComparisonAnalyzer::TransitiveComparisonAnalyzer() : impl_(std::make_unique<Impl>()) {}
+TransitiveComparisonAnalyzer::~TransitiveComparisonAnalyzer() {}
+
+CompareResult TransitiveComparisonAnalyzer::TryCompare(const PrimExpr& lhs, const PrimExpr& rhs) {
+  return impl_->TryCompare(lhs, rhs);
+}
+
+void TransitiveComparisonAnalyzer::Bind(const Var& var, const PrimExpr& expr, bool allow_override) {
+  impl_->Bind(var, expr, allow_override);
+}
+void TransitiveComparisonAnalyzer::Bind(const Var& var, const Range& range, bool allow_override) {
+  impl_->Bind(var, range, allow_override);
+}
+
+std::function<void()> TransitiveComparisonAnalyzer::EnterConstraint(const PrimExpr& constraint) {
+  return impl_->EnterConstraint(constraint);
+}
+
+void TransitiveComparisonAnalyzer::Impl::AddKnown(const PrimExpr& expr,
+                                                  std::vector<Comparison>* vec) {
+  for (const auto& subexpr : ExtractConstraints(expr)) {
+    if (tir::SideEffect(expr) <= tir::CallEffectKind::kPure) {
+      if (auto cmp = FromExpr(subexpr)) {
+        vec->push_back(cmp.value());
+      }
+    }
+  }
+}
+
+void TransitiveComparisonAnalyzer::Impl::Bind(const tir::Var& var, const Range& range,
+                                              bool allow_override) {
+  auto it = prev_bindings_.find(var);
+  if (it != prev_bindings_.end()) {
+    ExprDeepEqual expr_equal;
+    bool differs_from_previous = !expr_equal(range->min, (*it).second->min) ||
+                                 !expr_equal(range->extent, (*it).second->extent);
+    if (differs_from_previous) {
+      ICHECK(allow_override) << "Binding of variable " << var << " as " << range
+                             << " conflicts with previous binding as " << (*it).second;
+      if (auto key = ExprToPreviousKey(var)) {
+        knowns_.erase(std::remove_if(knowns_.begin(), knowns_.end(),
+                                     [&](const auto& known) { return known.lhs_ == key.value(); }),
+                      knowns_.end());
+      }
+    }
+  }
+
+  prev_bindings_.Set(var, range);
+
+  if (is_const_int(range->extent, 1)) {
+    AddKnown(var == range->min, &knowns_);
+  } else {
+    AddKnown(var >= range->min, &knowns_);
+    AddKnown(var < range->min + range->extent, &knowns_);
+  }
+}
+
+void TransitiveComparisonAnalyzer::Impl::Bind(const tir::Var& var, const PrimExpr& expr,
+                                              bool allow_override) {
+  Bind(var, Range::FromMinExtent(expr, 1), allow_override);
+}
+
+std::function<void()> TransitiveComparisonAnalyzer::Impl::EnterConstraint(const PrimExpr& expr) {
+  size_t old_literal_size = scoped_knowns_.size();
+  AddKnown(expr, &scoped_knowns_);
+  size_t new_literal_size = scoped_knowns_.size();
+
+  PrimExpr temp = expr;
+  auto frecover = [old_literal_size, new_literal_size, this, temp]() {
+    ICHECK_EQ(scoped_knowns_.size(), new_literal_size);
+    scoped_knowns_.erase(scoped_knowns_.begin() + old_literal_size, scoped_knowns_.end());
+  };
+  return frecover;
+}
+
+CompareResult TransitiveComparisonAnalyzer::Impl::TryCompare(const PrimExpr& lhs_expr,
+                                                             const PrimExpr& rhs_expr) const {
+  // Currently only supports integer checks
+  if (!lhs_expr.dtype().is_int() || !rhs_expr.dtype().is_int()) {
+    return CompareResult::kUnknown;
+  }
+
+  // Bail out early if possible.  This int check should have been
+  // constant-folded earlier, so this check shouldn't occur.
+  auto* x_int = lhs_expr.as<IntImmNode>();
+  auto* y_int = rhs_expr.as<IntImmNode>();
+  if (x_int && y_int) {
+    if (x_int->value < y_int->value) {
+      return CompareResult::kLT;
+    } else if (x_int->value > y_int->value) {
+      return CompareResult::kGT;
+    } else {
+      return CompareResult::kEQ;
+    }
+  }
+
+  auto [lhs, rhs, offset] = ExtractOffsets(lhs_expr, rhs_expr);
+  auto lhs_key = ExprToPreviousKey(lhs);
+  auto rhs_key = ExprToPreviousKey(rhs);
+
+  if (!lhs_key.has_value() || !rhs_key.has_value()) {
+    return CompareResult::kUnknown;
+  }
+
+  auto from_lhs = TryCompareFromLHS(lhs_key.value(), rhs_key.value(), offset, lhs, rhs);
+  auto from_rhs = Reverse(TryCompareFromLHS(rhs_key.value(), lhs_key.value(), -offset, rhs, lhs));
+  auto output = from_lhs & from_rhs;
+
+  return output;
+}
+
+CompareResult TransitiveComparisonAnalyzer::Impl::TryCompareFromLHS(
+    Key lhs_key_input, Key rhs_key_input, int64_t offset_input, const PrimExpr& lhs_input,
+    const PrimExpr& rhs_input) const {
+  Key lhs_key = lhs_key_input;
+  Key rhs_key = rhs_key_input;
+  int64_t offset = offset_input;
+
+  // Everything in `to_visit` has lhs as its lhs.
+  std::unordered_set<Key> seen;
+  std::unordered_set<Key> to_visit;
+  std::unordered_map<Key, std::vector<Comparison>> compared_to_x;
+
+  // Utility function to add a new known statement
+  auto declare_known = [&](Comparison cmp) {
+    auto& prev_knowns = compared_to_x[cmp.rhs_];
+
+    for (auto& prev_known : prev_knowns) {
+      if (prev_known.Implies(cmp)) {
+        return;
+      }
+    }
+
+    if (cmp.rhs_ != rhs_key && !seen.count(cmp.rhs_)) {
+      to_visit.insert(cmp.rhs_);
+      seen.insert(cmp.rhs_);
+    }
+
+    for (auto& prev_known : prev_knowns) {
+      if (cmp.Implies(prev_known)) {
+        prev_known = cmp;
+        return;
+      }
+    }
+
+    prev_knowns.push_back(cmp);
+  };
+
+  // Initialize the search based on any known (in)equalities that use
+  // the LHS of the comparison.
+  for (const auto& known : knowns_) {
+    if (auto normalized = known.WithLHS(lhs_key)) {
+      declare_known(normalized.value());
+    }
+  }
+  for (const auto& known : scoped_knowns_) {
+    if (auto normalized = known.WithLHS(lhs_key)) {
+      declare_known(normalized.value());
+    }
+  }
+
+  // Walk through the space of all comparisons that can be made with
+  // LHS.
+  while (to_visit.size()) {
+    Key middle_key = *to_visit.begin();
+    to_visit.erase(to_visit.begin());
+
+    std::vector<Comparison>& prev_knowns_using_middle = compared_to_x.at(middle_key);
+    ICHECK(compared_to_x.count(middle_key));
+
+    std::vector<Comparison> new_knowns_using_lhs;
+
+    auto attempt_transitive = [&](Comparison cmp) {
+      ICHECK(cmp.IsNormalized());
+
+      Key right_key = cmp.rhs_;
+
+      if (right_key == lhs_key) {
+        return;
+      }
+
+      for (const auto& prev : prev_knowns_using_middle) {
+        CompareResult new_result = CompareResult::kUnknown;
+        int64_t new_offset = prev.offset_ + cmp.offset_;
+
+        if (prev.result_ == CompareResult::kEQ) {
+          // x == y + c1 && y OP z + c2, x OP z + (c1 + c2)
+          new_result = cmp.result_;
+        } else if (cmp.result_ == CompareResult::kEQ) {
+          // x OP y + c1 && y == z + c2, x OP z + (c1 + c2)
+          new_result = prev.result_;
+        } else if (prev.result_ == cmp.result_ &&
+                   (prev.result_ == CompareResult::kLE || prev.result_ == CompareResult::kGE)) {
+          // x <= y + c1 && y <= z + c2, x <= z + (c1 + c2)
+          // x >= y + c1 && y >= z + c2, x >= z + (c1 + c2)
+          //
+          // This condition is much simpler to write than the
+          // equivalent handling of < or of >, which is why the
+          // inequalities are normalized to <= and to >=.

Review Comment:
   Added the details in the `TransitiveComparisonAnalyzer::Impl::Comparison::Comparison` constructor where the normalization is performed, with references to it here and in `IsNormalized`.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] Lunderberg commented on pull request #12863: [TIR][Arith] Prove conditionals by transitively applying knowns

Posted by GitBox <gi...@apache.org>.
Lunderberg commented on PR #12863:
URL: https://github.com/apache/tvm/pull/12863#issuecomment-1259627256

   The additional simplifications are now optional, and can be opted-in either through explicit flags with `RewriteSimplifier::SetEnabledFeatures`, or through a `PassContext` when using `tir::transform::Simplify`.  Each currently supports only a single flag for the transitive search through inequalities, but will have additional flags added for simplifications required by #12261.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] Lunderberg commented on a diff in pull request #12863: [TIR][Arith] Prove conditionals by transitively applying knowns

Posted by GitBox <gi...@apache.org>.
Lunderberg commented on code in PR #12863:
URL: https://github.com/apache/tvm/pull/12863#discussion_r989383994


##########
include/tvm/arith/analyzer.h:
##########
@@ -275,6 +275,36 @@ class RewriteSimplifier {
    */
   std::function<void()> EnterConstraint(const PrimExpr& constraint);
 
+  /*! \brief Flags to enable more computationally-intensive simplifications
+   *
+   * These simplifications may be required for specific schedules, but
+   * would impose too high a compile-time cost to enable by default.
+   * They can be enabled on an as-needed basis by calling
+   * `RewriteSimplifier::SetEnabledFeatures` prior to using
+   * `RewriteSimplifier::operator()`.
+   */
+  enum Feature {
+    // No features enabled
+    kNone = 0,
+
+    /* When simplifying an inequality, attempt to use scope-based knowns.
+     *
+     * Example:
+     * if_then_else(i<j && j<k, i<k, false) => if_then_else(i<j && j<k, true, false)
+     */
+    kTransitivelyProveInequalities = (1 << 0),
+  };
+
+  /*! \brief Enable an optional feature or features
+   *
+   * \param flags A bitwise OR of all optional features that should be
+   * enabled.
+   */
+  void SetEnabledFeatures(Feature flags);

Review Comment:
   Updated to use `TVM_DLL` on the new member functions.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] csullivan commented on a diff in pull request #12863: [TIR][Arith] Prove conditionals by transitively applying knowns

Posted by GitBox <gi...@apache.org>.
csullivan commented on code in PR #12863:
URL: https://github.com/apache/tvm/pull/12863#discussion_r989642453


##########
src/arith/transitive_comparison_analyzer.cc:
##########
@@ -0,0 +1,683 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file tvm/arith/transitive_comparison_analyzer.cc
+ */
+
+#include <tvm/arith/analyzer.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/expr.h>
+
+#include <optional>
+#include <vector>
+
+#include "constraint_extract.h"
+#include "pattern_match.h"
+
+namespace tvm {
+namespace arith {
+
+using namespace tir;
+
+class TransitiveComparisonAnalyzer::Impl {
+ public:
+  /* \brief Using previously specified knowns, compare the expressions provided
+   *
+   * \param lhs The left-hand side of the comparison
+   *
+   * \param rhs The right-hand side of the comparison
+   *
+   * \return The most specific result that can be proven about the
+   * comparison.  If nothing can be proven, returns kUnknown.
+   */
+  CompareResult TryCompare(const PrimExpr& lhs, const PrimExpr& rhs) const;
+
+  /*! \brief Bind a variable as being equal to a known expression
+   *
+   * \param var The variable of interest.
+   * \param expr The bound expression
+   * \param allow_override Whether to allow override of existing information.
+   */
+  void Bind(const tir::Var& var, const PrimExpr& expr, bool allow_override = false);
+
+  /*! \brief Bind a variable as being within a specified range
+   *
+   * \param var The variable of interest.
+   * \param range The known range
+   * \param allow_override Whether to allow override of existing information.
+   */
+  void Bind(const tir::Var& var, const Range& expr, bool allow_override = false);
+
+  /*!
+   * \brief Update the internal state to enter constraint.
+   * \param constraint A constraint expression.
+   *
+   * \return An exit function that must be called to cleanup.  May be
+   * `nullptr`, if no cleanup is required.
+   */
+  std::function<void()> EnterConstraint(const PrimExpr& expr);
+
+ private:
+  // Utility class to avoid needing to repeatedly call ExprDeepEqual
+  enum class Key : size_t {};
+  std::optional<Key> ExprToPreviousKey(const PrimExpr& expr) const;
+  Key ExprToKey(const PrimExpr& expr);
+  std::unordered_map<PrimExpr, Key, StructuralHash, StructuralEqual> expr_to_key;
+
+  /*! \brief Internal representation of a comparison operator */
+  struct Comparison {
+    /*! \brief Construct a comparison that represents `lhs OP rhs +
+     * offset`, where the operation is specified by the CompareResult.
+     */
+    Comparison(Key lhs, Key rhs, int64_t offset, CompareResult result);
+
+    /*! \brief Utility function to validate that all GT and LT results
+     *  have been normalized out
+     */
+    bool IsNormalized() const;
+
+    /*! \brief Move the specified expression to the LHS.
+     *
+     * \param new_lhs The argument that should be moved to the LHS of the
+     * comparison.
+     *
+     * \return If possible, returns a comparison that is equivalent to
+     * the current comparison, but with the specified LHS.  If not
+     * possible, returns nullopt.
+     */
+    std::optional<Comparison> WithLHS(Key new_lhs) const;
+
+    /*! \brief Create the negation of the current comparison */
+    Comparison Negated() const;
+
+    /*! \brief Check the this comparison implies
+     *
+     * Returns true if this comparison being true implies that the
+     * other comparison must also be true.  Returns false if the other
+     * comparison cannot be shown to be true.
+     */
+    bool Implies(const Comparison& other) const;
+
+    // The LHS of the comparison
+    Key lhs_;
+
+    // The RHS of the comparison, not including any constant offset.
+    Key rhs_;
+
+    // Additive offset on rhs
+    int64_t offset_{0};
+
+    // The comparison operator.
+    CompareResult result_{CompareResult::kInconsistent};
+  };
+
+  /*! \brief Generate a Comparison representing the given expression */
+  std::optional<Comparison> FromExpr(const PrimExpr& expr);
+
+  /*! \brief Utility function used by Bind and EnterConstraint
+   *
+   * \param expr The comparison expression, to be converted into
+   * internal Comparison objects.
+   *
+   * \param vec The vector to which the Comparison objects should be
+   * appended.
+   */
+  void AddKnown(const PrimExpr& expr, std::vector<Comparison>* vec);
+
+  /*! \brief Attempt to compare, starting at the lhs.
+   *
+   * Taking each available `Comparison` as a node edge, search for a
+   * path from lhs to rhs.  For example, the priors (a<=b), (b<=c+1)
+   * and (c<=d-5) can be used to prove that (a<=d-4).
+   *
+   * \param lhs The left-hand side of the comparison
+   *
+   * \param rhs The right-hand side of the comparison
+   *
+   * \return The result of the comparison
+   */
+  CompareResult TryCompareFromLHS(Key lhs_key, Key rhs_key, int64_t offset, const PrimExpr& lhs,
+                                  const PrimExpr& rhs) const;
+
+  /*! \brief Previous Range bindings
+   *
+   * Tracked separatedly to handle the `allow_override` option used by
+   * all sub-analyzers when binding variables.
+   */
+  Map<Var, Range> prev_bindings_;
+
+  /*! \brief Known comparisons based on definitionally-true statements
+   *
+   * For example, a Let binding, or the range of an iterator.
+   */
+  std::vector<Comparison> knowns_;
+
+  /*! \brief Known comparisons based on of scope-based statements
+   *
+   * For example, the condition of an IfThenElse, which is known to be
+   * true while within the if scope.
+   */
+  std::vector<Comparison> scoped_knowns_;
+};
+
+namespace {
+
+// Internal utility, return the CompareResult resulting from swapping
+// the left-hand side with the right-hand side.
+CompareResult Reverse(CompareResult res) {
+  switch (res) {
+    case CompareResult::kInconsistent:
+      return CompareResult::kInconsistent;
+    case CompareResult::kEQ:
+      return CompareResult::kEQ;
+    case CompareResult::kLT:
+      return CompareResult::kGT;
+    case CompareResult::kLE:
+      return CompareResult::kGE;
+    case CompareResult::kGT:
+      return CompareResult::kLT;
+    case CompareResult::kGE:
+      return CompareResult::kLE;
+    case CompareResult::kNE:
+      return CompareResult::kNE;
+    case CompareResult::kUnknown:
+      return CompareResult::kUnknown;
+    default:
+      LOG(FATAL) << "Invalid CompareResult: " << static_cast<int>(res);
+      return CompareResult::kInconsistent;
+  }
+}
+
+// Internal utility, return the CompareResult resulting from negating
+// the comparison.
+CompareResult Negate(CompareResult res) {
+  switch (res) {
+    case CompareResult::kInconsistent:
+      return CompareResult::kInconsistent;
+    case CompareResult::kUnknown:
+      return CompareResult::kUnknown;
+    default:
+      return CompareResult(~static_cast<int>(res) & static_cast<int>(CompareResult::kUnknown));
+  }
+}
+
+// Internal utility, extract constant offsets out of the two sides of
+// a comparison.  Given lhs and rhs, return a tuple of three elements
+// (lhs_inner, rhs_inner, offset), such that (lhs OP rhs) and
+// (lhs_inner OP rhs_inner + offset) are equivalent.
+std::tuple<PrimExpr, PrimExpr, int64_t> ExtractOffsets(const PrimExpr& lhs, const PrimExpr& rhs) {
+  auto extract_offset = [](const PrimExpr& expr) -> std::pair<PrimExpr, int64_t> {
+    PVar<PrimExpr> x;
+    PVar<IntImm> c;
+    if ((x + c).Match(expr)) {
+      return {x.Eval(), c.Eval()->value};
+    } else if ((x - c).Match(expr)) {
+      return {x.Eval(), -c.Eval()->value};
+    } else if (c.Match(expr)) {
+      return {0, c.Eval()->value};
+    } else {
+      return {expr, 0};
+    }
+  };
+
+  auto lhs_split = extract_offset(lhs);
+  auto rhs_split = extract_offset(rhs);
+  return {lhs_split.first, rhs_split.first, rhs_split.second - lhs_split.second};
+}
+
+}  // namespace
+
+std::optional<TransitiveComparisonAnalyzer::Impl::Comparison>
+TransitiveComparisonAnalyzer::Impl::FromExpr(const PrimExpr& expr) {
+  CompareResult res;
+  PVar<PrimExpr> x, y;
+  if ((x <= y).Match(expr)) {
+    res = CompareResult::kLE;
+  } else if ((x >= y).Match(expr)) {
+    res = CompareResult::kGE;
+  } else if ((x < y).Match(expr)) {
+    res = CompareResult::kLT;
+  } else if ((x > y).Match(expr)) {
+    res = CompareResult::kGT;
+  } else if ((x == y).Match(expr)) {
+    res = CompareResult::kEQ;
+  } else if ((x != y).Match(expr)) {
+    res = CompareResult::kNE;
+  } else {
+    return std::nullopt;
+  }
+
+  PrimExpr lhs_expr = x.Eval();
+  PrimExpr rhs_expr = y.Eval();
+
+  if (lhs_expr.as<IntImmNode>() && rhs_expr.as<IntImmNode>()) {
+    return std::nullopt;
+  }
+
+  auto [lhs, rhs, offset] = ExtractOffsets(lhs_expr, rhs_expr);
+  Key lhs_key = ExprToKey(lhs);
+  Key rhs_key = ExprToKey(rhs);
+
+  return Comparison(lhs_key, rhs_key, offset, res);
+}
+
+TransitiveComparisonAnalyzer::Impl::Comparison::Comparison(Key lhs, Key rhs, int64_t offset,
+                                                           CompareResult result)
+    : lhs_(lhs), rhs_(rhs), offset_(offset), result_(result) {
+  if (result_ == CompareResult::kLT) {
+    result_ = CompareResult::kLE;
+    offset_ -= 1;
+  }
+  if (result_ == CompareResult::kGT) {
+    result_ = CompareResult::kGE;
+    offset_ += 1;
+  }
+}
+
+std::optional<TransitiveComparisonAnalyzer::Impl::Key>
+TransitiveComparisonAnalyzer::Impl::ExprToPreviousKey(const PrimExpr& expr) const {
+  auto it = expr_to_key.find(expr);
+  if (it != expr_to_key.end()) {
+    return it->second;
+  } else {
+    return std::nullopt;
+  }
+}
+
+TransitiveComparisonAnalyzer::Impl::Key TransitiveComparisonAnalyzer::Impl::ExprToKey(
+    const PrimExpr& expr) {
+  if (auto prev = ExprToPreviousKey(expr)) {
+    return prev.value();
+  } else {
+    Key new_key = Key(expr_to_key.size());
+    expr_to_key[expr] = new_key;
+    return new_key;
+  }
+}
+
+bool TransitiveComparisonAnalyzer::Impl::Comparison::IsNormalized() const {
+  // These < and > should be removed during normalization.
+  return result_ != CompareResult::kLT && result_ != CompareResult::kGT;
+}
+
+std::optional<TransitiveComparisonAnalyzer::Impl::Comparison>
+TransitiveComparisonAnalyzer::Impl::Comparison::WithLHS(Key new_lhs) const {
+  if (new_lhs == lhs_) {
+    return *this;
+  } else if (new_lhs == rhs_) {
+    return Comparison(rhs_, lhs_, -offset_, Reverse(result_));
+  } else {
+    return std::nullopt;
+  }
+}
+
+TransitiveComparisonAnalyzer::Impl::Comparison
+TransitiveComparisonAnalyzer::Impl::Comparison::Negated() const {
+  return Comparison(lhs_, rhs_, offset_, Negate(result_));
+}
+
+bool TransitiveComparisonAnalyzer::Impl::Comparison::Implies(
+    const TransitiveComparisonAnalyzer::Impl::Comparison& other) const {
+  ICHECK(lhs_ == other.lhs_);
+  ICHECK(rhs_ == other.rhs_);
+  ICHECK(IsNormalized());
+  ICHECK(other.IsNormalized());
+
+  if (result_ == other.result_ && offset_ == other.offset_) {
+    // if c1 == c2, x != y + c1 => x != y + c2
+    // if c1 == c2, x == y + c1 => x == y + c2
+    return true;
+  }
+
+  if (other.result_ == CompareResult::kLE && offset_ <= other.offset_) {
+    if (result_ == CompareResult::kEQ || result_ == CompareResult::kLE) {
+      // if c1 <= c2, x <= y + c1 => x <= y + c2
+      // if c1 <= c2, x == y + c1 => x <= y + c2
+      return true;
+    }
+  }
+
+  if (other.result_ == CompareResult::kGE && offset_ >= other.offset_) {
+    if (result_ == CompareResult::kEQ || result_ == CompareResult::kGE) {
+      // if c1 >= c2, x == y + c1 => x >= y + c2
+      // if c1 >= c2, x >= y + c1 => x >= y + c2
+      return true;
+    }
+  }
+
+  if (other.result_ == CompareResult::kNE) {
+    if (result_ == CompareResult::kEQ && offset_ != other.offset_) {
+      // if c1 != c2, x == y + c1 => x != y + c2
+      return true;
+    }
+
+    if (result_ == CompareResult::kLE && offset_ < other.offset_) {
+      // if c1 < c2, x <= y + c1 => x < y + c2 => x != y + c2
+      return true;
+    }
+
+    if (result_ == CompareResult::kGE && offset_ > other.offset_) {
+      // if c1 != c2, x >= y + c1 => x > y + c2 => x != y + c2
+      return true;
+    }
+  }
+
+  return false;
+}
+
+TransitiveComparisonAnalyzer::TransitiveComparisonAnalyzer() : impl_(std::make_unique<Impl>()) {}
+TransitiveComparisonAnalyzer::~TransitiveComparisonAnalyzer() {}
+
+CompareResult TransitiveComparisonAnalyzer::TryCompare(const PrimExpr& lhs, const PrimExpr& rhs) {
+  return impl_->TryCompare(lhs, rhs);
+}
+
+void TransitiveComparisonAnalyzer::Bind(const Var& var, const PrimExpr& expr, bool allow_override) {
+  impl_->Bind(var, expr, allow_override);
+}
+void TransitiveComparisonAnalyzer::Bind(const Var& var, const Range& range, bool allow_override) {
+  impl_->Bind(var, range, allow_override);
+}
+
+std::function<void()> TransitiveComparisonAnalyzer::EnterConstraint(const PrimExpr& constraint) {
+  return impl_->EnterConstraint(constraint);
+}
+
+void TransitiveComparisonAnalyzer::Impl::AddKnown(const PrimExpr& expr,
+                                                  std::vector<Comparison>* vec) {
+  for (const auto& subexpr : ExtractConstraints(expr)) {
+    if (tir::SideEffect(expr) <= tir::CallEffectKind::kPure) {
+      if (auto cmp = FromExpr(subexpr)) {
+        vec->push_back(cmp.value());
+      }
+    }
+  }
+}
+
+void TransitiveComparisonAnalyzer::Impl::Bind(const tir::Var& var, const Range& range,
+                                              bool allow_override) {
+  auto it = prev_bindings_.find(var);
+  if (it != prev_bindings_.end()) {
+    ExprDeepEqual expr_equal;
+    bool differs_from_previous = !expr_equal(range->min, (*it).second->min) ||
+                                 !expr_equal(range->extent, (*it).second->extent);
+    if (differs_from_previous) {
+      ICHECK(allow_override) << "Binding of variable " << var << " as " << range
+                             << " conflicts with previous binding as " << (*it).second;
+      if (auto key = ExprToPreviousKey(var)) {
+        knowns_.erase(std::remove_if(knowns_.begin(), knowns_.end(),
+                                     [&](const auto& known) { return known.lhs_ == key.value(); }),
+                      knowns_.end());
+      }
+    }
+  }
+
+  prev_bindings_.Set(var, range);
+
+  if (is_const_int(range->extent, 1)) {
+    AddKnown(var == range->min, &knowns_);
+  } else {
+    AddKnown(var >= range->min, &knowns_);
+    AddKnown(var < range->min + range->extent, &knowns_);
+  }
+}
+
+void TransitiveComparisonAnalyzer::Impl::Bind(const tir::Var& var, const PrimExpr& expr,
+                                              bool allow_override) {
+  Bind(var, Range::FromMinExtent(expr, 1), allow_override);
+}
+
+std::function<void()> TransitiveComparisonAnalyzer::Impl::EnterConstraint(const PrimExpr& expr) {
+  size_t old_literal_size = scoped_knowns_.size();
+  AddKnown(expr, &scoped_knowns_);
+  size_t new_literal_size = scoped_knowns_.size();
+
+  PrimExpr temp = expr;
+  auto frecover = [old_literal_size, new_literal_size, this, temp]() {
+    ICHECK_EQ(scoped_knowns_.size(), new_literal_size);
+    scoped_knowns_.erase(scoped_knowns_.begin() + old_literal_size, scoped_knowns_.end());
+  };
+  return frecover;
+}
+
+CompareResult TransitiveComparisonAnalyzer::Impl::TryCompare(const PrimExpr& lhs_expr,
+                                                             const PrimExpr& rhs_expr) const {
+  // Currently only supports integer checks
+  if (!lhs_expr.dtype().is_int() || !rhs_expr.dtype().is_int()) {
+    return CompareResult::kUnknown;
+  }
+
+  // Bail out early if possible.  This int check should have been
+  // constant-folded earlier, so this check shouldn't occur.
+  auto* x_int = lhs_expr.as<IntImmNode>();
+  auto* y_int = rhs_expr.as<IntImmNode>();
+  if (x_int && y_int) {
+    if (x_int->value < y_int->value) {
+      return CompareResult::kLT;
+    } else if (x_int->value > y_int->value) {
+      return CompareResult::kGT;
+    } else {
+      return CompareResult::kEQ;
+    }
+  }
+
+  auto [lhs, rhs, offset] = ExtractOffsets(lhs_expr, rhs_expr);
+  auto lhs_key = ExprToPreviousKey(lhs);
+  auto rhs_key = ExprToPreviousKey(rhs);
+
+  if (!lhs_key.has_value() || !rhs_key.has_value()) {
+    return CompareResult::kUnknown;
+  }
+
+  auto from_lhs = TryCompareFromLHS(lhs_key.value(), rhs_key.value(), offset, lhs, rhs);
+  auto from_rhs = Reverse(TryCompareFromLHS(rhs_key.value(), lhs_key.value(), -offset, rhs, lhs));
+  auto output = from_lhs & from_rhs;
+
+  return output;
+}
+
+CompareResult TransitiveComparisonAnalyzer::Impl::TryCompareFromLHS(

Review Comment:
   Oh shoot sorry I was meaning the function length was quite long and could be easier to understand if factored but it was completely a nitpick and not necessary given everything else looks good now. But I do like the new name too!



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] Lunderberg commented on a diff in pull request #12863: [TIR][Arith] Prove conditionals by transitively applying knowns

Posted by GitBox <gi...@apache.org>.
Lunderberg commented on code in PR #12863:
URL: https://github.com/apache/tvm/pull/12863#discussion_r989392512


##########
tests/python/unittest/test_tir_transform_simplify.py:
##########
@@ -547,5 +561,129 @@ def before(A: T.Buffer[16, "float32"]):
     expected = before
 
 
+class TestRemoveTransitivelyProvableCondition(BaseBeforeAfter):
+    """Remove comparisons that may be proven using multiple others
+
+    For example, the `0 < i` and `i <= j` conditions can be used to prove
+    that `0 < j`.
+    """
+
+    i, j, k = [tvm.tir.Var(name, "int32") for name in "ijk"]
+    zero = tvm.tir.IntImm("int32", 0)
+
+    test_case = tvm.testing.parameter(
+        (tvm.tir.all(zero < i, i <= j), zero < j, True),
+        # Transitive comparisons from LT
+        (tvm.tir.all(i < j, j < k), i < k, True),
+        (tvm.tir.all(i < j, j == k), i < k, True),
+        (tvm.tir.all(i < j, j <= k), i < k, True),
+        (tvm.tir.all(i < j, j > k), i < k, False),
+        (tvm.tir.all(i < j, j >= k), i < k, False),
+        (tvm.tir.all(i < j, j != k), i < k, False),
+        # Transitive comparisons from LE
+        (tvm.tir.all(i <= j, j < k), i < k, True),
+        (tvm.tir.all(i <= j, j == k), i == k, False),
+        (tvm.tir.all(i <= j, j == k), i <= k, True),
+        (tvm.tir.all(i <= j, j <= k), i <= k, True),
+        (tvm.tir.all(i <= j, j <= k), i < k, False),
+        (tvm.tir.all(i <= j, j > k), i < k, False),
+        (tvm.tir.all(i <= j, j >= k), i < k, False),
+        (tvm.tir.all(i <= j, j != k), i < k, False),
+        # Transitive comparisons from GT
+        (tvm.tir.all(i > j, j > k), i > k, True),
+        (tvm.tir.all(i > j, j == k), i > k, True),
+        (tvm.tir.all(i > j, j >= k), i > k, True),
+        (tvm.tir.all(i > j, j < k), i > k, False),
+        (tvm.tir.all(i > j, j <= k), i > k, False),
+        (tvm.tir.all(i > j, j != k), i > k, False),
+        # Transitive comparisons from GE
+        (tvm.tir.all(i >= j, j > k), i > k, True),
+        (tvm.tir.all(i >= j, j == k), i == k, False),
+        (tvm.tir.all(i >= j, j == k), i >= k, True),
+        (tvm.tir.all(i >= j, j >= k), i >= k, True),
+        (tvm.tir.all(i >= j, j >= k), i > k, False),
+        (tvm.tir.all(i >= j, j < k), i > k, False),
+        (tvm.tir.all(i >= j, j <= k), i > k, False),
+        (tvm.tir.all(i >= j, j != k), i > k, False),
+        # GT or LT may be used to prove NE
+        (tvm.tir.all(i == j, j != k), i != k, True),
+        (tvm.tir.all(i == j, j < k), i != k, True),
+        (tvm.tir.all(i == j, j > k), i != k, True),
+        (tvm.tir.all(i == j, j != k), i < k, False),
+        (tvm.tir.all(i == j, j != k), i > k, False),
+        # Because these are integers, x<y is equivalent to x <= y-1,
+        # and may be used in equivalent simplifications.
+        (tvm.tir.all(i < j, j < k), i < k, True),

Review Comment:
   Thank you for the catch, and yes, this was intended to be a `i <= j-1` test.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] Lunderberg commented on pull request #12863: [TIR][Arith] Prove conditionals by transitively applying knowns

Posted by GitBox <gi...@apache.org>.
Lunderberg commented on PR #12863:
URL: https://github.com/apache/tvm/pull/12863#issuecomment-1258687309

   Current failures are timeouts during `apps/microtvm/ethosu/run_demo.sh`.  The demo passes locally, but the runtime has increased from 1m22s to 5m48s, which would make it trip the timeout in `tests/scripts/task_demo_microtvm.sh`.  Prior to merging this commit, the additional simplifications should be explicitly enabled.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] Lunderberg commented on a diff in pull request #12863: [TIR][Arith] Prove conditionals by transitively applying knowns

Posted by GitBox <gi...@apache.org>.
Lunderberg commented on code in PR #12863:
URL: https://github.com/apache/tvm/pull/12863#discussion_r989412233


##########
src/arith/transitive_comparison_analyzer.cc:
##########
@@ -0,0 +1,683 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file tvm/arith/transitive_comparison_analyzer.cc
+ */
+
+#include <tvm/arith/analyzer.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/expr.h>
+
+#include <optional>
+#include <vector>
+
+#include "constraint_extract.h"
+#include "pattern_match.h"
+
+namespace tvm {
+namespace arith {
+
+using namespace tir;
+
+class TransitiveComparisonAnalyzer::Impl {
+ public:
+  /* \brief Using previously specified knowns, compare the expressions provided
+   *
+   * \param lhs The left-hand side of the comparison
+   *
+   * \param rhs The right-hand side of the comparison
+   *
+   * \return The most specific result that can be proven about the
+   * comparison.  If nothing can be proven, returns kUnknown.
+   */
+  CompareResult TryCompare(const PrimExpr& lhs, const PrimExpr& rhs) const;
+
+  /*! \brief Bind a variable as being equal to a known expression
+   *
+   * \param var The variable of interest.
+   * \param expr The bound expression
+   * \param allow_override Whether to allow override of existing information.
+   */
+  void Bind(const tir::Var& var, const PrimExpr& expr, bool allow_override = false);
+
+  /*! \brief Bind a variable as being within a specified range
+   *
+   * \param var The variable of interest.
+   * \param range The known range
+   * \param allow_override Whether to allow override of existing information.
+   */
+  void Bind(const tir::Var& var, const Range& expr, bool allow_override = false);
+
+  /*!
+   * \brief Update the internal state to enter constraint.
+   * \param constraint A constraint expression.
+   *
+   * \return An exit function that must be called to cleanup.  May be
+   * `nullptr`, if no cleanup is required.
+   */
+  std::function<void()> EnterConstraint(const PrimExpr& expr);
+
+ private:
+  // Utility class to avoid needing to repeatedly call ExprDeepEqual
+  enum class Key : size_t {};
+  std::optional<Key> ExprToPreviousKey(const PrimExpr& expr) const;
+  Key ExprToKey(const PrimExpr& expr);
+  std::unordered_map<PrimExpr, Key, StructuralHash, StructuralEqual> expr_to_key;
+
+  /*! \brief Internal representation of a comparison operator */
+  struct Comparison {
+    /*! \brief Construct a comparison that represents `lhs OP rhs +
+     * offset`, where the operation is specified by the CompareResult.
+     */
+    Comparison(Key lhs, Key rhs, int64_t offset, CompareResult result);
+
+    /*! \brief Utility function to validate that all GT and LT results
+     *  have been normalized out
+     */
+    bool IsNormalized() const;
+
+    /*! \brief Move the specified expression to the LHS.
+     *
+     * \param new_lhs The argument that should be moved to the LHS of the
+     * comparison.
+     *
+     * \return If possible, returns a comparison that is equivalent to
+     * the current comparison, but with the specified LHS.  If not
+     * possible, returns nullopt.
+     */
+    std::optional<Comparison> WithLHS(Key new_lhs) const;
+
+    /*! \brief Create the negation of the current comparison */
+    Comparison Negated() const;
+
+    /*! \brief Check the this comparison implies
+     *
+     * Returns true if this comparison being true implies that the
+     * other comparison must also be true.  Returns false if the other
+     * comparison cannot be shown to be true.
+     */
+    bool Implies(const Comparison& other) const;
+
+    // The LHS of the comparison
+    Key lhs_;
+
+    // The RHS of the comparison, not including any constant offset.
+    Key rhs_;
+
+    // Additive offset on rhs
+    int64_t offset_{0};
+
+    // The comparison operator.
+    CompareResult result_{CompareResult::kInconsistent};
+  };
+
+  /*! \brief Generate a Comparison representing the given expression */
+  std::optional<Comparison> FromExpr(const PrimExpr& expr);
+
+  /*! \brief Utility function used by Bind and EnterConstraint
+   *
+   * \param expr The comparison expression, to be converted into
+   * internal Comparison objects.
+   *
+   * \param vec The vector to which the Comparison objects should be
+   * appended.
+   */
+  void AddKnown(const PrimExpr& expr, std::vector<Comparison>* vec);
+
+  /*! \brief Attempt to compare, starting at the lhs.
+   *
+   * Taking each available `Comparison` as a node edge, search for a
+   * path from lhs to rhs.  For example, the priors (a<=b), (b<=c+1)
+   * and (c<=d-5) can be used to prove that (a<=d-4).
+   *
+   * \param lhs The left-hand side of the comparison
+   *
+   * \param rhs The right-hand side of the comparison
+   *
+   * \return The result of the comparison
+   */
+  CompareResult TryCompareFromLHS(Key lhs_key, Key rhs_key, int64_t offset, const PrimExpr& lhs,
+                                  const PrimExpr& rhs) const;
+
+  /*! \brief Previous Range bindings
+   *
+   * Tracked separatedly to handle the `allow_override` option used by
+   * all sub-analyzers when binding variables.
+   */
+  Map<Var, Range> prev_bindings_;
+
+  /*! \brief Known comparisons based on definitionally-true statements
+   *
+   * For example, a Let binding, or the range of an iterator.
+   */
+  std::vector<Comparison> knowns_;
+
+  /*! \brief Known comparisons based on of scope-based statements

Review Comment:
   Thank you, reworded.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] Lunderberg commented on a diff in pull request #12863: [TIR][Arith] Prove conditionals by transitively applying knowns

Posted by GitBox <gi...@apache.org>.
Lunderberg commented on code in PR #12863:
URL: https://github.com/apache/tvm/pull/12863#discussion_r989382211


##########
include/tvm/arith/analyzer.h:
##########
@@ -275,6 +275,36 @@ class RewriteSimplifier {
    */
   std::function<void()> EnterConstraint(const PrimExpr& constraint);
 
+  /*! \brief Flags to enable more computationally-intensive simplifications
+   *
+   * These simplifications may be required for specific schedules, but
+   * would impose too high a compile-time cost to enable by default.
+   * They can be enabled on an as-needed basis by calling
+   * `RewriteSimplifier::SetEnabledFeatures` prior to using
+   * `RewriteSimplifier::operator()`.
+   */
+  enum Feature {

Review Comment:
   I like that wording much more than `Feature`.  Updated here, and in all other occurrences in this PR.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] Lunderberg commented on a diff in pull request #12863: [TIR][Arith] Prove conditionals by transitively applying knowns

Posted by GitBox <gi...@apache.org>.
Lunderberg commented on code in PR #12863:
URL: https://github.com/apache/tvm/pull/12863#discussion_r989450306


##########
src/arith/transitive_comparison_analyzer.cc:
##########
@@ -0,0 +1,683 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file tvm/arith/transitive_comparison_analyzer.cc
+ */
+
+#include <tvm/arith/analyzer.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/expr.h>
+
+#include <optional>
+#include <vector>
+
+#include "constraint_extract.h"
+#include "pattern_match.h"
+
+namespace tvm {
+namespace arith {
+
+using namespace tir;
+
+class TransitiveComparisonAnalyzer::Impl {
+ public:
+  /* \brief Using previously specified knowns, compare the expressions provided
+   *
+   * \param lhs The left-hand side of the comparison
+   *
+   * \param rhs The right-hand side of the comparison
+   *
+   * \return The most specific result that can be proven about the
+   * comparison.  If nothing can be proven, returns kUnknown.
+   */
+  CompareResult TryCompare(const PrimExpr& lhs, const PrimExpr& rhs) const;
+
+  /*! \brief Bind a variable as being equal to a known expression
+   *
+   * \param var The variable of interest.
+   * \param expr The bound expression
+   * \param allow_override Whether to allow override of existing information.
+   */
+  void Bind(const tir::Var& var, const PrimExpr& expr, bool allow_override = false);
+
+  /*! \brief Bind a variable as being within a specified range
+   *
+   * \param var The variable of interest.
+   * \param range The known range
+   * \param allow_override Whether to allow override of existing information.
+   */
+  void Bind(const tir::Var& var, const Range& expr, bool allow_override = false);
+
+  /*!
+   * \brief Update the internal state to enter constraint.
+   * \param constraint A constraint expression.
+   *
+   * \return An exit function that must be called to cleanup.  May be
+   * `nullptr`, if no cleanup is required.
+   */
+  std::function<void()> EnterConstraint(const PrimExpr& expr);
+
+ private:
+  // Utility class to avoid needing to repeatedly call ExprDeepEqual
+  enum class Key : size_t {};
+  std::optional<Key> ExprToPreviousKey(const PrimExpr& expr) const;
+  Key ExprToKey(const PrimExpr& expr);
+  std::unordered_map<PrimExpr, Key, StructuralHash, StructuralEqual> expr_to_key;
+
+  /*! \brief Internal representation of a comparison operator */
+  struct Comparison {
+    /*! \brief Construct a comparison that represents `lhs OP rhs +
+     * offset`, where the operation is specified by the CompareResult.
+     */
+    Comparison(Key lhs, Key rhs, int64_t offset, CompareResult result);
+
+    /*! \brief Utility function to validate that all GT and LT results
+     *  have been normalized out
+     */
+    bool IsNormalized() const;
+
+    /*! \brief Move the specified expression to the LHS.
+     *
+     * \param new_lhs The argument that should be moved to the LHS of the
+     * comparison.
+     *
+     * \return If possible, returns a comparison that is equivalent to
+     * the current comparison, but with the specified LHS.  If not
+     * possible, returns nullopt.
+     */
+    std::optional<Comparison> WithLHS(Key new_lhs) const;
+
+    /*! \brief Create the negation of the current comparison */
+    Comparison Negated() const;
+
+    /*! \brief Check the this comparison implies
+     *
+     * Returns true if this comparison being true implies that the
+     * other comparison must also be true.  Returns false if the other
+     * comparison cannot be shown to be true.
+     */
+    bool Implies(const Comparison& other) const;
+
+    // The LHS of the comparison
+    Key lhs_;
+
+    // The RHS of the comparison, not including any constant offset.
+    Key rhs_;
+
+    // Additive offset on rhs
+    int64_t offset_{0};
+
+    // The comparison operator.
+    CompareResult result_{CompareResult::kInconsistent};
+  };
+
+  /*! \brief Generate a Comparison representing the given expression */
+  std::optional<Comparison> FromExpr(const PrimExpr& expr);
+
+  /*! \brief Utility function used by Bind and EnterConstraint
+   *
+   * \param expr The comparison expression, to be converted into
+   * internal Comparison objects.
+   *
+   * \param vec The vector to which the Comparison objects should be
+   * appended.
+   */
+  void AddKnown(const PrimExpr& expr, std::vector<Comparison>* vec);
+
+  /*! \brief Attempt to compare, starting at the lhs.
+   *
+   * Taking each available `Comparison` as a node edge, search for a
+   * path from lhs to rhs.  For example, the priors (a<=b), (b<=c+1)
+   * and (c<=d-5) can be used to prove that (a<=d-4).
+   *
+   * \param lhs The left-hand side of the comparison
+   *
+   * \param rhs The right-hand side of the comparison
+   *
+   * \return The result of the comparison
+   */
+  CompareResult TryCompareFromLHS(Key lhs_key, Key rhs_key, int64_t offset, const PrimExpr& lhs,
+                                  const PrimExpr& rhs) const;
+
+  /*! \brief Previous Range bindings
+   *
+   * Tracked separatedly to handle the `allow_override` option used by
+   * all sub-analyzers when binding variables.
+   */
+  Map<Var, Range> prev_bindings_;
+
+  /*! \brief Known comparisons based on definitionally-true statements
+   *
+   * For example, a Let binding, or the range of an iterator.
+   */
+  std::vector<Comparison> knowns_;
+
+  /*! \brief Known comparisons based on of scope-based statements
+   *
+   * For example, the condition of an IfThenElse, which is known to be
+   * true while within the if scope.
+   */
+  std::vector<Comparison> scoped_knowns_;
+};
+
+namespace {
+
+// Internal utility, return the CompareResult resulting from swapping
+// the left-hand side with the right-hand side.
+CompareResult Reverse(CompareResult res) {
+  switch (res) {
+    case CompareResult::kInconsistent:
+      return CompareResult::kInconsistent;
+    case CompareResult::kEQ:
+      return CompareResult::kEQ;
+    case CompareResult::kLT:
+      return CompareResult::kGT;
+    case CompareResult::kLE:
+      return CompareResult::kGE;
+    case CompareResult::kGT:
+      return CompareResult::kLT;
+    case CompareResult::kGE:
+      return CompareResult::kLE;
+    case CompareResult::kNE:
+      return CompareResult::kNE;
+    case CompareResult::kUnknown:
+      return CompareResult::kUnknown;
+    default:
+      LOG(FATAL) << "Invalid CompareResult: " << static_cast<int>(res);
+      return CompareResult::kInconsistent;
+  }
+}
+
+// Internal utility, return the CompareResult resulting from negating
+// the comparison.
+CompareResult Negate(CompareResult res) {
+  switch (res) {
+    case CompareResult::kInconsistent:
+      return CompareResult::kInconsistent;
+    case CompareResult::kUnknown:
+      return CompareResult::kUnknown;
+    default:
+      return CompareResult(~static_cast<int>(res) & static_cast<int>(CompareResult::kUnknown));
+  }
+}
+
+// Internal utility, extract constant offsets out of the two sides of
+// a comparison.  Given lhs and rhs, return a tuple of three elements
+// (lhs_inner, rhs_inner, offset), such that (lhs OP rhs) and
+// (lhs_inner OP rhs_inner + offset) are equivalent.
+std::tuple<PrimExpr, PrimExpr, int64_t> ExtractOffsets(const PrimExpr& lhs, const PrimExpr& rhs) {
+  auto extract_offset = [](const PrimExpr& expr) -> std::pair<PrimExpr, int64_t> {
+    PVar<PrimExpr> x;
+    PVar<IntImm> c;
+    if ((x + c).Match(expr)) {
+      return {x.Eval(), c.Eval()->value};
+    } else if ((x - c).Match(expr)) {
+      return {x.Eval(), -c.Eval()->value};
+    } else if (c.Match(expr)) {
+      return {0, c.Eval()->value};
+    } else {
+      return {expr, 0};
+    }
+  };
+
+  auto lhs_split = extract_offset(lhs);
+  auto rhs_split = extract_offset(rhs);
+  return {lhs_split.first, rhs_split.first, rhs_split.second - lhs_split.second};
+}
+
+}  // namespace
+
+std::optional<TransitiveComparisonAnalyzer::Impl::Comparison>
+TransitiveComparisonAnalyzer::Impl::FromExpr(const PrimExpr& expr) {
+  CompareResult res;
+  PVar<PrimExpr> x, y;
+  if ((x <= y).Match(expr)) {
+    res = CompareResult::kLE;
+  } else if ((x >= y).Match(expr)) {
+    res = CompareResult::kGE;
+  } else if ((x < y).Match(expr)) {
+    res = CompareResult::kLT;
+  } else if ((x > y).Match(expr)) {
+    res = CompareResult::kGT;
+  } else if ((x == y).Match(expr)) {
+    res = CompareResult::kEQ;
+  } else if ((x != y).Match(expr)) {
+    res = CompareResult::kNE;
+  } else {
+    return std::nullopt;
+  }
+
+  PrimExpr lhs_expr = x.Eval();
+  PrimExpr rhs_expr = y.Eval();
+
+  if (lhs_expr.as<IntImmNode>() && rhs_expr.as<IntImmNode>()) {
+    return std::nullopt;
+  }
+
+  auto [lhs, rhs, offset] = ExtractOffsets(lhs_expr, rhs_expr);
+  Key lhs_key = ExprToKey(lhs);
+  Key rhs_key = ExprToKey(rhs);
+
+  return Comparison(lhs_key, rhs_key, offset, res);
+}
+
+TransitiveComparisonAnalyzer::Impl::Comparison::Comparison(Key lhs, Key rhs, int64_t offset,
+                                                           CompareResult result)
+    : lhs_(lhs), rhs_(rhs), offset_(offset), result_(result) {
+  if (result_ == CompareResult::kLT) {
+    result_ = CompareResult::kLE;
+    offset_ -= 1;
+  }
+  if (result_ == CompareResult::kGT) {
+    result_ = CompareResult::kGE;
+    offset_ += 1;
+  }
+}
+
+std::optional<TransitiveComparisonAnalyzer::Impl::Key>
+TransitiveComparisonAnalyzer::Impl::ExprToPreviousKey(const PrimExpr& expr) const {
+  auto it = expr_to_key.find(expr);
+  if (it != expr_to_key.end()) {
+    return it->second;
+  } else {
+    return std::nullopt;
+  }
+}
+
+TransitiveComparisonAnalyzer::Impl::Key TransitiveComparisonAnalyzer::Impl::ExprToKey(
+    const PrimExpr& expr) {
+  if (auto prev = ExprToPreviousKey(expr)) {
+    return prev.value();
+  } else {
+    Key new_key = Key(expr_to_key.size());
+    expr_to_key[expr] = new_key;
+    return new_key;
+  }
+}
+
+bool TransitiveComparisonAnalyzer::Impl::Comparison::IsNormalized() const {
+  // These < and > should be removed during normalization.
+  return result_ != CompareResult::kLT && result_ != CompareResult::kGT;
+}
+
+std::optional<TransitiveComparisonAnalyzer::Impl::Comparison>
+TransitiveComparisonAnalyzer::Impl::Comparison::WithLHS(Key new_lhs) const {
+  if (new_lhs == lhs_) {
+    return *this;
+  } else if (new_lhs == rhs_) {
+    return Comparison(rhs_, lhs_, -offset_, Reverse(result_));
+  } else {
+    return std::nullopt;
+  }
+}
+
+TransitiveComparisonAnalyzer::Impl::Comparison
+TransitiveComparisonAnalyzer::Impl::Comparison::Negated() const {
+  return Comparison(lhs_, rhs_, offset_, Negate(result_));
+}
+
+bool TransitiveComparisonAnalyzer::Impl::Comparison::Implies(
+    const TransitiveComparisonAnalyzer::Impl::Comparison& other) const {
+  ICHECK(lhs_ == other.lhs_);
+  ICHECK(rhs_ == other.rhs_);
+  ICHECK(IsNormalized());
+  ICHECK(other.IsNormalized());
+
+  if (result_ == other.result_ && offset_ == other.offset_) {
+    // if c1 == c2, x != y + c1 => x != y + c2
+    // if c1 == c2, x == y + c1 => x == y + c2
+    return true;
+  }
+
+  if (other.result_ == CompareResult::kLE && offset_ <= other.offset_) {
+    if (result_ == CompareResult::kEQ || result_ == CompareResult::kLE) {
+      // if c1 <= c2, x <= y + c1 => x <= y + c2
+      // if c1 <= c2, x == y + c1 => x <= y + c2
+      return true;
+    }
+  }
+
+  if (other.result_ == CompareResult::kGE && offset_ >= other.offset_) {
+    if (result_ == CompareResult::kEQ || result_ == CompareResult::kGE) {
+      // if c1 >= c2, x == y + c1 => x >= y + c2
+      // if c1 >= c2, x >= y + c1 => x >= y + c2
+      return true;
+    }
+  }
+
+  if (other.result_ == CompareResult::kNE) {
+    if (result_ == CompareResult::kEQ && offset_ != other.offset_) {
+      // if c1 != c2, x == y + c1 => x != y + c2
+      return true;
+    }
+
+    if (result_ == CompareResult::kLE && offset_ < other.offset_) {
+      // if c1 < c2, x <= y + c1 => x < y + c2 => x != y + c2
+      return true;
+    }
+
+    if (result_ == CompareResult::kGE && offset_ > other.offset_) {
+      // if c1 != c2, x >= y + c1 => x > y + c2 => x != y + c2
+      return true;
+    }
+  }
+
+  return false;
+}
+
+TransitiveComparisonAnalyzer::TransitiveComparisonAnalyzer() : impl_(std::make_unique<Impl>()) {}
+TransitiveComparisonAnalyzer::~TransitiveComparisonAnalyzer() {}
+
+CompareResult TransitiveComparisonAnalyzer::TryCompare(const PrimExpr& lhs, const PrimExpr& rhs) {
+  return impl_->TryCompare(lhs, rhs);
+}
+
+void TransitiveComparisonAnalyzer::Bind(const Var& var, const PrimExpr& expr, bool allow_override) {
+  impl_->Bind(var, expr, allow_override);
+}
+void TransitiveComparisonAnalyzer::Bind(const Var& var, const Range& range, bool allow_override) {
+  impl_->Bind(var, range, allow_override);
+}
+
+std::function<void()> TransitiveComparisonAnalyzer::EnterConstraint(const PrimExpr& constraint) {
+  return impl_->EnterConstraint(constraint);
+}
+
+void TransitiveComparisonAnalyzer::Impl::AddKnown(const PrimExpr& expr,
+                                                  std::vector<Comparison>* vec) {
+  for (const auto& subexpr : ExtractConstraints(expr)) {
+    if (tir::SideEffect(expr) <= tir::CallEffectKind::kPure) {
+      if (auto cmp = FromExpr(subexpr)) {
+        vec->push_back(cmp.value());
+      }
+    }
+  }
+}
+
+void TransitiveComparisonAnalyzer::Impl::Bind(const tir::Var& var, const Range& range,
+                                              bool allow_override) {
+  auto it = prev_bindings_.find(var);
+  if (it != prev_bindings_.end()) {
+    ExprDeepEqual expr_equal;
+    bool differs_from_previous = !expr_equal(range->min, (*it).second->min) ||
+                                 !expr_equal(range->extent, (*it).second->extent);
+    if (differs_from_previous) {
+      ICHECK(allow_override) << "Binding of variable " << var << " as " << range
+                             << " conflicts with previous binding as " << (*it).second;
+      if (auto key = ExprToPreviousKey(var)) {
+        knowns_.erase(std::remove_if(knowns_.begin(), knowns_.end(),
+                                     [&](const auto& known) { return known.lhs_ == key.value(); }),
+                      knowns_.end());
+      }
+    }
+  }
+
+  prev_bindings_.Set(var, range);
+
+  if (is_const_int(range->extent, 1)) {
+    AddKnown(var == range->min, &knowns_);
+  } else {
+    AddKnown(var >= range->min, &knowns_);
+    AddKnown(var < range->min + range->extent, &knowns_);
+  }
+}
+
+void TransitiveComparisonAnalyzer::Impl::Bind(const tir::Var& var, const PrimExpr& expr,
+                                              bool allow_override) {
+  Bind(var, Range::FromMinExtent(expr, 1), allow_override);
+}
+
+std::function<void()> TransitiveComparisonAnalyzer::Impl::EnterConstraint(const PrimExpr& expr) {
+  size_t old_literal_size = scoped_knowns_.size();
+  AddKnown(expr, &scoped_knowns_);
+  size_t new_literal_size = scoped_knowns_.size();
+
+  PrimExpr temp = expr;
+  auto frecover = [old_literal_size, new_literal_size, this, temp]() {

Review Comment:
   Removed `temp` entirely instead.  It was used during debugging to know which constraints were being exited, but is no longer necessary.  (That said, I kept forgetting the init captures in the first place.  :P)



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] Lunderberg commented on a diff in pull request #12863: [TIR][Arith] Prove conditionals by transitively applying knowns

Posted by GitBox <gi...@apache.org>.
Lunderberg commented on code in PR #12863:
URL: https://github.com/apache/tvm/pull/12863#discussion_r989445868


##########
src/arith/transitive_comparison_analyzer.cc:
##########
@@ -0,0 +1,683 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file tvm/arith/transitive_comparison_analyzer.cc
+ */
+
+#include <tvm/arith/analyzer.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/expr.h>
+
+#include <optional>
+#include <vector>
+
+#include "constraint_extract.h"
+#include "pattern_match.h"
+
+namespace tvm {
+namespace arith {
+
+using namespace tir;
+
+class TransitiveComparisonAnalyzer::Impl {
+ public:
+  /* \brief Using previously specified knowns, compare the expressions provided
+   *
+   * \param lhs The left-hand side of the comparison
+   *
+   * \param rhs The right-hand side of the comparison
+   *
+   * \return The most specific result that can be proven about the
+   * comparison.  If nothing can be proven, returns kUnknown.
+   */
+  CompareResult TryCompare(const PrimExpr& lhs, const PrimExpr& rhs) const;
+
+  /*! \brief Bind a variable as being equal to a known expression
+   *
+   * \param var The variable of interest.
+   * \param expr The bound expression
+   * \param allow_override Whether to allow override of existing information.
+   */
+  void Bind(const tir::Var& var, const PrimExpr& expr, bool allow_override = false);
+
+  /*! \brief Bind a variable as being within a specified range
+   *
+   * \param var The variable of interest.
+   * \param range The known range
+   * \param allow_override Whether to allow override of existing information.
+   */
+  void Bind(const tir::Var& var, const Range& expr, bool allow_override = false);
+
+  /*!
+   * \brief Update the internal state to enter constraint.
+   * \param constraint A constraint expression.
+   *
+   * \return An exit function that must be called to cleanup.  May be
+   * `nullptr`, if no cleanup is required.
+   */
+  std::function<void()> EnterConstraint(const PrimExpr& expr);
+
+ private:
+  // Utility class to avoid needing to repeatedly call ExprDeepEqual
+  enum class Key : size_t {};
+  std::optional<Key> ExprToPreviousKey(const PrimExpr& expr) const;
+  Key ExprToKey(const PrimExpr& expr);
+  std::unordered_map<PrimExpr, Key, StructuralHash, StructuralEqual> expr_to_key;
+
+  /*! \brief Internal representation of a comparison operator */
+  struct Comparison {
+    /*! \brief Construct a comparison that represents `lhs OP rhs +
+     * offset`, where the operation is specified by the CompareResult.
+     */
+    Comparison(Key lhs, Key rhs, int64_t offset, CompareResult result);
+
+    /*! \brief Utility function to validate that all GT and LT results
+     *  have been normalized out
+     */
+    bool IsNormalized() const;
+
+    /*! \brief Move the specified expression to the LHS.
+     *
+     * \param new_lhs The argument that should be moved to the LHS of the
+     * comparison.
+     *
+     * \return If possible, returns a comparison that is equivalent to
+     * the current comparison, but with the specified LHS.  If not
+     * possible, returns nullopt.
+     */
+    std::optional<Comparison> WithLHS(Key new_lhs) const;
+
+    /*! \brief Create the negation of the current comparison */
+    Comparison Negated() const;
+
+    /*! \brief Check the this comparison implies
+     *
+     * Returns true if this comparison being true implies that the
+     * other comparison must also be true.  Returns false if the other
+     * comparison cannot be shown to be true.
+     */
+    bool Implies(const Comparison& other) const;
+
+    // The LHS of the comparison
+    Key lhs_;
+
+    // The RHS of the comparison, not including any constant offset.
+    Key rhs_;
+
+    // Additive offset on rhs
+    int64_t offset_{0};
+
+    // The comparison operator.
+    CompareResult result_{CompareResult::kInconsistent};
+  };
+
+  /*! \brief Generate a Comparison representing the given expression */
+  std::optional<Comparison> FromExpr(const PrimExpr& expr);
+
+  /*! \brief Utility function used by Bind and EnterConstraint
+   *
+   * \param expr The comparison expression, to be converted into
+   * internal Comparison objects.
+   *
+   * \param vec The vector to which the Comparison objects should be
+   * appended.
+   */
+  void AddKnown(const PrimExpr& expr, std::vector<Comparison>* vec);
+
+  /*! \brief Attempt to compare, starting at the lhs.
+   *
+   * Taking each available `Comparison` as a node edge, search for a
+   * path from lhs to rhs.  For example, the priors (a<=b), (b<=c+1)
+   * and (c<=d-5) can be used to prove that (a<=d-4).
+   *
+   * \param lhs The left-hand side of the comparison
+   *
+   * \param rhs The right-hand side of the comparison
+   *
+   * \return The result of the comparison
+   */
+  CompareResult TryCompareFromLHS(Key lhs_key, Key rhs_key, int64_t offset, const PrimExpr& lhs,
+                                  const PrimExpr& rhs) const;
+
+  /*! \brief Previous Range bindings
+   *
+   * Tracked separatedly to handle the `allow_override` option used by
+   * all sub-analyzers when binding variables.
+   */
+  Map<Var, Range> prev_bindings_;
+
+  /*! \brief Known comparisons based on definitionally-true statements
+   *
+   * For example, a Let binding, or the range of an iterator.
+   */
+  std::vector<Comparison> knowns_;
+
+  /*! \brief Known comparisons based on of scope-based statements
+   *
+   * For example, the condition of an IfThenElse, which is known to be
+   * true while within the if scope.
+   */
+  std::vector<Comparison> scoped_knowns_;
+};
+
+namespace {
+
+// Internal utility, return the CompareResult resulting from swapping
+// the left-hand side with the right-hand side.
+CompareResult Reverse(CompareResult res) {
+  switch (res) {
+    case CompareResult::kInconsistent:
+      return CompareResult::kInconsistent;
+    case CompareResult::kEQ:
+      return CompareResult::kEQ;
+    case CompareResult::kLT:
+      return CompareResult::kGT;
+    case CompareResult::kLE:
+      return CompareResult::kGE;
+    case CompareResult::kGT:
+      return CompareResult::kLT;
+    case CompareResult::kGE:
+      return CompareResult::kLE;
+    case CompareResult::kNE:
+      return CompareResult::kNE;
+    case CompareResult::kUnknown:
+      return CompareResult::kUnknown;
+    default:
+      LOG(FATAL) << "Invalid CompareResult: " << static_cast<int>(res);
+      return CompareResult::kInconsistent;
+  }
+}
+
+// Internal utility, return the CompareResult resulting from negating
+// the comparison.
+CompareResult Negate(CompareResult res) {
+  switch (res) {
+    case CompareResult::kInconsistent:
+      return CompareResult::kInconsistent;
+    case CompareResult::kUnknown:
+      return CompareResult::kUnknown;
+    default:
+      return CompareResult(~static_cast<int>(res) & static_cast<int>(CompareResult::kUnknown));
+  }
+}
+
+// Internal utility, extract constant offsets out of the two sides of
+// a comparison.  Given lhs and rhs, return a tuple of three elements
+// (lhs_inner, rhs_inner, offset), such that (lhs OP rhs) and
+// (lhs_inner OP rhs_inner + offset) are equivalent.
+std::tuple<PrimExpr, PrimExpr, int64_t> ExtractOffsets(const PrimExpr& lhs, const PrimExpr& rhs) {
+  auto extract_offset = [](const PrimExpr& expr) -> std::pair<PrimExpr, int64_t> {
+    PVar<PrimExpr> x;
+    PVar<IntImm> c;
+    if ((x + c).Match(expr)) {
+      return {x.Eval(), c.Eval()->value};
+    } else if ((x - c).Match(expr)) {
+      return {x.Eval(), -c.Eval()->value};
+    } else if (c.Match(expr)) {
+      return {0, c.Eval()->value};
+    } else {
+      return {expr, 0};
+    }
+  };
+
+  auto lhs_split = extract_offset(lhs);
+  auto rhs_split = extract_offset(rhs);
+  return {lhs_split.first, rhs_split.first, rhs_split.second - lhs_split.second};
+}
+
+}  // namespace
+
+std::optional<TransitiveComparisonAnalyzer::Impl::Comparison>
+TransitiveComparisonAnalyzer::Impl::FromExpr(const PrimExpr& expr) {
+  CompareResult res;
+  PVar<PrimExpr> x, y;
+  if ((x <= y).Match(expr)) {
+    res = CompareResult::kLE;
+  } else if ((x >= y).Match(expr)) {
+    res = CompareResult::kGE;
+  } else if ((x < y).Match(expr)) {
+    res = CompareResult::kLT;
+  } else if ((x > y).Match(expr)) {
+    res = CompareResult::kGT;
+  } else if ((x == y).Match(expr)) {
+    res = CompareResult::kEQ;
+  } else if ((x != y).Match(expr)) {
+    res = CompareResult::kNE;
+  } else {
+    return std::nullopt;
+  }
+
+  PrimExpr lhs_expr = x.Eval();
+  PrimExpr rhs_expr = y.Eval();
+
+  if (lhs_expr.as<IntImmNode>() && rhs_expr.as<IntImmNode>()) {
+    return std::nullopt;
+  }
+
+  auto [lhs, rhs, offset] = ExtractOffsets(lhs_expr, rhs_expr);
+  Key lhs_key = ExprToKey(lhs);
+  Key rhs_key = ExprToKey(rhs);
+
+  return Comparison(lhs_key, rhs_key, offset, res);
+}
+
+TransitiveComparisonAnalyzer::Impl::Comparison::Comparison(Key lhs, Key rhs, int64_t offset,
+                                                           CompareResult result)
+    : lhs_(lhs), rhs_(rhs), offset_(offset), result_(result) {
+  if (result_ == CompareResult::kLT) {
+    result_ = CompareResult::kLE;
+    offset_ -= 1;

Review Comment:
   Sounds good, and added a description in the constructor, along with an explanation of why the internal representation normalizes everything to GE/LE, instead of normalizing to GT/LT.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org