You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2021/04/03 12:23:24 UTC
[tvm] branch main updated: [Target] Fix empty target and host for
autotvm task (#7791)
This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 1f59139 [Target] Fix empty target and host for autotvm task (#7791)
1f59139 is described below
commit 1f59139db2003bc718159d1d87e7a7d36522961c
Author: Xiyou Zhou <xi...@octoml.ai>
AuthorDate: Sat Apr 3 05:23:14 2021 -0700
[Target] Fix empty target and host for autotvm task (#7791)
---
python/tvm/autotvm/task/task.py | 4 ++--
python/tvm/target/target.py | 3 +++
tests/python/integration/test_tuning.py | 4 ++--
tests/python/unittest/test_target_target.py | 33 ++++++++++++++++++++++++++++-
4 files changed, 39 insertions(+), 5 deletions(-)
diff --git a/python/tvm/autotvm/task/task.py b/python/tvm/autotvm/task/task.py
index 0d60ca9..668832b 100644
--- a/python/tvm/autotvm/task/task.py
+++ b/python/tvm/autotvm/task/task.py
@@ -185,7 +185,7 @@ class Task(object):
"config_space": self.config_space,
"flop": self.flop,
"target": self.target,
- "target_host": self.target.host,
+ "target_host": self.target_host,
"func": cloudpickle.dumps(self.func),
}
@@ -465,7 +465,7 @@ def create(task_name, args, target, target_host=None):
ret.flop = ret.config_space.flop or compute_flop(sch)
ret.target = target
- ret.target_host = target.host
+ ret.target_host = target_host
return ret
diff --git a/python/tvm/target/target.py b/python/tvm/target/target.py
index 6d0a063..baf0760 100644
--- a/python/tvm/target/target.py
+++ b/python/tvm/target/target.py
@@ -182,6 +182,9 @@ class Target(Object):
target_is_dict_key : Bool
When the type of target is dict, whether Target is the key (Otherwise the value)
"""
+ if target is None:
+ assert host is None, "Target host is not empty when target is empty."
+ return target, host
if isinstance(target, dict) and "kind" not in target:
new_target = {}
for tgt, mod in target.items():
diff --git a/tests/python/integration/test_tuning.py b/tests/python/integration/test_tuning.py
index 45e0958..55c8e56 100644
--- a/tests/python/integration/test_tuning.py
+++ b/tests/python/integration/test_tuning.py
@@ -30,6 +30,7 @@ from tvm import te
from tvm import autotvm
from tvm.autotvm.tuner import RandomTuner
+from tvm.target import Target
import tvm.testing
@@ -131,8 +132,7 @@ def teardown_module():
def get_sample_task(target=tvm.target.cuda(), target_host=None):
- target = tvm.target.Target(target, target_host)
- target_host = target.host
+ target, target_host = Target.check_and_update_host_consist(target, target_host)
"""return a sample task for testing"""
task = autotvm.task.create(
"testing/conv2d_no_batching", args=(1, 7, 7, 512, 512, 3, 3), target=target
diff --git a/tests/python/unittest/test_target_target.py b/tests/python/unittest/test_target_target.py
index 2f885d3..98a9edc 100644
--- a/tests/python/unittest/test_target_target.py
+++ b/tests/python/unittest/test_target_target.py
@@ -18,7 +18,7 @@ import json
import sys
import pytest
import tvm
-from tvm.target import cuda, rocm, mali, intel_graphics, arm_cpu, vta, bifrost
+from tvm.target import cuda, rocm, mali, intel_graphics, arm_cpu, vta, bifrost, Target
@tvm.target.generic_func
@@ -268,5 +268,36 @@ def test_target_with_host():
assert tgt.host.attrs["registers_per_block"] == 32768
+def test_check_and_update_host_consist_0():
+ target = None
+ host = None
+ target, host = Target.check_and_update_host_consist(target, host)
+
+
+def test_check_and_update_host_consist_1():
+ target = None
+ host = "llvm"
+ with pytest.raises(AssertionError, match=r"Target host is not empty when target is empty."):
+ target, host = Target.check_and_update_host_consist(target, host)
+
+
+def test_check_and_update_host_consist_2():
+ target = Target("cuda")
+ host = Target("llvm")
+ target, host = Target.check_and_update_host_consist(target, host)
+ assert target.kind.name == "cuda"
+ assert target.host.kind.name == "llvm"
+
+
+def test_check_and_update_host_consist_3():
+ target = Target(target="cuda", host="llvm")
+ host = None
+ target, host = Target.check_and_update_host_consist(target, host)
+ assert target.kind.name == "cuda"
+ assert target.host.kind.name == "llvm"
+ assert host.kind.name == "llvm"
+ assert target.host == host
+
+
if __name__ == "__main__":
sys.exit(pytest.main([__file__] + sys.argv[1:]))