You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by jw...@apache.org on 2022/05/17 23:17:53 UTC
[tvm] branch main updated: [Relay] Bug fix when applying history using an iterator or records. (#11306)
This is an automated email from the ASF dual-hosted git repository.
jwfromm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 75c31cae75 [Relay] Bug fix when applying history using an iterator or records. (#11306)
75c31cae75 is described below
commit 75c31cae75fe31af9e0901210ba7fa597e6f153a
Author: Josh Fromm <jw...@octoml.ai>
AuthorDate: Tue May 17 16:17:48 2022 -0700
[Relay] Bug fix when applying history using an iterator or records. (#11306)
* Bug fix when applying history using an iterator or records.
* I forgot strings are iterables.
---
python/tvm/auto_scheduler/dispatcher.py | 3 ++-
python/tvm/autotvm/task/dispatcher.py | 5 +++--
tests/python/relay/test_auto_scheduler_tuning.py | 7 +++++++
tests/python/unittest/test_autotvm_record.py | 5 +++++
4 files changed, 17 insertions(+), 3 deletions(-)
diff --git a/python/tvm/auto_scheduler/dispatcher.py b/python/tvm/auto_scheduler/dispatcher.py
index eceeba38e0..98566f8636 100644
--- a/python/tvm/auto_scheduler/dispatcher.py
+++ b/python/tvm/auto_scheduler/dispatcher.py
@@ -25,6 +25,7 @@ as a schedule configuration here.
import logging
import pathlib
+from collections.abc import Iterable
import numpy as np
@@ -199,7 +200,7 @@ class ApplyHistoryBest(DispatchContext):
if it is not None, only load the first `n_lines` lines of log
"""
joint_records = []
- if not isinstance(records, (list, tuple)):
+ if not isinstance(records, Iterable) or isinstance(records, str):
records = [records]
for rec in records:
diff --git a/python/tvm/autotvm/task/dispatcher.py b/python/tvm/autotvm/task/dispatcher.py
index ffff50b9dc..6c072dc1fa 100644
--- a/python/tvm/autotvm/task/dispatcher.py
+++ b/python/tvm/autotvm/task/dispatcher.py
@@ -31,6 +31,7 @@ of the DispatchContext base class.
from __future__ import absolute_import as _abs
import logging
+from collections.abc import Iterable
import numpy as np
@@ -212,7 +213,7 @@ class ApplyHistoryBest(DispatchContext):
Collection of tuning records.
If is str, then it should be the filename of a records log file.
Each row of this file is an encoded record pair. If it is a list
- it can either be a list of paths to logs that will loaded jointly or
+ it can either be a list of paths to logs that will be loaded jointly or
an iterator of measurement results.
"""
# pylint: disable=import-outside-toplevel
@@ -220,7 +221,7 @@ class ApplyHistoryBest(DispatchContext):
from ..record import load_from_file
joint_records = []
- if not isinstance(records, (list, tuple)):
+ if not isinstance(records, Iterable) or isinstance(records, str):
records = [records]
for rec in records:
diff --git a/tests/python/relay/test_auto_scheduler_tuning.py b/tests/python/relay/test_auto_scheduler_tuning.py
index c9ce5b59ff..735486ef27 100644
--- a/tests/python/relay/test_auto_scheduler_tuning.py
+++ b/tests/python/relay/test_auto_scheduler_tuning.py
@@ -62,6 +62,13 @@ def tune_network(network, target):
best, auto_scheduler.dispatcher.ApplyHistoryBest
), "Unable to load multiple log files jointly."
+ # Confirm iterables can be directly loaded.
+ loaded_recs = auto_scheduler.dispatcher.load_records(log_file)
+ with auto_scheduler.ApplyHistoryBest(iter(loaded_recs)) as best:
+ assert isinstance(
+ best, auto_scheduler.dispatcher.ApplyHistoryBest
+ ), "Unable to ingest logs from an interator."
+
# Sample a schedule when missing
with auto_scheduler.ApplyHistoryBestOrSample(None, num_measure=2):
with tvm.transform.PassContext(
diff --git a/tests/python/unittest/test_autotvm_record.py b/tests/python/unittest/test_autotvm_record.py
index 2ee75cf18c..147122ff10 100644
--- a/tests/python/unittest/test_autotvm_record.py
+++ b/tests/python/unittest/test_autotvm_record.py
@@ -91,6 +91,11 @@ def test_apply_history_best():
x = hist_best.query(target, tsk.workload)
assert str(x) == str(tsk.config_space.get(2))
+ # Confirm same functionality for iterators.
+ hist_best = ApplyHistoryBest(iter(records))
+ x = hist_best.query(target, tsk.workload)
+ assert str(x) == str(tsk.config_space.get(2))
+
if __name__ == "__main__":
test_load_dump()