You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by lu...@apache.org on 2023/03/24 17:30:28 UTC
[tvm] branch main updated: [TIR] not estimating the flops when there is a default estimated flops as attr (#14379)
This is an automated email from the ASF dual-hosted git repository.
lunderberg pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new c5075dc30f [TIR] not estimating the flops when there is a default estimated flops as attr (#14379)
c5075dc30f is described below
commit c5075dc30fd6cca3eeab3535cf45dfb43998b0ce
Author: Farshid Salemi Parizi <fp...@octoml.ai>
AuthorDate: Fri Mar 24 10:30:13 2023 -0700
[TIR] not estimating the flops when there is a default estimated flops as attr (#14379)
* not estimating the flops when there is a default estimated flops as attr
* add unittests
* lint fix
* make unittest simpler
---
src/tir/analysis/estimate_flops.cc | 11 +++++---
.../test_tir_analysis_estimate_tir_flops.py | 30 ++++++++++++++++++++++
2 files changed, 38 insertions(+), 3 deletions(-)
diff --git a/src/tir/analysis/estimate_flops.cc b/src/tir/analysis/estimate_flops.cc
index d158a001b2..ce9d5eaaf8 100644
--- a/src/tir/analysis/estimate_flops.cc
+++ b/src/tir/analysis/estimate_flops.cc
@@ -208,10 +208,15 @@ double EstimateTIRFlops(const Stmt& stmt) {
double EstimateTIRFlops(const IRModule& mod) {
FlopEstimator counter;
TResult result;
- VisitPrimFuncs(mod, [&result, &counter](const PrimFuncNode* f) {
- result += counter.VisitStmt(f->body); //
+ double cached_result = 0;
+ VisitPrimFuncs(mod, [&result, &counter, &cached_result](const PrimFuncNode* f) {
+ if (auto cached = f->attrs.GetAttr<Integer>("estimated_flops")) {
+ cached_result += cached.value()->value;
+ } else {
+ result += counter.VisitStmt(f->body); //
+ }
});
- return PostprocessResults(result);
+ return PostprocessResults(result) + cached_result;
}
TVM_REGISTER_GLOBAL("tir.analysis.EstimateTIRFlops").set_body_typed([](ObjectRef obj) -> double {
diff --git a/tests/python/unittest/test_tir_analysis_estimate_tir_flops.py b/tests/python/unittest/test_tir_analysis_estimate_tir_flops.py
index 06f6fe3127..489db287f3 100644
--- a/tests/python/unittest/test_tir_analysis_estimate_tir_flops.py
+++ b/tests/python/unittest/test_tir_analysis_estimate_tir_flops.py
@@ -77,5 +77,35 @@ def test_flops_with_if():
assert flops == 16
+@T.prim_func
+def flops_with_forloop_as_expression(A: T.Buffer(1)):
+ for i in T.serial(0, 16):
+ for k in T.serial(0, i):
+ A[0] = A[0] + 1
+
+
+@T.prim_func
+def flops_override(A: T.Buffer(16, "float32")):
+ T.func_attr({"estimated_flops": 32})
+ for i in range(16):
+ A[0] = A[0] + 1
+
+
+def test_estimate_flops_forloop_as_experssion():
+ flops = estimate_tir_flops(
+ IRModule({"main": flops_with_forloop_as_expression.with_attr("estimated_flops", 32)})
+ )
+ assert flops == 32
+
+ # test whether the user estimated flop would over ride
+ flops = estimate_tir_flops(IRModule({"main": flops_override}))
+ assert flops == 32
+
+
+def test_exception():
+ with pytest.raises(tvm.TVMError):
+ flops = estimate_tir_flops(IRModule({"main": flops_with_forloop_as_expression}))
+
+
if __name__ == "__main__":
tvm.testing.main()