You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ho...@apache.org on 2023/09/20 22:26:08 UTC
[tvm] branch unity updated: [Disco] Integrate RCCL (#15776)
This is an automated email from the ASF dual-hosted git repository.
hongyij pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 2fdedf1ea8 [Disco] Integrate RCCL (#15776)
2fdedf1ea8 is described below
commit 2fdedf1ea88a149cdeb4c82cc8d8aa9c6ffdaa55
Author: Lesheng Jin <34...@users.noreply.github.com>
AuthorDate: Wed Sep 20 15:26:02 2023 -0700
[Disco] Integrate RCCL (#15776)
This PR integrates RCCL for amd multi-GPU parallelism.
---
CMakeLists.txt | 11 +-
python/tvm/target/detect_target.py | 2 +-
src/runtime/disco/{nccl/nccl.cc => ccl/ccl.cc} | 127 +++++++++++++++--------
src/runtime/disco/{nccl => ccl}/utils.h | 23 ++--
src/runtime/rocm/rocm_device_api.cc | 4 +
tests/python/disco/{test_nccl.py => test_ccl.py} | 98 ++++++++---------
6 files changed, 155 insertions(+), 110 deletions(-)
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 5bb3d3e2e0..bc09655923 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -444,14 +444,14 @@ endif(USE_PROFILER)
if(USE_CUDA AND USE_NCCL)
message(STATUS "Build with NCCL...")
find_nccl(${USE_NCCL})
- tvm_file_glob(GLOB RUNTIME_NCCL_SRC src/runtime/disco/nccl/*.cc)
+ tvm_file_glob(GLOB RUNTIME_NCCL_SRC src/runtime/disco/ccl/*.cc)
list(APPEND RUNTIME_SRCS ${RUNTIME_NCCL_SRC})
endif()
if(USE_ROCM AND USE_RCCL)
message(STATUS "Build with RCCL...")
find_rccl(${USE_RCCL})
- tvm_file_glob(GLOB RUNTIME_RCCL_SRC src/runtime/disco/rccl/*.cc)
+ tvm_file_glob(GLOB RUNTIME_RCCL_SRC src/runtime/disco/ccl/*.cc)
list(APPEND RUNTIME_SRCS ${RUNTIME_RCCL_SRC})
endif()
@@ -891,11 +891,7 @@ endif()
if(USE_CUDA AND USE_NCCL)
target_link_libraries(tvm_runtime PRIVATE nccl)
target_link_libraries(tvm PRIVATE nccl)
-endif()
-
-if(USE_CUDA AND USE_NCCL)
- target_link_libraries(tvm PRIVATE nccl)
- target_link_libraries(tvm_runtime PRIVATE nccl)
+ set_source_files_properties(src/runtime/disco/ccl/ccl.cc PROPERTIES COMPILE_DEFINITIONS "TVM_NCCL_RCCL_SWITCH=0")
endif()
if(USE_CUDA AND USE_NVTX)
@@ -905,4 +901,5 @@ endif()
if(USE_ROCM AND USE_RCCL)
target_link_libraries(tvm PRIVATE rccl)
target_link_libraries(tvm_runtime PRIVATE rccl)
+ set_source_files_properties(src/runtime/disco/ccl/ccl.cc PROPERTIES COMPILE_DEFINITIONS "TVM_NCCL_RCCL_SWITCH=1")
endif()
diff --git a/python/tvm/target/detect_target.py b/python/tvm/target/detect_target.py
index 5c139cc949..0cfcf17e6a 100644
--- a/python/tvm/target/detect_target.py
+++ b/python/tvm/target/detect_target.py
@@ -50,7 +50,7 @@ def _detect_rocm(dev: Device) -> Target:
return Target(
{
"kind": "rocm",
- "mtriple": "amdgcn-and-amdhsa-hcc",
+ "mtriple": "amdgcn-amd-amdhsa-hcc",
"max_shared_memory_per_block": dev.max_shared_memory_per_block,
"max_threads_per_block": dev.max_threads_per_block,
"thread_warp_size": dev.warp_size,
diff --git a/src/runtime/disco/nccl/nccl.cc b/src/runtime/disco/ccl/ccl.cc
similarity index 66%
rename from src/runtime/disco/nccl/nccl.cc
rename to src/runtime/disco/ccl/ccl.cc
index 07ffbed6f0..8e656f9ab2 100644
--- a/src/runtime/disco/nccl/nccl.cc
+++ b/src/runtime/disco/ccl/ccl.cc
@@ -16,9 +16,48 @@
* specific language governing permissions and limitations
* under the License.
*/
+#ifndef TVM_NCCL_RCCL_SWITCH
+#define TVM_NCCL_RCCL_SWITCH 0 // 0: NCCL, 1: RCCL
+#endif
+
+#if TVM_NCCL_RCCL_SWITCH == 0
#include <cuda_runtime_api.h>
-#include <dlpack/dlpack.h>
#include <nccl.h>
+
+#include "../../cuda/cuda_common.h"
+
+using runtimeStream_t = cudaStream_t;
+
+#define TVM_DISCO_DEVICE_CALL CUDA_CALL
+#define TVM_DISCO_DEVICE_SET_DEVICE cudaSetDevice
+#define TVM_DISCO_DEVICE_STREAM_CREATE cudaStreamCreate
+#define TVM_DISCO_DEVICE_STREAM_SYNC cudaStreamSynchronize
+#define TVM_DISCO_DEVICE_STREAM_DESTROY cudaStreamDestroy
+#define TVM_DISCO_DEVICE_NAME "cuda"
+#define TVM_DISCO_CCL_DESTROY ncclCommDestroy
+#define TVM_DISCO_CCL_NAME "nccl"
+const constexpr DLDeviceType TVM_DISCO_DEVICE_TYPE = DLDeviceType::kDLCUDA;
+#else
+#include <hip/hip_runtime_api.h>
+#include <hip/hip_version.h>
+#include <rccl/rccl.h>
+
+#include "../../rocm/rocm_common.h"
+
+using runtimeStream_t = hipStream_t;
+
+#define TVM_DISCO_DEVICE_CALL ROCM_CALL
+#define TVM_DISCO_DEVICE_SET_DEVICE hipSetDevice
+#define TVM_DISCO_DEVICE_STREAM_CREATE hipStreamCreate
+#define TVM_DISCO_DEVICE_STREAM_SYNC hipStreamSynchronize
+#define TVM_DISCO_DEVICE_STREAM_DESTROY hipStreamDestroy
+#define TVM_DISCO_DEVICE_NAME "rocm"
+#define TVM_DISCO_CCL_DESTROY ncclCommDestroy
+#define TVM_DISCO_CCL_NAME "rccl"
+const constexpr DLDeviceType TVM_DISCO_DEVICE_TYPE = DLDeviceType::kDLROCM;
+#endif
+
+#include <dlpack/dlpack.h>
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/disco/session.h>
#include <tvm/runtime/registry.h>
@@ -29,40 +68,39 @@
#include <vector>
#include "../../../support/process_id.h"
-#include "../../cuda/cuda_common.h"
#include "./utils.h"
namespace tvm {
namespace runtime {
-namespace nccl {
+namespace ccl {
-struct NCCLThreadLocalContext {
+struct CCLThreadLocalContext {
DiscoWorker* worker;
int device_id;
- cudaStream_t default_stream;
+ runtimeStream_t default_stream;
ncclComm_t comm;
void Clear() {
- NCCL_CALL(ncclCommDestroy(comm));
- CUDA_CALL(cudaStreamDestroy(default_stream));
+ NCCL_CALL(TVM_DISCO_CCL_DESTROY(comm));
+ TVM_DISCO_DEVICE_CALL(TVM_DISCO_DEVICE_STREAM_DESTROY(default_stream));
}
- cudaStream_t GetDefaultStream() {
- const auto* func = tvm::runtime::Registry::Get("runtime.get_cuda_stream");
+ runtimeStream_t GetDefaultStream() {
+ const auto* func = tvm::runtime::Registry::Get("runtime.get_" TVM_DISCO_DEVICE_NAME "_stream");
ICHECK(func != nullptr);
- cudaStream_t stream = static_cast<cudaStream_t>((*func)().operator void*());
+ runtimeStream_t stream = static_cast<runtimeStream_t>((*func)().operator void*());
return stream == nullptr ? default_stream : stream;
}
- static NCCLThreadLocalContext* Get() {
- thread_local static NCCLThreadLocalContext ctx;
+ static CCLThreadLocalContext* Get() {
+ thread_local static CCLThreadLocalContext ctx;
return &ctx;
}
};
void InitCCL(Session sess, ShapeTuple device_ids) {
- DRef func = sess->GetGlobalFunc("runtime.disco.nccl.init_ccl_per_worker");
- LOG(INFO) << "Initializing NCCL with devices: " << device_ids;
+ DRef func = sess->GetGlobalFunc("runtime.disco." TVM_DISCO_CCL_NAME ".init_ccl_per_worker");
+ LOG(INFO) << "Initializing " TVM_DISCO_CCL_NAME " with devices: " << device_ids;
ncclUniqueId id;
TVMByteArray array;
NCCL_CALL(ncclGetUniqueId(&id));
@@ -72,7 +110,7 @@ void InitCCL(Session sess, ShapeTuple device_ids) {
}
void InitCCLPerWorker(ShapeTuple device_ids, std::string unique_id_bytes) {
- NCCLThreadLocalContext* ctx = NCCLThreadLocalContext::Get();
+ CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get();
DiscoWorker* worker = DiscoWorker::ThreadLocal();
ICHECK(worker != nullptr);
CHECK_EQ(unique_id_bytes.size(), NCCL_UNIQUE_ID_BYTES)
@@ -80,11 +118,11 @@ void InitCCLPerWorker(ShapeTuple device_ids, std::string unique_id_bytes) {
<< unique_id_bytes.size() << ".";
// Step up local context of NCCL
int device_id = device_ids[worker->worker_id];
- CUDA_CALL(cudaSetDevice(device_id));
- CUDA_CALL(cudaStreamCreate(&ctx->default_stream));
- Device device{DLDeviceType::kDLCUDA, device_id};
+ TVM_DISCO_DEVICE_CALL(TVM_DISCO_DEVICE_SET_DEVICE(device_id));
+ TVM_DISCO_DEVICE_CALL(TVM_DISCO_DEVICE_STREAM_CREATE(&ctx->default_stream));
+ Device device{TVM_DISCO_DEVICE_TYPE, device_id};
worker->default_device = device;
- worker->ccl = "nccl";
+ worker->ccl = TVM_DISCO_CCL_NAME;
ctx->worker = worker;
ctx->device_id = device_id;
// Initialize the communicator
@@ -94,21 +132,21 @@ void InitCCLPerWorker(ShapeTuple device_ids, std::string unique_id_bytes) {
}
void AllReduce(NDArray send, ReduceKind reduce_kind, NDArray recv) {
- NCCLThreadLocalContext* ctx = NCCLThreadLocalContext::Get();
+ CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get();
ShapeTuple shape = send.Shape();
int64_t numel = shape->Product();
- cudaStream_t stream = ctx->GetDefaultStream();
+ runtimeStream_t stream = ctx->GetDefaultStream();
NCCL_CALL(ncclAllReduce(send->data, recv->data, numel,
/*datatype=*/AsNCCLDataType(DataType(send->dtype)),
/*op=*/AsNCCLRedOp(reduce_kind), ctx->comm, stream));
}
void BroadcastFromWorker0(NDArray send, NDArray recv) {
- NCCLThreadLocalContext* ctx = NCCLThreadLocalContext::Get();
+ CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get();
ICHECK(send.Shape()->Product() == recv.Shape()->Product());
ShapeTuple shape = send.Shape();
int64_t numel = shape->Product();
- cudaStream_t stream = ctx->GetDefaultStream();
+ runtimeStream_t stream = ctx->GetDefaultStream();
NCCL_CALL(ncclBroadcast(send->data, recv->data, numel,
/*datatype=*/AsNCCLDataType(DataType(send->dtype)),
/*root=*/0, ctx->comm, stream));
@@ -116,10 +154,10 @@ void BroadcastFromWorker0(NDArray send, NDArray recv) {
void ScatterFromWorker0(Optional<NDArray> send, NDArray recv) {
CHECK(recv.defined()) << "ValueError: buffer `recv` must not be None";
- NCCLThreadLocalContext* ctx = NCCLThreadLocalContext::Get();
+ CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get();
int worker_id = ctx->worker->worker_id;
int num_workers = ctx->worker->num_workers;
- cudaStream_t stream = ctx->GetDefaultStream();
+ runtimeStream_t stream = ctx->GetDefaultStream();
if (worker_id == 0) {
CHECK(send.defined()) << "ValueError: buffer `send` must be provided when worker_id == 0.";
NDArray buffer = send.value();
@@ -157,10 +195,10 @@ void ScatterFromWorker0(Optional<NDArray> send, NDArray recv) {
void GatherToWorker0(NDArray send, Optional<NDArray> recv) {
CHECK(send.defined()) << "ValueError: buffer `send` must not be None";
- NCCLThreadLocalContext* ctx = NCCLThreadLocalContext::Get();
+ CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get();
int worker_id = ctx->worker->worker_id;
int num_workers = ctx->worker->num_workers;
- cudaStream_t stream = ctx->GetDefaultStream();
+ runtimeStream_t stream = ctx->GetDefaultStream();
if (worker_id == 0) {
CHECK(recv.defined()) << "ValueError: buffer `recv` must be provided when worker_id == 0.";
NDArray buffer = recv.value();
@@ -197,8 +235,8 @@ void GatherToWorker0(NDArray send, Optional<NDArray> recv) {
}
void RecvFromWorker0(NDArray buffer) {
- NCCLThreadLocalContext* ctx = NCCLThreadLocalContext::Get();
- cudaStream_t stream = ctx->GetDefaultStream();
+ CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get();
+ runtimeStream_t stream = ctx->GetDefaultStream();
CHECK_NE(ctx->worker->worker_id, 0)
<< "ValueError: Worker 0 is not allowed to call RecvFromWorker0.";
NCCL_CALL(ncclGroupStart());
@@ -208,26 +246,33 @@ void RecvFromWorker0(NDArray buffer) {
}
void SyncWorker() {
- NCCLThreadLocalContext* ctx = NCCLThreadLocalContext::Get();
+ CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get();
ICHECK(ctx->worker != nullptr);
- cudaStream_t stream = ctx->GetDefaultStream();
- CUDA_CALL(cudaStreamSynchronize(stream));
+ runtimeStream_t stream = ctx->GetDefaultStream();
+ TVM_DISCO_DEVICE_CALL(TVM_DISCO_DEVICE_STREAM_SYNC(stream));
}
-TVM_REGISTER_GLOBAL("runtime.disco.nccl.init_ccl").set_body_typed(InitCCL);
-TVM_REGISTER_GLOBAL("runtime.disco.nccl.init_ccl_per_worker").set_body_typed(InitCCLPerWorker);
-TVM_REGISTER_GLOBAL("runtime.disco.nccl.allreduce")
+TVM_REGISTER_GLOBAL("runtime.disco.compiled_ccl").set_body_typed([]() -> String {
+ return TVM_DISCO_CCL_NAME;
+});
+TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".init_ccl").set_body_typed(InitCCL);
+TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".init_ccl_per_worker")
+ .set_body_typed(InitCCLPerWorker);
+TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".allreduce")
.set_body_typed([](NDArray send, int kind, NDArray recv) {
CHECK(0 <= kind && kind <= 4) << "ValueError: Unknown ReduceKind: " << kind;
AllReduce(send, static_cast<ReduceKind>(kind), recv);
});
-TVM_REGISTER_GLOBAL("runtime.disco.nccl.broadcast_from_worker0")
+TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".broadcast_from_worker0")
.set_body_typed(BroadcastFromWorker0);
-TVM_REGISTER_GLOBAL("runtime.disco.nccl.scatter_from_worker0").set_body_typed(ScatterFromWorker0);
-TVM_REGISTER_GLOBAL("runtime.disco.nccl.gather_to_worker0").set_body_typed(GatherToWorker0);
-TVM_REGISTER_GLOBAL("runtime.disco.nccl.recv_from_worker0").set_body_typed(RecvFromWorker0);
-TVM_REGISTER_GLOBAL("runtime.disco.nccl.sync_worker").set_body_typed(SyncWorker);
+TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".scatter_from_worker0")
+ .set_body_typed(ScatterFromWorker0);
+TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".gather_to_worker0")
+ .set_body_typed(GatherToWorker0);
+TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".recv_from_worker0")
+ .set_body_typed(RecvFromWorker0);
+TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".sync_worker").set_body_typed(SyncWorker);
-} // namespace nccl
+} // namespace ccl
} // namespace runtime
} // namespace tvm
diff --git a/src/runtime/disco/nccl/utils.h b/src/runtime/disco/ccl/utils.h
similarity index 77%
rename from src/runtime/disco/nccl/utils.h
rename to src/runtime/disco/ccl/utils.h
index 7f40365136..c5066796c0 100644
--- a/src/runtime/disco/nccl/utils.h
+++ b/src/runtime/disco/ccl/utils.h
@@ -16,10 +16,9 @@
* specific language governing permissions and limitations
* under the License.
*/
-#ifndef TVM_RUNTIME_DISCO_NCCL_UTILS_H_
-#define TVM_RUNTIME_DISCO_NCCL_UTILS_H_
+#ifndef TVM_RUNTIME_DISCO_CCL_UTILS_H_
+#define TVM_RUNTIME_DISCO_CCL_UTILS_H_
-#include <nccl.h>
#include <tvm/runtime/data_type.h>
#include <tvm/runtime/disco/session.h>
@@ -27,14 +26,14 @@
namespace tvm {
namespace runtime {
-namespace nccl {
+namespace ccl {
-#define NCCL_CALL(cmd) \
- do { \
- ncclResult_t r = cmd; \
- if (r != ncclSuccess) { \
- LOG(FATAL) << "NCCLErrror: " << ncclGetErrorString(r); \
- } \
+#define NCCL_CALL(cmd) \
+ do { \
+ ncclResult_t r = cmd; \
+ if (r != ncclSuccess) { \
+ LOG(FATAL) << TVM_DISCO_CCL_NAME "Errror: " << ncclGetErrorString(r); \
+ } \
} while (0)
inline ncclDataType_t AsNCCLDataType(runtime::DataType dtype) {
@@ -89,7 +88,7 @@ inline ncclRedOp_t AsNCCLRedOp(ReduceKind kind) {
throw;
}
-} // namespace nccl
+} // namespace ccl
} // namespace runtime
} // namespace tvm
-#endif // TVM_RUNTIME_DISCO_NCCL_UTILS_H_
+#endif // TVM_RUNTIME_DISCO_CCL_UTILS_H_
diff --git a/src/runtime/rocm/rocm_device_api.cc b/src/runtime/rocm/rocm_device_api.cc
index c2fb42ee36..fa19c0148d 100644
--- a/src/runtime/rocm/rocm_device_api.cc
+++ b/src/runtime/rocm/rocm_device_api.cc
@@ -247,5 +247,9 @@ TVM_REGISTER_GLOBAL("profiling.timer.rocm").set_body_typed([](Device dev) {
return Timer(make_object<ROCMTimerNode>());
});
+TVM_REGISTER_GLOBAL("runtime.get_rocm_stream").set_body_typed([]() {
+ return static_cast<void*>(ROCMThreadEntry::ThreadLocal()->stream);
+});
+
} // namespace runtime
} // namespace tvm
diff --git a/tests/python/disco/test_nccl.py b/tests/python/disco/test_ccl.py
similarity index 88%
rename from tests/python/disco/test_nccl.py
rename to tests/python/disco/test_ccl.py
index e86c973fc2..ecd2e07287 100644
--- a/tests/python/disco/test_nccl.py
+++ b/tests/python/disco/test_ccl.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=missing-docstring
-"""Tests for NCCL"""
+"""Tests for NCCL/RCCL"""
import tempfile
import numpy as np
@@ -28,22 +28,35 @@ from tvm import relax as rx
from tvm.runtime import disco as di
from tvm.runtime.relax_vm import VirtualMachine
from tvm.script import relax as R
+from tvm import get_global_func
_all_session_kinds = [di.ThreadedSession, di.ProcessSession]
+_ccl = [get_global_func("runtime.disco.compiled_ccl")()]
+
+
+def create_device_target(ccl):
+ if ccl == "nccl":
+ dev = tvm.cuda(0)
+ else:
+ dev = tvm.rocm(0)
+ target = tvm.target.Target.from_device(dev)
+ return (dev, target)
@pytest.mark.parametrize("session_kind", _all_session_kinds)
-def test_init(session_kind):
+@pytest.mark.parametrize("ccl", _ccl)
+def test_init(session_kind, ccl):
devices = [0, 1]
sess = session_kind(num_workers=len(devices))
- sess.init_ccl("nccl", *devices)
+ sess.init_ccl(ccl, *devices)
@pytest.mark.parametrize("session_kind", _all_session_kinds)
-def test_allreduce(session_kind):
+@pytest.mark.parametrize("ccl", _ccl)
+def test_allreduce(session_kind, ccl):
devices = [0, 1]
sess = session_kind(num_workers=len(devices))
- sess.init_ccl("nccl", *devices)
+ sess.init_ccl(ccl, *devices)
array_1 = np.arange(12, dtype="float32").reshape(3, 4)
array_2 = np.arange(start=1, stop=-11, step=-1, dtype="float32").reshape(3, 4)
@@ -65,10 +78,11 @@ def test_allreduce(session_kind):
@pytest.mark.parametrize("session_kind", _all_session_kinds)
-def test_broadcast_from_worker0(session_kind):
+@pytest.mark.parametrize("ccl", _ccl)
+def test_broadcast_from_worker0(session_kind, ccl):
devices = [0, 1]
sess = session_kind(num_workers=len(devices))
- sess.init_ccl("nccl", *devices)
+ sess.init_ccl(ccl, *devices)
array = np.arange(12, dtype="float32").reshape(3, 4)
d_array = sess.empty((3, 4), "float32")
@@ -80,10 +94,11 @@ def test_broadcast_from_worker0(session_kind):
@pytest.mark.parametrize("session_kind", _all_session_kinds)
-def test_scatter(session_kind):
+@pytest.mark.parametrize("ccl", _ccl)
+def test_scatter(session_kind, ccl):
devices = [0, 1]
sess = session_kind(num_workers=len(devices))
- sess.init_ccl("nccl", *devices)
+ sess.init_ccl(ccl, *devices)
array = np.arange(36, dtype="float32").reshape(3, 4, 3)
d_src = sess.empty((3, 4, 3), "float32")
@@ -103,10 +118,11 @@ def test_scatter(session_kind):
@pytest.mark.parametrize("session_kind", _all_session_kinds)
-def test_gather(session_kind):
+@pytest.mark.parametrize("ccl", _ccl)
+def test_gather(session_kind, ccl):
devices = [0, 1]
sess = session_kind(num_workers=len(devices))
- sess.init_ccl("nccl", *devices)
+ sess.init_ccl(ccl, *devices)
array = np.arange(36, dtype="float32")
d_src = sess.empty((3, 3, 2), "float32")
@@ -121,10 +137,11 @@ def test_gather(session_kind):
@pytest.mark.parametrize("session_kind", _all_session_kinds)
-def test_mlp(session_kind): # pylint: disable=too-many-locals
+@pytest.mark.parametrize("ccl", _ccl)
+def test_mlp(session_kind, ccl): # pylint: disable=too-many-locals
devices = [0, 1]
sess = session_kind(num_workers=len(devices))
- sess.init_ccl("nccl", *devices)
+ sess.init_ccl(ccl, *devices)
# pylint: disable=invalid-name
@tvm.script.ir_module
@@ -162,16 +179,7 @@ def test_mlp(session_kind): # pylint: disable=too-many-locals
return lv3
# pylint: enable=invalid-name
- target = tvm.target.Target(
- {
- "kind": "cuda",
- "max_shared_memory_per_block": 49152,
- "max_threads_per_block": 1024,
- "thread_warp_size": 32,
- "registers_per_block": 65536,
- "arch": "sm_80",
- }
- )
+ dev, target = create_device_target(ccl)
def relax_build(mod, target):
with target:
@@ -183,16 +191,16 @@ def test_mlp(session_kind): # pylint: disable=too-many-locals
dl.gpu.GeneralReduction(),
dl.gpu.Fallback(),
)(mod)
- return rx.build(mod, target="cuda")
+ return rx.build(mod, target=target)
# pylint: disable=invalid-name
X = np.random.randn(128, 128).astype("float32")
W1 = np.random.randn(128, 128).astype("float32")
W2 = np.random.randn(128, 128).astype("float32")
- Y_expected = VirtualMachine(relax_build(MLP, target), device=tvm.cuda(0))["main"](
- tvm.nd.array(X, device=tvm.cuda(0)),
- tvm.nd.array(W1, device=tvm.cuda(0)),
- tvm.nd.array(W2, device=tvm.cuda(0)),
+ Y_expected = VirtualMachine(relax_build(MLP, target), device=dev)["main"](
+ tvm.nd.array(X, device=dev),
+ tvm.nd.array(W1, device=dev),
+ tvm.nd.array(W2, device=dev),
).numpy()
with tempfile.TemporaryDirectory() as tmpdir:
@@ -211,7 +219,7 @@ def test_mlp(session_kind): # pylint: disable=too-many-locals
d_W2.debug_copy_from(0, W2[:64, :])
d_W2.debug_copy_from(1, W2[64:, :])
d_Y = mod["main"](d_X, d_W1, d_W2)
- Y_result = tvm.nd.empty((128, 128), "float32", device=tvm.cuda(0))
+ Y_result = tvm.nd.empty((128, 128), "float32", device=dev)
sess.copy_from_worker_0(Y_result, d_Y)
sess.sync_worker_0()
Y_result = Y_result.numpy()
@@ -220,10 +228,11 @@ def test_mlp(session_kind): # pylint: disable=too-many-locals
@pytest.mark.parametrize("session_kind", _all_session_kinds)
-def test_attention(session_kind): # pylint: disable=too-many-locals,too-many-statements
+@pytest.mark.parametrize("ccl", _ccl)
+def test_attention(session_kind, ccl): # pylint: disable=too-many-locals,too-many-statements
devices = [0, 1]
sess = session_kind(num_workers=len(devices))
- sess.init_ccl("nccl", *devices)
+ sess.init_ccl(ccl, *devices)
# pylint: disable=invalid-name
@tvm.script.ir_module
@@ -309,16 +318,7 @@ def test_attention(session_kind): # pylint: disable=too-many-locals,too-many-st
return lv17
# pylint: enable=invalid-name
- target = tvm.target.Target(
- {
- "kind": "cuda",
- "max_shared_memory_per_block": 49152,
- "max_threads_per_block": 1024,
- "thread_warp_size": 32,
- "registers_per_block": 65536,
- "arch": "sm_80",
- }
- )
+ dev, target = create_device_target(ccl)
def relax_build(mod, target):
with target:
@@ -330,7 +330,7 @@ def test_attention(session_kind): # pylint: disable=too-many-locals,too-many-st
dl.gpu.GeneralReduction(),
dl.gpu.Fallback(),
)(mod)
- return rx.build(mod, target="cuda")
+ return rx.build(mod, target=target)
# pylint: disable=invalid-name
X = np.random.randn(1, 10, 128).astype("float32")
@@ -338,12 +338,12 @@ def test_attention(session_kind): # pylint: disable=too-many-locals,too-many-st
Wk = np.random.randn(128, 512).astype("float32")
Wv = np.random.randn(128, 512).astype("float32")
Wo = np.random.randn(512, 128).astype("float32")
- Y_expected = VirtualMachine(relax_build(Attention, target), device=tvm.cuda(0))["main"](
- tvm.nd.array(X, device=tvm.cuda(0)),
- tvm.nd.array(Wq, device=tvm.cuda(0)),
- tvm.nd.array(Wk, device=tvm.cuda(0)),
- tvm.nd.array(Wv, device=tvm.cuda(0)),
- tvm.nd.array(Wo, device=tvm.cuda(0)),
+ Y_expected = VirtualMachine(relax_build(Attention, target), device=dev)["main"](
+ tvm.nd.array(X, device=dev),
+ tvm.nd.array(Wq, device=dev),
+ tvm.nd.array(Wk, device=dev),
+ tvm.nd.array(Wv, device=dev),
+ tvm.nd.array(Wo, device=dev),
).numpy()
with tempfile.TemporaryDirectory() as tmpdir:
@@ -368,7 +368,7 @@ def test_attention(session_kind): # pylint: disable=too-many-locals,too-many-st
d_Wo.debug_copy_from(0, Wo[:256, :])
d_Wo.debug_copy_from(1, Wo[256:, :])
d_Y = mod["main"](d_X, d_Wq, d_Wk, d_Wv, d_Wo)
- Y_result = tvm.nd.empty((1, 10, 128), "float32", device=tvm.cuda(0))
+ Y_result = tvm.nd.empty((1, 10, 128), "float32", device=dev)
sess.copy_from_worker_0(Y_result, d_Y)
sess.sync_worker_0()
Y_result = Y_result.numpy()