You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2021/12/23 14:11:21 UTC

[tvm] branch main updated: [Target] Fix device mask issue and typos (#9768)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 4af2a66  [Target] Fix device mask issue and typos (#9768)
4af2a66 is described below

commit 4af2a6614b901d38a1914bd28ec38080217105d1
Author: Colin Y. Li <cy...@live.com>
AuthorDate: Thu Dec 23 22:09:35 2021 +0800

    [Target] Fix device mask issue and typos (#9768)
    
    * [Target] Fix device mask issue and typos
    
    * Skip target hook
---
 python/tvm/_ffi/runtime_ctypes.py           |  5 +++--
 python/tvm/driver/tvmc/compiler.py          |  2 +-
 python/tvm/relay/op/op.py                   |  6 +++---
 python/tvm/target/target.py                 | 15 +++++++++++++++
 src/target/target_kind.cc                   | 12 +++++++++++-
 tests/python/unittest/test_target_target.py | 17 +++++++++++++++++
 6 files changed, 50 insertions(+), 7 deletions(-)

diff --git a/python/tvm/_ffi/runtime_ctypes.py b/python/tvm/_ffi/runtime_ctypes.py
index 297e24d..03a68e9 100644
--- a/python/tvm/_ffi/runtime_ctypes.py
+++ b/python/tvm/_ffi/runtime_ctypes.py
@@ -203,7 +203,6 @@ class Device(ctypes.Structure):
         2: "cuda",
         4: "opencl",
         5: "aocl",
-        6: "sdaccel",
         7: "vulkan",
         8: "metal",
         9: "vpi",
@@ -217,13 +216,15 @@ class Device(ctypes.Structure):
         "stackvm": 1,
         "cpu": 1,
         "c": 1,
+        "hybrid": 1,
+        "composite": 1,
         "cuda": 2,
         "nvptx": 2,
         "cl": 4,
         "opencl": 4,
+        "sdaccel": 4,
         "aocl": 5,
         "aocl_sw_emu": 5,
-        "sdaccel": 6,
         "vulkan": 7,
         "metal": 8,
         "vpi": 9,
diff --git a/python/tvm/driver/tvmc/compiler.py b/python/tvm/driver/tvmc/compiler.py
index d390ce5..dbf7e46 100644
--- a/python/tvm/driver/tvmc/compiler.py
+++ b/python/tvm/driver/tvmc/compiler.py
@@ -76,7 +76,7 @@ def add_compile_parser(subparsers, _):
         "-o",
         "--output",
         default="module.tar",
-        help="output the compiled module to a specifed archive. Defaults to 'module.tar'.",
+        help="output the compiled module to a specified archive. Defaults to 'module.tar'.",
     )
     parser.add_argument(
         "-f",
diff --git a/python/tvm/relay/op/op.py b/python/tvm/relay/op/op.py
index bbaffc4..ec48ea1 100644
--- a/python/tvm/relay/op/op.py
+++ b/python/tvm/relay/op/op.py
@@ -60,7 +60,7 @@ def register(op_name, describe=""):
 
 
 def register_stateful(op_name, stateful, level=10):
-    """Register operator pattern for an op.
+    """Register stateful flag for an op.
 
     Parameters
     ----------
@@ -81,7 +81,7 @@ class OpPattern(object):
 
     See Also
     --------
-    top.tag : Contains explanation of the tag type.
+    topi.tag : Contains explanation of the tag type.
     """
 
     # Elementwise operator
@@ -393,7 +393,7 @@ def register_pattern(op_name, pattern, level=10):
 
 
 def register_gradient(op_name, fgradient=None, level=10):
-    """Register operator pattern for an op.
+    """Register operator gradient function for an op.
 
     Parameters
     ----------
diff --git a/python/tvm/target/target.py b/python/tvm/target/target.py
index c7f6d41..723cb91 100644
--- a/python/tvm/target/target.py
+++ b/python/tvm/target/target.py
@@ -186,6 +186,21 @@ class Target(Object):
     def libs(self):
         return list(self.attrs.get("libs", []))
 
+    def get_kind_attr(self, attr_name):
+        """Get additional attribute about the target kind.
+
+        Parameters
+        ----------
+        attr_name : str
+            The attribute name.
+
+        Returns
+        -------
+        value : object
+            The attribute value
+        """
+        return _ffi_api.TargetKindGetAttr(self.kind, attr_name)
+
     @staticmethod
     def list_kinds():
         """Returns the list of available target names."""
diff --git a/src/target/target_kind.cc b/src/target/target_kind.cc
index 59efc8c..e4bf48b 100644
--- a/src/target/target_kind.cc
+++ b/src/target/target_kind.cc
@@ -400,10 +400,20 @@ TVM_REGISTER_TARGET_KIND("ext_dev", kDLExtDev)  // line break
 TVM_REGISTER_TARGET_KIND("hybrid", kDLCPU)  // line break
     .add_attr_option<Bool>("system-lib");
 
-TVM_REGISTER_TARGET_KIND("composite", kDLCPU).add_attr_option<Array<Target>>("devices");
+TVM_REGISTER_TARGET_KIND("composite", kDLCPU)  // line break
+    .add_attr_option<Array<Target>>("devices");
 
 /**********  Registry  **********/
 
+TVM_REGISTER_GLOBAL("target.TargetKindGetAttr")
+    .set_body_typed([](TargetKind kind, String attr_name) -> TVMRetValue {
+      auto target_attr_map = TargetKind::GetAttrMap<TVMRetValue>(attr_name);
+      TVMRetValue rv;
+      if (target_attr_map.count(kind)) {
+        rv = target_attr_map[kind];
+      }
+      return rv;
+    });
 TVM_REGISTER_GLOBAL("target.ListTargetKinds").set_body_typed(TargetKindRegEntry::ListTargetKinds);
 TVM_REGISTER_GLOBAL("target.ListTargetKindOptions")
     .set_body_typed(TargetKindRegEntry::ListTargetKindOptions);
diff --git a/tests/python/unittest/test_target_target.py b/tests/python/unittest/test_target_target.py
index 46ceeae..3a8cba5 100644
--- a/tests/python/unittest/test_target_target.py
+++ b/tests/python/unittest/test_target_target.py
@@ -43,6 +43,23 @@ def rocm_func(data):
     return data + 10
 
 
+def test_all_targets_device_type_verify():
+    """Consistency verification for all targets' device type"""
+    all_targets = [tvm.target.Target(t) for t in tvm.target.Target.list_kinds()]
+
+    for tgt in all_targets:
+        # skip target hook
+        relay_to_tir = tgt.get_kind_attr("RelayToTIR")
+        tir_to_runtime = tgt.get_kind_attr("TIRToRuntime")
+        if relay_to_tir is not None or tir_to_runtime is not None:
+            continue
+
+        if tgt.kind.name not in tvm._ffi.runtime_ctypes.Device.STR2MASK:
+            raise KeyError("Cannot find target kind: %s in Device.STR2MASK" % tgt.kind.name)
+
+        assert tgt.kind.device_type == tvm._ffi.runtime_ctypes.Device.STR2MASK[tgt.kind.name]
+
+
 def test_target_dispatch():
     with tvm.target.cuda():
         assert mygeneric(1) == 3