You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2021/12/31 20:57:46 UTC

[tvm] branch main updated: [Target][BugFix] Convert dict and str to TVM object (#9807)

This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new a5ac362  [Target][BugFix] Convert dict and str to TVM object (#9807)
a5ac362 is described below

commit a5ac362fc3e84bc0bbd3ca4592b58edd0b77830a
Author: Colin Y. Li <cy...@live.com>
AuthorDate: Sat Jan 1 04:57:14 2022 +0800

    [Target][BugFix] Convert dict and str to TVM object (#9807)
    
    * [Target][BugFix] Convert dict and str to TVM object
    
    * Add tests
---
 python/tvm/target/target.py                 | 22 +++++++++++++-------
 tests/python/relay/test_build_module.py     | 26 ++++++++++++++++++++++-
 tests/python/unittest/test_target_target.py | 32 +++++++++++++++++++++++++++++
 3 files changed, 72 insertions(+), 8 deletions(-)

diff --git a/python/tvm/target/target.py b/python/tvm/target/target.py
index 723cb91..ebeb437 100644
--- a/python/tvm/target/target.py
+++ b/python/tvm/target/target.py
@@ -22,7 +22,9 @@ import warnings
 
 import tvm._ffi
 from tvm._ffi import register_func as _register_func
-from tvm.runtime import Object
+from tvm.runtime import Object, convert
+from tvm.runtime.container import String
+from tvm.ir.container import Map
 
 from . import _ffi_api
 
@@ -107,10 +109,14 @@ class Target(Object):
             When using a dictionary or json string to configure target, the possible values are
             same as target.
         """
-        if target is None or not isinstance(target, (dict, str, Target)):
+        if isinstance(target, (dict, str)):
+            target = convert(target)
+        if isinstance(host, (dict, str)):
+            host = convert(host)
+        if target is None or not isinstance(target, (Map, String, Target)):
             raise ValueError("target has to be a string or dictionary.")
         if host is not None:
-            if not isinstance(host, (dict, str, Target)):
+            if not isinstance(host, (Map, String, Target)):
                 raise ValueError("target host has to be a string or dictionary.")
             self.__init_handle_by_constructor__(_ffi_api.Target, Target(target), Target(host))
         else:
@@ -221,15 +227,19 @@ class Target(Object):
         target_is_dict_key : Bool
             When the type of target is dict, whether Target is the key (Otherwise the value)
         """
+        if isinstance(target, (dict, str)):
+            target = convert(target)
+        if isinstance(host, (dict, str)):
+            host = convert(host)
         if target is None:
             assert host is None, "Target host is not empty when target is empty."
             return target, host
-        if isinstance(target, dict) and "kind" not in target:
+        if isinstance(target, Map) and "kind" not in target:
             new_target = {}
             for tgt, mod in target.items():
                 if not target_is_dict_key:
                     tgt, mod = mod, tgt
-                if isinstance(tgt, (dict, str, Target)):
+                if isinstance(tgt, (Map, String, Target)):
                     tgt, host = Target.check_and_update_host_consist(tgt, host)
                 if not target_is_dict_key:
                     tgt, mod = mod, tgt
@@ -242,8 +252,6 @@ class Target(Object):
 
 
 # TODO(@tvm-team): Deprecate the helper functions below. Encourage the usage of config dict instead.
-
-
 def _merge_opts(opts, new_opts):
     """Helper function to merge options"""
     if isinstance(new_opts, str):
diff --git a/tests/python/relay/test_build_module.py b/tests/python/relay/test_build_module.py
index d812ad8..7470622 100644
--- a/tests/python/relay/test_build_module.py
+++ b/tests/python/relay/test_build_module.py
@@ -17,8 +17,10 @@
 
 import pytest
 
+import tvm
+from tvm import relay
 from tvm.target.target import Target
-from tvm.relay.backend import Runtime, Executor
+from tvm.relay.backend import Runtime, Executor, graph_executor_codegen
 from tvm.relay.build_module import _reconstruct_from_deprecated_options
 
 
@@ -58,5 +60,27 @@ def test_deprecated_target_parameters(target, executor, runtime):
     assert runtime == actual_runtime
 
 
+def test_build_relay_graph_():
+    """Test to build a simple relay graph by using APIs directly"""
+
+    def build_graph(mod, target):
+        target = relay.build_module.build_target_by_device_type_map(target)
+        target, target_host = tvm.target.Target.check_and_update_host_consist(target)
+        mod, _ = relay.optimize(mod, target, None)
+        grc = graph_executor_codegen.GraphExecutorCodegen(None, target)
+        _, lowered_funcs, _ = grc.codegen(mod, mod["main"])
+        _ = relay.backend._backend.build(lowered_funcs, target, target_host)
+
+    def add(shape, dtype):
+        lhs = relay.var("A", shape=shape, dtype=dtype)
+        rhs = relay.var("B", shape=shape, dtype=dtype)
+        out = relay.add(lhs, rhs)
+        expr = relay.Function((lhs, rhs), out)
+        mod = tvm.IRModule.from_expr(expr)
+        return mod
+
+    build_graph(add((1, 8), "float32"), tvm.target.Target("llvm"))
+
+
 if __name__ == "__main__":
     pytest.main()
diff --git a/tests/python/unittest/test_target_target.py b/tests/python/unittest/test_target_target.py
index 199721b..33f9a96 100644
--- a/tests/python/unittest/test_target_target.py
+++ b/tests/python/unittest/test_target_target.py
@@ -319,6 +319,17 @@ def test_target_host_merge_2():
     assert tgt.host.kind.name == "llvm"
 
 
+def test_target_tvm_object():
+    """Test creating Target by using TVM Objects"""
+    String = tvm.runtime.container.String
+    tgt = tvm.target.Target(target=String("cuda --host llvm"))
+    assert tgt.kind.name == "cuda"
+    assert tgt.host.kind.name == "llvm"
+    tgt = tvm.target.Target(target=String("cuda"), host=String("llvm"))
+    assert tgt.kind.name == "cuda"
+    assert tgt.host.kind.name == "llvm"
+
+
 @pytest.mark.skip(reason="Causing infinite loop because of pytest and handle issue")
 def test_target_host_merge_3():
     with pytest.raises(ValueError, match=r"target host has to be a string or dictionary."):
@@ -372,6 +383,27 @@ def test_check_and_update_host_consist_3():
     assert target.host == host
 
 
+def test_check_and_update_host_consist_4():
+    """Test `check_and_update_host_consist` by using TVM Objects"""
+    cuda_device_type = tvm.device("cuda").device_type
+    target = {cuda_device_type: Target(target="cuda", host="llvm")}
+    host = None
+    target_1, host_1 = Target.check_and_update_host_consist(target, host)
+    assert isinstance(target_1, dict)
+    assert target_1[cuda_device_type].kind.name == "cuda"
+    assert target_1[cuda_device_type].host.kind.name == "llvm"
+    assert host_1 is None
+
+    target = {cuda_device_type: Target(tvm.runtime.container.String("cuda"))}
+    host = Target(tvm.runtime.container.String("llvm"))
+    target = tvm.runtime.convert(target)
+    assert isinstance(target, tvm.ir.container.Map)
+    target_2, host_2 = Target.check_and_update_host_consist(target, host)
+    assert isinstance(target_2, dict)
+    assert target_2[cuda_device_type].kind.name == "cuda"
+    assert host_2.kind.name == "llvm"
+
+
 def test_target_attr_bool_value():
     target0 = Target("vulkan --supports_float16=True")
     assert target0.attrs["supports_float16"] == 1