You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/11/10 00:21:00 UTC

[GitHub] [tvm] FranckQC opened a new pull request #9482: Implementation of Common Subexpression Elimination for TIR

FranckQC opened a new pull request #9482:
URL: https://github.com/apache/tvm/pull/9482


   Hi everyone,
   
   We would like to upstream some work that we did at Qualcomm.
   The goal of this PR is to implement a Common Subexpression Elimination (CSE) pass for TIR, which aims at identifying redundant computations (both within statements and within expressions), and to replace them by a new fresh variable, introduced before the first occurrence of the redundant computation.
   
   Note that it does not only try to do commoning on full expressions, but it is also able to do it on subexpressions. For instance, if the program computes the expression (w+x) + (y+z) and the expression (w+x)+u, it will introduce the subexpression (w+x) into a new variable.
   
   If we want so, it will be easily possible in the future to make the notion of equivalence between terms more flexible, allowing for instance to identify expressions modulo commutativity (identifying for instance (x+y) with (y+x)), modulo associativity (identifying for instance (x+y)+z with x+(y+z)), etc. Replacing only the function bool EquivalentTerms(const PrimExpr& a, const PrimExpr& b) will be the only thing needed in order to do that. The typical way to rewrite it for such extensions would be to compute a canonical representant of a and a canonical representant of b and to then compare them with the strict syntactical equality.
   
   The main CSE pass is declared and implemented respectively in the files common_subexpr_elim.h and common_subexpr_elim.cc.
   The function Stmt CommonSubexpressionEliminator::VisitStmt(const Stmt& stmt) is a good entry point as it contains many comments about what the pass is doing.
   
   The general idea of this pass is that it tries to introduce at the current level (the current root) the computations that are redundant and which are possible to introduce there (they should only contain variables that are in scope). This notion of variables in scope is implemented with a context, which is a vector of pairs (var, MaybeValue). The context is not only used for checking that variables that appear in candidate computations are known at this point, but also for checking if a computation has already been introduced into a variable.
   
   For a greater flexibility in the future, there is a strong distinction already in place between :
   
      - Syntactic computations, which are maintained in a hashtable which associates expressions (the computations already seen) to size_int (the number of times the computation has been seen).
      - Semantic entities, which are obtained from the syntactic computations by merging equivalent computations (where this notion of "equivalent" is customizable). Semantic entities are stored into a vector of pairs (expr, size_int) where, again, the number is the number of times that expr or equivalent computations have been seen.
   
   The VisitStmt() method starts by computing the syntactic computations (implemented in an auxiliary analysis), then it merges equivalent computations to obtain the semantic computations. Then it sorts these semantic computations from biggest to smallest in order to always consider first the biggest computations. The rest will essentially be a loop over all these candidates, which will stay sorted.
   
   When dealing with a candidate computation, there are three cases that can happen:
   
      - 1 - Rare case A variable in the context already contains this computation. This variable can't have been introduced by the CSE, as we would have performed the replacements at the same time (see case 2). So this is the case where the user himself (or the previous TIR passes) has written something like "let x = A in ...A...A...)"
       -> In this case, we simply perform the replacements of A with x in the current result. These replacements are done by an auxiliary transform/Mutator, declared and implemented in replace_expr_selected.h and in replace_expr_selected.cc.
   
      - 2 - Case where we need to introduce the current computation inside a new variable This is the case where all the variables used by the current computation are within scope (i.e. are present in the context) and where our internal heuristic/predicate tells us to introduce this computation into a new variable.
       -> In this case, a new variable new_var_i is generated, all the locations that use this computation in result are replaced by this fresh variable (using the same auxiliary Mutator mentioned in 1.), and the current result is replaced by let new_var_i = currentComputation in result.
   
      - 3 - Case where we can't or don't want to introduce this computation inside a new variable This is the case where we either can't introduce the current computation inside a new variable (because it contains variables that are not yet in scope there) or because our internal heuristic/predicate did not want to introduce it.
       -> In this case, we will compute the direct sub-expressions of the current computation (implemented by an auxiliary analysis), and we will add them to the vector of semantic computations so that they have a chance to be considered later. Note that they are added while still preserving the order.
       Note that we do not add all the sub-expressions of the current expression but only its direct subexpressions given the fact that we always consider them from biggest to smallest, and given that some candidates are mutually exclusive. Otherwise it would be computationally more intensive and it would pose the problem of cleaning the vector of candidate computations when one of them gets introduced into a variable. Evaluating them lazily by only looking at the direct sub-expressions is at the same time more efficient and simpler.
   
   Once the entire vector of semantic computations has been tried, the main function VisitStmt() calls the general dispatcher , which will in turn call the appropriate handlers. The only specific task of overridden handlers will be to update the context appropriately as new variables are introduced into scope (via Let-In, via For loop, etc) or leave the current scope. Thus, they will update the context appropriately before and after the calls to VisitStmt() and VisitExpr() on the child nodes.
   
   Please do not hesitate if you have any question.
   Thank you.
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] FranckQC edited a comment on pull request #9482: Implementation of Common Subexpression Elimination for TIR

Posted by GitBox <gi...@apache.org>.
FranckQC edited a comment on pull request #9482:
URL: https://github.com/apache/tvm/pull/9482#issuecomment-1034443098


   Sorry I forgot to answer to the other parts of your message @mbs-octoml . Many thanks for it by the way!
   
   > Thanks so much for the clear and complete comments, very much appreciate that.
   > Echo @Hzfengsy suggestion to switch to TVMScript for your tests -- perhaps just pretty printing what you already constructs would be a short cut.
   
   Yes I didn't know about TVMScript before. When writing the tests, I initially stared at some other test files and got inspiration from them. Unfortunately the ones I've been looking at might not have been the most up-to-date way of doing things, sorry for that! :-(
   I guess we can still improve the tests later on, but they still accomplish their main function for now, which is the most important I think, even though they are probably a little bit more verbose than they would be using TVMScript, where I could just directly write the expected TIR code instead of unfolding the TIR code that the CSE pass has produced.
    
   > The code duplication across PrimExpr & Stmnt in CommonSubexpressionElimintor is unfortunate but suspect removing that would only obscure things.
   
   Yes, I agree that this duplication is a little bit unfortunate. @masahi did pointed it out [here](https://github.com/apache/tvm/pull/9482#discussion_r790201064). I was also a little bit annoyed with it at the beginning. So I tried to factorize it out a few times, including an attempt described in my answer [here](https://github.com/apache/tvm/pull/9482#discussion_r797243979). But all my attempt ended up with something much too complicated for what we would gain.
   
   In fact, we just happen to want to do almost exactly the same treatment for Expr() and Stmt() from an algorithmic point of view, but from a data-type point of view, quite a things are still different type-wise. That's a pretty rare situation. In the end, I decided to not force things, and to leave it like that.
   And the positive point of having the VisitStmt() and VisitExpr() not factorized by some weird magic is that we can still easily customize the algorithm if at some point we need to, between expressions and statements (I can't think of any reason we would want that, but still! :) )
   
   Many thanks again!
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi commented on pull request #9482: Implementation of Common Subexpression Elimination for TIR

Posted by GitBox <gi...@apache.org>.
masahi commented on pull request #9482:
URL: https://github.com/apache/tvm/pull/9482#issuecomment-1031958392


   This issue https://github.com/apache/tvm/issues/10180 could be relevant for the i386 failure.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] FranckQC commented on a change in pull request #9482: Implementation of Common Subexpression Elimination for TIR

Posted by GitBox <gi...@apache.org>.
FranckQC commented on a change in pull request #9482:
URL: https://github.com/apache/tvm/pull/9482#discussion_r797139585



##########
File path: src/tir/transforms/common_subexpr_elim_tools.h
##########
@@ -0,0 +1,216 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file common_subexpr_elim_tools.h
+ * \brief Interface of analysis tools and utility functions used
+           by the Common Subexpression Elimination (CSE) pass.
+ */
+
+#ifndef TVM_TIR_TRANSFORMS_COMMON_SUBEXPR_ELIM_TOOLS_H_
+#define TVM_TIR_TRANSFORMS_COMMON_SUBEXPR_ELIM_TOOLS_H_
+
+#include <tvm/runtime/container/string.h>
+#include <tvm/tir/analysis.h>  // For the ExprDeepEqual analysis
+#include <tvm/tir/expr.h>
+#include <tvm/tir/expr_functor.h>
+#include <tvm/tir/stmt.h>
+#include <tvm/tir/stmt_functor.h>  // For the class StmtExprVisitor
+
+#include <unordered_map>  // For the hashtable datatype
+#include <vector>
+
+#include "../../../3rdparty/dmlc-core/include/dmlc/optional.h"
+
+namespace tvm {
+namespace tir {
+
+/*!
+ * \brief A table of computations is a hashtable which associates to each expression being computed
+          a number (which is the number of time that it is computed)
+					It is important to note that the hash used is a StructuralHash (and not an ObjectPtrHash)
+					as we need to hash similarly deeply equal terms.
+					The comparison used is ExprDeepEqual, which is stricter than StructuralEqual (as it does
+					not do variables remapping), so it is compatible with StructuralHash (intended to be used
+					with StructuralEqual).
+ */
+using TableOfComputations = std::unordered_map<PrimExpr, size_t, StructuralHash, ExprDeepEqual>;
+
+/*!
+ * \brief A cache of computations is made of a pair of two hashtables, which respectively associate
+          to each statement or expression of the program its table of computations. Its purpose is
+          to avoid the CSE pass from recomputing repeatedly the same tables of computations.
+ */
+struct CacheOfComputations {
+  // Part of the cache for statements
+  // It maps each known statement to its table of computations
+  std::unordered_map<Stmt, TableOfComputations, ObjectPtrHash, ObjectPtrEqual>
+      cache_stmt_table_computations_;
+
+  // Part of the cache for expressions
+  // It maps each known expression to its table of computations
+  std::unordered_map<PrimExpr, TableOfComputations, ObjectPtrHash, ObjectPtrEqual>
+      cache_expr_table_computations_;
+};
+
+/*!
+ * \brief Visitor which returns in a hashtable the (syntatic) computations done by an expression
+          or by a statement.
+ * \note Computations here are considered syntactically, meaning that semantically equivalent
+          computations that are not syntactically the same are not merged together.
+ */
+class ComputationsDoneBy : public StmtExprVisitor {
+ public:
+  // Toplevel (static) methods
+  static TableOfComputations GetComputationsDoneBy(
+      const PrimExpr& expr, std::function<bool(const PrimExpr&)> is_eligible_computation,
+      std::function<bool(const PrimExpr&)> can_contain_computations);
+  static TableOfComputations GetComputationsDoneBy(

Review comment:
       While I agree that `ExtractSubComputations` is a nicer name, I think it would be better to stick with the concept of "computations being done by a node", which I have tried to describe [at the beginning of the implementation of this class](https://github.com/apache/tvm/pull/9482/files#diff-f83c46530e92c628fd309499ceffa32cb2fb0505633ac0e754e2bdef4d518962R68-R84).
   
   I agree that `ExtractSubComputations` is a nicer name, but it would lie quite a bit about what the function really is returning. For an expression in input, it indeed returns its subexpressions (if recursing inside the current node is allowed by the second predicate `can_contain_computations`), but that's only if the original expression itself was not illegible (according to the first predicate `is_eligible_computation`), as the original expression is precisely what would have been returned otherwise.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] FranckQC commented on pull request #9482: Implementation of Common Subexpression Elimination for TIR

Posted by GitBox <gi...@apache.org>.
FranckQC commented on pull request #9482:
URL: https://github.com/apache/tvm/pull/9482#issuecomment-1029593259


   Thank you so much @masahi for all your comments and advises that are always coming so quick! I'll do that tomorrow morning.
   You've being doing an amazing job helping me to upstream this pass, thanks a lot!


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] FranckQC edited a comment on pull request #9482: Implementation of Common Subexpression Elimination for TIR

Posted by GitBox <gi...@apache.org>.
FranckQC edited a comment on pull request #9482:
URL: https://github.com/apache/tvm/pull/9482#issuecomment-972608426


   Many thanks for the interesting questions and comments, I really appreciate it.
   
   > 
   > 
   > I have another three questions about the pass.
   > 
   >     1. Since function calls are non-eligible, it seems to prevent many common non-stateful builtin computations from being optimized, such as `tir.floordiv`, `tir.shift_right`, `tir.likely`, `tir.log` and etc.  TVM already have op annotation like `PURE`:
   >        https://github.com/apache/tvm/blob/main/include/tvm/tir/op_attr_types.h#L60-L67
   >        what if to add an annotation type like `DETERMINISTIC` to mark the function return same output on same input. It should be
   >        safe to eliminate common calls if the call op is both pure and deterministic.
   
   Yes, I definitely agree that we could, in the future, relax the restriction which currently prevent function calls from being commoned-out, for some functions that we know to be pure, including some builtin.
   I think just the "pure" tag would be enough for this purpose, if it has the meaning that I imagine, which is this one (quoted from Wikipedia) :
   
   #################
   a pure function is a function that has the following properties:
   
   1. The function return values are identical for identical arguments (no variation with local static variables, non-local variables, mutable reference arguments or input streams).
   2. The function application has no side effects (no mutation of local static variables, non-local variables, mutable reference arguments or input/output streams).
   
   Thus a pure function is a computational analogue of a mathematical function. Some authors, particularly from the imperative language community, use the term "pure" for all functions that just have the above property 2. 
   #################
   
   What is crucial for us for commoning-out function calls is only the condition 1. As long as the function returns always the same output for the same input, we are fine with putting the result of applying the function to some input into a variable (or any kind of "cache").
   
   Please note that in the Wikipedia article they later mention that the condition 2 is needed for perfoming CSE, but they are talking about a different problem that we are discussing (and they aren't very precise about it). They do not talk about commoning-out function calls here. They are talking about commoning-out some redundant terms, between which some function calls happen, as in :
   x = a+b+c;
   f(); // Here, if f can change the content of some variables like a, b or c, it prevents the CSE on the redundant term (a+b+c)
   y = a+b+c;
   
   But this kind of consideration does not concern us, as TIR is always in a (weak) SSA form : the content of variables won't be updated, and can be assumed to be the same everywhere through the execution of the program. Thus, only condition 1 matters to us.
   
   If the kPure tag has both the meaning of 1 and 2, then we will be able to simply use it in the future when relaxing the condition.
   If it's only the condition 2, then we will indeed need another tag.
   
   >     2. Is the let binding semantic in tir explicitly what we want? Is the value expr is specified to be evaluated at let scope entry? otherwise there may be no reuse effect if let-binding allow the value expr get lazily expanded in let body.
   
   I must admit that I am not entirely sure now that I think of it... Anyone to confirm on that?
   
   >     3. Is this possible for some complex (but common) expr to be lifted out of control flow scope? eg
   >        ```python
   >        if cond0:
   >            tir.evaluate(complex_e)  # slow path
   >        else:
   >           if cond2: 
   >                tir.evaluate(complex_e)  # slow path
   >           else:
   >                ...  # general fast path
   >        ```
   >        
   >        
   >            
   >              
   >            
   >        
   >              
   >            
   >        
   >            
   >          
   >        after CSE maybe comes to
   >        ```python
   >        x = tir.Var("int32")
   >        with tir.let(x, complex_e):  # always evaluate slow path
   >            if cond0:
   >                tir.evaluate(x) 
   >            else:
   >                if cond2: 
   >                    tir.evaluate(x)
   >                else:
   >                    ... # general fast path
   >        ```
   
   Yes, that is possible. Although, now that I think of it, it would probably better to not do it in the exact example that you provided, as the expression `complex_e` of the slow path now becomes always evaluated, even if we end up going in the general fast branch...
   We should only lift things through "if" statements if both the then branch and the else branch computes it. This restriction doesn't improve the semantics preservation (which was previously fine), but it will avoid this kind of anti-optimization. That should be easily fixable.
   
   Many thanks!


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] FranckQC edited a comment on pull request #9482: Implementation of Common Subexpression Elimination for TIR

Posted by GitBox <gi...@apache.org>.
FranckQC edited a comment on pull request #9482:
URL: https://github.com/apache/tvm/pull/9482#issuecomment-972608426


   Many thanks for the interesting questions and comments, I really appreciate it.
   
   > 
   > 
   > I have another three questions about the pass.
   > 
   >     1. Since function calls are non-eligible, it seems to prevent many common non-stateful builtin computations from being optimized, such as `tir.floordiv`, `tir.shift_right`, `tir.likely`, `tir.log` and etc.  TVM already have op annotation like `PURE`:
   >        https://github.com/apache/tvm/blob/main/include/tvm/tir/op_attr_types.h#L60-L67
   >        what if to add an annotation type like `DETERMINISTIC` to mark the function return same output on same input. It should be
   >        safe to eliminate common calls if the call op is both pure and deterministic.
   
   Yes, I definitely agree that we could, in the future, relax the restriction which currently prevent function calls from being commoned-out, for some functions that we know to be pure, including some builtin.
   I think just the "pure" tag would be enough for this purpose, if it has the meaning that I imagine, which is this one (quoted from Wikipedia) :
   
   <<
   a pure function is a function that has the following properties:
   
       1. The function return values are identical for identical arguments (no variation with local static variables, non-local variables, mutable reference arguments or input streams).
       2. The function application has no side effects (no mutation of local static variables, non-local variables, mutable reference arguments or input/output streams).
   
   Thus a pure function is a computational analogue of a mathematical function. Some authors, particularly from the imperative language community, use the term "pure" for all functions that just have the above property 2. 
   >>
   > 
   What is crucial for us for commoning-out function calls is only the condition 1. As long as the function returns always the same output for the same input, we are fine with putting the result of applying the function to some input into a variable (or any kind of "cache").
   
   Please note that in the Wikipedia article they later mention that the condition 2 is needed for perfoming CSE, but they are talking about something different (and aren't very precise about it). They do not talk about commoning-out function calls here. They are talking about commoning-out some redundant terms, between which some function calls happen, as in :
   x = a+b+c;
   f(); // Here, if f can change the content of some variables like a, b or c, it prevents the CSE on the redundant term (a+b+c)
   y = a+b+c;
   
   But this kind of consideration does not concern us, as TIR is in a (weak) SSA form : the content of variables won't be updated, and can be assumed to be the same everywhere through the execution of the program. Thus, only condition 1 matters to us.
   
   If the kPure tag has both the meaning of 1 and 2, then we will be to just use it in the future when relaxing the condition.
   If it's only the condition 2, then we will indeed need another tag.
   
   >     2. Is the let binding semantic in tir explicitly what we want? Is the value expr is specified to be evaluated at let scope entry? otherwise there may be no reuse effect if let-binding allow the value expr get lazily expanded in let body.
   
   I must admit that I am not entirely sure now that I think of it... Anyone to confirm on that?
   
   >     3. Is this possible for some complex (but common) expr to be lifted out of control flow scope? eg
   >        ```python
   >        if cond0:
   >            tir.evaluate(complex_e)  # slow path
   >        else:
   >           if cond2: 
   >                tir.evaluate(complex_e)  # slow path
   >           else:
   >                ...  # general fast path
   >        ```
   >        
   >        
   >            
   >              
   >            
   >        
   >              
   >            
   >        
   >            
   >          
   >        after CSE maybe comes to
   >        ```python
   >        x = tir.Var("int32")
   >        with tir.let(x, complex_e):  # always evaluate slow path
   >            if cond0:
   >                tir.evaluate(x) 
   >            else:
   >                if cond2: 
   >                    tir.evaluate(x)
   >                else:
   >                    ... # general fast path
   >        ```
   
   Yes, that is possible. Although, now that I think of it, it would probably better to not do it in the exact example that you provided, as the expression `complex_e` of the slow path now becomes always evaluated, even if we end up going in the general fast branch...
   We should only lift things through "if" statements if both the then branch and the else branch computes it. This restriction doesn't improve the semantics preservation (which was previously fine), but it will avoid this kind of anti-optimization. That should be easily fixable.
   
   Many thanks!


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] FranckQC edited a comment on pull request #9482: Implementation of Common Subexpression Elimination for TIR

Posted by GitBox <gi...@apache.org>.
FranckQC edited a comment on pull request #9482:
URL: https://github.com/apache/tvm/pull/9482#issuecomment-1018993039


   Hi everyone.
   
   So sorry for the delayed answer.
   
   I have just pushed yesterday and today a few more commits that improve this CSE pass. The biggest improvement is the behavior of the pass for If and For nodes, for which expressions were previously allowed to be lifted out of the control flow scope.
   I spotted that thanks to the question of @wrongtest, so many thanks again for it.
   It was definitely a bad thing (although rare, and not breaking the semantics preservation) as it was an anti-optimization if a redundant was only present in one execution path (let's say the THEN branch of an If) to lift it just above the If, as it might never have been computed at all if the execution flows reach the ELSE branch.
   Now, it won't do these kind of things. Instead, it will introduce it at the beginning of the bloc where it is redundant, but it won't be lifted outside of it.
   
   The only things left are to re run clang-format and to address the last little things reported by the CI. However, the build already succeeds for both Windows and MacOS.
   I'm hoping to be done with everything by the end of the week-end, so that it can be merged next week for people to enjoy the CSE! I hope it will be useful.
   
   Kind regards,
   Franck


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi commented on a change in pull request #9482: Implementation of Common Subexpression Elimination for TIR

Posted by GitBox <gi...@apache.org>.
masahi commented on a change in pull request #9482:
URL: https://github.com/apache/tvm/pull/9482#discussion_r790200752



##########
File path: src/tir/transforms/common_subexpr_elim.cc
##########
@@ -0,0 +1,601 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file common_subexpr_elim.cc
+ * \brief Implementation of the Common Subexpressions Elimination (CSE) pass
+           which rewrites statements and expressions in order to eliminate
+           redundant computations. In order to achieve that, common (sub-)
+           expressions are introduced into variables with let-in bindings,
+           and the places where the expression was used are replaced with
+           the freshly introduced variable.
+ */
+
+#include "common_subexpr_elim.h"
+
+#include <tvm/ir/transform.h>  // For the class Pass and the class PassContext
+#include <tvm/runtime/container/array.h>
+#include <tvm/runtime/container/string.h>
+#include <tvm/tir/analysis.h>  // For the analysis which gives the size of an expr
+#include <tvm/tir/expr.h>
+#include <tvm/tir/expr_functor.h>
+#include <tvm/tir/function.h>  // For the class PrimFunc
+#include <tvm/tir/stmt.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>  // For the decl of the function returning the pass
+
+#include <algorithm>  // For the algorithm std::find
+#include <iostream>
+#include <unordered_map>  // For the hashtable datatype
+#include <utility>        // For std::pair and std::move
+#include <vector>
+
+#include "../analysis/check_contains.h"  // For the visitor CheckContains
+#include "common_subexpr_elim_tools.h"   // For the auxiliary analysis (visitors) and tools
+#include "replace_expr_selected.h"       // For the mutator ReplaceExprSelected
+
+namespace tvm {
+namespace tir {
+
+/*!
+ * \brief Check whether a computation is forbidden for being treated by the CSE pass.
+          The important thing about forbidden computations is that not only we won't want
+          to collect them for the CSE pass, but we also won't even want to collect computations
+          that contain them.
+          The reason is that reusing such computations would change the semantics of the program,
+          and therefore before doing any introduction of variable or any reuse of already introduced
+          variables, we will make sure that the computation being considered is not forbidden, and
+          that it does not even contain a forbidden computation.
+ * \param expr The expression to check
+ * \return Whether `expr` is a forbidden computation or not
+ */
+bool CommonSubexpressionEliminator::ForbiddenComputation(const PrimExpr& expr) {
+  // Function calls, loads and buffer loads are absolutely forbidden as introducing them into
+  // variables would change the semantics of the program.
+  return (expr.as<CallNode>() != nullptr || expr.as<LoadNode>() != nullptr ||
+          expr.as<BufferLoadNode>() != nullptr);
+}
+
+/*!
+ * \brief Predicate used for verifying that a computation is eligible for being treated by
+          the CSE pass, i.e. for being introduced into a variable / for being replaced by a
+          variable.
+          Being eligible is a conjunction of a few conditions, like not being an atom (constant
+          or variable), not being a forbidden node, not containing a forbidden node, etc.
+ * \param expr The expression to check
+ * \return Whether `expr` is an eligible computation or not
+ */
+bool CommonSubexpressionEliminator::IsEligibleComputation(const PrimExpr& expr) {
+  return (
+      // In order to be eligible, the given expression should not be a constant
+      (expr.as<IntImmNode>() == nullptr) && (expr.as<FloatImmNode>() == nullptr) &&
+      (expr.as<StringImmNode>() == nullptr)
+      // and it should not be a variable
+      && (expr.as<VarNode>() == nullptr)
+      // and it should not be a forbidden computation (function calls and loads)
+      && (!ForbiddenComputation(expr))
+      // and it should not even contain a forbidden computation (function calls and loads)
+      // the reason is that we don't want to register expressions like (x + f(y)) or
+      // (x + Mem[i]) as introducing them into variables could change the semantics
+      && (!CheckContains::ExprContains(expr, ForbiddenComputation))
+      // and it should not be a ramp node or a broadcast node due to some internals TVM
+      // constraints (which check for these node explicitely without performing any
+      // evaluation first, so if they have been put into variables it fails)
+      && (expr.as<RampNode>() == nullptr) && (expr.as<BroadcastNode>() == nullptr));
+}
+
+/*!
+ * \brief Predicate used (when considering eligible computations) for only diving into
+          expressions that are allowed to contain eligible computations. Customize this predicate
+          if you want to make it forbidden to rewrite inside a specific node, like inside
+          a Load node for instance.
+ * \param expr The expression to check
+ * \return Whether `expr` can contain some eligible computations or not, and therefore
+             if recursing inside `expr` is necessary.
+ */
+bool CommonSubexpressionEliminator::CanContainEligibleComputations(const PrimExpr& expr) {
+  // Uncomment the next line to prevent the collection and the replacement of eligible computations
+  // inside the index of Load nodes. We initially thought that this would be needed in order to
+  // not harm the indexing mode of the CPU, but as we are still far from ASM code, we
+  // finally want to perform such simplifications, which tend to happen fairly frequently.
+
+  // return ( (expr.as<LoadNode>() == nullptr) && (expr.as<BufferLoadNode>() == nullptr) )
+  return true;
+};
+
+/*!
+ * \brief Generates a new fresh variable, whose name will be cse_var_i.
+ * \param type_annotation The type of the new variable to generate
+ * \return A new variable of type `type_annotation` called cse_var_i where i is the first available
+            integer.
+ */
+Var CommonSubexpressionEliminator::GenerateNewVar(DataType type_annotation) {
+  // Increase `num_last_try_` for this new attempt
+  num_last_try_++;
+  // Builds the variable name, which is sce_var_i where i will go up from 1
+  std::string prefix = "cse_var_";
+  std::string name = prefix.append(std::to_string(num_last_try_));
+  // Builds a String using the std::string
+  String string_name(name);
+
+  // Check that the name that we want to use for the new variable isn't already being used
+  // (names don't really have to be unique as they are just hints, and having the same name
+  // doesn't means that it's the same variable, but it's clearer for dumps)
+  if (UsesVarName::StmtUsesVarName(initial_body_, string_name)) {
+    // If the name is already used, call ourselves recursively for trying with the next one
+    return GenerateNewVar(type_annotation);
+  }
+
+  // Increase `nb_var_` for this new generation of variable that we have just done
+  nb_var_++;
+
+  // Return a new Variable using the name built and the given type_annotation
+  return (Var(string_name, type_annotation));
+}
+
+/*!
+ * \brief Gives the number of variables generated by the CSE on the current function
+           (i.e., getter for `nb_var_`).
+ * \return A copy of `nb_var_`
+ */
+int CommonSubexpressionEliminator::GetNbVarGenerated() { return nb_var_; }
+
+/*!
+ * \brief Toplevel (static) method that performs Common Subexpression Elimination on
+          a given statement (which should be the body of a PrimFunc). This method should be
+          called for each PrimFunc definition.
+ * \param stmt The statement of the function being analyzed, on which we want to perform CSE
+ * \param context_init The initial context, which should contain the formal parameters
+                          of the function being analyzed
+ * \return A new statement where CSE has been performed
+ */
+Stmt CommonSubexpressionEliminator::PerformCSE(const Stmt& stmt, const Context& context_init) {
+  // As this function is being called for each PrimFunc definition, we create a new instance
+  // for the one we are having now.
+  CommonSubexpressionEliminator common_subexpression_eliminator(stmt, context_init);
+  return common_subexpression_eliminator.VisitStmt(stmt);
+}
+
+/*!
+ * \brief Protected constructor of CommonSubexpressionEliminator.
+ * \param context_init The context at the begining of the CSE pass. It should contain the
+                        formal parameters of the function that will be analyzed
+ */
+CommonSubexpressionEliminator::CommonSubexpressionEliminator(const Stmt& stmt,
+                                                             const Context& context_init)
+    : initial_body_(stmt), context_(context_init) {}
+
+/*!
+ * \brief The method which overrides the generic dispatcher of StmtExprMutator.
+          Entry point to the common subexpression elimination mutator for expressions.
+ * \param expr The expression to mutate
+ */
+PrimExpr CommonSubexpressionEliminator::VisitExpr(const PrimExpr& expr) {
+  PrimExpr result = expr;
+
+  // Obtain the (syntactic) eligible computations done by the input expression, and keep it as
+  // a TableOfComputations, which is a mapping from PrimExpr to size_t, where the size_t is the
+  // number of time this exact syntactic computation is being computed.
+  TableOfComputations table_syntactic_comp_done_by_expr = ComputationsDoneBy::GetComputationsDoneBy(
+      expr, IsEligibleComputation, CanContainEligibleComputations);
+
+  // Transform the hashtable of *syntactic* eligible computations into a vector of pairs
+  // containing *semantic* entities, i.e. where equivalent computations are merged.
+  std::vector<std::pair<PrimExpr, size_t>> semantic_comp_done_by_expr =
+      SyntacticToSemanticComputations(table_syntactic_comp_done_by_expr);
+
+  // Sort the vector of semantic entities by decreasing size
+  std::sort(semantic_comp_done_by_expr.begin(), semantic_comp_done_by_expr.end(),
+            [](std::pair<PrimExpr, size_t> a, std::pair<PrimExpr, size_t> b) {
+              return (CalculateExprComplexity(a.first) > CalculateExprComplexity(b.first));
+            });
+
+  // For each computation done (considering them from biggest to smallest)
+  for (size_t i = 0; i < semantic_comp_done_by_expr.size(); i++) {
+    std::pair<PrimExpr, size_t>& computation_and_nb = semantic_comp_done_by_expr[i];
+
+    // The predicate later used (when doing replacements) to select expressions that are
+    // equivalent to the current computation (`computation_and_nb.first`)
+    std::function<bool(const PrimExpr&)> predicate_selector =
+        [computation_and_nb](const PrimExpr& current_expr) {
+          // `current_expr` should be equivalent to `computation_and_nb.first`, but we also check
+          // that `current_expr` is an eligible computation even if we know that
+          // `computation_and_nb.first` is eligible by construction, in case that one day the
+          // equivalence relation would not preserve the eligibility any more (even though that
+          // would probably be a very weird equivalence).
+          return (EquivalentTerms(current_expr, computation_and_nb.first) &&
+                  IsEligibleComputation(current_expr));
+        };
+
+    // See if there is a pair (`var`, `value`) in the context where `value` is semantically
+    // equivalent to `computation_and_nb.first`
+    auto it_on_var = std::find_if(
+        context_.begin(), context_.end(),
+        [computation_and_nb](const std::pair<Var, MaybeValue>& var_and_value) {
+          // Note : safe to call value() as we check has_value() just before
+          return (var_and_value.second.has_value() &&
+                  EquivalentTerms(var_and_value.second.value(), computation_and_nb.first));
+        });
+
+    // Case where we have a perfectly equivalent computation already available in a variable
+    // introduced (i.e, present in context_).
+    // Note that this case is needed when the user has written something like
+    // [let x = A in ....A...A...] : we need to be able to replace all the occurences of A by
+    // an already existing variable holding A, when such a variable happens to exist.
+    if (it_on_var != context_.end()) {
+      // Replace in the current `result` everything that is selected by the selector with
+      // the existing variable, without diving into expressions in which we don't have the
+      // right to dive.
+      result = ReplaceExprSelected::ReplaceExprSelectedInExpr(
+          result, predicate_selector, it_on_var->first, CanContainEligibleComputations);
+    } else {
+      // The current computation is not equivalent to a computation already done. We will
+      // need to see if we want to introduce it.
+
+      // --- Chunk needed for reusing the UndefinedVars() analysis ---
+      // 1 - Wraps the computation into a statement
+      Stmt computation_wrapped_in_stmt = Evaluate(computation_and_nb.first);
+      // 2.1 - Transform the context into a vector of variables instead of pairs
+      std::function<Var(const std::pair<Var, MaybeValue>&)> forget_value =
+          [](const std::pair<Var, MaybeValue>& pair) { return pair.first; };
+      std::vector<Var> vector_vars_known = VectorMap(context_, forget_value);
+      // 2.2 - Transform the std::vector into an Array
+      Array<Var> array_vars_known = Array<Var>(vector_vars_known);
+      // --- End of chunk needed for reusing the UndefinedVars() analysis ---
+
+      // We use the UndefinedVars() analysis to get the undefined vars of the computation
+      Array<Var> vars_undefined = UndefinedVars(computation_wrapped_in_stmt, array_vars_known);
+
+      // Check if we can introduce it : if it contains no undefined variables and if we want
+      // to introduce it according to the predicate
+      if (vars_undefined.empty() &&
+          PredicateIntroVarForComputation(computation_and_nb.first, computation_and_nb.second)) {
+        // Create a new variable for this computation
+        Var new_var = GenerateNewVar(computation_and_nb.first.dtype());
+        // Replace in the current `result` everything that is selected by the selector with
+        // the new variable, without diving into expressions in which we don't have the
+        // right to dive.
+        result = ReplaceExprSelected::ReplaceExprSelectedInExpr(result, predicate_selector, new_var,
+                                                                CanContainEligibleComputations);
+        // Build a let-in that introduces the new variable in the current `result`
+        result = Let(new_var, computation_and_nb.first, result);
+        // We don't add the variable to the context because the invariant is that the
+        // context is the context in which 'result' makes sense, and we've just updated it.
+      } else {
+        // Here it's not doable to introduce (via a let-in) the computation at this level
+        // as it contains variables that are not yet declared, and/or because the predicate
+        // did not select it.
+        // Either way, we will simply add to the vector of computations the direct subexprs
+        // of the current computation, as these ones might be good candidates
+        // for being introduced into variables.
+        // Note that we don't need to add all of its subexpressions, but only its *direct*
+        // subexpressions as we consider them from biggest to smallest, and if they were
+        // all added at once, then there could be dependencies between them, as commoning
+        // one of them could remove some other possibilities.
+
+        // Computing the direct subexpressions will return a small number of direct
+        // subexpressions (typically 0 to 3)
+        std::vector<PrimExpr> direct_subexprs = DirectSubexpr::GetDirectSubexpressions(
+            computation_and_nb.first, IsEligibleComputation, CanContainEligibleComputations);
+        // The following insertion will maintain `semantic_comp_done_by_expr` sorted (by
+        // decreasing size/complexity), and it will only insert at locations > i as the
+        // direct subexprs are necessarily smaller than the current computation.
+        InsertVectorToSortedSemanticComputations(semantic_comp_done_by_expr, direct_subexprs);
+      }
+    }
+    // Note : we do not remove the current element, as we never look back in the local vector
+  }  // End of for loop
+
+  // Calling the dispatcher to the specific treatments, which will update the context
+  // appropriately before doing the recursive calls on the child nodes
+  result = StmtExprMutator::VisitExpr(result);
+
+  return result;
+}
+
+/*!
+ * \brief The method which overrides the specific treatment for a LetNode
+ */
+PrimExpr CommonSubexpressionEliminator::VisitExpr_(const LetNode* op) {
+  // At this point, we have already done the generic treatment of introducing (via let-in) what
+  // was doable at the toplevel of the given let-in.
+
+  // Save the context at the entry of the function
+  Context context_at_entry = context_;
+
+  // Recurse on the `value` field for potentially rewriting it
+  PrimExpr value_new = VisitExpr(op->value);
+
+  // Augment the context with the association (`var`, `value`) for preparing the next recursion
+  // on the `body`
+  context_.push_back({op->var, MaybeValue(op->value)});
+
+  // Recurse on the `body` (with this extended context)
+  // The recursive call will have potentially done new simplifications, because in this recursive
+  // call `var` will be a part of the context.
+  // (see in VisitExpr() that no introduction were performed when a computation was using an
+  // undefined variable, as that would lead to ill-formed code)
+  PrimExpr body_new = VisitExpr(op->body);
+
+  // Restaure the context to its content at the entrance to not carry out of scope declarations
+  // as the variable introduced by the let-in is not in scope outside of its body
+  context_ = context_at_entry;
+
+  // Rebuild the let-in with a new `value_new` and `body_new` where new simplifications might
+  // have been done.
+
+  // If the `value` and the `body` of the let-in have been rewritten to the same thing
+  if (value_new.same_as(op->value) && body_new.same_as(op->body)) {
+    // then return a reference to the same node
+    return GetRef<PrimExpr>(op);
+  } else {
+    // Otherwise return a let-in built with the new `value_new` and the new `body_new` that
+    // have just been obtained
+    return Let(op->var, value_new, body_new, op->span);
+  }
+}
+
+/*!
+ * \brief The method which overrides the generic dispatcher of StmtExprMutator.
+          Entry point to the common subexpression elimination mutator for statements.
+ * \param stmt The statement to mutate.
+ */
+Stmt CommonSubexpressionEliminator::VisitStmt(const Stmt& stmt) {

Review comment:
       Please remove duplication with `CommonSubexpressionEliminator::VisitExpr` if possible (use template etc)




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] wrongtest commented on pull request #9482: Implementation of Common Subexpression Elimination for TIR

Posted by GitBox <gi...@apache.org>.
wrongtest commented on pull request #9482:
URL: https://github.com/apache/tvm/pull/9482#issuecomment-972463175


   Hi @FranckQC, I am also very interested in the CSE pass. In our circumstances we suffer from duplicate index computations produced by loop unroll, like
   ```
   A[i*256 + j*16 + 0] = B[i*256 + j*16 + 4096]
   A[i*256 + j*16 + 1] = B[i*256 + j*16 + 4097]
   ...
   ```
   We have to depend on target backend optimize abilities which may or maynot optimize them out. It would be great if CSE can handle part of these things in TIR level.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] wrongtest commented on pull request #9482: Implementation of Common Subexpression Elimination for TIR

Posted by GitBox <gi...@apache.org>.
wrongtest commented on pull request #9482:
URL: https://github.com/apache/tvm/pull/9482#issuecomment-972538596


   I have another three questions about the pass.
   
   1.  Since function calls are non-eligible, it seems to prevent many common non-stateful builtin computations from being optimized, such as `tir.floordiv`, `tir.shift_right`, `tir.likely`, `tir.log` and etc.  TVM already have op annotation like `PURE`:
       https://github.com/apache/tvm/blob/main/include/tvm/tir/op_attr_types.h#L60-L67
       what if to add an annotation type like `DETERMINISTIC` to mark the function return same output on same input. It should be 
    safe to eliminate common calls if the call op is both pure and deterministic.
   
   2. Is the let binding semantic in tir explicitly what we want? Is the value expr is specified to be evaluated at let scope entry? otherwise there may be no reuse effect if let-binding allow the value expr get lazily expanded in let body.
   
   3. Is this possible for some complex (but common) expr to be lifted out of scope? eg
       ```python
       if cond0:
           # slow path
           tir.evaluate(complex_e)
       else:
          if cond2: 
          else:
               # general fast path
       ```


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi commented on pull request #9482: Implementation of Common Subexpression Elimination for TIR

Posted by GitBox <gi...@apache.org>.
masahi commented on pull request #9482:
URL: https://github.com/apache/tvm/pull/9482#issuecomment-964681835


   Hi @FranckQC, I wanted TIR-level CSE for a long time, so very excited to see this!
   
   What I wanted to do is to eliminate common expressions that span across the host and GPU - for example, in GPU `sort` kernel, I need to make `log2(N)` GPU kernel calls from the host to sort the input bottom up. In principle, `log2(N)` needs to computed once by the host and pass to the GPU kernel, but since we cannot CSE `log2(N)` expression that appears both in the host and GPU kernel, right now the GPU sort kernel is littered with `log2(N)` compute like this (note a log of calls to `call_spirv_pure_glsl450` which is totally unnecessary if we had TIR-level CSE) https://gist.github.com/masahi/7a755ef67009e1a836e3212c53cf496f
   
   Is this PR going to solve my problem?
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi commented on a change in pull request #9482: Implementation of Common Subexpression Elimination for TIR

Posted by GitBox <gi...@apache.org>.
masahi commented on a change in pull request #9482:
URL: https://github.com/apache/tvm/pull/9482#discussion_r790202166



##########
File path: src/tir/transforms/common_subexpr_elim_tools.h
##########
@@ -0,0 +1,216 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file common_subexpr_elim_tools.h
+ * \brief Interface of analysis tools and utility functions used
+           by the Common Subexpression Elimination (CSE) pass.
+ */
+
+#ifndef TVM_TIR_TRANSFORMS_COMMON_SUBEXPR_ELIM_TOOLS_H_
+#define TVM_TIR_TRANSFORMS_COMMON_SUBEXPR_ELIM_TOOLS_H_
+
+#include <tvm/runtime/container/string.h>
+#include <tvm/tir/analysis.h>  // For the ExprDeepEqual analysis
+#include <tvm/tir/expr.h>
+#include <tvm/tir/expr_functor.h>
+#include <tvm/tir/stmt.h>
+#include <tvm/tir/stmt_functor.h>  // For the class StmtExprVisitor
+
+#include <unordered_map>  // For the hashtable datatype
+#include <vector>
+
+#include "../../../3rdparty/dmlc-core/include/dmlc/optional.h"
+
+namespace tvm {
+namespace tir {
+
+/*!
+ * \brief A table of computations is a hashtable which associates to each expression being computed
+          a number (which is the number of time that it is computed)
+					It is important to note that the hash used is a StructuralHash (and not an ObjectPtrHash)
+					as we need to hash similarly deeply equal terms.
+					The comparison used is ExprDeepEqual, which is stricter than StructuralEqual (as it does
+					not do variables remapping), so it is compatible with StructuralHash (intended to be used
+					with StructuralEqual).
+ */
+using TableOfComputations = std::unordered_map<PrimExpr, size_t, StructuralHash, ExprDeepEqual>;
+
+/*!
+ * \brief A cache of computations is made of a pair of two hashtables, which respectively associate
+          to each statement or expression of the program its table of computations. Its purpose is
+          to avoid the CSE pass from recomputing repeatedly the same tables of computations.
+ */
+struct CacheOfComputations {
+  // Part of the cache for statements
+  // It maps each known statement to its table of computations
+  std::unordered_map<Stmt, TableOfComputations, ObjectPtrHash, ObjectPtrEqual>
+      cache_stmt_table_computations_;
+
+  // Part of the cache for expressions
+  // It maps each known expression to its table of computations
+  std::unordered_map<PrimExpr, TableOfComputations, ObjectPtrHash, ObjectPtrEqual>
+      cache_expr_table_computations_;
+};
+
+/*!
+ * \brief Visitor which returns in a hashtable the (syntatic) computations done by an expression
+          or by a statement.
+ * \note Computations here are considered syntactically, meaning that semantically equivalent
+          computations that are not syntactically the same are not merged together.
+ */
+class ComputationsDoneBy : public StmtExprVisitor {
+ public:
+  // Toplevel (static) methods
+  static TableOfComputations GetComputationsDoneBy(

Review comment:
       Please make it a free function since the usage `ComputationsDoneBy::GetComputationsDoneBy` is quite verbose. We can also remove the declaration of `ComputationsDoneBy` from this header. See how other passes are implemented, there is no need for static methods.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi commented on pull request #9482: Implementation of Common Subexpression Elimination for TIR

Posted by GitBox <gi...@apache.org>.
masahi commented on pull request #9482:
URL: https://github.com/apache/tvm/pull/9482#issuecomment-1029592504


   > Or perhaps I should simply update these tests by lowering without the CSE pass (i.e. with the pass disabled)
   
   Yes that's what I was about to suggest.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi commented on pull request #9482: Implementation of Common Subexpression Elimination for TIR

Posted by GitBox <gi...@apache.org>.
masahi commented on pull request #9482:
URL: https://github.com/apache/tvm/pull/9482#issuecomment-1013423769


   ping @FranckQC 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] FranckQC edited a comment on pull request #9482: Implementation of Common Subexpression Elimination for TIR

Posted by GitBox <gi...@apache.org>.
FranckQC edited a comment on pull request #9482:
URL: https://github.com/apache/tvm/pull/9482#issuecomment-1031844098


   Oh I see, sorry!
   Before I push again, would you mind looking at the [unittest: CPU] error in the file test_vta_insn.py, please?
   I disabled the CSE pass for it before each build (as I did for some other tests), but for some reason this one is still a failure and I don't really understand what's going on there.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] FranckQC commented on pull request #9482: Implementation of Common Subexpression Elimination for TIR

Posted by GitBox <gi...@apache.org>.
FranckQC commented on pull request #9482:
URL: https://github.com/apache/tvm/pull/9482#issuecomment-1031857978


   There is also some trouble with two of the [python3: i386] test on test_large_grpah (??) and test_large_graph.
   
   This, with the trouble in test_vta_insn.py on test_gemm  are the only issues remaining.
   Should I revert my changes on test_vta_insn.py which did not work?


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] FranckQC edited a comment on pull request #9482: Implementation of Common Subexpression Elimination for TIR

Posted by GitBox <gi...@apache.org>.
FranckQC edited a comment on pull request #9482:
URL: https://github.com/apache/tvm/pull/9482#issuecomment-969151523


   > 
   > 
   > Hi @FranckQC, I wanted TIR-level CSE for a long time, so very excited to see this!
   > 
   > What I wanted to do is to eliminate common expressions that span across the host and GPU - for example, in GPU `sort` kernel, I need to make `log2(N)` GPU kernel calls from the host to sort the input bottom up. In principle, `log2(N)` needs to computed once by the host and pass to the GPU kernel, but since we cannot CSE `log2(N)` expression that appears both in the host and GPU kernel, right now the GPU sort kernel is littered with `log2(N)` compute like this (note a lot of calls to `call_spirv_pure_glsl450` which is totally unnecessary if we had TIR-level CSE) https://gist.github.com/masahi/7a755ef67009e1a836e3212c53cf496f
   > 
   > Is this PR going to solve my problem?
   
   Hi @masahi 
   Thanks a lot for the kind words! I'm happy to read that this new pass might be useful to you.
   In principle, every redundant subterms that are eligible for being commoned out (i.e, which does not contain function calls, etc) will be commoned out. There are also a few restrictions which are due to some specifics of TVM, but these are rare.
   
   Unfortunately, on the TIR code snippet that you have uploaded, it seems that the subexpression that is redundant is just a function call. These can't be commoned out into a new variable by the CSE pass as it does not have any guarantee that the function has no side effects, meaning that it will always produce the same outputs for the same inputs. Without this guarantee, commoning out such function calls could change the program's semantics, so it's not done as preserving the semantics of the program is vital.
   
   I can imagine that for functions that are guaranteed to not do any side effects (and which are therefore "functions" in the mathematical sense of the term), we could relax this restriction, but that would be an extension to implement in the future. And it would rely on some "NoSideEffect" tag on functions.
   
   However, please note that if you had some other redundancies, this CSE pass would common out whatever redundancies you have that are eligible.
   For instance : 
   Assume you have the term (f(42) + f(42)) + (x*y + x*y) that appear somewhere.
   It is not eligible in its entirety, as it contains function calls.
   But its subterm (x*y + x*y) is eligible, so this subpart will be commoned out.
   
   Does that answers your question?
   Please do not hesitate to tell me if you need help for trying the pass.
   
   Kind regards.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] wrongtest commented on pull request #9482: Implementation of Common Subexpression Elimination for TIR

Posted by GitBox <gi...@apache.org>.
wrongtest commented on pull request #9482:
URL: https://github.com/apache/tvm/pull/9482#issuecomment-972543408


   BTW current split host device machinary seems work quite well with common expression bindings!
   A test script is as below, it shows common expression is evaluated at host function and passed as a kernel param to device function.
   
   ```python
   import tvm
   from tvm.script import tir as T
   
   @T.prim_func
   def func(a: T.handle, b: T.handle, n: T.int32) -> None:
       threadIdx_x = T.env_thread("threadIdx.x")
       A = T.match_buffer(a, [256], dtype="int32")
       B = T.match_buffer(b, [256], dtype="int32")
       common_expr = T.var("int32")
       # for common_expr in range(n // 8, n // 8 + 1):
       with T.let(common_expr, n // 8):
           for i in T.serial(0, common_expr):
               T.launch_thread(threadIdx_x, 8)
               T.store(B.data, i * 8 + threadIdx_x, common_expr + T.load("int32", A.data, i * 8 + threadIdx_x), True)
   
   mod = tvm.IRModule.from_expr(func)
   mod = tvm.tir.transform.Apply(lambda f: f.with_attr({"global_symbol": "main", "target": tvm.target.Target("cuda")}))(mod)
   mod = tvm.tir.transform.SplitHostDevice()(mod)
   print(mod.script())
   
   
   # script for result mod 
   @tvm.script.ir_module
   class Module:
       @T.prim_func
       def main(a: T.handle, b: T.handle, n: T.int32) -> None:
           # function attr dict
           T.func_attr({"global_symbol": "main", "target": None})
           A = T.match_buffer(a, [256], dtype="int32")
           B = T.match_buffer(b, [256], dtype="int32")
           # body
           for common_expr in T.serial(n // 8, n // 8 + 1):
               for i in T.serial(0, common_expr):
                   T.evaluate(T.tvm_call_packed("main_kernel0", B.data, A.data, common_expr, i, 8, dtype="int32"))
   
       @T.prim_func
       def main_kernel0(B_1: T.Ptr[global T.int32], A_1: T.Ptr[global T.int32], common_expr: T.int32, i: T.int32) -> None:
           # function attr dict
           T.func_attr({"target": cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32, "tir.noalias": 1, "global_symbol": "main_kernel0", "tir.device_thread_axis": [T.iter_var(threadIdx_x, [0:8], "ThreadIndex", "threadIdx.x")], "tir.is_global_func": 1, "calling_conv": 2})
           # var definition
           threadIdx_x = T.env_thread("threadIdx.x")
           # body
           T.launch_thread(threadIdx_x, 8)
           T.store(B_1, i * 8 + threadIdx_x, common_expr + T.load("int32", A_1, i * 8 + threadIdx_x), True)
   ```


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] FranckQC edited a comment on pull request #9482: Implementation of Common Subexpression Elimination for TIR

Posted by GitBox <gi...@apache.org>.
FranckQC edited a comment on pull request #9482:
URL: https://github.com/apache/tvm/pull/9482#issuecomment-972608426


   Many thanks for the interesting questions and comments, I really appreciate it.
   
   > 
   > 
   > I have another three questions about the pass.
   > 
   >     1. Since function calls are non-eligible, it seems to prevent many common non-stateful builtin computations from being optimized, such as `tir.floordiv`, `tir.shift_right`, `tir.likely`, `tir.log` and etc.  TVM already have op annotation like `PURE`:
   >        https://github.com/apache/tvm/blob/main/include/tvm/tir/op_attr_types.h#L60-L67
   >        what if to add an annotation type like `DETERMINISTIC` to mark the function return same output on same input. It should be
   >        safe to eliminate common calls if the call op is both pure and deterministic.
   
   Yes, I definitely agree that we could, in the future, relax the restriction which currently prevent function calls from being commoned-out, for some functions that we know to be pure, including some builtin.
   I think just the "pure" tag would be enough for this purpose, if it has the meaning that I imagine, which is this one (quoted from Wikipedia) :
   
   #################
   a pure function is a function that has the following properties:
   
   1. The function return values are identical for identical arguments (no variation with local static variables, non-local variables, mutable reference arguments or input streams).
   2. The function application has no side effects (no mutation of local static variables, non-local variables, mutable reference arguments or input/output streams).
   
   Thus a pure function is a computational analogue of a mathematical function. Some authors, particularly from the imperative language community, use the term "pure" for all functions that just have the above property 2. 
   #################
   
   What is crucial for us for commoning-out function calls is only the condition 1. As long as the function returns always the same output for the same input, we are fine with putting the result of applying the function to some input into a variable (or any kind of "cache").
   
   Please note that in the Wikipedia article they later mention that the condition 2 is needed for perfoming CSE, but they are talking about a different problem that we are discussing (and they aren't very precise about it). They do not talk about commoning-out function calls here. They are talking about commoning-out some redundant terms, between which some function calls happen, as in :
   x = a+b+c;
   f(); // Here, if f can change the content of some variables like a, b or c, it prevents the CSE on the redundant term (a+b+c)
   y = a+b+c;
   
   But this kind of consideration does not concern us, as TIR is always in a (weak) SSA form : the content of variables won't be updated, and can be assumed to be the same everywhere through the execution of the program. Thus, only condition 1 matters to us.
   
   If the kPure tag has both the meaning of 1 and 2, then we will be able to simply use it in the future when relaxing the condition.
   If it's only the condition 2, then we will indeed need another tag.
   
   >     2. Is the let binding semantic in tir explicitly what we want? Is the value expr is specified to be evaluated at let scope entry? otherwise there may be no reuse effect if let-binding allow the value expr get lazily expanded in let body.
   
   I must admit that I am not entirely sure now that I think of it... Anyone to confirm on that please?
   
   >     3. Is this possible for some complex (but common) expr to be lifted out of control flow scope? eg
   >        ```python
   >        if cond0:
   >            tir.evaluate(complex_e)  # slow path
   >        else:
   >           if cond2: 
   >                tir.evaluate(complex_e)  # slow path
   >           else:
   >                ...  # general fast path
   >        ```
   >        
   >        
   >            
   >              
   >            
   >        
   >              
   >            
   >        
   >            
   >          
   >        after CSE maybe comes to
   >        ```python
   >        x = tir.Var("int32")
   >        with tir.let(x, complex_e):  # always evaluate slow path
   >            if cond0:
   >                tir.evaluate(x) 
   >            else:
   >                if cond2: 
   >                    tir.evaluate(x)
   >                else:
   >                    ... # general fast path
   >        ```
   
   Yes, that is possible. Although, now that I think of it, it would probably better to not do it in the kind of example that you provided, as the expression `complex_e` of the slow path now becomes always evaluated, even if we end up going in the general fast branch...
   We should only lift things through "if" statements if both the then branch and the else branch computes it. This restriction doesn't improve the semantics preservation (which was previously fine), but it will avoid this kind of anti-optimization. That should be easily fixable.
   
   Many thanks!


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] FranckQC edited a comment on pull request #9482: Implementation of Common Subexpression Elimination for TIR

Posted by GitBox <gi...@apache.org>.
FranckQC edited a comment on pull request #9482:
URL: https://github.com/apache/tvm/pull/9482#issuecomment-972608426


   Many thanks for the interesting questions and comments, I really appreciate it.
   
   > 
   > 
   > I have another three questions about the pass.
   > 
   >     1. Since function calls are non-eligible, it seems to prevent many common non-stateful builtin computations from being optimized, such as `tir.floordiv`, `tir.shift_right`, `tir.likely`, `tir.log` and etc.  TVM already have op annotation like `PURE`:
   >        https://github.com/apache/tvm/blob/main/include/tvm/tir/op_attr_types.h#L60-L67
   >        what if to add an annotation type like `DETERMINISTIC` to mark the function return same output on same input. It should be
   >        safe to eliminate common calls if the call op is both pure and deterministic.
   
   Yes, I definitely agree that we could, in the future, relax the restriction which currently prevent function calls from being commoned-out, for some functions that we know to be pure, including some builtin.
   I think just the "pure" tag would be enough for this purpose, if it has the meaning that I imagine, which is this one (quoted from Wikipedia) :
   
   #################
   a pure function is a function that has the following properties:
   
   1. The function return values are identical for identical arguments (no variation with local static variables, non-local variables, mutable reference arguments or input streams).
   2. The function application has no side effects (no mutation of local static variables, non-local variables, mutable reference arguments or input/output streams).
   
   Thus a pure function is a computational analogue of a mathematical function. Some authors, particularly from the imperative language community, use the term "pure" for all functions that just have the above property 2. 
   #################
   
   What is crucial for us for commoning-out function calls is only the condition 1. As long as the function returns always the same output for the same input, we are fine with putting the result of applying the function to some input into a variable (or any kind of "cache").
   
   Please note that in the Wikipedia article they later mention that the condition 2 is needed for perfoming CSE, but they are talking about something different (and aren't very precise about it). They do not talk about commoning-out function calls here. They are talking about commoning-out some redundant terms, between which some function calls happen, as in :
   x = a+b+c;
   f(); // Here, if f can change the content of some variables like a, b or c, it prevents the CSE on the redundant term (a+b+c)
   y = a+b+c;
   
   But this kind of consideration does not concern us, as TIR is in a (weak) SSA form : the content of variables won't be updated, and can be assumed to be the same everywhere through the execution of the program. Thus, only condition 1 matters to us.
   
   If the kPure tag has both the meaning of 1 and 2, then we will be to just use it in the future when relaxing the condition.
   If it's only the condition 2, then we will indeed need another tag.
   
   >     2. Is the let binding semantic in tir explicitly what we want? Is the value expr is specified to be evaluated at let scope entry? otherwise there may be no reuse effect if let-binding allow the value expr get lazily expanded in let body.
   
   I must admit that I am not entirely sure now that I think of it... Anyone to confirm on that?
   
   >     3. Is this possible for some complex (but common) expr to be lifted out of control flow scope? eg
   >        ```python
   >        if cond0:
   >            tir.evaluate(complex_e)  # slow path
   >        else:
   >           if cond2: 
   >                tir.evaluate(complex_e)  # slow path
   >           else:
   >                ...  # general fast path
   >        ```
   >        
   >        
   >            
   >              
   >            
   >        
   >              
   >            
   >        
   >            
   >          
   >        after CSE maybe comes to
   >        ```python
   >        x = tir.Var("int32")
   >        with tir.let(x, complex_e):  # always evaluate slow path
   >            if cond0:
   >                tir.evaluate(x) 
   >            else:
   >                if cond2: 
   >                    tir.evaluate(x)
   >                else:
   >                    ... # general fast path
   >        ```
   
   Yes, that is possible. Although, now that I think of it, it would probably better to not do it in the exact example that you provided, as the expression `complex_e` of the slow path now becomes always evaluated, even if we end up going in the general fast branch...
   We should only lift things through "if" statements if both the then branch and the else branch computes it. This restriction doesn't improve the semantics preservation (which was previously fine), but it will avoid this kind of anti-optimization. That should be easily fixable.
   
   Many thanks!


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi commented on a change in pull request #9482: Implementation of Common Subexpression Elimination for TIR

Posted by GitBox <gi...@apache.org>.
masahi commented on a change in pull request #9482:
URL: https://github.com/apache/tvm/pull/9482#discussion_r790202435



##########
File path: include/tvm/tir/transform.h
##########
@@ -458,6 +458,14 @@ TVM_DLL Pass FlattenBuffer();
  */
 TVM_DLL Pass TextureFlatten();
 
+/*!
+ * \brief Implements a Common Subexpression Elimination (CSE)
+ *        which introduces let-in bindings for duplicated sub-expressions.
+ * \param enable_cse_tir Whether common subexpression elimination is enabled.
+ * \return The pass.
+ */
+TVM_DLL Pass CommonSubexprElim(bool enable_cse_tir = true);

Review comment:
       Please choose another name to make it explicit that this pass is for TIR, to avoid confusion with the existing relay-level CSE `EliminateCommonSubexpr`.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] FranckQC commented on a change in pull request #9482: Implementation of Common Subexpression Elimination for TIR

Posted by GitBox <gi...@apache.org>.
FranckQC commented on a change in pull request #9482:
URL: https://github.com/apache/tvm/pull/9482#discussion_r796246711



##########
File path: src/tir/transforms/common_subexpr_elim_tools.cc
##########
@@ -0,0 +1,836 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+* \file common_subexpr_elim_tools.cc
+* \brief Implementation of analysis tools and utility functions used
+          by the Common Subexpression Elimination (CSE) pass.
+*/
+
+#include "common_subexpr_elim_tools.h"
+
+#include <tvm/ir/transform.h>  // For the class Pass and the class PassContext
+#include <tvm/runtime/container/string.h>
+#include <tvm/tir/analysis.h>  // For the ExprDeepEqual analysis
+#include <tvm/tir/expr.h>
+#include <tvm/tir/expr_functor.h>
+#include <tvm/tir/function.h>  // For the class PrimFunc
+#include <tvm/tir/stmt.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>  // For the declaration of the pass
+
+#include <algorithm>      // For std::find_if
+#include <unordered_map>  // For the hashtable datatype
+#include <vector>
+
+#include "../analysis/check_contains.h"  // For the CheckContains analysis
+
+namespace tvm {
+namespace tir {
+
+// cache_ is a static variable of the class ComputationsDoneBy, and C++ requires to define here
+// such static attribute, otherwise it causes a linking error.
+CacheOfComputations ComputationsDoneBy::cache_;
+
+/* ********************************** Class ComputationsDoneBy **********************************
+*********************************************************************************************** */
+
+/* This utility class of the CSE pass offers a way of knowing the eligible computations done by a
+   statement or expression. A "computation" here is a syntatical entity, represented by a PrimExpr.
+   This analysis returns a hashtable associating PrimExpr (a computation done) to a number (which
+   is the number of time that this computation is being seen).
+   This analysis is used by the CSE pass in order to find potential candidates for being introduced
+   into new variables (after having merged semantically equivalent computations).
+
+   This analysis is parametrized by two predicates : `is_eligible_computation` and
+   `can_contain_computations`.
+   The first one helps to select only "eligible" computations, and the second one helps to only
+   select computations that are located at appropriate location (i.e., it tells in which nodes the
+   analysis can recurse). The user of the class must define these notions of "eligible computation"
+   and of "nodes that can contain eligibile computations" for his own use case.
+
+   - On an statement, this analysis often returns the union of all the computations that appear in
+   its child nodes (ie, the union of the results of the recursive calls).
+   For instance, on the input statement [let a = x+y in Mem[i1+i2] = a+b] it will report (x+y)
+   seen once, (i1+i2) seen once, and (a+b) also seen once when used with typical predicates.
+   On some nodes, it will return something more complicated that uses the intersection of the
+   computations done by the children nodes.
+   For instance, on the input statement [if (x+y>z) then a = x+y else a = b-x] it will return
+   (x+y) seen twice but it won't report b-x as is it seen only the else branch.
+
+   - On an expression, this analysis returns the expression itself, except if it is not eligible
+   for being introduced by the CSE pass into a variable according to `is_eligible_computation_`
+   (often because it's a load node or a function call node for instance), in which case it will
+   return the union of the recursive calls on its children, as long as the other predicate
+   `can_contain_computations` evaluates to true to let the algorithm recurse deeper.
+   With typical predicates, on the expression ((w+x)+(y+z)) it will return only the expression
+   itself. But on the expression Load[i1+i2] it might return only (i1+i2) as the full Load node
+   might not be eligible.
+
+   This class uses an internal cache of results, so that if one queries it several times on the
+   same statement or expression, it will just retrieve the result from its internal cache.
+   That avoids some systematic recomputations, which would otherwise happen as the CSE pass first
+   analyses the program at the toplovel (asking for the computations done by the root), and then
+   dives deeper and deeper into the program, asking for the computations done by the children of
+   the root, which were necessarly previously obtained when computing the computations done by the
+   root (as the computations done by the root are by definition the union of the computations done
+   by the children nodes).
+
+   The somehow difficult aspect of the implementation is the interaction between this caching of
+   results, and the fact that the VisitStmt()/VisitExpr() of an analyzer (a StmtExprVisitor) are
+   void methods which can't return anything, and instead need to accumulate a result into a member
+   variable, which is called `table_of_computations_` here.
+
+   In particular, as the specialized methods (all the VisitStmt_() and VisitExpr_() methods), just
+   call VisitStmt()/VisitExpr() on all the children nodes within the same instance, if we don't
+   want to override each of these specialized methods to change this behaviour, then
+   `table_of_computations_` will necessary be shared by all the children of a given nodes.
+   That requires to be careful when trying to write into the cache.
+*/
+
+/*!
+ * \brief Does the union of two tables of computations.
+ * \param table_main One of the two tables. The union will be written into it.
+ * \param table_aux The other table, which won't change.
+ * \note Does it directly in the first argument A for efficiency, as the union of A and B
+ *       necessarily gives something which contains A, so we avoid its copy.
+ */
+void UnionOf2TablesOfComputations(TableOfComputations& table_main,
+                                  const TableOfComputations& table_aux) {
+  // Adds each element of the second table to the first one
+  for (const auto& current : table_aux) {
+    table_main[current.first] += current.second;
+  }
+}
+
+/*!
+ * \brief Does the union of three tables of computations.
+ * \param table1 One of the three tables, which won't change.
+ * \param table2 One of the three tables, which won't change.
+ * \param table3 One of the three tables, which won't change.
+ * \note We don't need (at least yet) to have a function working for N tables, even if this
+ *       function for 3 tables seems at first glance redundant with the one for 2 tables defined
+ *       just above. The reason is that in order to do the union for N tables, we need to know how
+ *       to do it for two. That's because we would compute for N tables using the associativity
+ *       of the union : T1 U T2 U T3 ... U Tn = ((T1 U T2) U T3) ... U Tn
+ *       Therefore, we need one for 2 tables anyway. And as the one for N (N>=3) would only be used
+ *       (at least for now) for N=3, there is at the moment no need for such a generic union over
+ *       N tables.
+ */
+TableOfComputations UnionOf3TablesOfComputations(const TableOfComputations& table1,
+            const TableOfComputations& table2, const TableOfComputations& table3) {
+  TableOfComputations result = table1; // Copy needed as the union of 2 writes into its first arg
+  UnionOf2TablesOfComputations(result, table2);
+  UnionOf2TablesOfComputations(result, table3);
+
+  return result;
+}
+
+/*!
+ * \brief Does the intersection of two tables of computations.
+ * \param table1 One of the two tables, which won't change.
+ * \param table2 The other table, which also won't change.
+ */
+TableOfComputations IntersectionOf2TablesOfComputations(const TableOfComputations& table1,
+                                                        const TableOfComputations& table2) {
+  TableOfComputations result;
+  for (const auto& current : table1) {
+    auto it = table2.find(current.first);
+    if (it != table2.end()) {
+      result[current.first] = current.second + it->second;
+    }
+  }
+  return result;
+}
+
+/*!
+ * \brief Does the intersection of three tables of computations.
+ * \param table1 One of the three tables, which won't change.
+ * \param table2 One of the three tables, which won't change.
+ * \param table3 One of the three tables, which won't change.
+ * \note We don't need (at least yet) to have a function working for N tables, even if this
+ *       function for 3 tables seems at first glance redundant with the one for 2 tables defined
+ *       just above. The reason is that in order to do the intersection for N tables, we need to
+ *       know how to do it for two. That's because we would compute for N tables using the
+ *       associativity of the intersection : T1 Inter T2 Inter T3 ... Inter Tn
+ *       = ((T1 Inter T2) Inter T3) ... Inter Tn
+ *       Therefore, we need one for 2 tables anyway. And as the one for N (N>=3) would only be used
+ *       (at least for now) for N=3, there is at the moment no need for such a generic intersection
+ *       over N tables.
+ */
+TableOfComputations IntersectionOf3TablesOfComputations(const TableOfComputations& table1,

Review comment:
       Agreed :)




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi commented on pull request #9482: Implementation of Common Subexpression Elimination for TIR

Posted by GitBox <gi...@apache.org>.
masahi commented on pull request #9482:
URL: https://github.com/apache/tvm/pull/9482#issuecomment-1034147763


   Please have a look again @Hzfengsy @wrongtest 
   
   I'm merging this this week unless there are other comments @tqchen @junrushao1994 @vinx13  


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] FranckQC commented on pull request #9482: Implementation of Common Subexpression Elimination for TIR

Posted by GitBox <gi...@apache.org>.
FranckQC commented on pull request #9482:
URL: https://github.com/apache/tvm/pull/9482#issuecomment-1031839327


   For the docs, I still get a red unfortunately.
   It seems to be only warning though (due to http errors when fetching some docs), so I'm not sure I understand why it's considered as a failure (it returns 1).
   
   ![image](https://user-images.githubusercontent.com/89943638/152858417-d7e865fb-88e8-4082-a3a7-44ef5ae6b2e9.png)
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi commented on pull request #9482: Implementation of Common Subexpression Elimination for TIR

Posted by GitBox <gi...@apache.org>.
masahi commented on pull request #9482:
URL: https://github.com/apache/tvm/pull/9482#issuecomment-1034881613


   Thanks @FranckQC @Hzfengsy @wrongtest @mbs-octoml @jroesch this is merged!!


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] FranckQC edited a comment on pull request #9482: Implementation of Common Subexpression Elimination for TIR

Posted by GitBox <gi...@apache.org>.
FranckQC edited a comment on pull request #9482:
URL: https://github.com/apache/tvm/pull/9482#issuecomment-969151523


   > 
   > 
   > Hi @FranckQC, I wanted TIR-level CSE for a long time, so very excited to see this!
   > 
   > What I wanted to do is to eliminate common expressions that span across the host and GPU - for example, in GPU `sort` kernel, I need to make `log2(N)` GPU kernel calls from the host to sort the input bottom up. In principle, `log2(N)` needs to computed once by the host and pass to the GPU kernel, but since we cannot CSE `log2(N)` expression that appears both in the host and GPU kernel, right now the GPU sort kernel is littered with `log2(N)` compute like this (note a lot of calls to `call_spirv_pure_glsl450` which is totally unnecessary if we had TIR-level CSE) https://gist.github.com/masahi/7a755ef67009e1a836e3212c53cf496f
   > 
   > Is this PR going to solve my problem?
   
   Hi @masahi 
   Thanks a lot for the kind words! I'm happy to read that this new pass might be useful to you.
   In principle, every redundant subterms that are eligible for being commoned out (i.e, which does not contain function calls, etc) will be commoned out. There are also a few restrictions which are due to some specifics of TVM, but these are rare.
   
   Unfortunately, on the TIR code snippet that you have uploaded, it seems that the subexpression that is redundant is just a function call. These can't be commoned out into a new variable by the CSE pass as it does not have any guarantee that the function has no side effects, meaning that it will always produce the same outputs for the same inputs. Without this guarantee, commoning out such function calls could change the program's semantics, so it's not done as preserving the samantics of the program is vital.
   
   I can imagine that for functions that are guaranteed to not do any side effects (and which are therefore "functions" in the mathematical sense of the term), we could relax this restriction, but that would be an extension to implement in the future. And it would rely on some "NoSideEffect" tag on functions.
   
   However, please note that if you had some other redundancies, this CSE pass would common out whatever redundancies you have that are eligible.
   For instance : 
   Assume you have the term (f(42) + f(42)) + (x*y + x*y) that appear somewhere.
   It is not eligible in its entirety, as it contains functions call.
   But its subterm (x*y) + x*y) is eligible, so this subpart will be commoned out.
   
   Does that answers your question?
   Please do not hesitate to tell me if you need help for trying the pass.
   
   Kind regards.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] FranckQC commented on pull request #9482: Implementation of Common Subexpression Elimination for TIR

Posted by GitBox <gi...@apache.org>.
FranckQC commented on pull request #9482:
URL: https://github.com/apache/tvm/pull/9482#issuecomment-969398522


   The pointer to the submodule vta-hw has been rolled-back to its previous state.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi commented on pull request #9482: Implementation of Common Subexpression Elimination for TIR

Posted by GitBox <gi...@apache.org>.
masahi commented on pull request #9482:
URL: https://github.com/apache/tvm/pull/9482#issuecomment-1031914483


   cc @tmoreau89 if he has any insight on why the VTA test is broken despite CSE being turned off.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] FranckQC edited a comment on pull request #9482: Implementation of Common Subexpression Elimination for TIR

Posted by GitBox <gi...@apache.org>.
FranckQC edited a comment on pull request #9482:
URL: https://github.com/apache/tvm/pull/9482#issuecomment-1032077140


   Current status before running it again :
   ![image](https://user-images.githubusercontent.com/89943638/152893889-3d53456e-c28c-4698-bef7-1b2ce2dfe3b7.png)
   
   Hopefully, docs should be ok now.
   Perhaps also the i386 should be ok as it was a weird memory allocation issue.
   
   But we will still have to deal with the failure on the VTA test.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi merged pull request #9482: Implementation of Common Subexpression Elimination for TIR

Posted by GitBox <gi...@apache.org>.
masahi merged pull request #9482:
URL: https://github.com/apache/tvm/pull/9482


   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] FranckQC commented on pull request #9482: Implementation of Common Subexpression Elimination for TIR

Posted by GitBox <gi...@apache.org>.
FranckQC commented on pull request #9482:
URL: https://github.com/apache/tvm/pull/9482#issuecomment-972611870


   > 
   > 
   > > Does that answers your question?
   > 
   > @FranckQC Thanks, yes absolutely. I can work on extending this pass to support my use case, after this is merged.
   
   Great! I'll be of course happy to help if that's needed.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] FranckQC commented on pull request #9482: Implementation of Common Subexpression Elimination for TIR

Posted by GitBox <gi...@apache.org>.
FranckQC commented on pull request #9482:
URL: https://github.com/apache/tvm/pull/9482#issuecomment-1018993039


   Hi everyone.
   
   So sorry for the delayed answer.
   
   I have just pushed yesterday and today a few more commits that improve this CSE pass. The biggest improvement is the behavior of the pass for If and For nodes, for which expressions were previously allowed to be lifted out of control flow scope.
   I spotted that thanks to the question of @wrongtest, so many thanks again for it.
   It was definitely a bad thing (although rare, and not breaking the semantics preservation) as it was an anti-optimization if a redundant was only present in one execution path (let's say the THEN branch of an If) to lift it just above the If, as it might never have been computed at all if the execution flows reach the ELSE branch.
   Now, it won't do these kind of things. Instead, it will introduce it at the beginning of the bloc where it is redundant, but it won't be lifted outside of it.
   
   The only things left are to re run clang-format and to address the last little things reported by the CI. However, the build already succeeds for both Windows and MacOS.
   I'm hoping to be done with everything by the end of the week-end, so that it can be merged next week for people to enjoy the CSE! I hope it will be useful.
   
   Kind regards,
   Franck


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi commented on a change in pull request #9482: Implementation of Common Subexpression Elimination for TIR

Posted by GitBox <gi...@apache.org>.
masahi commented on a change in pull request #9482:
URL: https://github.com/apache/tvm/pull/9482#discussion_r790201528



##########
File path: src/tir/transforms/common_subexpr_elim_tools.cc
##########
@@ -0,0 +1,836 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+* \file common_subexpr_elim_tools.cc
+* \brief Implementation of analysis tools and utility functions used
+          by the Common Subexpression Elimination (CSE) pass.
+*/
+
+#include "common_subexpr_elim_tools.h"
+
+#include <tvm/ir/transform.h>  // For the class Pass and the class PassContext
+#include <tvm/runtime/container/string.h>
+#include <tvm/tir/analysis.h>  // For the ExprDeepEqual analysis
+#include <tvm/tir/expr.h>
+#include <tvm/tir/expr_functor.h>
+#include <tvm/tir/function.h>  // For the class PrimFunc
+#include <tvm/tir/stmt.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>  // For the declaration of the pass
+
+#include <algorithm>      // For std::find_if
+#include <unordered_map>  // For the hashtable datatype
+#include <vector>
+
+#include "../analysis/check_contains.h"  // For the CheckContains analysis
+
+namespace tvm {
+namespace tir {
+
+// cache_ is a static variable of the class ComputationsDoneBy, and C++ requires to define here
+// such static attribute, otherwise it causes a linking error.
+CacheOfComputations ComputationsDoneBy::cache_;
+
+/* ********************************** Class ComputationsDoneBy **********************************
+*********************************************************************************************** */
+
+/* This utility class of the CSE pass offers a way of knowing the eligible computations done by a
+   statement or expression. A "computation" here is a syntatical entity, represented by a PrimExpr.
+   This analysis returns a hashtable associating PrimExpr (a computation done) to a number (which
+   is the number of time that this computation is being seen).
+   This analysis is used by the CSE pass in order to find potential candidates for being introduced
+   into new variables (after having merged semantically equivalent computations).
+
+   This analysis is parametrized by two predicates : `is_eligible_computation` and
+   `can_contain_computations`.
+   The first one helps to select only "eligible" computations, and the second one helps to only
+   select computations that are located at appropriate location (i.e., it tells in which nodes the
+   analysis can recurse). The user of the class must define these notions of "eligible computation"
+   and of "nodes that can contain eligibile computations" for his own use case.
+
+   - On an statement, this analysis often returns the union of all the computations that appear in
+   its child nodes (ie, the union of the results of the recursive calls).
+   For instance, on the input statement [let a = x+y in Mem[i1+i2] = a+b] it will report (x+y)
+   seen once, (i1+i2) seen once, and (a+b) also seen once when used with typical predicates.
+   On some nodes, it will return something more complicated that uses the intersection of the
+   computations done by the children nodes.
+   For instance, on the input statement [if (x+y>z) then a = x+y else a = b-x] it will return
+   (x+y) seen twice but it won't report b-x as is it seen only the else branch.
+
+   - On an expression, this analysis returns the expression itself, except if it is not eligible
+   for being introduced by the CSE pass into a variable according to `is_eligible_computation_`
+   (often because it's a load node or a function call node for instance), in which case it will
+   return the union of the recursive calls on its children, as long as the other predicate
+   `can_contain_computations` evaluates to true to let the algorithm recurse deeper.
+   With typical predicates, on the expression ((w+x)+(y+z)) it will return only the expression
+   itself. But on the expression Load[i1+i2] it might return only (i1+i2) as the full Load node
+   might not be eligible.
+
+   This class uses an internal cache of results, so that if one queries it several times on the
+   same statement or expression, it will just retrieve the result from its internal cache.
+   That avoids some systematic recomputations, which would otherwise happen as the CSE pass first
+   analyses the program at the toplovel (asking for the computations done by the root), and then
+   dives deeper and deeper into the program, asking for the computations done by the children of
+   the root, which were necessarly previously obtained when computing the computations done by the
+   root (as the computations done by the root are by definition the union of the computations done
+   by the children nodes).
+
+   The somehow difficult aspect of the implementation is the interaction between this caching of
+   results, and the fact that the VisitStmt()/VisitExpr() of an analyzer (a StmtExprVisitor) are
+   void methods which can't return anything, and instead need to accumulate a result into a member
+   variable, which is called `table_of_computations_` here.
+
+   In particular, as the specialized methods (all the VisitStmt_() and VisitExpr_() methods), just
+   call VisitStmt()/VisitExpr() on all the children nodes within the same instance, if we don't
+   want to override each of these specialized methods to change this behaviour, then
+   `table_of_computations_` will necessary be shared by all the children of a given nodes.
+   That requires to be careful when trying to write into the cache.
+*/
+
+/*!
+ * \brief Does the union of two tables of computations.
+ * \param table_main One of the two tables. The union will be written into it.
+ * \param table_aux The other table, which won't change.
+ * \note Does it directly in the first argument A for efficiency, as the union of A and B
+ *       necessarily gives something which contains A, so we avoid its copy.
+ */
+void UnionOf2TablesOfComputations(TableOfComputations& table_main,
+                                  const TableOfComputations& table_aux) {
+  // Adds each element of the second table to the first one
+  for (const auto& current : table_aux) {
+    table_main[current.first] += current.second;
+  }
+}
+
+/*!
+ * \brief Does the union of three tables of computations.
+ * \param table1 One of the three tables, which won't change.
+ * \param table2 One of the three tables, which won't change.
+ * \param table3 One of the three tables, which won't change.
+ * \note We don't need (at least yet) to have a function working for N tables, even if this
+ *       function for 3 tables seems at first glance redundant with the one for 2 tables defined
+ *       just above. The reason is that in order to do the union for N tables, we need to know how
+ *       to do it for two. That's because we would compute for N tables using the associativity
+ *       of the union : T1 U T2 U T3 ... U Tn = ((T1 U T2) U T3) ... U Tn
+ *       Therefore, we need one for 2 tables anyway. And as the one for N (N>=3) would only be used
+ *       (at least for now) for N=3, there is at the moment no need for such a generic union over
+ *       N tables.
+ */
+TableOfComputations UnionOf3TablesOfComputations(const TableOfComputations& table1,
+            const TableOfComputations& table2, const TableOfComputations& table3) {
+  TableOfComputations result = table1; // Copy needed as the union of 2 writes into its first arg
+  UnionOf2TablesOfComputations(result, table2);
+  UnionOf2TablesOfComputations(result, table3);
+
+  return result;
+}
+
+/*!
+ * \brief Does the intersection of two tables of computations.
+ * \param table1 One of the two tables, which won't change.
+ * \param table2 The other table, which also won't change.
+ */
+TableOfComputations IntersectionOf2TablesOfComputations(const TableOfComputations& table1,
+                                                        const TableOfComputations& table2) {
+  TableOfComputations result;
+  for (const auto& current : table1) {
+    auto it = table2.find(current.first);
+    if (it != table2.end()) {
+      result[current.first] = current.second + it->second;
+    }
+  }
+  return result;
+}
+
+/*!
+ * \brief Does the intersection of three tables of computations.
+ * \param table1 One of the three tables, which won't change.
+ * \param table2 One of the three tables, which won't change.
+ * \param table3 One of the three tables, which won't change.
+ * \note We don't need (at least yet) to have a function working for N tables, even if this
+ *       function for 3 tables seems at first glance redundant with the one for 2 tables defined
+ *       just above. The reason is that in order to do the intersection for N tables, we need to
+ *       know how to do it for two. That's because we would compute for N tables using the
+ *       associativity of the intersection : T1 Inter T2 Inter T3 ... Inter Tn
+ *       = ((T1 Inter T2) Inter T3) ... Inter Tn
+ *       Therefore, we need one for 2 tables anyway. And as the one for N (N>=3) would only be used
+ *       (at least for now) for N=3, there is at the moment no need for such a generic intersection
+ *       over N tables.
+ */
+TableOfComputations IntersectionOf3TablesOfComputations(const TableOfComputations& table1,
+                  const TableOfComputations& table2, const TableOfComputations& table3) {
+  TableOfComputations result = IntersectionOf2TablesOfComputations(table1, table2);
+  result = IntersectionOf2TablesOfComputations(result, table3);
+  return result;
+}
+
+/*!
+ * \brief Recompute the number of times that each computation in table_main
+          is being seen in table_bloc1, table_bloc2 and table_bloc3. It sets
+          each element to the sum of the times it is seen in each individual bloc.
+ * \param table_main The main table, for which we recompute the counters.
+ * \param table1 One of the three tables, which won't change.
+ * \param table2 One of the three tables, which won't change.
+ * \param table3 One of the three tables, which won't change.
+ * \note This function is needed because both the intersection (A Inter B) and the union
+ *        (A U B U C) adds the individual counters found in A, B and C. So when we treat for
+ *        instance an If (which contains a Cond, a Then branch and an Else branch),
+ *        it will compute (Then Inter Else) U (Cond Inter Then) U (Cond Inter Else).
+ *        In order to get back to the appropriate number (for instance, 3 if seen one time in each
+ *        bloc), it is therefore necessary to recompute the counters afterwards, which is what this
+ *        function does.
+ */
+void RecomputeNbTimesSeenInThreeBlocs(TableOfComputations& table_main,
+    const TableOfComputations& table_bloc1, const TableOfComputations& table_bloc2,
+    const TableOfComputations& table_bloc3) {

Review comment:
       I prefer taking a vector of `ComputationTable` here, simplify the body, and remove `ThreeBlocs` suffix.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi commented on a change in pull request #9482: Implementation of Common Subexpression Elimination for TIR

Posted by GitBox <gi...@apache.org>.
masahi commented on a change in pull request #9482:
URL: https://github.com/apache/tvm/pull/9482#discussion_r790201951



##########
File path: src/tir/transforms/common_subexpr_elim_tools.h
##########
@@ -0,0 +1,216 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file common_subexpr_elim_tools.h
+ * \brief Interface of analysis tools and utility functions used
+           by the Common Subexpression Elimination (CSE) pass.
+ */
+
+#ifndef TVM_TIR_TRANSFORMS_COMMON_SUBEXPR_ELIM_TOOLS_H_
+#define TVM_TIR_TRANSFORMS_COMMON_SUBEXPR_ELIM_TOOLS_H_
+
+#include <tvm/runtime/container/string.h>
+#include <tvm/tir/analysis.h>  // For the ExprDeepEqual analysis
+#include <tvm/tir/expr.h>
+#include <tvm/tir/expr_functor.h>
+#include <tvm/tir/stmt.h>
+#include <tvm/tir/stmt_functor.h>  // For the class StmtExprVisitor
+
+#include <unordered_map>  // For the hashtable datatype
+#include <vector>
+
+#include "../../../3rdparty/dmlc-core/include/dmlc/optional.h"
+
+namespace tvm {
+namespace tir {
+
+/*!
+ * \brief A table of computations is a hashtable which associates to each expression being computed
+          a number (which is the number of time that it is computed)
+					It is important to note that the hash used is a StructuralHash (and not an ObjectPtrHash)
+					as we need to hash similarly deeply equal terms.
+					The comparison used is ExprDeepEqual, which is stricter than StructuralEqual (as it does
+					not do variables remapping), so it is compatible with StructuralHash (intended to be used
+					with StructuralEqual).
+ */
+using TableOfComputations = std::unordered_map<PrimExpr, size_t, StructuralHash, ExprDeepEqual>;
+
+/*!
+ * \brief A cache of computations is made of a pair of two hashtables, which respectively associate
+          to each statement or expression of the program its table of computations. Its purpose is
+          to avoid the CSE pass from recomputing repeatedly the same tables of computations.
+ */
+struct CacheOfComputations {
+  // Part of the cache for statements
+  // It maps each known statement to its table of computations
+  std::unordered_map<Stmt, TableOfComputations, ObjectPtrHash, ObjectPtrEqual>
+      cache_stmt_table_computations_;
+
+  // Part of the cache for expressions
+  // It maps each known expression to its table of computations
+  std::unordered_map<PrimExpr, TableOfComputations, ObjectPtrHash, ObjectPtrEqual>
+      cache_expr_table_computations_;
+};
+
+/*!
+ * \brief Visitor which returns in a hashtable the (syntatic) computations done by an expression
+          or by a statement.
+ * \note Computations here are considered syntactically, meaning that semantically equivalent
+          computations that are not syntactically the same are not merged together.
+ */
+class ComputationsDoneBy : public StmtExprVisitor {
+ public:
+  // Toplevel (static) methods
+  static TableOfComputations GetComputationsDoneBy(
+      const PrimExpr& expr, std::function<bool(const PrimExpr&)> is_eligible_computation,
+      std::function<bool(const PrimExpr&)> can_contain_computations);
+  static TableOfComputations GetComputationsDoneBy(

Review comment:
       I prefer the name `ExtractSubComputations` 




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi commented on a change in pull request #9482: Implementation of Common Subexpression Elimination for TIR

Posted by GitBox <gi...@apache.org>.
masahi commented on a change in pull request #9482:
URL: https://github.com/apache/tvm/pull/9482#discussion_r797995975



##########
File path: python/tvm/tir/transform/transform.py
##########
@@ -310,6 +310,15 @@ def BF16TypeLowering():
     """
     return _ffi_api.BF16TypeLowering()  # type: ignore
 
+def CommonSubexprElim(enable_cse_tir: bool = True):

Review comment:
       I didn't know that other TIR passes have a specific option for disabling. `disabled_pass` I brought up is enforced at https://github.com/apache/tvm/blob/3fbce70a8ad7e032de8c402fb27e0396435c8eca/src/ir/transform.cc#L479. I believe this has higher precedence than other disabling options, since it is enforced at the higher level in the stack. I don't know why things are like this way, maybe @tqchen @zhiics knows?
   
   For Vectorize, I think this option is needed because we need to turn a loop with `vectorized` annotation into a sequential one if we want to skip this pass. 
   
   https://github.com/apache/tvm/blob/c6f62aafc91e2600ed7772597fd4238c924c2a1b/src/tir/transforms/vectorize_loop.cc#L584-L588 
   
   Not so long ago, the contents of `driver_api.c`` was actually written in python, so exposing this option to python made sense at that time. It's still good to have it today for experimentation purposes like you said. You can keep it for CSE if you find it useful, but for the purpose of disabling CSE, `disabled_list` already does the job (but not for `Vectorize` as explained above).
   
   > The pass LoopPartition is quite weird, as its boolean for activation/deactivation is passed directly to the CreatePassList() method:
   
   Indeed it is weird, I remember our `LoopPartition` implementation was problematic in the past (may be it is still today), probably an artifact from that time. Anyway I don't think there is a deep reason behind it.
   




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi commented on a change in pull request #9482: Implementation of Common Subexpression Elimination for TIR

Posted by GitBox <gi...@apache.org>.
masahi commented on a change in pull request #9482:
URL: https://github.com/apache/tvm/pull/9482#discussion_r797995975



##########
File path: python/tvm/tir/transform/transform.py
##########
@@ -310,6 +310,15 @@ def BF16TypeLowering():
     """
     return _ffi_api.BF16TypeLowering()  # type: ignore
 
+def CommonSubexprElim(enable_cse_tir: bool = True):

Review comment:
       I didn't know that other TIR passes have a specific option for disabling. `disabled_pass` I brought up is enforced at https://github.com/apache/tvm/blob/3fbce70a8ad7e032de8c402fb27e0396435c8eca/src/ir/transform.cc#L479. I believe this has higher precedence than other disabling options, since it is enforced at the higher level in the stack. I don't know why things are like this way, maybe @tqchen @zhiics knows?
   
   For Vectorize, I think this option is needed because we need to turn a loop with `vectorized` annotation into a sequential one if we want to skip this pass. 
   
   https://github.com/apache/tvm/blob/c6f62aafc91e2600ed7772597fd4238c924c2a1b/src/tir/transforms/vectorize_loop.cc#L584-L588 
   
   Not so long ago, the contents of `driver_api.c` was actually written in python, so exposing this option to python made sense at that time. It's still good to have it today for experimentation purposes like you said. You can keep it for CSE if you find it useful, but for the purpose of disabling CSE, `disabled_list` already does the job (but not for `Vectorize` as explained above).
   
   > The pass LoopPartition is quite weird, as its boolean for activation/deactivation is passed directly to the CreatePassList() method:
   
   Indeed it is weird, I remember our `LoopPartition` implementation was problematic in the past (may be it is still today), probably an artifact from that time. Anyway I don't think there is a deep reason behind it.
   




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] FranckQC commented on pull request #9482: Implementation of Common Subexpression Elimination for TIR

Posted by GitBox <gi...@apache.org>.
FranckQC commented on pull request #9482:
URL: https://github.com/apache/tvm/pull/9482#issuecomment-1029590772


   
   ![image](https://user-images.githubusercontent.com/89943638/152464025-9be63301-c629-4ca5-af3d-8cf168f420ff.png)
   
   Alright, so now we have the following that are OK:
   - Both the windows and macOS builds
   - All the sanity check and all their syntactical checks with linters and clang-format (no change is being done by them to the code-base, yay!)
   
   Only 6 additional tests (test_lower_build_te_schedule, test_lower_build_tir_func, test_lower_build_tir_module, test_lower_build_lowered_module, test_tensor_compute1 and test_tensor_compute2) that I did not have are failing. For these ones, the CSE is doing some work, and therefore is breaking the syntactical check between the obtained TIR and the expected TIR.
   I'll try to check that the TIR code obtained after lowering with the CSE on is effectively what we expect, and I'll update the expected TIR with it.
   Or perhaps I should simply update these tests by lowering without the CSE pass (i.e. with the pass disabled)?
   
   Hoping to see that merged soon, I think we are not far!
   :)


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] FranckQC commented on a change in pull request #9482: Implementation of Common Subexpression Elimination for TIR

Posted by GitBox <gi...@apache.org>.
FranckQC commented on a change in pull request #9482:
URL: https://github.com/apache/tvm/pull/9482#discussion_r797171727



##########
File path: src/tir/transforms/common_subexpr_elim_tools.cc
##########
@@ -0,0 +1,836 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+* \file common_subexpr_elim_tools.cc
+* \brief Implementation of analysis tools and utility functions used
+          by the Common Subexpression Elimination (CSE) pass.
+*/
+
+#include "common_subexpr_elim_tools.h"
+
+#include <tvm/ir/transform.h>  // For the class Pass and the class PassContext
+#include <tvm/runtime/container/string.h>
+#include <tvm/tir/analysis.h>  // For the ExprDeepEqual analysis
+#include <tvm/tir/expr.h>
+#include <tvm/tir/expr_functor.h>
+#include <tvm/tir/function.h>  // For the class PrimFunc
+#include <tvm/tir/stmt.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>  // For the declaration of the pass
+
+#include <algorithm>      // For std::find_if
+#include <unordered_map>  // For the hashtable datatype
+#include <vector>
+
+#include "../analysis/check_contains.h"  // For the CheckContains analysis
+
+namespace tvm {
+namespace tir {
+
+// cache_ is a static variable of the class ComputationsDoneBy, and C++ requires to define here
+// such static attribute, otherwise it causes a linking error.
+CacheOfComputations ComputationsDoneBy::cache_;
+
+/* ********************************** Class ComputationsDoneBy **********************************
+*********************************************************************************************** */
+
+/* This utility class of the CSE pass offers a way of knowing the eligible computations done by a
+   statement or expression. A "computation" here is a syntatical entity, represented by a PrimExpr.
+   This analysis returns a hashtable associating PrimExpr (a computation done) to a number (which
+   is the number of time that this computation is being seen).
+   This analysis is used by the CSE pass in order to find potential candidates for being introduced
+   into new variables (after having merged semantically equivalent computations).
+
+   This analysis is parametrized by two predicates : `is_eligible_computation` and
+   `can_contain_computations`.
+   The first one helps to select only "eligible" computations, and the second one helps to only
+   select computations that are located at appropriate location (i.e., it tells in which nodes the
+   analysis can recurse). The user of the class must define these notions of "eligible computation"
+   and of "nodes that can contain eligibile computations" for his own use case.
+
+   - On an statement, this analysis often returns the union of all the computations that appear in
+   its child nodes (ie, the union of the results of the recursive calls).
+   For instance, on the input statement [let a = x+y in Mem[i1+i2] = a+b] it will report (x+y)
+   seen once, (i1+i2) seen once, and (a+b) also seen once when used with typical predicates.
+   On some nodes, it will return something more complicated that uses the intersection of the
+   computations done by the children nodes.
+   For instance, on the input statement [if (x+y>z) then a = x+y else a = b-x] it will return
+   (x+y) seen twice but it won't report b-x as is it seen only the else branch.
+
+   - On an expression, this analysis returns the expression itself, except if it is not eligible
+   for being introduced by the CSE pass into a variable according to `is_eligible_computation_`
+   (often because it's a load node or a function call node for instance), in which case it will
+   return the union of the recursive calls on its children, as long as the other predicate
+   `can_contain_computations` evaluates to true to let the algorithm recurse deeper.
+   With typical predicates, on the expression ((w+x)+(y+z)) it will return only the expression
+   itself. But on the expression Load[i1+i2] it might return only (i1+i2) as the full Load node
+   might not be eligible.
+
+   This class uses an internal cache of results, so that if one queries it several times on the
+   same statement or expression, it will just retrieve the result from its internal cache.
+   That avoids some systematic recomputations, which would otherwise happen as the CSE pass first
+   analyses the program at the toplovel (asking for the computations done by the root), and then
+   dives deeper and deeper into the program, asking for the computations done by the children of
+   the root, which were necessarly previously obtained when computing the computations done by the
+   root (as the computations done by the root are by definition the union of the computations done
+   by the children nodes).
+
+   The somehow difficult aspect of the implementation is the interaction between this caching of
+   results, and the fact that the VisitStmt()/VisitExpr() of an analyzer (a StmtExprVisitor) are
+   void methods which can't return anything, and instead need to accumulate a result into a member
+   variable, which is called `table_of_computations_` here.
+
+   In particular, as the specialized methods (all the VisitStmt_() and VisitExpr_() methods), just
+   call VisitStmt()/VisitExpr() on all the children nodes within the same instance, if we don't
+   want to override each of these specialized methods to change this behaviour, then
+   `table_of_computations_` will necessary be shared by all the children of a given nodes.
+   That requires to be careful when trying to write into the cache.
+*/
+
+/*!
+ * \brief Does the union of two tables of computations.
+ * \param table_main One of the two tables. The union will be written into it.
+ * \param table_aux The other table, which won't change.
+ * \note Does it directly in the first argument A for efficiency, as the union of A and B
+ *       necessarily gives something which contains A, so we avoid its copy.
+ */
+void UnionOf2TablesOfComputations(TableOfComputations& table_main,
+                                  const TableOfComputations& table_aux) {
+  // Adds each element of the second table to the first one
+  for (const auto& current : table_aux) {
+    table_main[current.first] += current.second;
+  }
+}
+
+/*!
+ * \brief Does the union of three tables of computations.
+ * \param table1 One of the three tables, which won't change.
+ * \param table2 One of the three tables, which won't change.
+ * \param table3 One of the three tables, which won't change.
+ * \note We don't need (at least yet) to have a function working for N tables, even if this
+ *       function for 3 tables seems at first glance redundant with the one for 2 tables defined
+ *       just above. The reason is that in order to do the union for N tables, we need to know how
+ *       to do it for two. That's because we would compute for N tables using the associativity
+ *       of the union : T1 U T2 U T3 ... U Tn = ((T1 U T2) U T3) ... U Tn
+ *       Therefore, we need one for 2 tables anyway. And as the one for N (N>=3) would only be used
+ *       (at least for now) for N=3, there is at the moment no need for such a generic union over
+ *       N tables.
+ */
+TableOfComputations UnionOf3TablesOfComputations(const TableOfComputations& table1,
+            const TableOfComputations& table2, const TableOfComputations& table3) {
+  TableOfComputations result = table1; // Copy needed as the union of 2 writes into its first arg
+  UnionOf2TablesOfComputations(result, table2);
+  UnionOf2TablesOfComputations(result, table3);
+
+  return result;
+}
+
+/*!
+ * \brief Does the intersection of two tables of computations.
+ * \param table1 One of the two tables, which won't change.
+ * \param table2 The other table, which also won't change.
+ */
+TableOfComputations IntersectionOf2TablesOfComputations(const TableOfComputations& table1,
+                                                        const TableOfComputations& table2) {
+  TableOfComputations result;
+  for (const auto& current : table1) {
+    auto it = table2.find(current.first);
+    if (it != table2.end()) {
+      result[current.first] = current.second + it->second;
+    }
+  }
+  return result;
+}
+
+/*!
+ * \brief Does the intersection of three tables of computations.
+ * \param table1 One of the three tables, which won't change.
+ * \param table2 One of the three tables, which won't change.
+ * \param table3 One of the three tables, which won't change.
+ * \note We don't need (at least yet) to have a function working for N tables, even if this
+ *       function for 3 tables seems at first glance redundant with the one for 2 tables defined
+ *       just above. The reason is that in order to do the intersection for N tables, we need to
+ *       know how to do it for two. That's because we would compute for N tables using the
+ *       associativity of the intersection : T1 Inter T2 Inter T3 ... Inter Tn
+ *       = ((T1 Inter T2) Inter T3) ... Inter Tn
+ *       Therefore, we need one for 2 tables anyway. And as the one for N (N>=3) would only be used
+ *       (at least for now) for N=3, there is at the moment no need for such a generic intersection
+ *       over N tables.
+ */
+TableOfComputations IntersectionOf3TablesOfComputations(const TableOfComputations& table1,
+                  const TableOfComputations& table2, const TableOfComputations& table3) {
+  TableOfComputations result = IntersectionOf2TablesOfComputations(table1, table2);
+  result = IntersectionOf2TablesOfComputations(result, table3);
+  return result;
+}
+
+/*!
+ * \brief Recompute the number of times that each computation in table_main
+          is being seen in table_bloc1, table_bloc2 and table_bloc3. It sets
+          each element to the sum of the times it is seen in each individual bloc.
+ * \param table_main The main table, for which we recompute the counters.
+ * \param table1 One of the three tables, which won't change.
+ * \param table2 One of the three tables, which won't change.
+ * \param table3 One of the three tables, which won't change.
+ * \note This function is needed because both the intersection (A Inter B) and the union
+ *        (A U B U C) adds the individual counters found in A, B and C. So when we treat for
+ *        instance an If (which contains a Cond, a Then branch and an Else branch),
+ *        it will compute (Then Inter Else) U (Cond Inter Then) U (Cond Inter Else).
+ *        In order to get back to the appropriate number (for instance, 3 if seen one time in each
+ *        bloc), it is therefore necessary to recompute the counters afterwards, which is what this
+ *        function does.
+ */
+void RecomputeNbTimesSeenInThreeBlocs(TableOfComputations& table_main,
+    const TableOfComputations& table_bloc1, const TableOfComputations& table_bloc2,
+    const TableOfComputations& table_bloc3) {

Review comment:
       Good idea! Thanks for the remark!




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] FranckQC commented on a change in pull request #9482: Implementation of Common Subexpression Elimination for TIR

Posted by GitBox <gi...@apache.org>.
FranckQC commented on a change in pull request #9482:
URL: https://github.com/apache/tvm/pull/9482#discussion_r749732215



##########
File path: src/tir/analysis/check_contains.cc
##########
@@ -0,0 +1,98 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+  * \file check_contains.cc
+  * \brief Implementation of the analysis that tells if an expression contains
+            a node that satisfies a given predicate.
+  */
+
+#include "check_contains.h"
+
+#include <tvm/tir/expr.h>
+
+#include <vector>
+
+namespace tvm {
+namespace tir {
+
+/*!
+ * \brief Toplevel (static) function that tells if an expression contains a subexpression that
+          satisfies a given predicate.
+ * \param expr The expression to check
+ * \param predicate The predicate that must be satisfied
+ * \return Whether `expr` contains a subexpression that satisfies `predicate`
+ */
+bool CheckContains::ExprContains(const PrimExpr& expr,
+                                 std::function<bool(const PrimExpr&)> predicate) {
+  CheckContains check_contains(predicate);
+  check_contains.VisitExpr(expr);
+  return check_contains.contains_it_;
+}
+
+/*!
+ * \brief Toplevel (static) function that tells if a statement contains a subexpression that
+          satisfies a given predicate.
+ * \param stmt The statement to check
+ * \param predicate The predicate that must be satisfied
+ * \return Whether `stmt` contains a subexpression that satisfies `predicate`
+ */
+bool CheckContains::StmtContains(const Stmt& stmt, std::function<bool(const PrimExpr&)> predicate) {
+  CheckContains check_contains(predicate);
+  check_contains.VisitStmt(stmt);
+  return check_contains.contains_it_;
+}
+
+/*!
+ * \brief Protected constructor of CheckContains.
+ * \param predicate The predicate that must be satisfied
+ */
+CheckContains::CheckContains(std::function<bool(const PrimExpr&)> predicate)
+    : predicate_(predicate) {}
+
+/*!
+ * \brief The method which overrides the generic dispatcher of StmtExprVisitor for expressions.
+ * \param expr The expression to visit
+ */
+void CheckContains::VisitExpr(const PrimExpr& expr) {
+  // If the predicate holds on `expr`, we know `expr` contains something which makes
+  // the predicate hold
+  if (predicate_(expr)) {
+    contains_it_ = true;
+  } else {
+    // Otherwise we continue to look for it recursively by calling the dispatcher
+    StmtExprVisitor::VisitExpr(expr);
+  }
+}
+
+/*!
+ * \brief The method which overrides the generic dispatcher of StmtExprVisitor for statements.
+ * \param stmt The statement to visit
+ */
+void CheckContains::VisitStmt(const Stmt& stmt) {
+  // We keep exploring only if `contains_it_` is false
+  if (!contains_it_) {
+    // and in order to do that we call the general dispatcher
+    StmtExprVisitor::VisitStmt(stmt);
+  }
+  // As otherwise we already have our answer
+}
+
+}  // namespace tir
+}  // namespace tvm

Review comment:
       Thanks! Done.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi commented on pull request #9482: Implementation of Common Subexpression Elimination for TIR

Posted by GitBox <gi...@apache.org>.
masahi commented on pull request #9482:
URL: https://github.com/apache/tvm/pull/9482#issuecomment-1028584009


   @FranckQC You can just do `tests/lint/git-clang-format.sh -i upstream/main` to fix c++ lint errors.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] FranckQC commented on a change in pull request #9482: Implementation of Common Subexpression Elimination for TIR

Posted by GitBox <gi...@apache.org>.
FranckQC commented on a change in pull request #9482:
URL: https://github.com/apache/tvm/pull/9482#discussion_r797182546



##########
File path: src/tir/transforms/common_subexpr_elim_tools.h
##########
@@ -0,0 +1,216 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file common_subexpr_elim_tools.h
+ * \brief Interface of analysis tools and utility functions used
+           by the Common Subexpression Elimination (CSE) pass.
+ */
+
+#ifndef TVM_TIR_TRANSFORMS_COMMON_SUBEXPR_ELIM_TOOLS_H_
+#define TVM_TIR_TRANSFORMS_COMMON_SUBEXPR_ELIM_TOOLS_H_
+
+#include <tvm/runtime/container/string.h>
+#include <tvm/tir/analysis.h>  // For the ExprDeepEqual analysis
+#include <tvm/tir/expr.h>
+#include <tvm/tir/expr_functor.h>
+#include <tvm/tir/stmt.h>
+#include <tvm/tir/stmt_functor.h>  // For the class StmtExprVisitor
+
+#include <unordered_map>  // For the hashtable datatype
+#include <vector>
+
+#include "../../../3rdparty/dmlc-core/include/dmlc/optional.h"
+
+namespace tvm {
+namespace tir {
+
+/*!
+ * \brief A table of computations is a hashtable which associates to each expression being computed
+          a number (which is the number of time that it is computed)
+					It is important to note that the hash used is a StructuralHash (and not an ObjectPtrHash)
+					as we need to hash similarly deeply equal terms.
+					The comparison used is ExprDeepEqual, which is stricter than StructuralEqual (as it does
+					not do variables remapping), so it is compatible with StructuralHash (intended to be used
+					with StructuralEqual).
+ */
+using TableOfComputations = std::unordered_map<PrimExpr, size_t, StructuralHash, ExprDeepEqual>;
+
+/*!
+ * \brief A cache of computations is made of a pair of two hashtables, which respectively associate
+          to each statement or expression of the program its table of computations. Its purpose is
+          to avoid the CSE pass from recomputing repeatedly the same tables of computations.
+ */
+struct CacheOfComputations {
+  // Part of the cache for statements
+  // It maps each known statement to its table of computations
+  std::unordered_map<Stmt, TableOfComputations, ObjectPtrHash, ObjectPtrEqual>
+      cache_stmt_table_computations_;
+
+  // Part of the cache for expressions
+  // It maps each known expression to its table of computations
+  std::unordered_map<PrimExpr, TableOfComputations, ObjectPtrHash, ObjectPtrEqual>
+      cache_expr_table_computations_;
+};
+
+/*!
+ * \brief Visitor which returns in a hashtable the (syntatic) computations done by an expression
+          or by a statement.
+ * \note Computations here are considered syntactically, meaning that semantically equivalent
+          computations that are not syntactically the same are not merged together.
+ */
+class ComputationsDoneBy : public StmtExprVisitor {
+ public:
+  // Toplevel (static) methods
+  static TableOfComputations GetComputationsDoneBy(

Review comment:
       And for the verbosity, it's only called by the main functions of the CSE pass (i.e. two times), so I think that's not a big issue.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] FranckQC commented on a change in pull request #9482: Implementation of Common Subexpression Elimination for TIR

Posted by GitBox <gi...@apache.org>.
FranckQC commented on a change in pull request #9482:
URL: https://github.com/apache/tvm/pull/9482#discussion_r797243979



##########
File path: src/tir/transforms/common_subexpr_elim.cc
##########
@@ -0,0 +1,601 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file common_subexpr_elim.cc
+ * \brief Implementation of the Common Subexpressions Elimination (CSE) pass
+           which rewrites statements and expressions in order to eliminate
+           redundant computations. In order to achieve that, common (sub-)
+           expressions are introduced into variables with let-in bindings,
+           and the places where the expression was used are replaced with
+           the freshly introduced variable.
+ */
+
+#include "common_subexpr_elim.h"
+
+#include <tvm/ir/transform.h>  // For the class Pass and the class PassContext
+#include <tvm/runtime/container/array.h>
+#include <tvm/runtime/container/string.h>
+#include <tvm/tir/analysis.h>  // For the analysis which gives the size of an expr
+#include <tvm/tir/expr.h>
+#include <tvm/tir/expr_functor.h>
+#include <tvm/tir/function.h>  // For the class PrimFunc
+#include <tvm/tir/stmt.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>  // For the decl of the function returning the pass
+
+#include <algorithm>  // For the algorithm std::find
+#include <iostream>
+#include <unordered_map>  // For the hashtable datatype
+#include <utility>        // For std::pair and std::move
+#include <vector>
+
+#include "../analysis/check_contains.h"  // For the visitor CheckContains
+#include "common_subexpr_elim_tools.h"   // For the auxiliary analysis (visitors) and tools
+#include "replace_expr_selected.h"       // For the mutator ReplaceExprSelected
+
+namespace tvm {
+namespace tir {
+
+/*!
+ * \brief Check whether a computation is forbidden for being treated by the CSE pass.
+          The important thing about forbidden computations is that not only we won't want
+          to collect them for the CSE pass, but we also won't even want to collect computations
+          that contain them.
+          The reason is that reusing such computations would change the semantics of the program,
+          and therefore before doing any introduction of variable or any reuse of already introduced
+          variables, we will make sure that the computation being considered is not forbidden, and
+          that it does not even contain a forbidden computation.
+ * \param expr The expression to check
+ * \return Whether `expr` is a forbidden computation or not
+ */
+bool CommonSubexpressionEliminator::ForbiddenComputation(const PrimExpr& expr) {
+  // Function calls, loads and buffer loads are absolutely forbidden as introducing them into
+  // variables would change the semantics of the program.
+  return (expr.as<CallNode>() != nullptr || expr.as<LoadNode>() != nullptr ||
+          expr.as<BufferLoadNode>() != nullptr);
+}
+
+/*!
+ * \brief Predicate used for verifying that a computation is eligible for being treated by
+          the CSE pass, i.e. for being introduced into a variable / for being replaced by a
+          variable.
+          Being eligible is a conjunction of a few conditions, like not being an atom (constant
+          or variable), not being a forbidden node, not containing a forbidden node, etc.
+ * \param expr The expression to check
+ * \return Whether `expr` is an eligible computation or not
+ */
+bool CommonSubexpressionEliminator::IsEligibleComputation(const PrimExpr& expr) {
+  return (
+      // In order to be eligible, the given expression should not be a constant
+      (expr.as<IntImmNode>() == nullptr) && (expr.as<FloatImmNode>() == nullptr) &&
+      (expr.as<StringImmNode>() == nullptr)
+      // and it should not be a variable
+      && (expr.as<VarNode>() == nullptr)
+      // and it should not be a forbidden computation (function calls and loads)
+      && (!ForbiddenComputation(expr))
+      // and it should not even contain a forbidden computation (function calls and loads)
+      // the reason is that we don't want to register expressions like (x + f(y)) or
+      // (x + Mem[i]) as introducing them into variables could change the semantics
+      && (!CheckContains::ExprContains(expr, ForbiddenComputation))
+      // and it should not be a ramp node or a broadcast node due to some internals TVM
+      // constraints (which check for these node explicitely without performing any
+      // evaluation first, so if they have been put into variables it fails)
+      && (expr.as<RampNode>() == nullptr) && (expr.as<BroadcastNode>() == nullptr));
+}
+
+/*!
+ * \brief Predicate used (when considering eligible computations) for only diving into
+          expressions that are allowed to contain eligible computations. Customize this predicate
+          if you want to make it forbidden to rewrite inside a specific node, like inside
+          a Load node for instance.
+ * \param expr The expression to check
+ * \return Whether `expr` can contain some eligible computations or not, and therefore
+             if recursing inside `expr` is necessary.
+ */
+bool CommonSubexpressionEliminator::CanContainEligibleComputations(const PrimExpr& expr) {
+  // Uncomment the next line to prevent the collection and the replacement of eligible computations
+  // inside the index of Load nodes. We initially thought that this would be needed in order to
+  // not harm the indexing mode of the CPU, but as we are still far from ASM code, we
+  // finally want to perform such simplifications, which tend to happen fairly frequently.
+
+  // return ( (expr.as<LoadNode>() == nullptr) && (expr.as<BufferLoadNode>() == nullptr) )
+  return true;
+};
+
+/*!
+ * \brief Generates a new fresh variable, whose name will be cse_var_i.
+ * \param type_annotation The type of the new variable to generate
+ * \return A new variable of type `type_annotation` called cse_var_i where i is the first available
+            integer.
+ */
+Var CommonSubexpressionEliminator::GenerateNewVar(DataType type_annotation) {
+  // Increase `num_last_try_` for this new attempt
+  num_last_try_++;
+  // Builds the variable name, which is sce_var_i where i will go up from 1
+  std::string prefix = "cse_var_";
+  std::string name = prefix.append(std::to_string(num_last_try_));
+  // Builds a String using the std::string
+  String string_name(name);
+
+  // Check that the name that we want to use for the new variable isn't already being used
+  // (names don't really have to be unique as they are just hints, and having the same name
+  // doesn't means that it's the same variable, but it's clearer for dumps)
+  if (UsesVarName::StmtUsesVarName(initial_body_, string_name)) {
+    // If the name is already used, call ourselves recursively for trying with the next one
+    return GenerateNewVar(type_annotation);
+  }
+
+  // Increase `nb_var_` for this new generation of variable that we have just done
+  nb_var_++;
+
+  // Return a new Variable using the name built and the given type_annotation
+  return (Var(string_name, type_annotation));
+}
+
+/*!
+ * \brief Gives the number of variables generated by the CSE on the current function
+           (i.e., getter for `nb_var_`).
+ * \return A copy of `nb_var_`
+ */
+int CommonSubexpressionEliminator::GetNbVarGenerated() { return nb_var_; }
+
+/*!
+ * \brief Toplevel (static) method that performs Common Subexpression Elimination on
+          a given statement (which should be the body of a PrimFunc). This method should be
+          called for each PrimFunc definition.
+ * \param stmt The statement of the function being analyzed, on which we want to perform CSE
+ * \param context_init The initial context, which should contain the formal parameters
+                          of the function being analyzed
+ * \return A new statement where CSE has been performed
+ */
+Stmt CommonSubexpressionEliminator::PerformCSE(const Stmt& stmt, const Context& context_init) {
+  // As this function is being called for each PrimFunc definition, we create a new instance
+  // for the one we are having now.
+  CommonSubexpressionEliminator common_subexpression_eliminator(stmt, context_init);
+  return common_subexpression_eliminator.VisitStmt(stmt);
+}
+
+/*!
+ * \brief Protected constructor of CommonSubexpressionEliminator.
+ * \param context_init The context at the begining of the CSE pass. It should contain the
+                        formal parameters of the function that will be analyzed
+ */
+CommonSubexpressionEliminator::CommonSubexpressionEliminator(const Stmt& stmt,
+                                                             const Context& context_init)
+    : initial_body_(stmt), context_(context_init) {}
+
+/*!
+ * \brief The method which overrides the generic dispatcher of StmtExprMutator.
+          Entry point to the common subexpression elimination mutator for expressions.
+ * \param expr The expression to mutate
+ */
+PrimExpr CommonSubexpressionEliminator::VisitExpr(const PrimExpr& expr) {
+  PrimExpr result = expr;
+
+  // Obtain the (syntactic) eligible computations done by the input expression, and keep it as
+  // a TableOfComputations, which is a mapping from PrimExpr to size_t, where the size_t is the
+  // number of time this exact syntactic computation is being computed.
+  TableOfComputations table_syntactic_comp_done_by_expr = ComputationsDoneBy::GetComputationsDoneBy(
+      expr, IsEligibleComputation, CanContainEligibleComputations);
+
+  // Transform the hashtable of *syntactic* eligible computations into a vector of pairs
+  // containing *semantic* entities, i.e. where equivalent computations are merged.
+  std::vector<std::pair<PrimExpr, size_t>> semantic_comp_done_by_expr =
+      SyntacticToSemanticComputations(table_syntactic_comp_done_by_expr);
+
+  // Sort the vector of semantic entities by decreasing size
+  std::sort(semantic_comp_done_by_expr.begin(), semantic_comp_done_by_expr.end(),
+            [](std::pair<PrimExpr, size_t> a, std::pair<PrimExpr, size_t> b) {
+              return (CalculateExprComplexity(a.first) > CalculateExprComplexity(b.first));
+            });
+
+  // For each computation done (considering them from biggest to smallest)
+  for (size_t i = 0; i < semantic_comp_done_by_expr.size(); i++) {
+    std::pair<PrimExpr, size_t>& computation_and_nb = semantic_comp_done_by_expr[i];
+
+    // The predicate later used (when doing replacements) to select expressions that are
+    // equivalent to the current computation (`computation_and_nb.first`)
+    std::function<bool(const PrimExpr&)> predicate_selector =
+        [computation_and_nb](const PrimExpr& current_expr) {
+          // `current_expr` should be equivalent to `computation_and_nb.first`, but we also check
+          // that `current_expr` is an eligible computation even if we know that
+          // `computation_and_nb.first` is eligible by construction, in case that one day the
+          // equivalence relation would not preserve the eligibility any more (even though that
+          // would probably be a very weird equivalence).
+          return (EquivalentTerms(current_expr, computation_and_nb.first) &&
+                  IsEligibleComputation(current_expr));
+        };
+
+    // See if there is a pair (`var`, `value`) in the context where `value` is semantically
+    // equivalent to `computation_and_nb.first`
+    auto it_on_var = std::find_if(
+        context_.begin(), context_.end(),
+        [computation_and_nb](const std::pair<Var, MaybeValue>& var_and_value) {
+          // Note : safe to call value() as we check has_value() just before
+          return (var_and_value.second.has_value() &&
+                  EquivalentTerms(var_and_value.second.value(), computation_and_nb.first));
+        });
+
+    // Case where we have a perfectly equivalent computation already available in a variable
+    // introduced (i.e, present in context_).
+    // Note that this case is needed when the user has written something like
+    // [let x = A in ....A...A...] : we need to be able to replace all the occurences of A by
+    // an already existing variable holding A, when such a variable happens to exist.
+    if (it_on_var != context_.end()) {
+      // Replace in the current `result` everything that is selected by the selector with
+      // the existing variable, without diving into expressions in which we don't have the
+      // right to dive.
+      result = ReplaceExprSelected::ReplaceExprSelectedInExpr(
+          result, predicate_selector, it_on_var->first, CanContainEligibleComputations);
+    } else {
+      // The current computation is not equivalent to a computation already done. We will
+      // need to see if we want to introduce it.
+
+      // --- Chunk needed for reusing the UndefinedVars() analysis ---
+      // 1 - Wraps the computation into a statement
+      Stmt computation_wrapped_in_stmt = Evaluate(computation_and_nb.first);
+      // 2.1 - Transform the context into a vector of variables instead of pairs
+      std::function<Var(const std::pair<Var, MaybeValue>&)> forget_value =
+          [](const std::pair<Var, MaybeValue>& pair) { return pair.first; };
+      std::vector<Var> vector_vars_known = VectorMap(context_, forget_value);
+      // 2.2 - Transform the std::vector into an Array
+      Array<Var> array_vars_known = Array<Var>(vector_vars_known);
+      // --- End of chunk needed for reusing the UndefinedVars() analysis ---
+
+      // We use the UndefinedVars() analysis to get the undefined vars of the computation
+      Array<Var> vars_undefined = UndefinedVars(computation_wrapped_in_stmt, array_vars_known);
+
+      // Check if we can introduce it : if it contains no undefined variables and if we want
+      // to introduce it according to the predicate
+      if (vars_undefined.empty() &&
+          PredicateIntroVarForComputation(computation_and_nb.first, computation_and_nb.second)) {
+        // Create a new variable for this computation
+        Var new_var = GenerateNewVar(computation_and_nb.first.dtype());
+        // Replace in the current `result` everything that is selected by the selector with
+        // the new variable, without diving into expressions in which we don't have the
+        // right to dive.
+        result = ReplaceExprSelected::ReplaceExprSelectedInExpr(result, predicate_selector, new_var,
+                                                                CanContainEligibleComputations);
+        // Build a let-in that introduces the new variable in the current `result`
+        result = Let(new_var, computation_and_nb.first, result);
+        // We don't add the variable to the context because the invariant is that the
+        // context is the context in which 'result' makes sense, and we've just updated it.
+      } else {
+        // Here it's not doable to introduce (via a let-in) the computation at this level
+        // as it contains variables that are not yet declared, and/or because the predicate
+        // did not select it.
+        // Either way, we will simply add to the vector of computations the direct subexprs
+        // of the current computation, as these ones might be good candidates
+        // for being introduced into variables.
+        // Note that we don't need to add all of its subexpressions, but only its *direct*
+        // subexpressions as we consider them from biggest to smallest, and if they were
+        // all added at once, then there could be dependencies between them, as commoning
+        // one of them could remove some other possibilities.
+
+        // Computing the direct subexpressions will return a small number of direct
+        // subexpressions (typically 0 to 3)
+        std::vector<PrimExpr> direct_subexprs = DirectSubexpr::GetDirectSubexpressions(
+            computation_and_nb.first, IsEligibleComputation, CanContainEligibleComputations);
+        // The following insertion will maintain `semantic_comp_done_by_expr` sorted (by
+        // decreasing size/complexity), and it will only insert at locations > i as the
+        // direct subexprs are necessarily smaller than the current computation.
+        InsertVectorToSortedSemanticComputations(semantic_comp_done_by_expr, direct_subexprs);
+      }
+    }
+    // Note : we do not remove the current element, as we never look back in the local vector
+  }  // End of for loop
+
+  // Calling the dispatcher to the specific treatments, which will update the context
+  // appropriately before doing the recursive calls on the child nodes
+  result = StmtExprMutator::VisitExpr(result);
+
+  return result;
+}
+
+/*!
+ * \brief The method which overrides the specific treatment for a LetNode
+ */
+PrimExpr CommonSubexpressionEliminator::VisitExpr_(const LetNode* op) {
+  // At this point, we have already done the generic treatment of introducing (via let-in) what
+  // was doable at the toplevel of the given let-in.
+
+  // Save the context at the entry of the function
+  Context context_at_entry = context_;
+
+  // Recurse on the `value` field for potentially rewriting it
+  PrimExpr value_new = VisitExpr(op->value);
+
+  // Augment the context with the association (`var`, `value`) for preparing the next recursion
+  // on the `body`
+  context_.push_back({op->var, MaybeValue(op->value)});
+
+  // Recurse on the `body` (with this extended context)
+  // The recursive call will have potentially done new simplifications, because in this recursive
+  // call `var` will be a part of the context.
+  // (see in VisitExpr() that no introduction were performed when a computation was using an
+  // undefined variable, as that would lead to ill-formed code)
+  PrimExpr body_new = VisitExpr(op->body);
+
+  // Restaure the context to its content at the entrance to not carry out of scope declarations
+  // as the variable introduced by the let-in is not in scope outside of its body
+  context_ = context_at_entry;
+
+  // Rebuild the let-in with a new `value_new` and `body_new` where new simplifications might
+  // have been done.
+
+  // If the `value` and the `body` of the let-in have been rewritten to the same thing
+  if (value_new.same_as(op->value) && body_new.same_as(op->body)) {
+    // then return a reference to the same node
+    return GetRef<PrimExpr>(op);
+  } else {
+    // Otherwise return a let-in built with the new `value_new` and the new `body_new` that
+    // have just been obtained
+    return Let(op->var, value_new, body_new, op->span);
+  }
+}
+
+/*!
+ * \brief The method which overrides the generic dispatcher of StmtExprMutator.
+          Entry point to the common subexpression elimination mutator for statements.
+ * \param stmt The statement to mutate.
+ */
+Stmt CommonSubexpressionEliminator::VisitStmt(const Stmt& stmt) {

Review comment:
       Yes we do need both the `VisitExpr()` and `VisitStmt()` as we are doing commoning both inside expressions (think about `let x = bigComp+bigComp`) and inside statements (think about `Mem[i] = bigComp+bigComp`).
   
   The basic implementation `StmtExprMutator::VisitStmt` does indeed visit sub `Expr`s. And actually we do rely on `StmtExprMutator::VisitExpr()` and `StmtExprMutator::VisitStmt()` for recursing inside the children, see the end of the two methods [here](https://github.com/apache/tvm/pull/9482/files#diff-2cfef2190c8093f3b70d5ed20a1e9e75a74cd040a6a0987d5648f511949efce9R307) and [here](https://github.com/apache/tvm/pull/9482/files#diff-2cfef2190c8093f3b70d5ed20a1e9e75a74cd040a6a0987d5648f511949efce9R477).
   
   The algorithm implemented for performing commoning inside expressions and inside statements is of course basically the same (it identifies redundancies, creates variables and then builds new expression/statement using let-in), but many things will be different type-wise when treating a statement versus an expression.
   There are many of these subtle differences, but one of them is that for an expression the pass will build a let-in expression and for a statement it will build a let-in statement.
   
   I also thought that it should be possible to factorize them and to replace them by something polymorphic. But when I actually tried a little bit (quite some time ago) I ended up not sure that this is doable easily without putting complexity everywhere.
   
   In particular, a possible approach would be to:
   - 1 - Create a main function for the algorithm currently duplicated in `VisitStmt()` and in `VisitExpr()`. Let's call it "`Visit()`" for now. This function takes an `Either<Expr,Stmt>`and returns an `Either<Expr,Stmt>`. Actually, that is called std::variant in C++.
   - 2 - In `VisitExpr(const PrimExpr& expr)` : we now build an `std::variant<Expr, Stmt>` from the given `expr`, then call the "generic" `Visit()`, and unpack its result (an Expr) to return it.
   - 3 - Same thing for `VisitStmt(const Stmt& stmt)` : we now build an std::variant<Expr, Stmt> from `stmt`, then call the "generic" `Visit()`, and unpack its result (a Stmt) to return it.
   
   In order to achieve 1 (which is the hard part here), we will need to :
   - Use polymorphic functions for all the functions being used by the current `VisitStmt()` and `VisitExpr()` (i.e. functions with the same name for treating a `PrimExpr `and a `Stmt`). It is already the case for `GetComputationsDoneBy()`, but we need to do the same for `ReplaceSelectedExprInStmt()` and `ReplaceSelectedExprInExpr()` which are currently not polymorphic methods. And the same goes for quite a few other methods.
   This part is absolutely not the problem.
   
   - But now, when we try to write the new `Visit()` function, it seems that we are a bit stuck :
   We get as input this expr_or_stmt of type std::variant<Expr, Stmt>.
   We can of course check what we got :
   Expr expr;
   Stmt stmt;
   if( isExpr(expr_or_stmt) ) // Pseudo code, can't remember the name of the actual std method
     expr = std::get<Expr>(expr_or_stmt);
   else
     stmt = std::get<Stmt>(expr_or_stmt);
   
   And then what?
   For instance, by what should we replace the beginning of the algorithm, which currently does:
   ComputationsDoneBy::GetComputationsDoneBy(
         expr, IsEligibleComputation, CanContainEligibleComputations);
         
   Sure, we could check if we're working with an Expr or a Stmt and within a conditional call either 
   ```
   GetComputationsDoneBy(
         **expr**, IsEligibleComputation, CanContainEligibleComputations);
   ```
   or
   ```
   GetComputationsDoneBy(
         **stmt**, IsEligibleComputation, CanContainEligibleComputations);
   ```
   depending on which case we're in. But that replaces a duplication of methods by duplication **within** the method... And as the algorithm was not trivial to follow even with the comments, I though it's probably better to not add complexity in it.
   
   I will try to think about that again soon, but my current way of seeing things is that without adding complexity everywhere, it's not so easy to actually replace these two methods by just a single one doing the job.
   In addition, in its current form, it's possible to easily change the algorithm if we want it to do something slightly different for an Expr versus for a Stmt. That would be more complicated if everything goets unified.
     




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] FranckQC commented on pull request #9482: Implementation of Common Subexpression Elimination for TIR

Posted by GitBox <gi...@apache.org>.
FranckQC commented on pull request #9482:
URL: https://github.com/apache/tvm/pull/9482#issuecomment-1031900087


   Regarding the VTA trouble, I'm trying to make sure that vta.build() effectively respects the disabled passes.
   It seems that it mostly just calls tvm.build(), and I had no problem with the tests that were calling tvm.build() or tvm.lower().
   
   Am I missing something here?
   (found in tvm/vta/python/vta/build_module.py)
   ![image](https://user-images.githubusercontent.com/89943638/152868450-3a785ea4-2761-4a7f-a9be-70854209fc2a.png)
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] FranckQC edited a comment on pull request #9482: Implementation of Common Subexpression Elimination for TIR

Posted by GitBox <gi...@apache.org>.
FranckQC edited a comment on pull request #9482:
URL: https://github.com/apache/tvm/pull/9482#issuecomment-1034443098


   Sorry I forgot to answer to the other parts of your message @mbs-octoml . Many thanks for it by the way!
   
   > Thanks so much for the clear and complete comments, very much appreciate that.
   > Echo @Hzfengsy suggestion to switch to TVMScript for your tests -- perhaps just pretty printing what you already constructs would be a short cut.
   
   Yes I didn't know about TVMScript before. When writing the tests, I initially stared at some other test files and got inspiration from them. Unfortunately the ones I've been looking at might not have been the most up-to-date way of doing things, sorry for that! :-(
   I guess we can still improve the tests later on, but they still accomplish their main function for now, which is the most important I think, even though they are probably a little bit more verbose than they would be using TVMScript, where I could just directly write the expected TIR code instead of unfolding the TIR code that the CSE pass has produced.
    
   > The code duplication across PrimExpr & Stmnt in CommonSubexpressionElimintor is unfortunate but suspect removing that would only obscure things.
   
   Yes, I agree that this duplication is a little bit unfortunate. @masahi did pointed it out [here](https://github.com/apache/tvm/pull/9482#discussion_r790201064). I was also a little bit annoyed with it at the beginning. So I tried to factorize it out a few times, including an attempt described in my answer [here](https://github.com/apache/tvm/pull/9482#discussion_r797243979). But all my attempt ended up with something much too complicated for what we would gain.
   
   In fact, we just happen to want to do almost exactly the same treatment for an expression and for a statement from an algorithmic point of view, but from a data-type point of view, quite a things are still different type-wise. That's a pretty rare situation. In the end, I decided to not force things, and to leave it like that.
   And the positive point of having the VisitStmt() and VisitExpr() not factorized by some weird magic is that we can still easily customize the algorithm if at some point we need to, between expressions and statements (I can't think of any reason we would want that, but still! :) )
   
   Many thanks again!
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi commented on a change in pull request #9482: Implementation of Common Subexpression Elimination for TIR

Posted by GitBox <gi...@apache.org>.
masahi commented on a change in pull request #9482:
URL: https://github.com/apache/tvm/pull/9482#discussion_r790177909



##########
File path: src/tir/transforms/common_subexpr_elim_tools.h
##########
@@ -0,0 +1,216 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file common_subexpr_elim_tools.h
+ * \brief Interface of analysis tools and utility functions used
+           by the Common Subexpression Elimination (CSE) pass.
+ */
+
+#ifndef TVM_TIR_TRANSFORMS_COMMON_SUBEXPR_ELIM_TOOLS_H_
+#define TVM_TIR_TRANSFORMS_COMMON_SUBEXPR_ELIM_TOOLS_H_
+
+#include <tvm/runtime/container/string.h>
+#include <tvm/tir/analysis.h>  // For the ExprDeepEqual analysis
+#include <tvm/tir/expr.h>
+#include <tvm/tir/expr_functor.h>
+#include <tvm/tir/stmt.h>
+#include <tvm/tir/stmt_functor.h>  // For the class StmtExprVisitor
+
+#include <unordered_map>  // For the hashtable datatype
+#include <vector>
+
+#include "../../../3rdparty/dmlc-core/include/dmlc/optional.h"
+
+namespace tvm {
+namespace tir {
+
+/*!
+ * \brief A table of computations is a hashtable which associates to each expression being computed
+          a number (which is the number of time that it is computed)
+					It is important to note that the hash used is a StructuralHash (and not an ObjectPtrHash)
+					as we need to hash similarly deeply equal terms.
+					The comparison used is ExprDeepEqual, which is stricter than StructuralEqual (as it does
+					not do variables remapping), so it is compatible with StructuralHash (intended to be used
+					with StructuralEqual).
+ */
+using TableOfComputations = std::unordered_map<PrimExpr, size_t, StructuralHash, ExprDeepEqual>;
+
+/*!
+ * \brief A cache of computations is made of a pair of two hashtables, which respectively associate
+          to each statement or expression of the program its table of computations. Its purpose is
+          to avoid the CSE pass from recomputing repeatedly the same tables of computations.
+ */
+struct CacheOfComputations {
+  // Part of the cache for statements
+  // It maps each known statement to its table of computations
+  std::unordered_map<Stmt, TableOfComputations, ObjectPtrHash, ObjectPtrEqual>
+      cache_stmt_table_computations_;
+
+  // Part of the cache for expressions
+  // It maps each known expression to its table of computations
+  std::unordered_map<PrimExpr, TableOfComputations, ObjectPtrHash, ObjectPtrEqual>
+      cache_expr_table_computations_;
+};
+
+/*!
+ * \brief Visitor which returns in a hashtable the (syntatic) computations done by an expression
+          or by a statement.
+ * \note Computations here are considered syntactically, meaning that semantically equivalent
+          computations that are not syntactically the same are not merged together.
+ */
+class ComputationsDoneBy : public StmtExprVisitor {

Review comment:
       Except for this class and probably few others, most of classes and functions exposed in this header are only used inside `common_subexpr_elim_tools.cc`. Please leave only what's necessary for `common_subexpr_elim.cc` here.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi commented on a change in pull request #9482: Implementation of Common Subexpression Elimination for TIR

Posted by GitBox <gi...@apache.org>.
masahi commented on a change in pull request #9482:
URL: https://github.com/apache/tvm/pull/9482#discussion_r790200526



##########
File path: src/tir/analysis/check_contains.h
##########
@@ -0,0 +1,60 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file check_contains.h
+ * \brief Interface of the analysis that tells if an expression contains
+           a node that satisfies a given predicate.
+ */
+
+#ifndef TVM_TIR_ANALYSIS_CHECK_CONTAINS_H_
+#define TVM_TIR_ANALYSIS_CHECK_CONTAINS_H_
+
+#include <tvm/tir/expr.h>
+#include <tvm/tir/stmt_functor.h>  // For the class StmtExprVisitor
+
+namespace tvm {
+namespace tir {
+
+/*!
+ * \brief Visitor which tells if a given expression or statement contains a subexpression
+          that satisfies a given predicate
+ */
+class CheckContains : public StmtExprVisitor {

Review comment:
       This class looks very similar to https://github.com/apache/tvm/blob/5012462ef88acfd6a84b3f28135b361a8788f257/src/tir/analysis/var_touch.cc#L30
   
   Is it possible to merge them? e.g. Refactor `VarTouch` as a special case of `CheckContains`.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] FranckQC edited a comment on pull request #9482: Implementation of Common Subexpression Elimination for TIR

Posted by GitBox <gi...@apache.org>.
FranckQC edited a comment on pull request #9482:
URL: https://github.com/apache/tvm/pull/9482#issuecomment-1034315546






-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] tmoreau89 commented on pull request #9482: Implementation of Common Subexpression Elimination for TIR

Posted by GitBox <gi...@apache.org>.
tmoreau89 commented on pull request #9482:
URL: https://github.com/apache/tvm/pull/9482#issuecomment-1031990492


   @masahi I can't really say without spending the cycles investigating; it's been some time since I've stared at this code. It is odd that this is breaking VTA although CSE has been disabled. Do we have a feeling that this error is flakey?
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] FranckQC edited a comment on pull request #9482: Implementation of Common Subexpression Elimination for TIR

Posted by GitBox <gi...@apache.org>.
FranckQC edited a comment on pull request #9482:
URL: https://github.com/apache/tvm/pull/9482#issuecomment-1031857978


   There is also some trouble with two of the [python3: i386] tests on test_large_grpah (??) and test_large_graph.
   
   This, with the trouble in test_vta_insn.py on test_gemm  are the only issues remaining.
   Should I revert my changes on test_vta_insn.py which did not work?


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] FranckQC commented on a change in pull request #9482: Implementation of Common Subexpression Elimination for TIR

Posted by GitBox <gi...@apache.org>.
FranckQC commented on a change in pull request #9482:
URL: https://github.com/apache/tvm/pull/9482#discussion_r796249295



##########
File path: include/tvm/tir/transform.h
##########
@@ -458,6 +458,14 @@ TVM_DLL Pass FlattenBuffer();
  */
 TVM_DLL Pass TextureFlatten();
 
+/*!
+ * \brief Implements a Common Subexpression Elimination (CSE)
+ *        which introduces let-in bindings for duplicated sub-expressions.
+ * \param enable_cse_tir Whether common subexpression elimination is enabled.
+ * \return The pass.
+ */
+TVM_DLL Pass CommonSubexprElim(bool enable_cse_tir = true);

Review comment:
       At first I was about to make it explicit in the name that it's for TIR, and in the end decided otherwise as the file is in the /tir/ subfolder.
   But you're right, it's probably better to make that explicit in the name!




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] FranckQC commented on a change in pull request #9482: Implementation of Common Subexpression Elimination for TIR

Posted by GitBox <gi...@apache.org>.
FranckQC commented on a change in pull request #9482:
URL: https://github.com/apache/tvm/pull/9482#discussion_r796177571



##########
File path: src/tir/transforms/common_subexpr_elim_tools.h
##########
@@ -0,0 +1,216 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file common_subexpr_elim_tools.h
+ * \brief Interface of analysis tools and utility functions used
+           by the Common Subexpression Elimination (CSE) pass.
+ */
+
+#ifndef TVM_TIR_TRANSFORMS_COMMON_SUBEXPR_ELIM_TOOLS_H_
+#define TVM_TIR_TRANSFORMS_COMMON_SUBEXPR_ELIM_TOOLS_H_
+
+#include <tvm/runtime/container/string.h>
+#include <tvm/tir/analysis.h>  // For the ExprDeepEqual analysis
+#include <tvm/tir/expr.h>
+#include <tvm/tir/expr_functor.h>
+#include <tvm/tir/stmt.h>
+#include <tvm/tir/stmt_functor.h>  // For the class StmtExprVisitor
+
+#include <unordered_map>  // For the hashtable datatype
+#include <vector>
+
+#include "../../../3rdparty/dmlc-core/include/dmlc/optional.h"
+
+namespace tvm {
+namespace tir {
+
+/*!
+ * \brief A table of computations is a hashtable which associates to each expression being computed
+          a number (which is the number of time that it is computed)
+					It is important to note that the hash used is a StructuralHash (and not an ObjectPtrHash)
+					as we need to hash similarly deeply equal terms.
+					The comparison used is ExprDeepEqual, which is stricter than StructuralEqual (as it does
+					not do variables remapping), so it is compatible with StructuralHash (intended to be used
+					with StructuralEqual).
+ */
+using TableOfComputations = std::unordered_map<PrimExpr, size_t, StructuralHash, ExprDeepEqual>;
+
+/*!
+ * \brief A cache of computations is made of a pair of two hashtables, which respectively associate
+          to each statement or expression of the program its table of computations. Its purpose is
+          to avoid the CSE pass from recomputing repeatedly the same tables of computations.
+ */
+struct CacheOfComputations {
+  // Part of the cache for statements

Review comment:
       Good idea!




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] FranckQC commented on a change in pull request #9482: Implementation of Common Subexpression Elimination for TIR

Posted by GitBox <gi...@apache.org>.
FranckQC commented on a change in pull request #9482:
URL: https://github.com/apache/tvm/pull/9482#discussion_r797181949



##########
File path: src/tir/transforms/common_subexpr_elim_tools.h
##########
@@ -0,0 +1,216 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file common_subexpr_elim_tools.h
+ * \brief Interface of analysis tools and utility functions used
+           by the Common Subexpression Elimination (CSE) pass.
+ */
+
+#ifndef TVM_TIR_TRANSFORMS_COMMON_SUBEXPR_ELIM_TOOLS_H_
+#define TVM_TIR_TRANSFORMS_COMMON_SUBEXPR_ELIM_TOOLS_H_
+
+#include <tvm/runtime/container/string.h>
+#include <tvm/tir/analysis.h>  // For the ExprDeepEqual analysis
+#include <tvm/tir/expr.h>
+#include <tvm/tir/expr_functor.h>
+#include <tvm/tir/stmt.h>
+#include <tvm/tir/stmt_functor.h>  // For the class StmtExprVisitor
+
+#include <unordered_map>  // For the hashtable datatype
+#include <vector>
+
+#include "../../../3rdparty/dmlc-core/include/dmlc/optional.h"
+
+namespace tvm {
+namespace tir {
+
+/*!
+ * \brief A table of computations is a hashtable which associates to each expression being computed
+          a number (which is the number of time that it is computed)
+					It is important to note that the hash used is a StructuralHash (and not an ObjectPtrHash)
+					as we need to hash similarly deeply equal terms.
+					The comparison used is ExprDeepEqual, which is stricter than StructuralEqual (as it does
+					not do variables remapping), so it is compatible with StructuralHash (intended to be used
+					with StructuralEqual).
+ */
+using TableOfComputations = std::unordered_map<PrimExpr, size_t, StructuralHash, ExprDeepEqual>;
+
+/*!
+ * \brief A cache of computations is made of a pair of two hashtables, which respectively associate
+          to each statement or expression of the program its table of computations. Its purpose is
+          to avoid the CSE pass from recomputing repeatedly the same tables of computations.
+ */
+struct CacheOfComputations {
+  // Part of the cache for statements
+  // It maps each known statement to its table of computations
+  std::unordered_map<Stmt, TableOfComputations, ObjectPtrHash, ObjectPtrEqual>
+      cache_stmt_table_computations_;
+
+  // Part of the cache for expressions
+  // It maps each known expression to its table of computations
+  std::unordered_map<PrimExpr, TableOfComputations, ObjectPtrHash, ObjectPtrEqual>
+      cache_expr_table_computations_;
+};
+
+/*!
+ * \brief Visitor which returns in a hashtable the (syntatic) computations done by an expression
+          or by a statement.
+ * \note Computations here are considered syntactically, meaning that semantically equivalent
+          computations that are not syntactically the same are not merged together.
+ */
+class ComputationsDoneBy : public StmtExprVisitor {
+ public:
+  // Toplevel (static) methods
+  static TableOfComputations GetComputationsDoneBy(

Review comment:
       Actually, it looks like both way of doing things already exist in the code base.
   For instance, the following passes will use a static function as the entry point to the pass:
   - https://github.com/apache/tvm/blob/main/src/tir/transforms/convert_blocks_to_opaque.cc#L39
   - https://github.com/apache/tvm/blob/main/src/tir/transforms/flatten_buffer.cc#L53
   - https://github.com/apache/tvm/blob/main/src/tir/transforms/lower_cross_thread_reduction.cc#L154
   Among many others occurrences.
   
   Actually, I discovered that both style coexist when I just started to work on the pass, and I wasn't sure which one to follow.
   I think I tend to prefer to have everything self-contained in the class to follow oriented object principles and to avoid mixing imperative and OO styles, so I followed this specific one, but in the end, either way is fine with me to be honest.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] FranckQC edited a comment on pull request #9482: Implementation of Common Subexpression Elimination for TIR

Posted by GitBox <gi...@apache.org>.
FranckQC edited a comment on pull request #9482:
URL: https://github.com/apache/tvm/pull/9482#issuecomment-1029590772


   ![image](https://user-images.githubusercontent.com/89943638/152464025-9be63301-c629-4ca5-af3d-8cf168f420ff.png)
   
   Alright, so now we have the following that are OK:
   - Both the windows and macOS builds
   - All the sanity check and all their syntactical checks with linters and clang-format (no change is being done by them to the code-base, yay!)
   - All the front-ends and integration tests
   
   Only 6 additional tests (test_lower_build_te_schedule, test_lower_build_tir_func, test_lower_build_tir_module, test_lower_build_lowered_module, test_tensor_compute1 and test_tensor_compute2) that I did not have are failing. For these ones, the CSE is doing some work, and therefore is breaking the syntactical check between the obtained TIR and the expected TIR.
   I'll try to check that the TIR code obtained after lowering with the CSE on is effectively what we expect, and I'll update the expected TIR with it.
   Or perhaps I should simply update these tests by lowering without the CSE pass (i.e. with the pass disabled)?
   
   Hoping to see that merged soon, I think we are not far!
   :)


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] FranckQC edited a comment on pull request #9482: Implementation of Common Subexpression Elimination for TIR

Posted by GitBox <gi...@apache.org>.
FranckQC edited a comment on pull request #9482:
URL: https://github.com/apache/tvm/pull/9482#issuecomment-969151523


   > 
   > 
   > Hi @FranckQC, I wanted TIR-level CSE for a long time, so very excited to see this!
   > 
   > What I wanted to do is to eliminate common expressions that span across the host and GPU - for example, in GPU `sort` kernel, I need to make `log2(N)` GPU kernel calls from the host to sort the input bottom up. In principle, `log2(N)` needs to computed once by the host and pass to the GPU kernel, but since we cannot CSE `log2(N)` expression that appears both in the host and GPU kernel, right now the GPU sort kernel is littered with `log2(N)` compute like this (note a lot of calls to `call_spirv_pure_glsl450` which is totally unnecessary if we had TIR-level CSE) https://gist.github.com/masahi/7a755ef67009e1a836e3212c53cf496f
   > 
   > Is this PR going to solve my problem?
   
   Hi @masahi 
   Thanks a lot for the kind words! I'm happy to read that this new pass will be useful to you.
   Yes, in principle, every redundant subterms that are eligible for being commoned out (i.e, which does not contain function calls, etc) will be commoned out. There are also a few restrictions which are due to some specifics of TVM, but these are rare.
   Do you have a little snippet of the TIR code that you have which has some redundancies? I cant try to tell if the CSE pass will be able to optimize it.
   Also please do not hesitate to play with the pass and to let me know if it does what you would hope to obtain. I can help of course.
   
   Kind regards.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi commented on pull request #9482: Implementation of Common Subexpression Elimination for TIR

Posted by GitBox <gi...@apache.org>.
masahi commented on pull request #9482:
URL: https://github.com/apache/tvm/pull/9482#issuecomment-972549506


   > Does that answers your question?
   
   @FranckQC Thanks, yes absolutely. I can work on extending this pass to support my use case, after this is merged.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] FranckQC commented on pull request #9482: Implementation of Common Subexpression Elimination for TIR

Posted by GitBox <gi...@apache.org>.
FranckQC commented on pull request #9482:
URL: https://github.com/apache/tvm/pull/9482#issuecomment-972608426


   Many thanks for the interesting questions and comments, I really appreciate it.
   
   > 
   > 
   > I have another three questions about the pass.
   > 
   >     1. Since function calls are non-eligible, it seems to prevent many common non-stateful builtin computations from being optimized, such as `tir.floordiv`, `tir.shift_right`, `tir.likely`, `tir.log` and etc.  TVM already have op annotation like `PURE`:
   >        https://github.com/apache/tvm/blob/main/include/tvm/tir/op_attr_types.h#L60-L67
   >        what if to add an annotation type like `DETERMINISTIC` to mark the function return same output on same input. It should be
   >        safe to eliminate common calls if the call op is both pure and deterministic.
   
   Yes, I definitely agree that we could, in the future, relax the restriction which currently prevent function calls from being commoned-out, for some functions that we know to be pure, including some builtin.
   I think just the "pure" tag would be enough for this purpose, if it has the meaning that I imagine, which is this one (quoted from Wikipedia) :
   
   <<
   a pure function is a function that has the following properties:[1][2]
   
       1. The function return values are identical for identical arguments (no variation with local static variables, non-local variables, mutable reference arguments or input streams).
       2. The function application has no side effects (no mutation of local static variables, non-local variables, mutable reference arguments or input/output streams).
   
   Thus a pure function is a computational analogue of a mathematical function. Some authors, particularly from the imperative language community, use the term "pure" for all functions that just have the above property 2. 
   >>
   > 
   What is crucial for us for commoning-out function calls is only the condition 1. As long as the function returns always the same output for the same input, we are fine with putting the result of applying the function to some input into a variable (or any kind of "cache").
   
   Please note that in the Wikipedia article they later mention that the condition 2 is needed for perfoming CSE, but they are talking about something different (and aren't very precise about it). They do not talk about commoning-out function calls here. They are talking about commoning-out some redundant terms, between which some function calls happen, as in :
   x = a+b+c;
   f(); // Here, if f can change the content of some variables like a, b or c, it prevents the CSE on the redundant term (a+b+c)
   y = a+b+c;
   
   But this kind of consideration does not concern us, as TIR is in a (weak) SSA form : the content of variables won't be updated, and can be assumed to be the same everywhere through the execution of the program. Thus, only condition 1 matters to us.
   
   If the kPure tag has both the meaning of 1 and 2, then we will be to just use it in the future when relaxing the condition.
   If it's only the condition 2, then we will indeed need another tag.
   
   >     2. Is the let binding semantic in tir explicitly what we want? Is the value expr is specified to be evaluated at let scope entry? otherwise there may be no reuse effect if let-binding allow the value expr get lazily expanded in let body.
   
   I must admit that I am not entirely sure now that I think of it... Anyone to confirm on that?
   
   >     3. Is this possible for some complex (but common) expr to be lifted out of control flow scope? eg
   >        ```python
   >        if cond0:
   >            tir.evaluate(complex_e)  # slow path
   >        else:
   >           if cond2: 
   >                tir.evaluate(complex_e)  # slow path
   >           else:
   >                ...  # general fast path
   >        ```
   >        
   >        
   >            
   >              
   >            
   >        
   >              
   >            
   >        
   >            
   >          
   >        after CSE maybe comes to
   >        ```python
   >        x = tir.Var("int32")
   >        with tir.let(x, complex_e):  # always evaluate slow path
   >            if cond0:
   >                tir.evaluate(x) 
   >            else:
   >                if cond2: 
   >                    tir.evaluate(x)
   >                else:
   >                    ... # general fast path
   >        ```
   
   Yes, that is possible. Although, now that I think of it, it would probably better to not do it in the exact example that you provided, as the expression `complex_e` of the slow path now becomes always evaluated, even if we end up going in the general fast branch...
   We should only lift things through "if" statements if both the then branch and the else branch computes it. This restriction doesn't improve the semantics preservation (which was previously fine), but it will avoid this kind of anti-optimization. That should be easily fixable.
   
   Many thanks!


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] FranckQC edited a comment on pull request #9482: Implementation of Common Subexpression Elimination for TIR

Posted by GitBox <gi...@apache.org>.
FranckQC edited a comment on pull request #9482:
URL: https://github.com/apache/tvm/pull/9482#issuecomment-969151523


   > 
   > 
   > Hi @FranckQC, I wanted TIR-level CSE for a long time, so very excited to see this!
   > 
   > What I wanted to do is to eliminate common expressions that span across the host and GPU - for example, in GPU `sort` kernel, I need to make `log2(N)` GPU kernel calls from the host to sort the input bottom up. In principle, `log2(N)` needs to computed once by the host and pass to the GPU kernel, but since we cannot CSE `log2(N)` expression that appears both in the host and GPU kernel, right now the GPU sort kernel is littered with `log2(N)` compute like this (note a lot of calls to `call_spirv_pure_glsl450` which is totally unnecessary if we had TIR-level CSE) https://gist.github.com/masahi/7a755ef67009e1a836e3212c53cf496f
   > 
   > Is this PR going to solve my problem?
   
   Hi @masahi 
   Thanks a lot for the kind words! I'm happy to read that this new pass might be useful to you.
   In principle, every redundant subterms that are eligible for being commoned out (i.e, which does not contain function calls, etc) will be commoned out. There are also a few restrictions which are due to some specifics of TVM, but these are rare.
   
   Unfortunately, on the TIR code snippet that you have uploaded, it seems that the subexpression that is redundant is just a function call. These can't be commoned out into a new variable by the CSE pass as it does not have any guarantee that the function has no side effects, meaning that it will always produce the same outputs for the same inputs. Without this guarantee, commoning out such function calls could change the program's semantics, so it's not done as preserving the semantics of the program is vital.
   
   I can imagine that for functions that are guaranteed to not do any side effects (and which are therefore "functions" in the mathematical sense of the term), we could relax this restriction, but that would be an extension to implement in the future. And it would rely on some "NoSideEffect" tag on functions.
   
   However, please note that if you had some other redundancies, this CSE pass would common out whatever redundancies you have that are eligible.
   For instance : 
   Assume you have the term (f(42) + f(42)) + (x*y + x*y) that appear somewhere.
   It is not eligible in its entirety, as it contains functions call.
   But its subterm (x*y + x*y) is eligible, so this subpart will be commoned out.
   
   Does that answers your question?
   Please do not hesitate to tell me if you need help for trying the pass.
   
   Kind regards.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] FranckQC commented on pull request #9482: Implementation of Common Subexpression Elimination for TIR

Posted by GitBox <gi...@apache.org>.
FranckQC commented on pull request #9482:
URL: https://github.com/apache/tvm/pull/9482#issuecomment-1031844098


   Oh I see, sorry!
   Before I push again, would you mind looking at the error in test_vta_insn.py?
   I added the CSE pass as disabled (as I did for some other tests), but for some reason this one is still a failure and I don't really udnerstand what's going on there.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi commented on pull request #9482: Implementation of Common Subexpression Elimination for TIR

Posted by GitBox <gi...@apache.org>.
masahi commented on pull request #9482:
URL: https://github.com/apache/tvm/pull/9482#issuecomment-1034251858


   > Overall looks good to me and we should merge as is, @masahi do you mind adding some follow up issues to track the TODOs generated during review?
   
   Sure! In addition to https://github.com/apache/tvm/pull/9482#discussion_r790200526, I think we can refactor our `Substitute` function to be a special case of the subsutiuter with a predicate added in this PR. 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] FranckQC commented on pull request #9482: Implementation of Common Subexpression Elimination for TIR

Posted by GitBox <gi...@apache.org>.
FranckQC commented on pull request #9482:
URL: https://github.com/apache/tvm/pull/9482#issuecomment-1034315546


   > Thanks so much for the clear and complete comments, very much appreciate that.
   > 
   > Echo @Hzfengsy suggestion to switch to TVMScript for your tests -- perhaps just pretty printing what you already constructs would be a short cut.
   > 
   > The code duplication across PrimExpr & Stmnt in CommonSubexpressionElimintor is unfortunate but suspect removing that would only obscure things.
   > 
   > In your experience is this mostly firing on the affine index sub-expressions, or do you see cse over actual data sub-expressions?
   
   Thank you so much for the compliment, I really appreciate it. It makes me happy to know that the code is easy to read!
   
   If I recall well I saw quite a lot of indices (mostly from loop unrolling), just like what @wrongtest had here https://github.com/apache/tvm/pull/9482#issuecomment-972463175. 
   
   Also some indices due to lowering of memory accesses, for instance:
   C[x,y] = C[x,y] + A[x,k] * B[y,k]
   which can lowered (2D to 1D) to:
   C[x*128+y] = C[x*128+y] + A[x*128+k]*B[y*128+k]
   which gives the opportunity to create first:
   cse_var_1 = x*128+y
   and then in cascade:
   cse_var_2 = x*128
   
   And I also recall a lot of random commoning, like:
   ![result-CSE-realProgram-better](https://user-images.githubusercontent.com/89943638/153308601-ce7edc8e-67be-4b8c-b892-7df78746eaaf.png)
   
   I'll post more if I can find my notes where I had more interesting commonings performed in test files.
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi commented on pull request #9482: Implementation of Common Subexpression Elimination for TIR

Posted by GitBox <gi...@apache.org>.
masahi commented on pull request #9482:
URL: https://github.com/apache/tvm/pull/9482#issuecomment-1034955150


   Follow-up items in https://github.com/apache/tvm/issues/10211


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi commented on a change in pull request #9482: Implementation of Common Subexpression Elimination for TIR

Posted by GitBox <gi...@apache.org>.
masahi commented on a change in pull request #9482:
URL: https://github.com/apache/tvm/pull/9482#discussion_r790201528



##########
File path: src/tir/transforms/common_subexpr_elim_tools.cc
##########
@@ -0,0 +1,836 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+* \file common_subexpr_elim_tools.cc
+* \brief Implementation of analysis tools and utility functions used
+          by the Common Subexpression Elimination (CSE) pass.
+*/
+
+#include "common_subexpr_elim_tools.h"
+
+#include <tvm/ir/transform.h>  // For the class Pass and the class PassContext
+#include <tvm/runtime/container/string.h>
+#include <tvm/tir/analysis.h>  // For the ExprDeepEqual analysis
+#include <tvm/tir/expr.h>
+#include <tvm/tir/expr_functor.h>
+#include <tvm/tir/function.h>  // For the class PrimFunc
+#include <tvm/tir/stmt.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>  // For the declaration of the pass
+
+#include <algorithm>      // For std::find_if
+#include <unordered_map>  // For the hashtable datatype
+#include <vector>
+
+#include "../analysis/check_contains.h"  // For the CheckContains analysis
+
+namespace tvm {
+namespace tir {
+
+// cache_ is a static variable of the class ComputationsDoneBy, and C++ requires to define here
+// such static attribute, otherwise it causes a linking error.
+CacheOfComputations ComputationsDoneBy::cache_;
+
+/* ********************************** Class ComputationsDoneBy **********************************
+*********************************************************************************************** */
+
+/* This utility class of the CSE pass offers a way of knowing the eligible computations done by a
+   statement or expression. A "computation" here is a syntatical entity, represented by a PrimExpr.
+   This analysis returns a hashtable associating PrimExpr (a computation done) to a number (which
+   is the number of time that this computation is being seen).
+   This analysis is used by the CSE pass in order to find potential candidates for being introduced
+   into new variables (after having merged semantically equivalent computations).
+
+   This analysis is parametrized by two predicates : `is_eligible_computation` and
+   `can_contain_computations`.
+   The first one helps to select only "eligible" computations, and the second one helps to only
+   select computations that are located at appropriate location (i.e., it tells in which nodes the
+   analysis can recurse). The user of the class must define these notions of "eligible computation"
+   and of "nodes that can contain eligibile computations" for his own use case.
+
+   - On an statement, this analysis often returns the union of all the computations that appear in
+   its child nodes (ie, the union of the results of the recursive calls).
+   For instance, on the input statement [let a = x+y in Mem[i1+i2] = a+b] it will report (x+y)
+   seen once, (i1+i2) seen once, and (a+b) also seen once when used with typical predicates.
+   On some nodes, it will return something more complicated that uses the intersection of the
+   computations done by the children nodes.
+   For instance, on the input statement [if (x+y>z) then a = x+y else a = b-x] it will return
+   (x+y) seen twice but it won't report b-x as is it seen only the else branch.
+
+   - On an expression, this analysis returns the expression itself, except if it is not eligible
+   for being introduced by the CSE pass into a variable according to `is_eligible_computation_`
+   (often because it's a load node or a function call node for instance), in which case it will
+   return the union of the recursive calls on its children, as long as the other predicate
+   `can_contain_computations` evaluates to true to let the algorithm recurse deeper.
+   With typical predicates, on the expression ((w+x)+(y+z)) it will return only the expression
+   itself. But on the expression Load[i1+i2] it might return only (i1+i2) as the full Load node
+   might not be eligible.
+
+   This class uses an internal cache of results, so that if one queries it several times on the
+   same statement or expression, it will just retrieve the result from its internal cache.
+   That avoids some systematic recomputations, which would otherwise happen as the CSE pass first
+   analyses the program at the toplovel (asking for the computations done by the root), and then
+   dives deeper and deeper into the program, asking for the computations done by the children of
+   the root, which were necessarly previously obtained when computing the computations done by the
+   root (as the computations done by the root are by definition the union of the computations done
+   by the children nodes).
+
+   The somehow difficult aspect of the implementation is the interaction between this caching of
+   results, and the fact that the VisitStmt()/VisitExpr() of an analyzer (a StmtExprVisitor) are
+   void methods which can't return anything, and instead need to accumulate a result into a member
+   variable, which is called `table_of_computations_` here.
+
+   In particular, as the specialized methods (all the VisitStmt_() and VisitExpr_() methods), just
+   call VisitStmt()/VisitExpr() on all the children nodes within the same instance, if we don't
+   want to override each of these specialized methods to change this behaviour, then
+   `table_of_computations_` will necessary be shared by all the children of a given nodes.
+   That requires to be careful when trying to write into the cache.
+*/
+
+/*!
+ * \brief Does the union of two tables of computations.
+ * \param table_main One of the two tables. The union will be written into it.
+ * \param table_aux The other table, which won't change.
+ * \note Does it directly in the first argument A for efficiency, as the union of A and B
+ *       necessarily gives something which contains A, so we avoid its copy.
+ */
+void UnionOf2TablesOfComputations(TableOfComputations& table_main,
+                                  const TableOfComputations& table_aux) {
+  // Adds each element of the second table to the first one
+  for (const auto& current : table_aux) {
+    table_main[current.first] += current.second;
+  }
+}
+
+/*!
+ * \brief Does the union of three tables of computations.
+ * \param table1 One of the three tables, which won't change.
+ * \param table2 One of the three tables, which won't change.
+ * \param table3 One of the three tables, which won't change.
+ * \note We don't need (at least yet) to have a function working for N tables, even if this
+ *       function for 3 tables seems at first glance redundant with the one for 2 tables defined
+ *       just above. The reason is that in order to do the union for N tables, we need to know how
+ *       to do it for two. That's because we would compute for N tables using the associativity
+ *       of the union : T1 U T2 U T3 ... U Tn = ((T1 U T2) U T3) ... U Tn
+ *       Therefore, we need one for 2 tables anyway. And as the one for N (N>=3) would only be used
+ *       (at least for now) for N=3, there is at the moment no need for such a generic union over
+ *       N tables.
+ */
+TableOfComputations UnionOf3TablesOfComputations(const TableOfComputations& table1,
+            const TableOfComputations& table2, const TableOfComputations& table3) {
+  TableOfComputations result = table1; // Copy needed as the union of 2 writes into its first arg
+  UnionOf2TablesOfComputations(result, table2);
+  UnionOf2TablesOfComputations(result, table3);
+
+  return result;
+}
+
+/*!
+ * \brief Does the intersection of two tables of computations.
+ * \param table1 One of the two tables, which won't change.
+ * \param table2 The other table, which also won't change.
+ */
+TableOfComputations IntersectionOf2TablesOfComputations(const TableOfComputations& table1,
+                                                        const TableOfComputations& table2) {
+  TableOfComputations result;
+  for (const auto& current : table1) {
+    auto it = table2.find(current.first);
+    if (it != table2.end()) {
+      result[current.first] = current.second + it->second;
+    }
+  }
+  return result;
+}
+
+/*!
+ * \brief Does the intersection of three tables of computations.
+ * \param table1 One of the three tables, which won't change.
+ * \param table2 One of the three tables, which won't change.
+ * \param table3 One of the three tables, which won't change.
+ * \note We don't need (at least yet) to have a function working for N tables, even if this
+ *       function for 3 tables seems at first glance redundant with the one for 2 tables defined
+ *       just above. The reason is that in order to do the intersection for N tables, we need to
+ *       know how to do it for two. That's because we would compute for N tables using the
+ *       associativity of the intersection : T1 Inter T2 Inter T3 ... Inter Tn
+ *       = ((T1 Inter T2) Inter T3) ... Inter Tn
+ *       Therefore, we need one for 2 tables anyway. And as the one for N (N>=3) would only be used
+ *       (at least for now) for N=3, there is at the moment no need for such a generic intersection
+ *       over N tables.
+ */
+TableOfComputations IntersectionOf3TablesOfComputations(const TableOfComputations& table1,
+                  const TableOfComputations& table2, const TableOfComputations& table3) {
+  TableOfComputations result = IntersectionOf2TablesOfComputations(table1, table2);
+  result = IntersectionOf2TablesOfComputations(result, table3);
+  return result;
+}
+
+/*!
+ * \brief Recompute the number of times that each computation in table_main
+          is being seen in table_bloc1, table_bloc2 and table_bloc3. It sets
+          each element to the sum of the times it is seen in each individual bloc.
+ * \param table_main The main table, for which we recompute the counters.
+ * \param table1 One of the three tables, which won't change.
+ * \param table2 One of the three tables, which won't change.
+ * \param table3 One of the three tables, which won't change.
+ * \note This function is needed because both the intersection (A Inter B) and the union
+ *        (A U B U C) adds the individual counters found in A, B and C. So when we treat for
+ *        instance an If (which contains a Cond, a Then branch and an Else branch),
+ *        it will compute (Then Inter Else) U (Cond Inter Then) U (Cond Inter Else).
+ *        In order to get back to the appropriate number (for instance, 3 if seen one time in each
+ *        bloc), it is therefore necessary to recompute the counters afterwards, which is what this
+ *        function does.
+ */
+void RecomputeNbTimesSeenInThreeBlocs(TableOfComputations& table_main,
+    const TableOfComputations& table_bloc1, const TableOfComputations& table_bloc2,
+    const TableOfComputations& table_bloc3) {

Review comment:
       I prefer taking a vector of `ComputationTable` here and remove `ThreeBlocs` suffix.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi commented on a change in pull request #9482: Implementation of Common Subexpression Elimination for TIR

Posted by GitBox <gi...@apache.org>.
masahi commented on a change in pull request #9482:
URL: https://github.com/apache/tvm/pull/9482#discussion_r790178147



##########
File path: src/tir/transforms/common_subexpr_elim_tools.cc
##########
@@ -0,0 +1,836 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+* \file common_subexpr_elim_tools.cc
+* \brief Implementation of analysis tools and utility functions used
+          by the Common Subexpression Elimination (CSE) pass.
+*/
+
+#include "common_subexpr_elim_tools.h"
+
+#include <tvm/ir/transform.h>  // For the class Pass and the class PassContext
+#include <tvm/runtime/container/string.h>
+#include <tvm/tir/analysis.h>  // For the ExprDeepEqual analysis
+#include <tvm/tir/expr.h>
+#include <tvm/tir/expr_functor.h>
+#include <tvm/tir/function.h>  // For the class PrimFunc
+#include <tvm/tir/stmt.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>  // For the declaration of the pass
+
+#include <algorithm>      // For std::find_if
+#include <unordered_map>  // For the hashtable datatype
+#include <vector>
+
+#include "../analysis/check_contains.h"  // For the CheckContains analysis
+
+namespace tvm {
+namespace tir {
+
+// cache_ is a static variable of the class ComputationsDoneBy, and C++ requires to define here
+// such static attribute, otherwise it causes a linking error.
+CacheOfComputations ComputationsDoneBy::cache_;
+
+/* ********************************** Class ComputationsDoneBy **********************************
+*********************************************************************************************** */
+
+/* This utility class of the CSE pass offers a way of knowing the eligible computations done by a
+   statement or expression. A "computation" here is a syntatical entity, represented by a PrimExpr.
+   This analysis returns a hashtable associating PrimExpr (a computation done) to a number (which
+   is the number of time that this computation is being seen).
+   This analysis is used by the CSE pass in order to find potential candidates for being introduced
+   into new variables (after having merged semantically equivalent computations).
+
+   This analysis is parametrized by two predicates : `is_eligible_computation` and
+   `can_contain_computations`.
+   The first one helps to select only "eligible" computations, and the second one helps to only
+   select computations that are located at appropriate location (i.e., it tells in which nodes the
+   analysis can recurse). The user of the class must define these notions of "eligible computation"
+   and of "nodes that can contain eligibile computations" for his own use case.
+
+   - On an statement, this analysis often returns the union of all the computations that appear in
+   its child nodes (ie, the union of the results of the recursive calls).
+   For instance, on the input statement [let a = x+y in Mem[i1+i2] = a+b] it will report (x+y)
+   seen once, (i1+i2) seen once, and (a+b) also seen once when used with typical predicates.
+   On some nodes, it will return something more complicated that uses the intersection of the
+   computations done by the children nodes.
+   For instance, on the input statement [if (x+y>z) then a = x+y else a = b-x] it will return
+   (x+y) seen twice but it won't report b-x as is it seen only the else branch.
+
+   - On an expression, this analysis returns the expression itself, except if it is not eligible
+   for being introduced by the CSE pass into a variable according to `is_eligible_computation_`
+   (often because it's a load node or a function call node for instance), in which case it will
+   return the union of the recursive calls on its children, as long as the other predicate
+   `can_contain_computations` evaluates to true to let the algorithm recurse deeper.
+   With typical predicates, on the expression ((w+x)+(y+z)) it will return only the expression
+   itself. But on the expression Load[i1+i2] it might return only (i1+i2) as the full Load node
+   might not be eligible.
+
+   This class uses an internal cache of results, so that if one queries it several times on the
+   same statement or expression, it will just retrieve the result from its internal cache.
+   That avoids some systematic recomputations, which would otherwise happen as the CSE pass first
+   analyses the program at the toplovel (asking for the computations done by the root), and then
+   dives deeper and deeper into the program, asking for the computations done by the children of
+   the root, which were necessarly previously obtained when computing the computations done by the
+   root (as the computations done by the root are by definition the union of the computations done
+   by the children nodes).
+
+   The somehow difficult aspect of the implementation is the interaction between this caching of
+   results, and the fact that the VisitStmt()/VisitExpr() of an analyzer (a StmtExprVisitor) are
+   void methods which can't return anything, and instead need to accumulate a result into a member
+   variable, which is called `table_of_computations_` here.
+
+   In particular, as the specialized methods (all the VisitStmt_() and VisitExpr_() methods), just
+   call VisitStmt()/VisitExpr() on all the children nodes within the same instance, if we don't
+   want to override each of these specialized methods to change this behaviour, then
+   `table_of_computations_` will necessary be shared by all the children of a given nodes.
+   That requires to be careful when trying to write into the cache.
+*/
+
+/*!
+ * \brief Does the union of two tables of computations.
+ * \param table_main One of the two tables. The union will be written into it.
+ * \param table_aux The other table, which won't change.
+ * \note Does it directly in the first argument A for efficiency, as the union of A and B
+ *       necessarily gives something which contains A, so we avoid its copy.
+ */
+void UnionOf2TablesOfComputations(TableOfComputations& table_main,
+                                  const TableOfComputations& table_aux) {
+  // Adds each element of the second table to the first one
+  for (const auto& current : table_aux) {
+    table_main[current.first] += current.second;
+  }
+}
+
+/*!
+ * \brief Does the union of three tables of computations.
+ * \param table1 One of the three tables, which won't change.
+ * \param table2 One of the three tables, which won't change.
+ * \param table3 One of the three tables, which won't change.
+ * \note We don't need (at least yet) to have a function working for N tables, even if this
+ *       function for 3 tables seems at first glance redundant with the one for 2 tables defined
+ *       just above. The reason is that in order to do the union for N tables, we need to know how
+ *       to do it for two. That's because we would compute for N tables using the associativity
+ *       of the union : T1 U T2 U T3 ... U Tn = ((T1 U T2) U T3) ... U Tn
+ *       Therefore, we need one for 2 tables anyway. And as the one for N (N>=3) would only be used
+ *       (at least for now) for N=3, there is at the moment no need for such a generic union over
+ *       N tables.
+ */
+TableOfComputations UnionOf3TablesOfComputations(const TableOfComputations& table1,
+            const TableOfComputations& table2, const TableOfComputations& table3) {
+  TableOfComputations result = table1; // Copy needed as the union of 2 writes into its first arg
+  UnionOf2TablesOfComputations(result, table2);
+  UnionOf2TablesOfComputations(result, table3);
+
+  return result;
+}
+
+/*!
+ * \brief Does the intersection of two tables of computations.
+ * \param table1 One of the two tables, which won't change.
+ * \param table2 The other table, which also won't change.
+ */
+TableOfComputations IntersectionOf2TablesOfComputations(const TableOfComputations& table1,
+                                                        const TableOfComputations& table2) {
+  TableOfComputations result;
+  for (const auto& current : table1) {
+    auto it = table2.find(current.first);
+    if (it != table2.end()) {
+      result[current.first] = current.second + it->second;
+    }
+  }
+  return result;
+}
+
+/*!
+ * \brief Does the intersection of three tables of computations.
+ * \param table1 One of the three tables, which won't change.
+ * \param table2 One of the three tables, which won't change.
+ * \param table3 One of the three tables, which won't change.
+ * \note We don't need (at least yet) to have a function working for N tables, even if this
+ *       function for 3 tables seems at first glance redundant with the one for 2 tables defined
+ *       just above. The reason is that in order to do the intersection for N tables, we need to
+ *       know how to do it for two. That's because we would compute for N tables using the
+ *       associativity of the intersection : T1 Inter T2 Inter T3 ... Inter Tn
+ *       = ((T1 Inter T2) Inter T3) ... Inter Tn
+ *       Therefore, we need one for 2 tables anyway. And as the one for N (N>=3) would only be used
+ *       (at least for now) for N=3, there is at the moment no need for such a generic intersection
+ *       over N tables.
+ */
+TableOfComputations IntersectionOf3TablesOfComputations(const TableOfComputations& table1,

Review comment:
       How about renaming both `IntersectionOf2TablesOfComputations` and `IntersectionOf3TablesOfComputations` to `IntersectComputationTables` to make it less verbose.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi commented on a change in pull request #9482: Implementation of Common Subexpression Elimination for TIR

Posted by GitBox <gi...@apache.org>.
masahi commented on a change in pull request #9482:
URL: https://github.com/apache/tvm/pull/9482#discussion_r790201064



##########
File path: src/tir/transforms/common_subexpr_elim.cc
##########
@@ -0,0 +1,601 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file common_subexpr_elim.cc
+ * \brief Implementation of the Common Subexpressions Elimination (CSE) pass
+           which rewrites statements and expressions in order to eliminate
+           redundant computations. In order to achieve that, common (sub-)
+           expressions are introduced into variables with let-in bindings,
+           and the places where the expression was used are replaced with
+           the freshly introduced variable.
+ */
+
+#include "common_subexpr_elim.h"
+
+#include <tvm/ir/transform.h>  // For the class Pass and the class PassContext
+#include <tvm/runtime/container/array.h>
+#include <tvm/runtime/container/string.h>
+#include <tvm/tir/analysis.h>  // For the analysis which gives the size of an expr
+#include <tvm/tir/expr.h>
+#include <tvm/tir/expr_functor.h>
+#include <tvm/tir/function.h>  // For the class PrimFunc
+#include <tvm/tir/stmt.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>  // For the decl of the function returning the pass
+
+#include <algorithm>  // For the algorithm std::find
+#include <iostream>
+#include <unordered_map>  // For the hashtable datatype
+#include <utility>        // For std::pair and std::move
+#include <vector>
+
+#include "../analysis/check_contains.h"  // For the visitor CheckContains
+#include "common_subexpr_elim_tools.h"   // For the auxiliary analysis (visitors) and tools
+#include "replace_expr_selected.h"       // For the mutator ReplaceExprSelected
+
+namespace tvm {
+namespace tir {
+
+/*!
+ * \brief Check whether a computation is forbidden for being treated by the CSE pass.
+          The important thing about forbidden computations is that not only we won't want
+          to collect them for the CSE pass, but we also won't even want to collect computations
+          that contain them.
+          The reason is that reusing such computations would change the semantics of the program,
+          and therefore before doing any introduction of variable or any reuse of already introduced
+          variables, we will make sure that the computation being considered is not forbidden, and
+          that it does not even contain a forbidden computation.
+ * \param expr The expression to check
+ * \return Whether `expr` is a forbidden computation or not
+ */
+bool CommonSubexpressionEliminator::ForbiddenComputation(const PrimExpr& expr) {
+  // Function calls, loads and buffer loads are absolutely forbidden as introducing them into
+  // variables would change the semantics of the program.
+  return (expr.as<CallNode>() != nullptr || expr.as<LoadNode>() != nullptr ||
+          expr.as<BufferLoadNode>() != nullptr);
+}
+
+/*!
+ * \brief Predicate used for verifying that a computation is eligible for being treated by
+          the CSE pass, i.e. for being introduced into a variable / for being replaced by a
+          variable.
+          Being eligible is a conjunction of a few conditions, like not being an atom (constant
+          or variable), not being a forbidden node, not containing a forbidden node, etc.
+ * \param expr The expression to check
+ * \return Whether `expr` is an eligible computation or not
+ */
+bool CommonSubexpressionEliminator::IsEligibleComputation(const PrimExpr& expr) {
+  return (
+      // In order to be eligible, the given expression should not be a constant
+      (expr.as<IntImmNode>() == nullptr) && (expr.as<FloatImmNode>() == nullptr) &&
+      (expr.as<StringImmNode>() == nullptr)
+      // and it should not be a variable
+      && (expr.as<VarNode>() == nullptr)
+      // and it should not be a forbidden computation (function calls and loads)
+      && (!ForbiddenComputation(expr))
+      // and it should not even contain a forbidden computation (function calls and loads)
+      // the reason is that we don't want to register expressions like (x + f(y)) or
+      // (x + Mem[i]) as introducing them into variables could change the semantics
+      && (!CheckContains::ExprContains(expr, ForbiddenComputation))
+      // and it should not be a ramp node or a broadcast node due to some internals TVM
+      // constraints (which check for these node explicitely without performing any
+      // evaluation first, so if they have been put into variables it fails)
+      && (expr.as<RampNode>() == nullptr) && (expr.as<BroadcastNode>() == nullptr));
+}
+
+/*!
+ * \brief Predicate used (when considering eligible computations) for only diving into
+          expressions that are allowed to contain eligible computations. Customize this predicate
+          if you want to make it forbidden to rewrite inside a specific node, like inside
+          a Load node for instance.
+ * \param expr The expression to check
+ * \return Whether `expr` can contain some eligible computations or not, and therefore
+             if recursing inside `expr` is necessary.
+ */
+bool CommonSubexpressionEliminator::CanContainEligibleComputations(const PrimExpr& expr) {
+  // Uncomment the next line to prevent the collection and the replacement of eligible computations
+  // inside the index of Load nodes. We initially thought that this would be needed in order to
+  // not harm the indexing mode of the CPU, but as we are still far from ASM code, we
+  // finally want to perform such simplifications, which tend to happen fairly frequently.
+
+  // return ( (expr.as<LoadNode>() == nullptr) && (expr.as<BufferLoadNode>() == nullptr) )
+  return true;
+};
+
+/*!
+ * \brief Generates a new fresh variable, whose name will be cse_var_i.
+ * \param type_annotation The type of the new variable to generate
+ * \return A new variable of type `type_annotation` called cse_var_i where i is the first available
+            integer.
+ */
+Var CommonSubexpressionEliminator::GenerateNewVar(DataType type_annotation) {
+  // Increase `num_last_try_` for this new attempt
+  num_last_try_++;
+  // Builds the variable name, which is sce_var_i where i will go up from 1
+  std::string prefix = "cse_var_";
+  std::string name = prefix.append(std::to_string(num_last_try_));
+  // Builds a String using the std::string
+  String string_name(name);
+
+  // Check that the name that we want to use for the new variable isn't already being used
+  // (names don't really have to be unique as they are just hints, and having the same name
+  // doesn't means that it's the same variable, but it's clearer for dumps)
+  if (UsesVarName::StmtUsesVarName(initial_body_, string_name)) {
+    // If the name is already used, call ourselves recursively for trying with the next one
+    return GenerateNewVar(type_annotation);
+  }
+
+  // Increase `nb_var_` for this new generation of variable that we have just done
+  nb_var_++;
+
+  // Return a new Variable using the name built and the given type_annotation
+  return (Var(string_name, type_annotation));
+}
+
+/*!
+ * \brief Gives the number of variables generated by the CSE on the current function
+           (i.e., getter for `nb_var_`).
+ * \return A copy of `nb_var_`
+ */
+int CommonSubexpressionEliminator::GetNbVarGenerated() { return nb_var_; }
+
+/*!
+ * \brief Toplevel (static) method that performs Common Subexpression Elimination on
+          a given statement (which should be the body of a PrimFunc). This method should be
+          called for each PrimFunc definition.
+ * \param stmt The statement of the function being analyzed, on which we want to perform CSE
+ * \param context_init The initial context, which should contain the formal parameters
+                          of the function being analyzed
+ * \return A new statement where CSE has been performed
+ */
+Stmt CommonSubexpressionEliminator::PerformCSE(const Stmt& stmt, const Context& context_init) {
+  // As this function is being called for each PrimFunc definition, we create a new instance
+  // for the one we are having now.
+  CommonSubexpressionEliminator common_subexpression_eliminator(stmt, context_init);
+  return common_subexpression_eliminator.VisitStmt(stmt);
+}
+
+/*!
+ * \brief Protected constructor of CommonSubexpressionEliminator.
+ * \param context_init The context at the begining of the CSE pass. It should contain the
+                        formal parameters of the function that will be analyzed
+ */
+CommonSubexpressionEliminator::CommonSubexpressionEliminator(const Stmt& stmt,
+                                                             const Context& context_init)
+    : initial_body_(stmt), context_(context_init) {}
+
+/*!
+ * \brief The method which overrides the generic dispatcher of StmtExprMutator.
+          Entry point to the common subexpression elimination mutator for expressions.
+ * \param expr The expression to mutate
+ */
+PrimExpr CommonSubexpressionEliminator::VisitExpr(const PrimExpr& expr) {
+  PrimExpr result = expr;
+
+  // Obtain the (syntactic) eligible computations done by the input expression, and keep it as
+  // a TableOfComputations, which is a mapping from PrimExpr to size_t, where the size_t is the
+  // number of time this exact syntactic computation is being computed.
+  TableOfComputations table_syntactic_comp_done_by_expr = ComputationsDoneBy::GetComputationsDoneBy(
+      expr, IsEligibleComputation, CanContainEligibleComputations);
+
+  // Transform the hashtable of *syntactic* eligible computations into a vector of pairs
+  // containing *semantic* entities, i.e. where equivalent computations are merged.
+  std::vector<std::pair<PrimExpr, size_t>> semantic_comp_done_by_expr =
+      SyntacticToSemanticComputations(table_syntactic_comp_done_by_expr);
+
+  // Sort the vector of semantic entities by decreasing size
+  std::sort(semantic_comp_done_by_expr.begin(), semantic_comp_done_by_expr.end(),
+            [](std::pair<PrimExpr, size_t> a, std::pair<PrimExpr, size_t> b) {
+              return (CalculateExprComplexity(a.first) > CalculateExprComplexity(b.first));
+            });
+
+  // For each computation done (considering them from biggest to smallest)
+  for (size_t i = 0; i < semantic_comp_done_by_expr.size(); i++) {
+    std::pair<PrimExpr, size_t>& computation_and_nb = semantic_comp_done_by_expr[i];
+
+    // The predicate later used (when doing replacements) to select expressions that are
+    // equivalent to the current computation (`computation_and_nb.first`)
+    std::function<bool(const PrimExpr&)> predicate_selector =
+        [computation_and_nb](const PrimExpr& current_expr) {
+          // `current_expr` should be equivalent to `computation_and_nb.first`, but we also check
+          // that `current_expr` is an eligible computation even if we know that
+          // `computation_and_nb.first` is eligible by construction, in case that one day the
+          // equivalence relation would not preserve the eligibility any more (even though that
+          // would probably be a very weird equivalence).
+          return (EquivalentTerms(current_expr, computation_and_nb.first) &&
+                  IsEligibleComputation(current_expr));
+        };
+
+    // See if there is a pair (`var`, `value`) in the context where `value` is semantically
+    // equivalent to `computation_and_nb.first`
+    auto it_on_var = std::find_if(
+        context_.begin(), context_.end(),
+        [computation_and_nb](const std::pair<Var, MaybeValue>& var_and_value) {
+          // Note : safe to call value() as we check has_value() just before
+          return (var_and_value.second.has_value() &&
+                  EquivalentTerms(var_and_value.second.value(), computation_and_nb.first));
+        });
+
+    // Case where we have a perfectly equivalent computation already available in a variable
+    // introduced (i.e, present in context_).
+    // Note that this case is needed when the user has written something like
+    // [let x = A in ....A...A...] : we need to be able to replace all the occurences of A by
+    // an already existing variable holding A, when such a variable happens to exist.
+    if (it_on_var != context_.end()) {
+      // Replace in the current `result` everything that is selected by the selector with
+      // the existing variable, without diving into expressions in which we don't have the
+      // right to dive.
+      result = ReplaceExprSelected::ReplaceExprSelectedInExpr(
+          result, predicate_selector, it_on_var->first, CanContainEligibleComputations);
+    } else {
+      // The current computation is not equivalent to a computation already done. We will
+      // need to see if we want to introduce it.
+
+      // --- Chunk needed for reusing the UndefinedVars() analysis ---
+      // 1 - Wraps the computation into a statement
+      Stmt computation_wrapped_in_stmt = Evaluate(computation_and_nb.first);
+      // 2.1 - Transform the context into a vector of variables instead of pairs
+      std::function<Var(const std::pair<Var, MaybeValue>&)> forget_value =
+          [](const std::pair<Var, MaybeValue>& pair) { return pair.first; };
+      std::vector<Var> vector_vars_known = VectorMap(context_, forget_value);
+      // 2.2 - Transform the std::vector into an Array
+      Array<Var> array_vars_known = Array<Var>(vector_vars_known);
+      // --- End of chunk needed for reusing the UndefinedVars() analysis ---
+
+      // We use the UndefinedVars() analysis to get the undefined vars of the computation
+      Array<Var> vars_undefined = UndefinedVars(computation_wrapped_in_stmt, array_vars_known);
+
+      // Check if we can introduce it : if it contains no undefined variables and if we want
+      // to introduce it according to the predicate
+      if (vars_undefined.empty() &&
+          PredicateIntroVarForComputation(computation_and_nb.first, computation_and_nb.second)) {
+        // Create a new variable for this computation
+        Var new_var = GenerateNewVar(computation_and_nb.first.dtype());
+        // Replace in the current `result` everything that is selected by the selector with
+        // the new variable, without diving into expressions in which we don't have the
+        // right to dive.
+        result = ReplaceExprSelected::ReplaceExprSelectedInExpr(result, predicate_selector, new_var,
+                                                                CanContainEligibleComputations);
+        // Build a let-in that introduces the new variable in the current `result`
+        result = Let(new_var, computation_and_nb.first, result);
+        // We don't add the variable to the context because the invariant is that the
+        // context is the context in which 'result' makes sense, and we've just updated it.
+      } else {
+        // Here it's not doable to introduce (via a let-in) the computation at this level
+        // as it contains variables that are not yet declared, and/or because the predicate
+        // did not select it.
+        // Either way, we will simply add to the vector of computations the direct subexprs
+        // of the current computation, as these ones might be good candidates
+        // for being introduced into variables.
+        // Note that we don't need to add all of its subexpressions, but only its *direct*
+        // subexpressions as we consider them from biggest to smallest, and if they were
+        // all added at once, then there could be dependencies between them, as commoning
+        // one of them could remove some other possibilities.
+
+        // Computing the direct subexpressions will return a small number of direct
+        // subexpressions (typically 0 to 3)
+        std::vector<PrimExpr> direct_subexprs = DirectSubexpr::GetDirectSubexpressions(
+            computation_and_nb.first, IsEligibleComputation, CanContainEligibleComputations);
+        // The following insertion will maintain `semantic_comp_done_by_expr` sorted (by
+        // decreasing size/complexity), and it will only insert at locations > i as the
+        // direct subexprs are necessarily smaller than the current computation.
+        InsertVectorToSortedSemanticComputations(semantic_comp_done_by_expr, direct_subexprs);
+      }
+    }
+    // Note : we do not remove the current element, as we never look back in the local vector
+  }  // End of for loop
+
+  // Calling the dispatcher to the specific treatments, which will update the context
+  // appropriately before doing the recursive calls on the child nodes
+  result = StmtExprMutator::VisitExpr(result);
+
+  return result;
+}
+
+/*!
+ * \brief The method which overrides the specific treatment for a LetNode
+ */
+PrimExpr CommonSubexpressionEliminator::VisitExpr_(const LetNode* op) {
+  // At this point, we have already done the generic treatment of introducing (via let-in) what
+  // was doable at the toplevel of the given let-in.
+
+  // Save the context at the entry of the function
+  Context context_at_entry = context_;
+
+  // Recurse on the `value` field for potentially rewriting it
+  PrimExpr value_new = VisitExpr(op->value);
+
+  // Augment the context with the association (`var`, `value`) for preparing the next recursion
+  // on the `body`
+  context_.push_back({op->var, MaybeValue(op->value)});
+
+  // Recurse on the `body` (with this extended context)
+  // The recursive call will have potentially done new simplifications, because in this recursive
+  // call `var` will be a part of the context.
+  // (see in VisitExpr() that no introduction were performed when a computation was using an
+  // undefined variable, as that would lead to ill-formed code)
+  PrimExpr body_new = VisitExpr(op->body);
+
+  // Restaure the context to its content at the entrance to not carry out of scope declarations
+  // as the variable introduced by the let-in is not in scope outside of its body
+  context_ = context_at_entry;
+
+  // Rebuild the let-in with a new `value_new` and `body_new` where new simplifications might
+  // have been done.
+
+  // If the `value` and the `body` of the let-in have been rewritten to the same thing
+  if (value_new.same_as(op->value) && body_new.same_as(op->body)) {
+    // then return a reference to the same node
+    return GetRef<PrimExpr>(op);
+  } else {
+    // Otherwise return a let-in built with the new `value_new` and the new `body_new` that
+    // have just been obtained
+    return Let(op->var, value_new, body_new, op->span);
+  }
+}
+
+/*!
+ * \brief The method which overrides the generic dispatcher of StmtExprMutator.
+          Entry point to the common subexpression elimination mutator for statements.
+ * \param stmt The statement to mutate.
+ */
+Stmt CommonSubexpressionEliminator::VisitStmt(const Stmt& stmt) {

Review comment:
       Do we need this specialization for `Stmt` at all? `StmtExprMutator::VisitStmt` visits sub `Expr`s, I think.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] FranckQC commented on a change in pull request #9482: Implementation of Common Subexpression Elimination for TIR

Posted by GitBox <gi...@apache.org>.
FranckQC commented on a change in pull request #9482:
URL: https://github.com/apache/tvm/pull/9482#discussion_r798122379



##########
File path: python/tvm/tir/transform/transform.py
##########
@@ -310,6 +310,15 @@ def BF16TypeLowering():
     """
     return _ffi_api.BF16TypeLowering()  # type: ignore
 
+def CommonSubexprElim(enable_cse_tir: bool = True):

Review comment:
       Many thanks for the very detailed @masahi, I really appreciate your comments and your time.
   Ok, so I will keep for now the boolean to disable the pass easily, and we can still remove it in the near future if people really don't want these anymore.
   
   I'm almost ready to send a final commit for this PR.
   Many thanks again!




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] FranckQC commented on a change in pull request #9482: Implementation of Common Subexpression Elimination for TIR

Posted by GitBox <gi...@apache.org>.
FranckQC commented on a change in pull request #9482:
URL: https://github.com/apache/tvm/pull/9482#discussion_r797131033



##########
File path: include/tvm/tir/transform.h
##########
@@ -458,6 +458,14 @@ TVM_DLL Pass FlattenBuffer();
  */
 TVM_DLL Pass TextureFlatten();
 
+/*!
+ * \brief Implements a Common Subexpression Elimination (CSE)
+ *        which introduces let-in bindings for duplicated sub-expressions.
+ * \param enable_cse_tir Whether common subexpression elimination is enabled.
+ * \return The pass.
+ */
+TVM_DLL Pass CommonSubexprElim(bool enable_cse_tir = true);

Review comment:
       I'll change it with a "TIR" suffix, to make that clear :)




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi edited a comment on pull request #9482: Implementation of Common Subexpression Elimination for TIR

Posted by GitBox <gi...@apache.org>.
masahi edited a comment on pull request #9482:
URL: https://github.com/apache/tvm/pull/9482#issuecomment-964681835


   Hi @FranckQC, I wanted TIR-level CSE for a long time, so very excited to see this!
   
   What I wanted to do is to eliminate common expressions that span across the host and GPU - for example, in GPU `sort` kernel, I need to make `log2(N)` GPU kernel calls from the host to sort the input bottom up. In principle, `log2(N)` needs to computed once by the host and pass to the GPU kernel, but since we cannot CSE `log2(N)` expression that appears both in the host and GPU kernel, right now the GPU sort kernel is littered with `log2(N)` compute like this (note a lot of calls to `call_spirv_pure_glsl450` which is totally unnecessary if we had TIR-level CSE) https://gist.github.com/masahi/7a755ef67009e1a836e3212c53cf496f
   
   Is this PR going to solve my problem?
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] FranckQC edited a comment on pull request #9482: Implementation of Common Subexpression Elimination for TIR

Posted by GitBox <gi...@apache.org>.
FranckQC edited a comment on pull request #9482:
URL: https://github.com/apache/tvm/pull/9482#issuecomment-969151523


   > 
   > 
   > Hi @FranckQC, I wanted TIR-level CSE for a long time, so very excited to see this!
   > 
   > What I wanted to do is to eliminate common expressions that span across the host and GPU - for example, in GPU `sort` kernel, I need to make `log2(N)` GPU kernel calls from the host to sort the input bottom up. In principle, `log2(N)` needs to computed once by the host and pass to the GPU kernel, but since we cannot CSE `log2(N)` expression that appears both in the host and GPU kernel, right now the GPU sort kernel is littered with `log2(N)` compute like this (note a lot of calls to `call_spirv_pure_glsl450` which is totally unnecessary if we had TIR-level CSE) https://gist.github.com/masahi/7a755ef67009e1a836e3212c53cf496f
   > 
   > Is this PR going to solve my problem?
   
   Hi @masahi 
   Thanks a lot for the kind words! I'm happy to read that this new pass might be useful to you.
   In principle, every redundant subterms that are eligible for being commoned out (i.e, which does not contain function calls, etc) will be commoned out. There are also a few other minor restrictions which are due to some specifics of TVM, but these are rare.
   
   Unfortunately, on the TIR code snippet that you have uploaded, it seems that the subexpression that is redundant is just a function call. These can't be commoned out into a new variable by the CSE pass as it does not have any guarantee that the function has no side effects, meaning that it will always produce the same outputs for the same inputs. Without this guarantee, commoning out such function calls could change the program's semantics, so it's not done as preserving the semantics of the program is vital.
   
   I can imagine that for functions that are guaranteed to not do any side effects (and which are therefore "functions" in the mathematical sense of the term), we could relax this restriction, but that would be an extension to implement in the future. And it would rely on some "NoSideEffect" tag on functions.
   
   However, please note that if you had some other redundancies, this CSE pass would common out whatever redundancies you have that are eligible.
   For instance : 
   Assume you have the term (f(42) + f(42)) + (x*y + x*y) that appear somewhere.
   It is not eligible in its entirety, as it contains function calls.
   But its subterm (x*y + x*y) is eligible, so this subpart will be commoned out.
   
   Does that answers your question?
   Please do not hesitate to tell me if you need help for trying the pass.
   
   Kind regards.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] FranckQC commented on pull request #9482: Implementation of Common Subexpression Elimination for TIR

Posted by GitBox <gi...@apache.org>.
FranckQC commented on pull request #9482:
URL: https://github.com/apache/tvm/pull/9482#issuecomment-972579646


   > 
   > 
   > Hi @FranckQC, I am also very interested in the CSE pass. In our circumstances we suffer from duplicate index computations produced by loop unroll, like
   > 
   > ```
   > A[i*256 + j*16 + 0] = B[i*256 + j*16 + 4096]
   > A[i*256 + j*16 + 1] = B[i*256 + j*16 + 4097]
   > ...
   > ```
   > 
   > We have to depend on target backend optimize abilities which may or maynot optimize them out. It would be great if CSE can handle part of these things in TIR level.
   
   Hi @wrongtest 
   Thank you, I'm very glad to know that the pass might be useful to you!
   
   Yes, these kind of redundancies should definitely be commoned out by this new CSE pass.
   For the specific example that you have given, I expect the result to be something like :
   
   Let cse_var_1 = i*256 + j*16 in
       A[cse_var_1 + 0] = B[cse_var_1 + 4096];
       A[cse_var_1 + 1] = B[cse_var_1 + 4097];
       
   Do not hesitate to try the pass out, and to let me know if it does what we hope. I'd be happy to help of course if that's needed.
   
   Best regards.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] FranckQC edited a comment on pull request #9482: Implementation of Common Subexpression Elimination for TIR

Posted by GitBox <gi...@apache.org>.
FranckQC edited a comment on pull request #9482:
URL: https://github.com/apache/tvm/pull/9482#issuecomment-972608426


   Many thanks for the interesting questions and comments, I really appreciate it.
   
   > 
   > 
   > I have another three questions about the pass.
   > 
   >     1. Since function calls are non-eligible, it seems to prevent many common non-stateful builtin computations from being optimized, such as `tir.floordiv`, `tir.shift_right`, `tir.likely`, `tir.log` and etc.  TVM already have op annotation like `PURE`:
   >        https://github.com/apache/tvm/blob/main/include/tvm/tir/op_attr_types.h#L60-L67
   >        what if to add an annotation type like `DETERMINISTIC` to mark the function return same output on same input. It should be
   >        safe to eliminate common calls if the call op is both pure and deterministic.
   
   Yes, I definitely agree that we could, in the future, relax the restriction which currently prevent function calls from being commoned-out, for some functions that we know to be pure, including some builtin.
   I think just the "pure" tag would be enough for this purpose, if it has the meaning that I imagine, which is this one (quoted from Wikipedia) :
   
   #################
   a pure function is a function that has the following properties:
   
   1. The function return values are identical for identical arguments (no variation with local static variables, non-local variables, mutable reference arguments or input streams).
   2. The function application has no side effects (no mutation of local static variables, non-local variables, mutable reference arguments or input/output streams).
   
   Thus a pure function is a computational analogue of a mathematical function. Some authors, particularly from the imperative language community, use the term "pure" for all functions that just have the above property 2. 
   #################
   
   What is crucial for us for commoning-out function calls is only the condition 1. As long as the function returns always the same output for the same input, we are fine with putting the result of applying the function to some input into a variable (or any kind of "cache").
   
   Please note that in the Wikipedia article they later mention that the condition 2 is needed for perfoming CSE, but they are talking about a different problem that we are discussing (and they aren't very precise about it). They do not talk about commoning-out function calls here. They are talking about commoning-out some redundant terms, between which some function calls happen, as in :
   x = a+b+c;
   f(); // Here, if f can change the content of some variables like a, b or c, it prevents the CSE on the redundant term (a+b+c)
   y = a+b+c;
   
   But this kind of consideration does not concern us, as TIR is always in a (weak) SSA form : the content of variables won't be updated, and can be assumed to be the same everywhere through the execution of the program. Thus, only condition 1 matters to us.
   
   If the kPure tag has both the meaning of 1 and 2, then we will be to just use it in the future when relaxing the condition.
   If it's only the condition 2, then we will indeed need another tag.
   
   >     2. Is the let binding semantic in tir explicitly what we want? Is the value expr is specified to be evaluated at let scope entry? otherwise there may be no reuse effect if let-binding allow the value expr get lazily expanded in let body.
   
   I must admit that I am not entirely sure now that I think of it... Anyone to confirm on that?
   
   >     3. Is this possible for some complex (but common) expr to be lifted out of control flow scope? eg
   >        ```python
   >        if cond0:
   >            tir.evaluate(complex_e)  # slow path
   >        else:
   >           if cond2: 
   >                tir.evaluate(complex_e)  # slow path
   >           else:
   >                ...  # general fast path
   >        ```
   >        
   >        
   >            
   >              
   >            
   >        
   >              
   >            
   >        
   >            
   >          
   >        after CSE maybe comes to
   >        ```python
   >        x = tir.Var("int32")
   >        with tir.let(x, complex_e):  # always evaluate slow path
   >            if cond0:
   >                tir.evaluate(x) 
   >            else:
   >                if cond2: 
   >                    tir.evaluate(x)
   >                else:
   >                    ... # general fast path
   >        ```
   
   Yes, that is possible. Although, now that I think of it, it would probably better to not do it in the exact example that you provided, as the expression `complex_e` of the slow path now becomes always evaluated, even if we end up going in the general fast branch...
   We should only lift things through "if" statements if both the then branch and the else branch computes it. This restriction doesn't improve the semantics preservation (which was previously fine), but it will avoid this kind of anti-optimization. That should be easily fixable.
   
   Many thanks!


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] Hzfengsy commented on a change in pull request #9482: Implementation of Common Subexpression Elimination for TIR

Posted by GitBox <gi...@apache.org>.
Hzfengsy commented on a change in pull request #9482:
URL: https://github.com/apache/tvm/pull/9482#discussion_r746255687



##########
File path: tests/python/unittest/test_tir_transform_common_subexpr_elim.py
##########
@@ -0,0 +1,127 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import tvm

Review comment:
       Do you consider writing the unittest with TVMScript? It would be easy to write an original prim_func and an expected prim_func after the pass, then `tvm.ir.assert_structural_equal` them.

##########
File path: src/tir/analysis/check_contains.cc
##########
@@ -0,0 +1,98 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+  * \file check_contains.cc
+  * \brief Implementation of the analysis that tells if an expression contains
+            a node that satisfies a given predicate.
+  */
+
+#include "check_contains.h"
+
+#include <tvm/tir/expr.h>
+
+#include <vector>
+
+namespace tvm {
+namespace tir {
+
+/*!
+ * \brief Toplevel (static) function that tells if an expression contains a subexpression that
+          satisfies a given predicate.
+ * \param expr The expression to check
+ * \param predicate The predicate that must be satisfied
+ * \return Whether `expr` contains a subexpression that satisfies `predicate`
+ */
+bool CheckContains::ExprContains(const PrimExpr& expr,
+                                 std::function<bool(const PrimExpr&)> predicate) {
+  CheckContains check_contains(predicate);
+  check_contains.VisitExpr(expr);
+  return check_contains.contains_it_;
+}
+
+/*!
+ * \brief Toplevel (static) function that tells if a statement contains a subexpression that
+          satisfies a given predicate.
+ * \param stmt The statement to check
+ * \param predicate The predicate that must be satisfied
+ * \return Whether `stmt` contains a subexpression that satisfies `predicate`
+ */
+bool CheckContains::StmtContains(const Stmt& stmt, std::function<bool(const PrimExpr&)> predicate) {
+  CheckContains check_contains(predicate);
+  check_contains.VisitStmt(stmt);
+  return check_contains.contains_it_;
+}
+
+/*!
+ * \brief Protected constructor of CheckContains.
+ * \param predicate The predicate that must be satisfied
+ */
+CheckContains::CheckContains(std::function<bool(const PrimExpr&)> predicate)
+    : predicate_(predicate) {}
+
+/*!
+ * \brief The method which overrides the generic dispatcher of StmtExprVisitor for expressions.
+ * \param expr The expression to visit
+ */
+void CheckContains::VisitExpr(const PrimExpr& expr) {
+  // If the predicate holds on `expr`, we know `expr` contains something which makes
+  // the predicate hold
+  if (predicate_(expr)) {
+    contains_it_ = true;
+  } else {
+    // Otherwise we continue to look for it recursively by calling the dispatcher
+    StmtExprVisitor::VisitExpr(expr);
+  }
+}
+
+/*!
+ * \brief The method which overrides the generic dispatcher of StmtExprVisitor for statements.
+ * \param stmt The statement to visit
+ */
+void CheckContains::VisitStmt(const Stmt& stmt) {
+  // We keep exploring only if `contains_it_` is false
+  if (!contains_it_) {
+    // and in order to do that we call the general dispatcher
+    StmtExprVisitor::VisitStmt(stmt);
+  }
+  // As otherwise we already have our answer
+}
+
+}  // namespace tir
+}  // namespace tvm

Review comment:
       Please add a newline at the end of the file for all files




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] FranckQC commented on pull request #9482: Implementation of Common Subexpression Elimination for TIR

Posted by GitBox <gi...@apache.org>.
FranckQC commented on pull request #9482:
URL: https://github.com/apache/tvm/pull/9482#issuecomment-969387318


   > 
   > 
   > A quick round of review.
   > 
   > the changes of vta-hw is not necessary
   
   Thanks a lot for the review!
   I added the missing empty new lines at the end of each new file.
   
   I'll rollback the pointer to the sub-module in another commit.
   
   Many thanks again!


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] FranckQC edited a comment on pull request #9482: Implementation of Common Subexpression Elimination for TIR

Posted by GitBox <gi...@apache.org>.
FranckQC edited a comment on pull request #9482:
URL: https://github.com/apache/tvm/pull/9482#issuecomment-972608426


   Many thanks for the interesting questions and comments, I really appreciate it.
   
   > 
   > 
   > I have another three questions about the pass.
   > 
   >     1. Since function calls are non-eligible, it seems to prevent many common non-stateful builtin computations from being optimized, such as `tir.floordiv`, `tir.shift_right`, `tir.likely`, `tir.log` and etc.  TVM already have op annotation like `PURE`:
   >        https://github.com/apache/tvm/blob/main/include/tvm/tir/op_attr_types.h#L60-L67
   >        what if to add an annotation type like `DETERMINISTIC` to mark the function return same output on same input. It should be
   >        safe to eliminate common calls if the call op is both pure and deterministic.
   
   Yes, I definitely agree that we could, in the future, relax the restriction which currently prevent function calls from being commoned-out, for some functions that we know to be pure, including some builtin.
   I think just the "pure" tag would be enough for this purpose, if it has the meaning that I imagine, which is this one (quoted from Wikipedia) :
   
   #################
   a pure function is a function that has the following properties:
   
   1. The function return values are identical for identical arguments (no variation with local static variables, non-local variables, mutable reference arguments or input streams).
   2. The function application has no side effects (no mutation of local static variables, non-local variables, mutable reference arguments or input/output streams).
   
   Thus a pure function is a computational analogue of a mathematical function. Some authors, particularly from the imperative language community, use the term "pure" for all functions that just have the above property 2. 
   #################
   
   What is crucial for us for commoning-out function calls is only the condition 1. As long as the function returns always the same output for the same input, we are fine with putting the result of applying the function to some input into a variable (or any kind of "cache").
   
   Please note that in the Wikipedia article they later mention that the condition 2 is needed for perfoming CSE, but they are talking about a different problem that we are discussing (and they aren't very precise about it). They do not talk about commoning-out function calls here. They are talking about commoning-out some redundant terms, between which some function calls happen, as in :
   x = a+b+c;
   f(); // Here, if f can change the content of some variables like a, b or c, it prevents the CSE on the redundant term (a+b+c)
   y = a+b+c;
   
   But this kind of consideration does not concern us, as TIR is in a (weak) SSA form : the content of variables won't be updated, and can be assumed to be the same everywhere through the execution of the program. Thus, only condition 1 matters to us.
   
   If the kPure tag has both the meaning of 1 and 2, then we will be to just use it in the future when relaxing the condition.
   If it's only the condition 2, then we will indeed need another tag.
   
   >     2. Is the let binding semantic in tir explicitly what we want? Is the value expr is specified to be evaluated at let scope entry? otherwise there may be no reuse effect if let-binding allow the value expr get lazily expanded in let body.
   
   I must admit that I am not entirely sure now that I think of it... Anyone to confirm on that?
   
   >     3. Is this possible for some complex (but common) expr to be lifted out of control flow scope? eg
   >        ```python
   >        if cond0:
   >            tir.evaluate(complex_e)  # slow path
   >        else:
   >           if cond2: 
   >                tir.evaluate(complex_e)  # slow path
   >           else:
   >                ...  # general fast path
   >        ```
   >        
   >        
   >            
   >              
   >            
   >        
   >              
   >            
   >        
   >            
   >          
   >        after CSE maybe comes to
   >        ```python
   >        x = tir.Var("int32")
   >        with tir.let(x, complex_e):  # always evaluate slow path
   >            if cond0:
   >                tir.evaluate(x) 
   >            else:
   >                if cond2: 
   >                    tir.evaluate(x)
   >                else:
   >                    ... # general fast path
   >        ```
   
   Yes, that is possible. Although, now that I think of it, it would probably better to not do it in the exact example that you provided, as the expression `complex_e` of the slow path now becomes always evaluated, even if we end up going in the general fast branch...
   We should only lift things through "if" statements if both the then branch and the else branch computes it. This restriction doesn't improve the semantics preservation (which was previously fine), but it will avoid this kind of anti-optimization. That should be easily fixable.
   
   Many thanks!


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi commented on pull request #9482: Implementation of Common Subexpression Elimination for TIR

Posted by GitBox <gi...@apache.org>.
masahi commented on pull request #9482:
URL: https://github.com/apache/tvm/pull/9482#issuecomment-1001389312


   @FranckQC Shall we try getting this in? Please take a look at the lint issues, you need to run `clang-format` and `black`. You can use scripts in https://github.com/apache/tvm/tree/main/tests/lint


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi commented on a change in pull request #9482: Implementation of Common Subexpression Elimination for TIR

Posted by GitBox <gi...@apache.org>.
masahi commented on a change in pull request #9482:
URL: https://github.com/apache/tvm/pull/9482#discussion_r790178147



##########
File path: src/tir/transforms/common_subexpr_elim_tools.cc
##########
@@ -0,0 +1,836 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+* \file common_subexpr_elim_tools.cc
+* \brief Implementation of analysis tools and utility functions used
+          by the Common Subexpression Elimination (CSE) pass.
+*/
+
+#include "common_subexpr_elim_tools.h"
+
+#include <tvm/ir/transform.h>  // For the class Pass and the class PassContext
+#include <tvm/runtime/container/string.h>
+#include <tvm/tir/analysis.h>  // For the ExprDeepEqual analysis
+#include <tvm/tir/expr.h>
+#include <tvm/tir/expr_functor.h>
+#include <tvm/tir/function.h>  // For the class PrimFunc
+#include <tvm/tir/stmt.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>  // For the declaration of the pass
+
+#include <algorithm>      // For std::find_if
+#include <unordered_map>  // For the hashtable datatype
+#include <vector>
+
+#include "../analysis/check_contains.h"  // For the CheckContains analysis
+
+namespace tvm {
+namespace tir {
+
+// cache_ is a static variable of the class ComputationsDoneBy, and C++ requires to define here
+// such static attribute, otherwise it causes a linking error.
+CacheOfComputations ComputationsDoneBy::cache_;
+
+/* ********************************** Class ComputationsDoneBy **********************************
+*********************************************************************************************** */
+
+/* This utility class of the CSE pass offers a way of knowing the eligible computations done by a
+   statement or expression. A "computation" here is a syntatical entity, represented by a PrimExpr.
+   This analysis returns a hashtable associating PrimExpr (a computation done) to a number (which
+   is the number of time that this computation is being seen).
+   This analysis is used by the CSE pass in order to find potential candidates for being introduced
+   into new variables (after having merged semantically equivalent computations).
+
+   This analysis is parametrized by two predicates : `is_eligible_computation` and
+   `can_contain_computations`.
+   The first one helps to select only "eligible" computations, and the second one helps to only
+   select computations that are located at appropriate location (i.e., it tells in which nodes the
+   analysis can recurse). The user of the class must define these notions of "eligible computation"
+   and of "nodes that can contain eligibile computations" for his own use case.
+
+   - On an statement, this analysis often returns the union of all the computations that appear in
+   its child nodes (ie, the union of the results of the recursive calls).
+   For instance, on the input statement [let a = x+y in Mem[i1+i2] = a+b] it will report (x+y)
+   seen once, (i1+i2) seen once, and (a+b) also seen once when used with typical predicates.
+   On some nodes, it will return something more complicated that uses the intersection of the
+   computations done by the children nodes.
+   For instance, on the input statement [if (x+y>z) then a = x+y else a = b-x] it will return
+   (x+y) seen twice but it won't report b-x as is it seen only the else branch.
+
+   - On an expression, this analysis returns the expression itself, except if it is not eligible
+   for being introduced by the CSE pass into a variable according to `is_eligible_computation_`
+   (often because it's a load node or a function call node for instance), in which case it will
+   return the union of the recursive calls on its children, as long as the other predicate
+   `can_contain_computations` evaluates to true to let the algorithm recurse deeper.
+   With typical predicates, on the expression ((w+x)+(y+z)) it will return only the expression
+   itself. But on the expression Load[i1+i2] it might return only (i1+i2) as the full Load node
+   might not be eligible.
+
+   This class uses an internal cache of results, so that if one queries it several times on the
+   same statement or expression, it will just retrieve the result from its internal cache.
+   That avoids some systematic recomputations, which would otherwise happen as the CSE pass first
+   analyses the program at the toplovel (asking for the computations done by the root), and then
+   dives deeper and deeper into the program, asking for the computations done by the children of
+   the root, which were necessarly previously obtained when computing the computations done by the
+   root (as the computations done by the root are by definition the union of the computations done
+   by the children nodes).
+
+   The somehow difficult aspect of the implementation is the interaction between this caching of
+   results, and the fact that the VisitStmt()/VisitExpr() of an analyzer (a StmtExprVisitor) are
+   void methods which can't return anything, and instead need to accumulate a result into a member
+   variable, which is called `table_of_computations_` here.
+
+   In particular, as the specialized methods (all the VisitStmt_() and VisitExpr_() methods), just
+   call VisitStmt()/VisitExpr() on all the children nodes within the same instance, if we don't
+   want to override each of these specialized methods to change this behaviour, then
+   `table_of_computations_` will necessary be shared by all the children of a given nodes.
+   That requires to be careful when trying to write into the cache.
+*/
+
+/*!
+ * \brief Does the union of two tables of computations.
+ * \param table_main One of the two tables. The union will be written into it.
+ * \param table_aux The other table, which won't change.
+ * \note Does it directly in the first argument A for efficiency, as the union of A and B
+ *       necessarily gives something which contains A, so we avoid its copy.
+ */
+void UnionOf2TablesOfComputations(TableOfComputations& table_main,
+                                  const TableOfComputations& table_aux) {
+  // Adds each element of the second table to the first one
+  for (const auto& current : table_aux) {
+    table_main[current.first] += current.second;
+  }
+}
+
+/*!
+ * \brief Does the union of three tables of computations.
+ * \param table1 One of the three tables, which won't change.
+ * \param table2 One of the three tables, which won't change.
+ * \param table3 One of the three tables, which won't change.
+ * \note We don't need (at least yet) to have a function working for N tables, even if this
+ *       function for 3 tables seems at first glance redundant with the one for 2 tables defined
+ *       just above. The reason is that in order to do the union for N tables, we need to know how
+ *       to do it for two. That's because we would compute for N tables using the associativity
+ *       of the union : T1 U T2 U T3 ... U Tn = ((T1 U T2) U T3) ... U Tn
+ *       Therefore, we need one for 2 tables anyway. And as the one for N (N>=3) would only be used
+ *       (at least for now) for N=3, there is at the moment no need for such a generic union over
+ *       N tables.
+ */
+TableOfComputations UnionOf3TablesOfComputations(const TableOfComputations& table1,
+            const TableOfComputations& table2, const TableOfComputations& table3) {
+  TableOfComputations result = table1; // Copy needed as the union of 2 writes into its first arg
+  UnionOf2TablesOfComputations(result, table2);
+  UnionOf2TablesOfComputations(result, table3);
+
+  return result;
+}
+
+/*!
+ * \brief Does the intersection of two tables of computations.
+ * \param table1 One of the two tables, which won't change.
+ * \param table2 The other table, which also won't change.
+ */
+TableOfComputations IntersectionOf2TablesOfComputations(const TableOfComputations& table1,
+                                                        const TableOfComputations& table2) {
+  TableOfComputations result;
+  for (const auto& current : table1) {
+    auto it = table2.find(current.first);
+    if (it != table2.end()) {
+      result[current.first] = current.second + it->second;
+    }
+  }
+  return result;
+}
+
+/*!
+ * \brief Does the intersection of three tables of computations.
+ * \param table1 One of the three tables, which won't change.
+ * \param table2 One of the three tables, which won't change.
+ * \param table3 One of the three tables, which won't change.
+ * \note We don't need (at least yet) to have a function working for N tables, even if this
+ *       function for 3 tables seems at first glance redundant with the one for 2 tables defined
+ *       just above. The reason is that in order to do the intersection for N tables, we need to
+ *       know how to do it for two. That's because we would compute for N tables using the
+ *       associativity of the intersection : T1 Inter T2 Inter T3 ... Inter Tn
+ *       = ((T1 Inter T2) Inter T3) ... Inter Tn
+ *       Therefore, we need one for 2 tables anyway. And as the one for N (N>=3) would only be used
+ *       (at least for now) for N=3, there is at the moment no need for such a generic intersection
+ *       over N tables.
+ */
+TableOfComputations IntersectionOf3TablesOfComputations(const TableOfComputations& table1,

Review comment:
       How about renaming both `IntersectionOf2TablesOfComputations` and `IntersectionOf3TablesOfComputations` to `IntersectComputationTables` to make it less verbose. The same comment for `UnionOf3TablesOfComputations`.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi commented on a change in pull request #9482: Implementation of Common Subexpression Elimination for TIR

Posted by GitBox <gi...@apache.org>.
masahi commented on a change in pull request #9482:
URL: https://github.com/apache/tvm/pull/9482#discussion_r790200238



##########
File path: python/tvm/tir/transform/transform.py
##########
@@ -310,6 +310,15 @@ def BF16TypeLowering():
     """
     return _ffi_api.BF16TypeLowering()  # type: ignore
 
+def CommonSubexprElim(enable_cse_tir: bool = True):

Review comment:
       There is `disabled_pass` option, so you don't need `enable_cse_tir` flag. 
   
   See https://github.com/apache/tvm/blob/52039c910aca2d54772b8960d7d868fb75f8936e/tests/python/relay/utils/external_codegen.py#L76




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi commented on a change in pull request #9482: Implementation of Common Subexpression Elimination for TIR

Posted by GitBox <gi...@apache.org>.
masahi commented on a change in pull request #9482:
URL: https://github.com/apache/tvm/pull/9482#discussion_r790201656



##########
File path: src/tir/transforms/common_subexpr_elim_tools.cc
##########
@@ -0,0 +1,836 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+* \file common_subexpr_elim_tools.cc
+* \brief Implementation of analysis tools and utility functions used
+          by the Common Subexpression Elimination (CSE) pass.
+*/
+
+#include "common_subexpr_elim_tools.h"
+
+#include <tvm/ir/transform.h>  // For the class Pass and the class PassContext
+#include <tvm/runtime/container/string.h>
+#include <tvm/tir/analysis.h>  // For the ExprDeepEqual analysis
+#include <tvm/tir/expr.h>
+#include <tvm/tir/expr_functor.h>
+#include <tvm/tir/function.h>  // For the class PrimFunc
+#include <tvm/tir/stmt.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>  // For the declaration of the pass
+
+#include <algorithm>      // For std::find_if
+#include <unordered_map>  // For the hashtable datatype
+#include <vector>
+
+#include "../analysis/check_contains.h"  // For the CheckContains analysis
+
+namespace tvm {
+namespace tir {
+
+// cache_ is a static variable of the class ComputationsDoneBy, and C++ requires to define here
+// such static attribute, otherwise it causes a linking error.
+CacheOfComputations ComputationsDoneBy::cache_;
+
+/* ********************************** Class ComputationsDoneBy **********************************
+*********************************************************************************************** */
+
+/* This utility class of the CSE pass offers a way of knowing the eligible computations done by a
+   statement or expression. A "computation" here is a syntatical entity, represented by a PrimExpr.
+   This analysis returns a hashtable associating PrimExpr (a computation done) to a number (which
+   is the number of time that this computation is being seen).
+   This analysis is used by the CSE pass in order to find potential candidates for being introduced
+   into new variables (after having merged semantically equivalent computations).
+
+   This analysis is parametrized by two predicates : `is_eligible_computation` and
+   `can_contain_computations`.
+   The first one helps to select only "eligible" computations, and the second one helps to only
+   select computations that are located at appropriate location (i.e., it tells in which nodes the
+   analysis can recurse). The user of the class must define these notions of "eligible computation"
+   and of "nodes that can contain eligibile computations" for his own use case.
+
+   - On an statement, this analysis often returns the union of all the computations that appear in
+   its child nodes (ie, the union of the results of the recursive calls).
+   For instance, on the input statement [let a = x+y in Mem[i1+i2] = a+b] it will report (x+y)
+   seen once, (i1+i2) seen once, and (a+b) also seen once when used with typical predicates.
+   On some nodes, it will return something more complicated that uses the intersection of the
+   computations done by the children nodes.
+   For instance, on the input statement [if (x+y>z) then a = x+y else a = b-x] it will return
+   (x+y) seen twice but it won't report b-x as is it seen only the else branch.
+
+   - On an expression, this analysis returns the expression itself, except if it is not eligible
+   for being introduced by the CSE pass into a variable according to `is_eligible_computation_`
+   (often because it's a load node or a function call node for instance), in which case it will
+   return the union of the recursive calls on its children, as long as the other predicate
+   `can_contain_computations` evaluates to true to let the algorithm recurse deeper.
+   With typical predicates, on the expression ((w+x)+(y+z)) it will return only the expression
+   itself. But on the expression Load[i1+i2] it might return only (i1+i2) as the full Load node
+   might not be eligible.
+
+   This class uses an internal cache of results, so that if one queries it several times on the
+   same statement or expression, it will just retrieve the result from its internal cache.
+   That avoids some systematic recomputations, which would otherwise happen as the CSE pass first
+   analyses the program at the toplovel (asking for the computations done by the root), and then
+   dives deeper and deeper into the program, asking for the computations done by the children of
+   the root, which were necessarly previously obtained when computing the computations done by the
+   root (as the computations done by the root are by definition the union of the computations done
+   by the children nodes).
+
+   The somehow difficult aspect of the implementation is the interaction between this caching of
+   results, and the fact that the VisitStmt()/VisitExpr() of an analyzer (a StmtExprVisitor) are
+   void methods which can't return anything, and instead need to accumulate a result into a member
+   variable, which is called `table_of_computations_` here.
+
+   In particular, as the specialized methods (all the VisitStmt_() and VisitExpr_() methods), just
+   call VisitStmt()/VisitExpr() on all the children nodes within the same instance, if we don't
+   want to override each of these specialized methods to change this behaviour, then
+   `table_of_computations_` will necessary be shared by all the children of a given nodes.
+   That requires to be careful when trying to write into the cache.
+*/
+
+/*!
+ * \brief Does the union of two tables of computations.
+ * \param table_main One of the two tables. The union will be written into it.
+ * \param table_aux The other table, which won't change.
+ * \note Does it directly in the first argument A for efficiency, as the union of A and B
+ *       necessarily gives something which contains A, so we avoid its copy.
+ */
+void UnionOf2TablesOfComputations(TableOfComputations& table_main,
+                                  const TableOfComputations& table_aux) {
+  // Adds each element of the second table to the first one
+  for (const auto& current : table_aux) {
+    table_main[current.first] += current.second;
+  }
+}
+
+/*!
+ * \brief Does the union of three tables of computations.
+ * \param table1 One of the three tables, which won't change.
+ * \param table2 One of the three tables, which won't change.
+ * \param table3 One of the three tables, which won't change.
+ * \note We don't need (at least yet) to have a function working for N tables, even if this
+ *       function for 3 tables seems at first glance redundant with the one for 2 tables defined
+ *       just above. The reason is that in order to do the union for N tables, we need to know how
+ *       to do it for two. That's because we would compute for N tables using the associativity
+ *       of the union : T1 U T2 U T3 ... U Tn = ((T1 U T2) U T3) ... U Tn
+ *       Therefore, we need one for 2 tables anyway. And as the one for N (N>=3) would only be used
+ *       (at least for now) for N=3, there is at the moment no need for such a generic union over
+ *       N tables.
+ */
+TableOfComputations UnionOf3TablesOfComputations(const TableOfComputations& table1,
+            const TableOfComputations& table2, const TableOfComputations& table3) {
+  TableOfComputations result = table1; // Copy needed as the union of 2 writes into its first arg
+  UnionOf2TablesOfComputations(result, table2);
+  UnionOf2TablesOfComputations(result, table3);
+
+  return result;
+}
+
+/*!
+ * \brief Does the intersection of two tables of computations.
+ * \param table1 One of the two tables, which won't change.
+ * \param table2 The other table, which also won't change.
+ */
+TableOfComputations IntersectionOf2TablesOfComputations(const TableOfComputations& table1,
+                                                        const TableOfComputations& table2) {
+  TableOfComputations result;
+  for (const auto& current : table1) {
+    auto it = table2.find(current.first);
+    if (it != table2.end()) {
+      result[current.first] = current.second + it->second;
+    }
+  }
+  return result;
+}
+
+/*!
+ * \brief Does the intersection of three tables of computations.
+ * \param table1 One of the three tables, which won't change.
+ * \param table2 One of the three tables, which won't change.
+ * \param table3 One of the three tables, which won't change.
+ * \note We don't need (at least yet) to have a function working for N tables, even if this

Review comment:
       Don't need to repeat this explanation




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi commented on a change in pull request #9482: Implementation of Common Subexpression Elimination for TIR

Posted by GitBox <gi...@apache.org>.
masahi commented on a change in pull request #9482:
URL: https://github.com/apache/tvm/pull/9482#discussion_r790201222



##########
File path: src/tir/transforms/replace_expr_selected.h
##########
@@ -0,0 +1,75 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file replace_expr_selected.h
+ * \brief Interface of the pass that replaces in a statement
+           or expression all the subexpressions that are selected
+           with a predicate by another expression.
+ */
+
+#ifndef TVM_TIR_TRANSFORMS_REPLACE_SELECTED_H_
+#define TVM_TIR_TRANSFORMS_REPLACE_SELECTED_H_
+
+#include <tvm/tir/expr.h>
+#include <tvm/tir/expr_functor.h>
+#include <tvm/tir/stmt.h>
+#include <tvm/tir/stmt_functor.h>  // For the class StmtExprMutator
+
+namespace tvm {
+namespace tir {
+
+/*!
+ * \brief Mutator for replacing the expressions selected by a predicate in a statement and/or
+          in an expression, which only replace inside of nodes in which it is allowed to perform
+          replacecements (given by a second predicate)
+ */
+class ReplaceExprSelected : public StmtExprMutator {

Review comment:
       looks like `ExprMutator` is enough for the base class.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi commented on a change in pull request #9482: Implementation of Common Subexpression Elimination for TIR

Posted by GitBox <gi...@apache.org>.
masahi commented on a change in pull request #9482:
URL: https://github.com/apache/tvm/pull/9482#discussion_r797245853



##########
File path: src/tir/transforms/common_subexpr_elim.cc
##########
@@ -0,0 +1,601 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file common_subexpr_elim.cc
+ * \brief Implementation of the Common Subexpressions Elimination (CSE) pass
+           which rewrites statements and expressions in order to eliminate
+           redundant computations. In order to achieve that, common (sub-)
+           expressions are introduced into variables with let-in bindings,
+           and the places where the expression was used are replaced with
+           the freshly introduced variable.
+ */
+
+#include "common_subexpr_elim.h"
+
+#include <tvm/ir/transform.h>  // For the class Pass and the class PassContext
+#include <tvm/runtime/container/array.h>
+#include <tvm/runtime/container/string.h>
+#include <tvm/tir/analysis.h>  // For the analysis which gives the size of an expr
+#include <tvm/tir/expr.h>
+#include <tvm/tir/expr_functor.h>
+#include <tvm/tir/function.h>  // For the class PrimFunc
+#include <tvm/tir/stmt.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>  // For the decl of the function returning the pass
+
+#include <algorithm>  // For the algorithm std::find
+#include <iostream>
+#include <unordered_map>  // For the hashtable datatype
+#include <utility>        // For std::pair and std::move
+#include <vector>
+
+#include "../analysis/check_contains.h"  // For the visitor CheckContains
+#include "common_subexpr_elim_tools.h"   // For the auxiliary analysis (visitors) and tools
+#include "replace_expr_selected.h"       // For the mutator ReplaceExprSelected
+
+namespace tvm {
+namespace tir {
+
+/*!
+ * \brief Check whether a computation is forbidden for being treated by the CSE pass.
+          The important thing about forbidden computations is that not only we won't want
+          to collect them for the CSE pass, but we also won't even want to collect computations
+          that contain them.
+          The reason is that reusing such computations would change the semantics of the program,
+          and therefore before doing any introduction of variable or any reuse of already introduced
+          variables, we will make sure that the computation being considered is not forbidden, and
+          that it does not even contain a forbidden computation.
+ * \param expr The expression to check
+ * \return Whether `expr` is a forbidden computation or not
+ */
+bool CommonSubexpressionEliminator::ForbiddenComputation(const PrimExpr& expr) {
+  // Function calls, loads and buffer loads are absolutely forbidden as introducing them into
+  // variables would change the semantics of the program.
+  return (expr.as<CallNode>() != nullptr || expr.as<LoadNode>() != nullptr ||
+          expr.as<BufferLoadNode>() != nullptr);
+}
+
+/*!
+ * \brief Predicate used for verifying that a computation is eligible for being treated by
+          the CSE pass, i.e. for being introduced into a variable / for being replaced by a
+          variable.
+          Being eligible is a conjunction of a few conditions, like not being an atom (constant
+          or variable), not being a forbidden node, not containing a forbidden node, etc.
+ * \param expr The expression to check
+ * \return Whether `expr` is an eligible computation or not
+ */
+bool CommonSubexpressionEliminator::IsEligibleComputation(const PrimExpr& expr) {
+  return (
+      // In order to be eligible, the given expression should not be a constant
+      (expr.as<IntImmNode>() == nullptr) && (expr.as<FloatImmNode>() == nullptr) &&
+      (expr.as<StringImmNode>() == nullptr)
+      // and it should not be a variable
+      && (expr.as<VarNode>() == nullptr)
+      // and it should not be a forbidden computation (function calls and loads)
+      && (!ForbiddenComputation(expr))
+      // and it should not even contain a forbidden computation (function calls and loads)
+      // the reason is that we don't want to register expressions like (x + f(y)) or
+      // (x + Mem[i]) as introducing them into variables could change the semantics
+      && (!CheckContains::ExprContains(expr, ForbiddenComputation))
+      // and it should not be a ramp node or a broadcast node due to some internals TVM
+      // constraints (which check for these node explicitely without performing any
+      // evaluation first, so if they have been put into variables it fails)
+      && (expr.as<RampNode>() == nullptr) && (expr.as<BroadcastNode>() == nullptr));
+}
+
+/*!
+ * \brief Predicate used (when considering eligible computations) for only diving into
+          expressions that are allowed to contain eligible computations. Customize this predicate
+          if you want to make it forbidden to rewrite inside a specific node, like inside
+          a Load node for instance.
+ * \param expr The expression to check
+ * \return Whether `expr` can contain some eligible computations or not, and therefore
+             if recursing inside `expr` is necessary.
+ */
+bool CommonSubexpressionEliminator::CanContainEligibleComputations(const PrimExpr& expr) {
+  // Uncomment the next line to prevent the collection and the replacement of eligible computations
+  // inside the index of Load nodes. We initially thought that this would be needed in order to
+  // not harm the indexing mode of the CPU, but as we are still far from ASM code, we
+  // finally want to perform such simplifications, which tend to happen fairly frequently.
+
+  // return ( (expr.as<LoadNode>() == nullptr) && (expr.as<BufferLoadNode>() == nullptr) )
+  return true;
+};
+
+/*!
+ * \brief Generates a new fresh variable, whose name will be cse_var_i.
+ * \param type_annotation The type of the new variable to generate
+ * \return A new variable of type `type_annotation` called cse_var_i where i is the first available
+            integer.
+ */
+Var CommonSubexpressionEliminator::GenerateNewVar(DataType type_annotation) {
+  // Increase `num_last_try_` for this new attempt
+  num_last_try_++;
+  // Builds the variable name, which is sce_var_i where i will go up from 1
+  std::string prefix = "cse_var_";
+  std::string name = prefix.append(std::to_string(num_last_try_));
+  // Builds a String using the std::string
+  String string_name(name);
+
+  // Check that the name that we want to use for the new variable isn't already being used
+  // (names don't really have to be unique as they are just hints, and having the same name
+  // doesn't means that it's the same variable, but it's clearer for dumps)
+  if (UsesVarName::StmtUsesVarName(initial_body_, string_name)) {
+    // If the name is already used, call ourselves recursively for trying with the next one
+    return GenerateNewVar(type_annotation);
+  }
+
+  // Increase `nb_var_` for this new generation of variable that we have just done
+  nb_var_++;
+
+  // Return a new Variable using the name built and the given type_annotation
+  return (Var(string_name, type_annotation));
+}
+
+/*!
+ * \brief Gives the number of variables generated by the CSE on the current function
+           (i.e., getter for `nb_var_`).
+ * \return A copy of `nb_var_`
+ */
+int CommonSubexpressionEliminator::GetNbVarGenerated() { return nb_var_; }
+
+/*!
+ * \brief Toplevel (static) method that performs Common Subexpression Elimination on
+          a given statement (which should be the body of a PrimFunc). This method should be
+          called for each PrimFunc definition.
+ * \param stmt The statement of the function being analyzed, on which we want to perform CSE
+ * \param context_init The initial context, which should contain the formal parameters
+                          of the function being analyzed
+ * \return A new statement where CSE has been performed
+ */
+Stmt CommonSubexpressionEliminator::PerformCSE(const Stmt& stmt, const Context& context_init) {
+  // As this function is being called for each PrimFunc definition, we create a new instance
+  // for the one we are having now.
+  CommonSubexpressionEliminator common_subexpression_eliminator(stmt, context_init);
+  return common_subexpression_eliminator.VisitStmt(stmt);
+}
+
+/*!
+ * \brief Protected constructor of CommonSubexpressionEliminator.
+ * \param context_init The context at the begining of the CSE pass. It should contain the
+                        formal parameters of the function that will be analyzed
+ */
+CommonSubexpressionEliminator::CommonSubexpressionEliminator(const Stmt& stmt,
+                                                             const Context& context_init)
+    : initial_body_(stmt), context_(context_init) {}
+
+/*!
+ * \brief The method which overrides the generic dispatcher of StmtExprMutator.
+          Entry point to the common subexpression elimination mutator for expressions.
+ * \param expr The expression to mutate
+ */
+PrimExpr CommonSubexpressionEliminator::VisitExpr(const PrimExpr& expr) {
+  PrimExpr result = expr;
+
+  // Obtain the (syntactic) eligible computations done by the input expression, and keep it as
+  // a TableOfComputations, which is a mapping from PrimExpr to size_t, where the size_t is the
+  // number of time this exact syntactic computation is being computed.
+  TableOfComputations table_syntactic_comp_done_by_expr = ComputationsDoneBy::GetComputationsDoneBy(
+      expr, IsEligibleComputation, CanContainEligibleComputations);
+
+  // Transform the hashtable of *syntactic* eligible computations into a vector of pairs
+  // containing *semantic* entities, i.e. where equivalent computations are merged.
+  std::vector<std::pair<PrimExpr, size_t>> semantic_comp_done_by_expr =
+      SyntacticToSemanticComputations(table_syntactic_comp_done_by_expr);
+
+  // Sort the vector of semantic entities by decreasing size
+  std::sort(semantic_comp_done_by_expr.begin(), semantic_comp_done_by_expr.end(),
+            [](std::pair<PrimExpr, size_t> a, std::pair<PrimExpr, size_t> b) {
+              return (CalculateExprComplexity(a.first) > CalculateExprComplexity(b.first));
+            });
+
+  // For each computation done (considering them from biggest to smallest)
+  for (size_t i = 0; i < semantic_comp_done_by_expr.size(); i++) {
+    std::pair<PrimExpr, size_t>& computation_and_nb = semantic_comp_done_by_expr[i];
+
+    // The predicate later used (when doing replacements) to select expressions that are
+    // equivalent to the current computation (`computation_and_nb.first`)
+    std::function<bool(const PrimExpr&)> predicate_selector =
+        [computation_and_nb](const PrimExpr& current_expr) {
+          // `current_expr` should be equivalent to `computation_and_nb.first`, but we also check
+          // that `current_expr` is an eligible computation even if we know that
+          // `computation_and_nb.first` is eligible by construction, in case that one day the
+          // equivalence relation would not preserve the eligibility any more (even though that
+          // would probably be a very weird equivalence).
+          return (EquivalentTerms(current_expr, computation_and_nb.first) &&
+                  IsEligibleComputation(current_expr));
+        };
+
+    // See if there is a pair (`var`, `value`) in the context where `value` is semantically
+    // equivalent to `computation_and_nb.first`
+    auto it_on_var = std::find_if(
+        context_.begin(), context_.end(),
+        [computation_and_nb](const std::pair<Var, MaybeValue>& var_and_value) {
+          // Note : safe to call value() as we check has_value() just before
+          return (var_and_value.second.has_value() &&
+                  EquivalentTerms(var_and_value.second.value(), computation_and_nb.first));
+        });
+
+    // Case where we have a perfectly equivalent computation already available in a variable
+    // introduced (i.e, present in context_).
+    // Note that this case is needed when the user has written something like
+    // [let x = A in ....A...A...] : we need to be able to replace all the occurences of A by
+    // an already existing variable holding A, when such a variable happens to exist.
+    if (it_on_var != context_.end()) {
+      // Replace in the current `result` everything that is selected by the selector with
+      // the existing variable, without diving into expressions in which we don't have the
+      // right to dive.
+      result = ReplaceExprSelected::ReplaceExprSelectedInExpr(
+          result, predicate_selector, it_on_var->first, CanContainEligibleComputations);
+    } else {
+      // The current computation is not equivalent to a computation already done. We will
+      // need to see if we want to introduce it.
+
+      // --- Chunk needed for reusing the UndefinedVars() analysis ---
+      // 1 - Wraps the computation into a statement
+      Stmt computation_wrapped_in_stmt = Evaluate(computation_and_nb.first);
+      // 2.1 - Transform the context into a vector of variables instead of pairs
+      std::function<Var(const std::pair<Var, MaybeValue>&)> forget_value =
+          [](const std::pair<Var, MaybeValue>& pair) { return pair.first; };
+      std::vector<Var> vector_vars_known = VectorMap(context_, forget_value);
+      // 2.2 - Transform the std::vector into an Array
+      Array<Var> array_vars_known = Array<Var>(vector_vars_known);
+      // --- End of chunk needed for reusing the UndefinedVars() analysis ---
+
+      // We use the UndefinedVars() analysis to get the undefined vars of the computation
+      Array<Var> vars_undefined = UndefinedVars(computation_wrapped_in_stmt, array_vars_known);
+
+      // Check if we can introduce it : if it contains no undefined variables and if we want
+      // to introduce it according to the predicate
+      if (vars_undefined.empty() &&
+          PredicateIntroVarForComputation(computation_and_nb.first, computation_and_nb.second)) {
+        // Create a new variable for this computation
+        Var new_var = GenerateNewVar(computation_and_nb.first.dtype());
+        // Replace in the current `result` everything that is selected by the selector with
+        // the new variable, without diving into expressions in which we don't have the
+        // right to dive.
+        result = ReplaceExprSelected::ReplaceExprSelectedInExpr(result, predicate_selector, new_var,
+                                                                CanContainEligibleComputations);
+        // Build a let-in that introduces the new variable in the current `result`
+        result = Let(new_var, computation_and_nb.first, result);
+        // We don't add the variable to the context because the invariant is that the
+        // context is the context in which 'result' makes sense, and we've just updated it.
+      } else {
+        // Here it's not doable to introduce (via a let-in) the computation at this level
+        // as it contains variables that are not yet declared, and/or because the predicate
+        // did not select it.
+        // Either way, we will simply add to the vector of computations the direct subexprs
+        // of the current computation, as these ones might be good candidates
+        // for being introduced into variables.
+        // Note that we don't need to add all of its subexpressions, but only its *direct*
+        // subexpressions as we consider them from biggest to smallest, and if they were
+        // all added at once, then there could be dependencies between them, as commoning
+        // one of them could remove some other possibilities.
+
+        // Computing the direct subexpressions will return a small number of direct
+        // subexpressions (typically 0 to 3)
+        std::vector<PrimExpr> direct_subexprs = DirectSubexpr::GetDirectSubexpressions(
+            computation_and_nb.first, IsEligibleComputation, CanContainEligibleComputations);
+        // The following insertion will maintain `semantic_comp_done_by_expr` sorted (by
+        // decreasing size/complexity), and it will only insert at locations > i as the
+        // direct subexprs are necessarily smaller than the current computation.
+        InsertVectorToSortedSemanticComputations(semantic_comp_done_by_expr, direct_subexprs);
+      }
+    }
+    // Note : we do not remove the current element, as we never look back in the local vector
+  }  // End of for loop
+
+  // Calling the dispatcher to the specific treatments, which will update the context
+  // appropriately before doing the recursive calls on the child nodes
+  result = StmtExprMutator::VisitExpr(result);
+
+  return result;
+}
+
+/*!
+ * \brief The method which overrides the specific treatment for a LetNode
+ */
+PrimExpr CommonSubexpressionEliminator::VisitExpr_(const LetNode* op) {
+  // At this point, we have already done the generic treatment of introducing (via let-in) what
+  // was doable at the toplevel of the given let-in.
+
+  // Save the context at the entry of the function
+  Context context_at_entry = context_;
+
+  // Recurse on the `value` field for potentially rewriting it
+  PrimExpr value_new = VisitExpr(op->value);
+
+  // Augment the context with the association (`var`, `value`) for preparing the next recursion
+  // on the `body`
+  context_.push_back({op->var, MaybeValue(op->value)});
+
+  // Recurse on the `body` (with this extended context)
+  // The recursive call will have potentially done new simplifications, because in this recursive
+  // call `var` will be a part of the context.
+  // (see in VisitExpr() that no introduction were performed when a computation was using an
+  // undefined variable, as that would lead to ill-formed code)
+  PrimExpr body_new = VisitExpr(op->body);
+
+  // Restaure the context to its content at the entrance to not carry out of scope declarations
+  // as the variable introduced by the let-in is not in scope outside of its body
+  context_ = context_at_entry;
+
+  // Rebuild the let-in with a new `value_new` and `body_new` where new simplifications might
+  // have been done.
+
+  // If the `value` and the `body` of the let-in have been rewritten to the same thing
+  if (value_new.same_as(op->value) && body_new.same_as(op->body)) {
+    // then return a reference to the same node
+    return GetRef<PrimExpr>(op);
+  } else {
+    // Otherwise return a let-in built with the new `value_new` and the new `body_new` that
+    // have just been obtained
+    return Let(op->var, value_new, body_new, op->span);
+  }
+}
+
+/*!
+ * \brief The method which overrides the generic dispatcher of StmtExprMutator.
+          Entry point to the common subexpression elimination mutator for statements.
+ * \param stmt The statement to mutate.
+ */
+Stmt CommonSubexpressionEliminator::VisitStmt(const Stmt& stmt) {

Review comment:
       Thanks for the detailed write-up.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] FranckQC commented on a change in pull request #9482: Implementation of Common Subexpression Elimination for TIR

Posted by GitBox <gi...@apache.org>.
FranckQC commented on a change in pull request #9482:
URL: https://github.com/apache/tvm/pull/9482#discussion_r797943017



##########
File path: python/tvm/tir/transform/transform.py
##########
@@ -310,6 +310,15 @@ def BF16TypeLowering():
     """
     return _ffi_api.BF16TypeLowering()  # type: ignore
 
+def CommonSubexprElim(enable_cse_tir: bool = True):

Review comment:
       Thanks for this remark! I am a little bit confused about things regarding activating/deactivating a pass at the moment.
   If I understand well, there is two ways to disable a pass :
   
   1) In driver_api.cc, we have the registration of options:
   ```
   // Register build pipeline related options
   [...]
   TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_vectorize", Bool);
   TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_storage_rewrite", Bool);
   TVM_REGISTER_PASS_CONFIG_OPTION("tir.is_entry_func", Bool);
   TVM_REGISTER_PASS_CONFIG_OPTION("tir.add_lower_pass", Array<Array<ObjectRef>>);
   TVM_REGISTER_PASS_CONFIG_OPTION("tir.debug_keep_trivial_loop", Bool);
   ```
   
   Then in Array<tvm::transform::Pass> CreatePassList(), we get these booleans from the `config`:
   ```
     bool disable_vectorize = pass_ctx->GetConfig<Bool>("tir.disable_vectorize", Bool(false)).value();
     bool disable_storage_rewrite =
         pass_ctx->GetConfig<Bool>("tir.disable_storage_rewrite", Bool(false)).value();
     bool instrument_bound_checkers =
         pass_ctx->GetConfig<Bool>("tir.instrument_bound_checkers", Bool(false)).value();
     bool disable_cse_tir = pass_ctx->GetConfig<Bool>("tir.disable_cse_tir", Bool(false)).value();
   ```
   
   And finally when we push the passes into the `pass_list`, we use the boolean that we got.
   For instance:
   ```
   [...]
    pass_list.push_back(tir::transform::VectorizeLoop(!disable_vectorize));
    pass_list.push_back(tir::transform::CommonSubexprElimTIR(!disable_cse_tir));
   ```
   
   That makes sense to me.
   Regarding what you said about the list of disabled pass, I assume that when a pass is mentioned in the disabled pass list, the context will get this information (somehow), and so the information will later flow through the GetConfig() showed above.
   Am I correct on this?
   
   2) In addition, some passes, like Vectorize, expose a boolean to the python interface for activating/deactivating them.
   In transform.py:
   ```
   def VectorizeLoop(enable_vectorize: bool = True):
       """Lower vectorization loops.
       [...]
       return _ffi_api.VectorizeLoop(enable_vectorize)  # type: ignore
   ```
   
   I decided to do the same for the CSE. I though that would make it easier to disable the pass from the Python bindings, for instance for doing some experiments. If I remember well I was told that exposing the boolean would be useful (perhaps for the end-to-end compilation?).
   Why is it done for Vectorize exactly? Should I keep it for the CSE?
   
   Side note :
   The pass LoopPartition is quite weird, as its boolean for activation/deactivation is passed directly to the `CreatePassList()` method:
   `Array<tvm::transform::Pass> CreatePassList(bool disable_loop_partition)`
   Why doesn't it get its boolean from the config, like all the other passes do?
   
   Any precision would be very much appreciated :) Thanks!




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] FranckQC commented on pull request #9482: Implementation of Common Subexpression Elimination for TIR

Posted by GitBox <gi...@apache.org>.
FranckQC commented on pull request #9482:
URL: https://github.com/apache/tvm/pull/9482#issuecomment-1031722927


   Ok, so right now I'm running into two problems with the CI:
   
   - The doc fails (and I don't know why)
   - And in addition, now the Windows CI does not run anymore. Previously, it was working and green with my changes.
   
   Theoretically, all the other troubles should now be resolved. I had to update quite a few tests by explicitly asking not to run the CSE pass, because it was doing some commoning, that was breaking some comparisons/tests as the produced TIR code was no longer the older and expected one (which is good, it did some work!)


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] FranckQC commented on pull request #9482: Implementation of Common Subexpression Elimination for TIR

Posted by GitBox <gi...@apache.org>.
FranckQC commented on pull request #9482:
URL: https://github.com/apache/tvm/pull/9482#issuecomment-1034303962


   > > Overall looks good to me and we should merge as is, @masahi do you mind adding some follow up issues to track the TODOs generated during review?
   > 
   > Sure! In addition to [#9482 (comment)](https://github.com/apache/tvm/pull/9482#discussion_r790200526), I think we can refactor our `Substitute` function to be a special case of the subsutiuter with a predicate added in this PR.
   
   I would be very happy to look into that as soon as this is merged. Hopefully the current run of the CI should be the last one!


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] FranckQC commented on pull request #9482: Implementation of Common Subexpression Elimination for TIR

Posted by GitBox <gi...@apache.org>.
FranckQC commented on pull request #9482:
URL: https://github.com/apache/tvm/pull/9482#issuecomment-1033901512


   Current status:
   ![image](https://user-images.githubusercontent.com/89943638/153235990-b25ebe2a-7e05-431e-b254-c5b1c5731e94.png)
   
   The only remaining failures are the VTA ones (both for python3: i386 and for unittest: CPU), which should now work with the trick kindly given by @masahi.
   I did not realize that vta.build() actually overwrites the disabled pass of the config. I'll try to fix that just after as it's a nasty side effect. But for now, let's verify that everything is green now.
   So, restarting the CI...


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] FranckQC edited a comment on pull request #9482: Implementation of Common Subexpression Elimination for TIR

Posted by GitBox <gi...@apache.org>.
FranckQC edited a comment on pull request #9482:
URL: https://github.com/apache/tvm/pull/9482#issuecomment-1031844098


   Oh I see, sorry!
   Before I push again, would you mind looking at the [unittest: CPU] error in the file test_vta_insn.py, please?
   I added the CSE pass as disabled (as I did for some other tests), but for some reason this one is still a failure and I don't really udnerstand what's going on there.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] FranckQC commented on pull request #9482: Implementation of Common Subexpression Elimination for TIR

Posted by GitBox <gi...@apache.org>.
FranckQC commented on pull request #9482:
URL: https://github.com/apache/tvm/pull/9482#issuecomment-1031727949


   I tried to re-run the CI (by adding a change and removing it - is there an easier way than that by the way?), and the Windows CI is still not running.
   
   I'll have just to check that the test_vta_insn.py is all good this time, and if it is, I think it will be ready to be merged from my end.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] FranckQC commented on a change in pull request #9482: Implementation of Common Subexpression Elimination for TIR

Posted by GitBox <gi...@apache.org>.
FranckQC commented on a change in pull request #9482:
URL: https://github.com/apache/tvm/pull/9482#discussion_r797243979



##########
File path: src/tir/transforms/common_subexpr_elim.cc
##########
@@ -0,0 +1,601 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file common_subexpr_elim.cc
+ * \brief Implementation of the Common Subexpressions Elimination (CSE) pass
+           which rewrites statements and expressions in order to eliminate
+           redundant computations. In order to achieve that, common (sub-)
+           expressions are introduced into variables with let-in bindings,
+           and the places where the expression was used are replaced with
+           the freshly introduced variable.
+ */
+
+#include "common_subexpr_elim.h"
+
+#include <tvm/ir/transform.h>  // For the class Pass and the class PassContext
+#include <tvm/runtime/container/array.h>
+#include <tvm/runtime/container/string.h>
+#include <tvm/tir/analysis.h>  // For the analysis which gives the size of an expr
+#include <tvm/tir/expr.h>
+#include <tvm/tir/expr_functor.h>
+#include <tvm/tir/function.h>  // For the class PrimFunc
+#include <tvm/tir/stmt.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>  // For the decl of the function returning the pass
+
+#include <algorithm>  // For the algorithm std::find
+#include <iostream>
+#include <unordered_map>  // For the hashtable datatype
+#include <utility>        // For std::pair and std::move
+#include <vector>
+
+#include "../analysis/check_contains.h"  // For the visitor CheckContains
+#include "common_subexpr_elim_tools.h"   // For the auxiliary analysis (visitors) and tools
+#include "replace_expr_selected.h"       // For the mutator ReplaceExprSelected
+
+namespace tvm {
+namespace tir {
+
+/*!
+ * \brief Check whether a computation is forbidden for being treated by the CSE pass.
+          The important thing about forbidden computations is that not only we won't want
+          to collect them for the CSE pass, but we also won't even want to collect computations
+          that contain them.
+          The reason is that reusing such computations would change the semantics of the program,
+          and therefore before doing any introduction of variable or any reuse of already introduced
+          variables, we will make sure that the computation being considered is not forbidden, and
+          that it does not even contain a forbidden computation.
+ * \param expr The expression to check
+ * \return Whether `expr` is a forbidden computation or not
+ */
+bool CommonSubexpressionEliminator::ForbiddenComputation(const PrimExpr& expr) {
+  // Function calls, loads and buffer loads are absolutely forbidden as introducing them into
+  // variables would change the semantics of the program.
+  return (expr.as<CallNode>() != nullptr || expr.as<LoadNode>() != nullptr ||
+          expr.as<BufferLoadNode>() != nullptr);
+}
+
+/*!
+ * \brief Predicate used for verifying that a computation is eligible for being treated by
+          the CSE pass, i.e. for being introduced into a variable / for being replaced by a
+          variable.
+          Being eligible is a conjunction of a few conditions, like not being an atom (constant
+          or variable), not being a forbidden node, not containing a forbidden node, etc.
+ * \param expr The expression to check
+ * \return Whether `expr` is an eligible computation or not
+ */
+bool CommonSubexpressionEliminator::IsEligibleComputation(const PrimExpr& expr) {
+  return (
+      // In order to be eligible, the given expression should not be a constant
+      (expr.as<IntImmNode>() == nullptr) && (expr.as<FloatImmNode>() == nullptr) &&
+      (expr.as<StringImmNode>() == nullptr)
+      // and it should not be a variable
+      && (expr.as<VarNode>() == nullptr)
+      // and it should not be a forbidden computation (function calls and loads)
+      && (!ForbiddenComputation(expr))
+      // and it should not even contain a forbidden computation (function calls and loads)
+      // the reason is that we don't want to register expressions like (x + f(y)) or
+      // (x + Mem[i]) as introducing them into variables could change the semantics
+      && (!CheckContains::ExprContains(expr, ForbiddenComputation))
+      // and it should not be a ramp node or a broadcast node due to some internals TVM
+      // constraints (which check for these node explicitely without performing any
+      // evaluation first, so if they have been put into variables it fails)
+      && (expr.as<RampNode>() == nullptr) && (expr.as<BroadcastNode>() == nullptr));
+}
+
+/*!
+ * \brief Predicate used (when considering eligible computations) for only diving into
+          expressions that are allowed to contain eligible computations. Customize this predicate
+          if you want to make it forbidden to rewrite inside a specific node, like inside
+          a Load node for instance.
+ * \param expr The expression to check
+ * \return Whether `expr` can contain some eligible computations or not, and therefore
+             if recursing inside `expr` is necessary.
+ */
+bool CommonSubexpressionEliminator::CanContainEligibleComputations(const PrimExpr& expr) {
+  // Uncomment the next line to prevent the collection and the replacement of eligible computations
+  // inside the index of Load nodes. We initially thought that this would be needed in order to
+  // not harm the indexing mode of the CPU, but as we are still far from ASM code, we
+  // finally want to perform such simplifications, which tend to happen fairly frequently.
+
+  // return ( (expr.as<LoadNode>() == nullptr) && (expr.as<BufferLoadNode>() == nullptr) )
+  return true;
+};
+
+/*!
+ * \brief Generates a new fresh variable, whose name will be cse_var_i.
+ * \param type_annotation The type of the new variable to generate
+ * \return A new variable of type `type_annotation` called cse_var_i where i is the first available
+            integer.
+ */
+Var CommonSubexpressionEliminator::GenerateNewVar(DataType type_annotation) {
+  // Increase `num_last_try_` for this new attempt
+  num_last_try_++;
+  // Builds the variable name, which is sce_var_i where i will go up from 1
+  std::string prefix = "cse_var_";
+  std::string name = prefix.append(std::to_string(num_last_try_));
+  // Builds a String using the std::string
+  String string_name(name);
+
+  // Check that the name that we want to use for the new variable isn't already being used
+  // (names don't really have to be unique as they are just hints, and having the same name
+  // doesn't means that it's the same variable, but it's clearer for dumps)
+  if (UsesVarName::StmtUsesVarName(initial_body_, string_name)) {
+    // If the name is already used, call ourselves recursively for trying with the next one
+    return GenerateNewVar(type_annotation);
+  }
+
+  // Increase `nb_var_` for this new generation of variable that we have just done
+  nb_var_++;
+
+  // Return a new Variable using the name built and the given type_annotation
+  return (Var(string_name, type_annotation));
+}
+
+/*!
+ * \brief Gives the number of variables generated by the CSE on the current function
+           (i.e., getter for `nb_var_`).
+ * \return A copy of `nb_var_`
+ */
+int CommonSubexpressionEliminator::GetNbVarGenerated() { return nb_var_; }
+
+/*!
+ * \brief Toplevel (static) method that performs Common Subexpression Elimination on
+          a given statement (which should be the body of a PrimFunc). This method should be
+          called for each PrimFunc definition.
+ * \param stmt The statement of the function being analyzed, on which we want to perform CSE
+ * \param context_init The initial context, which should contain the formal parameters
+                          of the function being analyzed
+ * \return A new statement where CSE has been performed
+ */
+Stmt CommonSubexpressionEliminator::PerformCSE(const Stmt& stmt, const Context& context_init) {
+  // As this function is being called for each PrimFunc definition, we create a new instance
+  // for the one we are having now.
+  CommonSubexpressionEliminator common_subexpression_eliminator(stmt, context_init);
+  return common_subexpression_eliminator.VisitStmt(stmt);
+}
+
+/*!
+ * \brief Protected constructor of CommonSubexpressionEliminator.
+ * \param context_init The context at the begining of the CSE pass. It should contain the
+                        formal parameters of the function that will be analyzed
+ */
+CommonSubexpressionEliminator::CommonSubexpressionEliminator(const Stmt& stmt,
+                                                             const Context& context_init)
+    : initial_body_(stmt), context_(context_init) {}
+
+/*!
+ * \brief The method which overrides the generic dispatcher of StmtExprMutator.
+          Entry point to the common subexpression elimination mutator for expressions.
+ * \param expr The expression to mutate
+ */
+PrimExpr CommonSubexpressionEliminator::VisitExpr(const PrimExpr& expr) {
+  PrimExpr result = expr;
+
+  // Obtain the (syntactic) eligible computations done by the input expression, and keep it as
+  // a TableOfComputations, which is a mapping from PrimExpr to size_t, where the size_t is the
+  // number of time this exact syntactic computation is being computed.
+  TableOfComputations table_syntactic_comp_done_by_expr = ComputationsDoneBy::GetComputationsDoneBy(
+      expr, IsEligibleComputation, CanContainEligibleComputations);
+
+  // Transform the hashtable of *syntactic* eligible computations into a vector of pairs
+  // containing *semantic* entities, i.e. where equivalent computations are merged.
+  std::vector<std::pair<PrimExpr, size_t>> semantic_comp_done_by_expr =
+      SyntacticToSemanticComputations(table_syntactic_comp_done_by_expr);
+
+  // Sort the vector of semantic entities by decreasing size
+  std::sort(semantic_comp_done_by_expr.begin(), semantic_comp_done_by_expr.end(),
+            [](std::pair<PrimExpr, size_t> a, std::pair<PrimExpr, size_t> b) {
+              return (CalculateExprComplexity(a.first) > CalculateExprComplexity(b.first));
+            });
+
+  // For each computation done (considering them from biggest to smallest)
+  for (size_t i = 0; i < semantic_comp_done_by_expr.size(); i++) {
+    std::pair<PrimExpr, size_t>& computation_and_nb = semantic_comp_done_by_expr[i];
+
+    // The predicate later used (when doing replacements) to select expressions that are
+    // equivalent to the current computation (`computation_and_nb.first`)
+    std::function<bool(const PrimExpr&)> predicate_selector =
+        [computation_and_nb](const PrimExpr& current_expr) {
+          // `current_expr` should be equivalent to `computation_and_nb.first`, but we also check
+          // that `current_expr` is an eligible computation even if we know that
+          // `computation_and_nb.first` is eligible by construction, in case that one day the
+          // equivalence relation would not preserve the eligibility any more (even though that
+          // would probably be a very weird equivalence).
+          return (EquivalentTerms(current_expr, computation_and_nb.first) &&
+                  IsEligibleComputation(current_expr));
+        };
+
+    // See if there is a pair (`var`, `value`) in the context where `value` is semantically
+    // equivalent to `computation_and_nb.first`
+    auto it_on_var = std::find_if(
+        context_.begin(), context_.end(),
+        [computation_and_nb](const std::pair<Var, MaybeValue>& var_and_value) {
+          // Note : safe to call value() as we check has_value() just before
+          return (var_and_value.second.has_value() &&
+                  EquivalentTerms(var_and_value.second.value(), computation_and_nb.first));
+        });
+
+    // Case where we have a perfectly equivalent computation already available in a variable
+    // introduced (i.e, present in context_).
+    // Note that this case is needed when the user has written something like
+    // [let x = A in ....A...A...] : we need to be able to replace all the occurences of A by
+    // an already existing variable holding A, when such a variable happens to exist.
+    if (it_on_var != context_.end()) {
+      // Replace in the current `result` everything that is selected by the selector with
+      // the existing variable, without diving into expressions in which we don't have the
+      // right to dive.
+      result = ReplaceExprSelected::ReplaceExprSelectedInExpr(
+          result, predicate_selector, it_on_var->first, CanContainEligibleComputations);
+    } else {
+      // The current computation is not equivalent to a computation already done. We will
+      // need to see if we want to introduce it.
+
+      // --- Chunk needed for reusing the UndefinedVars() analysis ---
+      // 1 - Wraps the computation into a statement
+      Stmt computation_wrapped_in_stmt = Evaluate(computation_and_nb.first);
+      // 2.1 - Transform the context into a vector of variables instead of pairs
+      std::function<Var(const std::pair<Var, MaybeValue>&)> forget_value =
+          [](const std::pair<Var, MaybeValue>& pair) { return pair.first; };
+      std::vector<Var> vector_vars_known = VectorMap(context_, forget_value);
+      // 2.2 - Transform the std::vector into an Array
+      Array<Var> array_vars_known = Array<Var>(vector_vars_known);
+      // --- End of chunk needed for reusing the UndefinedVars() analysis ---
+
+      // We use the UndefinedVars() analysis to get the undefined vars of the computation
+      Array<Var> vars_undefined = UndefinedVars(computation_wrapped_in_stmt, array_vars_known);
+
+      // Check if we can introduce it : if it contains no undefined variables and if we want
+      // to introduce it according to the predicate
+      if (vars_undefined.empty() &&
+          PredicateIntroVarForComputation(computation_and_nb.first, computation_and_nb.second)) {
+        // Create a new variable for this computation
+        Var new_var = GenerateNewVar(computation_and_nb.first.dtype());
+        // Replace in the current `result` everything that is selected by the selector with
+        // the new variable, without diving into expressions in which we don't have the
+        // right to dive.
+        result = ReplaceExprSelected::ReplaceExprSelectedInExpr(result, predicate_selector, new_var,
+                                                                CanContainEligibleComputations);
+        // Build a let-in that introduces the new variable in the current `result`
+        result = Let(new_var, computation_and_nb.first, result);
+        // We don't add the variable to the context because the invariant is that the
+        // context is the context in which 'result' makes sense, and we've just updated it.
+      } else {
+        // Here it's not doable to introduce (via a let-in) the computation at this level
+        // as it contains variables that are not yet declared, and/or because the predicate
+        // did not select it.
+        // Either way, we will simply add to the vector of computations the direct subexprs
+        // of the current computation, as these ones might be good candidates
+        // for being introduced into variables.
+        // Note that we don't need to add all of its subexpressions, but only its *direct*
+        // subexpressions as we consider them from biggest to smallest, and if they were
+        // all added at once, then there could be dependencies between them, as commoning
+        // one of them could remove some other possibilities.
+
+        // Computing the direct subexpressions will return a small number of direct
+        // subexpressions (typically 0 to 3)
+        std::vector<PrimExpr> direct_subexprs = DirectSubexpr::GetDirectSubexpressions(
+            computation_and_nb.first, IsEligibleComputation, CanContainEligibleComputations);
+        // The following insertion will maintain `semantic_comp_done_by_expr` sorted (by
+        // decreasing size/complexity), and it will only insert at locations > i as the
+        // direct subexprs are necessarily smaller than the current computation.
+        InsertVectorToSortedSemanticComputations(semantic_comp_done_by_expr, direct_subexprs);
+      }
+    }
+    // Note : we do not remove the current element, as we never look back in the local vector
+  }  // End of for loop
+
+  // Calling the dispatcher to the specific treatments, which will update the context
+  // appropriately before doing the recursive calls on the child nodes
+  result = StmtExprMutator::VisitExpr(result);
+
+  return result;
+}
+
+/*!
+ * \brief The method which overrides the specific treatment for a LetNode
+ */
+PrimExpr CommonSubexpressionEliminator::VisitExpr_(const LetNode* op) {
+  // At this point, we have already done the generic treatment of introducing (via let-in) what
+  // was doable at the toplevel of the given let-in.
+
+  // Save the context at the entry of the function
+  Context context_at_entry = context_;
+
+  // Recurse on the `value` field for potentially rewriting it
+  PrimExpr value_new = VisitExpr(op->value);
+
+  // Augment the context with the association (`var`, `value`) for preparing the next recursion
+  // on the `body`
+  context_.push_back({op->var, MaybeValue(op->value)});
+
+  // Recurse on the `body` (with this extended context)
+  // The recursive call will have potentially done new simplifications, because in this recursive
+  // call `var` will be a part of the context.
+  // (see in VisitExpr() that no introduction were performed when a computation was using an
+  // undefined variable, as that would lead to ill-formed code)
+  PrimExpr body_new = VisitExpr(op->body);
+
+  // Restaure the context to its content at the entrance to not carry out of scope declarations
+  // as the variable introduced by the let-in is not in scope outside of its body
+  context_ = context_at_entry;
+
+  // Rebuild the let-in with a new `value_new` and `body_new` where new simplifications might
+  // have been done.
+
+  // If the `value` and the `body` of the let-in have been rewritten to the same thing
+  if (value_new.same_as(op->value) && body_new.same_as(op->body)) {
+    // then return a reference to the same node
+    return GetRef<PrimExpr>(op);
+  } else {
+    // Otherwise return a let-in built with the new `value_new` and the new `body_new` that
+    // have just been obtained
+    return Let(op->var, value_new, body_new, op->span);
+  }
+}
+
+/*!
+ * \brief The method which overrides the generic dispatcher of StmtExprMutator.
+          Entry point to the common subexpression elimination mutator for statements.
+ * \param stmt The statement to mutate.
+ */
+Stmt CommonSubexpressionEliminator::VisitStmt(const Stmt& stmt) {

Review comment:
       Yes we do need both the `VisitExpr()` and `VisitStmt()` as we are doing commoning both inside expressions (think about `let x = bigComp+bigComp`) and inside statements (think about `Mem[i] = bigComp+bigComp`).
   
   The basic implementation `StmtExprMutator::VisitStmt` does indeed visit sub `Expr`s. And actually we do rely on `StmtExprMutator::VisitExpr()` and `StmtExprMutator::VisitStmt()` for recursing inside the children, see the end of the two methods [here](https://github.com/apache/tvm/pull/9482/files#diff-2cfef2190c8093f3b70d5ed20a1e9e75a74cd040a6a0987d5648f511949efce9R307) and [here](https://github.com/apache/tvm/pull/9482/files#diff-2cfef2190c8093f3b70d5ed20a1e9e75a74cd040a6a0987d5648f511949efce9R477).
   
   The algorithm implemented for performing commoning inside expressions and inside statements is of course basically the same (it identifies redundancies, creates variables and then builds new expression/statement using let-in), but many things will be different type-wise when treating a statement versus an expression.
   There are many of these subtle differences, but one of them is that for an expression the pass will build a let-in expression and for a statement it will build a let-in statement.
   
   I also thought that it should be possible to factorize them and to replace them by something polymorphic. But when I actually tried a little bit (quite some time ago) I ended up not sure that this is doable easily without putting complexity everywhere.
   
   In particular, a possible approach would be to:
   - 1 - Create a main function for the algorithm currently duplicated in `VisitStmt()` and in `VisitExpr()`. Let's call it "`Visit()`" for now. This function takes an `Either<Expr,Stmt>`and returns an `Either<Expr,Stmt>`. Actually, that is called std::variant in C++.
   - 2 - In `VisitExpr(const PrimExpr& expr)` : we now build an `std::variant<Expr, Stmt>` from the given `expr`, then call the "generic" `Visit()`, and unpack its result (an Expr) to return it.
   - 3 - Same thing for `VisitStmt(const Stmt& stmt)` : we now build an std::variant<Expr, Stmt> from `stmt`, then call the "generic" `Visit()`, and unpack its result (a Stmt) to return it.
   
   In order to achieve 1 (which is the hard part here), we will need to :
   - Use polymorphic functions for all the functions being used by the current `VisitStmt()` and `VisitExpr()` (i.e. functions with the same name for treating a `PrimExpr `and a `Stmt`). It is already the case for `GetComputationsDoneBy()`, but we need to do the same for `ReplaceSelectedExprInStmt()` and `ReplaceSelectedExprInExpr()` which are currently not polymorphic methods. And the same goes for quite a few other methods.
   This part is absolutely not the problem.
   
   - But now, when we try to write the new `Visit()` function, it seems that we are a bit stuck :
   We get as input this expr_or_stmt of type std::variant<Expr, Stmt>.
   We can of course check what we got :
   ```
   Expr expr;
   Stmt stmt;
   if( isExpr(expr_or_stmt) ) // Pseudo code, can't remember the name of the actual std method
     expr = std::get<Expr>(expr_or_stmt);
   else
     stmt = std::get<Stmt>(expr_or_stmt);
   ```
   
   And then what?
   For instance, by what should we replace the beginning of the algorithm, which currently does:
   ComputationsDoneBy::GetComputationsDoneBy(
         expr, IsEligibleComputation, CanContainEligibleComputations);
         
   Sure, we could check if we're working with an Expr or a Stmt and within a conditional call either 
   ```
   GetComputationsDoneBy(
         **expr**, IsEligibleComputation, CanContainEligibleComputations);
   ```
   or
   ```
   GetComputationsDoneBy(
         **stmt**, IsEligibleComputation, CanContainEligibleComputations);
   ```
   depending on which case we're in. But that replaces a duplication of methods by duplication **within** the method... And as the algorithm was not trivial to follow even with the comments, I though it's probably better to not add complexity in it.
   
   I will try to think about that again soon, but my current way of seeing things is that without adding complexity everywhere, it's not so easy to actually replace these two methods by just a single one doing the job.
   In addition, in its current form, it's possible to easily change the algorithm if we want it to do something slightly different for an Expr versus for a Stmt. That would be more complicated if everything goets unified.
     




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi commented on pull request #9482: Implementation of Common Subexpression Elimination for TIR

Posted by GitBox <gi...@apache.org>.
masahi commented on pull request #9482:
URL: https://github.com/apache/tvm/pull/9482#issuecomment-1019346372


   Since this is a large PR with many TIR-related additions, I'd like more eyes @vinx13 @junrushao1994 @tqchen 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] FranckQC commented on a change in pull request #9482: Implementation of Common Subexpression Elimination for TIR

Posted by GitBox <gi...@apache.org>.
FranckQC commented on a change in pull request #9482:
URL: https://github.com/apache/tvm/pull/9482#discussion_r797247187



##########
File path: src/tir/transforms/common_subexpr_elim.cc
##########
@@ -0,0 +1,601 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file common_subexpr_elim.cc
+ * \brief Implementation of the Common Subexpressions Elimination (CSE) pass
+           which rewrites statements and expressions in order to eliminate
+           redundant computations. In order to achieve that, common (sub-)
+           expressions are introduced into variables with let-in bindings,
+           and the places where the expression was used are replaced with
+           the freshly introduced variable.
+ */
+
+#include "common_subexpr_elim.h"
+
+#include <tvm/ir/transform.h>  // For the class Pass and the class PassContext
+#include <tvm/runtime/container/array.h>
+#include <tvm/runtime/container/string.h>
+#include <tvm/tir/analysis.h>  // For the analysis which gives the size of an expr
+#include <tvm/tir/expr.h>
+#include <tvm/tir/expr_functor.h>
+#include <tvm/tir/function.h>  // For the class PrimFunc
+#include <tvm/tir/stmt.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>  // For the decl of the function returning the pass
+
+#include <algorithm>  // For the algorithm std::find
+#include <iostream>
+#include <unordered_map>  // For the hashtable datatype
+#include <utility>        // For std::pair and std::move
+#include <vector>
+
+#include "../analysis/check_contains.h"  // For the visitor CheckContains
+#include "common_subexpr_elim_tools.h"   // For the auxiliary analysis (visitors) and tools
+#include "replace_expr_selected.h"       // For the mutator ReplaceExprSelected
+
+namespace tvm {
+namespace tir {
+
+/*!
+ * \brief Check whether a computation is forbidden for being treated by the CSE pass.
+          The important thing about forbidden computations is that not only we won't want
+          to collect them for the CSE pass, but we also won't even want to collect computations
+          that contain them.
+          The reason is that reusing such computations would change the semantics of the program,
+          and therefore before doing any introduction of variable or any reuse of already introduced
+          variables, we will make sure that the computation being considered is not forbidden, and
+          that it does not even contain a forbidden computation.
+ * \param expr The expression to check
+ * \return Whether `expr` is a forbidden computation or not
+ */
+bool CommonSubexpressionEliminator::ForbiddenComputation(const PrimExpr& expr) {
+  // Function calls, loads and buffer loads are absolutely forbidden as introducing them into
+  // variables would change the semantics of the program.
+  return (expr.as<CallNode>() != nullptr || expr.as<LoadNode>() != nullptr ||
+          expr.as<BufferLoadNode>() != nullptr);
+}
+
+/*!
+ * \brief Predicate used for verifying that a computation is eligible for being treated by
+          the CSE pass, i.e. for being introduced into a variable / for being replaced by a
+          variable.
+          Being eligible is a conjunction of a few conditions, like not being an atom (constant
+          or variable), not being a forbidden node, not containing a forbidden node, etc.
+ * \param expr The expression to check
+ * \return Whether `expr` is an eligible computation or not
+ */
+bool CommonSubexpressionEliminator::IsEligibleComputation(const PrimExpr& expr) {
+  return (
+      // In order to be eligible, the given expression should not be a constant
+      (expr.as<IntImmNode>() == nullptr) && (expr.as<FloatImmNode>() == nullptr) &&
+      (expr.as<StringImmNode>() == nullptr)
+      // and it should not be a variable
+      && (expr.as<VarNode>() == nullptr)
+      // and it should not be a forbidden computation (function calls and loads)
+      && (!ForbiddenComputation(expr))
+      // and it should not even contain a forbidden computation (function calls and loads)
+      // the reason is that we don't want to register expressions like (x + f(y)) or
+      // (x + Mem[i]) as introducing them into variables could change the semantics
+      && (!CheckContains::ExprContains(expr, ForbiddenComputation))
+      // and it should not be a ramp node or a broadcast node due to some internals TVM
+      // constraints (which check for these node explicitely without performing any
+      // evaluation first, so if they have been put into variables it fails)
+      && (expr.as<RampNode>() == nullptr) && (expr.as<BroadcastNode>() == nullptr));
+}
+
+/*!
+ * \brief Predicate used (when considering eligible computations) for only diving into
+          expressions that are allowed to contain eligible computations. Customize this predicate
+          if you want to make it forbidden to rewrite inside a specific node, like inside
+          a Load node for instance.
+ * \param expr The expression to check
+ * \return Whether `expr` can contain some eligible computations or not, and therefore
+             if recursing inside `expr` is necessary.
+ */
+bool CommonSubexpressionEliminator::CanContainEligibleComputations(const PrimExpr& expr) {
+  // Uncomment the next line to prevent the collection and the replacement of eligible computations
+  // inside the index of Load nodes. We initially thought that this would be needed in order to
+  // not harm the indexing mode of the CPU, but as we are still far from ASM code, we
+  // finally want to perform such simplifications, which tend to happen fairly frequently.
+
+  // return ( (expr.as<LoadNode>() == nullptr) && (expr.as<BufferLoadNode>() == nullptr) )
+  return true;
+};
+
+/*!
+ * \brief Generates a new fresh variable, whose name will be cse_var_i.
+ * \param type_annotation The type of the new variable to generate
+ * \return A new variable of type `type_annotation` called cse_var_i where i is the first available
+            integer.
+ */
+Var CommonSubexpressionEliminator::GenerateNewVar(DataType type_annotation) {
+  // Increase `num_last_try_` for this new attempt
+  num_last_try_++;
+  // Builds the variable name, which is sce_var_i where i will go up from 1
+  std::string prefix = "cse_var_";
+  std::string name = prefix.append(std::to_string(num_last_try_));
+  // Builds a String using the std::string
+  String string_name(name);
+
+  // Check that the name that we want to use for the new variable isn't already being used
+  // (names don't really have to be unique as they are just hints, and having the same name
+  // doesn't means that it's the same variable, but it's clearer for dumps)
+  if (UsesVarName::StmtUsesVarName(initial_body_, string_name)) {
+    // If the name is already used, call ourselves recursively for trying with the next one
+    return GenerateNewVar(type_annotation);
+  }
+
+  // Increase `nb_var_` for this new generation of variable that we have just done
+  nb_var_++;
+
+  // Return a new Variable using the name built and the given type_annotation
+  return (Var(string_name, type_annotation));
+}
+
+/*!
+ * \brief Gives the number of variables generated by the CSE on the current function
+           (i.e., getter for `nb_var_`).
+ * \return A copy of `nb_var_`
+ */
+int CommonSubexpressionEliminator::GetNbVarGenerated() { return nb_var_; }
+
+/*!
+ * \brief Toplevel (static) method that performs Common Subexpression Elimination on
+          a given statement (which should be the body of a PrimFunc). This method should be
+          called for each PrimFunc definition.
+ * \param stmt The statement of the function being analyzed, on which we want to perform CSE
+ * \param context_init The initial context, which should contain the formal parameters
+                          of the function being analyzed
+ * \return A new statement where CSE has been performed
+ */
+Stmt CommonSubexpressionEliminator::PerformCSE(const Stmt& stmt, const Context& context_init) {
+  // As this function is being called for each PrimFunc definition, we create a new instance
+  // for the one we are having now.
+  CommonSubexpressionEliminator common_subexpression_eliminator(stmt, context_init);
+  return common_subexpression_eliminator.VisitStmt(stmt);
+}
+
+/*!
+ * \brief Protected constructor of CommonSubexpressionEliminator.
+ * \param context_init The context at the begining of the CSE pass. It should contain the
+                        formal parameters of the function that will be analyzed
+ */
+CommonSubexpressionEliminator::CommonSubexpressionEliminator(const Stmt& stmt,
+                                                             const Context& context_init)
+    : initial_body_(stmt), context_(context_init) {}
+
+/*!
+ * \brief The method which overrides the generic dispatcher of StmtExprMutator.
+          Entry point to the common subexpression elimination mutator for expressions.
+ * \param expr The expression to mutate
+ */
+PrimExpr CommonSubexpressionEliminator::VisitExpr(const PrimExpr& expr) {
+  PrimExpr result = expr;
+
+  // Obtain the (syntactic) eligible computations done by the input expression, and keep it as
+  // a TableOfComputations, which is a mapping from PrimExpr to size_t, where the size_t is the
+  // number of time this exact syntactic computation is being computed.
+  TableOfComputations table_syntactic_comp_done_by_expr = ComputationsDoneBy::GetComputationsDoneBy(
+      expr, IsEligibleComputation, CanContainEligibleComputations);
+
+  // Transform the hashtable of *syntactic* eligible computations into a vector of pairs
+  // containing *semantic* entities, i.e. where equivalent computations are merged.
+  std::vector<std::pair<PrimExpr, size_t>> semantic_comp_done_by_expr =
+      SyntacticToSemanticComputations(table_syntactic_comp_done_by_expr);
+
+  // Sort the vector of semantic entities by decreasing size
+  std::sort(semantic_comp_done_by_expr.begin(), semantic_comp_done_by_expr.end(),
+            [](std::pair<PrimExpr, size_t> a, std::pair<PrimExpr, size_t> b) {
+              return (CalculateExprComplexity(a.first) > CalculateExprComplexity(b.first));
+            });
+
+  // For each computation done (considering them from biggest to smallest)
+  for (size_t i = 0; i < semantic_comp_done_by_expr.size(); i++) {
+    std::pair<PrimExpr, size_t>& computation_and_nb = semantic_comp_done_by_expr[i];
+
+    // The predicate later used (when doing replacements) to select expressions that are
+    // equivalent to the current computation (`computation_and_nb.first`)
+    std::function<bool(const PrimExpr&)> predicate_selector =
+        [computation_and_nb](const PrimExpr& current_expr) {
+          // `current_expr` should be equivalent to `computation_and_nb.first`, but we also check
+          // that `current_expr` is an eligible computation even if we know that
+          // `computation_and_nb.first` is eligible by construction, in case that one day the
+          // equivalence relation would not preserve the eligibility any more (even though that
+          // would probably be a very weird equivalence).
+          return (EquivalentTerms(current_expr, computation_and_nb.first) &&
+                  IsEligibleComputation(current_expr));
+        };
+
+    // See if there is a pair (`var`, `value`) in the context where `value` is semantically
+    // equivalent to `computation_and_nb.first`
+    auto it_on_var = std::find_if(
+        context_.begin(), context_.end(),
+        [computation_and_nb](const std::pair<Var, MaybeValue>& var_and_value) {
+          // Note : safe to call value() as we check has_value() just before
+          return (var_and_value.second.has_value() &&
+                  EquivalentTerms(var_and_value.second.value(), computation_and_nb.first));
+        });
+
+    // Case where we have a perfectly equivalent computation already available in a variable
+    // introduced (i.e, present in context_).
+    // Note that this case is needed when the user has written something like
+    // [let x = A in ....A...A...] : we need to be able to replace all the occurences of A by
+    // an already existing variable holding A, when such a variable happens to exist.
+    if (it_on_var != context_.end()) {
+      // Replace in the current `result` everything that is selected by the selector with
+      // the existing variable, without diving into expressions in which we don't have the
+      // right to dive.
+      result = ReplaceExprSelected::ReplaceExprSelectedInExpr(
+          result, predicate_selector, it_on_var->first, CanContainEligibleComputations);
+    } else {
+      // The current computation is not equivalent to a computation already done. We will
+      // need to see if we want to introduce it.
+
+      // --- Chunk needed for reusing the UndefinedVars() analysis ---
+      // 1 - Wraps the computation into a statement
+      Stmt computation_wrapped_in_stmt = Evaluate(computation_and_nb.first);
+      // 2.1 - Transform the context into a vector of variables instead of pairs
+      std::function<Var(const std::pair<Var, MaybeValue>&)> forget_value =
+          [](const std::pair<Var, MaybeValue>& pair) { return pair.first; };
+      std::vector<Var> vector_vars_known = VectorMap(context_, forget_value);
+      // 2.2 - Transform the std::vector into an Array
+      Array<Var> array_vars_known = Array<Var>(vector_vars_known);
+      // --- End of chunk needed for reusing the UndefinedVars() analysis ---
+
+      // We use the UndefinedVars() analysis to get the undefined vars of the computation
+      Array<Var> vars_undefined = UndefinedVars(computation_wrapped_in_stmt, array_vars_known);
+
+      // Check if we can introduce it : if it contains no undefined variables and if we want
+      // to introduce it according to the predicate
+      if (vars_undefined.empty() &&
+          PredicateIntroVarForComputation(computation_and_nb.first, computation_and_nb.second)) {
+        // Create a new variable for this computation
+        Var new_var = GenerateNewVar(computation_and_nb.first.dtype());
+        // Replace in the current `result` everything that is selected by the selector with
+        // the new variable, without diving into expressions in which we don't have the
+        // right to dive.
+        result = ReplaceExprSelected::ReplaceExprSelectedInExpr(result, predicate_selector, new_var,
+                                                                CanContainEligibleComputations);
+        // Build a let-in that introduces the new variable in the current `result`
+        result = Let(new_var, computation_and_nb.first, result);
+        // We don't add the variable to the context because the invariant is that the
+        // context is the context in which 'result' makes sense, and we've just updated it.
+      } else {
+        // Here it's not doable to introduce (via a let-in) the computation at this level
+        // as it contains variables that are not yet declared, and/or because the predicate
+        // did not select it.
+        // Either way, we will simply add to the vector of computations the direct subexprs
+        // of the current computation, as these ones might be good candidates
+        // for being introduced into variables.
+        // Note that we don't need to add all of its subexpressions, but only its *direct*
+        // subexpressions as we consider them from biggest to smallest, and if they were
+        // all added at once, then there could be dependencies between them, as commoning
+        // one of them could remove some other possibilities.
+
+        // Computing the direct subexpressions will return a small number of direct
+        // subexpressions (typically 0 to 3)
+        std::vector<PrimExpr> direct_subexprs = DirectSubexpr::GetDirectSubexpressions(
+            computation_and_nb.first, IsEligibleComputation, CanContainEligibleComputations);
+        // The following insertion will maintain `semantic_comp_done_by_expr` sorted (by
+        // decreasing size/complexity), and it will only insert at locations > i as the
+        // direct subexprs are necessarily smaller than the current computation.
+        InsertVectorToSortedSemanticComputations(semantic_comp_done_by_expr, direct_subexprs);
+      }
+    }
+    // Note : we do not remove the current element, as we never look back in the local vector
+  }  // End of for loop
+
+  // Calling the dispatcher to the specific treatments, which will update the context
+  // appropriately before doing the recursive calls on the child nodes
+  result = StmtExprMutator::VisitExpr(result);
+
+  return result;
+}
+
+/*!
+ * \brief The method which overrides the specific treatment for a LetNode
+ */
+PrimExpr CommonSubexpressionEliminator::VisitExpr_(const LetNode* op) {
+  // At this point, we have already done the generic treatment of introducing (via let-in) what
+  // was doable at the toplevel of the given let-in.
+
+  // Save the context at the entry of the function
+  Context context_at_entry = context_;
+
+  // Recurse on the `value` field for potentially rewriting it
+  PrimExpr value_new = VisitExpr(op->value);
+
+  // Augment the context with the association (`var`, `value`) for preparing the next recursion
+  // on the `body`
+  context_.push_back({op->var, MaybeValue(op->value)});
+
+  // Recurse on the `body` (with this extended context)
+  // The recursive call will have potentially done new simplifications, because in this recursive
+  // call `var` will be a part of the context.
+  // (see in VisitExpr() that no introduction were performed when a computation was using an
+  // undefined variable, as that would lead to ill-formed code)
+  PrimExpr body_new = VisitExpr(op->body);
+
+  // Restaure the context to its content at the entrance to not carry out of scope declarations
+  // as the variable introduced by the let-in is not in scope outside of its body
+  context_ = context_at_entry;
+
+  // Rebuild the let-in with a new `value_new` and `body_new` where new simplifications might
+  // have been done.
+
+  // If the `value` and the `body` of the let-in have been rewritten to the same thing
+  if (value_new.same_as(op->value) && body_new.same_as(op->body)) {
+    // then return a reference to the same node
+    return GetRef<PrimExpr>(op);
+  } else {
+    // Otherwise return a let-in built with the new `value_new` and the new `body_new` that
+    // have just been obtained
+    return Let(op->var, value_new, body_new, op->span);
+  }
+}
+
+/*!
+ * \brief The method which overrides the generic dispatcher of StmtExprMutator.
+          Entry point to the common subexpression elimination mutator for statements.
+ * \param stmt The statement to mutate.
+ */
+Stmt CommonSubexpressionEliminator::VisitStmt(const Stmt& stmt) {

Review comment:
       My pleasure! Many thanks for your very interesting review! I just have two last thing to do :
   - Fix the boolean not needed that you mentionned (last point not resolved).
   - Run clang-format again, and final cleanups.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi commented on a change in pull request #9482: Implementation of Common Subexpression Elimination for TIR

Posted by GitBox <gi...@apache.org>.
masahi commented on a change in pull request #9482:
URL: https://github.com/apache/tvm/pull/9482#discussion_r797245853



##########
File path: src/tir/transforms/common_subexpr_elim.cc
##########
@@ -0,0 +1,601 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file common_subexpr_elim.cc
+ * \brief Implementation of the Common Subexpressions Elimination (CSE) pass
+           which rewrites statements and expressions in order to eliminate
+           redundant computations. In order to achieve that, common (sub-)
+           expressions are introduced into variables with let-in bindings,
+           and the places where the expression was used are replaced with
+           the freshly introduced variable.
+ */
+
+#include "common_subexpr_elim.h"
+
+#include <tvm/ir/transform.h>  // For the class Pass and the class PassContext
+#include <tvm/runtime/container/array.h>
+#include <tvm/runtime/container/string.h>
+#include <tvm/tir/analysis.h>  // For the analysis which gives the size of an expr
+#include <tvm/tir/expr.h>
+#include <tvm/tir/expr_functor.h>
+#include <tvm/tir/function.h>  // For the class PrimFunc
+#include <tvm/tir/stmt.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>  // For the decl of the function returning the pass
+
+#include <algorithm>  // For the algorithm std::find
+#include <iostream>
+#include <unordered_map>  // For the hashtable datatype
+#include <utility>        // For std::pair and std::move
+#include <vector>
+
+#include "../analysis/check_contains.h"  // For the visitor CheckContains
+#include "common_subexpr_elim_tools.h"   // For the auxiliary analysis (visitors) and tools
+#include "replace_expr_selected.h"       // For the mutator ReplaceExprSelected
+
+namespace tvm {
+namespace tir {
+
+/*!
+ * \brief Check whether a computation is forbidden for being treated by the CSE pass.
+          The important thing about forbidden computations is that not only we won't want
+          to collect them for the CSE pass, but we also won't even want to collect computations
+          that contain them.
+          The reason is that reusing such computations would change the semantics of the program,
+          and therefore before doing any introduction of variable or any reuse of already introduced
+          variables, we will make sure that the computation being considered is not forbidden, and
+          that it does not even contain a forbidden computation.
+ * \param expr The expression to check
+ * \return Whether `expr` is a forbidden computation or not
+ */
+bool CommonSubexpressionEliminator::ForbiddenComputation(const PrimExpr& expr) {
+  // Function calls, loads and buffer loads are absolutely forbidden as introducing them into
+  // variables would change the semantics of the program.
+  return (expr.as<CallNode>() != nullptr || expr.as<LoadNode>() != nullptr ||
+          expr.as<BufferLoadNode>() != nullptr);
+}
+
+/*!
+ * \brief Predicate used for verifying that a computation is eligible for being treated by
+          the CSE pass, i.e. for being introduced into a variable / for being replaced by a
+          variable.
+          Being eligible is a conjunction of a few conditions, like not being an atom (constant
+          or variable), not being a forbidden node, not containing a forbidden node, etc.
+ * \param expr The expression to check
+ * \return Whether `expr` is an eligible computation or not
+ */
+bool CommonSubexpressionEliminator::IsEligibleComputation(const PrimExpr& expr) {
+  return (
+      // In order to be eligible, the given expression should not be a constant
+      (expr.as<IntImmNode>() == nullptr) && (expr.as<FloatImmNode>() == nullptr) &&
+      (expr.as<StringImmNode>() == nullptr)
+      // and it should not be a variable
+      && (expr.as<VarNode>() == nullptr)
+      // and it should not be a forbidden computation (function calls and loads)
+      && (!ForbiddenComputation(expr))
+      // and it should not even contain a forbidden computation (function calls and loads)
+      // the reason is that we don't want to register expressions like (x + f(y)) or
+      // (x + Mem[i]) as introducing them into variables could change the semantics
+      && (!CheckContains::ExprContains(expr, ForbiddenComputation))
+      // and it should not be a ramp node or a broadcast node due to some internals TVM
+      // constraints (which check for these node explicitely without performing any
+      // evaluation first, so if they have been put into variables it fails)
+      && (expr.as<RampNode>() == nullptr) && (expr.as<BroadcastNode>() == nullptr));
+}
+
+/*!
+ * \brief Predicate used (when considering eligible computations) for only diving into
+          expressions that are allowed to contain eligible computations. Customize this predicate
+          if you want to make it forbidden to rewrite inside a specific node, like inside
+          a Load node for instance.
+ * \param expr The expression to check
+ * \return Whether `expr` can contain some eligible computations or not, and therefore
+             if recursing inside `expr` is necessary.
+ */
+bool CommonSubexpressionEliminator::CanContainEligibleComputations(const PrimExpr& expr) {
+  // Uncomment the next line to prevent the collection and the replacement of eligible computations
+  // inside the index of Load nodes. We initially thought that this would be needed in order to
+  // not harm the indexing mode of the CPU, but as we are still far from ASM code, we
+  // finally want to perform such simplifications, which tend to happen fairly frequently.
+
+  // return ( (expr.as<LoadNode>() == nullptr) && (expr.as<BufferLoadNode>() == nullptr) )
+  return true;
+};
+
+/*!
+ * \brief Generates a new fresh variable, whose name will be cse_var_i.
+ * \param type_annotation The type of the new variable to generate
+ * \return A new variable of type `type_annotation` called cse_var_i where i is the first available
+            integer.
+ */
+Var CommonSubexpressionEliminator::GenerateNewVar(DataType type_annotation) {
+  // Increase `num_last_try_` for this new attempt
+  num_last_try_++;
+  // Builds the variable name, which is sce_var_i where i will go up from 1
+  std::string prefix = "cse_var_";
+  std::string name = prefix.append(std::to_string(num_last_try_));
+  // Builds a String using the std::string
+  String string_name(name);
+
+  // Check that the name that we want to use for the new variable isn't already being used
+  // (names don't really have to be unique as they are just hints, and having the same name
+  // doesn't means that it's the same variable, but it's clearer for dumps)
+  if (UsesVarName::StmtUsesVarName(initial_body_, string_name)) {
+    // If the name is already used, call ourselves recursively for trying with the next one
+    return GenerateNewVar(type_annotation);
+  }
+
+  // Increase `nb_var_` for this new generation of variable that we have just done
+  nb_var_++;
+
+  // Return a new Variable using the name built and the given type_annotation
+  return (Var(string_name, type_annotation));
+}
+
+/*!
+ * \brief Gives the number of variables generated by the CSE on the current function
+           (i.e., getter for `nb_var_`).
+ * \return A copy of `nb_var_`
+ */
+int CommonSubexpressionEliminator::GetNbVarGenerated() { return nb_var_; }
+
+/*!
+ * \brief Toplevel (static) method that performs Common Subexpression Elimination on
+          a given statement (which should be the body of a PrimFunc). This method should be
+          called for each PrimFunc definition.
+ * \param stmt The statement of the function being analyzed, on which we want to perform CSE
+ * \param context_init The initial context, which should contain the formal parameters
+                          of the function being analyzed
+ * \return A new statement where CSE has been performed
+ */
+Stmt CommonSubexpressionEliminator::PerformCSE(const Stmt& stmt, const Context& context_init) {
+  // As this function is being called for each PrimFunc definition, we create a new instance
+  // for the one we are having now.
+  CommonSubexpressionEliminator common_subexpression_eliminator(stmt, context_init);
+  return common_subexpression_eliminator.VisitStmt(stmt);
+}
+
+/*!
+ * \brief Protected constructor of CommonSubexpressionEliminator.
+ * \param context_init The context at the begining of the CSE pass. It should contain the
+                        formal parameters of the function that will be analyzed
+ */
+CommonSubexpressionEliminator::CommonSubexpressionEliminator(const Stmt& stmt,
+                                                             const Context& context_init)
+    : initial_body_(stmt), context_(context_init) {}
+
+/*!
+ * \brief The method which overrides the generic dispatcher of StmtExprMutator.
+          Entry point to the common subexpression elimination mutator for expressions.
+ * \param expr The expression to mutate
+ */
+PrimExpr CommonSubexpressionEliminator::VisitExpr(const PrimExpr& expr) {
+  PrimExpr result = expr;
+
+  // Obtain the (syntactic) eligible computations done by the input expression, and keep it as
+  // a TableOfComputations, which is a mapping from PrimExpr to size_t, where the size_t is the
+  // number of time this exact syntactic computation is being computed.
+  TableOfComputations table_syntactic_comp_done_by_expr = ComputationsDoneBy::GetComputationsDoneBy(
+      expr, IsEligibleComputation, CanContainEligibleComputations);
+
+  // Transform the hashtable of *syntactic* eligible computations into a vector of pairs
+  // containing *semantic* entities, i.e. where equivalent computations are merged.
+  std::vector<std::pair<PrimExpr, size_t>> semantic_comp_done_by_expr =
+      SyntacticToSemanticComputations(table_syntactic_comp_done_by_expr);
+
+  // Sort the vector of semantic entities by decreasing size
+  std::sort(semantic_comp_done_by_expr.begin(), semantic_comp_done_by_expr.end(),
+            [](std::pair<PrimExpr, size_t> a, std::pair<PrimExpr, size_t> b) {
+              return (CalculateExprComplexity(a.first) > CalculateExprComplexity(b.first));
+            });
+
+  // For each computation done (considering them from biggest to smallest)
+  for (size_t i = 0; i < semantic_comp_done_by_expr.size(); i++) {
+    std::pair<PrimExpr, size_t>& computation_and_nb = semantic_comp_done_by_expr[i];
+
+    // The predicate later used (when doing replacements) to select expressions that are
+    // equivalent to the current computation (`computation_and_nb.first`)
+    std::function<bool(const PrimExpr&)> predicate_selector =
+        [computation_and_nb](const PrimExpr& current_expr) {
+          // `current_expr` should be equivalent to `computation_and_nb.first`, but we also check
+          // that `current_expr` is an eligible computation even if we know that
+          // `computation_and_nb.first` is eligible by construction, in case that one day the
+          // equivalence relation would not preserve the eligibility any more (even though that
+          // would probably be a very weird equivalence).
+          return (EquivalentTerms(current_expr, computation_and_nb.first) &&
+                  IsEligibleComputation(current_expr));
+        };
+
+    // See if there is a pair (`var`, `value`) in the context where `value` is semantically
+    // equivalent to `computation_and_nb.first`
+    auto it_on_var = std::find_if(
+        context_.begin(), context_.end(),
+        [computation_and_nb](const std::pair<Var, MaybeValue>& var_and_value) {
+          // Note : safe to call value() as we check has_value() just before
+          return (var_and_value.second.has_value() &&
+                  EquivalentTerms(var_and_value.second.value(), computation_and_nb.first));
+        });
+
+    // Case where we have a perfectly equivalent computation already available in a variable
+    // introduced (i.e, present in context_).
+    // Note that this case is needed when the user has written something like
+    // [let x = A in ....A...A...] : we need to be able to replace all the occurences of A by
+    // an already existing variable holding A, when such a variable happens to exist.
+    if (it_on_var != context_.end()) {
+      // Replace in the current `result` everything that is selected by the selector with
+      // the existing variable, without diving into expressions in which we don't have the
+      // right to dive.
+      result = ReplaceExprSelected::ReplaceExprSelectedInExpr(
+          result, predicate_selector, it_on_var->first, CanContainEligibleComputations);
+    } else {
+      // The current computation is not equivalent to a computation already done. We will
+      // need to see if we want to introduce it.
+
+      // --- Chunk needed for reusing the UndefinedVars() analysis ---
+      // 1 - Wraps the computation into a statement
+      Stmt computation_wrapped_in_stmt = Evaluate(computation_and_nb.first);
+      // 2.1 - Transform the context into a vector of variables instead of pairs
+      std::function<Var(const std::pair<Var, MaybeValue>&)> forget_value =
+          [](const std::pair<Var, MaybeValue>& pair) { return pair.first; };
+      std::vector<Var> vector_vars_known = VectorMap(context_, forget_value);
+      // 2.2 - Transform the std::vector into an Array
+      Array<Var> array_vars_known = Array<Var>(vector_vars_known);
+      // --- End of chunk needed for reusing the UndefinedVars() analysis ---
+
+      // We use the UndefinedVars() analysis to get the undefined vars of the computation
+      Array<Var> vars_undefined = UndefinedVars(computation_wrapped_in_stmt, array_vars_known);
+
+      // Check if we can introduce it : if it contains no undefined variables and if we want
+      // to introduce it according to the predicate
+      if (vars_undefined.empty() &&
+          PredicateIntroVarForComputation(computation_and_nb.first, computation_and_nb.second)) {
+        // Create a new variable for this computation
+        Var new_var = GenerateNewVar(computation_and_nb.first.dtype());
+        // Replace in the current `result` everything that is selected by the selector with
+        // the new variable, without diving into expressions in which we don't have the
+        // right to dive.
+        result = ReplaceExprSelected::ReplaceExprSelectedInExpr(result, predicate_selector, new_var,
+                                                                CanContainEligibleComputations);
+        // Build a let-in that introduces the new variable in the current `result`
+        result = Let(new_var, computation_and_nb.first, result);
+        // We don't add the variable to the context because the invariant is that the
+        // context is the context in which 'result' makes sense, and we've just updated it.
+      } else {
+        // Here it's not doable to introduce (via a let-in) the computation at this level
+        // as it contains variables that are not yet declared, and/or because the predicate
+        // did not select it.
+        // Either way, we will simply add to the vector of computations the direct subexprs
+        // of the current computation, as these ones might be good candidates
+        // for being introduced into variables.
+        // Note that we don't need to add all of its subexpressions, but only its *direct*
+        // subexpressions as we consider them from biggest to smallest, and if they were
+        // all added at once, then there could be dependencies between them, as commoning
+        // one of them could remove some other possibilities.
+
+        // Computing the direct subexpressions will return a small number of direct
+        // subexpressions (typically 0 to 3)
+        std::vector<PrimExpr> direct_subexprs = DirectSubexpr::GetDirectSubexpressions(
+            computation_and_nb.first, IsEligibleComputation, CanContainEligibleComputations);
+        // The following insertion will maintain `semantic_comp_done_by_expr` sorted (by
+        // decreasing size/complexity), and it will only insert at locations > i as the
+        // direct subexprs are necessarily smaller than the current computation.
+        InsertVectorToSortedSemanticComputations(semantic_comp_done_by_expr, direct_subexprs);
+      }
+    }
+    // Note : we do not remove the current element, as we never look back in the local vector
+  }  // End of for loop
+
+  // Calling the dispatcher to the specific treatments, which will update the context
+  // appropriately before doing the recursive calls on the child nodes
+  result = StmtExprMutator::VisitExpr(result);
+
+  return result;
+}
+
+/*!
+ * \brief The method which overrides the specific treatment for a LetNode
+ */
+PrimExpr CommonSubexpressionEliminator::VisitExpr_(const LetNode* op) {
+  // At this point, we have already done the generic treatment of introducing (via let-in) what
+  // was doable at the toplevel of the given let-in.
+
+  // Save the context at the entry of the function
+  Context context_at_entry = context_;
+
+  // Recurse on the `value` field for potentially rewriting it
+  PrimExpr value_new = VisitExpr(op->value);
+
+  // Augment the context with the association (`var`, `value`) for preparing the next recursion
+  // on the `body`
+  context_.push_back({op->var, MaybeValue(op->value)});
+
+  // Recurse on the `body` (with this extended context)
+  // The recursive call will have potentially done new simplifications, because in this recursive
+  // call `var` will be a part of the context.
+  // (see in VisitExpr() that no introduction were performed when a computation was using an
+  // undefined variable, as that would lead to ill-formed code)
+  PrimExpr body_new = VisitExpr(op->body);
+
+  // Restaure the context to its content at the entrance to not carry out of scope declarations
+  // as the variable introduced by the let-in is not in scope outside of its body
+  context_ = context_at_entry;
+
+  // Rebuild the let-in with a new `value_new` and `body_new` where new simplifications might
+  // have been done.
+
+  // If the `value` and the `body` of the let-in have been rewritten to the same thing
+  if (value_new.same_as(op->value) && body_new.same_as(op->body)) {
+    // then return a reference to the same node
+    return GetRef<PrimExpr>(op);
+  } else {
+    // Otherwise return a let-in built with the new `value_new` and the new `body_new` that
+    // have just been obtained
+    return Let(op->var, value_new, body_new, op->span);
+  }
+}
+
+/*!
+ * \brief The method which overrides the generic dispatcher of StmtExprMutator.
+          Entry point to the common subexpression elimination mutator for statements.
+ * \param stmt The statement to mutate.
+ */
+Stmt CommonSubexpressionEliminator::VisitStmt(const Stmt& stmt) {

Review comment:
       Thanks for the detailed write-up, let's keep the current implementation then.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] FranckQC commented on a change in pull request #9482: Implementation of Common Subexpression Elimination for TIR

Posted by GitBox <gi...@apache.org>.
FranckQC commented on a change in pull request #9482:
URL: https://github.com/apache/tvm/pull/9482#discussion_r797943017



##########
File path: python/tvm/tir/transform/transform.py
##########
@@ -310,6 +310,15 @@ def BF16TypeLowering():
     """
     return _ffi_api.BF16TypeLowering()  # type: ignore
 
+def CommonSubexprElim(enable_cse_tir: bool = True):

Review comment:
       Thanks for this remark! I am a little bit confused about things regarding activating/deactivating a pass at the moment.
   If I understand well, there is two ways to disable a pass :
   
   1) In driver_api.cc, we have the registration of options:
   ```
   // Register build pipeline related options
   [...]
   TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_vectorize", Bool);
   TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_storage_rewrite", Bool);
   TVM_REGISTER_PASS_CONFIG_OPTION("tir.is_entry_func", Bool);
   TVM_REGISTER_PASS_CONFIG_OPTION("tir.add_lower_pass", Array<Array<ObjectRef>>);
   TVM_REGISTER_PASS_CONFIG_OPTION("tir.debug_keep_trivial_loop", Bool);
   ```
   
   Then in `Array<tvm::transform::Pass> CreatePassList()`, we get these booleans from the `config`:
   ```
     bool disable_vectorize = pass_ctx->GetConfig<Bool>("tir.disable_vectorize", Bool(false)).value();
     bool disable_storage_rewrite =
         pass_ctx->GetConfig<Bool>("tir.disable_storage_rewrite", Bool(false)).value();
     bool instrument_bound_checkers =
         pass_ctx->GetConfig<Bool>("tir.instrument_bound_checkers", Bool(false)).value();
     bool disable_cse_tir = pass_ctx->GetConfig<Bool>("tir.disable_cse_tir", Bool(false)).value();
   ```
   
   And finally when we push the passes into the `pass_list`, we use the boolean that we got.
   For instance:
   ```
   [...]
    pass_list.push_back(tir::transform::VectorizeLoop(!disable_vectorize));
    pass_list.push_back(tir::transform::CommonSubexprElimTIR(!disable_cse_tir));
   ```
   
   That makes sense to me.
   Regarding what you said about the list of disabled pass, I assume that when a pass is mentioned in the disabled pass list, the context will get this information (somehow), and so the information will later flow through the GetConfig() showed above.
   Am I correct on this?
   
   2) In addition, some passes, like Vectorize, expose a boolean to the python interface for activating/deactivating them.
   In transform.py:
   ```
   def VectorizeLoop(enable_vectorize: bool = True):
       """Lower vectorization loops.
       [...]
       return _ffi_api.VectorizeLoop(enable_vectorize)  # type: ignore
   ```
   
   I decided to do the same for the CSE. I though that would make it easier to disable the pass from the Python bindings, for instance for doing some experiments. If I remember well I was told that exposing the boolean would be useful (perhaps for the end-to-end compilation?).
   Why is it done for Vectorize exactly? Should I keep it for the CSE?
   
   Side note :
   The pass LoopPartition is quite weird, as its boolean for activation/deactivation is passed directly to the `CreatePassList()` method:
   `Array<tvm::transform::Pass> CreatePassList(bool disable_loop_partition)`
   Why doesn't it get its boolean from the config, like all the other passes do?
   
   Any precision would be very much appreciated :) Thanks!




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi commented on a change in pull request #9482: Implementation of Common Subexpression Elimination for TIR

Posted by GitBox <gi...@apache.org>.
masahi commented on a change in pull request #9482:
URL: https://github.com/apache/tvm/pull/9482#discussion_r790201297



##########
File path: src/tir/transforms/replace_expr_selected.h
##########
@@ -0,0 +1,75 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file replace_expr_selected.h
+ * \brief Interface of the pass that replaces in a statement
+           or expression all the subexpressions that are selected
+           with a predicate by another expression.
+ */
+
+#ifndef TVM_TIR_TRANSFORMS_REPLACE_SELECTED_H_
+#define TVM_TIR_TRANSFORMS_REPLACE_SELECTED_H_
+
+#include <tvm/tir/expr.h>
+#include <tvm/tir/expr_functor.h>
+#include <tvm/tir/stmt.h>
+#include <tvm/tir/stmt_functor.h>  // For the class StmtExprMutator
+
+namespace tvm {
+namespace tir {
+
+/*!
+ * \brief Mutator for replacing the expressions selected by a predicate in a statement and/or
+          in an expression, which only replace inside of nodes in which it is allowed to perform
+          replacecements (given by a second predicate)
+ */
+class ReplaceExprSelected : public StmtExprMutator {

Review comment:
       `ReplaceSelectedExpr`




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] FranckQC commented on pull request #9482: Implementation of Common Subexpression Elimination for TIR

Posted by GitBox <gi...@apache.org>.
FranckQC commented on pull request #9482:
URL: https://github.com/apache/tvm/pull/9482#issuecomment-969151523


   > 
   > 
   > Hi @FranckQC, I wanted TIR-level CSE for a long time, so very excited to see this!
   > 
   > What I wanted to do is to eliminate common expressions that span across the host and GPU - for example, in GPU `sort` kernel, I need to make `log2(N)` GPU kernel calls from the host to sort the input bottom up. In principle, `log2(N)` needs to computed once by the host and pass to the GPU kernel, but since we cannot CSE `log2(N)` expression that appears both in the host and GPU kernel, right now the GPU sort kernel is littered with `log2(N)` compute like this (note a lot of calls to `call_spirv_pure_glsl450` which is totally unnecessary if we had TIR-level CSE) https://gist.github.com/masahi/7a755ef67009e1a836e3212c53cf496f
   > 
   > Is this PR going to solve my problem?
   
   Hi @masahi 
   Thanks a lot for the kind words! I'm happy to read that this new pass will be useful to you.
   Yes, in principle, every redundant subterms that are eligible for being commoned out (i.e, which does not contain function calls, etc) will be commoned out. There are a also a few restrictions which are due to some specifics of TVM, but these are rare.
   Do you have a little snippet of the TIR code that you have which has some redundancies? I cant try to tell if the CSE pass will be able to optimize it.
   Also please do not hesitate to play with the pass and to let me know if it does what you would hope to obtain. I can help of course.
   
   Kind regards.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] FranckQC edited a comment on pull request #9482: Implementation of Common Subexpression Elimination for TIR

Posted by GitBox <gi...@apache.org>.
FranckQC edited a comment on pull request #9482:
URL: https://github.com/apache/tvm/pull/9482#issuecomment-972608426






-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] FranckQC edited a comment on pull request #9482: Implementation of Common Subexpression Elimination for TIR

Posted by GitBox <gi...@apache.org>.
FranckQC edited a comment on pull request #9482:
URL: https://github.com/apache/tvm/pull/9482#issuecomment-969151523


   > 
   > 
   > Hi @FranckQC, I wanted TIR-level CSE for a long time, so very excited to see this!
   > 
   > What I wanted to do is to eliminate common expressions that span across the host and GPU - for example, in GPU `sort` kernel, I need to make `log2(N)` GPU kernel calls from the host to sort the input bottom up. In principle, `log2(N)` needs to computed once by the host and pass to the GPU kernel, but since we cannot CSE `log2(N)` expression that appears both in the host and GPU kernel, right now the GPU sort kernel is littered with `log2(N)` compute like this (note a lot of calls to `call_spirv_pure_glsl450` which is totally unnecessary if we had TIR-level CSE) https://gist.github.com/masahi/7a755ef67009e1a836e3212c53cf496f
   > 
   > Is this PR going to solve my problem?
   
   Hi @masahi 
   Thanks a lot for the kind words! I'm happy to read that this new pass might be useful to you.
   In principle, every redundant subterms that are eligible for being commoned out (i.e, which does not contain function calls, etc) will be commoned out. There are also a few other minor restrictions which are due to some specifics of TVM, but these are rare.
   
   Unfortunately, on the TIR code snippet that you have uploaded, it seems that the subexpression that is redundant is just a function call. These can't be commoned out into a new variable by the CSE pass as it does not have any guarantee that the function has no side effects, meaning that it will always produce the same outputs for the same inputs. Without this guarantee, commoning out such function calls could change the program's semantics, so it's not done as preserving the semantics of the program is vital.
   
   I can imagine that for functions that are guaranteed to not do any side effects (and which are therefore "functions" in the mathematical sense of the term), we could relax this restriction, but that would be an extension to implement in the future. And it would rely on some "NoSideEffect" tag on functions.
   
   However, please note that if you had some other redundancies, this CSE pass would common out whatever redundancies you have that are eligible.
   For instance : 
   Assume you have the term (f(42) + f(42)) + (x*y + x*y) that appear somewhere.
   It is not eligible in its entirety, as it contains function calls.
   But its subterm (x*y + x*y) is eligible, so this subpart will be commoned out.
   In short, the CSE pass, as implemented, always try to common out all the redundant computations that are illegible, and it does it by looking from bigger subterms to smaller subterms.
   
   Does that answers your question?
   Please do not hesitate to tell me if you need help for trying the pass.
   
   Kind regards.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] wrongtest edited a comment on pull request #9482: Implementation of Common Subexpression Elimination for TIR

Posted by GitBox <gi...@apache.org>.
wrongtest edited a comment on pull request #9482:
URL: https://github.com/apache/tvm/pull/9482#issuecomment-972538596


   I have another three questions about the pass.
   
   1.  Since function calls are non-eligible, it seems to prevent many common non-stateful builtin computations from being optimized, such as `tir.floordiv`, `tir.shift_right`, `tir.likely`, `tir.log` and etc.  TVM already have op annotation like `PURE`:
       https://github.com/apache/tvm/blob/main/include/tvm/tir/op_attr_types.h#L60-L67
       what if to add an annotation type like `DETERMINISTIC` to mark the function return same output on same input. It should be 
    safe to eliminate common calls if the call op is both pure and deterministic.
   
   2. Is the let binding semantic in tir explicitly what we want? Is the value expr is specified to be evaluated at let scope entry? otherwise there may be no reuse effect if let-binding allow the value expr get lazily expanded in let body.
   
   3. Is this possible for some complex (but common) expr to be lifted out of control flow scope? eg
       ```python
       if cond0:
           tir.evaluate(complex_e)  # slow path
       else:
          if cond2: 
               tir.evaluate(complex_e)  # slow path
          else:
               ...  # general fast path
       ```
       after CSE maybe comes to
       ```python
       x = tir.Var("int32")
       with tir.let(x, complex_e):  # always evaluate slow path
           if cond0:
               tir.evaluate(x) 
           else:
               if cond2: 
                   tir.evaluate(x)
               else:
                   ... # general fast path
        ```


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] FranckQC edited a comment on pull request #9482: Implementation of Common Subexpression Elimination for TIR

Posted by GitBox <gi...@apache.org>.
FranckQC edited a comment on pull request #9482:
URL: https://github.com/apache/tvm/pull/9482#issuecomment-969151523


   > 
   > 
   > Hi @FranckQC, I wanted TIR-level CSE for a long time, so very excited to see this!
   > 
   > What I wanted to do is to eliminate common expressions that span across the host and GPU - for example, in GPU `sort` kernel, I need to make `log2(N)` GPU kernel calls from the host to sort the input bottom up. In principle, `log2(N)` needs to computed once by the host and pass to the GPU kernel, but since we cannot CSE `log2(N)` expression that appears both in the host and GPU kernel, right now the GPU sort kernel is littered with `log2(N)` compute like this (note a lot of calls to `call_spirv_pure_glsl450` which is totally unnecessary if we had TIR-level CSE) https://gist.github.com/masahi/7a755ef67009e1a836e3212c53cf496f
   > 
   > Is this PR going to solve my problem?
   
   Hi @masahi 
   Thanks a lot for the kind words! I'm happy to read that this new pass might be useful to you.
   In principle, every redundant subterms that are eligible for being commoned out (i.e, which does not contain function calls, etc) will be commoned out. There are also a few restrictions which are due to some specifics of TVM, but these are rare.
   
   Unfortunately, on the TIR code snippet that you have uploaded, it seems that the subexpression that is redundant is just a function call. These can't be commoned out into a new variable by the CSE pass as it does not have any guarantee that the function has no side effects, meaning that it will always produce the same outputs for the same inputs. Without this guarantee, commoning out such function calls could change the program's semantics, so it's not done as preserving the semantics of the program is vital.
   
   I can imagine that for functions that are guaranteed to not do any side effects (and which are therefore "functions" in the mathematical sense of the term), we could relax this restriction, but that would be an extension to implement in the future. And it would rely on some "NoSideEffect" tag on functions.
   
   However, please note that if you had some other redundancies, this CSE pass would common out whatever redundancies you have that are eligible.
   For instance : 
   Assume you have the term (f(42) + f(42)) + (x*y + x*y) that appear somewhere.
   It is not eligible in its entirety, as it contains functions call.
   But its subterm (x*y) + x*y) is eligible, so this subpart will be commoned out.
   
   Does that answers your question?
   Please do not hesitate to tell me if you need help for trying the pass.
   
   Kind regards.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] FranckQC commented on a change in pull request #9482: Implementation of Common Subexpression Elimination for TIR

Posted by GitBox <gi...@apache.org>.
FranckQC commented on a change in pull request #9482:
URL: https://github.com/apache/tvm/pull/9482#discussion_r796192872



##########
File path: src/tir/transforms/common_subexpr_elim_tools.h
##########
@@ -0,0 +1,216 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file common_subexpr_elim_tools.h
+ * \brief Interface of analysis tools and utility functions used
+           by the Common Subexpression Elimination (CSE) pass.
+ */
+
+#ifndef TVM_TIR_TRANSFORMS_COMMON_SUBEXPR_ELIM_TOOLS_H_
+#define TVM_TIR_TRANSFORMS_COMMON_SUBEXPR_ELIM_TOOLS_H_
+
+#include <tvm/runtime/container/string.h>
+#include <tvm/tir/analysis.h>  // For the ExprDeepEqual analysis
+#include <tvm/tir/expr.h>
+#include <tvm/tir/expr_functor.h>
+#include <tvm/tir/stmt.h>
+#include <tvm/tir/stmt_functor.h>  // For the class StmtExprVisitor
+
+#include <unordered_map>  // For the hashtable datatype
+#include <vector>
+
+#include "../../../3rdparty/dmlc-core/include/dmlc/optional.h"
+
+namespace tvm {
+namespace tir {
+
+/*!
+ * \brief A table of computations is a hashtable which associates to each expression being computed
+          a number (which is the number of time that it is computed)
+					It is important to note that the hash used is a StructuralHash (and not an ObjectPtrHash)
+					as we need to hash similarly deeply equal terms.
+					The comparison used is ExprDeepEqual, which is stricter than StructuralEqual (as it does
+					not do variables remapping), so it is compatible with StructuralHash (intended to be used
+					with StructuralEqual).
+ */
+using TableOfComputations = std::unordered_map<PrimExpr, size_t, StructuralHash, ExprDeepEqual>;
+
+/*!
+ * \brief A cache of computations is made of a pair of two hashtables, which respectively associate
+          to each statement or expression of the program its table of computations. Its purpose is
+          to avoid the CSE pass from recomputing repeatedly the same tables of computations.
+ */
+struct CacheOfComputations {
+  // Part of the cache for statements
+  // It maps each known statement to its table of computations
+  std::unordered_map<Stmt, TableOfComputations, ObjectPtrHash, ObjectPtrEqual>
+      cache_stmt_table_computations_;
+
+  // Part of the cache for expressions
+  // It maps each known expression to its table of computations
+  std::unordered_map<PrimExpr, TableOfComputations, ObjectPtrHash, ObjectPtrEqual>
+      cache_expr_table_computations_;
+};
+
+/*!
+ * \brief Visitor which returns in a hashtable the (syntatic) computations done by an expression
+          or by a statement.
+ * \note Computations here are considered syntactically, meaning that semantically equivalent
+          computations that are not syntactically the same are not merged together.
+ */
+class ComputationsDoneBy : public StmtExprVisitor {

Review comment:
       To make sure I understand well : you would prefer to move everything that is not being used by common_subexpr_elim.cc nor by common_subexpr_elim.h in the _tools.cc file, thus, right? Even the class definitions like ComputationDoneBy and DirectSubexpr?
   
   To be honest, the files _tools.h and _tools.cc  have been initially constructed as a proper analysis/transformation (with the declarations in the .h and the implementations in the .cc), so that one could still reuse the auxiliary analysis ComputationDoneBy and DirectSubexpr and the other things, if that is needed at some point for another pass. Each of these auxiliary analysis clearly did not deserved to be in a proper file, so I thought that was a middle-ground.
   
   Putting them (and everything not directly used by the main .h and .cc) into the _tools.cc file would mean that nobody else will ever be able to use these things.
   
   I don't have any strong opion on that, but I don't see any problem with having everything that is defined in the _tools.cc being declared in the _tools.h (as is the case now) even if nobody needs this interface at the present time.
   I even tend to find that it helps when discovering a new code to get a broad idea of the things by looking at the headers first.
   No issue with me either way of course.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi commented on pull request #9482: Implementation of Common Subexpression Elimination for TIR

Posted by GitBox <gi...@apache.org>.
masahi commented on pull request #9482:
URL: https://github.com/apache/tvm/pull/9482#issuecomment-1031840409


   Yes the doc was fixed like 30 min (I also got the same error in my PR), you need to push again. But there is still a Windows issue.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi edited a comment on pull request #9482: Implementation of Common Subexpression Elimination for TIR

Posted by GitBox <gi...@apache.org>.
masahi edited a comment on pull request #9482:
URL: https://github.com/apache/tvm/pull/9482#issuecomment-1031864766


   I would check if `vta.build(...)` is really honoring `disabled_passes=[...]` when it runs. You can build VTA and run this test locally. 
   
   Also, what do you think about running the TIR CSE pass optional? I have a feeling that it might bring some disruption if enabled globally, without much benefit for standard backends such as LLVM and CUDA which, until today, didn't see a need for TIR-level CSE (the example I gave, CSE across CPU & GPU, is the only good use case I'm aware of). That will also remove the need to fix existing tests.
   
   One way is to make TIR-CSE run only when `opt-level=4`, but I hope there is a better way to enable certain TIR passes (i.e. opposite of `disabled_passes=[...]`). 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] FranckQC commented on pull request #9482: Implementation of Common Subexpression Elimination for TIR

Posted by GitBox <gi...@apache.org>.
FranckQC commented on pull request #9482:
URL: https://github.com/apache/tvm/pull/9482#issuecomment-1034443098


   Sorry I forgot to answer to the other parts of your message @mbs-octoml . Many thanks for it by the way!
   
   > Thanks so much for the clear and complete comments, very much appreciate that.
   > Echo @Hzfengsy suggestion to switch to TVMScript for your tests -- perhaps just pretty printing what you already constructs would be a short cut.
   
   Yes I didn't know about TVMScript before. When writing the tests, I initially stared at some other test files and got inspiration from there. The ones I've been looking at  might not have been the most up-to-date way of doing things, sorry for that! :-(
   I guess we can still improve the tests later on, but they still accomplish their main function for now, which is the most important I think, even though they are probably a little bit more verbose than they would be using TVMScript, where I could just directly write the expected TIR code instead of unfolding the TIR code that the CSE pass has produced.
    
   > The code duplication across PrimExpr & Stmnt in CommonSubexpressionElimintor is unfortunate but suspect removing that would only obscure things.
   
   Yes, I agree that this duplication is a little bit unfortunate. @masahi did pointed it out [here](https://github.com/apache/tvm/pull/9482#discussion_r790201064). I was also a little bit annoyed with it at the beginning. So I tried to factorize it out a few times, including an attempt described in my answer [here](https://github.com/apache/tvm/pull/9482#discussion_r797243979). But all my attempt ended up with something much too complicated for what we would gain.
   
   In fact, we just happen to want to do almost exactly the same treatment for Expr() and Stmt() from an algorithmic point of view, but from a data-type point of view, quite a things are still different type-wise. That's a pretty rare situation. In the end, I decided to not force things, and to leave it like that.
   And the positive point of having the VisitStmt() and VisitExpr() not factorized by some weird magic is that we can still easily customize the algorithm if at some point we need to, between expressions and statements (I can't think of any reason we would want that, but still! :) )
   
   Many thanks again!
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] FranckQC edited a comment on pull request #9482: Implementation of Common Subexpression Elimination for TIR

Posted by GitBox <gi...@apache.org>.
FranckQC edited a comment on pull request #9482:
URL: https://github.com/apache/tvm/pull/9482#issuecomment-972579646


   > 
   > 
   > Hi @FranckQC, I am also very interested in the CSE pass. In our circumstances we suffer from duplicate index computations produced by loop unroll, like
   > 
   > ```
   > A[i*256 + j*16 + 0] = B[i*256 + j*16 + 4096]
   > A[i*256 + j*16 + 1] = B[i*256 + j*16 + 4097]
   > ...
   > ```
   > 
   > We have to depend on target backend optimize abilities which may or maynot optimize them out. It would be great if CSE can handle part of these things in TIR level.
   
   Hi @wrongtest 
   Thank you, I'm very glad to know that the pass might be useful to you too!
   
   Yes, these kind of redundancies should definitely be commoned out by this new CSE pass.
   For the specific example that you have given, I expect the result to be something like :
   
   ```
   Let cse_var_1 = i*256 + j*16 in
       A[cse_var_1 + 0] = B[cse_var_1 + 4096];
       A[cse_var_1 + 1] = B[cse_var_1 + 4097];
   ```
       
   Do not hesitate to try the pass out, and to let me know if it does what we hope. I'd be happy to help of course if that's needed.
   
   Best regards.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] FranckQC edited a comment on pull request #9482: Implementation of Common Subexpression Elimination for TIR

Posted by GitBox <gi...@apache.org>.
FranckQC edited a comment on pull request #9482:
URL: https://github.com/apache/tvm/pull/9482#issuecomment-972579646


   > 
   > 
   > Hi @FranckQC, I am also very interested in the CSE pass. In our circumstances we suffer from duplicate index computations produced by loop unroll, like
   > 
   > ```
   > A[i*256 + j*16 + 0] = B[i*256 + j*16 + 4096]
   > A[i*256 + j*16 + 1] = B[i*256 + j*16 + 4097]
   > ...
   > ```
   > 
   > We have to depend on target backend optimize abilities which may or maynot optimize them out. It would be great if CSE can handle part of these things in TIR level.
   
   Hi @wrongtest 
   Thank you, I'm very glad to know that the pass might be useful to you too!
   
   Yes, these kind of redundancies should definitely be commoned out by this new CSE pass.
   For the specific example that you have given, I expect the result to be something like :
   
   Let cse_var_1 = i*256 + j*16 in
       A[cse_var_1 + 0] = B[cse_var_1 + 4096];
       A[cse_var_1 + 1] = B[cse_var_1 + 4097];
       
   Do not hesitate to try the pass out, and to let me know if it does what we hope. I'd be happy to help of course if that's needed.
   
   Best regards.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] FranckQC edited a comment on pull request #9482: Implementation of Common Subexpression Elimination for TIR

Posted by GitBox <gi...@apache.org>.
FranckQC edited a comment on pull request #9482:
URL: https://github.com/apache/tvm/pull/9482#issuecomment-972608426


   Many thanks for the interesting questions and comments, I really appreciate it.
   
   > 
   > 
   > I have another three questions about the pass.
   > 
   >     1. Since function calls are non-eligible, it seems to prevent many common non-stateful builtin computations from being optimized, such as `tir.floordiv`, `tir.shift_right`, `tir.likely`, `tir.log` and etc.  TVM already have op annotation like `PURE`:
   >        https://github.com/apache/tvm/blob/main/include/tvm/tir/op_attr_types.h#L60-L67
   >        what if to add an annotation type like `DETERMINISTIC` to mark the function return same output on same input. It should be
   >        safe to eliminate common calls if the call op is both pure and deterministic.
   
   Yes, I definitely agree that we could, in the future, relax the restriction which currently prevent function calls from being commoned-out, for some functions that we know to be pure, including some builtin.
   I think just the "pure" tag would be enough for this purpose, if it has the meaning that I imagine, which is this one (quoted from Wikipedia) :
   
   #################
   a pure function is a function that has the following properties:
   
   1. The function return values are identical for identical arguments (no variation with local static variables, non-local variables, mutable reference arguments or input streams).
   2. The function application has no side effects (no mutation of local static variables, non-local variables, mutable reference arguments or input/output streams).
   
   Thus a pure function is a computational analogue of a mathematical function. Some authors, particularly from the imperative language community, use the term "pure" for all functions that just have the above property 2. 
   #################
   
   What is crucial for us for commoning-out function calls is only the condition 1. As long as the function returns always the same output for the same input, we are fine with putting the result of applying the function to some input into a variable (or any kind of "cache").
   
   Please note that in the Wikipedia article they later mention that the condition 2 is needed for perfoming CSE, but they are talking about a different problem that we are discussing (and they aren't very precise about it). They do not talk about commoning-out function calls here. They are talking about commoning-out some redundant terms, between which some function calls happen, as in :
   x = a+b+c;
   f(); // Here, if f can change the content of some variables like a, b or c, it prevents the CSE on the redundant term (a+b+c)
   y = a+b+c;
   
   But this kind of consideration does not concern us, as TIR is always in a (weak) SSA form : the content of variables won't be updated, and can be assumed to be the same everywhere through the execution of the program. Thus, only condition 1 matters to us.
   
   If the kPure tag has both the meaning of 1 and 2, then we will be able to simply use it in the future when relaxing the condition.
   If it's only the condition 2, then we will indeed need another tag.
   
   >     2. Is the let binding semantic in tir explicitly what we want? Is the value expr is specified to be evaluated at let scope entry? otherwise there may be no reuse effect if let-binding allow the value expr get lazily expanded in let body.
   
   I must admit that I am not entirely sure now that I think of it... Anyone to confirm on that please?
   
   >     3. Is this possible for some complex (but common) expr to be lifted out of control flow scope? eg
   >        ```python
   >        if cond0:
   >            tir.evaluate(complex_e)  # slow path
   >        else:
   >           if cond2: 
   >                tir.evaluate(complex_e)  # slow path
   >           else:
   >                ...  # general fast path
   >        ```
   >        
   >        
   >            
   >              
   >            
   >        
   >              
   >            
   >        
   >            
   >          
   >        after CSE maybe comes to
   >        ```python
   >        x = tir.Var("int32")
   >        with tir.let(x, complex_e):  # always evaluate slow path
   >            if cond0:
   >                tir.evaluate(x) 
   >            else:
   >                if cond2: 
   >                    tir.evaluate(x)
   >                else:
   >                    ... # general fast path
   >        ```
   
   Yes, that is possible. Although, now that I think of it, it would probably better to not do it in the exact example that you provided, as the expression `complex_e` of the slow path now becomes always evaluated, even if we end up going in the general fast branch...
   We should only lift things through "if" statements if both the then branch and the else branch computes it. This restriction doesn't improve the semantics preservation (which was previously fine), but it will avoid this kind of anti-optimization. That should be easily fixable.
   
   Many thanks!


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi commented on a change in pull request #9482: Implementation of Common Subexpression Elimination for TIR

Posted by GitBox <gi...@apache.org>.
masahi commented on a change in pull request #9482:
URL: https://github.com/apache/tvm/pull/9482#discussion_r790170969



##########
File path: src/tir/transforms/common_subexpr_elim_tools.h
##########
@@ -0,0 +1,216 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file common_subexpr_elim_tools.h
+ * \brief Interface of analysis tools and utility functions used
+           by the Common Subexpression Elimination (CSE) pass.
+ */
+
+#ifndef TVM_TIR_TRANSFORMS_COMMON_SUBEXPR_ELIM_TOOLS_H_
+#define TVM_TIR_TRANSFORMS_COMMON_SUBEXPR_ELIM_TOOLS_H_
+
+#include <tvm/runtime/container/string.h>
+#include <tvm/tir/analysis.h>  // For the ExprDeepEqual analysis
+#include <tvm/tir/expr.h>
+#include <tvm/tir/expr_functor.h>
+#include <tvm/tir/stmt.h>
+#include <tvm/tir/stmt_functor.h>  // For the class StmtExprVisitor
+
+#include <unordered_map>  // For the hashtable datatype
+#include <vector>
+
+#include "../../../3rdparty/dmlc-core/include/dmlc/optional.h"
+
+namespace tvm {
+namespace tir {
+
+/*!
+ * \brief A table of computations is a hashtable which associates to each expression being computed
+          a number (which is the number of time that it is computed)
+					It is important to note that the hash used is a StructuralHash (and not an ObjectPtrHash)
+					as we need to hash similarly deeply equal terms.
+					The comparison used is ExprDeepEqual, which is stricter than StructuralEqual (as it does
+					not do variables remapping), so it is compatible with StructuralHash (intended to be used
+					with StructuralEqual).
+ */
+using TableOfComputations = std::unordered_map<PrimExpr, size_t, StructuralHash, ExprDeepEqual>;
+
+/*!
+ * \brief A cache of computations is made of a pair of two hashtables, which respectively associate
+          to each statement or expression of the program its table of computations. Its purpose is
+          to avoid the CSE pass from recomputing repeatedly the same tables of computations.
+ */
+struct CacheOfComputations {
+  // Part of the cache for statements

Review comment:
       I prefer
   
   `TableOfComputations` -> `ComputationTable`
   
   `CacheOfComputations` -> `ComputationCache` 




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi commented on pull request #9482: Implementation of Common Subexpression Elimination for TIR

Posted by GitBox <gi...@apache.org>.
masahi commented on pull request #9482:
URL: https://github.com/apache/tvm/pull/9482#issuecomment-1029300708


   @FranckQC You can disable the offending lint rule by `# pylint: disable=...`


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] FranckQC commented on a change in pull request #9482: Implementation of Common Subexpression Elimination for TIR

Posted by GitBox <gi...@apache.org>.
FranckQC commented on a change in pull request #9482:
URL: https://github.com/apache/tvm/pull/9482#discussion_r796192872



##########
File path: src/tir/transforms/common_subexpr_elim_tools.h
##########
@@ -0,0 +1,216 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file common_subexpr_elim_tools.h
+ * \brief Interface of analysis tools and utility functions used
+           by the Common Subexpression Elimination (CSE) pass.
+ */
+
+#ifndef TVM_TIR_TRANSFORMS_COMMON_SUBEXPR_ELIM_TOOLS_H_
+#define TVM_TIR_TRANSFORMS_COMMON_SUBEXPR_ELIM_TOOLS_H_
+
+#include <tvm/runtime/container/string.h>
+#include <tvm/tir/analysis.h>  // For the ExprDeepEqual analysis
+#include <tvm/tir/expr.h>
+#include <tvm/tir/expr_functor.h>
+#include <tvm/tir/stmt.h>
+#include <tvm/tir/stmt_functor.h>  // For the class StmtExprVisitor
+
+#include <unordered_map>  // For the hashtable datatype
+#include <vector>
+
+#include "../../../3rdparty/dmlc-core/include/dmlc/optional.h"
+
+namespace tvm {
+namespace tir {
+
+/*!
+ * \brief A table of computations is a hashtable which associates to each expression being computed
+          a number (which is the number of time that it is computed)
+					It is important to note that the hash used is a StructuralHash (and not an ObjectPtrHash)
+					as we need to hash similarly deeply equal terms.
+					The comparison used is ExprDeepEqual, which is stricter than StructuralEqual (as it does
+					not do variables remapping), so it is compatible with StructuralHash (intended to be used
+					with StructuralEqual).
+ */
+using TableOfComputations = std::unordered_map<PrimExpr, size_t, StructuralHash, ExprDeepEqual>;
+
+/*!
+ * \brief A cache of computations is made of a pair of two hashtables, which respectively associate
+          to each statement or expression of the program its table of computations. Its purpose is
+          to avoid the CSE pass from recomputing repeatedly the same tables of computations.
+ */
+struct CacheOfComputations {
+  // Part of the cache for statements
+  // It maps each known statement to its table of computations
+  std::unordered_map<Stmt, TableOfComputations, ObjectPtrHash, ObjectPtrEqual>
+      cache_stmt_table_computations_;
+
+  // Part of the cache for expressions
+  // It maps each known expression to its table of computations
+  std::unordered_map<PrimExpr, TableOfComputations, ObjectPtrHash, ObjectPtrEqual>
+      cache_expr_table_computations_;
+};
+
+/*!
+ * \brief Visitor which returns in a hashtable the (syntatic) computations done by an expression
+          or by a statement.
+ * \note Computations here are considered syntactically, meaning that semantically equivalent
+          computations that are not syntactically the same are not merged together.
+ */
+class ComputationsDoneBy : public StmtExprVisitor {

Review comment:
       To make sure I understand well : you would prefer to move everything that is not being used by common_subexpr_elim.cc nor by common_subexpr_elim.h in the _tools.cc file, thus, right? Even the class definitions like ComputationDoneBy and DirectSubexpr?
   
   To be honest, the files _tools.h and _tools.cc  have been initially constructed as a proper analysis/transformation (with the declarations in the .h and the implementations in the .cc), so that one could still reuse the auxiliary analysis ComputationDoneBy and DirectSubexpr and the other things, if that is needed at some point for another pass. Each of these auxiliary analysis clearly did not deserved to be in a proper file, so I thought that was a middle-ground.
   
   Putting them (and everything not directly used by the main .h and .cc) into the _tools.cc file would mean that nobody else will ever be able to use these things.
   
   I don't have any strong opion on that, but I don't see any problem with having everything that is defined in the _tools.cc being declared in the _tools.h even if nobody needs this interface at the present time.
   I even tend to find that it helps when discovering a new code to get a broad idea of the things by looking at the headers first.
   No issue with me either way of course.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi commented on a change in pull request #9482: Implementation of Common Subexpression Elimination for TIR

Posted by GitBox <gi...@apache.org>.
masahi commented on a change in pull request #9482:
URL: https://github.com/apache/tvm/pull/9482#discussion_r796196735



##########
File path: src/tir/transforms/common_subexpr_elim_tools.h
##########
@@ -0,0 +1,216 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file common_subexpr_elim_tools.h
+ * \brief Interface of analysis tools and utility functions used
+           by the Common Subexpression Elimination (CSE) pass.
+ */
+
+#ifndef TVM_TIR_TRANSFORMS_COMMON_SUBEXPR_ELIM_TOOLS_H_
+#define TVM_TIR_TRANSFORMS_COMMON_SUBEXPR_ELIM_TOOLS_H_
+
+#include <tvm/runtime/container/string.h>
+#include <tvm/tir/analysis.h>  // For the ExprDeepEqual analysis
+#include <tvm/tir/expr.h>
+#include <tvm/tir/expr_functor.h>
+#include <tvm/tir/stmt.h>
+#include <tvm/tir/stmt_functor.h>  // For the class StmtExprVisitor
+
+#include <unordered_map>  // For the hashtable datatype
+#include <vector>
+
+#include "../../../3rdparty/dmlc-core/include/dmlc/optional.h"
+
+namespace tvm {
+namespace tir {
+
+/*!
+ * \brief A table of computations is a hashtable which associates to each expression being computed
+          a number (which is the number of time that it is computed)
+					It is important to note that the hash used is a StructuralHash (and not an ObjectPtrHash)
+					as we need to hash similarly deeply equal terms.
+					The comparison used is ExprDeepEqual, which is stricter than StructuralEqual (as it does
+					not do variables remapping), so it is compatible with StructuralHash (intended to be used
+					with StructuralEqual).
+ */
+using TableOfComputations = std::unordered_map<PrimExpr, size_t, StructuralHash, ExprDeepEqual>;
+
+/*!
+ * \brief A cache of computations is made of a pair of two hashtables, which respectively associate
+          to each statement or expression of the program its table of computations. Its purpose is
+          to avoid the CSE pass from recomputing repeatedly the same tables of computations.
+ */
+struct CacheOfComputations {
+  // Part of the cache for statements
+  // It maps each known statement to its table of computations
+  std::unordered_map<Stmt, TableOfComputations, ObjectPtrHash, ObjectPtrEqual>
+      cache_stmt_table_computations_;
+
+  // Part of the cache for expressions
+  // It maps each known expression to its table of computations
+  std::unordered_map<PrimExpr, TableOfComputations, ObjectPtrHash, ObjectPtrEqual>
+      cache_expr_table_computations_;
+};
+
+/*!
+ * \brief Visitor which returns in a hashtable the (syntatic) computations done by an expression
+          or by a statement.
+ * \note Computations here are considered syntactically, meaning that semantically equivalent
+          computations that are not syntactically the same are not merged together.
+ */
+class ComputationsDoneBy : public StmtExprVisitor {

Review comment:
       Ah ok, if these functions / classes are intended to be reusable and not internal to `common_subexpr_elim_tools.cc`, I'm fine with that. I'm more of a "minimize public API" camp.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] FranckQC commented on a change in pull request #9482: Implementation of Common Subexpression Elimination for TIR

Posted by GitBox <gi...@apache.org>.
FranckQC commented on a change in pull request #9482:
URL: https://github.com/apache/tvm/pull/9482#discussion_r796246946



##########
File path: src/tir/transforms/replace_expr_selected.h
##########
@@ -0,0 +1,75 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file replace_expr_selected.h
+ * \brief Interface of the pass that replaces in a statement
+           or expression all the subexpressions that are selected
+           with a predicate by another expression.
+ */
+
+#ifndef TVM_TIR_TRANSFORMS_REPLACE_SELECTED_H_
+#define TVM_TIR_TRANSFORMS_REPLACE_SELECTED_H_
+
+#include <tvm/tir/expr.h>
+#include <tvm/tir/expr_functor.h>
+#include <tvm/tir/stmt.h>
+#include <tvm/tir/stmt_functor.h>  // For the class StmtExprMutator
+
+namespace tvm {
+namespace tir {
+
+/*!
+ * \brief Mutator for replacing the expressions selected by a predicate in a statement and/or
+          in an expression, which only replace inside of nodes in which it is allowed to perform
+          replacecements (given by a second predicate)
+ */
+class ReplaceExprSelected : public StmtExprMutator {

Review comment:
       Thanks!




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] FranckQC commented on a change in pull request #9482: Implementation of Common Subexpression Elimination for TIR

Posted by GitBox <gi...@apache.org>.
FranckQC commented on a change in pull request #9482:
URL: https://github.com/apache/tvm/pull/9482#discussion_r796255239



##########
File path: src/tir/analysis/check_contains.h
##########
@@ -0,0 +1,60 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file check_contains.h
+ * \brief Interface of the analysis that tells if an expression contains
+           a node that satisfies a given predicate.
+ */
+
+#ifndef TVM_TIR_ANALYSIS_CHECK_CONTAINS_H_
+#define TVM_TIR_ANALYSIS_CHECK_CONTAINS_H_
+
+#include <tvm/tir/expr.h>
+#include <tvm/tir/stmt_functor.h>  // For the class StmtExprVisitor
+
+namespace tvm {
+namespace tir {
+
+/*!
+ * \brief Visitor which tells if a given expression or statement contains a subexpression
+          that satisfies a given predicate
+ */
+class CheckContains : public StmtExprVisitor {

Review comment:
       I'll see if VarTouch is a special case of my new analyzer CheckContains.
   But if that's the case, I would prefer to do the appropriate changes to VarTouch (or the occurrences that use it), into another PR, in order to let things clear and minimal.
   




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi commented on a change in pull request #9482: Implementation of Common Subexpression Elimination for TIR

Posted by GitBox <gi...@apache.org>.
masahi commented on a change in pull request #9482:
URL: https://github.com/apache/tvm/pull/9482#discussion_r796515819



##########
File path: include/tvm/tir/transform.h
##########
@@ -458,6 +458,14 @@ TVM_DLL Pass FlattenBuffer();
  */
 TVM_DLL Pass TextureFlatten();
 
+/*!
+ * \brief Implements a Common Subexpression Elimination (CSE)
+ *        which introduces let-in bindings for duplicated sub-expressions.
+ * \param enable_cse_tir Whether common subexpression elimination is enabled.
+ * \return The pass.
+ */
+TVM_DLL Pass CommonSubexprElim(bool enable_cse_tir = true);

Review comment:
       I see, since it is under `tvm::tir::transform` namespace, we might not need to say TIR in the function name. Up to you.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] FranckQC commented on a change in pull request #9482: Implementation of Common Subexpression Elimination for TIR

Posted by GitBox <gi...@apache.org>.
FranckQC commented on a change in pull request #9482:
URL: https://github.com/apache/tvm/pull/9482#discussion_r797243979



##########
File path: src/tir/transforms/common_subexpr_elim.cc
##########
@@ -0,0 +1,601 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file common_subexpr_elim.cc
+ * \brief Implementation of the Common Subexpressions Elimination (CSE) pass
+           which rewrites statements and expressions in order to eliminate
+           redundant computations. In order to achieve that, common (sub-)
+           expressions are introduced into variables with let-in bindings,
+           and the places where the expression was used are replaced with
+           the freshly introduced variable.
+ */
+
+#include "common_subexpr_elim.h"
+
+#include <tvm/ir/transform.h>  // For the class Pass and the class PassContext
+#include <tvm/runtime/container/array.h>
+#include <tvm/runtime/container/string.h>
+#include <tvm/tir/analysis.h>  // For the analysis which gives the size of an expr
+#include <tvm/tir/expr.h>
+#include <tvm/tir/expr_functor.h>
+#include <tvm/tir/function.h>  // For the class PrimFunc
+#include <tvm/tir/stmt.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>  // For the decl of the function returning the pass
+
+#include <algorithm>  // For the algorithm std::find
+#include <iostream>
+#include <unordered_map>  // For the hashtable datatype
+#include <utility>        // For std::pair and std::move
+#include <vector>
+
+#include "../analysis/check_contains.h"  // For the visitor CheckContains
+#include "common_subexpr_elim_tools.h"   // For the auxiliary analysis (visitors) and tools
+#include "replace_expr_selected.h"       // For the mutator ReplaceExprSelected
+
+namespace tvm {
+namespace tir {
+
+/*!
+ * \brief Check whether a computation is forbidden for being treated by the CSE pass.
+          The important thing about forbidden computations is that not only we won't want
+          to collect them for the CSE pass, but we also won't even want to collect computations
+          that contain them.
+          The reason is that reusing such computations would change the semantics of the program,
+          and therefore before doing any introduction of variable or any reuse of already introduced
+          variables, we will make sure that the computation being considered is not forbidden, and
+          that it does not even contain a forbidden computation.
+ * \param expr The expression to check
+ * \return Whether `expr` is a forbidden computation or not
+ */
+bool CommonSubexpressionEliminator::ForbiddenComputation(const PrimExpr& expr) {
+  // Function calls, loads and buffer loads are absolutely forbidden as introducing them into
+  // variables would change the semantics of the program.
+  return (expr.as<CallNode>() != nullptr || expr.as<LoadNode>() != nullptr ||
+          expr.as<BufferLoadNode>() != nullptr);
+}
+
+/*!
+ * \brief Predicate used for verifying that a computation is eligible for being treated by
+          the CSE pass, i.e. for being introduced into a variable / for being replaced by a
+          variable.
+          Being eligible is a conjunction of a few conditions, like not being an atom (constant
+          or variable), not being a forbidden node, not containing a forbidden node, etc.
+ * \param expr The expression to check
+ * \return Whether `expr` is an eligible computation or not
+ */
+bool CommonSubexpressionEliminator::IsEligibleComputation(const PrimExpr& expr) {
+  return (
+      // In order to be eligible, the given expression should not be a constant
+      (expr.as<IntImmNode>() == nullptr) && (expr.as<FloatImmNode>() == nullptr) &&
+      (expr.as<StringImmNode>() == nullptr)
+      // and it should not be a variable
+      && (expr.as<VarNode>() == nullptr)
+      // and it should not be a forbidden computation (function calls and loads)
+      && (!ForbiddenComputation(expr))
+      // and it should not even contain a forbidden computation (function calls and loads)
+      // the reason is that we don't want to register expressions like (x + f(y)) or
+      // (x + Mem[i]) as introducing them into variables could change the semantics
+      && (!CheckContains::ExprContains(expr, ForbiddenComputation))
+      // and it should not be a ramp node or a broadcast node due to some internals TVM
+      // constraints (which check for these node explicitely without performing any
+      // evaluation first, so if they have been put into variables it fails)
+      && (expr.as<RampNode>() == nullptr) && (expr.as<BroadcastNode>() == nullptr));
+}
+
+/*!
+ * \brief Predicate used (when considering eligible computations) for only diving into
+          expressions that are allowed to contain eligible computations. Customize this predicate
+          if you want to make it forbidden to rewrite inside a specific node, like inside
+          a Load node for instance.
+ * \param expr The expression to check
+ * \return Whether `expr` can contain some eligible computations or not, and therefore
+             if recursing inside `expr` is necessary.
+ */
+bool CommonSubexpressionEliminator::CanContainEligibleComputations(const PrimExpr& expr) {
+  // Uncomment the next line to prevent the collection and the replacement of eligible computations
+  // inside the index of Load nodes. We initially thought that this would be needed in order to
+  // not harm the indexing mode of the CPU, but as we are still far from ASM code, we
+  // finally want to perform such simplifications, which tend to happen fairly frequently.
+
+  // return ( (expr.as<LoadNode>() == nullptr) && (expr.as<BufferLoadNode>() == nullptr) )
+  return true;
+};
+
+/*!
+ * \brief Generates a new fresh variable, whose name will be cse_var_i.
+ * \param type_annotation The type of the new variable to generate
+ * \return A new variable of type `type_annotation` called cse_var_i where i is the first available
+            integer.
+ */
+Var CommonSubexpressionEliminator::GenerateNewVar(DataType type_annotation) {
+  // Increase `num_last_try_` for this new attempt
+  num_last_try_++;
+  // Builds the variable name, which is sce_var_i where i will go up from 1
+  std::string prefix = "cse_var_";
+  std::string name = prefix.append(std::to_string(num_last_try_));
+  // Builds a String using the std::string
+  String string_name(name);
+
+  // Check that the name that we want to use for the new variable isn't already being used
+  // (names don't really have to be unique as they are just hints, and having the same name
+  // doesn't means that it's the same variable, but it's clearer for dumps)
+  if (UsesVarName::StmtUsesVarName(initial_body_, string_name)) {
+    // If the name is already used, call ourselves recursively for trying with the next one
+    return GenerateNewVar(type_annotation);
+  }
+
+  // Increase `nb_var_` for this new generation of variable that we have just done
+  nb_var_++;
+
+  // Return a new Variable using the name built and the given type_annotation
+  return (Var(string_name, type_annotation));
+}
+
+/*!
+ * \brief Gives the number of variables generated by the CSE on the current function
+           (i.e., getter for `nb_var_`).
+ * \return A copy of `nb_var_`
+ */
+int CommonSubexpressionEliminator::GetNbVarGenerated() { return nb_var_; }
+
+/*!
+ * \brief Toplevel (static) method that performs Common Subexpression Elimination on
+          a given statement (which should be the body of a PrimFunc). This method should be
+          called for each PrimFunc definition.
+ * \param stmt The statement of the function being analyzed, on which we want to perform CSE
+ * \param context_init The initial context, which should contain the formal parameters
+                          of the function being analyzed
+ * \return A new statement where CSE has been performed
+ */
+Stmt CommonSubexpressionEliminator::PerformCSE(const Stmt& stmt, const Context& context_init) {
+  // As this function is being called for each PrimFunc definition, we create a new instance
+  // for the one we are having now.
+  CommonSubexpressionEliminator common_subexpression_eliminator(stmt, context_init);
+  return common_subexpression_eliminator.VisitStmt(stmt);
+}
+
+/*!
+ * \brief Protected constructor of CommonSubexpressionEliminator.
+ * \param context_init The context at the begining of the CSE pass. It should contain the
+                        formal parameters of the function that will be analyzed
+ */
+CommonSubexpressionEliminator::CommonSubexpressionEliminator(const Stmt& stmt,
+                                                             const Context& context_init)
+    : initial_body_(stmt), context_(context_init) {}
+
+/*!
+ * \brief The method which overrides the generic dispatcher of StmtExprMutator.
+          Entry point to the common subexpression elimination mutator for expressions.
+ * \param expr The expression to mutate
+ */
+PrimExpr CommonSubexpressionEliminator::VisitExpr(const PrimExpr& expr) {
+  PrimExpr result = expr;
+
+  // Obtain the (syntactic) eligible computations done by the input expression, and keep it as
+  // a TableOfComputations, which is a mapping from PrimExpr to size_t, where the size_t is the
+  // number of time this exact syntactic computation is being computed.
+  TableOfComputations table_syntactic_comp_done_by_expr = ComputationsDoneBy::GetComputationsDoneBy(
+      expr, IsEligibleComputation, CanContainEligibleComputations);
+
+  // Transform the hashtable of *syntactic* eligible computations into a vector of pairs
+  // containing *semantic* entities, i.e. where equivalent computations are merged.
+  std::vector<std::pair<PrimExpr, size_t>> semantic_comp_done_by_expr =
+      SyntacticToSemanticComputations(table_syntactic_comp_done_by_expr);
+
+  // Sort the vector of semantic entities by decreasing size
+  std::sort(semantic_comp_done_by_expr.begin(), semantic_comp_done_by_expr.end(),
+            [](std::pair<PrimExpr, size_t> a, std::pair<PrimExpr, size_t> b) {
+              return (CalculateExprComplexity(a.first) > CalculateExprComplexity(b.first));
+            });
+
+  // For each computation done (considering them from biggest to smallest)
+  for (size_t i = 0; i < semantic_comp_done_by_expr.size(); i++) {
+    std::pair<PrimExpr, size_t>& computation_and_nb = semantic_comp_done_by_expr[i];
+
+    // The predicate later used (when doing replacements) to select expressions that are
+    // equivalent to the current computation (`computation_and_nb.first`)
+    std::function<bool(const PrimExpr&)> predicate_selector =
+        [computation_and_nb](const PrimExpr& current_expr) {
+          // `current_expr` should be equivalent to `computation_and_nb.first`, but we also check
+          // that `current_expr` is an eligible computation even if we know that
+          // `computation_and_nb.first` is eligible by construction, in case that one day the
+          // equivalence relation would not preserve the eligibility any more (even though that
+          // would probably be a very weird equivalence).
+          return (EquivalentTerms(current_expr, computation_and_nb.first) &&
+                  IsEligibleComputation(current_expr));
+        };
+
+    // See if there is a pair (`var`, `value`) in the context where `value` is semantically
+    // equivalent to `computation_and_nb.first`
+    auto it_on_var = std::find_if(
+        context_.begin(), context_.end(),
+        [computation_and_nb](const std::pair<Var, MaybeValue>& var_and_value) {
+          // Note : safe to call value() as we check has_value() just before
+          return (var_and_value.second.has_value() &&
+                  EquivalentTerms(var_and_value.second.value(), computation_and_nb.first));
+        });
+
+    // Case where we have a perfectly equivalent computation already available in a variable
+    // introduced (i.e, present in context_).
+    // Note that this case is needed when the user has written something like
+    // [let x = A in ....A...A...] : we need to be able to replace all the occurences of A by
+    // an already existing variable holding A, when such a variable happens to exist.
+    if (it_on_var != context_.end()) {
+      // Replace in the current `result` everything that is selected by the selector with
+      // the existing variable, without diving into expressions in which we don't have the
+      // right to dive.
+      result = ReplaceExprSelected::ReplaceExprSelectedInExpr(
+          result, predicate_selector, it_on_var->first, CanContainEligibleComputations);
+    } else {
+      // The current computation is not equivalent to a computation already done. We will
+      // need to see if we want to introduce it.
+
+      // --- Chunk needed for reusing the UndefinedVars() analysis ---
+      // 1 - Wraps the computation into a statement
+      Stmt computation_wrapped_in_stmt = Evaluate(computation_and_nb.first);
+      // 2.1 - Transform the context into a vector of variables instead of pairs
+      std::function<Var(const std::pair<Var, MaybeValue>&)> forget_value =
+          [](const std::pair<Var, MaybeValue>& pair) { return pair.first; };
+      std::vector<Var> vector_vars_known = VectorMap(context_, forget_value);
+      // 2.2 - Transform the std::vector into an Array
+      Array<Var> array_vars_known = Array<Var>(vector_vars_known);
+      // --- End of chunk needed for reusing the UndefinedVars() analysis ---
+
+      // We use the UndefinedVars() analysis to get the undefined vars of the computation
+      Array<Var> vars_undefined = UndefinedVars(computation_wrapped_in_stmt, array_vars_known);
+
+      // Check if we can introduce it : if it contains no undefined variables and if we want
+      // to introduce it according to the predicate
+      if (vars_undefined.empty() &&
+          PredicateIntroVarForComputation(computation_and_nb.first, computation_and_nb.second)) {
+        // Create a new variable for this computation
+        Var new_var = GenerateNewVar(computation_and_nb.first.dtype());
+        // Replace in the current `result` everything that is selected by the selector with
+        // the new variable, without diving into expressions in which we don't have the
+        // right to dive.
+        result = ReplaceExprSelected::ReplaceExprSelectedInExpr(result, predicate_selector, new_var,
+                                                                CanContainEligibleComputations);
+        // Build a let-in that introduces the new variable in the current `result`
+        result = Let(new_var, computation_and_nb.first, result);
+        // We don't add the variable to the context because the invariant is that the
+        // context is the context in which 'result' makes sense, and we've just updated it.
+      } else {
+        // Here it's not doable to introduce (via a let-in) the computation at this level
+        // as it contains variables that are not yet declared, and/or because the predicate
+        // did not select it.
+        // Either way, we will simply add to the vector of computations the direct subexprs
+        // of the current computation, as these ones might be good candidates
+        // for being introduced into variables.
+        // Note that we don't need to add all of its subexpressions, but only its *direct*
+        // subexpressions as we consider them from biggest to smallest, and if they were
+        // all added at once, then there could be dependencies between them, as commoning
+        // one of them could remove some other possibilities.
+
+        // Computing the direct subexpressions will return a small number of direct
+        // subexpressions (typically 0 to 3)
+        std::vector<PrimExpr> direct_subexprs = DirectSubexpr::GetDirectSubexpressions(
+            computation_and_nb.first, IsEligibleComputation, CanContainEligibleComputations);
+        // The following insertion will maintain `semantic_comp_done_by_expr` sorted (by
+        // decreasing size/complexity), and it will only insert at locations > i as the
+        // direct subexprs are necessarily smaller than the current computation.
+        InsertVectorToSortedSemanticComputations(semantic_comp_done_by_expr, direct_subexprs);
+      }
+    }
+    // Note : we do not remove the current element, as we never look back in the local vector
+  }  // End of for loop
+
+  // Calling the dispatcher to the specific treatments, which will update the context
+  // appropriately before doing the recursive calls on the child nodes
+  result = StmtExprMutator::VisitExpr(result);
+
+  return result;
+}
+
+/*!
+ * \brief The method which overrides the specific treatment for a LetNode
+ */
+PrimExpr CommonSubexpressionEliminator::VisitExpr_(const LetNode* op) {
+  // At this point, we have already done the generic treatment of introducing (via let-in) what
+  // was doable at the toplevel of the given let-in.
+
+  // Save the context at the entry of the function
+  Context context_at_entry = context_;
+
+  // Recurse on the `value` field for potentially rewriting it
+  PrimExpr value_new = VisitExpr(op->value);
+
+  // Augment the context with the association (`var`, `value`) for preparing the next recursion
+  // on the `body`
+  context_.push_back({op->var, MaybeValue(op->value)});
+
+  // Recurse on the `body` (with this extended context)
+  // The recursive call will have potentially done new simplifications, because in this recursive
+  // call `var` will be a part of the context.
+  // (see in VisitExpr() that no introduction were performed when a computation was using an
+  // undefined variable, as that would lead to ill-formed code)
+  PrimExpr body_new = VisitExpr(op->body);
+
+  // Restaure the context to its content at the entrance to not carry out of scope declarations
+  // as the variable introduced by the let-in is not in scope outside of its body
+  context_ = context_at_entry;
+
+  // Rebuild the let-in with a new `value_new` and `body_new` where new simplifications might
+  // have been done.
+
+  // If the `value` and the `body` of the let-in have been rewritten to the same thing
+  if (value_new.same_as(op->value) && body_new.same_as(op->body)) {
+    // then return a reference to the same node
+    return GetRef<PrimExpr>(op);
+  } else {
+    // Otherwise return a let-in built with the new `value_new` and the new `body_new` that
+    // have just been obtained
+    return Let(op->var, value_new, body_new, op->span);
+  }
+}
+
+/*!
+ * \brief The method which overrides the generic dispatcher of StmtExprMutator.
+          Entry point to the common subexpression elimination mutator for statements.
+ * \param stmt The statement to mutate.
+ */
+Stmt CommonSubexpressionEliminator::VisitStmt(const Stmt& stmt) {

Review comment:
       Yes we do need both the `VisitExpr()` and `VisitStmt()` as we are doing commoning both inside expressions (think about `let x = bigComp+bigComp`) and inside statements (think about `Mem[i] = bigComp+bigComp`).
   
   The basic implementation `StmtExprMutator::VisitStmt` does indeed visit sub `Expr`s. And actually we do rely on `StmtExprMutator::VisitExpr()` and `StmtExprMutator::VisitStmt()` for recursing inside the children, see the end of the two methods [here](https://github.com/apache/tvm/pull/9482/files#diff-2cfef2190c8093f3b70d5ed20a1e9e75a74cd040a6a0987d5648f511949efce9R307) and [here](https://github.com/apache/tvm/pull/9482/files#diff-2cfef2190c8093f3b70d5ed20a1e9e75a74cd040a6a0987d5648f511949efce9R477).
   
   The algorithm implemented for performing commoning inside expressions and inside statements is of course basically the same (it identifies redundancies, creates variables and then builds new expression/statement using let-in), but many things will be different type-wise when treating a statement versus an expression.
   There are many of these subtle differences, but one of them is that for an expression the pass will build a let-in expression and for a statement it will build a let-in statement.
   
   I also thought that it should be possible to factorize them and to replace them by something polymorphic. But when I actually tried a little bit (quite some time ago) I ended up not sure that this is doable easily without putting complexity everywhere.
   
   In particular, a possible approach would be to:
   - 1 - Create a main function for the algorithm currently duplicated in `VisitStmt()` and in `VisitExpr()`. Let's call it "`Visit()`" for now. This function takes an `Either<Expr,Stmt>`and returns an `Either<Expr,Stmt>`. Actually, that is called std::variant in C++.
   - 2 - In `VisitExpr(const PrimExpr& expr)` : we now build an `std::variant<Expr, Stmt>` from the given `expr`, then call the "generic" `Visit()`, and unpack its result (an Expr) to return it.
   - 3 - Same thing for `VisitStmt(const Stmt& stmt)` : we now build an std::variant<Expr, Stmt> from `stmt`, then call the "generic" `Visit()`, and unpack its result (a Stmt) to return it.
   
   In order to achieve 1 (which is the hard part here), we will need to :
   - Use polymorphic functions for all the functions being used by the current `VisitStmt()` and `VisitExpr()` (i.e. functions with the same name for treating a `PrimExpr `and a `Stmt`). It is already the case for `GetComputationsDoneBy()`, but we need to do the same for `ReplaceSelectedExprInStmt()` and `ReplaceSelectedExprInExpr()` which are currently not polymorphic methods. And the same goes for quite a few other methods.
   This part is absolutely not the problem.
   
   - But now, when we try to write the new `Visit()` function, it seems that we are a bit stuck :
   We get as input this expr_or_stmt of type std::variant<Expr, Stmt>.
   We can of course check what we got :
   ```
   Expr expr;
   Stmt stmt;
   if( isExpr(expr_or_stmt) ) // Pseudo code, can't remember the name of the actual std method
     expr = std::get<Expr>(expr_or_stmt);
   else
     stmt = std::get<Stmt>(expr_or_stmt);
   ```
   
   And then what?
   For instance, by what should we replace the beginning of the algorithm, which currently does:
   ComputationsDoneBy::GetComputationsDoneBy(
         expr, IsEligibleComputation, CanContainEligibleComputations);
         
   Sure, we could check if we're working with an Expr or a Stmt and within a conditional call either 
   ```
   GetComputationsDoneBy(
         **expr**, IsEligibleComputation, CanContainEligibleComputations);
   ```
   or
   ```
   GetComputationsDoneBy(
         **stmt**, IsEligibleComputation, CanContainEligibleComputations);
   ```
   depending on which case we're in. But that replaces a duplication of methods by duplication **within** the method... And as the algorithm was not trivial to follow even with the comments, I though it's probably better to not add complexity in it.
   
   I will try to think about that again soon, but my current way of seeing things is that without adding complexity everywhere, it's not so easy to actually replace these two methods by just a single one doing the job.
   In addition, in its current form, it's possible to easily change/fine-tune the algorithm if we want it to do something slightly different for an Expr versus for a Stmt in the future. That would be more complicated if everything gets unified under layers of polymorphism.
     




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] FranckQC commented on a change in pull request #9482: Implementation of Common Subexpression Elimination for TIR

Posted by GitBox <gi...@apache.org>.
FranckQC commented on a change in pull request #9482:
URL: https://github.com/apache/tvm/pull/9482#discussion_r797243979



##########
File path: src/tir/transforms/common_subexpr_elim.cc
##########
@@ -0,0 +1,601 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file common_subexpr_elim.cc
+ * \brief Implementation of the Common Subexpressions Elimination (CSE) pass
+           which rewrites statements and expressions in order to eliminate
+           redundant computations. In order to achieve that, common (sub-)
+           expressions are introduced into variables with let-in bindings,
+           and the places where the expression was used are replaced with
+           the freshly introduced variable.
+ */
+
+#include "common_subexpr_elim.h"
+
+#include <tvm/ir/transform.h>  // For the class Pass and the class PassContext
+#include <tvm/runtime/container/array.h>
+#include <tvm/runtime/container/string.h>
+#include <tvm/tir/analysis.h>  // For the analysis which gives the size of an expr
+#include <tvm/tir/expr.h>
+#include <tvm/tir/expr_functor.h>
+#include <tvm/tir/function.h>  // For the class PrimFunc
+#include <tvm/tir/stmt.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>  // For the decl of the function returning the pass
+
+#include <algorithm>  // For the algorithm std::find
+#include <iostream>
+#include <unordered_map>  // For the hashtable datatype
+#include <utility>        // For std::pair and std::move
+#include <vector>
+
+#include "../analysis/check_contains.h"  // For the visitor CheckContains
+#include "common_subexpr_elim_tools.h"   // For the auxiliary analysis (visitors) and tools
+#include "replace_expr_selected.h"       // For the mutator ReplaceExprSelected
+
+namespace tvm {
+namespace tir {
+
+/*!
+ * \brief Check whether a computation is forbidden for being treated by the CSE pass.
+          The important thing about forbidden computations is that not only we won't want
+          to collect them for the CSE pass, but we also won't even want to collect computations
+          that contain them.
+          The reason is that reusing such computations would change the semantics of the program,
+          and therefore before doing any introduction of variable or any reuse of already introduced
+          variables, we will make sure that the computation being considered is not forbidden, and
+          that it does not even contain a forbidden computation.
+ * \param expr The expression to check
+ * \return Whether `expr` is a forbidden computation or not
+ */
+bool CommonSubexpressionEliminator::ForbiddenComputation(const PrimExpr& expr) {
+  // Function calls, loads and buffer loads are absolutely forbidden as introducing them into
+  // variables would change the semantics of the program.
+  return (expr.as<CallNode>() != nullptr || expr.as<LoadNode>() != nullptr ||
+          expr.as<BufferLoadNode>() != nullptr);
+}
+
+/*!
+ * \brief Predicate used for verifying that a computation is eligible for being treated by
+          the CSE pass, i.e. for being introduced into a variable / for being replaced by a
+          variable.
+          Being eligible is a conjunction of a few conditions, like not being an atom (constant
+          or variable), not being a forbidden node, not containing a forbidden node, etc.
+ * \param expr The expression to check
+ * \return Whether `expr` is an eligible computation or not
+ */
+bool CommonSubexpressionEliminator::IsEligibleComputation(const PrimExpr& expr) {
+  return (
+      // In order to be eligible, the given expression should not be a constant
+      (expr.as<IntImmNode>() == nullptr) && (expr.as<FloatImmNode>() == nullptr) &&
+      (expr.as<StringImmNode>() == nullptr)
+      // and it should not be a variable
+      && (expr.as<VarNode>() == nullptr)
+      // and it should not be a forbidden computation (function calls and loads)
+      && (!ForbiddenComputation(expr))
+      // and it should not even contain a forbidden computation (function calls and loads)
+      // the reason is that we don't want to register expressions like (x + f(y)) or
+      // (x + Mem[i]) as introducing them into variables could change the semantics
+      && (!CheckContains::ExprContains(expr, ForbiddenComputation))
+      // and it should not be a ramp node or a broadcast node due to some internals TVM
+      // constraints (which check for these node explicitely without performing any
+      // evaluation first, so if they have been put into variables it fails)
+      && (expr.as<RampNode>() == nullptr) && (expr.as<BroadcastNode>() == nullptr));
+}
+
+/*!
+ * \brief Predicate used (when considering eligible computations) for only diving into
+          expressions that are allowed to contain eligible computations. Customize this predicate
+          if you want to make it forbidden to rewrite inside a specific node, like inside
+          a Load node for instance.
+ * \param expr The expression to check
+ * \return Whether `expr` can contain some eligible computations or not, and therefore
+             if recursing inside `expr` is necessary.
+ */
+bool CommonSubexpressionEliminator::CanContainEligibleComputations(const PrimExpr& expr) {
+  // Uncomment the next line to prevent the collection and the replacement of eligible computations
+  // inside the index of Load nodes. We initially thought that this would be needed in order to
+  // not harm the indexing mode of the CPU, but as we are still far from ASM code, we
+  // finally want to perform such simplifications, which tend to happen fairly frequently.
+
+  // return ( (expr.as<LoadNode>() == nullptr) && (expr.as<BufferLoadNode>() == nullptr) )
+  return true;
+};
+
+/*!
+ * \brief Generates a new fresh variable, whose name will be cse_var_i.
+ * \param type_annotation The type of the new variable to generate
+ * \return A new variable of type `type_annotation` called cse_var_i where i is the first available
+            integer.
+ */
+Var CommonSubexpressionEliminator::GenerateNewVar(DataType type_annotation) {
+  // Increase `num_last_try_` for this new attempt
+  num_last_try_++;
+  // Builds the variable name, which is sce_var_i where i will go up from 1
+  std::string prefix = "cse_var_";
+  std::string name = prefix.append(std::to_string(num_last_try_));
+  // Builds a String using the std::string
+  String string_name(name);
+
+  // Check that the name that we want to use for the new variable isn't already being used
+  // (names don't really have to be unique as they are just hints, and having the same name
+  // doesn't means that it's the same variable, but it's clearer for dumps)
+  if (UsesVarName::StmtUsesVarName(initial_body_, string_name)) {
+    // If the name is already used, call ourselves recursively for trying with the next one
+    return GenerateNewVar(type_annotation);
+  }
+
+  // Increase `nb_var_` for this new generation of variable that we have just done
+  nb_var_++;
+
+  // Return a new Variable using the name built and the given type_annotation
+  return (Var(string_name, type_annotation));
+}
+
+/*!
+ * \brief Gives the number of variables generated by the CSE on the current function
+           (i.e., getter for `nb_var_`).
+ * \return A copy of `nb_var_`
+ */
+int CommonSubexpressionEliminator::GetNbVarGenerated() { return nb_var_; }
+
+/*!
+ * \brief Toplevel (static) method that performs Common Subexpression Elimination on
+          a given statement (which should be the body of a PrimFunc). This method should be
+          called for each PrimFunc definition.
+ * \param stmt The statement of the function being analyzed, on which we want to perform CSE
+ * \param context_init The initial context, which should contain the formal parameters
+                          of the function being analyzed
+ * \return A new statement where CSE has been performed
+ */
+Stmt CommonSubexpressionEliminator::PerformCSE(const Stmt& stmt, const Context& context_init) {
+  // As this function is being called for each PrimFunc definition, we create a new instance
+  // for the one we are having now.
+  CommonSubexpressionEliminator common_subexpression_eliminator(stmt, context_init);
+  return common_subexpression_eliminator.VisitStmt(stmt);
+}
+
+/*!
+ * \brief Protected constructor of CommonSubexpressionEliminator.
+ * \param context_init The context at the begining of the CSE pass. It should contain the
+                        formal parameters of the function that will be analyzed
+ */
+CommonSubexpressionEliminator::CommonSubexpressionEliminator(const Stmt& stmt,
+                                                             const Context& context_init)
+    : initial_body_(stmt), context_(context_init) {}
+
+/*!
+ * \brief The method which overrides the generic dispatcher of StmtExprMutator.
+          Entry point to the common subexpression elimination mutator for expressions.
+ * \param expr The expression to mutate
+ */
+PrimExpr CommonSubexpressionEliminator::VisitExpr(const PrimExpr& expr) {
+  PrimExpr result = expr;
+
+  // Obtain the (syntactic) eligible computations done by the input expression, and keep it as
+  // a TableOfComputations, which is a mapping from PrimExpr to size_t, where the size_t is the
+  // number of time this exact syntactic computation is being computed.
+  TableOfComputations table_syntactic_comp_done_by_expr = ComputationsDoneBy::GetComputationsDoneBy(
+      expr, IsEligibleComputation, CanContainEligibleComputations);
+
+  // Transform the hashtable of *syntactic* eligible computations into a vector of pairs
+  // containing *semantic* entities, i.e. where equivalent computations are merged.
+  std::vector<std::pair<PrimExpr, size_t>> semantic_comp_done_by_expr =
+      SyntacticToSemanticComputations(table_syntactic_comp_done_by_expr);
+
+  // Sort the vector of semantic entities by decreasing size
+  std::sort(semantic_comp_done_by_expr.begin(), semantic_comp_done_by_expr.end(),
+            [](std::pair<PrimExpr, size_t> a, std::pair<PrimExpr, size_t> b) {
+              return (CalculateExprComplexity(a.first) > CalculateExprComplexity(b.first));
+            });
+
+  // For each computation done (considering them from biggest to smallest)
+  for (size_t i = 0; i < semantic_comp_done_by_expr.size(); i++) {
+    std::pair<PrimExpr, size_t>& computation_and_nb = semantic_comp_done_by_expr[i];
+
+    // The predicate later used (when doing replacements) to select expressions that are
+    // equivalent to the current computation (`computation_and_nb.first`)
+    std::function<bool(const PrimExpr&)> predicate_selector =
+        [computation_and_nb](const PrimExpr& current_expr) {
+          // `current_expr` should be equivalent to `computation_and_nb.first`, but we also check
+          // that `current_expr` is an eligible computation even if we know that
+          // `computation_and_nb.first` is eligible by construction, in case that one day the
+          // equivalence relation would not preserve the eligibility any more (even though that
+          // would probably be a very weird equivalence).
+          return (EquivalentTerms(current_expr, computation_and_nb.first) &&
+                  IsEligibleComputation(current_expr));
+        };
+
+    // See if there is a pair (`var`, `value`) in the context where `value` is semantically
+    // equivalent to `computation_and_nb.first`
+    auto it_on_var = std::find_if(
+        context_.begin(), context_.end(),
+        [computation_and_nb](const std::pair<Var, MaybeValue>& var_and_value) {
+          // Note : safe to call value() as we check has_value() just before
+          return (var_and_value.second.has_value() &&
+                  EquivalentTerms(var_and_value.second.value(), computation_and_nb.first));
+        });
+
+    // Case where we have a perfectly equivalent computation already available in a variable
+    // introduced (i.e, present in context_).
+    // Note that this case is needed when the user has written something like
+    // [let x = A in ....A...A...] : we need to be able to replace all the occurences of A by
+    // an already existing variable holding A, when such a variable happens to exist.
+    if (it_on_var != context_.end()) {
+      // Replace in the current `result` everything that is selected by the selector with
+      // the existing variable, without diving into expressions in which we don't have the
+      // right to dive.
+      result = ReplaceExprSelected::ReplaceExprSelectedInExpr(
+          result, predicate_selector, it_on_var->first, CanContainEligibleComputations);
+    } else {
+      // The current computation is not equivalent to a computation already done. We will
+      // need to see if we want to introduce it.
+
+      // --- Chunk needed for reusing the UndefinedVars() analysis ---
+      // 1 - Wraps the computation into a statement
+      Stmt computation_wrapped_in_stmt = Evaluate(computation_and_nb.first);
+      // 2.1 - Transform the context into a vector of variables instead of pairs
+      std::function<Var(const std::pair<Var, MaybeValue>&)> forget_value =
+          [](const std::pair<Var, MaybeValue>& pair) { return pair.first; };
+      std::vector<Var> vector_vars_known = VectorMap(context_, forget_value);
+      // 2.2 - Transform the std::vector into an Array
+      Array<Var> array_vars_known = Array<Var>(vector_vars_known);
+      // --- End of chunk needed for reusing the UndefinedVars() analysis ---
+
+      // We use the UndefinedVars() analysis to get the undefined vars of the computation
+      Array<Var> vars_undefined = UndefinedVars(computation_wrapped_in_stmt, array_vars_known);
+
+      // Check if we can introduce it : if it contains no undefined variables and if we want
+      // to introduce it according to the predicate
+      if (vars_undefined.empty() &&
+          PredicateIntroVarForComputation(computation_and_nb.first, computation_and_nb.second)) {
+        // Create a new variable for this computation
+        Var new_var = GenerateNewVar(computation_and_nb.first.dtype());
+        // Replace in the current `result` everything that is selected by the selector with
+        // the new variable, without diving into expressions in which we don't have the
+        // right to dive.
+        result = ReplaceExprSelected::ReplaceExprSelectedInExpr(result, predicate_selector, new_var,
+                                                                CanContainEligibleComputations);
+        // Build a let-in that introduces the new variable in the current `result`
+        result = Let(new_var, computation_and_nb.first, result);
+        // We don't add the variable to the context because the invariant is that the
+        // context is the context in which 'result' makes sense, and we've just updated it.
+      } else {
+        // Here it's not doable to introduce (via a let-in) the computation at this level
+        // as it contains variables that are not yet declared, and/or because the predicate
+        // did not select it.
+        // Either way, we will simply add to the vector of computations the direct subexprs
+        // of the current computation, as these ones might be good candidates
+        // for being introduced into variables.
+        // Note that we don't need to add all of its subexpressions, but only its *direct*
+        // subexpressions as we consider them from biggest to smallest, and if they were
+        // all added at once, then there could be dependencies between them, as commoning
+        // one of them could remove some other possibilities.
+
+        // Computing the direct subexpressions will return a small number of direct
+        // subexpressions (typically 0 to 3)
+        std::vector<PrimExpr> direct_subexprs = DirectSubexpr::GetDirectSubexpressions(
+            computation_and_nb.first, IsEligibleComputation, CanContainEligibleComputations);
+        // The following insertion will maintain `semantic_comp_done_by_expr` sorted (by
+        // decreasing size/complexity), and it will only insert at locations > i as the
+        // direct subexprs are necessarily smaller than the current computation.
+        InsertVectorToSortedSemanticComputations(semantic_comp_done_by_expr, direct_subexprs);
+      }
+    }
+    // Note : we do not remove the current element, as we never look back in the local vector
+  }  // End of for loop
+
+  // Calling the dispatcher to the specific treatments, which will update the context
+  // appropriately before doing the recursive calls on the child nodes
+  result = StmtExprMutator::VisitExpr(result);
+
+  return result;
+}
+
+/*!
+ * \brief The method which overrides the specific treatment for a LetNode
+ */
+PrimExpr CommonSubexpressionEliminator::VisitExpr_(const LetNode* op) {
+  // At this point, we have already done the generic treatment of introducing (via let-in) what
+  // was doable at the toplevel of the given let-in.
+
+  // Save the context at the entry of the function
+  Context context_at_entry = context_;
+
+  // Recurse on the `value` field for potentially rewriting it
+  PrimExpr value_new = VisitExpr(op->value);
+
+  // Augment the context with the association (`var`, `value`) for preparing the next recursion
+  // on the `body`
+  context_.push_back({op->var, MaybeValue(op->value)});
+
+  // Recurse on the `body` (with this extended context)
+  // The recursive call will have potentially done new simplifications, because in this recursive
+  // call `var` will be a part of the context.
+  // (see in VisitExpr() that no introduction were performed when a computation was using an
+  // undefined variable, as that would lead to ill-formed code)
+  PrimExpr body_new = VisitExpr(op->body);
+
+  // Restaure the context to its content at the entrance to not carry out of scope declarations
+  // as the variable introduced by the let-in is not in scope outside of its body
+  context_ = context_at_entry;
+
+  // Rebuild the let-in with a new `value_new` and `body_new` where new simplifications might
+  // have been done.
+
+  // If the `value` and the `body` of the let-in have been rewritten to the same thing
+  if (value_new.same_as(op->value) && body_new.same_as(op->body)) {
+    // then return a reference to the same node
+    return GetRef<PrimExpr>(op);
+  } else {
+    // Otherwise return a let-in built with the new `value_new` and the new `body_new` that
+    // have just been obtained
+    return Let(op->var, value_new, body_new, op->span);
+  }
+}
+
+/*!
+ * \brief The method which overrides the generic dispatcher of StmtExprMutator.
+          Entry point to the common subexpression elimination mutator for statements.
+ * \param stmt The statement to mutate.
+ */
+Stmt CommonSubexpressionEliminator::VisitStmt(const Stmt& stmt) {

Review comment:
       Yes we do need both the `VisitExpr()` and `VisitStmt()` as we are doing commoning both inside expressions (think about `let x = bigComp+bigComp`) and inside statements (think about `Mem[i] = bigComp+bigComp`).
   
   The basic implementation `StmtExprMutator::VisitStmt` does indeed visit sub `Expr`s. And actually we do rely on `StmtExprMutator::VisitExpr()` and `StmtExprMutator::VisitStmt()` for recursing inside the children, see the end of the two methods [here](https://github.com/apache/tvm/pull/9482/files#diff-2cfef2190c8093f3b70d5ed20a1e9e75a74cd040a6a0987d5648f511949efce9R307) and [here](https://github.com/apache/tvm/pull/9482/files#diff-2cfef2190c8093f3b70d5ed20a1e9e75a74cd040a6a0987d5648f511949efce9R477).
   
   The algorithm implemented for performing commoning inside expressions and inside statements is of course basically the same (it identifies redundancies, creates variables and then builds new expression/statement using let-in), but many things will be different type-wise when treating a statement versus an expression.
   There are many of these subtle differences, but one of them is that for an expression the pass will build a let-in expression and for a statement it will build a let-in statement.
   
   I also thought that it should be possible to factorize them and to replace them by something polymorphic. But when I actually tried a little bit (quite some time ago) I ended up not sure that this is doable easily without putting complexity everywhere.
   
   In particular, a possible approach would be to:
   - 1 - I create a main function for the algorithm currently duplicated in `VisitStmt()` and in `VisitExpr()`. Let's call it "`Visit()`" for now. This function takes an `Either<Expr,Stmt>`and returns an `Either<Expr,Stmt>`. Actually, that is called std::variant in C++.
   - 2 - In `VisitExpr(const PrimExpr& expr)` : we now build an `std::variant<Expr, Stmt>` from the given `expr`, then call the "generic" `Visit()`, and unpack its result (an Expr) to return it.
   - 3 - Same thing for `VisitStmt(const Stmt& stmt)` : we now build an std::variant<Expr, Stmt> from `stmt`, then call the "generic" `Visit()`, and unpack its result (a Stmt) to return it.
   
   In order to achieve 1 (which is the hard part here), we will need to :
   - use polymorphic functions for all the functions being used by the current `VisitStmt()` and `VisitExpr()` (i.e. functions with the same name for treating a `PrimExpr `and a `Stmt`). It is already the case for `GetComputationsDoneBy()`, but we need to do the same for `ReplaceSelectedExprInStmt()` and `ReplaceSelectedExprInExpr()` which are currently not polymorphic methods. And the same goes for quite a few other methods.
   This part is absolutely not the problem.
   
   - But now, when we try to write the new `Visit()` function, it seems that we are a bit stuck :
   We get as input this expr_or_stmt of type std::variant<Expr, Stmt>.
   We can of course check what we got :
   Expr expr;
   Stmt stmt;
   if( isExpr(expr_or_stmt) )
     expr = std::get<Expr>(expr_or_stmt);
   else
     stmt = std::get<Stmt>(expr_or_stmt);
   
   And then what?
   For instance, by what should we replace the beginning of the algorithm, which currently does:
   ComputationsDoneBy::GetComputationsDoneBy(
         expr, IsEligibleComputation, CanContainEligibleComputations);
         
   Sure, we could check if we're working with an Expr or a Stmt and within a conditional call either 
   ```
   GetComputationsDoneBy(
         **expr**, IsEligibleComputation, CanContainEligibleComputations);
   ```
   or
   ```
   GetComputationsDoneBy(
         **stmt**, IsEligibleComputation, CanContainEligibleComputations);
   ```
   depending on which case we're in. But that replaces a duplication of methods by duplication **within** the method... And as the algorithm was not trivial to follow even with the comments, I though it's probably better to not add complexity in it.
   
   I will try to think about that again soon, but my current way of seeing things is that without adding complexity everywhere, it's not so easy to actually replace these two methods by just a single one doing the job.
   In addition, in its current form, it's possible to easily change the algorithm if we want it to do something slightly different for an Expr versus for a Stmt. That would be more complicated if everything goets unified.
     




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] FranckQC commented on a change in pull request #9482: Implementation of Common Subexpression Elimination for TIR

Posted by GitBox <gi...@apache.org>.
FranckQC commented on a change in pull request #9482:
URL: https://github.com/apache/tvm/pull/9482#discussion_r797243979



##########
File path: src/tir/transforms/common_subexpr_elim.cc
##########
@@ -0,0 +1,601 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file common_subexpr_elim.cc
+ * \brief Implementation of the Common Subexpressions Elimination (CSE) pass
+           which rewrites statements and expressions in order to eliminate
+           redundant computations. In order to achieve that, common (sub-)
+           expressions are introduced into variables with let-in bindings,
+           and the places where the expression was used are replaced with
+           the freshly introduced variable.
+ */
+
+#include "common_subexpr_elim.h"
+
+#include <tvm/ir/transform.h>  // For the class Pass and the class PassContext
+#include <tvm/runtime/container/array.h>
+#include <tvm/runtime/container/string.h>
+#include <tvm/tir/analysis.h>  // For the analysis which gives the size of an expr
+#include <tvm/tir/expr.h>
+#include <tvm/tir/expr_functor.h>
+#include <tvm/tir/function.h>  // For the class PrimFunc
+#include <tvm/tir/stmt.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>  // For the decl of the function returning the pass
+
+#include <algorithm>  // For the algorithm std::find
+#include <iostream>
+#include <unordered_map>  // For the hashtable datatype
+#include <utility>        // For std::pair and std::move
+#include <vector>
+
+#include "../analysis/check_contains.h"  // For the visitor CheckContains
+#include "common_subexpr_elim_tools.h"   // For the auxiliary analysis (visitors) and tools
+#include "replace_expr_selected.h"       // For the mutator ReplaceExprSelected
+
+namespace tvm {
+namespace tir {
+
+/*!
+ * \brief Check whether a computation is forbidden for being treated by the CSE pass.
+          The important thing about forbidden computations is that not only we won't want
+          to collect them for the CSE pass, but we also won't even want to collect computations
+          that contain them.
+          The reason is that reusing such computations would change the semantics of the program,
+          and therefore before doing any introduction of variable or any reuse of already introduced
+          variables, we will make sure that the computation being considered is not forbidden, and
+          that it does not even contain a forbidden computation.
+ * \param expr The expression to check
+ * \return Whether `expr` is a forbidden computation or not
+ */
+bool CommonSubexpressionEliminator::ForbiddenComputation(const PrimExpr& expr) {
+  // Function calls, loads and buffer loads are absolutely forbidden as introducing them into
+  // variables would change the semantics of the program.
+  return (expr.as<CallNode>() != nullptr || expr.as<LoadNode>() != nullptr ||
+          expr.as<BufferLoadNode>() != nullptr);
+}
+
+/*!
+ * \brief Predicate used for verifying that a computation is eligible for being treated by
+          the CSE pass, i.e. for being introduced into a variable / for being replaced by a
+          variable.
+          Being eligible is a conjunction of a few conditions, like not being an atom (constant
+          or variable), not being a forbidden node, not containing a forbidden node, etc.
+ * \param expr The expression to check
+ * \return Whether `expr` is an eligible computation or not
+ */
+bool CommonSubexpressionEliminator::IsEligibleComputation(const PrimExpr& expr) {
+  return (
+      // In order to be eligible, the given expression should not be a constant
+      (expr.as<IntImmNode>() == nullptr) && (expr.as<FloatImmNode>() == nullptr) &&
+      (expr.as<StringImmNode>() == nullptr)
+      // and it should not be a variable
+      && (expr.as<VarNode>() == nullptr)
+      // and it should not be a forbidden computation (function calls and loads)
+      && (!ForbiddenComputation(expr))
+      // and it should not even contain a forbidden computation (function calls and loads)
+      // the reason is that we don't want to register expressions like (x + f(y)) or
+      // (x + Mem[i]) as introducing them into variables could change the semantics
+      && (!CheckContains::ExprContains(expr, ForbiddenComputation))
+      // and it should not be a ramp node or a broadcast node due to some internals TVM
+      // constraints (which check for these node explicitely without performing any
+      // evaluation first, so if they have been put into variables it fails)
+      && (expr.as<RampNode>() == nullptr) && (expr.as<BroadcastNode>() == nullptr));
+}
+
+/*!
+ * \brief Predicate used (when considering eligible computations) for only diving into
+          expressions that are allowed to contain eligible computations. Customize this predicate
+          if you want to make it forbidden to rewrite inside a specific node, like inside
+          a Load node for instance.
+ * \param expr The expression to check
+ * \return Whether `expr` can contain some eligible computations or not, and therefore
+             if recursing inside `expr` is necessary.
+ */
+bool CommonSubexpressionEliminator::CanContainEligibleComputations(const PrimExpr& expr) {
+  // Uncomment the next line to prevent the collection and the replacement of eligible computations
+  // inside the index of Load nodes. We initially thought that this would be needed in order to
+  // not harm the indexing mode of the CPU, but as we are still far from ASM code, we
+  // finally want to perform such simplifications, which tend to happen fairly frequently.
+
+  // return ( (expr.as<LoadNode>() == nullptr) && (expr.as<BufferLoadNode>() == nullptr) )
+  return true;
+};
+
+/*!
+ * \brief Generates a new fresh variable, whose name will be cse_var_i.
+ * \param type_annotation The type of the new variable to generate
+ * \return A new variable of type `type_annotation` called cse_var_i where i is the first available
+            integer.
+ */
+Var CommonSubexpressionEliminator::GenerateNewVar(DataType type_annotation) {
+  // Increase `num_last_try_` for this new attempt
+  num_last_try_++;
+  // Builds the variable name, which is sce_var_i where i will go up from 1
+  std::string prefix = "cse_var_";
+  std::string name = prefix.append(std::to_string(num_last_try_));
+  // Builds a String using the std::string
+  String string_name(name);
+
+  // Check that the name that we want to use for the new variable isn't already being used
+  // (names don't really have to be unique as they are just hints, and having the same name
+  // doesn't means that it's the same variable, but it's clearer for dumps)
+  if (UsesVarName::StmtUsesVarName(initial_body_, string_name)) {
+    // If the name is already used, call ourselves recursively for trying with the next one
+    return GenerateNewVar(type_annotation);
+  }
+
+  // Increase `nb_var_` for this new generation of variable that we have just done
+  nb_var_++;
+
+  // Return a new Variable using the name built and the given type_annotation
+  return (Var(string_name, type_annotation));
+}
+
+/*!
+ * \brief Gives the number of variables generated by the CSE on the current function
+           (i.e., getter for `nb_var_`).
+ * \return A copy of `nb_var_`
+ */
+int CommonSubexpressionEliminator::GetNbVarGenerated() { return nb_var_; }
+
+/*!
+ * \brief Toplevel (static) method that performs Common Subexpression Elimination on
+          a given statement (which should be the body of a PrimFunc). This method should be
+          called for each PrimFunc definition.
+ * \param stmt The statement of the function being analyzed, on which we want to perform CSE
+ * \param context_init The initial context, which should contain the formal parameters
+                          of the function being analyzed
+ * \return A new statement where CSE has been performed
+ */
+Stmt CommonSubexpressionEliminator::PerformCSE(const Stmt& stmt, const Context& context_init) {
+  // As this function is being called for each PrimFunc definition, we create a new instance
+  // for the one we are having now.
+  CommonSubexpressionEliminator common_subexpression_eliminator(stmt, context_init);
+  return common_subexpression_eliminator.VisitStmt(stmt);
+}
+
+/*!
+ * \brief Protected constructor of CommonSubexpressionEliminator.
+ * \param context_init The context at the begining of the CSE pass. It should contain the
+                        formal parameters of the function that will be analyzed
+ */
+CommonSubexpressionEliminator::CommonSubexpressionEliminator(const Stmt& stmt,
+                                                             const Context& context_init)
+    : initial_body_(stmt), context_(context_init) {}
+
+/*!
+ * \brief The method which overrides the generic dispatcher of StmtExprMutator.
+          Entry point to the common subexpression elimination mutator for expressions.
+ * \param expr The expression to mutate
+ */
+PrimExpr CommonSubexpressionEliminator::VisitExpr(const PrimExpr& expr) {
+  PrimExpr result = expr;
+
+  // Obtain the (syntactic) eligible computations done by the input expression, and keep it as
+  // a TableOfComputations, which is a mapping from PrimExpr to size_t, where the size_t is the
+  // number of time this exact syntactic computation is being computed.
+  TableOfComputations table_syntactic_comp_done_by_expr = ComputationsDoneBy::GetComputationsDoneBy(
+      expr, IsEligibleComputation, CanContainEligibleComputations);
+
+  // Transform the hashtable of *syntactic* eligible computations into a vector of pairs
+  // containing *semantic* entities, i.e. where equivalent computations are merged.
+  std::vector<std::pair<PrimExpr, size_t>> semantic_comp_done_by_expr =
+      SyntacticToSemanticComputations(table_syntactic_comp_done_by_expr);
+
+  // Sort the vector of semantic entities by decreasing size
+  std::sort(semantic_comp_done_by_expr.begin(), semantic_comp_done_by_expr.end(),
+            [](std::pair<PrimExpr, size_t> a, std::pair<PrimExpr, size_t> b) {
+              return (CalculateExprComplexity(a.first) > CalculateExprComplexity(b.first));
+            });
+
+  // For each computation done (considering them from biggest to smallest)
+  for (size_t i = 0; i < semantic_comp_done_by_expr.size(); i++) {
+    std::pair<PrimExpr, size_t>& computation_and_nb = semantic_comp_done_by_expr[i];
+
+    // The predicate later used (when doing replacements) to select expressions that are
+    // equivalent to the current computation (`computation_and_nb.first`)
+    std::function<bool(const PrimExpr&)> predicate_selector =
+        [computation_and_nb](const PrimExpr& current_expr) {
+          // `current_expr` should be equivalent to `computation_and_nb.first`, but we also check
+          // that `current_expr` is an eligible computation even if we know that
+          // `computation_and_nb.first` is eligible by construction, in case that one day the
+          // equivalence relation would not preserve the eligibility any more (even though that
+          // would probably be a very weird equivalence).
+          return (EquivalentTerms(current_expr, computation_and_nb.first) &&
+                  IsEligibleComputation(current_expr));
+        };
+
+    // See if there is a pair (`var`, `value`) in the context where `value` is semantically
+    // equivalent to `computation_and_nb.first`
+    auto it_on_var = std::find_if(
+        context_.begin(), context_.end(),
+        [computation_and_nb](const std::pair<Var, MaybeValue>& var_and_value) {
+          // Note : safe to call value() as we check has_value() just before
+          return (var_and_value.second.has_value() &&
+                  EquivalentTerms(var_and_value.second.value(), computation_and_nb.first));
+        });
+
+    // Case where we have a perfectly equivalent computation already available in a variable
+    // introduced (i.e, present in context_).
+    // Note that this case is needed when the user has written something like
+    // [let x = A in ....A...A...] : we need to be able to replace all the occurences of A by
+    // an already existing variable holding A, when such a variable happens to exist.
+    if (it_on_var != context_.end()) {
+      // Replace in the current `result` everything that is selected by the selector with
+      // the existing variable, without diving into expressions in which we don't have the
+      // right to dive.
+      result = ReplaceExprSelected::ReplaceExprSelectedInExpr(
+          result, predicate_selector, it_on_var->first, CanContainEligibleComputations);
+    } else {
+      // The current computation is not equivalent to a computation already done. We will
+      // need to see if we want to introduce it.
+
+      // --- Chunk needed for reusing the UndefinedVars() analysis ---
+      // 1 - Wraps the computation into a statement
+      Stmt computation_wrapped_in_stmt = Evaluate(computation_and_nb.first);
+      // 2.1 - Transform the context into a vector of variables instead of pairs
+      std::function<Var(const std::pair<Var, MaybeValue>&)> forget_value =
+          [](const std::pair<Var, MaybeValue>& pair) { return pair.first; };
+      std::vector<Var> vector_vars_known = VectorMap(context_, forget_value);
+      // 2.2 - Transform the std::vector into an Array
+      Array<Var> array_vars_known = Array<Var>(vector_vars_known);
+      // --- End of chunk needed for reusing the UndefinedVars() analysis ---
+
+      // We use the UndefinedVars() analysis to get the undefined vars of the computation
+      Array<Var> vars_undefined = UndefinedVars(computation_wrapped_in_stmt, array_vars_known);
+
+      // Check if we can introduce it : if it contains no undefined variables and if we want
+      // to introduce it according to the predicate
+      if (vars_undefined.empty() &&
+          PredicateIntroVarForComputation(computation_and_nb.first, computation_and_nb.second)) {
+        // Create a new variable for this computation
+        Var new_var = GenerateNewVar(computation_and_nb.first.dtype());
+        // Replace in the current `result` everything that is selected by the selector with
+        // the new variable, without diving into expressions in which we don't have the
+        // right to dive.
+        result = ReplaceExprSelected::ReplaceExprSelectedInExpr(result, predicate_selector, new_var,
+                                                                CanContainEligibleComputations);
+        // Build a let-in that introduces the new variable in the current `result`
+        result = Let(new_var, computation_and_nb.first, result);
+        // We don't add the variable to the context because the invariant is that the
+        // context is the context in which 'result' makes sense, and we've just updated it.
+      } else {
+        // Here it's not doable to introduce (via a let-in) the computation at this level
+        // as it contains variables that are not yet declared, and/or because the predicate
+        // did not select it.
+        // Either way, we will simply add to the vector of computations the direct subexprs
+        // of the current computation, as these ones might be good candidates
+        // for being introduced into variables.
+        // Note that we don't need to add all of its subexpressions, but only its *direct*
+        // subexpressions as we consider them from biggest to smallest, and if they were
+        // all added at once, then there could be dependencies between them, as commoning
+        // one of them could remove some other possibilities.
+
+        // Computing the direct subexpressions will return a small number of direct
+        // subexpressions (typically 0 to 3)
+        std::vector<PrimExpr> direct_subexprs = DirectSubexpr::GetDirectSubexpressions(
+            computation_and_nb.first, IsEligibleComputation, CanContainEligibleComputations);
+        // The following insertion will maintain `semantic_comp_done_by_expr` sorted (by
+        // decreasing size/complexity), and it will only insert at locations > i as the
+        // direct subexprs are necessarily smaller than the current computation.
+        InsertVectorToSortedSemanticComputations(semantic_comp_done_by_expr, direct_subexprs);
+      }
+    }
+    // Note : we do not remove the current element, as we never look back in the local vector
+  }  // End of for loop
+
+  // Calling the dispatcher to the specific treatments, which will update the context
+  // appropriately before doing the recursive calls on the child nodes
+  result = StmtExprMutator::VisitExpr(result);
+
+  return result;
+}
+
+/*!
+ * \brief The method which overrides the specific treatment for a LetNode
+ */
+PrimExpr CommonSubexpressionEliminator::VisitExpr_(const LetNode* op) {
+  // At this point, we have already done the generic treatment of introducing (via let-in) what
+  // was doable at the toplevel of the given let-in.
+
+  // Save the context at the entry of the function
+  Context context_at_entry = context_;
+
+  // Recurse on the `value` field for potentially rewriting it
+  PrimExpr value_new = VisitExpr(op->value);
+
+  // Augment the context with the association (`var`, `value`) for preparing the next recursion
+  // on the `body`
+  context_.push_back({op->var, MaybeValue(op->value)});
+
+  // Recurse on the `body` (with this extended context)
+  // The recursive call will have potentially done new simplifications, because in this recursive
+  // call `var` will be a part of the context.
+  // (see in VisitExpr() that no introduction were performed when a computation was using an
+  // undefined variable, as that would lead to ill-formed code)
+  PrimExpr body_new = VisitExpr(op->body);
+
+  // Restaure the context to its content at the entrance to not carry out of scope declarations
+  // as the variable introduced by the let-in is not in scope outside of its body
+  context_ = context_at_entry;
+
+  // Rebuild the let-in with a new `value_new` and `body_new` where new simplifications might
+  // have been done.
+
+  // If the `value` and the `body` of the let-in have been rewritten to the same thing
+  if (value_new.same_as(op->value) && body_new.same_as(op->body)) {
+    // then return a reference to the same node
+    return GetRef<PrimExpr>(op);
+  } else {
+    // Otherwise return a let-in built with the new `value_new` and the new `body_new` that
+    // have just been obtained
+    return Let(op->var, value_new, body_new, op->span);
+  }
+}
+
+/*!
+ * \brief The method which overrides the generic dispatcher of StmtExprMutator.
+          Entry point to the common subexpression elimination mutator for statements.
+ * \param stmt The statement to mutate.
+ */
+Stmt CommonSubexpressionEliminator::VisitStmt(const Stmt& stmt) {

Review comment:
       Yes we do need both the `VisitExpr()` and `VisitStmt()` as we are doing commoning both inside expressions (think about `let x = bigComp+bigComp`) and inside statements (think about `Mem[i] = bigComp+bigComp`).
   
   The basic implementation `StmtExprMutator::VisitStmt` does indeed visit sub `Expr`s. And actually we do rely on `StmtExprMutator::VisitExpr()` and `StmtExprMutator::VisitStmt()` for recursing inside the children, see the end of the two methods [here](https://github.com/apache/tvm/pull/9482/files#diff-2cfef2190c8093f3b70d5ed20a1e9e75a74cd040a6a0987d5648f511949efce9R307) and [here](https://github.com/apache/tvm/pull/9482/files#diff-2cfef2190c8093f3b70d5ed20a1e9e75a74cd040a6a0987d5648f511949efce9R477).
   
   The algorithm implemented for performing commoning inside expressions and inside statements is of course basically the same (it identifies redundancies, creates variables and then builds new expression/statement using let-in), but many things will be different type-wise when treating a statement versus an expression.
   There are many of these subtle differences, but one of them is that for an expression the pass will build a let-in expression and for a statement it will build a let-in statement.
   
   I also thought that it should be possible to factorize them and to replace them by something polymorphic. But when I actually tried a little bit (quite some time ago) I ended up not sure that this is doable easily without putting complexity everywhere.
   
   In particular, a possible approach would be to:
   - 1 - Create a main function for the algorithm currently duplicated in `VisitStmt()` and in `VisitExpr()`. Let's call it "`Visit()`" for now. This function takes an `Either<Expr,Stmt>`and returns an `Either<Expr,Stmt>`. Actually, that is called std::variant in C++.
   - 2 - In `VisitExpr(const PrimExpr& expr)` : we now build an `std::variant<Expr, Stmt>` from the given `expr`, then call the "generic" `Visit()`, and unpack its result (an Expr) to return it.
   - 3 - Same thing for `VisitStmt(const Stmt& stmt)` : we now build an std::variant<Expr, Stmt> from `stmt`, then call the "generic" `Visit()`, and unpack its result (a Stmt) to return it.
   
   In order to achieve 1 (which is the hard part here), we will need to :
   - Use polymorphic functions for all the functions being used by the current `VisitStmt()` and `VisitExpr()` (i.e. functions with the same name for treating a `PrimExpr `and a `Stmt`). It is already the case for `GetComputationsDoneBy()`, but we need to do the same for `ReplaceSelectedExprInStmt()` and `ReplaceSelectedExprInExpr()` which are currently not polymorphic methods. And the same goes for quite a few other methods.
   This part is absolutely not the problem.
   
   - But now, when we try to write the new `Visit()` function, it seems that we are a bit stuck :
   We get as input this expr_or_stmt of type std::variant<Expr, Stmt>.
   We can of course check what we got :
   Expr expr;
   Stmt stmt;
   if( isExpr(expr_or_stmt) )
     expr = std::get<Expr>(expr_or_stmt);
   else
     stmt = std::get<Stmt>(expr_or_stmt);
   
   And then what?
   For instance, by what should we replace the beginning of the algorithm, which currently does:
   ComputationsDoneBy::GetComputationsDoneBy(
         expr, IsEligibleComputation, CanContainEligibleComputations);
         
   Sure, we could check if we're working with an Expr or a Stmt and within a conditional call either 
   ```
   GetComputationsDoneBy(
         **expr**, IsEligibleComputation, CanContainEligibleComputations);
   ```
   or
   ```
   GetComputationsDoneBy(
         **stmt**, IsEligibleComputation, CanContainEligibleComputations);
   ```
   depending on which case we're in. But that replaces a duplication of methods by duplication **within** the method... And as the algorithm was not trivial to follow even with the comments, I though it's probably better to not add complexity in it.
   
   I will try to think about that again soon, but my current way of seeing things is that without adding complexity everywhere, it's not so easy to actually replace these two methods by just a single one doing the job.
   In addition, in its current form, it's possible to easily change the algorithm if we want it to do something slightly different for an Expr versus for a Stmt. That would be more complicated if everything goets unified.
     




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] FranckQC commented on a change in pull request #9482: Implementation of Common Subexpression Elimination for TIR

Posted by GitBox <gi...@apache.org>.
FranckQC commented on a change in pull request #9482:
URL: https://github.com/apache/tvm/pull/9482#discussion_r797243979



##########
File path: src/tir/transforms/common_subexpr_elim.cc
##########
@@ -0,0 +1,601 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file common_subexpr_elim.cc
+ * \brief Implementation of the Common Subexpressions Elimination (CSE) pass
+           which rewrites statements and expressions in order to eliminate
+           redundant computations. In order to achieve that, common (sub-)
+           expressions are introduced into variables with let-in bindings,
+           and the places where the expression was used are replaced with
+           the freshly introduced variable.
+ */
+
+#include "common_subexpr_elim.h"
+
+#include <tvm/ir/transform.h>  // For the class Pass and the class PassContext
+#include <tvm/runtime/container/array.h>
+#include <tvm/runtime/container/string.h>
+#include <tvm/tir/analysis.h>  // For the analysis which gives the size of an expr
+#include <tvm/tir/expr.h>
+#include <tvm/tir/expr_functor.h>
+#include <tvm/tir/function.h>  // For the class PrimFunc
+#include <tvm/tir/stmt.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>  // For the decl of the function returning the pass
+
+#include <algorithm>  // For the algorithm std::find
+#include <iostream>
+#include <unordered_map>  // For the hashtable datatype
+#include <utility>        // For std::pair and std::move
+#include <vector>
+
+#include "../analysis/check_contains.h"  // For the visitor CheckContains
+#include "common_subexpr_elim_tools.h"   // For the auxiliary analysis (visitors) and tools
+#include "replace_expr_selected.h"       // For the mutator ReplaceExprSelected
+
+namespace tvm {
+namespace tir {
+
+/*!
+ * \brief Check whether a computation is forbidden for being treated by the CSE pass.
+          The important thing about forbidden computations is that not only we won't want
+          to collect them for the CSE pass, but we also won't even want to collect computations
+          that contain them.
+          The reason is that reusing such computations would change the semantics of the program,
+          and therefore before doing any introduction of variable or any reuse of already introduced
+          variables, we will make sure that the computation being considered is not forbidden, and
+          that it does not even contain a forbidden computation.
+ * \param expr The expression to check
+ * \return Whether `expr` is a forbidden computation or not
+ */
+bool CommonSubexpressionEliminator::ForbiddenComputation(const PrimExpr& expr) {
+  // Function calls, loads and buffer loads are absolutely forbidden as introducing them into
+  // variables would change the semantics of the program.
+  return (expr.as<CallNode>() != nullptr || expr.as<LoadNode>() != nullptr ||
+          expr.as<BufferLoadNode>() != nullptr);
+}
+
+/*!
+ * \brief Predicate used for verifying that a computation is eligible for being treated by
+          the CSE pass, i.e. for being introduced into a variable / for being replaced by a
+          variable.
+          Being eligible is a conjunction of a few conditions, like not being an atom (constant
+          or variable), not being a forbidden node, not containing a forbidden node, etc.
+ * \param expr The expression to check
+ * \return Whether `expr` is an eligible computation or not
+ */
+bool CommonSubexpressionEliminator::IsEligibleComputation(const PrimExpr& expr) {
+  return (
+      // In order to be eligible, the given expression should not be a constant
+      (expr.as<IntImmNode>() == nullptr) && (expr.as<FloatImmNode>() == nullptr) &&
+      (expr.as<StringImmNode>() == nullptr)
+      // and it should not be a variable
+      && (expr.as<VarNode>() == nullptr)
+      // and it should not be a forbidden computation (function calls and loads)
+      && (!ForbiddenComputation(expr))
+      // and it should not even contain a forbidden computation (function calls and loads)
+      // the reason is that we don't want to register expressions like (x + f(y)) or
+      // (x + Mem[i]) as introducing them into variables could change the semantics
+      && (!CheckContains::ExprContains(expr, ForbiddenComputation))
+      // and it should not be a ramp node or a broadcast node due to some internals TVM
+      // constraints (which check for these node explicitely without performing any
+      // evaluation first, so if they have been put into variables it fails)
+      && (expr.as<RampNode>() == nullptr) && (expr.as<BroadcastNode>() == nullptr));
+}
+
+/*!
+ * \brief Predicate used (when considering eligible computations) for only diving into
+          expressions that are allowed to contain eligible computations. Customize this predicate
+          if you want to make it forbidden to rewrite inside a specific node, like inside
+          a Load node for instance.
+ * \param expr The expression to check
+ * \return Whether `expr` can contain some eligible computations or not, and therefore
+             if recursing inside `expr` is necessary.
+ */
+bool CommonSubexpressionEliminator::CanContainEligibleComputations(const PrimExpr& expr) {
+  // Uncomment the next line to prevent the collection and the replacement of eligible computations
+  // inside the index of Load nodes. We initially thought that this would be needed in order to
+  // not harm the indexing mode of the CPU, but as we are still far from ASM code, we
+  // finally want to perform such simplifications, which tend to happen fairly frequently.
+
+  // return ( (expr.as<LoadNode>() == nullptr) && (expr.as<BufferLoadNode>() == nullptr) )
+  return true;
+};
+
+/*!
+ * \brief Generates a new fresh variable, whose name will be cse_var_i.
+ * \param type_annotation The type of the new variable to generate
+ * \return A new variable of type `type_annotation` called cse_var_i where i is the first available
+            integer.
+ */
+Var CommonSubexpressionEliminator::GenerateNewVar(DataType type_annotation) {
+  // Increase `num_last_try_` for this new attempt
+  num_last_try_++;
+  // Builds the variable name, which is sce_var_i where i will go up from 1
+  std::string prefix = "cse_var_";
+  std::string name = prefix.append(std::to_string(num_last_try_));
+  // Builds a String using the std::string
+  String string_name(name);
+
+  // Check that the name that we want to use for the new variable isn't already being used
+  // (names don't really have to be unique as they are just hints, and having the same name
+  // doesn't means that it's the same variable, but it's clearer for dumps)
+  if (UsesVarName::StmtUsesVarName(initial_body_, string_name)) {
+    // If the name is already used, call ourselves recursively for trying with the next one
+    return GenerateNewVar(type_annotation);
+  }
+
+  // Increase `nb_var_` for this new generation of variable that we have just done
+  nb_var_++;
+
+  // Return a new Variable using the name built and the given type_annotation
+  return (Var(string_name, type_annotation));
+}
+
+/*!
+ * \brief Gives the number of variables generated by the CSE on the current function
+           (i.e., getter for `nb_var_`).
+ * \return A copy of `nb_var_`
+ */
+int CommonSubexpressionEliminator::GetNbVarGenerated() { return nb_var_; }
+
+/*!
+ * \brief Toplevel (static) method that performs Common Subexpression Elimination on
+          a given statement (which should be the body of a PrimFunc). This method should be
+          called for each PrimFunc definition.
+ * \param stmt The statement of the function being analyzed, on which we want to perform CSE
+ * \param context_init The initial context, which should contain the formal parameters
+                          of the function being analyzed
+ * \return A new statement where CSE has been performed
+ */
+Stmt CommonSubexpressionEliminator::PerformCSE(const Stmt& stmt, const Context& context_init) {
+  // As this function is being called for each PrimFunc definition, we create a new instance
+  // for the one we are having now.
+  CommonSubexpressionEliminator common_subexpression_eliminator(stmt, context_init);
+  return common_subexpression_eliminator.VisitStmt(stmt);
+}
+
+/*!
+ * \brief Protected constructor of CommonSubexpressionEliminator.
+ * \param context_init The context at the begining of the CSE pass. It should contain the
+                        formal parameters of the function that will be analyzed
+ */
+CommonSubexpressionEliminator::CommonSubexpressionEliminator(const Stmt& stmt,
+                                                             const Context& context_init)
+    : initial_body_(stmt), context_(context_init) {}
+
+/*!
+ * \brief The method which overrides the generic dispatcher of StmtExprMutator.
+          Entry point to the common subexpression elimination mutator for expressions.
+ * \param expr The expression to mutate
+ */
+PrimExpr CommonSubexpressionEliminator::VisitExpr(const PrimExpr& expr) {
+  PrimExpr result = expr;
+
+  // Obtain the (syntactic) eligible computations done by the input expression, and keep it as
+  // a TableOfComputations, which is a mapping from PrimExpr to size_t, where the size_t is the
+  // number of time this exact syntactic computation is being computed.
+  TableOfComputations table_syntactic_comp_done_by_expr = ComputationsDoneBy::GetComputationsDoneBy(
+      expr, IsEligibleComputation, CanContainEligibleComputations);
+
+  // Transform the hashtable of *syntactic* eligible computations into a vector of pairs
+  // containing *semantic* entities, i.e. where equivalent computations are merged.
+  std::vector<std::pair<PrimExpr, size_t>> semantic_comp_done_by_expr =
+      SyntacticToSemanticComputations(table_syntactic_comp_done_by_expr);
+
+  // Sort the vector of semantic entities by decreasing size
+  std::sort(semantic_comp_done_by_expr.begin(), semantic_comp_done_by_expr.end(),
+            [](std::pair<PrimExpr, size_t> a, std::pair<PrimExpr, size_t> b) {
+              return (CalculateExprComplexity(a.first) > CalculateExprComplexity(b.first));
+            });
+
+  // For each computation done (considering them from biggest to smallest)
+  for (size_t i = 0; i < semantic_comp_done_by_expr.size(); i++) {
+    std::pair<PrimExpr, size_t>& computation_and_nb = semantic_comp_done_by_expr[i];
+
+    // The predicate later used (when doing replacements) to select expressions that are
+    // equivalent to the current computation (`computation_and_nb.first`)
+    std::function<bool(const PrimExpr&)> predicate_selector =
+        [computation_and_nb](const PrimExpr& current_expr) {
+          // `current_expr` should be equivalent to `computation_and_nb.first`, but we also check
+          // that `current_expr` is an eligible computation even if we know that
+          // `computation_and_nb.first` is eligible by construction, in case that one day the
+          // equivalence relation would not preserve the eligibility any more (even though that
+          // would probably be a very weird equivalence).
+          return (EquivalentTerms(current_expr, computation_and_nb.first) &&
+                  IsEligibleComputation(current_expr));
+        };
+
+    // See if there is a pair (`var`, `value`) in the context where `value` is semantically
+    // equivalent to `computation_and_nb.first`
+    auto it_on_var = std::find_if(
+        context_.begin(), context_.end(),
+        [computation_and_nb](const std::pair<Var, MaybeValue>& var_and_value) {
+          // Note : safe to call value() as we check has_value() just before
+          return (var_and_value.second.has_value() &&
+                  EquivalentTerms(var_and_value.second.value(), computation_and_nb.first));
+        });
+
+    // Case where we have a perfectly equivalent computation already available in a variable
+    // introduced (i.e, present in context_).
+    // Note that this case is needed when the user has written something like
+    // [let x = A in ....A...A...] : we need to be able to replace all the occurences of A by
+    // an already existing variable holding A, when such a variable happens to exist.
+    if (it_on_var != context_.end()) {
+      // Replace in the current `result` everything that is selected by the selector with
+      // the existing variable, without diving into expressions in which we don't have the
+      // right to dive.
+      result = ReplaceExprSelected::ReplaceExprSelectedInExpr(
+          result, predicate_selector, it_on_var->first, CanContainEligibleComputations);
+    } else {
+      // The current computation is not equivalent to a computation already done. We will
+      // need to see if we want to introduce it.
+
+      // --- Chunk needed for reusing the UndefinedVars() analysis ---
+      // 1 - Wraps the computation into a statement
+      Stmt computation_wrapped_in_stmt = Evaluate(computation_and_nb.first);
+      // 2.1 - Transform the context into a vector of variables instead of pairs
+      std::function<Var(const std::pair<Var, MaybeValue>&)> forget_value =
+          [](const std::pair<Var, MaybeValue>& pair) { return pair.first; };
+      std::vector<Var> vector_vars_known = VectorMap(context_, forget_value);
+      // 2.2 - Transform the std::vector into an Array
+      Array<Var> array_vars_known = Array<Var>(vector_vars_known);
+      // --- End of chunk needed for reusing the UndefinedVars() analysis ---
+
+      // We use the UndefinedVars() analysis to get the undefined vars of the computation
+      Array<Var> vars_undefined = UndefinedVars(computation_wrapped_in_stmt, array_vars_known);
+
+      // Check if we can introduce it : if it contains no undefined variables and if we want
+      // to introduce it according to the predicate
+      if (vars_undefined.empty() &&
+          PredicateIntroVarForComputation(computation_and_nb.first, computation_and_nb.second)) {
+        // Create a new variable for this computation
+        Var new_var = GenerateNewVar(computation_and_nb.first.dtype());
+        // Replace in the current `result` everything that is selected by the selector with
+        // the new variable, without diving into expressions in which we don't have the
+        // right to dive.
+        result = ReplaceExprSelected::ReplaceExprSelectedInExpr(result, predicate_selector, new_var,
+                                                                CanContainEligibleComputations);
+        // Build a let-in that introduces the new variable in the current `result`
+        result = Let(new_var, computation_and_nb.first, result);
+        // We don't add the variable to the context because the invariant is that the
+        // context is the context in which 'result' makes sense, and we've just updated it.
+      } else {
+        // Here it's not doable to introduce (via a let-in) the computation at this level
+        // as it contains variables that are not yet declared, and/or because the predicate
+        // did not select it.
+        // Either way, we will simply add to the vector of computations the direct subexprs
+        // of the current computation, as these ones might be good candidates
+        // for being introduced into variables.
+        // Note that we don't need to add all of its subexpressions, but only its *direct*
+        // subexpressions as we consider them from biggest to smallest, and if they were
+        // all added at once, then there could be dependencies between them, as commoning
+        // one of them could remove some other possibilities.
+
+        // Computing the direct subexpressions will return a small number of direct
+        // subexpressions (typically 0 to 3)
+        std::vector<PrimExpr> direct_subexprs = DirectSubexpr::GetDirectSubexpressions(
+            computation_and_nb.first, IsEligibleComputation, CanContainEligibleComputations);
+        // The following insertion will maintain `semantic_comp_done_by_expr` sorted (by
+        // decreasing size/complexity), and it will only insert at locations > i as the
+        // direct subexprs are necessarily smaller than the current computation.
+        InsertVectorToSortedSemanticComputations(semantic_comp_done_by_expr, direct_subexprs);
+      }
+    }
+    // Note : we do not remove the current element, as we never look back in the local vector
+  }  // End of for loop
+
+  // Calling the dispatcher to the specific treatments, which will update the context
+  // appropriately before doing the recursive calls on the child nodes
+  result = StmtExprMutator::VisitExpr(result);
+
+  return result;
+}
+
+/*!
+ * \brief The method which overrides the specific treatment for a LetNode
+ */
+PrimExpr CommonSubexpressionEliminator::VisitExpr_(const LetNode* op) {
+  // At this point, we have already done the generic treatment of introducing (via let-in) what
+  // was doable at the toplevel of the given let-in.
+
+  // Save the context at the entry of the function
+  Context context_at_entry = context_;
+
+  // Recurse on the `value` field for potentially rewriting it
+  PrimExpr value_new = VisitExpr(op->value);
+
+  // Augment the context with the association (`var`, `value`) for preparing the next recursion
+  // on the `body`
+  context_.push_back({op->var, MaybeValue(op->value)});
+
+  // Recurse on the `body` (with this extended context)
+  // The recursive call will have potentially done new simplifications, because in this recursive
+  // call `var` will be a part of the context.
+  // (see in VisitExpr() that no introduction were performed when a computation was using an
+  // undefined variable, as that would lead to ill-formed code)
+  PrimExpr body_new = VisitExpr(op->body);
+
+  // Restaure the context to its content at the entrance to not carry out of scope declarations
+  // as the variable introduced by the let-in is not in scope outside of its body
+  context_ = context_at_entry;
+
+  // Rebuild the let-in with a new `value_new` and `body_new` where new simplifications might
+  // have been done.
+
+  // If the `value` and the `body` of the let-in have been rewritten to the same thing
+  if (value_new.same_as(op->value) && body_new.same_as(op->body)) {
+    // then return a reference to the same node
+    return GetRef<PrimExpr>(op);
+  } else {
+    // Otherwise return a let-in built with the new `value_new` and the new `body_new` that
+    // have just been obtained
+    return Let(op->var, value_new, body_new, op->span);
+  }
+}
+
+/*!
+ * \brief The method which overrides the generic dispatcher of StmtExprMutator.
+          Entry point to the common subexpression elimination mutator for statements.
+ * \param stmt The statement to mutate.
+ */
+Stmt CommonSubexpressionEliminator::VisitStmt(const Stmt& stmt) {

Review comment:
       Yes we do need both the `VisitExpr()` and `VisitStmt()` as we are doing commoning both inside expressions (think about `let x = bigComp+bigComp`) and inside statements (think about `Mem[i] = bigComp+bigComp`).
   
   The basic implementation `StmtExprMutator::VisitStmt` does indeed visit sub `Expr`s. And actually we do rely on `StmtExprMutator::VisitExpr()` and `StmtExprMutator::VisitStmt()` for recursing inside the children, see the end of the function [here](https://github.com/apache/tvm/pull/9482/files#diff-2cfef2190c8093f3b70d5ed20a1e9e75a74cd040a6a0987d5648f511949efce9R307) and [here](https://github.com/apache/tvm/pull/9482/files#diff-2cfef2190c8093f3b70d5ed20a1e9e75a74cd040a6a0987d5648f511949efce9R477).
   
   The algorithm implemented for performing commoning inside expressions and inside statements is of course basically the same (it identifies redundancies, creates variables and then builds new expression/statement using let-in), but many things will be different type-wise when treating a statement versus an expression.
   There are many of these subtle differences, but one of them is that for an expression the pass will build a let-in expression and for a statement it will build a let-in statement.
   
   I also thought that it should be possible to factorize them and to replace them by something polymorphic. But when I actually tried a little bit (quite some time ago) I ended up not sure that this is doable easily without putting complexity everywhere.
   
   In particular, a possible approach would be to:
   - 1 - I create a main function for the algorithm currently duplicated in `VisitStmt()` and in `VisitExpr()`. Let's call it "`Visit()`" for now. This function takes an `Either<Expr,Stmt>`and returns an `Either<Expr,Stmt>`. Actually, that is called std::variant in C++.
   - 2 - In `VisitExpr(const PrimExpr& expr)` : we now build an `std::variant<Expr, Stmt>` from the given `expr`, then call the "generic" `Visit()`, and unpack its result (an Expr) to return it.
   - 3 - Same thing for `VisitStmt(const Stmt& stmt)` : we now build an std::variant<Expr, Stmt> from `stmt`, then call the "generic" `Visit()`, and unpack its result (a Stmt) to return it.
   
   In order to achieve 1 (which is the hard part here), we will need to :
   - use polymorphic functions for all the functions being used by the current `VisitStmt()` and `VisitExpr()` (i.e. functions with the same name for treating a `PrimExpr `and a `Stmt`). It is already the case for `GetComputationsDoneBy()`, but we need to do the same for `ReplaceSelectedExprInStmt()` and `ReplaceSelectedExprInExpr()` which are currently not polymorphic methods. And the same goes for quite a few other methods.
   This part is absolutely not the problem.
   
   - But now, when we try to write the new `Visit()` function, it seems that we are a bit stuck :
   We get as input this expr_or_stmt of type std::variant<Expr, Stmt>.
   We can of course check what we got :
   Expr expr;
   Stmt stmt;
   if( isExpr(expr_or_stmt) )
     expr = std::get<Expr>(expr_or_stmt);
   else
     stmt = std::get<Stmt>(expr_or_stmt);
   
   And then what?
   For instance, by what should we replace the beginning of the algorithm, which currently does:
   ComputationsDoneBy::GetComputationsDoneBy(
         expr, IsEligibleComputation, CanContainEligibleComputations);
         
   Sure, we could check if we're working with an Expr or a Stmt and within a conditional call either 
   ```
   GetComputationsDoneBy(
         **expr**, IsEligibleComputation, CanContainEligibleComputations);
   ```
   or
   ```
   GetComputationsDoneBy(
         **stmt**, IsEligibleComputation, CanContainEligibleComputations);
   ```
   depending on which case we're in. But that replaces a duplication of methods by duplication **within** the method... And as the algorithm was not trivial to follow even with the comments, I though it's probably better to not add complexity in it.
   
   I will try to think about that again soon, but my current way of seeing things is that without adding complexity everywhere, it's not so easy to actually replace these two methods by just a single one doing the job.
   In addition, in its current form, it's possible to easily change the algorithm if we want it to do something slightly different for an Expr versus for a Stmt. That would be more complicated if everything goets unified.
     




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] FranckQC commented on pull request #9482: Implementation of Common Subexpression Elimination for TIR

Posted by GitBox <gi...@apache.org>.
FranckQC commented on pull request #9482:
URL: https://github.com/apache/tvm/pull/9482#issuecomment-1031887221


   Well, to be honest, I'd love to see this pass as a default as it's a target-independent optimization from which everybody could benefit without having to wonder weither the back-end is going to take care of redundant computations. I see optional passes as more specialized or more target dependent.
   
   I also believe that our TVM CSE it is more aggressive than the LLVM CSE, which can't deal with things like divisions because they face some restrictions for semantics preservation that are not as strong in our specific DSL case.
   
   As that was only 4 or 5 tests that were annoying, I think we should still be able to make it a default pass. Most of them I could already fix with just disabling the CSE pass (mlost of them were in test_lower_build.py, which was expected). Another possibility would have been to rewrite the test and to state the equality between the new expected and what we got. But I think it's better to be minimalist, and to avoid running other passes that "complicate" the output.
   
   I'll try to fix the last two issues and we should be good.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] FranckQC commented on pull request #9482: Implementation of Common Subexpression Elimination for TIR

Posted by GitBox <gi...@apache.org>.
FranckQC commented on pull request #9482:
URL: https://github.com/apache/tvm/pull/9482#issuecomment-1032077140


   Current status before running it again :
   ![image](https://user-images.githubusercontent.com/89943638/152893889-3d53456e-c28c-4698-bef7-1b2ce2dfe3b7.png)
   
   Hopefully, docs should be ok now.
   Perhaps also the 1386 should be ok.
   
   But we will still have to deal with the failure on the VTA test.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi commented on pull request #9482: Implementation of Common Subexpression Elimination for TIR

Posted by GitBox <gi...@apache.org>.
masahi commented on pull request #9482:
URL: https://github.com/apache/tvm/pull/9482#issuecomment-1031864766


   I would check if `vta.build(...)` is really honoring `disabled_passes=[...]` when it runs. You can build VTA and run this test locally. 
   
   Also, what do you think about running the TIR CSE pass optional? I have a feeling that it might bring some disruption if enabled globally, without much benefit for standard backends such as LLVM and CUDA which, until today, didn't see a need for TIR-level CSE (the example I gave, CSE across CPU & GPU, is the only good use case I'm aware of). That will also remove the need to fix existing tests.
   
   One way is to make TIR-CSE run only when `opt-level=4`, but I hope there is a better way to enable certain TIR passes. 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi edited a comment on pull request #9482: Implementation of Common Subexpression Elimination for TIR

Posted by GitBox <gi...@apache.org>.
masahi edited a comment on pull request #9482:
URL: https://github.com/apache/tvm/pull/9482#issuecomment-1031840409


   Yes the doc was fixed like 30 min ago (I also got the same error in my PR), you need to push again. But there is still a Windows issue.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi commented on pull request #9482: Implementation of Common Subexpression Elimination for TIR

Posted by GitBox <gi...@apache.org>.
masahi commented on pull request #9482:
URL: https://github.com/apache/tvm/pull/9482#issuecomment-1031832770


   > I tried to re-run the CI (by adding a change and removing it - is there an easier way than that by the way?), and the Windows CI is still not running
   
   You can fetch, rebase against `upstream/main`, and push your branch. Or do an empty commit, that would also trigger a CI job.
   
   The doc build was fixed in https://github.com/apache/tvm/pull/10181, but the Windows issue is apparently due to github on-going some maintenance today. I don't know what is our mitigation yet.
   
   ![image](https://user-images.githubusercontent.com/1776403/152857273-14dfc01f-538c-48c3-a1a7-5964687b03fa.png)
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] FranckQC edited a comment on pull request #9482: Implementation of Common Subexpression Elimination for TIR

Posted by GitBox <gi...@apache.org>.
FranckQC edited a comment on pull request #9482:
URL: https://github.com/apache/tvm/pull/9482#issuecomment-969151523


   > 
   > 
   > Hi @FranckQC, I wanted TIR-level CSE for a long time, so very excited to see this!
   > 
   > What I wanted to do is to eliminate common expressions that span across the host and GPU - for example, in GPU `sort` kernel, I need to make `log2(N)` GPU kernel calls from the host to sort the input bottom up. In principle, `log2(N)` needs to computed once by the host and pass to the GPU kernel, but since we cannot CSE `log2(N)` expression that appears both in the host and GPU kernel, right now the GPU sort kernel is littered with `log2(N)` compute like this (note a lot of calls to `call_spirv_pure_glsl450` which is totally unnecessary if we had TIR-level CSE) https://gist.github.com/masahi/7a755ef67009e1a836e3212c53cf496f
   > 
   > Is this PR going to solve my problem?
   
   Hi @masahi 
   Thanks a lot for the kind words! I'm happy to read that this new pass might be useful to you.
   In principle, every redundant subterms that are eligible for being commoned out (i.e, which does not contain function calls, etc) will be commoned out. There are also a few other minor restrictions which are due to some specifics of TVM, but these are rare.
   
   Unfortunately, on the TIR code snippet that you have uploaded, it seems that the subexpression that is redundant is just a function call. These can't be commoned out into a new variable by the CSE pass as it does not have any guarantee that the function has no side effects, meaning that it will always produce the same outputs for the same inputs. Without this guarantee, commoning out such function calls could change the program's semantics, so it's not done as preserving the semantics of the program is vital.
   
   I can imagine that for functions that are guaranteed to not do any side effects (and which are therefore "functions" in the mathematical sense of the term), we could relax this restriction, but that would be an extension to implement in the future. And it would rely on some "NoSideEffect" tag on functions.
   
   However, please note that if you had some other redundancies, this CSE pass would common out whatever redundancies you have that are eligible.
   For instance : 
   Assume you have the term (f(42) + f(42)) + (x*y + x*y) that appear somewhere.
   It is not eligible in its entirety, as it contains function calls.
   But its subterm (x*y) is eligible, so this subpart will be commoned out.
   In short, the CSE pass, as implemented, always try to common out all the redundant computations that are illegible, and it does it by looking from bigger subterms to smaller subterms.
   
   Does that answers your question?
   Please do not hesitate to tell me if you need help for trying the pass.
   
   Kind regards.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] FranckQC edited a comment on pull request #9482: Implementation of Common Subexpression Elimination for TIR

Posted by GitBox <gi...@apache.org>.
FranckQC edited a comment on pull request #9482:
URL: https://github.com/apache/tvm/pull/9482#issuecomment-969151523


   > 
   > 
   > Hi @FranckQC, I wanted TIR-level CSE for a long time, so very excited to see this!
   > 
   > What I wanted to do is to eliminate common expressions that span across the host and GPU - for example, in GPU `sort` kernel, I need to make `log2(N)` GPU kernel calls from the host to sort the input bottom up. In principle, `log2(N)` needs to computed once by the host and pass to the GPU kernel, but since we cannot CSE `log2(N)` expression that appears both in the host and GPU kernel, right now the GPU sort kernel is littered with `log2(N)` compute like this (note a lot of calls to `call_spirv_pure_glsl450` which is totally unnecessary if we had TIR-level CSE) https://gist.github.com/masahi/7a755ef67009e1a836e3212c53cf496f
   > 
   > Is this PR going to solve my problem?
   
   Hi @masahi 
   Thanks a lot for the kind words! I'm happy to read that this new pass might be useful to you.
   Yes, in principle, every redundant subterms that are eligible for being commoned out (i.e, which does not contain function calls, etc) will be commoned out. There are also a few restrictions which are due to some specifics of TVM, but these are rare.
   Do you have a little snippet of the TIR code that you have which has some redundancies? I cant try to tell if the CSE pass will be able to optimize it.
   Also please do not hesitate to play with the pass and to let me know if it does what you would hope to obtain. I can help of course.
   
   Kind regards.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org