You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2023/01/24 01:02:45 UTC
[tvm] branch main updated: [Metaschedule] get_top_k should not return not built records (#13824)
This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new e79fac6300 [Metaschedule] get_top_k should not return not built records (#13824)
e79fac6300 is described below
commit e79fac6300bea5ba45982d3f087855cb71be0f53
Author: Alexey <av...@gmail.com>
AuthorDate: Tue Jan 24 04:02:38 2023 +0300
[Metaschedule] get_top_k should not return not built records (#13824)
* [Metaschedule] get_top_k should not return not built records
* [Metaschedule][NFC] GetTopK extra polishing
---
src/meta_schedule/database/json_database.cc | 9 ++++-
src/meta_schedule/database/memory_database.cc | 44 +++++++++-------------
.../python/unittest/test_meta_schedule_database.py | 16 ++++++--
3 files changed, 38 insertions(+), 31 deletions(-)
diff --git a/src/meta_schedule/database/json_database.cc b/src/meta_schedule/database/json_database.cc
index b0fba5adb5..0e51e262df 100644
--- a/src/meta_schedule/database/json_database.cc
+++ b/src/meta_schedule/database/json_database.cc
@@ -127,7 +127,14 @@ class JSONDatabaseNode : public DatabaseNode {
Array<TuningRecord> results;
results.reserve(top_k);
for (const TuningRecord& record : this->tuning_records_) {
- if (!record->run_secs.defined() || record->run_secs.value().empty()) {
+ auto run_secs = record->run_secs;
+ if (!run_secs.defined() || run_secs.value().empty() ||
+ std::all_of(run_secs.value().begin(), run_secs.value().end(),
+ // kMaxMeanTime(1e10) is used as a stub for undefined measurement times.
+ [](tvm::FloatImm v) {
+ return v.defined() &&
+ v->value == SortTuningRecordByMeanRunSecs::kMaxMeanTime;
+ })) {
continue;
}
if (record->workload.same_as(workload) ||
diff --git a/src/meta_schedule/database/memory_database.cc b/src/meta_schedule/database/memory_database.cc
index 19178a35f4..8cbde46f83 100644
--- a/src/meta_schedule/database/memory_database.cc
+++ b/src/meta_schedule/database/memory_database.cc
@@ -65,42 +65,34 @@ class MemoryDatabaseNode : public DatabaseNode {
if (top_k == 0) {
return {};
}
- std::vector<std::pair<double, TuningRecord>> results;
+ std::vector<TuningRecord> results;
results.reserve(records.size());
for (const TuningRecord& record : records) {
- if (!record->run_secs.defined()) {
- continue;
- }
- Array<FloatImm> run_secs = record->run_secs.value();
- if (run_secs.empty()) {
+ auto run_secs = record->run_secs;
+ if (!run_secs.defined() || run_secs.value().empty() ||
+ std::all_of(run_secs.value().begin(), run_secs.value().end(),
+ // kMaxMeanTime(1e10) is used as a stub for undefined measurement times.
+ [](tvm::FloatImm v) {
+ return v.defined() &&
+ v->value == SortTuningRecordByMeanRunSecs::kMaxMeanTime;
+ })) {
continue;
}
if (record->workload.same_as(workload) ||
WorkloadEqual(GetModuleEquality())(record->workload, workload)) {
- double sum = 0.0;
- for (const FloatImm& i : run_secs) {
- sum += i->value;
- }
- results.emplace_back(sum / run_secs.size(), record);
+ results.emplace_back(record);
}
}
- std::sort(results.begin(), results.end());
- auto begin = results.begin();
- auto end = results.end();
+ std::stable_sort(results.begin(), results.end(), SortTuningRecordByMeanRunSecs());
if (results.size() > static_cast<size_t>(top_k)) {
- end = begin + top_k;
- }
- Array<TuningRecord> ret;
- ret.reserve(end - begin);
- while (begin != end) {
- ret.push_back(begin->second);
- ++begin;
- }
- if (ret.size() < static_cast<size_t>(top_k)) {
- LOG(WARNING) << "The size of the GetTopK result is smaller than requested. There are not "
- "enough valid records in the database for this workload.";
+ return {results.begin(), results.end() + top_k};
+ } else {
+ if (results.size() < static_cast<size_t>(top_k)) {
+ LOG(WARNING) << "The size of the GetTopK result is smaller than requested. There are not "
+ "enough valid records in the database for this workload.";
+ }
+ return results;
}
- return ret;
}
Array<TuningRecord> GetAllTuningRecords() final { return records; }
diff --git a/tests/python/unittest/test_meta_schedule_database.py b/tests/python/unittest/test_meta_schedule_database.py
index 4ec10b556c..d4681d4011 100644
--- a/tests/python/unittest/test_meta_schedule_database.py
+++ b/tests/python/unittest/test_meta_schedule_database.py
@@ -554,10 +554,14 @@ def call_get_top_k(run_secs_list, database, k):
@pytest.mark.parametrize(
"k,expected",
- [(0, []), (3, [[0.0, 2.0], [2.0], [1.5, 4.5]]), (5, [[0.0, 2.0], [2.0], [1.5, 4.5]])],
+ [
+ (0, []),
+ (4, [[0.0, 2.0], [2.0], [1.5, 4.5], [3.0, 1e10]]),
+ (5, [[0.0, 2.0], [2.0], [1.5, 4.5], [3.0, 1e10]]),
+ ],
)
def test_memory_database_get_top_k(k, expected):
- run_secs_list = [[1.5, 4.5], [], [0.0, 2.0], None, [2.0]]
+ run_secs_list = [[1.5, 4.5], [], [0.0, 2.0], None, [2.0], [3.0, 1e10], [1e10]]
database = ms.database.MemoryDatabase()
result = call_get_top_k(run_secs_list, database, k)
assert result == expected
@@ -565,10 +569,14 @@ def test_memory_database_get_top_k(k, expected):
@pytest.mark.parametrize(
"k,expected",
- [(0, []), (3, [[0.0, 2.0], [2.0], [1.5, 4.5]]), (5, [[0.0, 2.0], [2.0], [1.5, 4.5]])],
+ [
+ (0, []),
+ (4, [[0.0, 2.0], [2.0], [1.5, 4.5], [3.0, 1e10]]),
+ (5, [[0.0, 2.0], [2.0], [1.5, 4.5], [3.0, 1e10]]),
+ ],
)
def test_json_database_get_top_k(k, expected):
- run_secs_list = [[1.5, 4.5], [], [0.0, 2.0], None, [2.0]]
+ run_secs_list = [[1.5, 4.5], [], [0.0, 2.0], None, [2.0], [3.0, 1e10], [1e10]]
with tempfile.TemporaryDirectory() as tmpdir:
database = _create_tmp_database(tmpdir)
result = call_get_top_k(run_secs_list, database, k)