You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2021/12/23 14:11:21 UTC
[tvm] branch main updated: [Target] Fix device mask issue and typos (#9768)
This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 4af2a66 [Target] Fix device mask issue and typos (#9768)
4af2a66 is described below
commit 4af2a6614b901d38a1914bd28ec38080217105d1
Author: Colin Y. Li <cy...@live.com>
AuthorDate: Thu Dec 23 22:09:35 2021 +0800
[Target] Fix device mask issue and typos (#9768)
* [Target] Fix device mask issue and typos
* Skip target hook
---
python/tvm/_ffi/runtime_ctypes.py | 5 +++--
python/tvm/driver/tvmc/compiler.py | 2 +-
python/tvm/relay/op/op.py | 6 +++---
python/tvm/target/target.py | 15 +++++++++++++++
src/target/target_kind.cc | 12 +++++++++++-
tests/python/unittest/test_target_target.py | 17 +++++++++++++++++
6 files changed, 50 insertions(+), 7 deletions(-)
diff --git a/python/tvm/_ffi/runtime_ctypes.py b/python/tvm/_ffi/runtime_ctypes.py
index 297e24d..03a68e9 100644
--- a/python/tvm/_ffi/runtime_ctypes.py
+++ b/python/tvm/_ffi/runtime_ctypes.py
@@ -203,7 +203,6 @@ class Device(ctypes.Structure):
2: "cuda",
4: "opencl",
5: "aocl",
- 6: "sdaccel",
7: "vulkan",
8: "metal",
9: "vpi",
@@ -217,13 +216,15 @@ class Device(ctypes.Structure):
"stackvm": 1,
"cpu": 1,
"c": 1,
+ "hybrid": 1,
+ "composite": 1,
"cuda": 2,
"nvptx": 2,
"cl": 4,
"opencl": 4,
+ "sdaccel": 4,
"aocl": 5,
"aocl_sw_emu": 5,
- "sdaccel": 6,
"vulkan": 7,
"metal": 8,
"vpi": 9,
diff --git a/python/tvm/driver/tvmc/compiler.py b/python/tvm/driver/tvmc/compiler.py
index d390ce5..dbf7e46 100644
--- a/python/tvm/driver/tvmc/compiler.py
+++ b/python/tvm/driver/tvmc/compiler.py
@@ -76,7 +76,7 @@ def add_compile_parser(subparsers, _):
"-o",
"--output",
default="module.tar",
- help="output the compiled module to a specifed archive. Defaults to 'module.tar'.",
+ help="output the compiled module to a specified archive. Defaults to 'module.tar'.",
)
parser.add_argument(
"-f",
diff --git a/python/tvm/relay/op/op.py b/python/tvm/relay/op/op.py
index bbaffc4..ec48ea1 100644
--- a/python/tvm/relay/op/op.py
+++ b/python/tvm/relay/op/op.py
@@ -60,7 +60,7 @@ def register(op_name, describe=""):
def register_stateful(op_name, stateful, level=10):
- """Register operator pattern for an op.
+ """Register stateful flag for an op.
Parameters
----------
@@ -81,7 +81,7 @@ class OpPattern(object):
See Also
--------
- top.tag : Contains explanation of the tag type.
+ topi.tag : Contains explanation of the tag type.
"""
# Elementwise operator
@@ -393,7 +393,7 @@ def register_pattern(op_name, pattern, level=10):
def register_gradient(op_name, fgradient=None, level=10):
- """Register operator pattern for an op.
+ """Register operator gradient function for an op.
Parameters
----------
diff --git a/python/tvm/target/target.py b/python/tvm/target/target.py
index c7f6d41..723cb91 100644
--- a/python/tvm/target/target.py
+++ b/python/tvm/target/target.py
@@ -186,6 +186,21 @@ class Target(Object):
def libs(self):
return list(self.attrs.get("libs", []))
+ def get_kind_attr(self, attr_name):
+ """Get additional attribute about the target kind.
+
+ Parameters
+ ----------
+ attr_name : str
+ The attribute name.
+
+ Returns
+ -------
+ value : object
+ The attribute value
+ """
+ return _ffi_api.TargetKindGetAttr(self.kind, attr_name)
+
@staticmethod
def list_kinds():
"""Returns the list of available target names."""
diff --git a/src/target/target_kind.cc b/src/target/target_kind.cc
index 59efc8c..e4bf48b 100644
--- a/src/target/target_kind.cc
+++ b/src/target/target_kind.cc
@@ -400,10 +400,20 @@ TVM_REGISTER_TARGET_KIND("ext_dev", kDLExtDev) // line break
TVM_REGISTER_TARGET_KIND("hybrid", kDLCPU) // line break
.add_attr_option<Bool>("system-lib");
-TVM_REGISTER_TARGET_KIND("composite", kDLCPU).add_attr_option<Array<Target>>("devices");
+TVM_REGISTER_TARGET_KIND("composite", kDLCPU) // line break
+ .add_attr_option<Array<Target>>("devices");
/********** Registry **********/
+TVM_REGISTER_GLOBAL("target.TargetKindGetAttr")
+ .set_body_typed([](TargetKind kind, String attr_name) -> TVMRetValue {
+ auto target_attr_map = TargetKind::GetAttrMap<TVMRetValue>(attr_name);
+ TVMRetValue rv;
+ if (target_attr_map.count(kind)) {
+ rv = target_attr_map[kind];
+ }
+ return rv;
+ });
TVM_REGISTER_GLOBAL("target.ListTargetKinds").set_body_typed(TargetKindRegEntry::ListTargetKinds);
TVM_REGISTER_GLOBAL("target.ListTargetKindOptions")
.set_body_typed(TargetKindRegEntry::ListTargetKindOptions);
diff --git a/tests/python/unittest/test_target_target.py b/tests/python/unittest/test_target_target.py
index 46ceeae..3a8cba5 100644
--- a/tests/python/unittest/test_target_target.py
+++ b/tests/python/unittest/test_target_target.py
@@ -43,6 +43,23 @@ def rocm_func(data):
return data + 10
+def test_all_targets_device_type_verify():
+ """Consistency verification for all targets' device type"""
+ all_targets = [tvm.target.Target(t) for t in tvm.target.Target.list_kinds()]
+
+ for tgt in all_targets:
+ # skip target hook
+ relay_to_tir = tgt.get_kind_attr("RelayToTIR")
+ tir_to_runtime = tgt.get_kind_attr("TIRToRuntime")
+ if relay_to_tir is not None or tir_to_runtime is not None:
+ continue
+
+ if tgt.kind.name not in tvm._ffi.runtime_ctypes.Device.STR2MASK:
+ raise KeyError("Cannot find target kind: %s in Device.STR2MASK" % tgt.kind.name)
+
+ assert tgt.kind.device_type == tvm._ffi.runtime_ctypes.Device.STR2MASK[tgt.kind.name]
+
+
def test_target_dispatch():
with tvm.target.cuda():
assert mygeneric(1) == 3