You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2023/05/24 00:29:27 UTC
[tvm] branch unity updated: [Unity] Allow eliminating only call nodes in CSE pass (#14895)
This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 6bb531fcb1 [Unity] Allow eliminating only call nodes in CSE pass (#14895)
6bb531fcb1 is described below
commit 6bb531fcb18b2fe8ea638367b5b4d0d6fa2c15a2
Author: masahi <ma...@gmail.com>
AuthorDate: Wed May 24 09:29:16 2023 +0900
[Unity] Allow eliminating only call nodes in CSE pass (#14895)
Allow eliminating only call nodes in CSE pass
---
include/tvm/relax/transform.h | 3 ++-
python/tvm/relax/transform/transform.py | 9 +++++--
src/relax/transform/eliminate_common_subexpr.cc | 21 +++++++++-------
tests/python/relax/test_transform_cse.py | 32 +++++++++++++++++++++++--
4 files changed, 52 insertions(+), 13 deletions(-)
diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h
index 138720ec13..6f9841ba7a 100644
--- a/include/tvm/relax/transform.h
+++ b/include/tvm/relax/transform.h
@@ -169,8 +169,9 @@ TVM_DLL Pass CanonicalizeBindings();
*
* \note For functions local to dataflow blocks, this pass performs
* CSE *within* those functions.
+ * \param call_only If true, enable eliminating only call nodes.
*/
-TVM_DLL Pass EliminateCommonSubexpr();
+TVM_DLL Pass EliminateCommonSubexpr(bool call_only = false);
/*!
* \brief Bind params of function of the module to constant tensors.
diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py
index 278f66fc40..6013073a37 100644
--- a/python/tvm/relax/transform/transform.py
+++ b/python/tvm/relax/transform/transform.py
@@ -290,18 +290,23 @@ def CanonicalizeBindings() -> tvm.ir.transform.Pass:
return _ffi_api.CanonicalizeBindings() # type: ignore
-def EliminateCommonSubexpr() -> DataflowBlockPass:
+def EliminateCommonSubexpr(call_only=False) -> DataflowBlockPass:
"""Eliminate common subexpressions within dataflow blocks.
Note: For functions local to dataflow blocks, this pass performs
CSE *within* those functions
+ Parameters
+ ----------
+ call_only : bool
+ If True, enable eliminating only call nodes.
+
Returns
-------
ret : tvm.transform.Pass
The registered pass that eliminates common subexpressions.
"""
- return _ffi_api.EliminateCommonSubexpr() # type: ignore
+ return _ffi_api.EliminateCommonSubexpr(call_only) # type: ignore
def RewriteDataflowReshape() -> tvm.ir.transform.Pass:
diff --git a/src/relax/transform/eliminate_common_subexpr.cc b/src/relax/transform/eliminate_common_subexpr.cc
index 6c772d2e20..3087c409ac 100644
--- a/src/relax/transform/eliminate_common_subexpr.cc
+++ b/src/relax/transform/eliminate_common_subexpr.cc
@@ -69,16 +69,20 @@ class SubexprCounter : public ExprVisitor {
};
// forward declaration
-DataflowBlock EliminateCommonSubexpr(const DataflowBlock&);
+DataflowBlock EliminateCommonSubexpr(const DataflowBlock&, bool call_only);
class CommonSubexprEliminator : public ExprMutator {
public:
explicit CommonSubexprEliminator(
- const std::unordered_map<Expr, int, StructuralHash, StructuralEqual>& count_map)
- : count_map_(count_map) {}
+ const std::unordered_map<Expr, int, StructuralHash, StructuralEqual>& count_map,
+ bool call_only = false)
+ : count_map_(count_map), call_only_(call_only) {}
// overriding here ensures we visit every subexpression
Expr VisitExpr(const Expr& e) override {
+ if (call_only_ && !e->IsInstance<CallNode>()) {
+ return ExprMutator::VisitExpr(e);
+ }
if (count_map_.count(e) && count_map_.at(e) > 1) {
// if we already have a mapping for it, get it
if (replacements_.count(e)) {
@@ -116,7 +120,7 @@ class CommonSubexprEliminator : public ExprMutator {
// apply CSE within dataflow blocks only
for (auto block : seq->blocks) {
if (const DataflowBlockNode* df_block = block.as<DataflowBlockNode>()) {
- auto new_df_block = EliminateCommonSubexpr(GetRef<DataflowBlock>(df_block));
+ auto new_df_block = EliminateCommonSubexpr(GetRef<DataflowBlock>(df_block), call_only_);
if (!new_df_block.same_as(block)) {
new_blocks.push_back(new_df_block);
all_unchanged = false;
@@ -182,21 +186,22 @@ class CommonSubexprEliminator : public ExprMutator {
const std::unordered_map<Expr, int, StructuralHash, StructuralEqual>& count_map_;
std::unordered_map<Expr, Var, StructuralHash, StructuralEqual> replacements_;
+ bool call_only_{false};
};
-DataflowBlock EliminateCommonSubexpr(const DataflowBlock& df_block) {
+DataflowBlock EliminateCommonSubexpr(const DataflowBlock& df_block, bool call_only) {
SubexprCounter counter;
auto count_map = counter.Count(df_block);
- CommonSubexprEliminator eliminator(count_map);
+ CommonSubexprEliminator eliminator(count_map, call_only);
return Downcast<DataflowBlock>(eliminator.VisitBindingBlock(df_block));
}
namespace transform {
-Pass EliminateCommonSubexpr() {
+Pass EliminateCommonSubexpr(bool call_only) {
runtime::TypedPackedFunc<DataflowBlock(DataflowBlock, IRModule, PassContext)> pass_func =
[=](DataflowBlock df_block, IRModule m, PassContext pc) {
- return Downcast<DataflowBlock>(EliminateCommonSubexpr(df_block));
+ return Downcast<DataflowBlock>(EliminateCommonSubexpr(df_block, call_only));
};
return CreateDataflowBlockPass(pass_func, 1, "EliminateCommonSubexpr", {});
}
diff --git a/tests/python/relax/test_transform_cse.py b/tests/python/relax/test_transform_cse.py
index 4ee9653ead..94897c1eae 100644
--- a/tests/python/relax/test_transform_cse.py
+++ b/tests/python/relax/test_transform_cse.py
@@ -23,8 +23,8 @@ from tvm.script.parser import ir as I, relax as R, tir as T
import numpy as np
-def verify(input, expected):
- tvm.ir.assert_structural_equal(EliminateCommonSubexpr()(input), expected)
+def verify(input, expected, call_only=False):
+ tvm.ir.assert_structural_equal(EliminateCommonSubexpr(call_only)(input), expected)
def test_simple():
@@ -182,5 +182,33 @@ def test_inner_function():
verify(Before, Expected)
+def test_call_only():
+ @I.ir_module
+ class Before:
+ @R.function
+ def foo(x: R.Tensor((160,), dtype="float32")):
+ with R.dataflow():
+ lv1 = R.arange(R.prim_value(0), R.prim_value(160), R.prim_value(1), dtype="float32")
+ lv2 = R.arange(R.prim_value(0), R.prim_value(160), R.prim_value(1), dtype="float32")
+ lv3 = R.add(x, lv1)
+ out = R.add(lv3, lv2)
+ R.output(out)
+ return out
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def foo(x: R.Tensor((160,), dtype="float32")) -> R.Tensor((160,), dtype="float32"):
+ with R.dataflow():
+ lv1 = R.arange(R.prim_value(0), R.prim_value(160), R.prim_value(1), dtype="float32")
+ lv2 = lv1
+ lv3 = R.add(x, lv1)
+ out = R.add(lv3, lv2)
+ R.output(out)
+ return out
+
+ verify(Before, Expected, call_only=True)
+
+
if __name__ == "__main__":
tvm.testing.main()