You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2021/09/02 06:21:50 UTC
[tvm] branch main updated: [ROCm][TVMC] Add ROCm to the TVMC driver
(#8896)
This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new eaf888c [ROCm][TVMC] Add ROCm to the TVMC driver (#8896)
eaf888c is described below
commit eaf888c56827ebac1a43f01f93e6d4e6f8623a28
Author: mvermeulen <54...@users.noreply.github.com>
AuthorDate: Thu Sep 2 01:21:21 2021 -0500
[ROCm][TVMC] Add ROCm to the TVMC driver (#8896)
* Add ROCm to list of RPC clients.
* Add ROCm to list of TVMC devices.
* Enable ROCm by adding session call.
---
python/tvm/driver/tvmc/runner.py | 4 +++-
python/tvm/rpc/client.py | 4 ++++
2 files changed, 7 insertions(+), 1 deletion(-)
diff --git a/python/tvm/driver/tvmc/runner.py b/python/tvm/driver/tvmc/runner.py
index 489604d..5a15d22 100644
--- a/python/tvm/driver/tvmc/runner.py
+++ b/python/tvm/driver/tvmc/runner.py
@@ -50,7 +50,7 @@ def add_run_parser(subparsers):
# like 'webgpu', etc (@leandron)
parser.add_argument(
"--device",
- choices=["cpu", "cuda", "cl", "metal", "vulkan"],
+ choices=["cpu", "cuda", "cl", "metal", "vulkan", "rocm"],
default="cpu",
help="target device to run the compiled module. Defaults to 'cpu'",
)
@@ -394,6 +394,8 @@ def run_module(
dev = session.metal()
elif device == "vulkan":
dev = session.vulkan()
+ elif device == "rocm":
+ dev = session.rocm()
else:
assert device == "cpu"
dev = session.cpu()
diff --git a/python/tvm/rpc/client.py b/python/tvm/rpc/client.py
index d8199c4..a983439 100644
--- a/python/tvm/rpc/client.py
+++ b/python/tvm/rpc/client.py
@@ -217,6 +217,10 @@ class RPCSession(object):
"""Construct Metal device."""
return self.device(8, dev_id)
+ def rocm(self, dev_id=0):
+ """Construct ROCm device."""
+ return self.device(10, dev_id)
+
def ext_dev(self, dev_id=0):
"""Construct extension device."""
return self.device(12, dev_id)