You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@singa.apache.org by wa...@apache.org on 2019/11/20 03:07:22 UTC

[singa] branch master updated: SINGA-487 Add support of gradient compression to half precision

This is an automated email from the ASF dual-hosted git repository.

wangwei pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/singa.git


The following commit(s) were added to refs/heads/master by this push:
     new 859785f  SINGA-487 Add support of gradient compression to half precision
     new 00f80e1  Merge pull request #562 from chrishkchris/SINGA-487_4
859785f is described below

commit 859785fc3400734a2fe26956433e8bf017dc5b3c
Author: chrishkchris <ch...@yahoo.com.hk>
AuthorDate: Tue Nov 19 10:13:04 2019 +0000

    SINGA-487 Add support of gradient compression to half precision
---
 include/singa/io/communicator.h |  12 ++++-
 python/singa/opt.py             |  41 ++++++++++++++++
 src/api/dist_communicator.i     |   2 +
 src/core/tensor/math_kernel.cu  |  98 ++++++++++++++++++++++---------------
 src/core/tensor/math_kernel.h   |   7 ++-
 src/io/communicator.cc          | 104 ++++++++++++++++++++++++++++++++++------
 6 files changed, 209 insertions(+), 55 deletions(-)

diff --git a/include/singa/io/communicator.h b/include/singa/io/communicator.h
index 802b972..e3717fc 100644
--- a/include/singa/io/communicator.h
+++ b/include/singa/io/communicator.h
@@ -30,6 +30,7 @@
 #include <mpi.h>
 
 #include "singa/core/tensor.h"
+#include "cuda_fp16.h"
 using std::vector;
 
 namespace singa{
@@ -67,11 +68,16 @@ public:
   bool UseMPI;
   float *fusedSendBuff;
   float *fusedRecvBuff;
+  __half *fusedSendBuffHalf;
+  __half *fusedRecvBuffHalf;
   size_t maxSize;
 
   ncclUniqueId id;
+  // cuda stream s is for nccl all reduce
   cudaStream_t s;
-  cudaStream_t c;
+  // cuda streams c1 and c2 are mainly for data copy to and from memory buffers
+  cudaStream_t c1;
+  cudaStream_t c2;
   ncclComm_t comm;
   cudaEvent_t event;
 
@@ -80,10 +86,12 @@ public:
   ~Communicator();
   void synch(Tensor &t);
   void fusedSynch(vector<Tensor> &t);
+  void synchHalf(Tensor &t);
+  void fusedSynchHalf(vector<Tensor> &t);
   void wait();
 
 private:
-  void allReduce(int size, void* sendbuff, void* recvbuff);
+  void allReduce(int size, void* sendbuff, void* recvbuff, ncclDataType_t ncclType);
   void setup(int gpu_num);
 
 };
diff --git a/python/singa/opt.py b/python/singa/opt.py
index 46c4c72..06286f2 100755
--- a/python/singa/opt.py
+++ b/python/singa/opt.py
@@ -171,6 +171,10 @@ class DistOpt(object):
         # The class is designed to wrap an optimizer to do disttributed training.
         # opt: The optimizer to be wrapped. nDev: number of devices(GPUs) a
         # process will control/use.
+        # nccl_id: an nccl id holder object for a unique communication id
+        # gpu_num: the GPU id in a single node
+        # gpu_per_node: the number of GPUs in a single node
+        # buffSize: the buffSize used in nccl communicator, default is 16 MB
 
         # world_size: total number of processes.
         # rank_in_local: local rank of a process on the current node.
@@ -199,10 +203,19 @@ class DistOpt(object):
         tensor = singa.VecTensor(tensor)
         self.communicator.fusedSynch(tensor)
 
+    def all_reduce_half(self, tensor):
+        self.communicator.synchHalf(tensor)
+
+    def fused_all_reduce_half(self, tensor):
+        tensor = singa.VecTensor(tensor)
+        self.communicator.fusedSynchHalf(tensor)
+
     def wait(self):
         self.communicator.wait()
 
     def backward_and_update(self, loss, threshold = 2097152):
+        # backward propagation from the loss and parameter update
+        # it applies tensor fusion which fuses all the tensor smaller than the threshold value
         plist = []
         acc = 0
         glist = []
