You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by lu...@apache.org on 2023/03/24 17:30:28 UTC

[tvm] branch main updated: [TIR] not estimating the flops when there is a default estimated flops as attr (#14379)

This is an automated email from the ASF dual-hosted git repository.

lunderberg pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new c5075dc30f [TIR] not estimating the flops when there is a default estimated flops as attr (#14379)
c5075dc30f is described below

commit c5075dc30fd6cca3eeab3535cf45dfb43998b0ce
Author: Farshid Salemi Parizi <fp...@octoml.ai>
AuthorDate: Fri Mar 24 10:30:13 2023 -0700

    [TIR] not estimating the flops when there is a default estimated flops as attr (#14379)
    
    * not estimating the flops when there is a default estimated flops as attr
    
    * add unittests
    
    * lint fix
    
    * make unittest simpler
---
 src/tir/analysis/estimate_flops.cc                 | 11 +++++---
 .../test_tir_analysis_estimate_tir_flops.py        | 30 ++++++++++++++++++++++
 2 files changed, 38 insertions(+), 3 deletions(-)

diff --git a/src/tir/analysis/estimate_flops.cc b/src/tir/analysis/estimate_flops.cc
index d158a001b2..ce9d5eaaf8 100644
--- a/src/tir/analysis/estimate_flops.cc
+++ b/src/tir/analysis/estimate_flops.cc
@@ -208,10 +208,15 @@ double EstimateTIRFlops(const Stmt& stmt) {
 double EstimateTIRFlops(const IRModule& mod) {
   FlopEstimator counter;
   TResult result;
-  VisitPrimFuncs(mod, [&result, &counter](const PrimFuncNode* f) {
-    result += counter.VisitStmt(f->body);  //
+  double cached_result = 0;
+  VisitPrimFuncs(mod, [&result, &counter, &cached_result](const PrimFuncNode* f) {
+    if (auto cached = f->attrs.GetAttr<Integer>("estimated_flops")) {
+      cached_result += cached.value()->value;
+    } else {
+      result += counter.VisitStmt(f->body);  //
+    }
   });
-  return PostprocessResults(result);
+  return PostprocessResults(result) + cached_result;
 }
 
 TVM_REGISTER_GLOBAL("tir.analysis.EstimateTIRFlops").set_body_typed([](ObjectRef obj) -> double {
diff --git a/tests/python/unittest/test_tir_analysis_estimate_tir_flops.py b/tests/python/unittest/test_tir_analysis_estimate_tir_flops.py
index 06f6fe3127..489db287f3 100644
--- a/tests/python/unittest/test_tir_analysis_estimate_tir_flops.py
+++ b/tests/python/unittest/test_tir_analysis_estimate_tir_flops.py
@@ -77,5 +77,35 @@ def test_flops_with_if():
     assert flops == 16
 
 
+@T.prim_func
+def flops_with_forloop_as_expression(A: T.Buffer(1)):
+    for i in T.serial(0, 16):
+        for k in T.serial(0, i):
+            A[0] = A[0] + 1
+
+
+@T.prim_func
+def flops_override(A: T.Buffer(16, "float32")):
+    T.func_attr({"estimated_flops": 32})
+    for i in range(16):
+        A[0] = A[0] + 1
+
+
+def test_estimate_flops_forloop_as_experssion():
+    flops = estimate_tir_flops(
+        IRModule({"main": flops_with_forloop_as_expression.with_attr("estimated_flops", 32)})
+    )
+    assert flops == 32
+
+    # test whether the user estimated flop would over ride
+    flops = estimate_tir_flops(IRModule({"main": flops_override}))
+    assert flops == 32
+
+
+def test_exception():
+    with pytest.raises(tvm.TVMError):
+        flops = estimate_tir_flops(IRModule({"main": flops_with_forloop_as_expression}))
+
+
 if __name__ == "__main__":
     tvm.testing.main()