You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2021/12/31 20:57:46 UTC
[tvm] branch main updated: [Target][BugFix] Convert dict and str to TVM object (#9807)
This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new a5ac362 [Target][BugFix] Convert dict and str to TVM object (#9807)
a5ac362 is described below
commit a5ac362fc3e84bc0bbd3ca4592b58edd0b77830a
Author: Colin Y. Li <cy...@live.com>
AuthorDate: Sat Jan 1 04:57:14 2022 +0800
[Target][BugFix] Convert dict and str to TVM object (#9807)
* [Target][BugFix] Convert dict and str to TVM object
* Add tests
---
python/tvm/target/target.py | 22 +++++++++++++-------
tests/python/relay/test_build_module.py | 26 ++++++++++++++++++++++-
tests/python/unittest/test_target_target.py | 32 +++++++++++++++++++++++++++++
3 files changed, 72 insertions(+), 8 deletions(-)
diff --git a/python/tvm/target/target.py b/python/tvm/target/target.py
index 723cb91..ebeb437 100644
--- a/python/tvm/target/target.py
+++ b/python/tvm/target/target.py
@@ -22,7 +22,9 @@ import warnings
import tvm._ffi
from tvm._ffi import register_func as _register_func
-from tvm.runtime import Object
+from tvm.runtime import Object, convert
+from tvm.runtime.container import String
+from tvm.ir.container import Map
from . import _ffi_api
@@ -107,10 +109,14 @@ class Target(Object):
When using a dictionary or json string to configure target, the possible values are
same as target.
"""
- if target is None or not isinstance(target, (dict, str, Target)):
+ if isinstance(target, (dict, str)):
+ target = convert(target)
+ if isinstance(host, (dict, str)):
+ host = convert(host)
+ if target is None or not isinstance(target, (Map, String, Target)):
raise ValueError("target has to be a string or dictionary.")
if host is not None:
- if not isinstance(host, (dict, str, Target)):
+ if not isinstance(host, (Map, String, Target)):
raise ValueError("target host has to be a string or dictionary.")
self.__init_handle_by_constructor__(_ffi_api.Target, Target(target), Target(host))
else:
@@ -221,15 +227,19 @@ class Target(Object):
target_is_dict_key : Bool
When the type of target is dict, whether Target is the key (Otherwise the value)
"""
+ if isinstance(target, (dict, str)):
+ target = convert(target)
+ if isinstance(host, (dict, str)):
+ host = convert(host)
if target is None:
assert host is None, "Target host is not empty when target is empty."
return target, host
- if isinstance(target, dict) and "kind" not in target:
+ if isinstance(target, Map) and "kind" not in target:
new_target = {}
for tgt, mod in target.items():
if not target_is_dict_key:
tgt, mod = mod, tgt
- if isinstance(tgt, (dict, str, Target)):
+ if isinstance(tgt, (Map, String, Target)):
tgt, host = Target.check_and_update_host_consist(tgt, host)
if not target_is_dict_key:
tgt, mod = mod, tgt
@@ -242,8 +252,6 @@ class Target(Object):
# TODO(@tvm-team): Deprecate the helper functions below. Encourage the usage of config dict instead.
-
-
def _merge_opts(opts, new_opts):
"""Helper function to merge options"""
if isinstance(new_opts, str):
diff --git a/tests/python/relay/test_build_module.py b/tests/python/relay/test_build_module.py
index d812ad8..7470622 100644
--- a/tests/python/relay/test_build_module.py
+++ b/tests/python/relay/test_build_module.py
@@ -17,8 +17,10 @@
import pytest
+import tvm
+from tvm import relay
from tvm.target.target import Target
-from tvm.relay.backend import Runtime, Executor
+from tvm.relay.backend import Runtime, Executor, graph_executor_codegen
from tvm.relay.build_module import _reconstruct_from_deprecated_options
@@ -58,5 +60,27 @@ def test_deprecated_target_parameters(target, executor, runtime):
assert runtime == actual_runtime
+def test_build_relay_graph_():
+ """Test to build a simple relay graph by using APIs directly"""
+
+ def build_graph(mod, target):
+ target = relay.build_module.build_target_by_device_type_map(target)
+ target, target_host = tvm.target.Target.check_and_update_host_consist(target)
+ mod, _ = relay.optimize(mod, target, None)
+ grc = graph_executor_codegen.GraphExecutorCodegen(None, target)
+ _, lowered_funcs, _ = grc.codegen(mod, mod["main"])
+ _ = relay.backend._backend.build(lowered_funcs, target, target_host)
+
+ def add(shape, dtype):
+ lhs = relay.var("A", shape=shape, dtype=dtype)
+ rhs = relay.var("B", shape=shape, dtype=dtype)
+ out = relay.add(lhs, rhs)
+ expr = relay.Function((lhs, rhs), out)
+ mod = tvm.IRModule.from_expr(expr)
+ return mod
+
+ build_graph(add((1, 8), "float32"), tvm.target.Target("llvm"))
+
+
if __name__ == "__main__":
pytest.main()
diff --git a/tests/python/unittest/test_target_target.py b/tests/python/unittest/test_target_target.py
index 199721b..33f9a96 100644
--- a/tests/python/unittest/test_target_target.py
+++ b/tests/python/unittest/test_target_target.py
@@ -319,6 +319,17 @@ def test_target_host_merge_2():
assert tgt.host.kind.name == "llvm"
+def test_target_tvm_object():
+ """Test creating Target by using TVM Objects"""
+ String = tvm.runtime.container.String
+ tgt = tvm.target.Target(target=String("cuda --host llvm"))
+ assert tgt.kind.name == "cuda"
+ assert tgt.host.kind.name == "llvm"
+ tgt = tvm.target.Target(target=String("cuda"), host=String("llvm"))
+ assert tgt.kind.name == "cuda"
+ assert tgt.host.kind.name == "llvm"
+
+
@pytest.mark.skip(reason="Causing infinite loop because of pytest and handle issue")
def test_target_host_merge_3():
with pytest.raises(ValueError, match=r"target host has to be a string or dictionary."):
@@ -372,6 +383,27 @@ def test_check_and_update_host_consist_3():
assert target.host == host
+def test_check_and_update_host_consist_4():
+ """Test `check_and_update_host_consist` by using TVM Objects"""
+ cuda_device_type = tvm.device("cuda").device_type
+ target = {cuda_device_type: Target(target="cuda", host="llvm")}
+ host = None
+ target_1, host_1 = Target.check_and_update_host_consist(target, host)
+ assert isinstance(target_1, dict)
+ assert target_1[cuda_device_type].kind.name == "cuda"
+ assert target_1[cuda_device_type].host.kind.name == "llvm"
+ assert host_1 is None
+
+ target = {cuda_device_type: Target(tvm.runtime.container.String("cuda"))}
+ host = Target(tvm.runtime.container.String("llvm"))
+ target = tvm.runtime.convert(target)
+ assert isinstance(target, tvm.ir.container.Map)
+ target_2, host_2 = Target.check_and_update_host_consist(target, host)
+ assert isinstance(target_2, dict)
+ assert target_2[cuda_device_type].kind.name == "cuda"
+ assert host_2.kind.name == "llvm"
+
+
def test_target_attr_bool_value():
target0 = Target("vulkan --supports_float16=True")
assert target0.attrs["supports_float16"] == 1