You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@singa.apache.org by wa...@apache.org on 2019/11/20 03:07:22 UTC
[singa] branch master updated: SINGA-487 Add support of gradient
compression to half precision
This is an automated email from the ASF dual-hosted git repository.
wangwei pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/singa.git
The following commit(s) were added to refs/heads/master by this push:
new 859785f SINGA-487 Add support of gradient compression to half precision
new 00f80e1 Merge pull request #562 from chrishkchris/SINGA-487_4
859785f is described below
commit 859785fc3400734a2fe26956433e8bf017dc5b3c
Author: chrishkchris <ch...@yahoo.com.hk>
AuthorDate: Tue Nov 19 10:13:04 2019 +0000
SINGA-487 Add support of gradient compression to half precision
---
include/singa/io/communicator.h | 12 ++++-
python/singa/opt.py | 41 ++++++++++++++++
src/api/dist_communicator.i | 2 +
src/core/tensor/math_kernel.cu | 98 ++++++++++++++++++++++---------------
src/core/tensor/math_kernel.h | 7 ++-
src/io/communicator.cc | 104 ++++++++++++++++++++++++++++++++++------
6 files changed, 209 insertions(+), 55 deletions(-)
diff --git a/include/singa/io/communicator.h b/include/singa/io/communicator.h
index 802b972..e3717fc 100644
--- a/include/singa/io/communicator.h
+++ b/include/singa/io/communicator.h
@@ -30,6 +30,7 @@
#include <mpi.h>
#include "singa/core/tensor.h"
+#include "cuda_fp16.h"
using std::vector;
namespace singa{
@@ -67,11 +68,16 @@ public:
bool UseMPI;
float *fusedSendBuff;
float *fusedRecvBuff;
+ __half *fusedSendBuffHalf;
+ __half *fusedRecvBuffHalf;
size_t maxSize;
ncclUniqueId id;
+ // cuda stream s is for nccl all reduce
cudaStream_t s;
- cudaStream_t c;
+ // cuda streams c1 and c2 are mainly for data copy to and from memory buffers
+ cudaStream_t c1;
+ cudaStream_t c2;
ncclComm_t comm;
cudaEvent_t event;
@@ -80,10 +86,12 @@ public:
~Communicator();
void synch(Tensor &t);
void fusedSynch(vector<Tensor> &t);
+ void synchHalf(Tensor &t);
+ void fusedSynchHalf(vector<Tensor> &t);
void wait();
private:
- void allReduce(int size, void* sendbuff, void* recvbuff);
+ void allReduce(int size, void* sendbuff, void* recvbuff, ncclDataType_t ncclType);
void setup(int gpu_num);
};
diff --git a/python/singa/opt.py b/python/singa/opt.py
index 46c4c72..06286f2 100755
--- a/python/singa/opt.py
+++ b/python/singa/opt.py
@@ -171,6 +171,10 @@ class DistOpt(object):
# The class is designed to wrap an optimizer to do disttributed training.
# opt: The optimizer to be wrapped. nDev: number of devices(GPUs) a
# process will control/use.
+ # nccl_id: an nccl id holder object for a unique communication id
+ # gpu_num: the GPU id in a single node
+ # gpu_per_node: the number of GPUs in a single node
+ # buffSize: the buffSize used in nccl communicator, default is 16 MB
# world_size: total number of processes.
# rank_in_local: local rank of a process on the current node.
@@ -199,10 +203,19 @@ class DistOpt(object):
tensor = singa.VecTensor(tensor)
self.communicator.fusedSynch(tensor)
+ def all_reduce_half(self, tensor):
+ self.communicator.synchHalf(tensor)
+
+ def fused_all_reduce_half(self, tensor):
+ tensor = singa.VecTensor(tensor)
+ self.communicator.fusedSynchHalf(tensor)
+
def wait(self):
self.communicator.wait()
def backward_and_update(self, loss, threshold = 2097152):
+ # backward propagation from the loss and parameter update
+ # it applies tensor fusion which fuses all the tensor smaller than the threshold value
plist = []
acc = 0
glist = []
@@ -224,3 +237,31 @@ class DistOpt(object):
self.wait()
for p, g in plist:
self.update(p, g)
+
+ def backward_and_update_half(self, loss, threshold = 2097152, clipping = False, clip_Value = 100):
+ # THIS IS A EXPERIMENTAL FUNCTION FOR RESEARCH PURPOSE:
+ # It converts the gradients to 16 bits half precision format before allreduce
+ # To assist training, this functions provide an option to perform gradient clipping
+ plist = []
+ acc = 0
+ glist = []
+ for p, g in autograd.backward(loss):
+ if clipping:
+ g = autograd.clip(g, -clip_Value, clip_Value)
+ if g.size() > threshold:
+ # larger than threshold -> reduced directly
+ self.all_reduce_half(g.data)
+ else:
+ # smaller than threshold -> accumulate
+ glist.append(g.data)
+ acc += g.size()
+ if (acc > threshold):
+ self.fused_all_reduce_half(glist)
+ acc = 0
+ glist = []
+ plist.append((p, g))
+ if glist:
+ self.fused_all_reduce_half(glist)
+ self.wait()
+ for p, g in plist:
+ self.update(p, g)
diff --git a/src/api/dist_communicator.i b/src/api/dist_communicator.i
index 8c777f6..2da9602 100644
--- a/src/api/dist_communicator.i
+++ b/src/api/dist_communicator.i
@@ -47,6 +47,8 @@ public:
Communicator(int gpu_num, int gpu_per_node, const NcclIdHolder &holder, int limit);
void synch(Tensor &t);
void fusedSynch(std::vector<Tensor> &t);
+ void synchHalf(Tensor &t);
+ void fusedSynchHalf(std::vector<Tensor> &t);
void wait();
};
diff --git a/src/core/tensor/math_kernel.cu b/src/core/tensor/math_kernel.cu
index 2ce6531..b25455d 100644
--- a/src/core/tensor/math_kernel.cu
+++ b/src/core/tensor/math_kernel.cu
@@ -348,6 +348,20 @@ __global__ void KernelSoftmaxCrossEntropyBwd(const bool int_target, const size_t
}
}
+__global__ void KernelFloat2Half(const size_t n, const float *in, __half *out) {
+ for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < n;
+ i += blockDim.x * gridDim.x) {
+ out[i] = __float2half_rn(in[i]);
+ }
+}
+
+__global__ void KernelHalf2Float(const size_t n, const __half *in, float *out) {
+ for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < n;
+ i += blockDim.x * gridDim.x) {
+ out[i] = __half2float(in[i]);
+ }
+}
+
//cuda unary elementwise ops kernel template
#define GenUnaryCudaKernel(fn,kernelfn,cudafn) \
__global__ void kernelfn(const size_t n, const float *in, float *out) { \
@@ -357,7 +371,7 @@ __global__ void KernelSoftmaxCrossEntropyBwd(const bool int_target, const size_t
} \
} \
void fn(const size_t n, const float *in, float *out, cudaStream_t s) { \
- kernelfn <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (n, in, out); \
+ kernelfn <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF, 0, s>>> (n, in, out); \
}
GenUnaryCudaKernel(cos,KernelCos,cosf);
@@ -378,128 +392,136 @@ GenUnaryCudaKernel(atanh,KernelAtanh,atanhf);
// Functions call kernels
// ********************************
+void float2half(const size_t n, const float *in, __half *out, cudaStream_t s) {
+ KernelFloat2Half <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF, 0, s>>> (n, in, out);
+}
+
+void half2float(const size_t n, const __half *in, float *out, cudaStream_t s) {
+ KernelHalf2Float <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF, 0, s>>> (n, in, out);
+}
+
void set(const size_t n, const float v, float *out, cudaStream_t s) {
- KernelSet <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (n, v, out);
+ KernelSet <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF, 0, s>>> (n, v, out);
}
void abs(const size_t n, const float *in, float *out, cudaStream_t s) {
- KernelAbs <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (n, in, out);
+ KernelAbs <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF, 0, s>>> (n, in, out);
}
void sign(const size_t n, const float *in, float *out, cudaStream_t s) {
- KernelSign <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (n, in, out);
+ KernelSign <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF, 0, s>>> (n, in, out);
}
void exp(const size_t n, const float *in, float *out, cudaStream_t s) {
- KernelExp <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (n, in, out);
+ KernelExp <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF, 0, s>>> (n, in, out);
}
void log(const size_t n, const float *in, float *out, cudaStream_t s) {
- KernelLog <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (n, in, out);
+ KernelLog <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF, 0, s>>> (n, in, out);
}
void sqrt(const size_t n, const float *in, float *out, cudaStream_t s) {
- KernelSqrt <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (n, in, out);
+ KernelSqrt <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF, 0, s>>> (n, in, out);
}
void square(const size_t n, const float *in, float *out, cudaStream_t s) {
- KernelSquare <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (n, in, out);
+ KernelSquare <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF, 0, s>>> (n, in, out);
}
void relu(const size_t n, const float *in, float *out, cudaStream_t s) {
- KernelRelu <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (n, in, out);
+ KernelRelu <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF, 0, s>>> (n, in, out);
}
void sigmoid(const size_t n, const float *in, float *out, cudaStream_t s) {
- KernelSigmoid <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (n, in, out);
+ KernelSigmoid <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF, 0, s>>> (n, in, out);
}
void softplus(const size_t n, const float *in, float *out, cudaStream_t s) {
- KernelSoftplus <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (n, in, out);
+ KernelSoftplus <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF, 0, s>>> (n, in, out);
}
void clamp(const size_t n, const float low, const float high, const float *in,
float *out, cudaStream_t s) {
- KernelClamp <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (n, low, high, in, out);
+ KernelClamp <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF, 0, s>>> (n, low, high, in, out);
}
void pow(const size_t n, const float *in, const float x, float *out,
cudaStream_t s) {
- KernelPow <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (n, in, x, out);
+ KernelPow <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF, 0, s>>> (n, in, x, out);
}
void add(const size_t n, const float *in, const float x, float *out,
cudaStream_t s) {
- KernelAdd <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (n, in, x, out);
+ KernelAdd <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF, 0, s>>> (n, in, x, out);
}
void mult(const size_t n, const float *in, const float x, float *out,
cudaStream_t s) {
- KernelMult <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (n, in, x, out);
+ KernelMult <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF, 0, s>>> (n, in, x, out);
}
void div(const size_t n, const float x, const float *in, float *out,
cudaStream_t s) {
- KernelDiv <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (n, x, in, out);
+ KernelDiv <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF, 0, s>>> (n, x, in, out);
}
void threshold(const size_t n, const float x, const float *in, float *out,
cudaStream_t s) {
- KernelThreshold <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (n, x, in, out);
+ KernelThreshold <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF, 0, s>>> (n, x, in, out);
}
void gt(const size_t num, const float *in, const float x, float *out,
cudaStream_t s) {
- KernelGT <<<ceil(num / CU1DBLOCKF), CU1DBLOCKF>>> (num, in, x, out);
+ KernelGT <<<ceil(num / CU1DBLOCKF), CU1DBLOCKF, 0, s>>> (num, in, x, out);
}
void gt(const size_t num, const float *in1, const float *in2, float *out,
cudaStream_t s) {
- KernelBGT <<<ceil(num / CU1DBLOCKF), CU1DBLOCKF>>> (num, in1, in2, out);
+ KernelBGT <<<ceil(num / CU1DBLOCKF), CU1DBLOCKF, 0, s>>> (num, in1, in2, out);
}
void ge(const size_t num, const float *in, const float x, float *out,
cudaStream_t s) {
- KernelGE <<<ceil(num / CU1DBLOCKF), CU1DBLOCKF>>> (num, in, x, out);
+ KernelGE <<<ceil(num / CU1DBLOCKF), CU1DBLOCKF, 0, s>>> (num, in, x, out);
}
void ge(const size_t num, const float *in1, const float *in2, float *out,
cudaStream_t s) {
- KernelBGE <<<ceil(num / CU1DBLOCKF), CU1DBLOCKF>>> (num, in1, in2, out);
+ KernelBGE <<<ceil(num / CU1DBLOCKF), CU1DBLOCKF, 0, s>>> (num, in1, in2, out);
}
void lt(const size_t num, const float *in, const float x, float *out,
cudaStream_t s) {
- KernelLT <<<ceil(num / CU1DBLOCKF), CU1DBLOCKF>>> (num, in, x, out);
+ KernelLT <<<ceil(num / CU1DBLOCKF), CU1DBLOCKF, 0, s>>> (num, in, x, out);
}
void lt(const size_t num, const float *in1, const float *in2, float *out,
cudaStream_t s) {
- KernelBLT <<<ceil(num / CU1DBLOCKF), CU1DBLOCKF>>> (num, in1, in2, out);
+ KernelBLT <<<ceil(num / CU1DBLOCKF), CU1DBLOCKF, 0, s>>> (num, in1, in2, out);
}
void le(const size_t num, const float *in, const float x, float *out,
cudaStream_t s) {
- KernelLE <<<ceil(num / CU1DBLOCKF), CU1DBLOCKF>>> (num, in, x, out);
+ KernelLE <<<ceil(num / CU1DBLOCKF), CU1DBLOCKF, 0, s>>> (num, in, x, out);
}
void le(const size_t num, const float *in1, const float *in2, float *out,
cudaStream_t s) {
- KernelBLE <<<ceil(num / CU1DBLOCKF), CU1DBLOCKF>>> (num, in1, in2, out);
+ KernelBLE <<<ceil(num / CU1DBLOCKF), CU1DBLOCKF, 0, s>>> (num, in1, in2, out);
}
void pow(const size_t n, const float *in1, const float *in2, float *out,
cudaStream_t s) {
- KernelPow <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (n, in1, in2, out);
+ KernelPow <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF, 0, s>>> (n, in1, in2, out);
}
void add(const size_t n, const float *in1, const float *in2, float *out,
cudaStream_t s) {
- KernelAdd <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (n, in1, in2, out);
+ KernelAdd <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF, 0, s>>> (n, in1, in2, out);
}
void sub(const size_t n, const float *in1, const float *in2, float *out,
cudaStream_t s) {
- KernelSub <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (n, in1, in2, out);
+ KernelSub <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF, 0, s>>> (n, in1, in2, out);
}
void mult(const size_t n, const float *in1, const float *in2, float *out,
cudaStream_t s) {
- KernelMult <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (n, in1, in2, out);
+ KernelMult <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF, 0, s>>> (n, in1, in2, out);
}
void div(const size_t n, const float *in1, const float *in2, float *out,
cudaStream_t s) {
- KernelDiv <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (n, in1, in2, out);
+ KernelDiv <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF, 0, s>>> (n, in1, in2, out);
}
/*
@@ -513,42 +535,42 @@ void sum(const size_t n, const float *in, float *out, cudaStream_t s) {
void ComputeCrossEntropy(const bool int_target, size_t batchsize, const size_t dim, const float *p,
const int *t, float *loss, cudaStream_t stream) {
- KernelComputeCrossEntropy <<<ceil(batchsize / CU1DBLOCKF), CU1DBLOCKF>>>
+ KernelComputeCrossEntropy <<<ceil(batchsize / CU1DBLOCKF), CU1DBLOCKF, 0, stream>>>
(int_target, batchsize, dim, p, t, loss);
}
void SoftmaxCrossEntropyBwd(const bool int_target, size_t batchsize, const size_t dim, const float *p,
const int *t, float *grad, cudaStream_t stream) {
- KernelSoftmaxCrossEntropyBwd <<<ceil(batchsize / CU1DBLOCKF), CU1DBLOCKF>>>
+ KernelSoftmaxCrossEntropyBwd <<<ceil(batchsize / CU1DBLOCKF), CU1DBLOCKF, 0, stream>>>
(int_target, batchsize, dim, p, t, grad);
}
void RowMax(const size_t nrow, const size_t ncol, const float *inPtr,
float *outPtr, cudaStream_t stream) {
- KernelRowMax <<<ceil(nrow / CU1DBLOCKF), CU1DBLOCKF>>>(nrow, ncol, inPtr, outPtr);
+ KernelRowMax <<<ceil(nrow / CU1DBLOCKF), CU1DBLOCKF, 0, stream>>>(nrow, ncol, inPtr, outPtr);
}
/*
void square_grad(int n, const float *in, float *out, cudaStream_t s) {
- kernel_square_grad <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (in, out, n);
+ kernel_square_grad <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF, 0, s>>> (in, out, n);
}
void tanh_grad(int n, const float *in, float *out, cudaStream_t s) {
- kernel_tanh_grad <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (in, out, n);
+ kernel_tanh_grad <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF, 0, s>>> (in, out, n);
}
void relu_grad(int n, const float *in, float *out, cudaStream_t s) {
- kernel_relu_grad <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (in, out, n);
+ kernel_relu_grad <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF, 0, s>>> (in, out, n);
}
void sigmoid_grad(int n, const float *in, float *out, cudaStream_t s) {
- kernel_sigmoid_grad <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (in, out, n);
+ kernel_sigmoid_grad <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF, 0, s>>> (in, out, n);
}
void softplus_grad(int n, const float *in, float *out, cudaStream_t s) {
- kernel_softplus_grad <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (in, out, n);
+ kernel_softplus_grad <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF, 0, s>>> (in, out, n);
}
diff --git a/src/core/tensor/math_kernel.h b/src/core/tensor/math_kernel.h
index 6fa2a64..fb8601d 100644
--- a/src/core/tensor/math_kernel.h
+++ b/src/core/tensor/math_kernel.h
@@ -21,7 +21,7 @@
#ifndef SRC_CORE_TENSOR__MATH_KERNEL_H_
#define SRC_CORE_TENSOR__MATH_KERNEL_H_
-
+#include "cuda_fp16.h"
#include "singa/singa_config.h"
#ifdef USE_CUDA
@@ -123,6 +123,11 @@ void SoftmaxCrossEntropyBwd(bool int_target, const size_t batchsize,
void RowMax(const size_t nrow, const size_t ncol, const float *inPtr,
float *outPtr, cudaStream_t stream);
+
+void float2half(const size_t n, const float *in, __half *out, cudaStream_t s);
+
+void half2float(const size_t n, const __half *in, float *out, cudaStream_t s);
+
} // cuda
} // namespace singa
diff --git a/src/io/communicator.cc b/src/io/communicator.cc
index 46e8dbf..7995e75 100644
--- a/src/io/communicator.cc
+++ b/src/io/communicator.cc
@@ -22,10 +22,10 @@
#ifdef USE_DIST
#include "singa/io/communicator.h"
+#include "./math_kernel.h"
namespace singa{
-
static uint64_t getHostHash(const char* string) {
// Based on DJB2, result = result * 33 + char
uint64_t result = 5381;
@@ -110,21 +110,24 @@ void Communicator::setup(int gpu_num){
CUDA_CHECK(cudaSetDevice(gpu_num));
NCCLCHECK(ncclCommInitRank(&comm, totalMPIRanksInGlobal, id, MPIRankInGlobal));
- CUDA_CHECK(cudaStreamCreateWithPriority(&s, cudaStreamNonBlocking, 0));
- CUDA_CHECK(cudaStreamCreateWithPriority(&c, cudaStreamNonBlocking, 1));
+ CUDA_CHECK(cudaStreamCreateWithFlags(&s, cudaStreamNonBlocking));
+ CUDA_CHECK(cudaStreamCreateWithFlags(&c1, cudaStreamNonBlocking));
+ CUDA_CHECK(cudaStreamCreateWithFlags(&c2, cudaStreamNonBlocking));
CUDA_CHECK(cudaMalloc(&fusedSendBuff, maxSize * sizeof(float)));
CUDA_CHECK(cudaMalloc(&fusedRecvBuff, maxSize * sizeof(float)));
+ CUDA_CHECK(cudaMalloc(&fusedSendBuffHalf, maxSize * sizeof(__half)));
+ CUDA_CHECK(cudaMalloc(&fusedRecvBuffHalf, maxSize * sizeof(__half)));
CUDA_CHECK(cudaEventCreateWithFlags(&event, cudaEventBlockingSync | cudaEventDisableTiming));
}
-void Communicator::allReduce(int size, void* sendbuff, void* recvbuff)
+void Communicator::allReduce(int size, void* sendbuff, void* recvbuff, ncclDataType_t ncclType)
{
NCCLCHECK(ncclAllReduce((const void*)sendbuff,
(void*)recvbuff,
size,
- ncclFloat,
+ ncclType,
ncclSum,
comm,
s));
@@ -132,10 +135,12 @@ void Communicator::allReduce(int size, void* sendbuff, void* recvbuff)
}
void Communicator::wait(){
- //synchronizing on CUDA stream to complete NCCL communication
+ //synchronizing on all the CUDA streams used by communicator
CUDA_CHECK(cudaEventRecord(event, s));
CUDA_CHECK(cudaStreamWaitEvent(NULL, event, 0));
- CUDA_CHECK(cudaEventRecord(event, c));
+ CUDA_CHECK(cudaEventRecord(event, c1));
+ CUDA_CHECK(cudaStreamWaitEvent(NULL, event, 0));
+ CUDA_CHECK(cudaEventRecord(event, c2));
CUDA_CHECK(cudaStreamWaitEvent(NULL, event, 0));
}
@@ -145,38 +150,45 @@ Communicator::~Communicator(){
if (UseMPI == true) MPICHECK(MPI_Finalize());
CUDA_CHECK(cudaFree(fusedSendBuff));
CUDA_CHECK(cudaFree(fusedRecvBuff));
+ CUDA_CHECK(cudaFree(fusedSendBuffHalf));
+ CUDA_CHECK(cudaFree(fusedRecvBuffHalf));
+ CUDA_CHECK(cudaStreamDestroy(s));
+ CUDA_CHECK(cudaStreamDestroy(c1));
+ CUDA_CHECK(cudaStreamDestroy(c2));
+
}
+
void Communicator::fusedSynch(vector<Tensor> &t){
// record the event of the default cuda stream and follow it
CUDA_CHECK(cudaEventRecord(event, NULL));
- CUDA_CHECK(cudaStreamWaitEvent(c, event, 0));
+ CUDA_CHECK(cudaStreamWaitEvent(c1, event, 0));
size_t offset = 0;
//memory copy to fusedBuff
for (size_t i = 0; i < t.size(); i++)
{
- CUDA_CHECK(cudaMemcpyAsync((void*) (fusedSendBuff + offset), (const void*) t[i].block()->mutable_data(), t[i].Size() * sizeof(float), cudaMemcpyDeviceToDevice, c));
+ CUDA_CHECK(cudaMemcpyAsync((void*) (fusedSendBuff + offset), (const void*) t[i].block()->mutable_data(), t[i].Size() * sizeof(float), cudaMemcpyDeviceToDevice, c1));
offset += t[i].Size();
}
// wait for the memcpy to complete
- CUDA_CHECK(cudaEventRecord(event, c));
+ CUDA_CHECK(cudaEventRecord(event, c1));
CUDA_CHECK(cudaStreamWaitEvent(s, event, 0));
- allReduce((int) offset, (void*) fusedSendBuff, (void*) fusedRecvBuff);
+ allReduce((int) offset, (void*) fusedSendBuff, (void*) fusedRecvBuff, ncclFloat);
// wait for the allreduce to complete
CUDA_CHECK(cudaEventRecord(event, s));
- CUDA_CHECK(cudaStreamWaitEvent(c, event, 0));
+ CUDA_CHECK(cudaStreamWaitEvent(c1, event, 0));
//copy data back to tensors after allreduce
offset = 0;
for (size_t i = 0; i < t.size(); i++)
{
- CUDA_CHECK(cudaMemcpyAsync((void*) t[i].block()->mutable_data(), (const void*) (fusedRecvBuff + offset), t[i].Size() * sizeof(float), cudaMemcpyDeviceToDevice, c));
+ CUDA_CHECK(cudaMemcpyAsync((void*) t[i].block()->mutable_data(), (const void*) (fusedRecvBuff + offset), t[i].Size() * sizeof(float), cudaMemcpyDeviceToDevice, c1));
offset += t[i].Size();
}
@@ -189,10 +201,74 @@ void Communicator::synch(Tensor &t){
CUDA_CHECK(cudaStreamWaitEvent(s, event, 0));
void* addr = t.block()->mutable_data();
- allReduce(t.Size(), addr, addr);
+ allReduce(t.Size(), addr, addr, ncclFloat);
+
+}
+
+void Communicator::fusedSynchHalf(vector<Tensor> &t){
+
+ // record the event of the default cuda stream and follow it
+ CUDA_CHECK(cudaEventRecord(event, NULL));
+ CUDA_CHECK(cudaStreamWaitEvent(c1, event, 0));
+
+ size_t offset = 0;
+
+ //memory copy to fusedBuff
+ for (size_t i = 0; i < t.size(); i++)
+ {
+ CUDA_CHECK(cudaMemcpyAsync((void*) (fusedSendBuff + offset), (const void*) t[i].block()->mutable_data(), t[i].Size() * sizeof(float), cudaMemcpyDeviceToDevice, c1));
+ offset += t[i].Size();
+ }
+
+ cuda::float2half(offset, fusedSendBuff, fusedSendBuffHalf, c1);
+
+ // wait for the memcpy to complete
+ CUDA_CHECK(cudaEventRecord(event, c1));
+ CUDA_CHECK(cudaStreamWaitEvent(s, event, 0));
+
+ allReduce((int) offset, (void*) fusedSendBuffHalf, (void*) fusedRecvBuffHalf, ncclHalf);
+
+ // wait for the allreduce to complete
+ CUDA_CHECK(cudaEventRecord(event, s));
+ CUDA_CHECK(cudaStreamWaitEvent(c2, event, 0));
+
+ cuda::half2float(offset, fusedRecvBuffHalf, fusedRecvBuff, c2);
+
+ //copy data back to tensors after allreduce
+ offset = 0;
+ for (size_t i = 0; i < t.size(); i++)
+ {
+ CUDA_CHECK(cudaMemcpyAsync((void*) t[i].block()->mutable_data(), (const void*) (fusedRecvBuff + offset), t[i].Size() * sizeof(float), cudaMemcpyDeviceToDevice, c2));
+ offset += t[i].Size();
+ }
}
+void Communicator::synchHalf(Tensor &t){
+
+ float* addr = static_cast<float*>(t.block()->mutable_data());
+
+ // record the event of the default cuda stream and follow it
+ CUDA_CHECK(cudaEventRecord(event, NULL));
+ CUDA_CHECK(cudaStreamWaitEvent(c1, event, 0));
+
+ cuda::float2half(t.Size(), addr, fusedSendBuffHalf, c1);
+
+ // wait for conversion to half precision complete
+ CUDA_CHECK(cudaEventRecord(event, c1));
+ CUDA_CHECK(cudaStreamWaitEvent(s, event, 0));
+
+ allReduce(t.Size(), (void*) fusedSendBuffHalf, (void*) fusedRecvBuffHalf, ncclHalf);
+
+ // wait for the allreduce to complete
+ CUDA_CHECK(cudaEventRecord(event, s));
+ CUDA_CHECK(cudaStreamWaitEvent(c2, event, 0));
+
+ cuda::half2float(t.Size(), fusedRecvBuffHalf, addr, c2);
+
+}
+
+
}
#endif // USE_DIST