You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ss...@apache.org on 2023/03/27 17:58:21 UTC

[tvm] branch unity updated: [Unity][Transform] Common Subexpression Elimination (#14361)

This is an automated email from the ASF dual-hosted git repository.

sslyu pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new f4d5964653 [Unity][Transform] Common Subexpression Elimination (#14361)
f4d5964653 is described below

commit f4d5964653f1fc197fac02edc197deb46ee4dfed
Author: Steven S. Lyubomirsky <sl...@octoml.ai>
AuthorDate: Mon Mar 27 13:58:13 2023 -0400

    [Unity][Transform] Common Subexpression Elimination (#14361)
    
    * [Unity][Pass] Add pass for CSE within dataflow
    
    * Fill in CSE definition and test cases
    
    * Missing trailing newline
    
    ---------
    
    Co-authored-by: Prakalp Srivastava <pr...@octoml.ai>
---
 include/tvm/relax/transform.h                   |   9 +
 python/tvm/relax/transform/transform.py         |  14 ++
 src/relax/transform/eliminate_common_subexpr.cc | 209 ++++++++++++++++++++++++
 tests/python/relax/test_transform_cse.py        | 186 +++++++++++++++++++++
 4 files changed, 418 insertions(+)

diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h
index 4f45ba9c25..f6acf80beb 100644
--- a/include/tvm/relax/transform.h
+++ b/include/tvm/relax/transform.h
@@ -137,6 +137,15 @@ TVM_DLL Pass Normalize();
  */
 TVM_DLL Pass CanonicalizeBindings();
 
+/*!
+ * Eliminate common subexpressions within dataflow blocks.
+ * \return The pass that eliminates common subexpressions.
+ *
+ * \note For functions local to dataflow blocks, this pass performs
+ * CSE *within* those functions.
+ */
+TVM_DLL Pass EliminateCommonSubexpr();
+
 /*!
  * \brief Bind params of function of the module to constant tensors.
  *
diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py
index 18321e8dba..049ac2947f 100644
--- a/python/tvm/relax/transform/transform.py
+++ b/python/tvm/relax/transform/transform.py
@@ -100,6 +100,20 @@ def CanonicalizeBindings() -> tvm.ir.transform.Pass:
     return _ffi_api.CanonicalizeBindings()  # type: ignore
 
 
+def EliminateCommonSubexpr() -> DataflowBlockPass:
+    """Eliminate common subexpressions within dataflow blocks.
+
+    Note: For functions local to dataflow blocks, this pass performs
+    CSE *within* those functions
+
+    Returns
+    -------
+    ret : tvm.transform.Pass
+        The registered pass that eliminates common subexpressions.
+    """
+    return _ffi_api.EliminateCommonSubexpr()  # type: ignore
+
+
 def RewriteDataflowReshape() -> tvm.ir.transform.Pass:
     """Convert all reshape-like call_tir to VM reshape operator call.
     The VM reshape operator calls will be further lowered to a CreateView
diff --git a/src/relax/transform/eliminate_common_subexpr.cc b/src/relax/transform/eliminate_common_subexpr.cc
new file mode 100644
index 0000000000..9c9252ddfa
--- /dev/null
+++ b/src/relax/transform/eliminate_common_subexpr.cc
@@ -0,0 +1,209 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ *
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/relax/transform/eliminate_common_subexpr.cc
+ * \brief Eliminrate common subexpression pass.
+ *
+ * Currently it removes common subexpressions within a DataflowBlock.
+ */
+#include <tvm/relax/expr_functor.h>
+#include <tvm/relax/transform.h>
+
+namespace tvm {
+namespace relax {
+
+class SubexprCounter : public ExprVisitor {
+ public:
+  // overriding VisitExpr ensures we do this for every subexpression
+  void VisitExpr(const Expr& e) override {
+    // Cases we ignore because we will not substitute them:
+    // 1. Vars of all kinds
+    // 2. Op nodes (nothing we can do)
+    // 3. Scalar constants (not much benefit from binding to a var)
+    if (!(e->IsInstance<VarNode>() || e->IsInstance<DataflowVarNode>() ||
+          e->IsInstance<GlobalVarNode>() || e->IsInstance<tvm::OpNode>() ||
+          (e.as<ConstantNode>() && (e.as<ConstantNode>()->is_scalar())))) {
+      int count = 0;
+      if (count_map_.count(e)) {
+        count = count_map_.at(e);
+      }
+      count_map_[e] = count + 1;
+    }
+    ExprVisitor::VisitExpr(e);
+  }
+
+  // do not visit inner functions: we will do CSE within those
+  void VisitExpr_(const FunctionNode* func) override {}
+
+  // we are not going to do replacements inside struct info to avoid binding lots of reused shapes
+  void VisitExprDepStructInfoField(const StructInfo& struct_info) override {}
+
+  std::unordered_map<Expr, int, StructuralHash, StructuralEqual> Count(
+      const DataflowBlock& df_block) {
+    for (auto binding : df_block->bindings) {
+      VisitBinding(binding);
+    }
+    return count_map_;
+  }
+
+ private:
+  std::unordered_map<Expr, int, StructuralHash, StructuralEqual> count_map_;
+};
+
+// forward declaration
+DataflowBlock EliminateCommonSubexpr(const DataflowBlock&);
+
+class CommonSubexprEliminator : public ExprMutator {
+ public:
+  explicit CommonSubexprEliminator(
+      const std::unordered_map<Expr, int, StructuralHash, StructuralEqual>& count_map)
+      : count_map_(count_map) {}
+
+  // overriding here ensures we visit every subexpression
+  Expr VisitExpr(const Expr& e) override {
+    if (count_map_.count(e) && count_map_.at(e) > 1) {
+      // if we already have a mapping for it, get it
+      if (replacements_.count(e)) {
+        return replacements_.at(e);
+      }
+      // Otherwise, insert a new binding for the current expression.
+      // Visit before emitting to do inner replacements
+      Expr new_e = ExprMutator::VisitExpr(e);
+      Var v = builder_->Emit(new_e);
+      replacements_[e] = v;
+      return v;
+    }
+    return ExprMutator::VisitExpr(e);
+  }
+
+  // we are not going to do replacements inside struct info to avoid binding lots of reused shapes
+  StructInfo VisitExprDepStructInfoField(const StructInfo& struct_info) override {
+    return struct_info;
+  }
+
+  Expr VisitExpr_(const FunctionNode* func) override {
+    // for an inner function, we will do CSE on its body
+    Expr new_body = ExprMutator::VisitExpr(func->body);
+    if (new_body.same_as(func->body)) {
+      return GetRef<Expr>(func);
+    }
+    return Function(func->params, new_body, func->ret_struct_info, func->attrs, func->span);
+  }
+
+  // this should happen only for the inner function case
+  Expr VisitExpr_(const SeqExprNode* seq) override {
+    bool all_unchanged = true;
+    Array<BindingBlock> new_blocks;
+    // apply CSE within dataflow blocks only
+    for (auto block : seq->blocks) {
+      if (const DataflowBlockNode* df_block = block.as<DataflowBlockNode>()) {
+        auto new_df_block = EliminateCommonSubexpr(GetRef<DataflowBlock>(df_block));
+        if (!new_df_block.same_as(block)) {
+          new_blocks.push_back(new_df_block);
+          all_unchanged = false;
+          continue;
+        }
+      }
+      new_blocks.push_back(block);
+    }
+
+    if (all_unchanged) {
+      return GetRef<Expr>(seq);
+    }
+    // do not visit the body
+    return SeqExpr(new_blocks, seq->body, seq->span);
+  }
+
+  void VisitBinding_(const VarBindingNode* binding) override {
+    // no need to visit var def because the struct info isn't going to change
+    Expr new_value = RegisterBoundValue(binding->var, binding->value);
+
+    if (new_value.same_as(binding->value)) {
+      builder_->EmitNormalized(GetRef<VarBinding>(binding));
+    } else {
+      // no need to renormalize new_value because all replacements are with vars
+      builder_->EmitNormalized(VarBinding(binding->var, new_value, binding->span));
+    }
+  }
+
+  void VisitBinding_(const MatchCastNode* binding) override {
+    // no need to visit var def because the struct info isn't going to change
+    Expr new_value = RegisterBoundValue(binding->var, binding->value);
+
+    // re-emit old binding if nothing changes
+    if (new_value.same_as(binding->value)) {
+      builder_->EmitNormalized(GetRef<MatchCast>(binding));
+    } else {
+      // no need to renormalize new_value because all replacements are with vars
+      builder_->EmitNormalized(
+          MatchCast(binding->var, new_value, binding->struct_info, binding->span));
+    }
+  }
+
+ private:
+  Expr RegisterBoundValue(Var var, Expr bound_value) {
+    // special case: if we are processing a binding
+    // and this is the first time we've encountered it,
+    // we will use the binding's var for the mapping
+    bool newly_replaced = false;
+    if (count_map_.count(bound_value) && count_map_.at(bound_value) > 1 &&
+        !replacements_.count(bound_value)) {
+      replacements_[bound_value] = var;
+      newly_replaced = true;
+    }
+
+    if (newly_replaced) {
+      // If we've just added the mapping, using the overridden visitor will
+      // just return the var, which we don't want, so we will use
+      // the superclass VisitExpr to do inner substitutions
+      return ExprMutator::VisitExpr(bound_value);
+    }
+    return VisitExpr(bound_value);
+  }
+
+  const std::unordered_map<Expr, int, StructuralHash, StructuralEqual>& count_map_;
+  std::unordered_map<Expr, Var, StructuralHash, StructuralEqual> replacements_;
+};
+
+DataflowBlock EliminateCommonSubexpr(const DataflowBlock& df_block) {
+  SubexprCounter counter;
+  auto count_map = counter.Count(df_block);
+  CommonSubexprEliminator eliminator(count_map);
+  return Downcast<DataflowBlock>(eliminator.VisitBindingBlock(df_block));
+}
+
+namespace transform {
+
+Pass EliminateCommonSubexpr() {
+  runtime::TypedPackedFunc<DataflowBlock(DataflowBlock, IRModule, PassContext)> pass_func =
+      [=](DataflowBlock df_block, IRModule m, PassContext pc) {
+        return Downcast<DataflowBlock>(EliminateCommonSubexpr(df_block));
+      };
+  return CreateDataflowBlockPass(pass_func, 1, "EliminateCommonSubexpr", {});
+}
+
+TVM_REGISTER_GLOBAL("relax.transform.EliminateCommonSubexpr")
+    .set_body_typed(EliminateCommonSubexpr);
+
+}  // namespace transform
+
+}  // namespace relax
+}  // namespace tvm
diff --git a/tests/python/relax/test_transform_cse.py b/tests/python/relax/test_transform_cse.py
new file mode 100644
index 0000000000..4ee9653ead
--- /dev/null
+++ b/tests/python/relax/test_transform_cse.py
@@ -0,0 +1,186 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Test eliminate common subexpr pass"""
+import tvm
+import tvm.testing
+from tvm.relax.transform import EliminateCommonSubexpr
+from tvm.script.parser import ir as I, relax as R, tir as T
+
+import numpy as np
+
+
+def verify(input, expected):
+    tvm.ir.assert_structural_equal(EliminateCommonSubexpr()(input), expected)
+
+
+def test_simple():
+    @I.ir_module
+    class Before:
+        @R.function
+        def foo(x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3), dtype="float32")):
+            with R.dataflow():
+                lv0 = R.add(x, y)
+                lv1 = R.add(x, y)
+                gv = R.multiply(lv0, lv1)
+                R.output(gv)
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def foo(x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3), dtype="float32")):
+            with R.dataflow():
+                lv0 = R.add(x, y)
+                # can combine with canonicalizing bindings
+                # and getting rid of unused bindings to eliminate this line too
+                lv1 = lv0
+                gv = R.multiply(lv0, lv1)
+                R.output(gv)
+            return gv
+
+    verify(Before, Expected)
+
+
+def test_constants():
+    @I.ir_module
+    class Before:
+        @R.function
+        def foo() -> R.Tuple(R.Tensor((), dtype="int32"), R.Tensor((2, 2), dtype="int32")):
+            with R.dataflow():
+                # we are not going to bind the constant 1 to a var
+                lv0 = R.add(R.const(1, dtype="int32"), R.const(1, dtype="int32"))
+                # we expect to bind the repeated large constants
+                lv1 = R.add(
+                    R.const(tvm.nd.array(np.zeros((2, 2), dtype="int32"))),
+                    R.const(tvm.nd.array(np.zeros((2, 2), dtype="int32"))),
+                )
+                gv = (lv0, lv1)
+                R.output(gv)
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def foo() -> R.Tuple(R.Tensor((), dtype="int32"), R.Tensor((2, 2), dtype="int32")):
+            with R.dataflow():
+                lv0 = R.add(R.const(1, dtype="int32"), R.const(1, dtype="int32"))
+                lv1 = R.const(tvm.nd.array(np.zeros((2, 2), dtype="int32")))
+                lv2 = R.add(lv1, lv1)
+                gv = (lv0, lv2)
+                R.output(gv)
+            return gv
+
+    verify(Before, Expected)
+
+
+def test_repeated_inner_tuples():
+    @I.ir_module
+    class Before:
+        @R.function
+        def foo(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"):
+            with R.dataflow():
+                # repeated units: (x, x), (x, (x, x)), ((x, x), (x, (x, x)))
+                tup = (((x, x), (x, (x, x))), ((x, x), (x, (x, x))), (x, (x, x)))
+                gv = tup[0][0][1]
+                R.output(gv)
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def foo(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"):
+            with R.dataflow():
+                t1 = (x, x)
+                t2 = (x, t1)
+                t3 = (t1, t2)
+                t4 = (t3, t3, t2)
+                gv = t4[0][0][1]
+                R.output(gv)
+            return gv
+
+    verify(Before, Expected)
+
+
+def test_inner_function():
+    @I.ir_module
+    class Before:
+        @R.function
+        def foo(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"):
+            with R.dataflow():
+                # we are going to do CSE inside the local function
+                @R.function
+                def bar(y: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"):
+                    # not in dataflow: should not be touched
+                    z = R.add(R.add(y, y), R.add(y, y))
+                    with R.dataflow():
+                        # writing this out in ANF to illustrate why CSE behaves as it does
+                        # result of ANF transforming R.add(R.add(y, y), R.add(y, y))
+                        lv0 = R.add(y, y)
+                        lv1 = R.add(y, y)
+                        lv2 = R.add(lv0, lv1)
+                        gv = lv2
+                        R.output(gv)
+                    return R.add(z, gv)
+
+                # also making the ANF explicit to better illustrate the result of CSE
+                # result of ANF transforming R.add(R.add(bar(x), bar(x)), R.add(bar(x), bar(x)))
+                lv0 = bar(x)
+                lv1 = bar(x)
+                lv2 = R.add(lv0, lv1)
+                lv3 = bar(x)
+                lv4 = bar(x)
+                lv5 = R.add(lv3, lv4)
+                lv6 = R.add(lv2, lv5)
+                gv = lv6
+                R.output(gv)
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def foo(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"):
+            with R.dataflow():
+
+                @R.function
+                def bar(y: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"):
+                    z = R.add(R.add(y, y), R.add(y, y))
+                    with R.dataflow():
+                        lv0 = R.add(y, y)
+                        lv1 = lv0
+                        lv2 = R.add(lv0, lv1)
+                        gv = lv2
+                        R.output(gv)
+                    return R.add(z, gv)
+
+                # can further clean this up
+                # using canonicalize bindings, eliminate unused bindings, and CSE again
+                lv0 = bar(x)
+                lv1 = lv0
+                lv2 = R.add(lv0, lv1)
+                lv3 = lv0
+                lv4 = lv0
+                lv5 = R.add(lv3, lv4)
+                lv6 = R.add(lv2, lv5)
+                gv = lv6
+                R.output(gv)
+            return gv
+
+    verify(Before, Expected)
+
+
+if __name__ == "__main__":
+    tvm.testing.main()