@@ -224,3 +237,31 @@ class DistOpt(object):
         self.wait()
         for p, g in plist:
             self.update(p, g)  
+
+    def backward_and_update_half(self, loss, threshold = 2097152, clipping = False, clip_Value = 100):
+        # THIS IS A EXPERIMENTAL FUNCTION FOR RESEARCH PURPOSE:
+        # It converts the gradients to 16 bits half precision format before allreduce
+        # To assist training, this functions provide an option to perform gradient clipping
+        plist = []
+        acc = 0
+        glist = []
+        for p, g in autograd.backward(loss):
+            if clipping:
+                g = autograd.clip(g, -clip_Value, clip_Value)
+            if g.size() > threshold:
+                # larger than threshold -> reduced directly
+                self.all_reduce_half(g.data)
+            else:
+                # smaller than threshold -> accumulate
+                glist.append(g.data)                    
+                acc += g.size()
+                if (acc > threshold):
+                    self.fused_all_reduce_half(glist)
+                    acc = 0
+                    glist = []
+            plist.append((p, g))
+        if glist:
+            self.fused_all_reduce_half(glist)
+        self.wait()
+        for p, g in plist:
+            self.update(p, g)  
diff --git a/src/api/dist_communicator.i b/src/api/dist_communicator.i
index 8c777f6..2da9602 100644
--- a/src/api/dist_communicator.i
+++ b/src/api/dist_communicator.i
@@ -47,6 +47,8 @@ public:
   Communicator(int gpu_num, int gpu_per_node, const NcclIdHolder &holder, int limit);
   void synch(Tensor &t);
   void fusedSynch(std::vector<Tensor> &t);
+  void synchHalf(Tensor &t);
+  void fusedSynchHalf(std::vector<Tensor> &t);
   void wait();
 };
 
diff --git a/src/core/tensor/math_kernel.cu b/src/core/tensor/math_kernel.cu
index 2ce6531..b25455d 100644
--- a/src/core/tensor/math_kernel.cu
+++ b/src/core/tensor/math_kernel.cu
@@ -348,6 +348,20 @@ __global__ void KernelSoftmaxCrossEntropyBwd(const bool int_target, const size_t
   }
 }
 
+__global__ void KernelFloat2Half(const size_t n, const float *in, __half *out) {
+  for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < n;
+       i += blockDim.x * gridDim.x) {
+    out[i] = __float2half_rn(in[i]);
+  }
+}
+
+__global__ void KernelHalf2Float(const size_t n, const __half *in, float *out) {
+  for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < n;
+       i += blockDim.x * gridDim.x) {
+    out[i] = __half2float(in[i]);
+  }
+}
+
 //cuda unary elementwise ops kernel template 
 #define GenUnaryCudaKernel(fn,kernelfn,cudafn)                                \
   __global__ void kernelfn(const size_t n, const float *in, float *out) {     \
@@ -357,7 +371,7 @@ __global__ void KernelSoftmaxCrossEntropyBwd(const bool int_target, const size_t
     }                                                                         \
   }                                                                           \
   void fn(const size_t n, const float *in, float *out, cudaStream_t s) {      \
-    kernelfn <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (n, in, out);             \
+    kernelfn <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF, 0, s>>> (n, in, out);             \
   }
 
 GenUnaryCudaKernel(cos,KernelCos,cosf);
@@ -378,128 +392,136 @@ GenUnaryCudaKernel(atanh,KernelAtanh,atanhf);
 // Functions call kernels
 // ********************************
 
