You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2023/01/24 01:02:45 UTC

[tvm] branch main updated: [Metaschedule] get_top_k should not return not built records (#13824)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new e79fac6300 [Metaschedule] get_top_k should not return not built records (#13824)
e79fac6300 is described below

commit e79fac6300bea5ba45982d3f087855cb71be0f53
Author: Alexey <av...@gmail.com>
AuthorDate: Tue Jan 24 04:02:38 2023 +0300

    [Metaschedule] get_top_k should not return not built records (#13824)
    
    * [Metaschedule] get_top_k should not return not built records
    
    * [Metaschedule][NFC] GetTopK extra polishing
---
 src/meta_schedule/database/json_database.cc        |  9 ++++-
 src/meta_schedule/database/memory_database.cc      | 44 +++++++++-------------
 .../python/unittest/test_meta_schedule_database.py | 16 ++++++--
 3 files changed, 38 insertions(+), 31 deletions(-)

diff --git a/src/meta_schedule/database/json_database.cc b/src/meta_schedule/database/json_database.cc
index b0fba5adb5..0e51e262df 100644
--- a/src/meta_schedule/database/json_database.cc
+++ b/src/meta_schedule/database/json_database.cc
@@ -127,7 +127,14 @@ class JSONDatabaseNode : public DatabaseNode {
     Array<TuningRecord> results;
     results.reserve(top_k);
     for (const TuningRecord& record : this->tuning_records_) {
-      if (!record->run_secs.defined() || record->run_secs.value().empty()) {
+      auto run_secs = record->run_secs;
+      if (!run_secs.defined() || run_secs.value().empty() ||
+          std::all_of(run_secs.value().begin(), run_secs.value().end(),
+                      // kMaxMeanTime(1e10) is used as a stub for undefined measurement times.
+                      [](tvm::FloatImm v) {
+                        return v.defined() &&
+                               v->value == SortTuningRecordByMeanRunSecs::kMaxMeanTime;
+                      })) {
         continue;
       }
       if (record->workload.same_as(workload) ||
diff --git a/src/meta_schedule/database/memory_database.cc b/src/meta_schedule/database/memory_database.cc
index 19178a35f4..8cbde46f83 100644
--- a/src/meta_schedule/database/memory_database.cc
+++ b/src/meta_schedule/database/memory_database.cc
@@ -65,42 +65,34 @@ class MemoryDatabaseNode : public DatabaseNode {
     if (top_k == 0) {
       return {};
     }
-    std::vector<std::pair<double, TuningRecord>> results;
+    std::vector<TuningRecord> results;
     results.reserve(records.size());
     for (const TuningRecord& record : records) {
-      if (!record->run_secs.defined()) {
-        continue;
-      }
-      Array<FloatImm> run_secs = record->run_secs.value();
-      if (run_secs.empty()) {
+      auto run_secs = record->run_secs;
+      if (!run_secs.defined() || run_secs.value().empty() ||
+          std::all_of(run_secs.value().begin(), run_secs.value().end(),
+                      // kMaxMeanTime(1e10) is used as a stub for undefined measurement times.
+                      [](tvm::FloatImm v) {
+                        return v.defined() &&
+                               v->value == SortTuningRecordByMeanRunSecs::kMaxMeanTime;
+                      })) {
         continue;
       }
       if (record->workload.same_as(workload) ||
           WorkloadEqual(GetModuleEquality())(record->workload, workload)) {
-        double sum = 0.0;
-        for (const FloatImm& i : run_secs) {
-          sum += i->value;
-        }
-        results.emplace_back(sum / run_secs.size(), record);
+        results.emplace_back(record);
       }
     }
-    std::sort(results.begin(), results.end());
-    auto begin = results.begin();
-    auto end = results.end();
+    std::stable_sort(results.begin(), results.end(), SortTuningRecordByMeanRunSecs());
     if (results.size() > static_cast<size_t>(top_k)) {
-      end = begin + top_k;
-    }
-    Array<TuningRecord> ret;
-    ret.reserve(end - begin);
-    while (begin != end) {
-      ret.push_back(begin->second);
-      ++begin;
-    }
-    if (ret.size() < static_cast<size_t>(top_k)) {
-      LOG(WARNING) << "The size of the GetTopK result is smaller than requested. There are not "
-                      "enough valid records in the database for this workload.";
+      return {results.begin(), results.end() + top_k};
+    } else {
+      if (results.size() < static_cast<size_t>(top_k)) {
+        LOG(WARNING) << "The size of the GetTopK result is smaller than requested. There are not "
+                        "enough valid records in the database for this workload.";
+      }
+      return results;
     }
-    return ret;
   }
 
   Array<TuningRecord> GetAllTuningRecords() final { return records; }
diff --git a/tests/python/unittest/test_meta_schedule_database.py b/tests/python/unittest/test_meta_schedule_database.py
index 4ec10b556c..d4681d4011 100644
--- a/tests/python/unittest/test_meta_schedule_database.py
+++ b/tests/python/unittest/test_meta_schedule_database.py
@@ -554,10 +554,14 @@ def call_get_top_k(run_secs_list, database, k):
 
 @pytest.mark.parametrize(
     "k,expected",
-    [(0, []), (3, [[0.0, 2.0], [2.0], [1.5, 4.5]]), (5, [[0.0, 2.0], [2.0], [1.5, 4.5]])],
+    [
+        (0, []),
+        (4, [[0.0, 2.0], [2.0], [1.5, 4.5], [3.0, 1e10]]),
+        (5, [[0.0, 2.0], [2.0], [1.5, 4.5], [3.0, 1e10]]),
+    ],
 )
 def test_memory_database_get_top_k(k, expected):
-    run_secs_list = [[1.5, 4.5], [], [0.0, 2.0], None, [2.0]]
+    run_secs_list = [[1.5, 4.5], [], [0.0, 2.0], None, [2.0], [3.0, 1e10], [1e10]]
     database = ms.database.MemoryDatabase()
     result = call_get_top_k(run_secs_list, database, k)
     assert result == expected
@@ -565,10 +569,14 @@ def test_memory_database_get_top_k(k, expected):
 
 @pytest.mark.parametrize(
     "k,expected",
-    [(0, []), (3, [[0.0, 2.0], [2.0], [1.5, 4.5]]), (5, [[0.0, 2.0], [2.0], [1.5, 4.5]])],
+    [
+        (0, []),
+        (4, [[0.0, 2.0], [2.0], [1.5, 4.5], [3.0, 1e10]]),
+        (5, [[0.0, 2.0], [2.0], [1.5, 4.5], [3.0, 1e10]]),
+    ],
 )
 def test_json_database_get_top_k(k, expected):
-    run_secs_list = [[1.5, 4.5], [], [0.0, 2.0], None, [2.0]]
+    run_secs_list = [[1.5, 4.5], [], [0.0, 2.0], None, [2.0], [3.0, 1e10], [1e10]]
     with tempfile.TemporaryDirectory() as tmpdir:
         database = _create_tmp_database(tmpdir)
         result = call_get_top_k(run_secs_list, database, k)