You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2023/05/24 00:29:27 UTC

[tvm] branch unity updated: [Unity] Allow eliminating only call nodes in CSE pass (#14895)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 6bb531fcb1 [Unity] Allow eliminating only call nodes in CSE pass (#14895)
6bb531fcb1 is described below

commit 6bb531fcb18b2fe8ea638367b5b4d0d6fa2c15a2
Author: masahi <ma...@gmail.com>
AuthorDate: Wed May 24 09:29:16 2023 +0900

    [Unity] Allow eliminating only call nodes in CSE pass (#14895)
    
    Allow eliminating only call nodes in CSE pass
---
 include/tvm/relax/transform.h                   |  3 ++-
 python/tvm/relax/transform/transform.py         |  9 +++++--
 src/relax/transform/eliminate_common_subexpr.cc | 21 +++++++++-------
 tests/python/relax/test_transform_cse.py        | 32 +++++++++++++++++++++++--
 4 files changed, 52 insertions(+), 13 deletions(-)

diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h
index 138720ec13..6f9841ba7a 100644
--- a/include/tvm/relax/transform.h
+++ b/include/tvm/relax/transform.h
@@ -169,8 +169,9 @@ TVM_DLL Pass CanonicalizeBindings();
  *
  * \note For functions local to dataflow blocks, this pass performs
  * CSE *within* those functions.
+ * \param call_only If true, enable eliminating only call nodes.
  */
-TVM_DLL Pass EliminateCommonSubexpr();
+TVM_DLL Pass EliminateCommonSubexpr(bool call_only = false);
 
 /*!
  * \brief Bind params of function of the module to constant tensors.
diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py
index 278f66fc40..6013073a37 100644
--- a/python/tvm/relax/transform/transform.py
+++ b/python/tvm/relax/transform/transform.py
@@ -290,18 +290,23 @@ def CanonicalizeBindings() -> tvm.ir.transform.Pass:
     return _ffi_api.CanonicalizeBindings()  # type: ignore
 
 
-def EliminateCommonSubexpr() -> DataflowBlockPass:
+def EliminateCommonSubexpr(call_only=False) -> DataflowBlockPass:
     """Eliminate common subexpressions within dataflow blocks.
 
     Note: For functions local to dataflow blocks, this pass performs
     CSE *within* those functions
 
+    Parameters
+    ----------
+    call_only : bool
+        If True, enable eliminating only call nodes.
+
     Returns
     -------
     ret : tvm.transform.Pass
         The registered pass that eliminates common subexpressions.
     """
-    return _ffi_api.EliminateCommonSubexpr()  # type: ignore
+    return _ffi_api.EliminateCommonSubexpr(call_only)  # type: ignore
 
 
 def RewriteDataflowReshape() -> tvm.ir.transform.Pass:
diff --git a/src/relax/transform/eliminate_common_subexpr.cc b/src/relax/transform/eliminate_common_subexpr.cc
index 6c772d2e20..3087c409ac 100644
--- a/src/relax/transform/eliminate_common_subexpr.cc
+++ b/src/relax/transform/eliminate_common_subexpr.cc
@@ -69,16 +69,20 @@ class SubexprCounter : public ExprVisitor {
 };
 
 // forward declaration
-DataflowBlock EliminateCommonSubexpr(const DataflowBlock&);
+DataflowBlock EliminateCommonSubexpr(const DataflowBlock&, bool call_only);
 
 class CommonSubexprEliminator : public ExprMutator {
  public:
   explicit CommonSubexprEliminator(
-      const std::unordered_map<Expr, int, StructuralHash, StructuralEqual>& count_map)
-      : count_map_(count_map) {}
+      const std::unordered_map<Expr, int, StructuralHash, StructuralEqual>& count_map,
+      bool call_only = false)
+      : count_map_(count_map), call_only_(call_only) {}
 
   // overriding here ensures we visit every subexpression
   Expr VisitExpr(const Expr& e) override {
+    if (call_only_ && !e->IsInstance<CallNode>()) {
+      return ExprMutator::VisitExpr(e);
+    }
     if (count_map_.count(e) && count_map_.at(e) > 1) {
       // if we already have a mapping for it, get it
       if (replacements_.count(e)) {
@@ -116,7 +120,7 @@ class CommonSubexprEliminator : public ExprMutator {
     // apply CSE within dataflow blocks only
     for (auto block : seq->blocks) {
       if (const DataflowBlockNode* df_block = block.as<DataflowBlockNode>()) {
-        auto new_df_block = EliminateCommonSubexpr(GetRef<DataflowBlock>(df_block));
+        auto new_df_block = EliminateCommonSubexpr(GetRef<DataflowBlock>(df_block), call_only_);
         if (!new_df_block.same_as(block)) {
           new_blocks.push_back(new_df_block);
           all_unchanged = false;
@@ -182,21 +186,22 @@ class CommonSubexprEliminator : public ExprMutator {
 
   const std::unordered_map<Expr, int, StructuralHash, StructuralEqual>& count_map_;
   std::unordered_map<Expr, Var, StructuralHash, StructuralEqual> replacements_;
+  bool call_only_{false};
 };
 
-DataflowBlock EliminateCommonSubexpr(const DataflowBlock& df_block) {
+DataflowBlock EliminateCommonSubexpr(const DataflowBlock& df_block, bool call_only) {
   SubexprCounter counter;
   auto count_map = counter.Count(df_block);
-  CommonSubexprEliminator eliminator(count_map);
+  CommonSubexprEliminator eliminator(count_map, call_only);
   return Downcast<DataflowBlock>(eliminator.VisitBindingBlock(df_block));
 }
 
 namespace transform {
 
-Pass EliminateCommonSubexpr() {
+Pass EliminateCommonSubexpr(bool call_only) {
   runtime::TypedPackedFunc<DataflowBlock(DataflowBlock, IRModule, PassContext)> pass_func =
       [=](DataflowBlock df_block, IRModule m, PassContext pc) {
-        return Downcast<DataflowBlock>(EliminateCommonSubexpr(df_block));
+        return Downcast<DataflowBlock>(EliminateCommonSubexpr(df_block, call_only));
       };
   return CreateDataflowBlockPass(pass_func, 1, "EliminateCommonSubexpr", {});
 }
diff --git a/tests/python/relax/test_transform_cse.py b/tests/python/relax/test_transform_cse.py
index 4ee9653ead..94897c1eae 100644
--- a/tests/python/relax/test_transform_cse.py
+++ b/tests/python/relax/test_transform_cse.py
@@ -23,8 +23,8 @@ from tvm.script.parser import ir as I, relax as R, tir as T
 import numpy as np
 
 
-def verify(input, expected):
-    tvm.ir.assert_structural_equal(EliminateCommonSubexpr()(input), expected)
+def verify(input, expected, call_only=False):
+    tvm.ir.assert_structural_equal(EliminateCommonSubexpr(call_only)(input), expected)
 
 
 def test_simple():
@@ -182,5 +182,33 @@ def test_inner_function():
     verify(Before, Expected)
 
 
+def test_call_only():
+    @I.ir_module
+    class Before:
+        @R.function
+        def foo(x: R.Tensor((160,), dtype="float32")):
+            with R.dataflow():
+                lv1 = R.arange(R.prim_value(0), R.prim_value(160), R.prim_value(1), dtype="float32")
+                lv2 = R.arange(R.prim_value(0), R.prim_value(160), R.prim_value(1), dtype="float32")
+                lv3 = R.add(x, lv1)
+                out = R.add(lv3, lv2)
+                R.output(out)
+            return out
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def foo(x: R.Tensor((160,), dtype="float32")) -> R.Tensor((160,), dtype="float32"):
+            with R.dataflow():
+                lv1 = R.arange(R.prim_value(0), R.prim_value(160), R.prim_value(1), dtype="float32")
+                lv2 = lv1
+                lv3 = R.add(x, lv1)
+                out = R.add(lv3, lv2)
+                R.output(out)
+            return out
+
+    verify(Before, Expected, call_only=True)
+
+
 if __name__ == "__main__":
     tvm.testing.main()