You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@singa.apache.org by wa...@apache.org on 2020/02/11 05:27:25 UTC
[singa] branch dev updated: clang-format for distributed module
This is an automated email from the ASF dual-hosted git repository.
wangwei pushed a commit to branch dev
in repository https://gitbox.apache.org/repos/asf/singa.git
The following commit(s) were added to refs/heads/dev by this push:
new 1f1dcdf clang-format for distributed module
new 3d5f53b Merge pull request #595 from chrishkchris/clang-format-dist
1f1dcdf is described below
commit 1f1dcdf64ca08e8bbba1ec7500e1a3c8b67dfa2f
Author: Chris Yeung <ch...@yahoo.com.hk>
AuthorDate: Tue Feb 11 12:08:38 2020 +0800
clang-format for distributed module
---
include/singa/io/communicator.h | 93 ++++-----
src/io/communicator.cc | 414 +++++++++++++++++++---------------------
2 files changed, 248 insertions(+), 259 deletions(-)
diff --git a/include/singa/io/communicator.h b/include/singa/io/communicator.h
index 7192d44..b9fc6d5 100644
--- a/include/singa/io/communicator.h
+++ b/include/singa/io/communicator.h
@@ -34,69 +34,77 @@
#include <cusparse.h>
using std::vector;
-namespace singa{
-
-#define CUSPARSE_CHECK(cmd) do { \
- cusparseStatus_t e = cmd; \
- if (e != CUSPARSE_STATUS_SUCCESS) { \
- printf("Falied: Cusparse Error %s:%d '%d'\n", \
- __FILE__,__LINE__, int(e)); \
- exit(EXIT_FAILURE); \
- } \
-} while(0)
-
-#define MPICHECK(cmd) do { \
- int e = cmd; \
- if( e != MPI_SUCCESS ) { \
- printf("Failed: MPI error %s:%d '%d'\n", \
- __FILE__,__LINE__, e); \
- exit(EXIT_FAILURE); \
- } \
-} while(0)
-
-#define NCCLCHECK(cmd) do { \
- ncclResult_t r = cmd; \
- if (r!= ncclSuccess) { \
- printf("Failed, NCCL error %s:%d '%s'\n", \
- __FILE__,__LINE__,ncclGetErrorString(r)); \
- exit(EXIT_FAILURE); \
- } \
-} while(0)
+namespace singa {
+
+#define CUSPARSE_CHECK(cmd) \
+ do { \
+ cusparseStatus_t e = cmd; \
+ if (e != CUSPARSE_STATUS_SUCCESS) { \
+ printf("Falied: Cusparse Error %s:%d '%d'\n", __FILE__, __LINE__, \
+ int(e)); \
+ exit(EXIT_FAILURE); \
+ } \
+ } while (0)
+
+#define MPICHECK(cmd) \
+ do { \
+ int e = cmd; \
+ if (e != MPI_SUCCESS) { \
+ printf("Failed: MPI error %s:%d '%d'\n", __FILE__, __LINE__, e); \
+ exit(EXIT_FAILURE); \
+ } \
+ } while (0)
+
+#define NCCLCHECK(cmd) \
+ do { \
+ ncclResult_t r = cmd; \
+ if (r != ncclSuccess) { \
+ printf("Failed, NCCL error %s:%d '%s'\n", __FILE__, __LINE__, \
+ ncclGetErrorString(r)); \
+ exit(EXIT_FAILURE); \
+ } \
+ } while (0)
class NcclIdHolder {
-public:
+ public:
ncclUniqueId id;
- NcclIdHolder();
+ NcclIdHolder();
~NcclIdHolder();
};
class Communicator {
-public:
+ public:
int MPIRankInGlobal;
int totalMPIRanksInGlobal;
int MPIRankInLocal;
Communicator(int limit);
- Communicator(int gpu_num, int gpu_per_node, const NcclIdHolder &holder, int size);
+ Communicator(int gpu_num, int gpu_per_node, const NcclIdHolder &holder,
+ int size);
~Communicator();
void synch(Tensor &t);
void fusedSynch(vector<Tensor> &t);
void synchHalf(Tensor &t);
void fusedSynchHalf(vector<Tensor> &t);
- void fusedSparsification(vector<Tensor> &t, Tensor &accumulation, float sparsThreshold, bool topK);
+ void fusedSparsification(vector<Tensor> &t, Tensor &accumulation,
+ float sparsThreshold, bool topK);
void fusedSparsification(vector<Tensor> &t, float sparsThreshold, bool topK);
- void sparsification(Tensor &t, Tensor &accumulation, float sparsThreshold, bool topK);
+ void sparsification(Tensor &t, Tensor &accumulation, float sparsThreshold,
+ bool topK);
void sparsification(Tensor &t, float sparsThreshold, bool topK);
void wait();
-private:
- void allReduce(int size, void* sendbuff, void* recvbuff, ncclDataType_t ncclType);
+ private:
+ void allReduce(int size, void *sendbuff, void *recvbuff,
+ ncclDataType_t ncclType);
void setup();
void sparsInit();
- void _fusedSparsification(vector<Tensor> &t, Tensor* accumulation, float sparsThreshold, bool topK);
- void _sparsification(Tensor &t, Tensor* accumulation, float sparsThreshold, bool topK);
- void valSparsAllReduce(size_t num, float* accumulation);
- void topKSparsAllReduce(size_t num, float* accumulation);
+ void _fusedSparsification(vector<Tensor> &t, Tensor *accumulation,
+ float sparsThreshold, bool topK);
+ void _sparsification(Tensor &t, Tensor *accumulation, float sparsThreshold,
+ bool topK);
+ void valSparsAllReduce(size_t num, float *accumulation);
+ void topKSparsAllReduce(size_t num, float *accumulation);
float *fusedSendBuff;
float *fusedRecvBuff;
@@ -130,11 +138,8 @@ private:
float *sparsRecvBuff;
float *backupBuff;
int *fusedIndex;
-
};
-
-
}
-#endif // USE_DIST
+#endif // USE_DIST
#endif
diff --git a/src/io/communicator.cc b/src/io/communicator.cc
index 30b944b..1690308 100644
--- a/src/io/communicator.cc
+++ b/src/io/communicator.cc
@@ -24,45 +24,41 @@
#include "singa/io/communicator.h"
#include "./math_kernel.h"
-namespace singa{
+namespace singa {
-static uint64_t getHostHash(const char* string) {
+static uint64_t getHostHash(const char *string) {
// Based on DJB2, result = result * 33 + char
uint64_t result = 5381;
- for (int c = 0; string[c] != '\0'; c++){
+ for (int c = 0; string[c] != '\0'; c++) {
result = ((result << 5) + result) + string[c];
}
return result;
}
-
-static void getHostName(char* hostname, int maxlen) {
+static void getHostName(char *hostname, int maxlen) {
gethostname(hostname, maxlen);
- for (int i=0; i< maxlen; i++) {
+ for (int i = 0; i < maxlen; i++) {
if (hostname[i] == '.') {
- hostname[i] = '\0';
- return;
+ hostname[i] = '\0';
+ return;
}
}
}
-NcclIdHolder::NcclIdHolder(){
- ncclGetUniqueId(&id);
-} // end of constructor
+NcclIdHolder::NcclIdHolder() { ncclGetUniqueId(&id); } // end of constructor
-NcclIdHolder::~NcclIdHolder(){
-}
+NcclIdHolder::~NcclIdHolder() {}
// contructer for application with python multi-processing module
-Communicator::Communicator(int gpu_num, int gpu_per_node, const NcclIdHolder &holder, int buffSize){
-
- maxSize = (size_t) buffSize;
+Communicator::Communicator(int gpu_num, int gpu_per_node,
+ const NcclIdHolder &holder, int buffSize) {
+ maxSize = (size_t)buffSize;
// this contructor is for NCCL WITHOUT MPI
UseMPI = false;
// Determine the rank of the collective communication
- totalMPIRanksInGlobal=gpu_per_node;
- MPIRankInLocal=gpu_num;
- MPIRankInGlobal=gpu_num;
+ totalMPIRanksInGlobal = gpu_per_node;
+ MPIRankInLocal = gpu_num;
+ MPIRankInGlobal = gpu_num;
// copy the nccl unqiue id from the input id holder
id = holder.id;
@@ -70,12 +66,11 @@ Communicator::Communicator(int gpu_num, int gpu_per_node, const NcclIdHolder &ho
// setup cuda stream and nccl communicator
setup();
-} // end of constructor
+} // end of constructor
// contructer for application with MPI
-Communicator::Communicator(int buffSize){
-
- maxSize = (size_t) buffSize;
+Communicator::Communicator(int buffSize) {
+ maxSize = (size_t)buffSize;
// this contructor is for NCCL WITH MPI
UseMPI = true;
@@ -85,31 +80,31 @@ Communicator::Communicator(int buffSize){
MPICHECK(MPI_Comm_size(MPI_COMM_WORLD, &totalMPIRanksInGlobal));
// calculating MPIRankInLocal which is used in selecting a GPU
- MPIRankInLocal=0;
+ MPIRankInLocal = 0;
uint64_t hostHashs[totalMPIRanksInGlobal];
char hostname[1024];
getHostName(hostname, 1024);
hostHashs[MPIRankInGlobal] = getHostHash(hostname);
MPICHECK(MPI_Allgather(MPI_IN_PLACE, 0, MPI_DATATYPE_NULL, hostHashs,
- sizeof(uint64_t), MPI_BYTE, MPI_COMM_WORLD));
- for (int p=0; p<totalMPIRanksInGlobal; p++) {
- if (p == MPIRankInGlobal) break;
- if (hostHashs[p] == hostHashs[MPIRankInGlobal]) MPIRankInLocal++;
+ sizeof(uint64_t), MPI_BYTE, MPI_COMM_WORLD));
+ for (int p = 0; p < totalMPIRanksInGlobal; p++) {
+ if (p == MPIRankInGlobal) break;
+ if (hostHashs[p] == hostHashs[MPIRankInGlobal]) MPIRankInLocal++;
}
// generating NCCL unique nccl ID at one process and broadcasting it to all
if (MPIRankInGlobal == 0) ncclGetUniqueId(&id);
MPICHECK(MPI_Bcast((void *)&id, sizeof(id), MPI_BYTE, 0, MPI_COMM_WORLD));
- // setup cuda stream and nccl communicator
+ // setup cuda stream and nccl communicator
setup();
-} // end of constructor
-
-void Communicator::setup(){
+} // end of constructor
+void Communicator::setup() {
CUDA_CHECK(cudaSetDevice(MPIRankInLocal));
- NCCLCHECK(ncclCommInitRank(&comm, totalMPIRanksInGlobal, id, MPIRankInGlobal));
+ NCCLCHECK(
+ ncclCommInitRank(&comm, totalMPIRanksInGlobal, id, MPIRankInGlobal));
CUDA_CHECK(cudaStreamCreateWithFlags(&s, cudaStreamNonBlocking));
CUDA_CHECK(cudaStreamCreateWithFlags(&c1, cudaStreamNonBlocking));
CUDA_CHECK(cudaStreamCreateWithFlags(&c2, cudaStreamNonBlocking));
@@ -117,45 +112,38 @@ void Communicator::setup(){
CUDA_CHECK(cudaMalloc(&fusedRecvBuff, maxSize * sizeof(float)));
CUDA_CHECK(cudaMalloc(&fusedSendBuffHalf, maxSize * sizeof(__half)));
CUDA_CHECK(cudaMalloc(&fusedRecvBuffHalf, maxSize * sizeof(__half)));
- CUDA_CHECK(cudaEventCreateWithFlags(&event, cudaEventBlockingSync | cudaEventDisableTiming));
+ CUDA_CHECK(cudaEventCreateWithFlags(
+ &event, cudaEventBlockingSync | cudaEventDisableTiming));
sparsInitialized = false;
}
-void Communicator::sparsInit(){
-
- //initize sparsification environment
+void Communicator::sparsInit() {
+ // initize sparsification environment
CUDA_CHECK(cudaSetDevice(MPIRankInLocal));
- CUDA_CHECK(cudaMalloc(&sparsRecvBuff, (int) (maxSize * sizeof(float) * totalMPIRanksInGlobal)));
- CUDA_CHECK(cudaMalloc(&sparsSendBuff, (int) (maxSize * sizeof(float))));
+ CUDA_CHECK(cudaMalloc(
+ &sparsRecvBuff, (int)(maxSize * sizeof(float) * totalMPIRanksInGlobal)));
+ CUDA_CHECK(cudaMalloc(&sparsSendBuff, (int)(maxSize * sizeof(float))));
CUDA_CHECK(cudaMalloc(&backupBuff, maxSize * sizeof(float)));
CUDA_CHECK(cudaMalloc(&fusedIndex, maxSize * sizeof(int)));
- CUDA_CHECK(cudaMalloc(&xInd, (int) (sizeof(int) * maxSize)));
- CUDA_CHECK(cudaMalloc(&xVal, (int) (sizeof(float) * maxSize)));
+ CUDA_CHECK(cudaMalloc(&xInd, (int)(sizeof(int) * maxSize)));
+ CUDA_CHECK(cudaMalloc(&xVal, (int)(sizeof(float) * maxSize)));
CUSPARSE_CHECK(cusparseCreate(&cusparse_handle));
CUSPARSE_CHECK(cusparseSetStream(cusparse_handle, c2));
- nnz = (int*) malloc(sizeof(int));
- nnzAll = (int*) malloc(sizeof(int) * totalMPIRanksInGlobal);
+ nnz = (int *)malloc(sizeof(int));
+ nnzAll = (int *)malloc(sizeof(int) * totalMPIRanksInGlobal);
CUDA_CHECK(cudaMalloc(&nnzGPU, sizeof(int) * totalMPIRanksInGlobal));
CUDA_CHECK(cudaMalloc(&nnzAllGPU, sizeof(int) * totalMPIRanksInGlobal));
sparsInitialized = true;
-
}
-void Communicator::allReduce(int size, void* sendbuff, void* recvbuff, ncclDataType_t ncclType)
-{
-
- NCCLCHECK(ncclAllReduce((const void*)sendbuff,
- (void*)recvbuff,
- size,
- ncclType,
- ncclSum,
- comm,
- s));
-
+void Communicator::allReduce(int size, void *sendbuff, void *recvbuff,
+ ncclDataType_t ncclType) {
+ NCCLCHECK(ncclAllReduce((const void *)sendbuff, (void *)recvbuff, size,
+ ncclType, ncclSum, comm, s));
}
-void Communicator::wait(){
- //synchronizing on all the CUDA streams used by communicator
+void Communicator::wait() {
+ // synchronizing on all the CUDA streams used by communicator
CUDA_CHECK(cudaEventRecord(event, s));
CUDA_CHECK(cudaStreamWaitEvent(NULL, event, 0));
CUDA_CHECK(cudaEventRecord(event, c1));
@@ -164,8 +152,8 @@ void Communicator::wait(){
CUDA_CHECK(cudaStreamWaitEvent(NULL, event, 0));
}
-Communicator::~Communicator(){
- //finalizing NCCL
+Communicator::~Communicator() {
+ // finalizing NCCL
ncclCommDestroy(comm);
if (UseMPI == true) MPICHECK(MPI_Finalize());
CUDA_CHECK(cudaFree(fusedSendBuff));
@@ -186,22 +174,21 @@ Communicator::~Communicator(){
CUDA_CHECK(cudaFree(nnzGPU));
CUDA_CHECK(cudaFree(nnzAllGPU));
}
-
}
-
-void Communicator::fusedSynch(vector<Tensor> &t){
-
+void Communicator::fusedSynch(vector<Tensor> &t) {
// record the event of the default cuda stream and follow it
CUDA_CHECK(cudaEventRecord(event, NULL));
CUDA_CHECK(cudaStreamWaitEvent(c1, event, 0));
-
+
size_t offset = 0;
- //memory copy to fusedBuff
- for (size_t i = 0; i < t.size(); i++)
- {
- CUDA_CHECK(cudaMemcpyAsync((void*) (fusedSendBuff + offset), (const void*) t[i].block()->mutable_data(), t[i].Size() * sizeof(float), cudaMemcpyDeviceToDevice, c1));
+ // memory copy to fusedBuff
+ for (size_t i = 0; i < t.size(); i++) {
+ CUDA_CHECK(cudaMemcpyAsync((void *)(fusedSendBuff + offset),
+ (const void *)t[i].block()->mutable_data(),
+ t[i].Size() * sizeof(float),
+ cudaMemcpyDeviceToDevice, c1));
offset += t[i].Size();
}
@@ -209,45 +196,46 @@ void Communicator::fusedSynch(vector<Tensor> &t){
CUDA_CHECK(cudaEventRecord(event, c1));
CUDA_CHECK(cudaStreamWaitEvent(s, event, 0));
- allReduce((int) offset, (void*) fusedSendBuff, (void*) fusedRecvBuff, ncclFloat);
+ allReduce((int)offset, (void *)fusedSendBuff, (void *)fusedRecvBuff,
+ ncclFloat);
// wait for the allreduce to complete
CUDA_CHECK(cudaEventRecord(event, s));
CUDA_CHECK(cudaStreamWaitEvent(c1, event, 0));
- //copy data back to tensors after allreduce
+ // copy data back to tensors after allreduce
offset = 0;
- for (size_t i = 0; i < t.size(); i++)
- {
- CUDA_CHECK(cudaMemcpyAsync((void*) t[i].block()->mutable_data(), (const void*) (fusedRecvBuff + offset), t[i].Size() * sizeof(float), cudaMemcpyDeviceToDevice, c1));
+ for (size_t i = 0; i < t.size(); i++) {
+ CUDA_CHECK(cudaMemcpyAsync((void *)t[i].block()->mutable_data(),
+ (const void *)(fusedRecvBuff + offset),
+ t[i].Size() * sizeof(float),
+ cudaMemcpyDeviceToDevice, c1));
offset += t[i].Size();
}
-
}
-void Communicator::synch(Tensor &t){
-
+void Communicator::synch(Tensor &t) {
// record the event of the default cuda stream and follow it
CUDA_CHECK(cudaEventRecord(event, NULL));
CUDA_CHECK(cudaStreamWaitEvent(s, event, 0));
- void* addr = t.block()->mutable_data();
+ void *addr = t.block()->mutable_data();
allReduce(t.Size(), addr, addr, ncclFloat);
-
}
-void Communicator::fusedSynchHalf(vector<Tensor> &t){
-
+void Communicator::fusedSynchHalf(vector<Tensor> &t) {
// record the event of the default cuda stream and follow it
CUDA_CHECK(cudaEventRecord(event, NULL));
CUDA_CHECK(cudaStreamWaitEvent(c1, event, 0));
-
+
size_t offset = 0;
- //memory copy to fusedBuff
- for (size_t i = 0; i < t.size(); i++)
- {
- CUDA_CHECK(cudaMemcpyAsync((void*) (fusedSendBuff + offset), (const void*) t[i].block()->mutable_data(), t[i].Size() * sizeof(float), cudaMemcpyDeviceToDevice, c1));
+ // memory copy to fusedBuff
+ for (size_t i = 0; i < t.size(); i++) {
+ CUDA_CHECK(cudaMemcpyAsync((void *)(fusedSendBuff + offset),
+ (const void *)t[i].block()->mutable_data(),
+ t[i].Size() * sizeof(float),
+ cudaMemcpyDeviceToDevice, c1));
offset += t[i].Size();
}
@@ -257,7 +245,8 @@ void Communicator::fusedSynchHalf(vector<Tensor> &t){
CUDA_CHECK(cudaEventRecord(event, c1));
CUDA_CHECK(cudaStreamWaitEvent(s, event, 0));
- allReduce((int) offset, (void*) fusedSendBuffHalf, (void*) fusedRecvBuffHalf, ncclHalf);
+ allReduce((int)offset, (void *)fusedSendBuffHalf, (void *)fusedRecvBuffHalf,
+ ncclHalf);
// wait for the allreduce to complete
CUDA_CHECK(cudaEventRecord(event, s));
@@ -265,19 +254,19 @@ void Communicator::fusedSynchHalf(vector<Tensor> &t){
cuda::half2float(offset, fusedRecvBuffHalf, fusedRecvBuff, c2);
- //copy data back to tensors after allreduce
+ // copy data back to tensors after allreduce
offset = 0;
- for (size_t i = 0; i < t.size(); i++)
- {
- CUDA_CHECK(cudaMemcpyAsync((void*) t[i].block()->mutable_data(), (const void*) (fusedRecvBuff + offset), t[i].Size() * sizeof(float), cudaMemcpyDeviceToDevice, c2));
+ for (size_t i = 0; i < t.size(); i++) {
+ CUDA_CHECK(cudaMemcpyAsync((void *)t[i].block()->mutable_data(),
+ (const void *)(fusedRecvBuff + offset),
+ t[i].Size() * sizeof(float),
+ cudaMemcpyDeviceToDevice, c2));
offset += t[i].Size();
}
-
}
-void Communicator::synchHalf(Tensor &t){
-
- float* addr = static_cast<float*>(t.block()->mutable_data());
+void Communicator::synchHalf(Tensor &t) {
+ float *addr = static_cast<float *>(t.block()->mutable_data());
// record the event of the default cuda stream and follow it
CUDA_CHECK(cudaEventRecord(event, NULL));
@@ -289,26 +278,27 @@ void Communicator::synchHalf(Tensor &t){
CUDA_CHECK(cudaEventRecord(event, c1));
CUDA_CHECK(cudaStreamWaitEvent(s, event, 0));
- allReduce(t.Size(), (void*) fusedSendBuffHalf, (void*) fusedRecvBuffHalf, ncclHalf);
+ allReduce(t.Size(), (void *)fusedSendBuffHalf, (void *)fusedRecvBuffHalf,
+ ncclHalf);
// wait for the allreduce to complete
CUDA_CHECK(cudaEventRecord(event, s));
CUDA_CHECK(cudaStreamWaitEvent(c2, event, 0));
cuda::half2float(t.Size(), fusedRecvBuffHalf, addr, c2);
-
}
-void Communicator::sparsification(Tensor &t, Tensor &accumulation, float sparsThreshold, bool topK){
+void Communicator::sparsification(Tensor &t, Tensor &accumulation,
+ float sparsThreshold, bool topK) {
_sparsification(t, &accumulation, sparsThreshold, topK);
}
-void Communicator::sparsification(Tensor &t, float sparsThreshold, bool topK){
- _sparsification(t, (Tensor *) NULL, sparsThreshold, topK);
+void Communicator::sparsification(Tensor &t, float sparsThreshold, bool topK) {
+ _sparsification(t, (Tensor *)NULL, sparsThreshold, topK);
}
-void Communicator::_sparsification(Tensor &t, Tensor* accumulation, float sparsThreshold, bool topK){
-
+void Communicator::_sparsification(Tensor &t, Tensor *accumulation,
+ float sparsThreshold, bool topK) {
// threshold for sprasification
threshold = sparsThreshold;
@@ -316,85 +306,92 @@ void Communicator::_sparsification(Tensor &t, Tensor* accumulation, float sparsT
CUDA_CHECK(cudaEventRecord(event, NULL));
CUDA_CHECK(cudaStreamWaitEvent(c1, event, 0));
- //memory copy to fusedBuff
- CUDA_CHECK(cudaMemcpyAsync((void*) fusedSendBuff, (const void*) t.block()->mutable_data(), t.Size() * sizeof(float), cudaMemcpyDeviceToDevice, c1));
+ // memory copy to fusedBuff
+ CUDA_CHECK(cudaMemcpyAsync(
+ (void *)fusedSendBuff, (const void *)t.block()->mutable_data(),
+ t.Size() * sizeof(float), cudaMemcpyDeviceToDevice, c1));
float *accumPtr;
if (accumulation != NULL)
- accumPtr = (float*) accumulation->block()->mutable_data();
+ accumPtr = (float *)accumulation->block()->mutable_data();
else
- accumPtr = NULL;
+ accumPtr = NULL;
if (topK == false)
valSparsAllReduce(t.Size(), accumPtr);
else
topKSparsAllReduce(t.Size(), accumPtr);
- //copy data back to tensor after allreduce
- CUDA_CHECK(cudaMemcpyAsync((void*) t.block()->mutable_data(), (const void*) fusedRecvBuff, t.Size() * sizeof(float), cudaMemcpyDeviceToDevice, c2));
-
+ // copy data back to tensor after allreduce
+ CUDA_CHECK(cudaMemcpyAsync(
+ (void *)t.block()->mutable_data(), (const void *)fusedRecvBuff,
+ t.Size() * sizeof(float), cudaMemcpyDeviceToDevice, c2));
}
-void Communicator::fusedSparsification(vector<Tensor> &t, Tensor &accumulation, float sparsThreshold, bool topK){
+void Communicator::fusedSparsification(vector<Tensor> &t, Tensor &accumulation,
+ float sparsThreshold, bool topK) {
_fusedSparsification(t, &accumulation, sparsThreshold, topK);
}
-void Communicator::fusedSparsification(vector<Tensor> &t, float sparsThreshold, bool topK){
- _fusedSparsification(t, (Tensor *) NULL, sparsThreshold, topK);
+void Communicator::fusedSparsification(vector<Tensor> &t, float sparsThreshold,
+ bool topK) {
+ _fusedSparsification(t, (Tensor *)NULL, sparsThreshold, topK);
}
-void Communicator::_fusedSparsification(vector<Tensor> &t, Tensor* accumulation, float sparsThreshold, bool topK){
-
+void Communicator::_fusedSparsification(vector<Tensor> &t, Tensor *accumulation,
+ float sparsThreshold, bool topK) {
// threshold for sprasification
threshold = sparsThreshold;
// record the event of the default cuda stream and follow it
CUDA_CHECK(cudaEventRecord(event, NULL));
CUDA_CHECK(cudaStreamWaitEvent(c1, event, 0));
-
+
size_t offset = 0;
- //memory copy to fusedBuff
- for (size_t i = 0; i < t.size(); i++)
- {
- CUDA_CHECK(cudaMemcpyAsync((void*) (fusedSendBuff + offset), (const void*) t[i].block()->mutable_data(), t[i].Size() * sizeof(float), cudaMemcpyDeviceToDevice, c1));
+ // memory copy to fusedBuff
+ for (size_t i = 0; i < t.size(); i++) {
+ CUDA_CHECK(cudaMemcpyAsync((void *)(fusedSendBuff + offset),
+ (const void *)t[i].block()->mutable_data(),
+ t[i].Size() * sizeof(float),
+ cudaMemcpyDeviceToDevice, c1));
offset += t[i].Size();
}
float *accumPtr;
if (accumulation != NULL)
- accumPtr = (float*) accumulation->block()->mutable_data();
+ accumPtr = (float *)accumulation->block()->mutable_data();
else
- accumPtr = NULL;
+ accumPtr = NULL;
if (topK == false)
valSparsAllReduce(offset, accumPtr);
else
- topKSparsAllReduce(offset, accumPtr);
+ topKSparsAllReduce(offset, accumPtr);
- //copy data back to tensors after allreduce
+ // copy data back to tensors after allreduce
offset = 0;
- for (size_t i = 0; i < t.size(); i++)
- {
- CUDA_CHECK(cudaMemcpyAsync((void*) t[i].block()->mutable_data(), (const void*) (fusedRecvBuff + offset), t[i].Size() * sizeof(float), cudaMemcpyDeviceToDevice, c2));
+ for (size_t i = 0; i < t.size(); i++) {
+ CUDA_CHECK(cudaMemcpyAsync((void *)t[i].block()->mutable_data(),
+ (const void *)(fusedRecvBuff + offset),
+ t[i].Size() * sizeof(float),
+ cudaMemcpyDeviceToDevice, c2));
offset += t[i].Size();
}
-
}
-void Communicator::valSparsAllReduce(size_t num, float* accumulation){
-
- if (sparsInitialized == false)
- sparsInit();
+void Communicator::valSparsAllReduce(size_t num, float *accumulation) {
+ if (sparsInitialized == false) sparsInit();
- if (accumulation != NULL)
- {
+ if (accumulation != NULL) {
// add the previous accumulation
- cuda::add(num, fusedSendBuff, accumulation, fusedSendBuff, c1);
+ cuda::add(num, fusedSendBuff, accumulation, fusedSendBuff, c1);
// backup the fusedSendBuff
- CUDA_CHECK(cudaMemcpyAsync((void*) backupBuff, (const void*) fusedSendBuff, sizeof(float) * num, cudaMemcpyDeviceToDevice, c1));
+ CUDA_CHECK(cudaMemcpyAsync((void *)backupBuff, (const void *)fusedSendBuff,
+ sizeof(float) * num, cudaMemcpyDeviceToDevice,
+ c1));
}
// sparsification based on threshold
@@ -402,7 +399,7 @@ void Communicator::valSparsAllReduce(size_t num, float* accumulation){
// output the gradient accumulation
if (accumulation != NULL)
- cuda::sub(num, backupBuff, fusedSendBuff, accumulation, c1);
+ cuda::sub(num, backupBuff, fusedSendBuff, accumulation, c1);
// produce the index of the sparse array
cuda::sparsindex(num, fusedSendBuff, fusedIndex, c1);
@@ -410,83 +407,79 @@ void Communicator::valSparsAllReduce(size_t num, float* accumulation){
// remove zero of index to become sprase array and get the num of non-zero nnz
cuda::removezeroidx(num, fusedIndex, c1, nnz);
- CUDA_CHECK(cudaMemcpyAsync((void*) nnzGPU, (const void*) nnz, sizeof(int), cudaMemcpyHostToDevice, c1));
+ CUDA_CHECK(cudaMemcpyAsync((void *)nnzGPU, (const void *)nnz, sizeof(int),
+ cudaMemcpyHostToDevice, c1));
// all-gather all the nnz from different ranks
- NCCLCHECK(ncclAllGather((const void*)nnzGPU,
- (void*)nnzAllGPU,
- 1,
- ncclInt,
- comm,
- c1));
+ NCCLCHECK(ncclAllGather((const void *)nnzGPU, (void *)nnzAllGPU, 1, ncclInt,
+ comm, c1));
- CUDA_CHECK(cudaMemcpyAsync((void*) nnzAll, (const void*) nnzAllGPU, sizeof(int) * totalMPIRanksInGlobal, cudaMemcpyDeviceToHost, c1));
+ CUDA_CHECK(cudaMemcpyAsync((void *)nnzAll, (const void *)nnzAllGPU,
+ sizeof(int) * totalMPIRanksInGlobal,
+ cudaMemcpyDeviceToHost, c1));
CUDA_CHECK(cudaStreamSynchronize(c1));
int nnzMax = 0;
for (int i = 0; i < totalMPIRanksInGlobal; i++)
- if(nnzAll[i] > nnzMax)
- nnzMax = nnzAll[i];
+ if (nnzAll[i] > nnzMax) nnzMax = nnzAll[i];
// remove zero of values to become sprase array
cuda::removezeroval(num, fusedSendBuff, c1);
- CUDA_CHECK(cudaMemcpyAsync((void*) (sparsSendBuff), (const void*) fusedIndex, sizeof(int) * (*nnz), cudaMemcpyDeviceToDevice, c1));
- CUDA_CHECK(cudaMemcpyAsync((void*) (sparsSendBuff + (*nnz)), (const void*) fusedSendBuff, sizeof(float) * (*nnz), cudaMemcpyDeviceToDevice, c1));
+ CUDA_CHECK(cudaMemcpyAsync((void *)(sparsSendBuff), (const void *)fusedIndex,
+ sizeof(int) * (*nnz), cudaMemcpyDeviceToDevice,
+ c1));
+ CUDA_CHECK(cudaMemcpyAsync(
+ (void *)(sparsSendBuff + (*nnz)), (const void *)fusedSendBuff,
+ sizeof(float) * (*nnz), cudaMemcpyDeviceToDevice, c1));
// wait for the memcpy to complete
CUDA_CHECK(cudaEventRecord(event, c1));
CUDA_CHECK(cudaStreamWaitEvent(s, event, 0));
// all-gather all the sparse gradients
- NCCLCHECK(ncclAllGather((const void*)sparsSendBuff,
- (void*)sparsRecvBuff,
- 2 * nnzMax,
- ncclFloat,
- comm,
- s));
+ NCCLCHECK(ncclAllGather((const void *)sparsSendBuff, (void *)sparsRecvBuff,
+ 2 * nnzMax, ncclFloat, comm, s));
// wait for the all-gather to complete
CUDA_CHECK(cudaEventRecord(event, s));
CUDA_CHECK(cudaStreamWaitEvent(c2, event, 0));
// reduce the sparse gradients, firstly setting the sum buff value to zero
- CUDA_CHECK(cudaMemsetAsync(fusedRecvBuff, 0, num *sizeof(float) , c2));
+ CUDA_CHECK(cudaMemsetAsync(fusedRecvBuff, 0, num * sizeof(float), c2));
size_t offset = 0;
float alpha = 1.0;
- // add the spase gradent from each rank to the sum buff to finish the all-reduce process
- for (int i = 0; i < totalMPIRanksInGlobal; i++)
- {
- CUDA_CHECK(cudaMemcpyAsync((void*) xInd, (const void*) (sparsRecvBuff + offset), sizeof(int) * nnzAll[i], cudaMemcpyDeviceToDevice, c2));
- offset += nnzAll[i];
- CUDA_CHECK(cudaMemcpyAsync((void*) xVal, (const void*) (sparsRecvBuff + offset), sizeof(float) * nnzAll[i], cudaMemcpyDeviceToDevice, c2));
- offset += (2 * nnzMax - nnzAll[i]);
- CUSPARSE_CHECK(cusparseSaxpyi(cusparse_handle,
- nnzAll[i],
- &alpha,
- xVal,
- xInd,
- fusedRecvBuff,
- CUSPARSE_INDEX_BASE_ONE));
+ // add the spase gradent from each rank to the sum buff to finish the
+ // all-reduce process
+ for (int i = 0; i < totalMPIRanksInGlobal; i++) {
+ CUDA_CHECK(
+ cudaMemcpyAsync((void *)xInd, (const void *)(sparsRecvBuff + offset),
+ sizeof(int) * nnzAll[i], cudaMemcpyDeviceToDevice, c2));
+ offset += nnzAll[i];
+ CUDA_CHECK(cudaMemcpyAsync(
+ (void *)xVal, (const void *)(sparsRecvBuff + offset),
+ sizeof(float) * nnzAll[i], cudaMemcpyDeviceToDevice, c2));
+ offset += (2 * nnzMax - nnzAll[i]);
+ CUSPARSE_CHECK(cusparseSaxpyi(cusparse_handle, nnzAll[i], &alpha, xVal,
+ xInd, fusedRecvBuff,
+ CUSPARSE_INDEX_BASE_ONE));
}
-
}
-void Communicator::topKSparsAllReduce(size_t num, float* accumulation){
-
- if (sparsInitialized == false)
- sparsInit();
+void Communicator::topKSparsAllReduce(size_t num, float *accumulation) {
+ if (sparsInitialized == false) sparsInit();
// use gradient accumulation
- if (accumulation != NULL)
- {
+ if (accumulation != NULL) {
// add the previous accumulation
- cuda::add(num, fusedSendBuff, accumulation, fusedSendBuff, c1);
+ cuda::add(num, fusedSendBuff, accumulation, fusedSendBuff, c1);
// backup the fusedSendBuff
- CUDA_CHECK(cudaMemcpyAsync((void*) backupBuff, (const void*) fusedSendBuff, sizeof(float) * num, cudaMemcpyDeviceToDevice, c1));
+ CUDA_CHECK(cudaMemcpyAsync((void *)backupBuff, (const void *)fusedSendBuff,
+ sizeof(float) * num, cudaMemcpyDeviceToDevice,
+ c1));
}
// generate an index and sort the fusedSendBuff from large to small values
@@ -494,70 +487,61 @@ void Communicator::topKSparsAllReduce(size_t num, float* accumulation){
cuda::sortbykey(num, fusedSendBuff, fusedIndex, c1);
// determine the number of topK for communication
- int nnzMax = (int) ceil(threshold * num);
+ int nnzMax = (int)ceil(threshold * num);
// output the gradient accumulation
float alpha = 1.0;
- if (accumulation != NULL)
- {
- CUDA_CHECK(cudaMemsetAsync(accumulation, 0, num * sizeof(float) , c1));
- CUSPARSE_CHECK(cusparseSetStream(cusparse_handle, c1));
- CUSPARSE_CHECK(cusparseSaxpyi(cusparse_handle,
- nnzMax,
- &alpha,
- fusedSendBuff,
- fusedIndex,
- accumulation,
- CUSPARSE_INDEX_BASE_ONE));
- cuda::sub(num, backupBuff, accumulation, accumulation, c1);
+ if (accumulation != NULL) {
+ CUDA_CHECK(cudaMemsetAsync(accumulation, 0, num * sizeof(float), c1));
+ CUSPARSE_CHECK(cusparseSetStream(cusparse_handle, c1));
+ CUSPARSE_CHECK(cusparseSaxpyi(cusparse_handle, nnzMax, &alpha,
+ fusedSendBuff, fusedIndex, accumulation,
+ CUSPARSE_INDEX_BASE_ONE));
+ cuda::sub(num, backupBuff, accumulation, accumulation, c1);
}
// the topK value and index will be sent
- CUDA_CHECK(cudaMemcpyAsync((void*) (sparsSendBuff), (const void*) fusedIndex, sizeof(int) * nnzMax, cudaMemcpyDeviceToDevice, c1));
- CUDA_CHECK(cudaMemcpyAsync((void*) (sparsSendBuff + nnzMax), (const void*) fusedSendBuff, sizeof(float) * nnzMax, cudaMemcpyDeviceToDevice, c1));
+ CUDA_CHECK(cudaMemcpyAsync((void *)(sparsSendBuff), (const void *)fusedIndex,
+ sizeof(int) * nnzMax, cudaMemcpyDeviceToDevice,
+ c1));
+ CUDA_CHECK(cudaMemcpyAsync(
+ (void *)(sparsSendBuff + nnzMax), (const void *)fusedSendBuff,
+ sizeof(float) * nnzMax, cudaMemcpyDeviceToDevice, c1));
// wait for the memcpy to complete
CUDA_CHECK(cudaEventRecord(event, c1));
CUDA_CHECK(cudaStreamWaitEvent(s, event, 0));
// all-gather all the sparse gradients
- NCCLCHECK(ncclAllGather((const void*)sparsSendBuff,
- (void*)sparsRecvBuff,
- 2 * nnzMax,
- ncclFloat,
- comm,
- s));
+ NCCLCHECK(ncclAllGather((const void *)sparsSendBuff, (void *)sparsRecvBuff,
+ 2 * nnzMax, ncclFloat, comm, s));
// wait for the all-gather to complete
CUDA_CHECK(cudaEventRecord(event, s));
CUDA_CHECK(cudaStreamWaitEvent(c2, event, 0));
// reduce the sparse gradients, firstly setting the sum buff value to zero
- CUDA_CHECK(cudaMemsetAsync(fusedRecvBuff, 0, num *sizeof(float) , c2));
+ CUDA_CHECK(cudaMemsetAsync(fusedRecvBuff, 0, num * sizeof(float), c2));
size_t offset = 0;
CUSPARSE_CHECK(cusparseSetStream(cusparse_handle, c2));
- // add the spase gradent from each rank to the sum buff to finish the all-reduce process
- for (int i = 0; i < totalMPIRanksInGlobal; i++)
- {
- CUDA_CHECK(cudaMemcpyAsync((void*) xInd, (const void*) (sparsRecvBuff + offset), sizeof(int) * nnzMax, cudaMemcpyDeviceToDevice, c2));
+ // add the spase gradent from each rank to the sum buff to finish the
+ // all-reduce process
+ for (int i = 0; i < totalMPIRanksInGlobal; i++) {
+ CUDA_CHECK(
+ cudaMemcpyAsync((void *)xInd, (const void *)(sparsRecvBuff + offset),
+ sizeof(int) * nnzMax, cudaMemcpyDeviceToDevice, c2));
offset += nnzMax;
- CUDA_CHECK(cudaMemcpyAsync((void*) xVal, (const void*) (sparsRecvBuff + offset), sizeof(float) * nnzMax, cudaMemcpyDeviceToDevice, c2));
+ CUDA_CHECK(
+ cudaMemcpyAsync((void *)xVal, (const void *)(sparsRecvBuff + offset),
+ sizeof(float) * nnzMax, cudaMemcpyDeviceToDevice, c2));
offset += nnzMax;
- CUSPARSE_CHECK(cusparseSaxpyi(cusparse_handle,
- nnzMax,
- &alpha,
- xVal,
- xInd,
- fusedRecvBuff,
- CUSPARSE_INDEX_BASE_ONE));
+ CUSPARSE_CHECK(cusparseSaxpyi(cusparse_handle, nnzMax, &alpha, xVal, xInd,
+ fusedRecvBuff, CUSPARSE_INDEX_BASE_ONE));
}
-
}
-
-
}
-#endif // USE_DIST
+#endif // USE_DIST