+void float2half(const size_t n, const float *in, __half *out, cudaStream_t s) {
+  KernelFloat2Half <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF, 0, s>>> (n, in, out);
+}
+
+void half2float(const size_t n, const __half *in, float *out, cudaStream_t s) {
+  KernelHalf2Float <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF, 0, s>>> (n, in, out);
+}
+
 void set(const size_t n, const float v, float *out, cudaStream_t s) {
-  KernelSet <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (n, v, out);
+  KernelSet <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF, 0, s>>> (n, v, out);
 }
 
 void abs(const size_t n, const float *in, float *out, cudaStream_t s) {
-  KernelAbs <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (n, in, out);
+  KernelAbs <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF, 0, s>>> (n, in, out);
 }
 
 void sign(const size_t n, const float *in, float *out, cudaStream_t s) {
-  KernelSign <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (n, in, out);
+  KernelSign <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF, 0, s>>> (n, in, out);
 }
 
 void exp(const size_t n, const float *in, float *out, cudaStream_t s) {
-  KernelExp <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (n, in, out);
+  KernelExp <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF, 0, s>>> (n, in, out);
 }
 
 void log(const size_t n, const float *in, float *out, cudaStream_t s) {
-  KernelLog <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (n, in, out);
+  KernelLog <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF, 0, s>>> (n, in, out);
 }
 
 void sqrt(const size_t n, const float *in, float *out, cudaStream_t s) {
-  KernelSqrt <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (n, in, out);
+  KernelSqrt <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF, 0, s>>> (n, in, out);
 }
 
 void square(const size_t n, const float *in, float *out, cudaStream_t s) {
-  KernelSquare <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (n, in, out);
+  KernelSquare <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF, 0, s>>> (n, in, out);
 }
 
 void relu(const size_t n, const float *in, float *out, cudaStream_t s) {
-  KernelRelu <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (n, in, out);
+  KernelRelu <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF, 0, s>>> (n, in, out);
 }
 void sigmoid(const size_t n, const float *in, float *out, cudaStream_t s) {
-  KernelSigmoid <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (n, in, out);
+  KernelSigmoid <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF, 0, s>>> (n, in, out);
 }
 void softplus(const size_t n, const float *in, float *out, cudaStream_t s) {
-  KernelSoftplus <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (n, in, out);
+  KernelSoftplus <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF, 0, s>>> (n, in, out);
 }
 void clamp(const size_t n, const float low, const float high, const float *in,
            float *out, cudaStream_t s) {
-  KernelClamp <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (n, low, high, in, out);
+  KernelClamp <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF, 0, s>>> (n, low, high, in, out);
 }
 
 void pow(const size_t n, const float *in, const float x, float *out,
          cudaStream_t s) {
-  KernelPow <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (n, in, x, out);
+  KernelPow <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF, 0, s>>> (n, in, x, out);
 }
 
 void add(const size_t n, const float *in, const float x, float *out,
          cudaStream_t s) {
-  KernelAdd <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (n, in, x, out);
+  KernelAdd <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF, 0, s>>> (n, in, x, out);
 }
 
 void mult(const size_t n, const float *in, const float x, float *out,
           cudaStream_t s) {
-  KernelMult <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (n, in, x, out);
+  KernelMult <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF, 0, s>>> (n, in, x, out);
 }
 
 void div(const size_t n, const float x, const float *in, float *out,
           cudaStream_t s) {
-  KernelDiv <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (n, x, in, out);
+  KernelDiv <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF, 0, s>>> (n, x, in, out);
 }
 
 void threshold(const size_t n, const float x, const float *in, float *out,
                cudaStream_t s) {
-  KernelThreshold <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (n, x, in, out);
+  KernelThreshold <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF, 0, s>>> (n, x, in, out);
 }
 
 void gt(const size_t num, const float *in, const float x, float *out,
         cudaStream_t s) {
-  KernelGT <<<ceil(num / CU1DBLOCKF), CU1DBLOCKF>>> (num, in, x, out);
+  KernelGT <<<ceil(num / CU1DBLOCKF), CU1DBLOCKF, 0, s>>> (num, in, x, out);
 }
 void gt(const size_t num, const float *in1, const float *in2, float *out,
         cudaStream_t s) {
-  KernelBGT <<<ceil(num / CU1DBLOCKF), CU1DBLOCKF>>> (num, in1, in2, out);
+  KernelBGT <<<ceil(num / CU1DBLOCKF), CU1DBLOCKF, 0, s>>> (num, in1, in2, out);
 }
 void ge(const size_t num, const float *in, const float x, float *out,
         cudaStream_t s) {
-  KernelGE <<<ceil(num / CU1DBLOCKF), CU1DBLOCKF>>> (num, in, x, out);
+  KernelGE <<<ceil(num / CU1DBLOCKF), CU1DBLOCKF, 0, s>>> (num, in, x, out);
 }
 void ge(const size_t num, const float *in1, const float *in2, float *out,
         cudaStream_t s) {
-  KernelBGE <<<ceil(num / CU1DBLOCKF), CU1DBLOCKF>>> (num, in1, in2, out);
+  KernelBGE <<<ceil(num / CU1DBLOCKF), CU1DBLOCKF, 0, s>>> (num, in1, in2, out);
 }
 void lt(const size_t num, const float *in, const float x, float *out,
         cudaStream_t s) {
-  KernelLT <<<ceil(num / CU1DBLOCKF), CU1DBLOCKF>>> (num, in, x, out);
+  KernelLT <<<ceil(num / CU1DBLOCKF), CU1DBLOCKF, 0, s>>> (num, in, x, out);
 }
 void lt(const size_t num, const float *in1, const float *in2, float *out,
         cudaStream_t s) {
-  KernelBLT <<<ceil(num / CU1DBLOCKF), CU1DBLOCKF>>> (num, in1, in2, out);
+  KernelBLT <<<ceil(num / CU1DBLOCKF), CU1DBLOCKF, 0, s>>> (num, in1, in2, out);
 }
 void le(const size_t num, const float *in, const float x, float *out,
         cudaStream_t s) {
-  KernelLE <<<ceil(num / CU1DBLOCKF), CU1DBLOCKF>>> (num, in, x, out);
+  KernelLE <<<ceil(num / CU1DBLOCKF), CU1DBLOCKF, 0, s>>> (num, in, x, out);
 }
 void le(const size_t num, const float *in1, const float *in2, float *out,
         cudaStream_t s) {
-  KernelBLE <<<ceil(num / CU1DBLOCKF), CU1DBLOCKF>>> (num, in1, in2, out);
+  KernelBLE <<<ceil(num / CU1DBLOCKF), CU1DBLOCKF, 0, s>>> (num, in1, in2, out);
 }
 void pow(const size_t n, const float *in1, const float *in2, float *out,
          cudaStream_t s) {
-  KernelPow <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (n, in1, in2, out);
+  KernelPow <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF, 0, s>>> (n, in1, in2, out);
 }
 
 void add(const size_t n, const float *in1, const float *in2, float *out,
          cudaStream_t s) {
-  KernelAdd <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (n, in1, in2, out);
+  KernelAdd <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF, 0, s>>> (n, in1, in2, out);
 }
 
 void sub(const size_t n, const float *in1, const float *in2, float *out,
          cudaStream_t s) {
-  KernelSub <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (n, in1, in2, out);
+  KernelSub <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF, 0, s>>> (n, in1, in2, out);
 }
 
 void mult(const size_t n, const float *in1, const float *in2, float *out,
           cudaStream_t s) {
-  KernelMult <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (n, in1, in2, out);
+  KernelMult <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF, 0, s>>> (n, in1, in2, out);
 }
 
 void div(const size_t n, const float *in1, const float *in2, float *out,
          cudaStream_t s) {
-  KernelDiv <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (n, in1, in2, out);
+  KernelDiv <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF, 0, s>>> (n, in1, in2, out);
 }
 
 /*
@@ -513,42 +535,42 @@ void sum(const size_t n, const float *in, float *out, cudaStream_t s) {
 
 void ComputeCrossEntropy(const bool int_target, size_t batchsize, const size_t dim, const float *p,
                          const int *t, float *loss, cudaStream_t stream) {
-  KernelComputeCrossEntropy <<<ceil(batchsize / CU1DBLOCKF), CU1DBLOCKF>>>
+  KernelComputeCrossEntropy <<<ceil(batchsize / CU1DBLOCKF), CU1DBLOCKF, 0, stream>>>
       (int_target, batchsize, dim, p, t, loss);
 }
 
 void SoftmaxCrossEntropyBwd(const bool int_target, size_t batchsize, const size_t dim, const float *p,
                             const int *t, float *grad, cudaStream_t stream) {
-  KernelSoftmaxCrossEntropyBwd <<<ceil(batchsize / CU1DBLOCKF), CU1DBLOCKF>>>
+  KernelSoftmaxCrossEntropyBwd <<<ceil(batchsize / CU1DBLOCKF), CU1DBLOCKF, 0, stream>>>
       (int_target, batchsize, dim, p, t, grad);
 }
 
 void RowMax(const size_t nrow, const size_t ncol, const float *inPtr,
     float *outPtr, cudaStream_t stream) {
-  KernelRowMax <<<ceil(nrow / CU1DBLOCKF), CU1DBLOCKF>>>(nrow, ncol, inPtr, outPtr);
+  KernelRowMax <<<ceil(nrow / CU1DBLOCKF), CU1DBLOCKF, 0, stream>>>(nrow, ncol, inPtr, outPtr);
 }
 
 /*
 void square_grad(int n, const float *in, float *out, cudaStream_t s) {
-  kernel_square_grad <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (in, out, n);
+  kernel_square_grad <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF, 0, s>>> (in, out, n);
 }
 
 void tanh_grad(int n, const float *in, float *out, cudaStream_t s) {
-  kernel_tanh_grad <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (in, out, n);
+  kernel_tanh_grad <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF, 0, s>>> (in, out, n);
 }
 
 
 void relu_grad(int n, const float *in, float *out, cudaStream_t s) {
-  kernel_relu_grad <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (in, out, n);
+  kernel_relu_grad <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF, 0, s>>> (in, out, n);
 }
 
 
 void sigmoid_grad(int n, const float *in, float *out, cudaStream_t s) {
-  kernel_sigmoid_grad <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (in, out, n);
+  kernel_sigmoid_grad <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF, 0, s>>> (in, out, n);
 }
 
 void softplus_grad(int n, const float *in, float *out, cudaStream_t s) {
-  kernel_softplus_grad <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (in, out, n);
+  kernel_softplus_grad <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF, 0, s>>> (in, out, n);
 }
 
 
diff --git a/src/core/tensor/math_kernel.h b/src/core/tensor/math_kernel.h
index 6fa2a64..fb8601d 100644
--- a/src/core/tensor/math_kernel.h
+++ b/src/core/tensor/math_kernel.h
@@ -21,7 +21,7 @@
 #ifndef SRC_CORE_TENSOR__MATH_KERNEL_H_
 #define SRC_CORE_TENSOR__MATH_KERNEL_H_
 
-
+#include "cuda_fp16.h"
 #include "singa/singa_config.h"
 #ifdef USE_CUDA
 
@@ -123,6 +123,11 @@ void SoftmaxCrossEntropyBwd(bool int_target, const size_t batchsize,
 
 void RowMax(const size_t nrow, const size_t ncol, const float *inPtr,
     float *outPtr, cudaStream_t stream);
+
+void float2half(const size_t n, const float *in, __half *out, cudaStream_t s);
+
+void half2float(const size_t n, const __half *in, float *out, cudaStream_t s);
+
 }  // cuda
 
 }  // namespace singa
diff --git a/src/io/communicator.cc b/src/io/communicator.cc
index 46e8dbf..7995e75 100644
--- a/src/io/communicator.cc
+++ b/src/io/communicator.cc
@@ -22,10 +22,10 @@
 #ifdef USE_DIST
 
 #include "singa/io/communicator.h"
+#include "./math_kernel.h"
 
 namespace singa{
 
-
 static uint64_t getHostHash(const char* string) {
   // Based on DJB2, result = result * 33 + char
   uint64_t result = 5381;
@@ -110,21 +110,24 @@ void Communicator::setup(int gpu_num){
 
   CUDA_CHECK(cudaSetDevice(gpu_num));
   NCCLCHECK(ncclCommInitRank(&comm, totalMPIRanksInGlobal, id, MPIRankInGlobal));
-  CUDA_CHECK(cudaStreamCreateWithPriority(&s, cudaStreamNonBlocking, 0));
-  CUDA_CHECK(cudaStreamCreateWithPriority(&c, cudaStreamNonBlocking, 1));
+  CUDA_CHECK(cudaStreamCreateWithFlags(&s, cudaStreamNonBlocking));
+  CUDA_CHECK(cudaStreamCreateWithFlags(&c1, cudaStreamNonBlocking));
+  CUDA_CHECK(cudaStreamCreateWithFlags(&c2, cudaStreamNonBlocking));
   CUDA_CHECK(cudaMalloc(&fusedSendBuff, maxSize * sizeof(float)));
   CUDA_CHECK(cudaMalloc(&fusedRecvBuff, maxSize * sizeof(float)));
+  CUDA_CHECK(cudaMalloc(&fusedSendBuffHalf, maxSize * sizeof(__half)));
+  CUDA_CHECK(cudaMalloc(&fusedRecvBuffHalf, maxSize * sizeof(__half)));
   CUDA_CHECK(cudaEventCreateWithFlags(&event, cudaEventBlockingSync | cudaEventDisableTiming));
 
 }
 
-void Communicator::allReduce(int size, void* sendbuff, void* recvbuff)
+void Communicator::allReduce(int size, void* sendbuff, void* recvbuff, ncclDataType_t ncclType)
 {
 
   NCCLCHECK(ncclAllReduce((const void*)sendbuff,
                              (void*)recvbuff,
                              size,
-                             ncclFloat,
+                             ncclType,
                              ncclSum,
                              comm, 
                              s));
@@ -132,10 +135,12 @@ void Communicator::allReduce(int size, void* sendbuff, void* recvbuff)
 }
 
 void Communicator::wait(){
-  //synchronizing on CUDA stream to complete NCCL communication
+  //synchronizing on all the CUDA streams used by communicator
   CUDA_CHECK(cudaEventRecord(event, s));
   CUDA_CHECK(cudaStreamWaitEvent(NULL, event, 0));
-  CUDA_CHECK(cudaEventRecord(event, c));
+  CUDA_CHECK(cudaEventRecord(event, c1));
+  CUDA_CHECK(cudaStreamWaitEvent(NULL, event, 0));
+  CUDA_CHECK(cudaEventRecord(event, c2));
   CUDA_CHECK(cudaStreamWaitEvent(NULL, event, 0));
 }
 
@@ -145,38 +150,45 @@ Communicator::~Communicator(){
   if (UseMPI == true) MPICHECK(MPI_Finalize());
   CUDA_CHECK(cudaFree(fusedSendBuff));
   CUDA_CHECK(cudaFree(fusedRecvBuff));
+  CUDA_CHECK(cudaFree(fusedSendBuffHalf));
+  CUDA_CHECK(cudaFree(fusedRecvBuffHalf));
+  CUDA_CHECK(cudaStreamDestroy(s));
+  CUDA_CHECK(cudaStreamDestroy(c1));
+  CUDA_CHECK(cudaStreamDestroy(c2));
+
 }
 
+
 void Communicator::fusedSynch(vector<Tensor> &t){
 
   // record the event of the default cuda stream and follow it
   CUDA_CHECK(cudaEventRecord(event, NULL));
-  CUDA_CHECK(cudaStreamWaitEvent(c, event, 0));
+  CUDA_CHECK(cudaStreamWaitEvent(c1, event, 0));
   
   size_t offset = 0;
 
   //memory copy to fusedBuff
   for (size_t i = 0; i < t.size(); i++)
   {
-    CUDA_CHECK(cudaMemcpyAsync((void*) (fusedSendBuff + offset), (const void*) t[i].block()->mutable_data(), t[i].Size() * sizeof(float), cudaMemcpyDeviceToDevice, c));
+    CUDA_CHECK(cudaMemcpyAsync((void*) (fusedSendBuff + offset), (const void*) t[i].block()->mutable_data(), t[i].Size() * sizeof(float), cudaMemcpyDeviceToDevice, c1));
     offset += t[i].Size();
   }
 
   // wait for the memcpy to complete
-  CUDA_CHECK(cudaEventRecord(event, c));
+  CUDA_CHECK(cudaEventRecord(event, c1));
   CUDA_CHECK(cudaStreamWaitEvent(s, event, 0));
 
-  allReduce((int) offset, (void*) fusedSendBuff, (void*) fusedRecvBuff);
+  allReduce((int) offset, (void*) fusedSendBuff, (void*) fusedRecvBuff, ncclFloat);
 
   // wait for the allreduce to complete
   CUDA_CHECK(cudaEventRecord(event, s));
-  CUDA_CHECK(cudaStreamWaitEvent(c, event, 0));
+  CUDA_CHECK(cudaStreamWaitEvent(c1, event, 0));
 
   //copy data back to tensors after allreduce
   offset = 0;
   for (size_t i = 0; i < t.size(); i++)
   {
-    CUDA_CHECK(cudaMemcpyAsync((void*) t[i].block()->mutable_data(), (const void*) (fusedRecvBuff + offset), t[i].Size() * sizeof(float), cudaMemcpyDeviceToDevice, c));
+    CUDA_CHECK(cudaMemcpyAsync((void*) t[i].block()->mutable_data(), (const void*) (fusedRecvBuff + offset), t[i].Size() * sizeof(float), cudaMemcpyDeviceToDevice, c1));
     offset += t[i].Size();
   }
 
@@ -189,10 +201,74 @@ void Communicator::synch(Tensor &t){
   CUDA_CHECK(cudaStreamWaitEvent(s, event, 0));
 
   void* addr = t.block()->mutable_data();
-  allReduce(t.Size(), addr, addr);
+  allReduce(t.Size(), addr, addr, ncclFloat);
+
+}
+
+void Communicator::fusedSynchHalf(vector<Tensor> &t){
+
+  // record the event of the default cuda stream and follow it
+  CUDA_CHECK(cudaEventRecord(event, NULL));
+  CUDA_CHECK(cudaStreamWaitEvent(c1, event, 0));
+  
+  size_t offset = 0;
+
+  //memory copy to fusedBuff
+  for (size_t i = 0; i < t.size(); i++)
+  {
+    CUDA_CHECK(cudaMemcpyAsync((void*) (fusedSendBuff + offset), (const void*) t[i].block()->mutable_data(), t[i].Size() * sizeof(float), cudaMemcpyDeviceToDevice, c1));
+    offset += t[i].Size();
+  }
+
+  cuda::float2half(offset, fusedSendBuff, fusedSendBuffHalf, c1);
+
+  // wait for the memcpy to complete
+  CUDA_CHECK(cudaEventRecord(event, c1));
+  CUDA_CHECK(cudaStreamWaitEvent(s, event, 0));
+
+  allReduce((int) offset, (void*) fusedSendBuffHalf, (void*) fusedRecvBuffHalf, ncclHalf);
+
+  // wait for the allreduce to complete
+  CUDA_CHECK(cudaEventRecord(event, s));
+  CUDA_CHECK(cudaStreamWaitEvent(c2, event, 0));
+
+  cuda::half2float(offset, fusedRecvBuffHalf, fusedRecvBuff, c2);
+
+  //copy data back to tensors after allreduce
+  offset = 0;
+  for (size_t i = 0; i < t.size(); i++)
+  {
+    CUDA_CHECK(cudaMemcpyAsync((void*) t[i].block()->mutable_data(), (const void*) (fusedRecvBuff + offset), t[i].Size() * sizeof(float), cudaMemcpyDeviceToDevice, c2));
+    offset += t[i].Size();
+  }
 
 }
 
+void Communicator::synchHalf(Tensor &t){
+
+  float* addr = static_cast<float*>(t.block()->mutable_data());
+
+  // record the event of the default cuda stream and follow it
+  CUDA_CHECK(cudaEventRecord(event, NULL));
+  CUDA_CHECK(cudaStreamWaitEvent(c1, event, 0));
+
+  cuda::float2half(t.Size(), addr, fusedSendBuffHalf, c1);
+
+  // wait for conversion to half precision complete
+  CUDA_CHECK(cudaEventRecord(event, c1));
+  CUDA_CHECK(cudaStreamWaitEvent(s, event, 0));
+
+  allReduce(t.Size(), (void*) fusedSendBuffHalf, (void*) fusedRecvBuffHalf, ncclHalf);
+
+  // wait for the allreduce to complete
+  CUDA_CHECK(cudaEventRecord(event, s));
+  CUDA_CHECK(cudaStreamWaitEvent(c2, event, 0));
+
+  cuda::half2float(t.Size(), fusedRecvBuffHalf, addr, c2);
+
+}
+
+
 }
 
 #endif // USE_DIST