You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by wr...@apache.org on 2023/01/09 04:01:01 UTC

[tvm] branch main updated: [TIR][Arith] Add common sub expr analyzer (#13702)

This is an automated email from the ASF dual-hosted git repository.

wrongtest pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new a435cbb3b1 [TIR][Arith] Add common sub expr analyzer (#13702)
a435cbb3b1 is described below

commit a435cbb3b1484e6f347421444168ccc312ef41d3
Author: multiverstack <39...@users.noreply.github.com>
AuthorDate: Mon Jan 9 12:00:55 2023 +0800

    [TIR][Arith] Add common sub expr analyzer (#13702)
    
    * [TIR][Arith] Add common sub expr analyzer
    
    * Update python/tvm/arith/pattern.py
    
    Co-authored-by: Siyuan Feng <Hz...@sjtu.edu.cn>
    
    * Update src/arith/detect_common_subexpr.cc
    
    Co-authored-by: Siyuan Feng <Hz...@sjtu.edu.cn>
    
    * Update python/tvm/arith/pattern.py
    
    Co-authored-by: Siyuan Feng <Hz...@sjtu.edu.cn>
    
    * Update python/tvm/arith/pattern.py
    
    Co-authored-by: Siyuan Feng <Hz...@sjtu.edu.cn>
    
    * Update src/arith/detect_common_subexpr.cc
    
    Co-authored-by: Siyuan Feng <Hz...@sjtu.edu.cn>
    
    * Update detect_common_subexpr.cc
    
    * Update pattern.py
    
    * Update pattern.py
    
    * Update pattern.py
    
    * Update pattern.py
    
    Co-authored-by: Min Chen <ch...@intellif.com>
    Co-authored-by: Siyuan Feng <Hz...@sjtu.edu.cn>
---
 python/tvm/arith/__init__.py                       |  2 +-
 python/tvm/arith/pattern.py                        | 23 +++++++
 src/arith/detect_common_subexpr.cc                 | 74 ++++++++++++++++++++++
 src/tir/transforms/common_subexpr_elim_tools.cc    |  6 +-
 src/tir/transforms/common_subexpr_elim_tools.h     |  3 +-
 .../python/unittest/test_arith_detect_cse.py       | 35 +++++-----
 6 files changed, 119 insertions(+), 24 deletions(-)

diff --git a/python/tvm/arith/__init__.py b/python/tvm/arith/__init__.py
index 03c0769850..423aafe5d6 100644
--- a/python/tvm/arith/__init__.py
+++ b/python/tvm/arith/__init__.py
@@ -25,7 +25,7 @@ from .int_set import (
 )
 from .analyzer import ModularSet, ConstIntBound, Analyzer
 from .bound import deduce_bound
-from .pattern import detect_linear_equation, detect_clip_bound
+from .pattern import detect_linear_equation, detect_clip_bound, detect_common_subexpr
 from .int_solver import solve_linear_equations, solve_linear_inequalities
 from .iter_affine_map import IterMapExpr, IterMark, IterSplitExpr, IterSumExpr
 from .iter_affine_map import (
diff --git a/python/tvm/arith/pattern.py b/python/tvm/arith/pattern.py
index 53f8eb62b6..3c822dc523 100644
--- a/python/tvm/arith/pattern.py
+++ b/python/tvm/arith/pattern.py
@@ -15,6 +15,10 @@
 # specific language governing permissions and limitations
 # under the License.
 """Detect common patterns."""
+
+from typing import Dict
+
+from tvm.tir import PrimExpr
 from . import _ffi_api
 
 
@@ -58,3 +62,22 @@ def detect_clip_bound(expr, var_list):
         An empty list if the match failed.
     """
     return _ffi_api.DetectClipBound(expr, var_list)
+
+
+def detect_common_subexpr(expr: PrimExpr, threshold: int) -> Dict[PrimExpr, int]:
+    """Detect common sub expression which shows up more than a threshold times
+
+    Parameters
+    ----------
+    expr : PrimExpr
+        The expression to be analyzed.
+
+    threshold : int
+        The threshold of repeat times that determines a common sub expression
+
+    Returns
+    -------
+    cse_dict : Dict[PrimExpr, int]
+        The detected common sub expression dict, with sub expression and repeat times
+    """
+    return _ffi_api.DetectCommonSubExpr(expr, threshold)
diff --git a/src/arith/detect_common_subexpr.cc b/src/arith/detect_common_subexpr.cc
new file mode 100644
index 0000000000..b496e7fefc
--- /dev/null
+++ b/src/arith/detect_common_subexpr.cc
@@ -0,0 +1,74 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file detect_common_subexpr.cc
+ * \brief Utility to detect common sub expressions.
+ */
+#include <tvm/tir/expr.h>
+
+#include <limits>
+
+#include "../tir/transforms/common_subexpr_elim_tools.h"
+
+namespace tvm {
+namespace arith {
+
+using namespace tir;
+
+Map<PrimExpr, Integer> DetectCommonSubExpr(const PrimExpr& e, int thresh) {
+  // Check the threshold in the range of size_t
+  CHECK_GE(thresh, std::numeric_limits<size_t>::min());
+  CHECK_LE(thresh, std::numeric_limits<size_t>::max());
+  size_t repeat_thr = static_cast<size_t>(thresh);
+  auto IsEligibleComputation = [](const PrimExpr& expr) {
+    return (SideEffect(expr) <= CallEffectKind::kPure && CalculateExprComplexity(expr) > 1 &&
+            (expr.as<RampNode>() == nullptr) && (expr.as<BroadcastNode>() == nullptr));
+  };
+
+  // Analyze the sub expressions
+  ComputationTable table_syntactic_comp_done_by_expr = ComputationsDoneBy::GetComputationsDoneBy(
+      e, IsEligibleComputation, [](const PrimExpr& expr) { return true; });
+
+  std::vector<std::pair<PrimExpr, size_t>> semantic_comp_done_by_expr =
+      SyntacticToSemanticComputations(table_syntactic_comp_done_by_expr, true);
+
+  // Find eligible sub expr if occurrence is under thresh
+  for (size_t i = 0; i < semantic_comp_done_by_expr.size(); i++) {
+    std::pair<PrimExpr, size_t>& computation_and_nb = semantic_comp_done_by_expr[i];
+    if (computation_and_nb.second < repeat_thr) {
+      std::vector<PrimExpr> direct_subexprs =
+          DirectSubexpr::GetDirectSubexpressions(computation_and_nb.first, IsEligibleComputation,
+                                                 [](const PrimExpr& expr) { return true; });
+      InsertVectorToSortedSemanticComputations(&semantic_comp_done_by_expr, direct_subexprs, true,
+                                               computation_and_nb.second);
+    }
+  }
+
+  // Return the common sub expr that occur more than thresh times
+  Map<PrimExpr, Integer> results;
+  for (auto& it : semantic_comp_done_by_expr) {
+    if (it.second >= repeat_thr) results.Set(it.first, it.second);
+  }
+  return results;
+}
+
+TVM_REGISTER_GLOBAL("arith.DetectCommonSubExpr").set_body_typed(DetectCommonSubExpr);
+}  // namespace arith
+}  // namespace tvm
diff --git a/src/tir/transforms/common_subexpr_elim_tools.cc b/src/tir/transforms/common_subexpr_elim_tools.cc
index 130004c51c..c118d1db7d 100644
--- a/src/tir/transforms/common_subexpr_elim_tools.cc
+++ b/src/tir/transforms/common_subexpr_elim_tools.cc
@@ -902,7 +902,7 @@ void InsertElemToSortedSemanticComputations(std::vector<std::pair<PrimExpr, size
  */
 void InsertVectorToSortedSemanticComputations(std::vector<std::pair<PrimExpr, size_t>>* sorted_vec,
                                               const std::vector<PrimExpr>& vec_to_add,
-                                              bool identify_equiv_terms) {
+                                              bool identify_equiv_terms, size_t increase_count) {
   if (sorted_vec == nullptr) {
     return;
   }
@@ -918,10 +918,10 @@ void InsertVectorToSortedSemanticComputations(std::vector<std::pair<PrimExpr, si
     // If we found `elem_to_add` (or an equivalent expression) already in sorted_vec
     if (it_found != sorted_vec->end()) {
       // then we just increase its associated count
-      it_found->second++;
+      it_found->second += increase_count;
     } else {
       // Otherwise we add the pair (`elem_to_add`,1) at the right place
-      InsertElemToSortedSemanticComputations(sorted_vec, {elem_to_add, 1});
+      InsertElemToSortedSemanticComputations(sorted_vec, {elem_to_add, increase_count});
     }
   }
 }
diff --git a/src/tir/transforms/common_subexpr_elim_tools.h b/src/tir/transforms/common_subexpr_elim_tools.h
index 0871fd0091..841f1d65a6 100644
--- a/src/tir/transforms/common_subexpr_elim_tools.h
+++ b/src/tir/transforms/common_subexpr_elim_tools.h
@@ -210,9 +210,10 @@ template std::vector<Var> VectorMap(const std::vector<std::pair<Var, MaybeValue>
 
 void InsertElemToSortedSemanticComputations(std::vector<std::pair<PrimExpr, size_t>>* sorted_vec,
                                             const std::pair<PrimExpr, size_t>& pair);
+
 void InsertVectorToSortedSemanticComputations(std::vector<std::pair<PrimExpr, size_t>>* sorted_vec,
                                               const std::vector<PrimExpr>& vec_to_add,
-                                              bool identify_equiv_terms);
+                                              bool identify_equiv_terms, size_t increase_count = 1);
 
 }  // namespace tir
 }  // namespace tvm
diff --git a/python/tvm/arith/__init__.py b/tests/python/unittest/test_arith_detect_cse.py
similarity index 54%
copy from python/tvm/arith/__init__.py
copy to tests/python/unittest/test_arith_detect_cse.py
index 03c0769850..eba0920cb2 100644
--- a/python/tvm/arith/__init__.py
+++ b/tests/python/unittest/test_arith_detect_cse.py
@@ -14,23 +14,20 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""Integer bound analysis, simplification and pattern detection."""
+import tvm
+import tvm.testing
+from tvm.script import tir as T
 
-from .int_set import (
-    IntSet,
-    IntervalSet,
-    estimate_region_lower_bound,
-    estimate_region_strict_bound,
-    estimate_region_upper_bound,
-)
-from .analyzer import ModularSet, ConstIntBound, Analyzer
-from .bound import deduce_bound
-from .pattern import detect_linear_equation, detect_clip_bound
-from .int_solver import solve_linear_equations, solve_linear_inequalities
-from .iter_affine_map import IterMapExpr, IterMark, IterSplitExpr, IterSumExpr
-from .iter_affine_map import (
-    detect_iter_map,
-    normalize_iter_map_to_expr,
-    subspace_divide,
-    inverse_affine_iter_map,
-)
+
+def test_detect_cs():
+    x = T.Var("x", dtype="int32")
+    y = T.Var("y", dtype="int32")
+    z = T.Var("z", dtype="int32")
+    c = T.floor(x + y + 0.5) + x + z * (T.floor(x + y + 0.5))
+    m = tvm.arith.detect_common_subexpr(c, 2)
+    assert c.a.a in m
+    assert m[c.a.a] == 2
+
+
+if __name__ == "__main__":
+    tvm.testing.main()