You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2021/09/02 06:21:50 UTC

[tvm] branch main updated: [ROCm][TVMC] Add ROCm to the TVMC driver (#8896)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new eaf888c  [ROCm][TVMC] Add ROCm to the TVMC driver (#8896)
eaf888c is described below

commit eaf888c56827ebac1a43f01f93e6d4e6f8623a28
Author: mvermeulen <54...@users.noreply.github.com>
AuthorDate: Thu Sep 2 01:21:21 2021 -0500

    [ROCm][TVMC] Add ROCm to the TVMC driver (#8896)
    
    * Add ROCm to list of RPC clients.
    
    * Add ROCm to list of TVMC devices.
    
    * Enable ROCm by adding session call.
---
 python/tvm/driver/tvmc/runner.py | 4 +++-
 python/tvm/rpc/client.py         | 4 ++++
 2 files changed, 7 insertions(+), 1 deletion(-)

diff --git a/python/tvm/driver/tvmc/runner.py b/python/tvm/driver/tvmc/runner.py
index 489604d..5a15d22 100644
--- a/python/tvm/driver/tvmc/runner.py
+++ b/python/tvm/driver/tvmc/runner.py
@@ -50,7 +50,7 @@ def add_run_parser(subparsers):
     #      like 'webgpu', etc (@leandron)
     parser.add_argument(
         "--device",
-        choices=["cpu", "cuda", "cl", "metal", "vulkan"],
+        choices=["cpu", "cuda", "cl", "metal", "vulkan", "rocm"],
         default="cpu",
         help="target device to run the compiled module. Defaults to 'cpu'",
     )
@@ -394,6 +394,8 @@ def run_module(
         dev = session.metal()
     elif device == "vulkan":
         dev = session.vulkan()
+    elif device == "rocm":
+        dev = session.rocm()
     else:
         assert device == "cpu"
         dev = session.cpu()
diff --git a/python/tvm/rpc/client.py b/python/tvm/rpc/client.py
index d8199c4..a983439 100644
--- a/python/tvm/rpc/client.py
+++ b/python/tvm/rpc/client.py
@@ -217,6 +217,10 @@ class RPCSession(object):
         """Construct Metal device."""
         return self.device(8, dev_id)
 
+    def rocm(self, dev_id=0):
+        """Construct ROCm device."""
+        return self.device(10, dev_id)
+
     def ext_dev(self, dev_id=0):
         """Construct extension device."""
         return self.device(12, dev_id)