You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by wr...@apache.org on 2023/01/09 04:01:01 UTC
[tvm] branch main updated: [TIR][Arith] Add common sub expr analyzer (#13702)
This is an automated email from the ASF dual-hosted git repository.
wrongtest pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new a435cbb3b1 [TIR][Arith] Add common sub expr analyzer (#13702)
a435cbb3b1 is described below
commit a435cbb3b1484e6f347421444168ccc312ef41d3
Author: multiverstack <39...@users.noreply.github.com>
AuthorDate: Mon Jan 9 12:00:55 2023 +0800
[TIR][Arith] Add common sub expr analyzer (#13702)
* [TIR][Arith] Add common sub expr analyzer
* Update python/tvm/arith/pattern.py
Co-authored-by: Siyuan Feng <Hz...@sjtu.edu.cn>
* Update src/arith/detect_common_subexpr.cc
Co-authored-by: Siyuan Feng <Hz...@sjtu.edu.cn>
* Update python/tvm/arith/pattern.py
Co-authored-by: Siyuan Feng <Hz...@sjtu.edu.cn>
* Update python/tvm/arith/pattern.py
Co-authored-by: Siyuan Feng <Hz...@sjtu.edu.cn>
* Update src/arith/detect_common_subexpr.cc
Co-authored-by: Siyuan Feng <Hz...@sjtu.edu.cn>
* Update detect_common_subexpr.cc
* Update pattern.py
* Update pattern.py
* Update pattern.py
* Update pattern.py
Co-authored-by: Min Chen <ch...@intellif.com>
Co-authored-by: Siyuan Feng <Hz...@sjtu.edu.cn>
---
python/tvm/arith/__init__.py | 2 +-
python/tvm/arith/pattern.py | 23 +++++++
src/arith/detect_common_subexpr.cc | 74 ++++++++++++++++++++++
src/tir/transforms/common_subexpr_elim_tools.cc | 6 +-
src/tir/transforms/common_subexpr_elim_tools.h | 3 +-
.../python/unittest/test_arith_detect_cse.py | 35 +++++-----
6 files changed, 119 insertions(+), 24 deletions(-)
diff --git a/python/tvm/arith/__init__.py b/python/tvm/arith/__init__.py
index 03c0769850..423aafe5d6 100644
--- a/python/tvm/arith/__init__.py
+++ b/python/tvm/arith/__init__.py
@@ -25,7 +25,7 @@ from .int_set import (
)
from .analyzer import ModularSet, ConstIntBound, Analyzer
from .bound import deduce_bound
-from .pattern import detect_linear_equation, detect_clip_bound
+from .pattern import detect_linear_equation, detect_clip_bound, detect_common_subexpr
from .int_solver import solve_linear_equations, solve_linear_inequalities
from .iter_affine_map import IterMapExpr, IterMark, IterSplitExpr, IterSumExpr
from .iter_affine_map import (
diff --git a/python/tvm/arith/pattern.py b/python/tvm/arith/pattern.py
index 53f8eb62b6..3c822dc523 100644
--- a/python/tvm/arith/pattern.py
+++ b/python/tvm/arith/pattern.py
@@ -15,6 +15,10 @@
# specific language governing permissions and limitations
# under the License.
"""Detect common patterns."""
+
+from typing import Dict
+
+from tvm.tir import PrimExpr
from . import _ffi_api
@@ -58,3 +62,22 @@ def detect_clip_bound(expr, var_list):
An empty list if the match failed.
"""
return _ffi_api.DetectClipBound(expr, var_list)
+
+
+def detect_common_subexpr(expr: PrimExpr, threshold: int) -> Dict[PrimExpr, int]:
+ """Detect common sub expression which shows up more than a threshold times
+
+ Parameters
+ ----------
+ expr : PrimExpr
+ The expression to be analyzed.
+
+ threshold : int
+ The threshold of repeat times that determines a common sub expression
+
+ Returns
+ -------
+ cse_dict : Dict[PrimExpr, int]
+ The detected common sub expression dict, with sub expression and repeat times
+ """
+ return _ffi_api.DetectCommonSubExpr(expr, threshold)
diff --git a/src/arith/detect_common_subexpr.cc b/src/arith/detect_common_subexpr.cc
new file mode 100644
index 0000000000..b496e7fefc
--- /dev/null
+++ b/src/arith/detect_common_subexpr.cc
@@ -0,0 +1,74 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file detect_common_subexpr.cc
+ * \brief Utility to detect common sub expressions.
+ */
+#include <tvm/tir/expr.h>
+
+#include <limits>
+
+#include "../tir/transforms/common_subexpr_elim_tools.h"
+
+namespace tvm {
+namespace arith {
+
+using namespace tir;
+
+Map<PrimExpr, Integer> DetectCommonSubExpr(const PrimExpr& e, int thresh) {
+ // Check the threshold in the range of size_t
+ CHECK_GE(thresh, std::numeric_limits<size_t>::min());
+ CHECK_LE(thresh, std::numeric_limits<size_t>::max());
+ size_t repeat_thr = static_cast<size_t>(thresh);
+ auto IsEligibleComputation = [](const PrimExpr& expr) {
+ return (SideEffect(expr) <= CallEffectKind::kPure && CalculateExprComplexity(expr) > 1 &&
+ (expr.as<RampNode>() == nullptr) && (expr.as<BroadcastNode>() == nullptr));
+ };
+
+ // Analyze the sub expressions
+ ComputationTable table_syntactic_comp_done_by_expr = ComputationsDoneBy::GetComputationsDoneBy(
+ e, IsEligibleComputation, [](const PrimExpr& expr) { return true; });
+
+ std::vector<std::pair<PrimExpr, size_t>> semantic_comp_done_by_expr =
+ SyntacticToSemanticComputations(table_syntactic_comp_done_by_expr, true);
+
+ // Find eligible sub expr if occurrence is under thresh
+ for (size_t i = 0; i < semantic_comp_done_by_expr.size(); i++) {
+ std::pair<PrimExpr, size_t>& computation_and_nb = semantic_comp_done_by_expr[i];
+ if (computation_and_nb.second < repeat_thr) {
+ std::vector<PrimExpr> direct_subexprs =
+ DirectSubexpr::GetDirectSubexpressions(computation_and_nb.first, IsEligibleComputation,
+ [](const PrimExpr& expr) { return true; });
+ InsertVectorToSortedSemanticComputations(&semantic_comp_done_by_expr, direct_subexprs, true,
+ computation_and_nb.second);
+ }
+ }
+
+ // Return the common sub expr that occur more than thresh times
+ Map<PrimExpr, Integer> results;
+ for (auto& it : semantic_comp_done_by_expr) {
+ if (it.second >= repeat_thr) results.Set(it.first, it.second);
+ }
+ return results;
+}
+
+TVM_REGISTER_GLOBAL("arith.DetectCommonSubExpr").set_body_typed(DetectCommonSubExpr);
+} // namespace arith
+} // namespace tvm
diff --git a/src/tir/transforms/common_subexpr_elim_tools.cc b/src/tir/transforms/common_subexpr_elim_tools.cc
index 130004c51c..c118d1db7d 100644
--- a/src/tir/transforms/common_subexpr_elim_tools.cc
+++ b/src/tir/transforms/common_subexpr_elim_tools.cc
@@ -902,7 +902,7 @@ void InsertElemToSortedSemanticComputations(std::vector<std::pair<PrimExpr, size
*/
void InsertVectorToSortedSemanticComputations(std::vector<std::pair<PrimExpr, size_t>>* sorted_vec,
const std::vector<PrimExpr>& vec_to_add,
- bool identify_equiv_terms) {
+ bool identify_equiv_terms, size_t increase_count) {
if (sorted_vec == nullptr) {
return;
}
@@ -918,10 +918,10 @@ void InsertVectorToSortedSemanticComputations(std::vector<std::pair<PrimExpr, si
// If we found `elem_to_add` (or an equivalent expression) already in sorted_vec
if (it_found != sorted_vec->end()) {
// then we just increase its associated count
- it_found->second++;
+ it_found->second += increase_count;
} else {
// Otherwise we add the pair (`elem_to_add`,1) at the right place
- InsertElemToSortedSemanticComputations(sorted_vec, {elem_to_add, 1});
+ InsertElemToSortedSemanticComputations(sorted_vec, {elem_to_add, increase_count});
}
}
}
diff --git a/src/tir/transforms/common_subexpr_elim_tools.h b/src/tir/transforms/common_subexpr_elim_tools.h
index 0871fd0091..841f1d65a6 100644
--- a/src/tir/transforms/common_subexpr_elim_tools.h
+++ b/src/tir/transforms/common_subexpr_elim_tools.h
@@ -210,9 +210,10 @@ template std::vector<Var> VectorMap(const std::vector<std::pair<Var, MaybeValue>
void InsertElemToSortedSemanticComputations(std::vector<std::pair<PrimExpr, size_t>>* sorted_vec,
const std::pair<PrimExpr, size_t>& pair);
+
void InsertVectorToSortedSemanticComputations(std::vector<std::pair<PrimExpr, size_t>>* sorted_vec,
const std::vector<PrimExpr>& vec_to_add,
- bool identify_equiv_terms);
+ bool identify_equiv_terms, size_t increase_count = 1);
} // namespace tir
} // namespace tvm
diff --git a/python/tvm/arith/__init__.py b/tests/python/unittest/test_arith_detect_cse.py
similarity index 54%
copy from python/tvm/arith/__init__.py
copy to tests/python/unittest/test_arith_detect_cse.py
index 03c0769850..eba0920cb2 100644
--- a/python/tvm/arith/__init__.py
+++ b/tests/python/unittest/test_arith_detect_cse.py
@@ -14,23 +14,20 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""Integer bound analysis, simplification and pattern detection."""
+import tvm
+import tvm.testing
+from tvm.script import tir as T
-from .int_set import (
- IntSet,
- IntervalSet,
- estimate_region_lower_bound,
- estimate_region_strict_bound,
- estimate_region_upper_bound,
-)
-from .analyzer import ModularSet, ConstIntBound, Analyzer
-from .bound import deduce_bound
-from .pattern import detect_linear_equation, detect_clip_bound
-from .int_solver import solve_linear_equations, solve_linear_inequalities
-from .iter_affine_map import IterMapExpr, IterMark, IterSplitExpr, IterSumExpr
-from .iter_affine_map import (
- detect_iter_map,
- normalize_iter_map_to_expr,
- subspace_divide,
- inverse_affine_iter_map,
-)
+
+def test_detect_cs():
+ x = T.Var("x", dtype="int32")
+ y = T.Var("y", dtype="int32")
+ z = T.Var("z", dtype="int32")
+ c = T.floor(x + y + 0.5) + x + z * (T.floor(x + y + 0.5))
+ m = tvm.arith.detect_common_subexpr(c, 2)
+ assert c.a.a in m
+ assert m[c.a.a] == 2
+
+
+if __name__ == "__main__":
+ tvm.testing.main()