You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ho...@apache.org on 2023/09/20 22:26:08 UTC

[tvm] branch unity updated: [Disco] Integrate RCCL (#15776)

This is an automated email from the ASF dual-hosted git repository.

hongyij pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 2fdedf1ea8 [Disco] Integrate RCCL (#15776)
2fdedf1ea8 is described below

commit 2fdedf1ea88a149cdeb4c82cc8d8aa9c6ffdaa55
Author: Lesheng Jin <34...@users.noreply.github.com>
AuthorDate: Wed Sep 20 15:26:02 2023 -0700

    [Disco] Integrate RCCL (#15776)
    
    This PR integrates RCCL for amd multi-GPU parallelism.
---
 CMakeLists.txt                                   |  11 +-
 python/tvm/target/detect_target.py               |   2 +-
 src/runtime/disco/{nccl/nccl.cc => ccl/ccl.cc}   | 127 +++++++++++++++--------
 src/runtime/disco/{nccl => ccl}/utils.h          |  23 ++--
 src/runtime/rocm/rocm_device_api.cc              |   4 +
 tests/python/disco/{test_nccl.py => test_ccl.py} |  98 ++++++++---------
 6 files changed, 155 insertions(+), 110 deletions(-)

diff --git a/CMakeLists.txt b/CMakeLists.txt
index 5bb3d3e2e0..bc09655923 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -444,14 +444,14 @@ endif(USE_PROFILER)
 if(USE_CUDA AND USE_NCCL)
   message(STATUS "Build with NCCL...")
   find_nccl(${USE_NCCL})
-  tvm_file_glob(GLOB RUNTIME_NCCL_SRC src/runtime/disco/nccl/*.cc)
+  tvm_file_glob(GLOB RUNTIME_NCCL_SRC src/runtime/disco/ccl/*.cc)
   list(APPEND RUNTIME_SRCS ${RUNTIME_NCCL_SRC})
 endif()
 
 if(USE_ROCM AND USE_RCCL)
   message(STATUS "Build with RCCL...")
   find_rccl(${USE_RCCL})
-  tvm_file_glob(GLOB RUNTIME_RCCL_SRC src/runtime/disco/rccl/*.cc)
+  tvm_file_glob(GLOB RUNTIME_RCCL_SRC src/runtime/disco/ccl/*.cc)
   list(APPEND RUNTIME_SRCS ${RUNTIME_RCCL_SRC})
 endif()
 
@@ -891,11 +891,7 @@ endif()
 if(USE_CUDA AND USE_NCCL)
   target_link_libraries(tvm_runtime PRIVATE nccl)
   target_link_libraries(tvm PRIVATE nccl)
-endif()
-
-if(USE_CUDA AND USE_NCCL)
-  target_link_libraries(tvm PRIVATE nccl)
-  target_link_libraries(tvm_runtime PRIVATE nccl)
+  set_source_files_properties(src/runtime/disco/ccl/ccl.cc PROPERTIES COMPILE_DEFINITIONS "TVM_NCCL_RCCL_SWITCH=0")
 endif()
 
 if(USE_CUDA AND USE_NVTX)
@@ -905,4 +901,5 @@ endif()
 if(USE_ROCM AND USE_RCCL)
   target_link_libraries(tvm PRIVATE rccl)
   target_link_libraries(tvm_runtime PRIVATE rccl)
+  set_source_files_properties(src/runtime/disco/ccl/ccl.cc PROPERTIES COMPILE_DEFINITIONS "TVM_NCCL_RCCL_SWITCH=1")
 endif()
diff --git a/python/tvm/target/detect_target.py b/python/tvm/target/detect_target.py
index 5c139cc949..0cfcf17e6a 100644
--- a/python/tvm/target/detect_target.py
+++ b/python/tvm/target/detect_target.py
@@ -50,7 +50,7 @@ def _detect_rocm(dev: Device) -> Target:
     return Target(
         {
             "kind": "rocm",
-            "mtriple": "amdgcn-and-amdhsa-hcc",
+            "mtriple": "amdgcn-amd-amdhsa-hcc",
             "max_shared_memory_per_block": dev.max_shared_memory_per_block,
             "max_threads_per_block": dev.max_threads_per_block,
             "thread_warp_size": dev.warp_size,
diff --git a/src/runtime/disco/nccl/nccl.cc b/src/runtime/disco/ccl/ccl.cc
similarity index 66%
rename from src/runtime/disco/nccl/nccl.cc
rename to src/runtime/disco/ccl/ccl.cc
index 07ffbed6f0..8e656f9ab2 100644
--- a/src/runtime/disco/nccl/nccl.cc
+++ b/src/runtime/disco/ccl/ccl.cc
@@ -16,9 +16,48 @@
  * specific language governing permissions and limitations
  * under the License.
  */
+#ifndef TVM_NCCL_RCCL_SWITCH
+#define TVM_NCCL_RCCL_SWITCH 0  // 0: NCCL, 1: RCCL
+#endif
+
+#if TVM_NCCL_RCCL_SWITCH == 0
 #include <cuda_runtime_api.h>
-#include <dlpack/dlpack.h>
 #include <nccl.h>
+
+#include "../../cuda/cuda_common.h"
+
+using runtimeStream_t = cudaStream_t;
+
+#define TVM_DISCO_DEVICE_CALL CUDA_CALL
+#define TVM_DISCO_DEVICE_SET_DEVICE cudaSetDevice
+#define TVM_DISCO_DEVICE_STREAM_CREATE cudaStreamCreate
+#define TVM_DISCO_DEVICE_STREAM_SYNC cudaStreamSynchronize
+#define TVM_DISCO_DEVICE_STREAM_DESTROY cudaStreamDestroy
+#define TVM_DISCO_DEVICE_NAME "cuda"
+#define TVM_DISCO_CCL_DESTROY ncclCommDestroy
+#define TVM_DISCO_CCL_NAME "nccl"
+const constexpr DLDeviceType TVM_DISCO_DEVICE_TYPE = DLDeviceType::kDLCUDA;
+#else
+#include <hip/hip_runtime_api.h>
+#include <hip/hip_version.h>
+#include <rccl/rccl.h>
+
+#include "../../rocm/rocm_common.h"
+
+using runtimeStream_t = hipStream_t;
+
+#define TVM_DISCO_DEVICE_CALL ROCM_CALL
+#define TVM_DISCO_DEVICE_SET_DEVICE hipSetDevice
+#define TVM_DISCO_DEVICE_STREAM_CREATE hipStreamCreate
+#define TVM_DISCO_DEVICE_STREAM_SYNC hipStreamSynchronize
+#define TVM_DISCO_DEVICE_STREAM_DESTROY hipStreamDestroy
+#define TVM_DISCO_DEVICE_NAME "rocm"
+#define TVM_DISCO_CCL_DESTROY ncclCommDestroy
+#define TVM_DISCO_CCL_NAME "rccl"
+const constexpr DLDeviceType TVM_DISCO_DEVICE_TYPE = DLDeviceType::kDLROCM;
+#endif
+
+#include <dlpack/dlpack.h>
 #include <tvm/runtime/c_runtime_api.h>
 #include <tvm/runtime/disco/session.h>
 #include <tvm/runtime/registry.h>
@@ -29,40 +68,39 @@
 #include <vector>
 
 #include "../../../support/process_id.h"
-#include "../../cuda/cuda_common.h"
 #include "./utils.h"
 
 namespace tvm {
 namespace runtime {
-namespace nccl {
+namespace ccl {
 
-struct NCCLThreadLocalContext {
+struct CCLThreadLocalContext {
   DiscoWorker* worker;
   int device_id;
-  cudaStream_t default_stream;
+  runtimeStream_t default_stream;
   ncclComm_t comm;
 
   void Clear() {
-    NCCL_CALL(ncclCommDestroy(comm));
-    CUDA_CALL(cudaStreamDestroy(default_stream));
+    NCCL_CALL(TVM_DISCO_CCL_DESTROY(comm));
+    TVM_DISCO_DEVICE_CALL(TVM_DISCO_DEVICE_STREAM_DESTROY(default_stream));
   }
 
-  cudaStream_t GetDefaultStream() {
-    const auto* func = tvm::runtime::Registry::Get("runtime.get_cuda_stream");
+  runtimeStream_t GetDefaultStream() {
+    const auto* func = tvm::runtime::Registry::Get("runtime.get_" TVM_DISCO_DEVICE_NAME "_stream");
     ICHECK(func != nullptr);
-    cudaStream_t stream = static_cast<cudaStream_t>((*func)().operator void*());
+    runtimeStream_t stream = static_cast<runtimeStream_t>((*func)().operator void*());
     return stream == nullptr ? default_stream : stream;
   }
 
-  static NCCLThreadLocalContext* Get() {
-    thread_local static NCCLThreadLocalContext ctx;
+  static CCLThreadLocalContext* Get() {
+    thread_local static CCLThreadLocalContext ctx;
     return &ctx;
   }
 };
 
 void InitCCL(Session sess, ShapeTuple device_ids) {
-  DRef func = sess->GetGlobalFunc("runtime.disco.nccl.init_ccl_per_worker");
-  LOG(INFO) << "Initializing NCCL with devices: " << device_ids;
+  DRef func = sess->GetGlobalFunc("runtime.disco." TVM_DISCO_CCL_NAME ".init_ccl_per_worker");
+  LOG(INFO) << "Initializing " TVM_DISCO_CCL_NAME " with devices: " << device_ids;
   ncclUniqueId id;
   TVMByteArray array;
   NCCL_CALL(ncclGetUniqueId(&id));
@@ -72,7 +110,7 @@ void InitCCL(Session sess, ShapeTuple device_ids) {
 }
 
 void InitCCLPerWorker(ShapeTuple device_ids, std::string unique_id_bytes) {
-  NCCLThreadLocalContext* ctx = NCCLThreadLocalContext::Get();
+  CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get();
   DiscoWorker* worker = DiscoWorker::ThreadLocal();
   ICHECK(worker != nullptr);
   CHECK_EQ(unique_id_bytes.size(), NCCL_UNIQUE_ID_BYTES)
@@ -80,11 +118,11 @@ void InitCCLPerWorker(ShapeTuple device_ids, std::string unique_id_bytes) {
       << unique_id_bytes.size() << ".";
   // Step up local context of NCCL
   int device_id = device_ids[worker->worker_id];
-  CUDA_CALL(cudaSetDevice(device_id));
-  CUDA_CALL(cudaStreamCreate(&ctx->default_stream));
-  Device device{DLDeviceType::kDLCUDA, device_id};
+  TVM_DISCO_DEVICE_CALL(TVM_DISCO_DEVICE_SET_DEVICE(device_id));
+  TVM_DISCO_DEVICE_CALL(TVM_DISCO_DEVICE_STREAM_CREATE(&ctx->default_stream));
+  Device device{TVM_DISCO_DEVICE_TYPE, device_id};
   worker->default_device = device;
-  worker->ccl = "nccl";
+  worker->ccl = TVM_DISCO_CCL_NAME;
   ctx->worker = worker;
   ctx->device_id = device_id;
   // Initialize the communicator
@@ -94,21 +132,21 @@ void InitCCLPerWorker(ShapeTuple device_ids, std::string unique_id_bytes) {
 }
 
 void AllReduce(NDArray send, ReduceKind reduce_kind, NDArray recv) {
-  NCCLThreadLocalContext* ctx = NCCLThreadLocalContext::Get();
+  CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get();
   ShapeTuple shape = send.Shape();
   int64_t numel = shape->Product();
-  cudaStream_t stream = ctx->GetDefaultStream();
+  runtimeStream_t stream = ctx->GetDefaultStream();
   NCCL_CALL(ncclAllReduce(send->data, recv->data, numel,
                           /*datatype=*/AsNCCLDataType(DataType(send->dtype)),
                           /*op=*/AsNCCLRedOp(reduce_kind), ctx->comm, stream));
 }
 
 void BroadcastFromWorker0(NDArray send, NDArray recv) {
-  NCCLThreadLocalContext* ctx = NCCLThreadLocalContext::Get();
+  CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get();
   ICHECK(send.Shape()->Product() == recv.Shape()->Product());
   ShapeTuple shape = send.Shape();
   int64_t numel = shape->Product();
-  cudaStream_t stream = ctx->GetDefaultStream();
+  runtimeStream_t stream = ctx->GetDefaultStream();
   NCCL_CALL(ncclBroadcast(send->data, recv->data, numel,
                           /*datatype=*/AsNCCLDataType(DataType(send->dtype)),
                           /*root=*/0, ctx->comm, stream));
@@ -116,10 +154,10 @@ void BroadcastFromWorker0(NDArray send, NDArray recv) {
 
 void ScatterFromWorker0(Optional<NDArray> send, NDArray recv) {
   CHECK(recv.defined()) << "ValueError: buffer `recv` must not be None";
-  NCCLThreadLocalContext* ctx = NCCLThreadLocalContext::Get();
+  CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get();
   int worker_id = ctx->worker->worker_id;
   int num_workers = ctx->worker->num_workers;
-  cudaStream_t stream = ctx->GetDefaultStream();
+  runtimeStream_t stream = ctx->GetDefaultStream();
   if (worker_id == 0) {
     CHECK(send.defined()) << "ValueError: buffer `send` must be provided when worker_id == 0.";
     NDArray buffer = send.value();
@@ -157,10 +195,10 @@ void ScatterFromWorker0(Optional<NDArray> send, NDArray recv) {
 
 void GatherToWorker0(NDArray send, Optional<NDArray> recv) {
   CHECK(send.defined()) << "ValueError: buffer `send` must not be None";
-  NCCLThreadLocalContext* ctx = NCCLThreadLocalContext::Get();
+  CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get();
   int worker_id = ctx->worker->worker_id;
   int num_workers = ctx->worker->num_workers;
-  cudaStream_t stream = ctx->GetDefaultStream();
+  runtimeStream_t stream = ctx->GetDefaultStream();
   if (worker_id == 0) {
     CHECK(recv.defined()) << "ValueError: buffer `recv` must be provided when worker_id == 0.";
     NDArray buffer = recv.value();
@@ -197,8 +235,8 @@ void GatherToWorker0(NDArray send, Optional<NDArray> recv) {
 }
 
 void RecvFromWorker0(NDArray buffer) {
-  NCCLThreadLocalContext* ctx = NCCLThreadLocalContext::Get();
-  cudaStream_t stream = ctx->GetDefaultStream();
+  CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get();
+  runtimeStream_t stream = ctx->GetDefaultStream();
   CHECK_NE(ctx->worker->worker_id, 0)
       << "ValueError: Worker 0 is not allowed to call RecvFromWorker0.";
   NCCL_CALL(ncclGroupStart());
@@ -208,26 +246,33 @@ void RecvFromWorker0(NDArray buffer) {
 }
 
 void SyncWorker() {
-  NCCLThreadLocalContext* ctx = NCCLThreadLocalContext::Get();
+  CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get();
   ICHECK(ctx->worker != nullptr);
-  cudaStream_t stream = ctx->GetDefaultStream();
-  CUDA_CALL(cudaStreamSynchronize(stream));
+  runtimeStream_t stream = ctx->GetDefaultStream();
+  TVM_DISCO_DEVICE_CALL(TVM_DISCO_DEVICE_STREAM_SYNC(stream));
 }
 
-TVM_REGISTER_GLOBAL("runtime.disco.nccl.init_ccl").set_body_typed(InitCCL);
-TVM_REGISTER_GLOBAL("runtime.disco.nccl.init_ccl_per_worker").set_body_typed(InitCCLPerWorker);
-TVM_REGISTER_GLOBAL("runtime.disco.nccl.allreduce")
+TVM_REGISTER_GLOBAL("runtime.disco.compiled_ccl").set_body_typed([]() -> String {
+  return TVM_DISCO_CCL_NAME;
+});
+TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".init_ccl").set_body_typed(InitCCL);
+TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".init_ccl_per_worker")
+    .set_body_typed(InitCCLPerWorker);
+TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".allreduce")
     .set_body_typed([](NDArray send, int kind, NDArray recv) {
       CHECK(0 <= kind && kind <= 4) << "ValueError: Unknown ReduceKind: " << kind;
       AllReduce(send, static_cast<ReduceKind>(kind), recv);
     });
-TVM_REGISTER_GLOBAL("runtime.disco.nccl.broadcast_from_worker0")
+TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".broadcast_from_worker0")
     .set_body_typed(BroadcastFromWorker0);
-TVM_REGISTER_GLOBAL("runtime.disco.nccl.scatter_from_worker0").set_body_typed(ScatterFromWorker0);
-TVM_REGISTER_GLOBAL("runtime.disco.nccl.gather_to_worker0").set_body_typed(GatherToWorker0);
-TVM_REGISTER_GLOBAL("runtime.disco.nccl.recv_from_worker0").set_body_typed(RecvFromWorker0);
-TVM_REGISTER_GLOBAL("runtime.disco.nccl.sync_worker").set_body_typed(SyncWorker);
+TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".scatter_from_worker0")
+    .set_body_typed(ScatterFromWorker0);
+TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".gather_to_worker0")
+    .set_body_typed(GatherToWorker0);
+TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".recv_from_worker0")
+    .set_body_typed(RecvFromWorker0);
+TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".sync_worker").set_body_typed(SyncWorker);
 
-}  // namespace nccl
+}  // namespace ccl
 }  // namespace runtime
 }  // namespace tvm
diff --git a/src/runtime/disco/nccl/utils.h b/src/runtime/disco/ccl/utils.h
similarity index 77%
rename from src/runtime/disco/nccl/utils.h
rename to src/runtime/disco/ccl/utils.h
index 7f40365136..c5066796c0 100644
--- a/src/runtime/disco/nccl/utils.h
+++ b/src/runtime/disco/ccl/utils.h
@@ -16,10 +16,9 @@
  * specific language governing permissions and limitations
  * under the License.
  */
-#ifndef TVM_RUNTIME_DISCO_NCCL_UTILS_H_
-#define TVM_RUNTIME_DISCO_NCCL_UTILS_H_
+#ifndef TVM_RUNTIME_DISCO_CCL_UTILS_H_
+#define TVM_RUNTIME_DISCO_CCL_UTILS_H_
 
-#include <nccl.h>
 #include <tvm/runtime/data_type.h>
 #include <tvm/runtime/disco/session.h>
 
@@ -27,14 +26,14 @@
 
 namespace tvm {
 namespace runtime {
-namespace nccl {
+namespace ccl {
 
-#define NCCL_CALL(cmd)                                       \
-  do {                                                       \
-    ncclResult_t r = cmd;                                    \
-    if (r != ncclSuccess) {                                  \
-      LOG(FATAL) << "NCCLErrror: " << ncclGetErrorString(r); \
-    }                                                        \
+#define NCCL_CALL(cmd)                                                      \
+  do {                                                                      \
+    ncclResult_t r = cmd;                                                   \
+    if (r != ncclSuccess) {                                                 \
+      LOG(FATAL) << TVM_DISCO_CCL_NAME "Errror: " << ncclGetErrorString(r); \
+    }                                                                       \
   } while (0)
 
 inline ncclDataType_t AsNCCLDataType(runtime::DataType dtype) {
@@ -89,7 +88,7 @@ inline ncclRedOp_t AsNCCLRedOp(ReduceKind kind) {
   throw;
 }
 
-}  // namespace nccl
+}  // namespace ccl
 }  // namespace runtime
 }  // namespace tvm
-#endif  // TVM_RUNTIME_DISCO_NCCL_UTILS_H_
+#endif  // TVM_RUNTIME_DISCO_CCL_UTILS_H_
diff --git a/src/runtime/rocm/rocm_device_api.cc b/src/runtime/rocm/rocm_device_api.cc
index c2fb42ee36..fa19c0148d 100644
--- a/src/runtime/rocm/rocm_device_api.cc
+++ b/src/runtime/rocm/rocm_device_api.cc
@@ -247,5 +247,9 @@ TVM_REGISTER_GLOBAL("profiling.timer.rocm").set_body_typed([](Device dev) {
   return Timer(make_object<ROCMTimerNode>());
 });
 
+TVM_REGISTER_GLOBAL("runtime.get_rocm_stream").set_body_typed([]() {
+  return static_cast<void*>(ROCMThreadEntry::ThreadLocal()->stream);
+});
+
 }  // namespace runtime
 }  // namespace tvm
diff --git a/tests/python/disco/test_nccl.py b/tests/python/disco/test_ccl.py
similarity index 88%
rename from tests/python/disco/test_nccl.py
rename to tests/python/disco/test_ccl.py
index e86c973fc2..ecd2e07287 100644
--- a/tests/python/disco/test_nccl.py
+++ b/tests/python/disco/test_ccl.py
@@ -15,7 +15,7 @@
 # specific language governing permissions and limitations
 # under the License.
 # pylint: disable=missing-docstring
-"""Tests for NCCL"""
+"""Tests for NCCL/RCCL"""
 import tempfile
 
 import numpy as np
@@ -28,22 +28,35 @@ from tvm import relax as rx
 from tvm.runtime import disco as di
 from tvm.runtime.relax_vm import VirtualMachine
 from tvm.script import relax as R
+from tvm import get_global_func
 
 _all_session_kinds = [di.ThreadedSession, di.ProcessSession]
+_ccl = [get_global_func("runtime.disco.compiled_ccl")()]
+
+
+def create_device_target(ccl):
+    if ccl == "nccl":
+        dev = tvm.cuda(0)
+    else:
+        dev = tvm.rocm(0)
+    target = tvm.target.Target.from_device(dev)
+    return (dev, target)
 
 
 @pytest.mark.parametrize("session_kind", _all_session_kinds)
-def test_init(session_kind):
+@pytest.mark.parametrize("ccl", _ccl)
+def test_init(session_kind, ccl):
     devices = [0, 1]
     sess = session_kind(num_workers=len(devices))
-    sess.init_ccl("nccl", *devices)
+    sess.init_ccl(ccl, *devices)
 
 
 @pytest.mark.parametrize("session_kind", _all_session_kinds)
-def test_allreduce(session_kind):
+@pytest.mark.parametrize("ccl", _ccl)
+def test_allreduce(session_kind, ccl):
     devices = [0, 1]
     sess = session_kind(num_workers=len(devices))
-    sess.init_ccl("nccl", *devices)
+    sess.init_ccl(ccl, *devices)
 
     array_1 = np.arange(12, dtype="float32").reshape(3, 4)
     array_2 = np.arange(start=1, stop=-11, step=-1, dtype="float32").reshape(3, 4)
@@ -65,10 +78,11 @@ def test_allreduce(session_kind):
 
 
 @pytest.mark.parametrize("session_kind", _all_session_kinds)
-def test_broadcast_from_worker0(session_kind):
+@pytest.mark.parametrize("ccl", _ccl)
+def test_broadcast_from_worker0(session_kind, ccl):
     devices = [0, 1]
     sess = session_kind(num_workers=len(devices))
-    sess.init_ccl("nccl", *devices)
+    sess.init_ccl(ccl, *devices)
 
     array = np.arange(12, dtype="float32").reshape(3, 4)
     d_array = sess.empty((3, 4), "float32")
@@ -80,10 +94,11 @@ def test_broadcast_from_worker0(session_kind):
 
 
 @pytest.mark.parametrize("session_kind", _all_session_kinds)
-def test_scatter(session_kind):
+@pytest.mark.parametrize("ccl", _ccl)
+def test_scatter(session_kind, ccl):
     devices = [0, 1]
     sess = session_kind(num_workers=len(devices))
-    sess.init_ccl("nccl", *devices)
+    sess.init_ccl(ccl, *devices)
 
     array = np.arange(36, dtype="float32").reshape(3, 4, 3)
     d_src = sess.empty((3, 4, 3), "float32")
@@ -103,10 +118,11 @@ def test_scatter(session_kind):
 
 
 @pytest.mark.parametrize("session_kind", _all_session_kinds)
-def test_gather(session_kind):
+@pytest.mark.parametrize("ccl", _ccl)
+def test_gather(session_kind, ccl):
     devices = [0, 1]
     sess = session_kind(num_workers=len(devices))
-    sess.init_ccl("nccl", *devices)
+    sess.init_ccl(ccl, *devices)
 
     array = np.arange(36, dtype="float32")
     d_src = sess.empty((3, 3, 2), "float32")
@@ -121,10 +137,11 @@ def test_gather(session_kind):
 
 
 @pytest.mark.parametrize("session_kind", _all_session_kinds)
-def test_mlp(session_kind):  # pylint: disable=too-many-locals
+@pytest.mark.parametrize("ccl", _ccl)
+def test_mlp(session_kind, ccl):  # pylint: disable=too-many-locals
     devices = [0, 1]
     sess = session_kind(num_workers=len(devices))
-    sess.init_ccl("nccl", *devices)
+    sess.init_ccl(ccl, *devices)
 
     # pylint: disable=invalid-name
     @tvm.script.ir_module
@@ -162,16 +179,7 @@ def test_mlp(session_kind):  # pylint: disable=too-many-locals
             return lv3
 
     # pylint: enable=invalid-name
-    target = tvm.target.Target(
-        {
-            "kind": "cuda",
-            "max_shared_memory_per_block": 49152,
-            "max_threads_per_block": 1024,
-            "thread_warp_size": 32,
-            "registers_per_block": 65536,
-            "arch": "sm_80",
-        }
-    )
+    dev, target = create_device_target(ccl)
 
     def relax_build(mod, target):
         with target:
@@ -183,16 +191,16 @@ def test_mlp(session_kind):  # pylint: disable=too-many-locals
                 dl.gpu.GeneralReduction(),
                 dl.gpu.Fallback(),
             )(mod)
-            return rx.build(mod, target="cuda")
+            return rx.build(mod, target=target)
 
     # pylint: disable=invalid-name
     X = np.random.randn(128, 128).astype("float32")
     W1 = np.random.randn(128, 128).astype("float32")
     W2 = np.random.randn(128, 128).astype("float32")
-    Y_expected = VirtualMachine(relax_build(MLP, target), device=tvm.cuda(0))["main"](
-        tvm.nd.array(X, device=tvm.cuda(0)),
-        tvm.nd.array(W1, device=tvm.cuda(0)),
-        tvm.nd.array(W2, device=tvm.cuda(0)),
+    Y_expected = VirtualMachine(relax_build(MLP, target), device=dev)["main"](
+        tvm.nd.array(X, device=dev),
+        tvm.nd.array(W1, device=dev),
+        tvm.nd.array(W2, device=dev),
     ).numpy()
 
     with tempfile.TemporaryDirectory() as tmpdir:
@@ -211,7 +219,7 @@ def test_mlp(session_kind):  # pylint: disable=too-many-locals
         d_W2.debug_copy_from(0, W2[:64, :])
         d_W2.debug_copy_from(1, W2[64:, :])
         d_Y = mod["main"](d_X, d_W1, d_W2)
-        Y_result = tvm.nd.empty((128, 128), "float32", device=tvm.cuda(0))
+        Y_result = tvm.nd.empty((128, 128), "float32", device=dev)
         sess.copy_from_worker_0(Y_result, d_Y)
         sess.sync_worker_0()
         Y_result = Y_result.numpy()
@@ -220,10 +228,11 @@ def test_mlp(session_kind):  # pylint: disable=too-many-locals
 
 
 @pytest.mark.parametrize("session_kind", _all_session_kinds)
-def test_attention(session_kind):  # pylint: disable=too-many-locals,too-many-statements
+@pytest.mark.parametrize("ccl", _ccl)
+def test_attention(session_kind, ccl):  # pylint: disable=too-many-locals,too-many-statements
     devices = [0, 1]
     sess = session_kind(num_workers=len(devices))
-    sess.init_ccl("nccl", *devices)
+    sess.init_ccl(ccl, *devices)
 
     # pylint: disable=invalid-name
     @tvm.script.ir_module
@@ -309,16 +318,7 @@ def test_attention(session_kind):  # pylint: disable=too-many-locals,too-many-st
             return lv17
 
     # pylint: enable=invalid-name
-    target = tvm.target.Target(
-        {
-            "kind": "cuda",
-            "max_shared_memory_per_block": 49152,
-            "max_threads_per_block": 1024,
-            "thread_warp_size": 32,
-            "registers_per_block": 65536,
-            "arch": "sm_80",
-        }
-    )
+    dev, target = create_device_target(ccl)
 
     def relax_build(mod, target):
         with target:
@@ -330,7 +330,7 @@ def test_attention(session_kind):  # pylint: disable=too-many-locals,too-many-st
                 dl.gpu.GeneralReduction(),
                 dl.gpu.Fallback(),
             )(mod)
-            return rx.build(mod, target="cuda")
+            return rx.build(mod, target=target)
 
     # pylint: disable=invalid-name
     X = np.random.randn(1, 10, 128).astype("float32")
@@ -338,12 +338,12 @@ def test_attention(session_kind):  # pylint: disable=too-many-locals,too-many-st
     Wk = np.random.randn(128, 512).astype("float32")
     Wv = np.random.randn(128, 512).astype("float32")
     Wo = np.random.randn(512, 128).astype("float32")
-    Y_expected = VirtualMachine(relax_build(Attention, target), device=tvm.cuda(0))["main"](
-        tvm.nd.array(X, device=tvm.cuda(0)),
-        tvm.nd.array(Wq, device=tvm.cuda(0)),
-        tvm.nd.array(Wk, device=tvm.cuda(0)),
-        tvm.nd.array(Wv, device=tvm.cuda(0)),
-        tvm.nd.array(Wo, device=tvm.cuda(0)),
+    Y_expected = VirtualMachine(relax_build(Attention, target), device=dev)["main"](
+        tvm.nd.array(X, device=dev),
+        tvm.nd.array(Wq, device=dev),
+        tvm.nd.array(Wk, device=dev),
+        tvm.nd.array(Wv, device=dev),
+        tvm.nd.array(Wo, device=dev),
     ).numpy()
 
     with tempfile.TemporaryDirectory() as tmpdir:
@@ -368,7 +368,7 @@ def test_attention(session_kind):  # pylint: disable=too-many-locals,too-many-st
         d_Wo.debug_copy_from(0, Wo[:256, :])
         d_Wo.debug_copy_from(1, Wo[256:, :])
         d_Y = mod["main"](d_X, d_Wq, d_Wk, d_Wv, d_Wo)
-        Y_result = tvm.nd.empty((1, 10, 128), "float32", device=tvm.cuda(0))
+        Y_result = tvm.nd.empty((1, 10, 128), "float32", device=dev)
         sess.copy_from_worker_0(Y_result, d_Y)
         sess.sync_worker_0()
         Y_result = Y_result.numpy()