You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2022/09/13 04:05:34 UTC

[GitHub] [tvm] zhangyicole opened a new pull request, #12761: [TIR, analysis] Add expr hash sort in ExprDeepEqual.

zhangyicole opened a new pull request, #12761:
URL: https://github.com/apache/tvm/pull/12761

   Refer to issue https://github.com/apache/tvm/issues/10211. The CSE pass can't handle commutativity because the arith system may not be able to do the commutativity。
   
   The determination of the equality of two expressions (PrimExpr) is to use the method of structured determination, that is, to traverse the hierarchical structure of the two expressions, and to judge while traversing, if the structures of the two expressions are the same, and the smallest If the child nodes (nodes that cannot be traversed, generally such as Var) are the same, the expressions are considered to be the same.
   
   Before performing PrimExpr comparison, a series of rewrite rules will be used to rewrite expressions to solve some operational problems, such as x * y + x * z will be rewritten as x * (y + z), so that to deal with distributivity.
   However, commutativity cannot be rewritten due to the characteristics of rewrite (it will fall into an infinite loop). This makes it impossible to compare the equality of some expressions, such as a * b *c != a * c * b, (a * b) * c != a * (b * c).
   
   To solve this problem, one solution is to sort and rewrite the expressions according to the `StructuralHash` of the Var nodes in the expressions before comparing the expressions. If two expressions are equivalent if they satisfy the commutativity, then they will definitely produce equivalent expressions of the same structure after sorting.
   
   Under the assumption that the two expressions have the same structure, the determination condition that the two expressions satisfying the commutativity are the same can be further refined to the same set of all elements in the sub-expressions satisfying the commutativity. The sub-expressions satisfying the commutativity in the expression can be grasped by constructing the expression syntax tree. The sub-expressions satisfying the commutativity are the sub-trees of the expression tree whose child nodes are identical (the child nodes are OP).
   
   An example:
   ```
       Var: a, b, c, d, e
   
       StructuralHash(a > b >c >d > e)
   
       cse_var_1  = a * b *c + d + e
   
       cse_var_2 = b * a * c + e + d
   ```
   Sort and rewrite cse_var_1 and cse_var_2, first extract their first subexpressions a * b *c and b * a * c, and rewrite them as a * b * c, and then extract the sub-expressions again ( a * b *c) + d + e and (a *b * c) + e + d, rewriting the sort as (a * b *c) + d + e, get the same expression.
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] zhangyicole commented on pull request #12761: [TIR, analysis] Add expr hash sort in ExprDeepEqual

Posted by GitBox <gi...@apache.org>.
zhangyicole commented on PR #12761:
URL: https://github.com/apache/tvm/pull/12761#issuecomment-1248891043

   > Hello!
   > 
   > Thank you for the discussion and the PR. Just a few thoughts here:
   > 
   > * **1.** Indeed, we knew that the existing `Analyzer::Simplify()` would not deal with commutativity, because it does not implement a normalization procedure that is guaranteed to converge towards the normal form (which would indeed imply sorting sub-terms, etc). Rather, all it does is it tries to "do its best" by rewriting some known patterns, with no guarantees to converge to a normal form. For this reason, they could not deal with commutativity in `Simplify()`, because that would indeed lead to non-terminating rewrite sequences (or more realistically return junk, as in practice they stop to rewrite after N rewrites are done). It often works fairly well in practice, but there is no guarantee of being complete (it behaves like a heuristic). However, it really must be correct (i.e, the result of simplify() must rely be equivalent to its input with algebraic laws).
   > * **2.** Before trying to make `Analyzer::Simplify()` able to deal with with commutativity, it could be useful to **see if people are in practice facing the issue where a lot of redundant computations appear, but written differently due to the commutativity of some operators (like + and *)**. If so, it would be cool to see such concrete examples. To be honest, I don't even think that that many people are turning ON the already existing `bool identify_equiv_terms` of the CSE pass `Pass CommonSubexprElimTIR`, which uses `Analyzer::Simplify()`, which itself does what it can with associativity, neutral elements, etc. These things are probably pretty rare, and commutativity is probably too.
   > 
   > Although I designed the pass in a way that it can potentially identify terms that are equivalent [according to any equivalence relation](https://github.com/apache/tvm/blob/main/src/tir/transforms/common_subexpr_elim_tools.cc#L730-#L750) (instead of just the syntactical equality `ExprDeepEqual`), it might not necessary be often needed in practice. This design that allows to use a custom equivalence relation did not cost much more, so I went for it, in case it could become useful to someone one day for a particular case, so that it would not be needed to write another CSE pass for dealing with that. But I don't necessary think that we should do comparisons modulo equivalence often, and this is why by default `bool identify_equiv_terms` of the `Pass CommonSubexprElimTIR` is set to false.
   > 
   > * **3.** If we decide that `Analyzer::Simplify()` (or another new Analyzer!) should deal with commutativity, then it should itself deal with commutativity, rather than baking commutativity into `ExprDeepEqual` which is supposed to be just a deep **syntactical** equality, and is used as such in many many places of TVM's codebase (not just the CSE). So clearly it should not being changed (as @tqchen noted too).
   > * **4.** Remember that normalizing terms properly in order to deal with commutativity (which indeed includes sorting sub-terms) will likely be computationally expensive, which will make compilation of ML models longer. Actually, just the smaller work that `Analyzer::Simplify()` does is already time consuming, and that's probably why people leave the `bool identify_equiv_terms` of the CSE pass set to false (which is the default behavior, as I wrote earlier). It might make people want even less to turn ON this `bool identify_equiv_terms`. Perhaps the pseudo-normalization that `Analyzer::Simplify()` does is not too bad as a compromise: still usable in practice when needed (i.e does not take too long), and deals with most simplifications needed (although not commutativity).
   > * **5.** If all the other algebraic properties (associativity, simplification of neutral elements, etc) are still done by the **pseudo**-normalization `Analyzer::Simplify()` that is not guaranteed to find a normal form, I am not sure that a "normalizer for commutativity" built on top would be complete -even just in regard to commutativity. Is it worth it to then make `Analyzer::Simplify()` slower while still being incomplete?
   > 
   > Thanks!
   
   Hi,@FranckQC,thanks for the questions and comments。
   I quite agree with what you said about the complexity of the `Analyzer::Simplify()`, the overhead of the `Analyzer::Simplify()` is very expensive, changes to the `ExprDeepEqual` will not improve the function and may slow down its performance, so, as suggested by tqchen, make it as A subclass, maybe a good approach.
   
   The source of this submission is that in my previous use, there was a size judgment for the input and output size of the operator containing the reshape operation, and I rewrote deep equal to meet my needs. After seeing your issue, I think this part of the rewrite is helpful, so submit this PR.
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] tqchen commented on pull request #12761: [TIR, analysis] Add expr hash sort in ExprDeepEqual

Posted by GitBox <gi...@apache.org>.
tqchen commented on PR #12761:
URL: https://github.com/apache/tvm/pull/12761#issuecomment-1262368947

   of course


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] masahi commented on pull request #12761: [TIR, analysis] Add expr hash sort in ExprDeepEqual

Posted by GitBox <gi...@apache.org>.
masahi commented on PR #12761:
URL: https://github.com/apache/tvm/pull/12761#issuecomment-1281950599

   Please update the PR title and description to reflect the current status. In particular, please make it more concise and explain what the goal is clearly. 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] zhangyicole commented on a diff in pull request #12761: [TIR, analysis] Add expr hash sort in ExprDeepEqual

Posted by GitBox <gi...@apache.org>.
zhangyicole commented on code in PR #12761:
URL: https://github.com/apache/tvm/pull/12761#discussion_r997906925


##########
python/tvm/tir/analysis/analysis.py:
##########
@@ -331,3 +331,32 @@ def OOBChecker():
         The result pass
     """
     return _ffi_api.OOBChecker()  # type: ignore
+
+
+def communicative_deep_equal(lhs: PrimExpr, rhs: PrimExpr) -> bool:

Review Comment:
   Thanks for the review!It was a typo, I have corrected it.



##########
python/tvm/tir/analysis/analysis.py:
##########
@@ -331,3 +331,32 @@ def OOBChecker():
         The result pass
     """
     return _ffi_api.OOBChecker()  # type: ignore
