You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2021/04/03 12:23:24 UTC

[tvm] branch main updated: [Target] Fix empty target and host for autotvm task (#7791)

This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 1f59139  [Target] Fix empty target and host for autotvm task (#7791)
1f59139 is described below

commit 1f59139db2003bc718159d1d87e7a7d36522961c
Author: Xiyou Zhou <xi...@octoml.ai>
AuthorDate: Sat Apr 3 05:23:14 2021 -0700

    [Target] Fix empty target and host for autotvm task (#7791)
---
 python/tvm/autotvm/task/task.py             |  4 ++--
 python/tvm/target/target.py                 |  3 +++
 tests/python/integration/test_tuning.py     |  4 ++--
 tests/python/unittest/test_target_target.py | 33 ++++++++++++++++++++++++++++-
 4 files changed, 39 insertions(+), 5 deletions(-)

diff --git a/python/tvm/autotvm/task/task.py b/python/tvm/autotvm/task/task.py
index 0d60ca9..668832b 100644
--- a/python/tvm/autotvm/task/task.py
+++ b/python/tvm/autotvm/task/task.py
@@ -185,7 +185,7 @@ class Task(object):
             "config_space": self.config_space,
             "flop": self.flop,
             "target": self.target,
-            "target_host": self.target.host,
+            "target_host": self.target_host,
             "func": cloudpickle.dumps(self.func),
         }
 
@@ -465,7 +465,7 @@ def create(task_name, args, target, target_host=None):
 
     ret.flop = ret.config_space.flop or compute_flop(sch)
     ret.target = target
-    ret.target_host = target.host
+    ret.target_host = target_host
 
     return ret
 
diff --git a/python/tvm/target/target.py b/python/tvm/target/target.py
index 6d0a063..baf0760 100644
--- a/python/tvm/target/target.py
+++ b/python/tvm/target/target.py
@@ -182,6 +182,9 @@ class Target(Object):
         target_is_dict_key : Bool
             When the type of target is dict, whether Target is the key (Otherwise the value)
         """
+        if target is None:
+            assert host is None, "Target host is not empty when target is empty."
+            return target, host
         if isinstance(target, dict) and "kind" not in target:
             new_target = {}
             for tgt, mod in target.items():
diff --git a/tests/python/integration/test_tuning.py b/tests/python/integration/test_tuning.py
index 45e0958..55c8e56 100644
--- a/tests/python/integration/test_tuning.py
+++ b/tests/python/integration/test_tuning.py
@@ -30,6 +30,7 @@ from tvm import te
 
 from tvm import autotvm
 from tvm.autotvm.tuner import RandomTuner
+from tvm.target import Target
 
 import tvm.testing
 
@@ -131,8 +132,7 @@ def teardown_module():
 
 
 def get_sample_task(target=tvm.target.cuda(), target_host=None):
-    target = tvm.target.Target(target, target_host)
-    target_host = target.host
+    target, target_host = Target.check_and_update_host_consist(target, target_host)
     """return a sample task for testing"""
     task = autotvm.task.create(
         "testing/conv2d_no_batching", args=(1, 7, 7, 512, 512, 3, 3), target=target
diff --git a/tests/python/unittest/test_target_target.py b/tests/python/unittest/test_target_target.py
index 2f885d3..98a9edc 100644
--- a/tests/python/unittest/test_target_target.py
+++ b/tests/python/unittest/test_target_target.py
@@ -18,7 +18,7 @@ import json
 import sys
 import pytest
 import tvm
-from tvm.target import cuda, rocm, mali, intel_graphics, arm_cpu, vta, bifrost
+from tvm.target import cuda, rocm, mali, intel_graphics, arm_cpu, vta, bifrost, Target
 
 
 @tvm.target.generic_func
@@ -268,5 +268,36 @@ def test_target_with_host():
     assert tgt.host.attrs["registers_per_block"] == 32768
 
 
+def test_check_and_update_host_consist_0():
+    target = None
+    host = None
+    target, host = Target.check_and_update_host_consist(target, host)
+
+
+def test_check_and_update_host_consist_1():
+    target = None
+    host = "llvm"
+    with pytest.raises(AssertionError, match=r"Target host is not empty when target is empty."):
+        target, host = Target.check_and_update_host_consist(target, host)
+
+
+def test_check_and_update_host_consist_2():
+    target = Target("cuda")
+    host = Target("llvm")
+    target, host = Target.check_and_update_host_consist(target, host)
+    assert target.kind.name == "cuda"
+    assert target.host.kind.name == "llvm"
+
+
+def test_check_and_update_host_consist_3():
+    target = Target(target="cuda", host="llvm")
+    host = None
+    target, host = Target.check_and_update_host_consist(target, host)
+    assert target.kind.name == "cuda"
+    assert target.host.kind.name == "llvm"
+    assert host.kind.name == "llvm"
+    assert target.host == host
+
+
 if __name__ == "__main__":
     sys.exit(pytest.main([__file__] + sys.argv[1:]))