You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ss...@apache.org on 2023/03/27 17:58:21 UTC
[tvm] branch unity updated: [Unity][Transform] Common Subexpression Elimination (#14361)
This is an automated email from the ASF dual-hosted git repository.
sslyu pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new f4d5964653 [Unity][Transform] Common Subexpression Elimination (#14361)
f4d5964653 is described below
commit f4d5964653f1fc197fac02edc197deb46ee4dfed
Author: Steven S. Lyubomirsky <sl...@octoml.ai>
AuthorDate: Mon Mar 27 13:58:13 2023 -0400
[Unity][Transform] Common Subexpression Elimination (#14361)
* [Unity][Pass] Add pass for CSE within dataflow
* Fill in CSE definition and test cases
* Missing trailing newline
---------
Co-authored-by: Prakalp Srivastava <pr...@octoml.ai>
---
include/tvm/relax/transform.h | 9 +
python/tvm/relax/transform/transform.py | 14 ++
src/relax/transform/eliminate_common_subexpr.cc | 209 ++++++++++++++++++++++++
tests/python/relax/test_transform_cse.py | 186 +++++++++++++++++++++
4 files changed, 418 insertions(+)
diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h
index 4f45ba9c25..f6acf80beb 100644
--- a/include/tvm/relax/transform.h
+++ b/include/tvm/relax/transform.h
@@ -137,6 +137,15 @@ TVM_DLL Pass Normalize();
*/
TVM_DLL Pass CanonicalizeBindings();
+/*!
+ * Eliminate common subexpressions within dataflow blocks.
+ * \return The pass that eliminates common subexpressions.
+ *
+ * \note For functions local to dataflow blocks, this pass performs
+ * CSE *within* those functions.
+ */
+TVM_DLL Pass EliminateCommonSubexpr();
+
/*!
* \brief Bind params of function of the module to constant tensors.
*
diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py
index 18321e8dba..049ac2947f 100644
--- a/python/tvm/relax/transform/transform.py
+++ b/python/tvm/relax/transform/transform.py
@@ -100,6 +100,20 @@ def CanonicalizeBindings() -> tvm.ir.transform.Pass:
return _ffi_api.CanonicalizeBindings() # type: ignore
+def EliminateCommonSubexpr() -> DataflowBlockPass:
+ """Eliminate common subexpressions within dataflow blocks.
+
+ Note: For functions local to dataflow blocks, this pass performs
+ CSE *within* those functions
+
+ Returns
+ -------
+ ret : tvm.transform.Pass
+ The registered pass that eliminates common subexpressions.
+ """
+ return _ffi_api.EliminateCommonSubexpr() # type: ignore
+
+
def RewriteDataflowReshape() -> tvm.ir.transform.Pass:
"""Convert all reshape-like call_tir to VM reshape operator call.
The VM reshape operator calls will be further lowered to a CreateView
diff --git a/src/relax/transform/eliminate_common_subexpr.cc b/src/relax/transform/eliminate_common_subexpr.cc
new file mode 100644
index 0000000000..9c9252ddfa
--- /dev/null
+++ b/src/relax/transform/eliminate_common_subexpr.cc
@@ -0,0 +1,209 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ *
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/relax/transform/eliminate_common_subexpr.cc
+ * \brief Eliminrate common subexpression pass.
+ *
+ * Currently it removes common subexpressions within a DataflowBlock.
+ */
+#include <tvm/relax/expr_functor.h>
+#include <tvm/relax/transform.h>
+
+namespace tvm {
+namespace relax {
+
+class SubexprCounter : public ExprVisitor {
+ public:
+ // overriding VisitExpr ensures we do this for every subexpression
+ void VisitExpr(const Expr& e) override {
+ // Cases we ignore because we will not substitute them:
+ // 1. Vars of all kinds
+ // 2. Op nodes (nothing we can do)
+ // 3. Scalar constants (not much benefit from binding to a var)
+ if (!(e->IsInstance<VarNode>() || e->IsInstance<DataflowVarNode>() ||
+ e->IsInstance<GlobalVarNode>() || e->IsInstance<tvm::OpNode>() ||
+ (e.as<ConstantNode>() && (e.as<ConstantNode>()->is_scalar())))) {
+ int count = 0;
+ if (count_map_.count(e)) {
+ count = count_map_.at(e);
+ }
+ count_map_[e] = count + 1;
+ }
+ ExprVisitor::VisitExpr(e);
+ }
+
+ // do not visit inner functions: we will do CSE within those
+ void VisitExpr_(const FunctionNode* func) override {}
+
+ // we are not going to do replacements inside struct info to avoid binding lots of reused shapes
+ void VisitExprDepStructInfoField(const StructInfo& struct_info) override {}
+
+ std::unordered_map<Expr, int, StructuralHash, StructuralEqual> Count(
+ const DataflowBlock& df_block) {
+ for (auto binding : df_block->bindings) {
+ VisitBinding(binding);
+ }
+ return count_map_;
+ }
+
+ private:
+ std::unordered_map<Expr, int, StructuralHash, StructuralEqual> count_map_;
+};
+
+// forward declaration
+DataflowBlock EliminateCommonSubexpr(const DataflowBlock&);
+
+class CommonSubexprEliminator : public ExprMutator {
+ public:
+ explicit CommonSubexprEliminator(
+ const std::unordered_map<Expr, int, StructuralHash, StructuralEqual>& count_map)
+ : count_map_(count_map) {}
+
+ // overriding here ensures we visit every subexpression
+ Expr VisitExpr(const Expr& e) override {
+ if (count_map_.count(e) && count_map_.at(e) > 1) {
+ // if we already have a mapping for it, get it
+ if (replacements_.count(e)) {
+ return replacements_.at(e);
+ }
+ // Otherwise, insert a new binding for the current expression.
+ // Visit before emitting to do inner replacements
+ Expr new_e = ExprMutator::VisitExpr(e);
+ Var v = builder_->Emit(new_e);
+ replacements_[e] = v;
+ return v;
+ }
+ return ExprMutator::VisitExpr(e);
+ }
+
+ // we are not going to do replacements inside struct info to avoid binding lots of reused shapes
+ StructInfo VisitExprDepStructInfoField(const StructInfo& struct_info) override {
+ return struct_info;
+ }
+
+ Expr VisitExpr_(const FunctionNode* func) override {
+ // for an inner function, we will do CSE on its body
+ Expr new_body = ExprMutator::VisitExpr(func->body);
+ if (new_body.same_as(func->body)) {
+ return GetRef<Expr>(func);
+ }
+ return Function(func->params, new_body, func->ret_struct_info, func->attrs, func->span);
+ }
+
+ // this should happen only for the inner function case
+ Expr VisitExpr_(const SeqExprNode* seq) override {
+ bool all_unchanged = true;
+ Array<BindingBlock> new_blocks;
+ // apply CSE within dataflow blocks only
+ for (auto block : seq->blocks) {
+ if (const DataflowBlockNode* df_block = block.as<DataflowBlockNode>()) {
+ auto new_df_block = EliminateCommonSubexpr(GetRef<DataflowBlock>(df_block));
+ if (!new_df_block.same_as(block)) {
+ new_blocks.push_back(new_df_block);
+ all_unchanged = false;
+ continue;
+ }
+ }
+ new_blocks.push_back(block);
+ }
+
+ if (all_unchanged) {
+ return GetRef<Expr>(seq);
+ }
+ // do not visit the body
+ return SeqExpr(new_blocks, seq->body, seq->span);
+ }
+
+ void VisitBinding_(const VarBindingNode* binding) override {
+ // no need to visit var def because the struct info isn't going to change
+ Expr new_value = RegisterBoundValue(binding->var, binding->value);
+
+ if (new_value.same_as(binding->value)) {
+ builder_->EmitNormalized(GetRef<VarBinding>(binding));
+ } else {
+ // no need to renormalize new_value because all replacements are with vars
+ builder_->EmitNormalized(VarBinding(binding->var, new_value, binding->span));
+ }
+ }
+
+ void VisitBinding_(const MatchCastNode* binding) override {
+ // no need to visit var def because the struct info isn't going to change
+ Expr new_value = RegisterBoundValue(binding->var, binding->value);
+
+ // re-emit old binding if nothing changes
+ if (new_value.same_as(binding->value)) {
+ builder_->EmitNormalized(GetRef<MatchCast>(binding));
+ } else {
+ // no need to renormalize new_value because all replacements are with vars
+ builder_->EmitNormalized(
+ MatchCast(binding->var, new_value, binding->struct_info, binding->span));
+ }
+ }
+
+ private:
+ Expr RegisterBoundValue(Var var, Expr bound_value) {
+ // special case: if we are processing a binding
+ // and this is the first time we've encountered it,
+ // we will use the binding's var for the mapping
+ bool newly_replaced = false;
+ if (count_map_.count(bound_value) && count_map_.at(bound_value) > 1 &&
+ !replacements_.count(bound_value)) {
+ replacements_[bound_value] = var;
+ newly_replaced = true;
+ }
+
+ if (newly_replaced) {
+ // If we've just added the mapping, using the overridden visitor will
+ // just return the var, which we don't want, so we will use
+ // the superclass VisitExpr to do inner substitutions
+ return ExprMutator::VisitExpr(bound_value);
+ }
+ return VisitExpr(bound_value);
+ }
+
+ const std::unordered_map<Expr, int, StructuralHash, StructuralEqual>& count_map_;
+ std::unordered_map<Expr, Var, StructuralHash, StructuralEqual> replacements_;
+};
+
+DataflowBlock EliminateCommonSubexpr(const DataflowBlock& df_block) {
+ SubexprCounter counter;
+ auto count_map = counter.Count(df_block);
+ CommonSubexprEliminator eliminator(count_map);
+ return Downcast<DataflowBlock>(eliminator.VisitBindingBlock(df_block));
+}
+
+namespace transform {
+
+Pass EliminateCommonSubexpr() {
+ runtime::TypedPackedFunc<DataflowBlock(DataflowBlock, IRModule, PassContext)> pass_func =
+ [=](DataflowBlock df_block, IRModule m, PassContext pc) {
+ return Downcast<DataflowBlock>(EliminateCommonSubexpr(df_block));
+ };
+ return CreateDataflowBlockPass(pass_func, 1, "EliminateCommonSubexpr", {});
+}
+
+TVM_REGISTER_GLOBAL("relax.transform.EliminateCommonSubexpr")
+ .set_body_typed(EliminateCommonSubexpr);
+
+} // namespace transform
+
+} // namespace relax
+} // namespace tvm
diff --git a/tests/python/relax/test_transform_cse.py b/tests/python/relax/test_transform_cse.py
new file mode 100644
index 0000000000..4ee9653ead
--- /dev/null
+++ b/tests/python/relax/test_transform_cse.py
@@ -0,0 +1,186 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Test eliminate common subexpr pass"""
+import tvm
+import tvm.testing
+from tvm.relax.transform import EliminateCommonSubexpr
+from tvm.script.parser import ir as I, relax as R, tir as T
+
+import numpy as np
+
+
+def verify(input, expected):
+ tvm.ir.assert_structural_equal(EliminateCommonSubexpr()(input), expected)
+
+
+def test_simple():
+ @I.ir_module
+ class Before:
+ @R.function
+ def foo(x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3), dtype="float32")):
+ with R.dataflow():
+ lv0 = R.add(x, y)
+ lv1 = R.add(x, y)
+ gv = R.multiply(lv0, lv1)
+ R.output(gv)
+ return gv
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def foo(x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3), dtype="float32")):
+ with R.dataflow():
+ lv0 = R.add(x, y)
+ # can combine with canonicalizing bindings
+ # and getting rid of unused bindings to eliminate this line too
+ lv1 = lv0
+ gv = R.multiply(lv0, lv1)
+ R.output(gv)
+ return gv
+
+ verify(Before, Expected)
+
+
+def test_constants():
+ @I.ir_module
+ class Before:
+ @R.function
+ def foo() -> R.Tuple(R.Tensor((), dtype="int32"), R.Tensor((2, 2), dtype="int32")):
+ with R.dataflow():
+ # we are not going to bind the constant 1 to a var
+ lv0 = R.add(R.const(1, dtype="int32"), R.const(1, dtype="int32"))
+ # we expect to bind the repeated large constants
+ lv1 = R.add(
+ R.const(tvm.nd.array(np.zeros((2, 2), dtype="int32"))),
+ R.const(tvm.nd.array(np.zeros((2, 2), dtype="int32"))),
+ )
+ gv = (lv0, lv1)
+ R.output(gv)
+ return gv
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def foo() -> R.Tuple(R.Tensor((), dtype="int32"), R.Tensor((2, 2), dtype="int32")):
+ with R.dataflow():
+ lv0 = R.add(R.const(1, dtype="int32"), R.const(1, dtype="int32"))
+ lv1 = R.const(tvm.nd.array(np.zeros((2, 2), dtype="int32")))
+ lv2 = R.add(lv1, lv1)
+ gv = (lv0, lv2)
+ R.output(gv)
+ return gv
+
+ verify(Before, Expected)
+
+
+def test_repeated_inner_tuples():
+ @I.ir_module
+ class Before:
+ @R.function
+ def foo(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"):
+ with R.dataflow():
+ # repeated units: (x, x), (x, (x, x)), ((x, x), (x, (x, x)))
+ tup = (((x, x), (x, (x, x))), ((x, x), (x, (x, x))), (x, (x, x)))
+ gv = tup[0][0][1]
+ R.output(gv)
+ return gv
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def foo(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"):
+ with R.dataflow():
+ t1 = (x, x)
+ t2 = (x, t1)
+ t3 = (t1, t2)
+ t4 = (t3, t3, t2)
+ gv = t4[0][0][1]
+ R.output(gv)
+ return gv
+
+ verify(Before, Expected)
+
+
+def test_inner_function():
+ @I.ir_module
+ class Before:
+ @R.function
+ def foo(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"):
+ with R.dataflow():
+ # we are going to do CSE inside the local function
+ @R.function
+ def bar(y: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"):
+ # not in dataflow: should not be touched
+ z = R.add(R.add(y, y), R.add(y, y))
+ with R.dataflow():
+ # writing this out in ANF to illustrate why CSE behaves as it does
+ # result of ANF transforming R.add(R.add(y, y), R.add(y, y))
+ lv0 = R.add(y, y)
+ lv1 = R.add(y, y)
+ lv2 = R.add(lv0, lv1)
+ gv = lv2
+ R.output(gv)
+ return R.add(z, gv)
+
+ # also making the ANF explicit to better illustrate the result of CSE
+ # result of ANF transforming R.add(R.add(bar(x), bar(x)), R.add(bar(x), bar(x)))
+ lv0 = bar(x)
+ lv1 = bar(x)
+ lv2 = R.add(lv0, lv1)
+ lv3 = bar(x)
+ lv4 = bar(x)
+ lv5 = R.add(lv3, lv4)
+ lv6 = R.add(lv2, lv5)
+ gv = lv6
+ R.output(gv)
+ return gv
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def foo(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"):
+ with R.dataflow():
+
+ @R.function
+ def bar(y: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"):
+ z = R.add(R.add(y, y), R.add(y, y))
+ with R.dataflow():
+ lv0 = R.add(y, y)
+ lv1 = lv0
+ lv2 = R.add(lv0, lv1)
+ gv = lv2
+ R.output(gv)
+ return R.add(z, gv)
+
+ # can further clean this up
+ # using canonicalize bindings, eliminate unused bindings, and CSE again
+ lv0 = bar(x)
+ lv1 = lv0
+ lv2 = R.add(lv0, lv1)
+ lv3 = lv0
+ lv4 = lv0
+ lv5 = R.add(lv3, lv4)
+ lv6 = R.add(lv2, lv5)
+ gv = lv6
+ R.output(gv)
+ return gv
+
+ verify(Before, Expected)
+
+
+if __name__ == "__main__":
+ tvm.testing.main()