+
+
+def communicative_deep_equal(lhs: PrimExpr, rhs: PrimExpr) -> bool:
+    """Deeply compare two nested expressions that have communicative equality.
+
+    Parameters
+    ----------
+    lhs : PrimExpr
+        The left operand.
+
+    rhs : PrimExpr
+        The right operand.
+
+    Returns
+    -------
+    result : bool
+        The comparison result
+
+    Note
+    ----
+
+    This function is an extension of py:func:`tvm.ir.expr_deep_equal`, it can
+    handle commutativity. The function will not return true for (x + y) vs (y + x).

Review Comment:
   Thanks for the review!The description here is indeed misleading, I have replaced "The function" with the exact function name to avoid misunderstanding.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] masahi commented on a diff in pull request #12761: [TIR, analysis] Add expr hash sort in ExprDeepEqual

Posted by GitBox <gi...@apache.org>.
masahi commented on code in PR #12761:
URL: https://github.com/apache/tvm/pull/12761#discussion_r997846843


##########
python/tvm/tir/analysis/analysis.py:
##########
@@ -331,3 +331,32 @@ def OOBChecker():
         The result pass
     """
     return _ffi_api.OOBChecker()  # type: ignore
+
+
+def communicative_deep_equal(lhs: PrimExpr, rhs: PrimExpr) -> bool:
+    """Deeply compare two nested expressions that have communicative equality.
+
+    Parameters
+    ----------
+    lhs : PrimExpr
+        The left operand.
+
+    rhs : PrimExpr
+        The right operand.
+
+    Returns
+    -------
+    result : bool
+        The comparison result
+
+    Note
+    ----
+
+    This function is an extension of py:func:`tvm.ir.expr_deep_equal`, it can
+    handle commutativity. The function will not return true for (x + y) vs (y + x).

Review Comment:
   By "The function", which function you are talking about here? If you start by saying "This function" and continue with "The function" in the next sentence, people would think that they refer to the same function.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] zhangyicole commented on pull request #12761: [TIR, analysis] Add expr hash sort in ExprDeepEqual

Posted by GitBox <gi...@apache.org>.
zhangyicole commented on PR #12761:
URL: https://github.com/apache/tvm/pull/12761#issuecomment-1248891229

   > It would be useful to make it as a different equality comparator (rather than change DeepEqual's behavior, e.g. we can do CommunicativeDeepEqual as a subclass), as communicative rewrite is something that goes deeper.
   > 
   > Another possibility is add a canonicalization pass to canonicalize the expressions before CSE
   
   Hi,@tqchen,Thanks a lot for the review!
   I agree that it would be useful to make it as a different equality comparator,
   this will not have an impact on TVM's infrastructure. So how about making it a new subclass, as you suggested?


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] masahi commented on a diff in pull request #12761: [TIR, analysis] Add expr hash sort in ExprDeepEqual

Posted by GitBox <gi...@apache.org>.
masahi commented on code in PR #12761:
URL: https://github.com/apache/tvm/pull/12761#discussion_r997835193


##########
python/tvm/tir/analysis/analysis.py:
##########
@@ -331,3 +331,32 @@ def OOBChecker():
         The result pass
     """
     return _ffi_api.OOBChecker()  # type: ignore
+
+
+def communicative_deep_equal(lhs: PrimExpr, rhs: PrimExpr) -> bool:

Review Comment:
   Is "communicative" a word? Do you mean "commutative"?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] FranckQC commented on pull request #12761: [TIR, analysis] Add expr hash sort in ExprDeepEqual

Posted by GitBox <gi...@apache.org>.
FranckQC commented on PR #12761:
URL: https://github.com/apache/tvm/pull/12761#issuecomment-1262304347

   > @FranckQC @masahi can you help to take a look
   
   Sure, will do on Monday if that's ok :)


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] masahi commented on pull request #12761: [TIR, analysis] Add expr hash sort in ExprDeepEqual

Posted by GitBox <gi...@apache.org>.
masahi commented on PR #12761:
URL: https://github.com/apache/tvm/pull/12761#issuecomment-1246669670

   I can take a look at this tomorrow.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] tqchen commented on pull request #12761: [TIR, analysis] Add expr hash sort in ExprDeepEqual

Posted by GitBox <gi...@apache.org>.
tqchen commented on PR #12761:
URL: https://github.com/apache/tvm/pull/12761#issuecomment-1262276883

   @FranckQC @masahi can you help to take a look


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] masahi commented on a diff in pull request #12761: [TIR, analysis] Add CommutativeDeepEqual to handle commutativity in expression comparison

Posted by GitBox <gi...@apache.org>.
masahi commented on code in PR #12761:
URL: https://github.com/apache/tvm/pull/12761#discussion_r1002150356


##########
src/tir/analysis/deep_equal.cc:
##########
@@ -26,10 +26,324 @@
 #include <tvm/node/structural_equal.h>
 #include <tvm/runtime/registry.h>
 #include <tvm/tir/analysis.h>
