You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by jw...@apache.org on 2022/05/17 23:17:53 UTC

[tvm] branch main updated: [Relay] Bug fix when applying history using an iterator or records. (#11306)

This is an automated email from the ASF dual-hosted git repository.

jwfromm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 75c31cae75 [Relay] Bug fix when applying history using an iterator or records. (#11306)
75c31cae75 is described below

commit 75c31cae75fe31af9e0901210ba7fa597e6f153a
Author: Josh Fromm <jw...@octoml.ai>
AuthorDate: Tue May 17 16:17:48 2022 -0700

    [Relay] Bug fix when applying history using an iterator or records. (#11306)
    
    * Bug fix when applying history using an iterator or records.
    
    * I forgot strings are iterables.
---
 python/tvm/auto_scheduler/dispatcher.py          | 3 ++-
 python/tvm/autotvm/task/dispatcher.py            | 5 +++--
 tests/python/relay/test_auto_scheduler_tuning.py | 7 +++++++
 tests/python/unittest/test_autotvm_record.py     | 5 +++++
 4 files changed, 17 insertions(+), 3 deletions(-)

diff --git a/python/tvm/auto_scheduler/dispatcher.py b/python/tvm/auto_scheduler/dispatcher.py
index eceeba38e0..98566f8636 100644
--- a/python/tvm/auto_scheduler/dispatcher.py
+++ b/python/tvm/auto_scheduler/dispatcher.py
@@ -25,6 +25,7 @@ as a schedule configuration here.
 
 import logging
 import pathlib
+from collections.abc import Iterable
 
 import numpy as np
 
@@ -199,7 +200,7 @@ class ApplyHistoryBest(DispatchContext):
             if it is not None, only load the first `n_lines` lines of log
         """
         joint_records = []
-        if not isinstance(records, (list, tuple)):
+        if not isinstance(records, Iterable) or isinstance(records, str):
             records = [records]
 
         for rec in records:
diff --git a/python/tvm/autotvm/task/dispatcher.py b/python/tvm/autotvm/task/dispatcher.py
index ffff50b9dc..6c072dc1fa 100644
--- a/python/tvm/autotvm/task/dispatcher.py
+++ b/python/tvm/autotvm/task/dispatcher.py
@@ -31,6 +31,7 @@ of the DispatchContext base class.
 from __future__ import absolute_import as _abs
 
 import logging
+from collections.abc import Iterable
 
 import numpy as np
 
@@ -212,7 +213,7 @@ class ApplyHistoryBest(DispatchContext):
             Collection of tuning records.
             If is str, then it should be the filename of a records log file.
             Each row of this file is an encoded record pair. If it is a list
-            it can either be a list of paths to logs that will loaded jointly or
+            it can either be a list of paths to logs that will be loaded jointly or
             an iterator of measurement results.
         """
         # pylint: disable=import-outside-toplevel
@@ -220,7 +221,7 @@ class ApplyHistoryBest(DispatchContext):
         from ..record import load_from_file
 
         joint_records = []
-        if not isinstance(records, (list, tuple)):
+        if not isinstance(records, Iterable) or isinstance(records, str):
             records = [records]
 
         for rec in records:
diff --git a/tests/python/relay/test_auto_scheduler_tuning.py b/tests/python/relay/test_auto_scheduler_tuning.py
index c9ce5b59ff..735486ef27 100644
--- a/tests/python/relay/test_auto_scheduler_tuning.py
+++ b/tests/python/relay/test_auto_scheduler_tuning.py
@@ -62,6 +62,13 @@ def tune_network(network, target):
                 best, auto_scheduler.dispatcher.ApplyHistoryBest
             ), "Unable to load multiple log files jointly."
 
+        # Confirm iterables can be directly loaded.
+        loaded_recs = auto_scheduler.dispatcher.load_records(log_file)
+        with auto_scheduler.ApplyHistoryBest(iter(loaded_recs)) as best:
+            assert isinstance(
+                best, auto_scheduler.dispatcher.ApplyHistoryBest
+            ), "Unable to ingest logs from an interator."
+
         # Sample a schedule when missing
         with auto_scheduler.ApplyHistoryBestOrSample(None, num_measure=2):
             with tvm.transform.PassContext(
diff --git a/tests/python/unittest/test_autotvm_record.py b/tests/python/unittest/test_autotvm_record.py
index 2ee75cf18c..147122ff10 100644
--- a/tests/python/unittest/test_autotvm_record.py
+++ b/tests/python/unittest/test_autotvm_record.py
@@ -91,6 +91,11 @@ def test_apply_history_best():
     x = hist_best.query(target, tsk.workload)
     assert str(x) == str(tsk.config_space.get(2))
 
+    # Confirm same functionality for iterators.
+    hist_best = ApplyHistoryBest(iter(records))
+    x = hist_best.query(target, tsk.workload)
+    assert str(x) == str(tsk.config_space.get(2))
+
 
 if __name__ == "__main__":
     test_load_dump()