-
+#include <tvm/tir/expr.h>
+#include <tvm/tir/expr_functor.h>
+#include <tvm/tir/stmt.h>
+#include <tvm/tir/stmt_functor.h>  // For the class StmtExprMutator
 namespace tvm {
 namespace tir {
 
+class SortExprByHashVisitor : public ExprVisitor {
+ public:
+  void VisitExpr_(const VarNode* op) final;
+  void VisitExpr_(const SizeVarNode* op) final;
+  void VisitExpr_(const LoadNode* op) final;
+  void VisitExpr_(const BufferLoadNode* op) final;
+  void VisitExpr_(const ProducerLoadNode* op) final;
+  void VisitExpr_(const LetNode* op) final;
+  void VisitExpr_(const CallNode* op) final;
+  void VisitExpr_(const AddNode* op) final;
+  void VisitExpr_(const SubNode* op) final;
+  void VisitExpr_(const MulNode* op) final;
+  void VisitExpr_(const DivNode* op) final;
+  void VisitExpr_(const ModNode* op) final;
+  void VisitExpr_(const FloorDivNode* op) final;
+  void VisitExpr_(const FloorModNode* op) final;
+  void VisitExpr_(const MinNode* op) final;
+  void VisitExpr_(const MaxNode* op) final;
+  void VisitExpr_(const EQNode* op) final;
+  void VisitExpr_(const NENode* op) final;
+  void VisitExpr_(const LTNode* op) final;
+  void VisitExpr_(const LENode* op) final;
+  void VisitExpr_(const GTNode* op) final;
+  void VisitExpr_(const GENode* op) final;
+  void VisitExpr_(const AndNode* op) final;
+  void VisitExpr_(const OrNode* op) final;
+  void VisitExpr_(const ReduceNode* op) final;
+  void VisitExpr_(const CastNode* op) final;
+  void VisitExpr_(const NotNode* op) final;
+  void VisitExpr_(const SelectNode* op) final;
+  void VisitExpr_(const RampNode* op) final;
+  void VisitExpr_(const BroadcastNode* op) final;
+  void VisitExpr_(const ShuffleNode* op) final;
+  void VisitExpr_(const IntImmNode* op) final;
+  void VisitExpr_(const FloatImmNode* op) final;
+  void VisitExpr_(const StringImmNode* op) final;
+  void VisitExpr_(const AnyNode* op) final;
+
+  std::vector<std::pair<int, std::vector<PrimExpr>>> op_stack;
+  int cur_max_tree_idx = 0;
+  int pre_max_tree_idx = 0;
+
+ private:
+  std::string pre_bin_op = "null";
+  int stack_idx = 0;
+  int cur_tree_idx = 0;
+};
+
+#define TVM_DEFINE_BIN_OP_SORT_BY_HASH_VISITOR(OpName)                 \
+  void SortExprByHashVisitor::VisitExpr_(const OpName* op) {           \
+    std::string cur_bin_op = op->_type_key;                            \
+    std::string cur_pre_bin_op = pre_bin_op;                           \
+    int cur_stack_idx = stack_idx;                                     \
+    if (cur_bin_op != cur_pre_bin_op || cur_pre_bin_op == "null") {    \
+      std::vector<PrimExpr> expr_stack;                                \
+      if (cur_tree_idx + 1 > pre_max_tree_idx) {                       \
+        return;                                                        \
+      }                                                                \
+      op_stack.emplace_back(std::make_pair(cur_tree_idx, expr_stack)); \
+      cur_tree_idx += 1;                                               \
+      cur_max_tree_idx = std::max(cur_max_tree_idx, cur_tree_idx);     \
+      cur_stack_idx = op_stack.size();                                 \
+      stack_idx = cur_stack_idx;                                       \
+      cur_pre_bin_op = cur_bin_op;                                     \
+      pre_bin_op = cur_pre_bin_op;                                     \
+    }                                                                  \
+    int cur_tree_idx_temp = cur_tree_idx;                              \
+    if ((op->a).as<OpName>() == nullptr) {                             \
+      op_stack[stack_idx - 1].second.emplace_back(op->a);              \
+    }                                                                  \
+    if ((op->b).as<OpName>() == nullptr) {                             \
+      op_stack[stack_idx - 1].second.emplace_back(op->b);              \
+    }                                                                  \
+    this->VisitExpr(op->a);                                            \
+    pre_bin_op = cur_pre_bin_op;                                       \
+    stack_idx = cur_stack_idx;                                         \
+    cur_tree_idx = cur_tree_idx_temp;                                  \
+    this->VisitExpr(op->b);                                            \
+    pre_bin_op = cur_pre_bin_op;                                       \
+    stack_idx = cur_stack_idx;                                         \
+    cur_tree_idx = cur_tree_idx_temp;                                  \
+  }
+
+#define TVM_DEFINE_BIN_OP_NO_SORT_BY_HASH_VISITOR(OpName)    \
+  void SortExprByHashVisitor::VisitExpr_(const OpName* op) { \
+    std::string cur_pre_bin_op = "null";                     \
+    pre_bin_op = cur_pre_bin_op;                             \
+    this->VisitExpr(op->a);                                  \
+    pre_bin_op = cur_pre_bin_op;                             \
+    this->VisitExpr(op->b);                                  \
+  }
+
+#define TVM_DEFINE_PASS_OP_SORT_BY_HASH_VISITOR(OpName) \
+  void SortExprByHashVisitor::VisitExpr_(const OpName* op) { return; }
+
+TVM_DEFINE_BIN_OP_SORT_BY_HASH_VISITOR(AddNode)
+TVM_DEFINE_BIN_OP_SORT_BY_HASH_VISITOR(MulNode)
+TVM_DEFINE_BIN_OP_SORT_BY_HASH_VISITOR(AndNode)
+TVM_DEFINE_BIN_OP_SORT_BY_HASH_VISITOR(OrNode)
+
+TVM_DEFINE_BIN_OP_NO_SORT_BY_HASH_VISITOR(SubNode)
+TVM_DEFINE_BIN_OP_NO_SORT_BY_HASH_VISITOR(DivNode)
+TVM_DEFINE_BIN_OP_NO_SORT_BY_HASH_VISITOR(ModNode)
+TVM_DEFINE_BIN_OP_NO_SORT_BY_HASH_VISITOR(FloorDivNode)
+TVM_DEFINE_BIN_OP_NO_SORT_BY_HASH_VISITOR(FloorModNode)
+TVM_DEFINE_BIN_OP_NO_SORT_BY_HASH_VISITOR(MinNode)
+TVM_DEFINE_BIN_OP_NO_SORT_BY_HASH_VISITOR(MaxNode)
+TVM_DEFINE_BIN_OP_NO_SORT_BY_HASH_VISITOR(EQNode)
+TVM_DEFINE_BIN_OP_NO_SORT_BY_HASH_VISITOR(NENode)
+TVM_DEFINE_BIN_OP_NO_SORT_BY_HASH_VISITOR(LTNode)
+TVM_DEFINE_BIN_OP_NO_SORT_BY_HASH_VISITOR(LENode)
+TVM_DEFINE_BIN_OP_NO_SORT_BY_HASH_VISITOR(GTNode)
+TVM_DEFINE_BIN_OP_NO_SORT_BY_HASH_VISITOR(GENode)
+
+TVM_DEFINE_PASS_OP_SORT_BY_HASH_VISITOR(VarNode)
+TVM_DEFINE_PASS_OP_SORT_BY_HASH_VISITOR(SizeVarNode)
+TVM_DEFINE_PASS_OP_SORT_BY_HASH_VISITOR(LoadNode)
+TVM_DEFINE_PASS_OP_SORT_BY_HASH_VISITOR(BufferLoadNode)
+TVM_DEFINE_PASS_OP_SORT_BY_HASH_VISITOR(ProducerLoadNode)
+TVM_DEFINE_PASS_OP_SORT_BY_HASH_VISITOR(LetNode)
+TVM_DEFINE_PASS_OP_SORT_BY_HASH_VISITOR(CallNode)
+TVM_DEFINE_PASS_OP_SORT_BY_HASH_VISITOR(ReduceNode)
+TVM_DEFINE_PASS_OP_SORT_BY_HASH_VISITOR(CastNode)
+TVM_DEFINE_PASS_OP_SORT_BY_HASH_VISITOR(NotNode)
+TVM_DEFINE_PASS_OP_SORT_BY_HASH_VISITOR(SelectNode)
+TVM_DEFINE_PASS_OP_SORT_BY_HASH_VISITOR(RampNode)
+TVM_DEFINE_PASS_OP_SORT_BY_HASH_VISITOR(BroadcastNode)
+TVM_DEFINE_PASS_OP_SORT_BY_HASH_VISITOR(ShuffleNode)
+TVM_DEFINE_PASS_OP_SORT_BY_HASH_VISITOR(IntImmNode)
+TVM_DEFINE_PASS_OP_SORT_BY_HASH_VISITOR(FloatImmNode)
+TVM_DEFINE_PASS_OP_SORT_BY_HASH_VISITOR(StringImmNode)
+TVM_DEFINE_PASS_OP_SORT_BY_HASH_VISITOR(AnyNode)
+
+class SortExprByHashMutator : public StmtExprMutator {
+ public:
+  void Init() {

Review Comment:
   Replace with constructor



##########
src/tir/analysis/deep_equal.cc:
##########
@@ -70,10 +384,45 @@ bool ExprDeepEqual::operator()(const PrimExpr& lhs, const PrimExpr& rhs) const {
   return DeepCmpSEqualHandler().SEqualReduce(lhs, rhs, false, NullOpt);
 }
 
+class CommutativeDeepEqual : public ExprDeepEqual {
+ public:
+  bool operator()(const PrimExpr& lhs, const PrimExpr& rhs) const {
+    // quick path
+    if (lhs.same_as(rhs)) return true;
+    if (!lhs.defined() && rhs.defined()) return false;
+    if (!rhs.defined() && lhs.defined()) return false;
+    if (lhs->type_index() != rhs->type_index()) return false;
+    if (auto* plhs = lhs.as<IntImmNode>()) {
+      auto* prhs = rhs.as<IntImmNode>();
+      return plhs->dtype == prhs->dtype && plhs->value == prhs->value;
+    }
+    if (lhs.as<AnyNode>()) {
+      return false;
+    }
+    SortExprByHashMutator sort;
+    sort.pre_max_tree_idx = INT32_MAX;
+    auto sort_lhs = sort.Rewrite(lhs);
+    while (sort.pre_max_tree_idx != -1) {
+      sort_lhs = sort.Rewrite(sort_lhs);
+    }
+    sort.pre_max_tree_idx = INT32_MAX;
+    auto sort_rhs = sort.Rewrite(rhs);

Review Comment:
   Replace with `SortExprByHashMutator::Rewrite(...)` and avoid using the same sorter twice.



##########
src/tir/analysis/deep_equal.cc:
##########
@@ -26,10 +26,324 @@
 #include <tvm/node/structural_equal.h>
 #include <tvm/runtime/registry.h>
 #include <tvm/tir/analysis.h>
-
+#include <tvm/tir/expr.h>
+#include <tvm/tir/expr_functor.h>
+#include <tvm/tir/stmt.h>
+#include <tvm/tir/stmt_functor.h>  // For the class StmtExprMutator
 namespace tvm {
 namespace tir {
 
+class SortExprByHashVisitor : public ExprVisitor {
+ public:
+  void VisitExpr_(const VarNode* op) final;
+  void VisitExpr_(const SizeVarNode* op) final;
+  void VisitExpr_(const LoadNode* op) final;
+  void VisitExpr_(const BufferLoadNode* op) final;
+  void VisitExpr_(const ProducerLoadNode* op) final;
+  void VisitExpr_(const LetNode* op) final;
+  void VisitExpr_(const CallNode* op) final;
+  void VisitExpr_(const AddNode* op) final;
+  void VisitExpr_(const SubNode* op) final;
+  void VisitExpr_(const MulNode* op) final;
+  void VisitExpr_(const DivNode* op) final;
+  void VisitExpr_(const ModNode* op) final;
+  void VisitExpr_(const FloorDivNode* op) final;
+  void VisitExpr_(const FloorModNode* op) final;
+  void VisitExpr_(const MinNode* op) final;
+  void VisitExpr_(const MaxNode* op) final;
+  void VisitExpr_(const EQNode* op) final;
+  void VisitExpr_(const NENode* op) final;
+  void VisitExpr_(const LTNode* op) final;
+  void VisitExpr_(const LENode* op) final;
+  void VisitExpr_(const GTNode* op) final;
+  void VisitExpr_(const GENode* op) final;
+  void VisitExpr_(const AndNode* op) final;
+  void VisitExpr_(const OrNode* op) final;
+  void VisitExpr_(const ReduceNode* op) final;
+  void VisitExpr_(const CastNode* op) final;
+  void VisitExpr_(const NotNode* op) final;
+  void VisitExpr_(const SelectNode* op) final;
+  void VisitExpr_(const RampNode* op) final;
+  void VisitExpr_(const BroadcastNode* op) final;
+  void VisitExpr_(const ShuffleNode* op) final;
+  void VisitExpr_(const IntImmNode* op) final;
+  void VisitExpr_(const FloatImmNode* op) final;
+  void VisitExpr_(const StringImmNode* op) final;
+  void VisitExpr_(const AnyNode* op) final;
+
+  std::vector<std::pair<int, std::vector<PrimExpr>>> op_stack;
+  int cur_max_tree_idx = 0;
+  int pre_max_tree_idx = 0;
+
+ private:
+  std::string pre_bin_op = "null";
+  int stack_idx = 0;
+  int cur_tree_idx = 0;

Review Comment:
   Please document these params. Otherwise it's impossible to understand your code.



##########
src/tir/analysis/deep_equal.cc:
##########
@@ -26,10 +26,324 @@
 #include <tvm/node/structural_equal.h>
 #include <tvm/runtime/registry.h>
 #include <tvm/tir/analysis.h>
-
+#include <tvm/tir/expr.h>
+#include <tvm/tir/expr_functor.h>
+#include <tvm/tir/stmt.h>
+#include <tvm/tir/stmt_functor.h>  // For the class StmtExprMutator
 namespace tvm {
 namespace tir {
 
+class SortExprByHashVisitor : public ExprVisitor {
+ public:
+  void VisitExpr_(const VarNode* op) final;
+  void VisitExpr_(const SizeVarNode* op) final;
+  void VisitExpr_(const LoadNode* op) final;
+  void VisitExpr_(const BufferLoadNode* op) final;
+  void VisitExpr_(const ProducerLoadNode* op) final;
+  void VisitExpr_(const LetNode* op) final;
+  void VisitExpr_(const CallNode* op) final;
+  void VisitExpr_(const AddNode* op) final;
+  void VisitExpr_(const SubNode* op) final;
+  void VisitExpr_(const MulNode* op) final;
+  void VisitExpr_(const DivNode* op) final;
+  void VisitExpr_(const ModNode* op) final;
+  void VisitExpr_(const FloorDivNode* op) final;
+  void VisitExpr_(const FloorModNode* op) final;
+  void VisitExpr_(const MinNode* op) final;
+  void VisitExpr_(const MaxNode* op) final;
+  void VisitExpr_(const EQNode* op) final;
+  void VisitExpr_(const NENode* op) final;
+  void VisitExpr_(const LTNode* op) final;
+  void VisitExpr_(const LENode* op) final;
+  void VisitExpr_(const GTNode* op) final;
+  void VisitExpr_(const GENode* op) final;
+  void VisitExpr_(const AndNode* op) final;
+  void VisitExpr_(const OrNode* op) final;
+  void VisitExpr_(const ReduceNode* op) final;
+  void VisitExpr_(const CastNode* op) final;
+  void VisitExpr_(const NotNode* op) final;
+  void VisitExpr_(const SelectNode* op) final;
+  void VisitExpr_(const RampNode* op) final;
+  void VisitExpr_(const BroadcastNode* op) final;
+  void VisitExpr_(const ShuffleNode* op) final;
+  void VisitExpr_(const IntImmNode* op) final;
+  void VisitExpr_(const FloatImmNode* op) final;
+  void VisitExpr_(const StringImmNode* op) final;
+  void VisitExpr_(const AnyNode* op) final;
+
+  std::vector<std::pair<int, std::vector<PrimExpr>>> op_stack;
+  int cur_max_tree_idx = 0;
+  int pre_max_tree_idx = 0;
+
+ private:
+  std::string pre_bin_op = "null";
+  int stack_idx = 0;
+  int cur_tree_idx = 0;
+};
+
+#define TVM_DEFINE_BIN_OP_SORT_BY_HASH_VISITOR(OpName)                 \
+  void SortExprByHashVisitor::VisitExpr_(const OpName* op) {           \
+    std::string cur_bin_op = op->_type_key;                            \
+    std::string cur_pre_bin_op = pre_bin_op;                           \
+    int cur_stack_idx = stack_idx;                                     \
+    if (cur_bin_op != cur_pre_bin_op || cur_pre_bin_op == "null") {    \
+      std::vector<PrimExpr> expr_stack;                                \
+      if (cur_tree_idx + 1 > pre_max_tree_idx) {                       \
+        return;                                                        \
+      }                                                                \
+      op_stack.emplace_back(std::make_pair(cur_tree_idx, expr_stack)); \
+      cur_tree_idx += 1;                                               \
+      cur_max_tree_idx = std::max(cur_max_tree_idx, cur_tree_idx);     \
+      cur_stack_idx = op_stack.size();                                 \
+      stack_idx = cur_stack_idx;                                       \
+      cur_pre_bin_op = cur_bin_op;                                     \
+      pre_bin_op = cur_pre_bin_op;                                     \
+    }                                                                  \
+    int cur_tree_idx_temp = cur_tree_idx;                              \
+    if ((op->a).as<OpName>() == nullptr) {                             \
+      op_stack[stack_idx - 1].second.emplace_back(op->a);              \
+    }                                                                  \
+    if ((op->b).as<OpName>() == nullptr) {                             \
+      op_stack[stack_idx - 1].second.emplace_back(op->b);              \
+    }                                                                  \
+    this->VisitExpr(op->a);                                            \
+    pre_bin_op = cur_pre_bin_op;                                       \
+    stack_idx = cur_stack_idx;                                         \
+    cur_tree_idx = cur_tree_idx_temp;                                  \
+    this->VisitExpr(op->b);                                            \
+    pre_bin_op = cur_pre_bin_op;                                       \
+    stack_idx = cur_stack_idx;                                         \
+    cur_tree_idx = cur_tree_idx_temp;                                  \
+  }
+
+#define TVM_DEFINE_BIN_OP_NO_SORT_BY_HASH_VISITOR(OpName)    \
+  void SortExprByHashVisitor::VisitExpr_(const OpName* op) { \
+    std::string cur_pre_bin_op = "null";                     \
+    pre_bin_op = cur_pre_bin_op;                             \
+    this->VisitExpr(op->a);                                  \
+    pre_bin_op = cur_pre_bin_op;                             \
+    this->VisitExpr(op->b);                                  \
+  }
+
+#define TVM_DEFINE_PASS_OP_SORT_BY_HASH_VISITOR(OpName) \
+  void SortExprByHashVisitor::VisitExpr_(const OpName* op) { return; }
+
+TVM_DEFINE_BIN_OP_SORT_BY_HASH_VISITOR(AddNode)
+TVM_DEFINE_BIN_OP_SORT_BY_HASH_VISITOR(MulNode)
+TVM_DEFINE_BIN_OP_SORT_BY_HASH_VISITOR(AndNode)
+TVM_DEFINE_BIN_OP_SORT_BY_HASH_VISITOR(OrNode)
+
+TVM_DEFINE_BIN_OP_NO_SORT_BY_HASH_VISITOR(SubNode)
+TVM_DEFINE_BIN_OP_NO_SORT_BY_HASH_VISITOR(DivNode)
+TVM_DEFINE_BIN_OP_NO_SORT_BY_HASH_VISITOR(ModNode)
+TVM_DEFINE_BIN_OP_NO_SORT_BY_HASH_VISITOR(FloorDivNode)
+TVM_DEFINE_BIN_OP_NO_SORT_BY_HASH_VISITOR(FloorModNode)
+TVM_DEFINE_BIN_OP_NO_SORT_BY_HASH_VISITOR(MinNode)
+TVM_DEFINE_BIN_OP_NO_SORT_BY_HASH_VISITOR(MaxNode)
+TVM_DEFINE_BIN_OP_NO_SORT_BY_HASH_VISITOR(EQNode)
+TVM_DEFINE_BIN_OP_NO_SORT_BY_HASH_VISITOR(NENode)
+TVM_DEFINE_BIN_OP_NO_SORT_BY_HASH_VISITOR(LTNode)
+TVM_DEFINE_BIN_OP_NO_SORT_BY_HASH_VISITOR(LENode)
+TVM_DEFINE_BIN_OP_NO_SORT_BY_HASH_VISITOR(GTNode)
+TVM_DEFINE_BIN_OP_NO_SORT_BY_HASH_VISITOR(GENode)
+
+TVM_DEFINE_PASS_OP_SORT_BY_HASH_VISITOR(VarNode)
+TVM_DEFINE_PASS_OP_SORT_BY_HASH_VISITOR(SizeVarNode)
+TVM_DEFINE_PASS_OP_SORT_BY_HASH_VISITOR(LoadNode)
+TVM_DEFINE_PASS_OP_SORT_BY_HASH_VISITOR(BufferLoadNode)
+TVM_DEFINE_PASS_OP_SORT_BY_HASH_VISITOR(ProducerLoadNode)
+TVM_DEFINE_PASS_OP_SORT_BY_HASH_VISITOR(LetNode)
+TVM_DEFINE_PASS_OP_SORT_BY_HASH_VISITOR(CallNode)
+TVM_DEFINE_PASS_OP_SORT_BY_HASH_VISITOR(ReduceNode)
+TVM_DEFINE_PASS_OP_SORT_BY_HASH_VISITOR(CastNode)
+TVM_DEFINE_PASS_OP_SORT_BY_HASH_VISITOR(NotNode)
+TVM_DEFINE_PASS_OP_SORT_BY_HASH_VISITOR(SelectNode)
+TVM_DEFINE_PASS_OP_SORT_BY_HASH_VISITOR(RampNode)
+TVM_DEFINE_PASS_OP_SORT_BY_HASH_VISITOR(BroadcastNode)
+TVM_DEFINE_PASS_OP_SORT_BY_HASH_VISITOR(ShuffleNode)
+TVM_DEFINE_PASS_OP_SORT_BY_HASH_VISITOR(IntImmNode)
+TVM_DEFINE_PASS_OP_SORT_BY_HASH_VISITOR(FloatImmNode)
+TVM_DEFINE_PASS_OP_SORT_BY_HASH_VISITOR(StringImmNode)
+TVM_DEFINE_PASS_OP_SORT_BY_HASH_VISITOR(AnyNode)
+
+class SortExprByHashMutator : public StmtExprMutator {
+ public:
+  void Init() {
+    pre_bin_op = "null";
+    stack_idx = 0;
+    cur_tree_idx = 0;
+    full_stack_size = 0;
+  }
+
+  PrimExpr Rewrite(const PrimExpr& op) {
+    Init();
+    SortExprByHashVisitor sort_visitor;
+    sort_visitor.pre_max_tree_idx = pre_max_tree_idx;
+    sort_visitor(op);
+    for (auto& stack_pair : sort_visitor.op_stack) {
+      if (stack_pair.first == sort_visitor.cur_max_tree_idx - 1) {
+        std::sort(stack_pair.second.begin(), stack_pair.second.end(),
+                  [](PrimExpr expr_a, PrimExpr expr_b) {
+                    int64_t hash_a = tvm::StructuralHash()(expr_a);
+                    int64_t hash_b = tvm::StructuralHash()(expr_b);
+                    return hash_a < hash_b;
+                  });
+      }
+    }
+    op_stack.swap(sort_visitor.op_stack);
+    pre_max_tree_idx = sort_visitor.cur_max_tree_idx;
+    PrimExpr result = StmtExprMutator::VisitExpr(op);
+    pre_max_tree_idx = sort_visitor.cur_max_tree_idx - 1;
+    return result;
+  }
+
+  PrimExpr VisitExpr_(const VarNode* op) final;
+  PrimExpr VisitExpr_(const SizeVarNode* op) final;
+  PrimExpr VisitExpr_(const LoadNode* op) final;
+  PrimExpr VisitExpr_(const BufferLoadNode* op) final;
+  PrimExpr VisitExpr_(const ProducerLoadNode* op) final;
+  PrimExpr VisitExpr_(const LetNode* op) final;
+  PrimExpr VisitExpr_(const CallNode* op) final;
+  PrimExpr VisitExpr_(const AddNode* op) final;
+  PrimExpr VisitExpr_(const SubNode* op) final;
+  PrimExpr VisitExpr_(const MulNode* op) final;
+  PrimExpr VisitExpr_(const DivNode* op) final;
+  PrimExpr VisitExpr_(const ModNode* op) final;
+  PrimExpr VisitExpr_(const FloorDivNode* op) final;
+  PrimExpr VisitExpr_(const FloorModNode* op) final;
+  PrimExpr VisitExpr_(const MinNode* op) final;
+  PrimExpr VisitExpr_(const MaxNode* op) final;
+  PrimExpr VisitExpr_(const EQNode* op) final;
+  PrimExpr VisitExpr_(const NENode* op) final;
+  PrimExpr VisitExpr_(const LTNode* op) final;
+  PrimExpr VisitExpr_(const LENode* op) final;
+  PrimExpr VisitExpr_(const GTNode* op) final;
+  PrimExpr VisitExpr_(const GENode* op) final;
+  PrimExpr VisitExpr_(const AndNode* op) final;
+  PrimExpr VisitExpr_(const OrNode* op) final;
+  PrimExpr VisitExpr_(const ReduceNode* op) final;
+  PrimExpr VisitExpr_(const CastNode* op) final;
+  PrimExpr VisitExpr_(const NotNode* op) final;
+  PrimExpr VisitExpr_(const SelectNode* op) final;
+  PrimExpr VisitExpr_(const RampNode* op) final;
+  PrimExpr VisitExpr_(const BroadcastNode* op) final;
+  PrimExpr VisitExpr_(const ShuffleNode* op) final;
+  PrimExpr VisitExpr_(const IntImmNode* op) final;
+  PrimExpr VisitExpr_(const FloatImmNode* op) final;
+  PrimExpr VisitExpr_(const StringImmNode* op) final;
+  PrimExpr VisitExpr_(const AnyNode* op) final;
+
+  int pre_max_tree_idx = 0;
+
+ private:
+  std::vector<std::pair<int, std::vector<PrimExpr>>> op_stack;
+  std::string pre_bin_op = "null";
+  int stack_idx = 0;
+  int full_stack_size = 0;
+  int cur_tree_idx = 0;
+};
+
+#define TVM_DEFINE_BIN_OP_SORT_BY_HASH_MUTATOR(Op)                                            \
+  PrimExpr SortExprByHashMutator::VisitExpr_(const Op##Node* op) {                            \
+    std::string cur_bin_op = op->_type_key;                                                   \
+    std::string cur_pre_bin_op = pre_bin_op;                                                  \
+    int cur_stack_idx = stack_idx;                                                            \
+    if (cur_bin_op != cur_pre_bin_op) {                                                       \
+      if (cur_tree_idx + 1 > pre_max_tree_idx) {                                              \
+        return GetRef<PrimExpr>(op);                                                          \
+      }                                                                                       \
+      if (cur_tree_idx + 1 == pre_max_tree_idx) {                                             \
+        PrimExpr expr_sorted =                                                                \
+            Op(op_stack[full_stack_size].second[0], op_stack[full_stack_size].second[1]);     \
+        for (std::size_t idx = 0; idx < op_stack[full_stack_size].second.size() - 2; idx++) { \
+          expr_sorted = Op(expr_sorted, op_stack[full_stack_size].second[idx + 2]);           \
+        }                                                                                     \
+        full_stack_size += 1;                                                                 \
+        cur_stack_idx = full_stack_size;                                                      \
+        cur_tree_idx += 1;                                                                    \
+        return expr_sorted;                                                                   \
+      }                                                                                       \
+      full_stack_size += 1;                                                                   \
+      cur_stack_idx = full_stack_size;                                                        \
+      cur_tree_idx += 1;                                                                      \
+      stack_idx = cur_stack_idx;                                                              \
+      cur_pre_bin_op = cur_bin_op;                                                            \
+      pre_bin_op = cur_pre_bin_op;                                                            \
+    }                                                                                         \
+    PrimExpr a;                                                                               \
+    PrimExpr b;                                                                               \
+    int cur_tree_idx_temp = cur_tree_idx;                                                     \
+    GetRef<PrimExpr>(op);                                                                     \

Review Comment:
   Is this line necessary?



##########
src/tir/analysis/deep_equal.cc:
##########
@@ -26,10 +26,324 @@
 #include <tvm/node/structural_equal.h>
 #include <tvm/runtime/registry.h>
 #include <tvm/tir/analysis.h>
-
+#include <tvm/tir/expr.h>
+#include <tvm/tir/expr_functor.h>
+#include <tvm/tir/stmt.h>
+#include <tvm/tir/stmt_functor.h>  // For the class StmtExprMutator
 namespace tvm {
 namespace tir {
 
+class SortExprByHashVisitor : public ExprVisitor {
+ public:
+  void VisitExpr_(const VarNode* op) final;
+  void VisitExpr_(const SizeVarNode* op) final;
+  void VisitExpr_(const LoadNode* op) final;
+  void VisitExpr_(const BufferLoadNode* op) final;
+  void VisitExpr_(const ProducerLoadNode* op) final;
+  void VisitExpr_(const LetNode* op) final;
+  void VisitExpr_(const CallNode* op) final;
+  void VisitExpr_(const AddNode* op) final;
+  void VisitExpr_(const SubNode* op) final;
+  void VisitExpr_(const MulNode* op) final;
+  void VisitExpr_(const DivNode* op) final;
+  void VisitExpr_(const ModNode* op) final;
+  void VisitExpr_(const FloorDivNode* op) final;
+  void VisitExpr_(const FloorModNode* op) final;
+  void VisitExpr_(const MinNode* op) final;
+  void VisitExpr_(const MaxNode* op) final;
+  void VisitExpr_(const EQNode* op) final;
+  void VisitExpr_(const NENode* op) final;
+  void VisitExpr_(const LTNode* op) final;
+  void VisitExpr_(const LENode* op) final;
+  void VisitExpr_(const GTNode* op) final;
+  void VisitExpr_(const GENode* op) final;
+  void VisitExpr_(const AndNode* op) final;
+  void VisitExpr_(const OrNode* op) final;
+  void VisitExpr_(const ReduceNode* op) final;
+  void VisitExpr_(const CastNode* op) final;
+  void VisitExpr_(const NotNode* op) final;
+  void VisitExpr_(const SelectNode* op) final;
+  void VisitExpr_(const RampNode* op) final;
+  void VisitExpr_(const BroadcastNode* op) final;
+  void VisitExpr_(const ShuffleNode* op) final;
+  void VisitExpr_(const IntImmNode* op) final;
+  void VisitExpr_(const FloatImmNode* op) final;
+  void VisitExpr_(const StringImmNode* op) final;
+  void VisitExpr_(const AnyNode* op) final;
+
+  std::vector<std::pair<int, std::vector<PrimExpr>>> op_stack;
+  int cur_max_tree_idx = 0;
+  int pre_max_tree_idx = 0;
+
+ private:
+  std::string pre_bin_op = "null";
+  int stack_idx = 0;
+  int cur_tree_idx = 0;
+};
+
+#define TVM_DEFINE_BIN_OP_SORT_BY_HASH_VISITOR(OpName)                 \
+  void SortExprByHashVisitor::VisitExpr_(const OpName* op) {           \
+    std::string cur_bin_op = op->_type_key;                            \
+    std::string cur_pre_bin_op = pre_bin_op;                           \
+    int cur_stack_idx = stack_idx;                                     \
+    if (cur_bin_op != cur_pre_bin_op || cur_pre_bin_op == "null") {    \
+      std::vector<PrimExpr> expr_stack;                                \
+      if (cur_tree_idx + 1 > pre_max_tree_idx) {                       \
+        return;                                                        \
+      }                                                                \
+      op_stack.emplace_back(std::make_pair(cur_tree_idx, expr_stack)); \
+      cur_tree_idx += 1;                                               \
+      cur_max_tree_idx = std::max(cur_max_tree_idx, cur_tree_idx);     \
+      cur_stack_idx = op_stack.size();                                 \
+      stack_idx = cur_stack_idx;                                       \
+      cur_pre_bin_op = cur_bin_op;                                     \
+      pre_bin_op = cur_pre_bin_op;                                     \
+    }                                                                  \
+    int cur_tree_idx_temp = cur_tree_idx;                              \
+    if ((op->a).as<OpName>() == nullptr) {                             \
+      op_stack[stack_idx - 1].second.emplace_back(op->a);              \
+    }                                                                  \
+    if ((op->b).as<OpName>() == nullptr) {                             \
+      op_stack[stack_idx - 1].second.emplace_back(op->b);              \
+    }                                                                  \
+    this->VisitExpr(op->a);                                            \
+    pre_bin_op = cur_pre_bin_op;                                       \
+    stack_idx = cur_stack_idx;                                         \
+    cur_tree_idx = cur_tree_idx_temp;                                  \
+    this->VisitExpr(op->b);                                            \
+    pre_bin_op = cur_pre_bin_op;                                       \
+    stack_idx = cur_stack_idx;                                         \
+    cur_tree_idx = cur_tree_idx_temp;                                  \
+  }
+
+#define TVM_DEFINE_BIN_OP_NO_SORT_BY_HASH_VISITOR(OpName)    \
+  void SortExprByHashVisitor::VisitExpr_(const OpName* op) { \
+    std::string cur_pre_bin_op = "null";                     \
+    pre_bin_op = cur_pre_bin_op;                             \
+    this->VisitExpr(op->a);                                  \
+    pre_bin_op = cur_pre_bin_op;                             \
+    this->VisitExpr(op->b);                                  \
+  }
+
+#define TVM_DEFINE_PASS_OP_SORT_BY_HASH_VISITOR(OpName) \
+  void SortExprByHashVisitor::VisitExpr_(const OpName* op) { return; }
+
+TVM_DEFINE_BIN_OP_SORT_BY_HASH_VISITOR(AddNode)
+TVM_DEFINE_BIN_OP_SORT_BY_HASH_VISITOR(MulNode)
+TVM_DEFINE_BIN_OP_SORT_BY_HASH_VISITOR(AndNode)
+TVM_DEFINE_BIN_OP_SORT_BY_HASH_VISITOR(OrNode)
+
+TVM_DEFINE_BIN_OP_NO_SORT_BY_HASH_VISITOR(SubNode)
+TVM_DEFINE_BIN_OP_NO_SORT_BY_HASH_VISITOR(DivNode)
+TVM_DEFINE_BIN_OP_NO_SORT_BY_HASH_VISITOR(ModNode)
+TVM_DEFINE_BIN_OP_NO_SORT_BY_HASH_VISITOR(FloorDivNode)
+TVM_DEFINE_BIN_OP_NO_SORT_BY_HASH_VISITOR(FloorModNode)
+TVM_DEFINE_BIN_OP_NO_SORT_BY_HASH_VISITOR(MinNode)
+TVM_DEFINE_BIN_OP_NO_SORT_BY_HASH_VISITOR(MaxNode)
+TVM_DEFINE_BIN_OP_NO_SORT_BY_HASH_VISITOR(EQNode)
+TVM_DEFINE_BIN_OP_NO_SORT_BY_HASH_VISITOR(NENode)
+TVM_DEFINE_BIN_OP_NO_SORT_BY_HASH_VISITOR(LTNode)
+TVM_DEFINE_BIN_OP_NO_SORT_BY_HASH_VISITOR(LENode)
+TVM_DEFINE_BIN_OP_NO_SORT_BY_HASH_VISITOR(GTNode)
+TVM_DEFINE_BIN_OP_NO_SORT_BY_HASH_VISITOR(GENode)
+
+TVM_DEFINE_PASS_OP_SORT_BY_HASH_VISITOR(VarNode)
+TVM_DEFINE_PASS_OP_SORT_BY_HASH_VISITOR(SizeVarNode)
+TVM_DEFINE_PASS_OP_SORT_BY_HASH_VISITOR(LoadNode)
+TVM_DEFINE_PASS_OP_SORT_BY_HASH_VISITOR(BufferLoadNode)
+TVM_DEFINE_PASS_OP_SORT_BY_HASH_VISITOR(ProducerLoadNode)
+TVM_DEFINE_PASS_OP_SORT_BY_HASH_VISITOR(LetNode)
+TVM_DEFINE_PASS_OP_SORT_BY_HASH_VISITOR(CallNode)
+TVM_DEFINE_PASS_OP_SORT_BY_HASH_VISITOR(ReduceNode)
+TVM_DEFINE_PASS_OP_SORT_BY_HASH_VISITOR(CastNode)
+TVM_DEFINE_PASS_OP_SORT_BY_HASH_VISITOR(NotNode)
+TVM_DEFINE_PASS_OP_SORT_BY_HASH_VISITOR(SelectNode)
+TVM_DEFINE_PASS_OP_SORT_BY_HASH_VISITOR(RampNode)
+TVM_DEFINE_PASS_OP_SORT_BY_HASH_VISITOR(BroadcastNode)
+TVM_DEFINE_PASS_OP_SORT_BY_HASH_VISITOR(ShuffleNode)
+TVM_DEFINE_PASS_OP_SORT_BY_HASH_VISITOR(IntImmNode)
+TVM_DEFINE_PASS_OP_SORT_BY_HASH_VISITOR(FloatImmNode)
+TVM_DEFINE_PASS_OP_SORT_BY_HASH_VISITOR(StringImmNode)
+TVM_DEFINE_PASS_OP_SORT_BY_HASH_VISITOR(AnyNode)
+
+class SortExprByHashMutator : public StmtExprMutator {
+ public:
+  void Init() {
+    pre_bin_op = "null";
+    stack_idx = 0;
+    cur_tree_idx = 0;
+    full_stack_size = 0;
+  }
+
+  PrimExpr Rewrite(const PrimExpr& op) {
+    Init();
+    SortExprByHashVisitor sort_visitor;
+    sort_visitor.pre_max_tree_idx = pre_max_tree_idx;
+    sort_visitor(op);
+    for (auto& stack_pair : sort_visitor.op_stack) {
+      if (stack_pair.first == sort_visitor.cur_max_tree_idx - 1) {
+        std::sort(stack_pair.second.begin(), stack_pair.second.end(),
+                  [](PrimExpr expr_a, PrimExpr expr_b) {
+                    int64_t hash_a = tvm::StructuralHash()(expr_a);
+                    int64_t hash_b = tvm::StructuralHash()(expr_b);
+                    return hash_a < hash_b;
+                  });
+      }
+    }
+    op_stack.swap(sort_visitor.op_stack);
+    pre_max_tree_idx = sort_visitor.cur_max_tree_idx;
+    PrimExpr result = StmtExprMutator::VisitExpr(op);
+    pre_max_tree_idx = sort_visitor.cur_max_tree_idx - 1;
+    return result;
+  }
+
+  PrimExpr VisitExpr_(const VarNode* op) final;
+  PrimExpr VisitExpr_(const SizeVarNode* op) final;
+  PrimExpr VisitExpr_(const LoadNode* op) final;
+  PrimExpr VisitExpr_(const BufferLoadNode* op) final;
+  PrimExpr VisitExpr_(const ProducerLoadNode* op) final;
+  PrimExpr VisitExpr_(const LetNode* op) final;
+  PrimExpr VisitExpr_(const CallNode* op) final;
+  PrimExpr VisitExpr_(const AddNode* op) final;
+  PrimExpr VisitExpr_(const SubNode* op) final;
+  PrimExpr VisitExpr_(const MulNode* op) final;
+  PrimExpr VisitExpr_(const DivNode* op) final;
+  PrimExpr VisitExpr_(const ModNode* op) final;
+  PrimExpr VisitExpr_(const FloorDivNode* op) final;
+  PrimExpr VisitExpr_(const FloorModNode* op) final;
+  PrimExpr VisitExpr_(const MinNode* op) final;
+  PrimExpr VisitExpr_(const MaxNode* op) final;
+  PrimExpr VisitExpr_(const EQNode* op) final;
+  PrimExpr VisitExpr_(const NENode* op) final;
+  PrimExpr VisitExpr_(const LTNode* op) final;
+  PrimExpr VisitExpr_(const LENode* op) final;
+  PrimExpr VisitExpr_(const GTNode* op) final;
+  PrimExpr VisitExpr_(const GENode* op) final;
+  PrimExpr VisitExpr_(const AndNode* op) final;
+  PrimExpr VisitExpr_(const OrNode* op) final;
+  PrimExpr VisitExpr_(const ReduceNode* op) final;
+  PrimExpr VisitExpr_(const CastNode* op) final;
+  PrimExpr VisitExpr_(const NotNode* op) final;
+  PrimExpr VisitExpr_(const SelectNode* op) final;
+  PrimExpr VisitExpr_(const RampNode* op) final;
+  PrimExpr VisitExpr_(const BroadcastNode* op) final;
+  PrimExpr VisitExpr_(const ShuffleNode* op) final;
+  PrimExpr VisitExpr_(const IntImmNode* op) final;
+  PrimExpr VisitExpr_(const FloatImmNode* op) final;
+  PrimExpr VisitExpr_(const StringImmNode* op) final;
+  PrimExpr VisitExpr_(const AnyNode* op) final;
+
+  int pre_max_tree_idx = 0;
+
+ private:
+  std::vector<std::pair<int, std::vector<PrimExpr>>> op_stack;
+  std::string pre_bin_op = "null";
+  int stack_idx = 0;
+  int full_stack_size = 0;
+  int cur_tree_idx = 0;

Review Comment:
   Document them.



##########
src/tir/analysis/deep_equal.cc:
##########
@@ -70,10 +384,45 @@ bool ExprDeepEqual::operator()(const PrimExpr& lhs, const PrimExpr& rhs) const {
   return DeepCmpSEqualHandler().SEqualReduce(lhs, rhs, false, NullOpt);
 }
 
+class CommutativeDeepEqual : public ExprDeepEqual {
+ public:
+  bool operator()(const PrimExpr& lhs, const PrimExpr& rhs) const {
+    // quick path
+    if (lhs.same_as(rhs)) return true;
+    if (!lhs.defined() && rhs.defined()) return false;
+    if (!rhs.defined() && lhs.defined()) return false;
+    if (lhs->type_index() != rhs->type_index()) return false;
+    if (auto* plhs = lhs.as<IntImmNode>()) {
+      auto* prhs = rhs.as<IntImmNode>();
+      return plhs->dtype == prhs->dtype && plhs->value == prhs->value;
+    }
+    if (lhs.as<AnyNode>()) {
+      return false;
+    }
+    SortExprByHashMutator sort;
+    sort.pre_max_tree_idx = INT32_MAX;
+    auto sort_lhs = sort.Rewrite(lhs);
+    while (sort.pre_max_tree_idx != -1) {
+      sort_lhs = sort.Rewrite(sort_lhs);
+    }

Review Comment:
   What does this loop do? Why not do it inside `Rewrite`?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] FranckQC commented on pull request #12761: [TIR, analysis] Add CommutativeDeepEqual to handle commutativity in expression comparison

Posted by GitBox <gi...@apache.org>.
FranckQC commented on PR #12761:
URL: https://github.com/apache/tvm/pull/12761#issuecomment-1319338998

   I was finally able to take some time this week to have a closer look at this PR. It is definitely better than it was before, as it leaves the current DeepEqual unchanged, which was very important for me. That was my initial point 3 in my earlier comment, which I consider being addressed, thank you :).
   
   However, I agree with @masahi comments. The implementation could have a lot more comments and documentation about the variables being used, what the function do, and what parts of the algorithm do. It would help a lot reading the code, which increases the confidence one can have in the implementation.
   
   It's a great thing that there is quite a lot of tests, thanks for taking the time to write many of them. However, I'd also like to be able to see a real usage for new equivalence relation (that was point 2 in my earlier comment). TVM is a compiler, not a tool for just doing algebraic manipulations of mathematical terms like Matlab, so I would really like to see some natural use cases for this, where this get used/integrated into a pass, or into something else that ultimately lead to improvements in the code produced by the compilation of some ML models.
   
   **More minor thing:** Finally, I'd like to know how one is supposed to use this `CommutativeDeepEqual` along the `Analyzer::Simplify()` function that performs other kind of simplification (simplification of neutral elements, applying distributivity, etc, but which unfortunately can't handle commutativity, as discussed earlier in the thread), in order to have a function that uses all the algebraic properties available. I imagine it would call Simplify() on both sides and then uses this new `CommutativeDeepEqual`. Would that be enough for being complete?
   The reason behind that is the following: I believe most of the people who could be interested in equality-modulo-commutativity will be coming here after having discovered that `Analyzer::Simplify()` can't do all the simplifications for them. So when they will learn that there is this `CommutativeDeepEqual` equivalence relation for dealing with commutativity, their first question will likely be "how do I combine both?". So I think demonstrating that could be useful.
   
   The most important thing for me at this stage for this PR are to add comments to the code, and to show some real use case/integration for this. Thank you for your work!


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] tqchen commented on pull request #12761: [TIR, analysis] Add expr hash sort in ExprDeepEqual

Posted by GitBox <gi...@apache.org>.
tqchen commented on PR #12761:
URL: https://github.com/apache/tvm/pull/12761#issuecomment-1246801723

   It would be useful to make it as a different equality comparator (rather than change DeepEqual's behavior, e.g. we can do CommunicativeDeepEqual as a subclass), as communicative rewrite is something that goes deeper.
   
   Another possibility is add a canonicalization pass to canonicalize the expressions before CSE 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] zhangyicole commented on pull request #12761: [TIR, analysis] Add expr hash sort in ExprDeepEqual

Posted by GitBox <gi...@apache.org>.
zhangyicole commented on PR #12761:
URL: https://github.com/apache/tvm/pull/12761#issuecomment-1261696154

   Hi,@tqchen. I have implemented the function as a subclass,would you like to take a look and merge it if everything looks good?


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] tvm-bot commented on pull request #12761: [TIR, analysis] Add expr hash sort in ExprDeepEqual

Posted by GitBox <gi...@apache.org>.
tvm-bot commented on PR #12761:
URL: https://github.com/apache/tvm/pull/12761#issuecomment-1281852891

   <!---bot-comment-->
   
   Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from [Reviewers](https://github.com/apache/incubator-tvm/blob/master/CONTRIBUTORS.md#reviewers) by @-ing them in a comment.
   
   <!--bot-comment-docs-start-->
    * Built docs for commit 43698188c6e2e9bcac6734c180be896c01d1ccca can be found [here](https://pr-docs.tlcpack.ai/PR-12761/11/docs/index.html).<!--bot-comment-docs-end-->
   
   <sub>Generated by [tvm-bot](https://github.com/apache/tvm/blob/main/ci/README.md#github-actions)</sub>


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] FranckQC commented on pull request #12761: [TIR, analysis] Add expr hash sort in ExprDeepEqual

Posted by GitBox <gi...@apache.org>.
FranckQC commented on PR #12761:
URL: https://github.com/apache/tvm/pull/12761#issuecomment-1247109999

   Hello!
   
   Thank you for the discussion and the PR.
   Just a few thoughts here:
   
   - **1.** Indeed, we knew that `Analyzer::Simplify()` would not deal with commutativity, because it does not implement a normalization procedure that is guaranteed to converge towards the normal form (which would indeed imply sorting sub-terms, etc). Rather, all it does is it tries to "do its best" by rewriting some known patterns, with no guarantees to converge to a normal form. It often works fairly well in practice, but there is no guarantee of being complete (this behaves like a heuristic). However, it really must be correct (i.e, the result of simplify() must rely be equivalent to its input with algebraic laws).
   
   - **2.** Before trying to make `Analyzer::Simplify()` able to deal with with commutativity, it could be useful to **see if people are in practice facing the issue where a lot of redundant computations appear, but written differently due to the commutativity of some operators (like + and *)**. If so, it would be cool to see such concrete examples. To be honest, I don't even think that that many people are turning ON the already existing `bool identify_equiv_terms` of the CSE pass `Pass CommonSubexprElimTIR`, which uses `Analyzer::Simplify()`, which itself does what it can with associativity, neutral elements, etc. These things are probably pretty rare, and commutativity is probably too.
   
   Although I designed the pass in a way that it can potentially identify terms that are equivalent [according to any equivalence relation](https://github.com/apache/tvm/blob/main/src/tir/transforms/common_subexpr_elim_tools.cc#L730-#L750) (instead of just the syntactical equality `ExprDeepEqual`), it might not necessary be often needed in practice. This design that allows to use a custom equivalence relation did not cost much more, so I went for it, in case it could become useful to someone one day for a particular case, so that it would not be needed to write another CSE pass for dealing with that. But I don't necessary think that we should do comparisons modulo equivalence often, and this is why by default `bool identify_equiv_terms` of the CSE pass `Pass CommonSubexprElimTIR` is set to false.
   
   - **3.** If we decide that `Analyzer::Simplify()` (or another new Analyzer!) should deal with commutativity, then it should itself deal with commutativity, rather than baking commutativity into `ExprDeepEqual` which is supposed to be just a deep **syntactical** equality, and is used as such in many many places of TVM's codebase (not just the CSE). So clearly it should not being changed (as @tqchen noted too).
   
   - **4.** Remember that normalizing terms properly in order to deal with commutativity (which indeed includes sorting sub-terms) will likely be computationally expensive, which will make compilation of ML models longer. Actually, just the smaller work that `Analyzer::Simplify()` does is already time consuming, and that's probably why people leave the `bool identify_equiv_terms` of the CSE pass `Pass CommonSubexprElimTIR` set to false (which is the default behavior, as I wrote earlier). It might make people want even less to turn on this `bool identify_equiv_terms`. Perhaps the pseudo-normalization that `Analyzer::Simplify()` does is not too bad as a compromise: still usable in practice when needed (i.e does not take too long), and deals with most simplifications needed.
   
   - **5.** If all the other algebraic properties (associativity, simplification of neutral elements, etc) are still done by the **pseudo**-normalization `Analyzer::Simplify()` that is not guaranteed to find a normal form, I am not sure that the commutativity simplified build on top would be complete -even just in regard to commutativity. Is it worth it to then make `Analyzer::Simplify()` slower while still being incomplete?
   
   Thanks!


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] zhangyicole commented on pull request #12761: [TIR, analysis] Add expr hash sort in ExprDeepEqual

Posted by GitBox <gi...@apache.org>.
zhangyicole commented on PR #12761:
URL: https://github.com/apache/tvm/pull/12761#issuecomment-1272242410

   Hi, @FranckQC @masahi. Is there anything in the code that I need to update?


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] zhangyicole commented on pull request #12761: [TIR, analysis] Add expr hash sort in ExprDeepEqual

Posted by GitBox <gi...@apache.org>.
zhangyicole commented on PR #12761:
URL: https://github.com/apache/tvm/pull/12761#issuecomment-1281779841

   Hi, @FranckQC @masahi. If nothing needs to be changed in this PR, can you help merge it?


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org