You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@singa.apache.org by zh...@apache.org on 2018/07/12 08:39:52 UTC

[4/5] incubator-singa git commit: SINGA-379 Implement batchnorm operation and its related functions for autograd

SINGA-379 Implement batchnorm operation and its related functions for autograd

Change the API (arguments) of batchnorm functions.


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/ce1a7335
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/ce1a7335
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/ce1a7335

Branch: refs/heads/master
Commit: ce1a73359a6e3eb2c3e7ec5ac861ac1829144dad
Parents: 8654f89
Author: Wang Wei <wa...@gmail.com>
Authored: Thu Jul 12 00:31:40 2018 +0800
Committer: wang wei <wa...@comp.nus.edu.sg>
Committed: Thu Jul 12 12:32:50 2018 +0800

----------------------------------------------------------------------
 python/singa/autograd.py         |  42 +++++-----
 src/api/model_operation.i        |  38 ++++-----
 src/model/operation/batchnorm.cc | 147 +++++++++++++++-------------------
 src/model/operation/batchnorm.h  |  26 +++---
 4 files changed, 118 insertions(+), 135 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/ce1a7335/python/singa/autograd.py
----------------------------------------------------------------------
diff --git a/python/singa/autograd.py b/python/singa/autograd.py
index 4ba0b11..3a2eddd 100755
--- a/python/singa/autograd.py
+++ b/python/singa/autograd.py
@@ -33,7 +33,6 @@ CTensor = singa.Tensor
 training = False
 
 
-
 def infer_dependency(op):
     '''
     Infer the dependency of all operations with the
@@ -483,6 +482,7 @@ def cross_entropy(y, t):
 
 
 class SoftMaxCrossEntropy(Operation):
+
     def forward(self, x, t):
         self.p = singa.SoftMax(x)
         self.t = t
@@ -771,7 +771,7 @@ class Conv2D(Layer):
         return y
 
 
-class BatchNorm2d(NewLayer):
+class BatchNorm2d(Layer):
 
     def __init__(self, num_features, momentum=0.9):
         self.channels = num_features
@@ -787,16 +787,16 @@ class BatchNorm2d(NewLayer):
                            requires_grad=True, stores_grad=True)
         self.bias.set_value(0.0)
 
-        self.runningmean = Tensor(
+        self.running_mean = Tensor(
             shape=param_shape, requires_grad=False, stores_grad=False)
-        self.runningvariance = Tensor(
+        self.running_var = Tensor(
             shape=param_shape, requires_grad=False, stores_grad=False)
 
     def __call__(self, x):
         assert x.shape[1] == self.channels, 'number of channels dismatched.'
 
         self.device_check(x, self.scale, self.bias,
-                          self.runningmean, self.runningvariance)
+                          self.running_mean, self.running_var)
 
         if x.device.id() == -1:
             raise NotImplementedError
@@ -804,39 +804,40 @@ class BatchNorm2d(NewLayer):
         else:
             if not hasattr(self, 'handle'):
                 self.handle = singa.CudnnBatchNormHandle(
-                    self.momentum, x.data, self.runningmean.data, self.runningvariance.data)
+                    self.momentum, x.data)
             elif x.shape[0] != self.handle.batchsize:
                 self.handle = singa.CudnnBatchNormHandle(
-                    self.momentum, x.data, self.runningmean.data, self.runningvariance.data)
+                    self.momentum, x.data)
         self.handle.device_id = x.device.id()
 
-        y = batchnorm2d(x, self.scale, self.bias, self.handle)
+        y = batchnorm2d(x, self.scale, self.bias,
+                        self.running_mean, self.running_var, self.handle)
         return y
 
 
 class _BatchNorm2d(Operation):
 
-    def __init__(self, handle):
+    def __init__(self, running_mean, running_var, handle):
+        self.running_mean = running_mean.data
+        self.running_var = running_var.data
         self.handle = handle
 
     def forward(self, x, scale, bias):
         if training:
-            resultmean = CTensor([scale.shape(0)])
-            resultvariance = CTensor([scale.shape(0)])
-            self.cache = (x, resultmean, resultvariance, scale)
 
             if self.handle.device_id == -1:
                 raise NotImplementedError
             else:
-                resultmean.ToDevice(x.device())
-                resultvariance.ToDevice(x.device())
-                return singa.GpuBatchNormForwardTraining(x, scale, bias, self.cache, self.handle)
-
+                y, mean, var = singa.GpuBatchNormForwardTraining(self.handle,
+                                                                 x, scale, bias, self.running_mean, self.running_var)
+                self.cache = (x, scale, mean, var)
         else:
             if self.handle.device_id == -1:
                 raise NotImplementedError
             else:
-                return singa.GpuBatchNormForwardInference(x, scale, bias, self.handle)
+                y, _, _ = singa.GpuBatchNormForwardInference(
+                    self.handle, x, scale, bias, self.running_mean, self.running_var)
+        return y
 
     def backward(self, dy):
         assert training is True and hasattr(
@@ -848,10 +849,11 @@ class _BatchNorm2d(Operation):
         if self.handle.device_id == -1:
             raise NotImplementedError
         else:
+            x, scale, mean, var = self.cache
             dx, ds, db = singa.GpuBatchNormBackward(
-                dy, self.cache, self.handle)
+                self.handle, dy, x, scale, mean, var)
             return dx, ds, db
 
 
-def batchnorm2d(x, scale, bias, handle):
-    return _BatchNorm2d(handle)(x, scale, bias)[0]
+def batchnorm2d(x, scale, bias, running_mean, running_var, handle):
+    return _BatchNorm2d(running_mean, running_var, handle)(x, scale, bias)[0]

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/ce1a7335/src/api/model_operation.i
----------------------------------------------------------------------
diff --git a/src/api/model_operation.i b/src/api/model_operation.i
index a1d59ed..6f2d1fa 100755
--- a/src/api/model_operation.i
+++ b/src/api/model_operation.i
@@ -27,6 +27,17 @@ Tensor CpuConvBackwardW(const Tensor &dy, const Tensor &x, const Tensor &W, cons
 
 Tensor CpuConvBackwardb(const Tensor &dy, const Tensor &b, const ConvHandle &ch);
 
+
+
+class BatchNormHandle{
+  public:
+    BatchNormHandle(const float momentum, const Tensor& input);
+
+    size_t batchsize;
+};
+
+
+
 #if USE_CUDNN
 class CudnnConvHandle: public ConvHandle {
  public:
@@ -47,36 +58,25 @@ Tensor GpuConvBackwardW(const Tensor &dy, const Tensor &x, const Tensor &W, cons
 
 Tensor GpuConvBackwardb(const Tensor &dy, const Tensor &b, const CudnnConvHandle &cch);
 
-#endif  // USE_CUDNN
 
-class BatchNormHandle{
-  public:
-    BatchNormHandle(const float momentum, const Tensor& input, const Tensor& RunningMean, const Tensor& RunningVariance);
-
-    size_t batchsize;
-    Tensor runningMean;
-    Tensor runningVariance;
-
-};
 
 
 class CudnnBatchNormHandle: public BatchNormHandle{
     public:
-      CudnnBatchNormHandle(const float momentum, const Tensor& input, const Tensor& RunningMean, const Tensor& RunningVariance);
+      CudnnBatchNormHandle(const float momentum, const Tensor& input);
 
     size_t batchsize;
-    Tensor runningMean;
-    Tensor runningVariance;
 };
 
-Tensor GpuBatchNormForwardTraining(const Tensor& x, const Tensor& bnScale, const Tensor& bnBias, 
-   const std::vector<Tensor>& cache, CudnnBatchNormHandle &cbnh);
+const vector<Tensor> GpuBatchNormForwardTraining(const CudnnBatchNormHandle &cbnh, 
+  const Tensor& x, const Tensor& bnScale, const Tensor& bnBias, Tensor& running_mean, Tensor& running_var);
 
-Tensor GpuBatchNormForwardInference(const Tensor& x, const Tensor& bnScale, const Tensor& bnBias, const CudnnBatchNormHandle &cbnh);
+Tensor GpuBatchNormForwardInference(const CudnnBatchNormHandle &cbnh, const Tensor& x, 
+  const Tensor& bnScale, const Tensor& bnBias,  const Tensor& running_mean, const Tensor& running_var);
 
-std::vector<Tensor> GpuBatchNormBackward(const Tensor& dy,
-  const std::vector<Tensor>& cache, const CudnnBatchNormHandle &cbnh);
+const std::vector<Tensor> GpuBatchNormBackward(const CudnnBatchNormHandle &cbnh, 
+  const Tensor& dy, const Tensor& x, const Tensor& bnScale, const Tensor& mean, const Tensor& var);
      
-
+#endif  // USE_CUDNN
 }
 

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/ce1a7335/src/model/operation/batchnorm.cc
----------------------------------------------------------------------
diff --git a/src/model/operation/batchnorm.cc b/src/model/operation/batchnorm.cc
index 6b2421d..7040895 100755
--- a/src/model/operation/batchnorm.cc
+++ b/src/model/operation/batchnorm.cc
@@ -2,8 +2,7 @@
 
 namespace singa {
 
-BatchNormHandle::BatchNormHandle(const float momentum, const Tensor& input, const Tensor& RunningMean,
-                                 const Tensor& RunningVariance) {
+BatchNormHandle::BatchNormHandle(const float momentum, const Tensor& input) {
   factor = momentum;
   batchsize = input.shape(0);
   channels = input.shape(1);
@@ -18,12 +17,11 @@ BatchNormHandle::BatchNormHandle(const float momentum, const Tensor& input, cons
   } else {
     LOG(FATAL) << "The dimension of input should either be 4D or 2D.";
   }
-  runningMean = RunningMean;
-  runningVariance = RunningVariance;
 };
 
-CudnnBatchNormHandle::CudnnBatchNormHandle(const float momentum, const Tensor& input, const Tensor& RunningMean,
-    const Tensor& RunningVariance): BatchNormHandle(momentum, input, RunningMean, RunningVariance) {
+#if USE_CUDNN
+CudnnBatchNormHandle::CudnnBatchNormHandle(const float momentum,
+    const Tensor& input): BatchNormHandle(momentum, input) {
   if (is_2d)
     mode = CUDNN_BATCHNORM_PER_ACTIVATION;
   else
@@ -40,85 +38,77 @@ CudnnBatchNormHandle::CudnnBatchNormHandle(const float momentum, const Tensor& i
                                          1, 1));
 };
 
-Tensor GpuBatchNormForwardTraining(const Tensor& x, const Tensor& bnScale, const Tensor& bnBias,
-                                   const std::vector<Tensor>& cache, const CudnnBatchNormHandle &cbnh) {
+Tensor GpuBatchNormForwardTraining(const CudnnBatchNormHandle &cbnh,
+                                   const Tensor& x, const Tensor& bnScale, const Tensor& bnBias,
+                                   Tensor& running_mean, Tensor& running_var) {
   CHECK_EQ(x.device()->lang(), kCuda);
   CHECK_EQ(bnScale.device()->lang(), kCuda);
   CHECK_EQ(bnBias.device()->lang(), kCuda);
-  CHECK_EQ(cbnh.runningMean.device()->lang(), kCuda);
-  CHECK_EQ(cbnh.runningVariance.device()->lang(), kCuda);
-  CHECK_EQ(cache[1].device()->lang(), kCuda);  //resultmean
-  CHECK_EQ(cache[2].device()->lang(), kCuda);  //resultvariance
+  CHECK_EQ(runningMean.device()->lang(), kCuda);
+  CHECK_EQ(runningVariance.device()->lang(), kCuda);
+
+  Tensor mean, var;
+  mean.ResetLike(running_mean);
+  var.ResetLike(running_var);
 
   Shape shape = x.shape();
-  Tensor output;
-  Tensor input;  //for unification of 2d and 4d cases.
+
+  Tensor input = x;  //for unification of 2d and 4d cases.
   if (cbnh.is_2d)
-    input = Reshape(x, Shape{shape.at(0), shape.at(1), 1, 1});
-  else
-    input = x;
+    input.Reshape(Shape{shape.at(0), shape.at(1), 1, 1});
+
+  Tensor output;
   output.ResetLike(x);
 
   output.device()->Exec(
-  [&output, &input, &bnScale, &bnBias, &cache, &cbnh](Context * ctx) {
-    Block* inBlock = input.block(), * outBlock = output.block(),
-           * saveMeanBlock = cache[1].block(),
-             * saveVarBlock = cache[2].block(),
-               * runningMeanBlock = cbnh.runningMean.block(),
-                 * runningVarBlock = cbnh.runningVariance.block(),
-                   * bnScaleBlock = bnScale.block(),
-                     * bnBiasBlock = bnBias.block();
+  [&](Context * ctx) {
     const float alpha = 1.0f, beta = 0.0f;
     double epsilon = CUDNN_BN_MIN_EPSILON;
     CUDNN_CHECK(cudnnBatchNormalizationForwardTraining(
                   ctx->cudnn_handle, cbnh.mode, &alpha, &beta, cbnh.shape_desc,
-                  inBlock->data(), cbnh.shape_desc, outBlock->mutable_data(),
-                  cbnh.param_desc, bnScaleBlock->data(), bnBiasBlock->data(), cbnh.factor,
-                  runningMeanBlock->mutable_data(), runningVarBlock->mutable_data(),
-                  epsilon, saveMeanBlock->mutable_data(),
-                  saveVarBlock->mutable_data()));
+                  input.block()->data(), cbnh.shape_desc, output.block()->mutable_data(),
+                  cbnh.param_desc, bnScale.block()->data(), bnBias.block()->data(), cbnh.factor,
+                  running_mean.block()->mutable_data(), running_var.block()->mutable_data(),
+                  epsilon, mean.block()->mutable_data(),
+                  var.block()->mutable_data()));
   },
-  {input.block(), bnScale.block(), bnBias.block()},
-  { output.block(), cbnh.runningMean.block(), cbnh.runningVariance.block(),
-    cache[1].block(), cache[2].block()
+  {input.block(), bnScale.block(), bnBias.block(), running_mean.block(), running_var.block()}, {
+    output.block(), running_mean.block(), running_var.block(),
+    mean.block(), var.block()
   });
   if (cbnh.is_2d) output.Reshape(Shape{shape.at(0), shape.at(1)});
-  return output;
+  return {output, mean, var};
 };
 
-Tensor GpuBatchNormForwardInference(const Tensor& x, const Tensor& bnScale, const Tensor& bnBias,
-                                    const CudnnBatchNormHandle &cbnh) {
+Tensor GpuBatchNormForwardInference(const CudnnBatchNormHandle &cbnh,
+                                    const Tensor& x, const Tensor& bnScale,
+                                    const Tensor& bnBias, const Tensor& running_mean, const Tensor& running_var) {
   CHECK_EQ(x.device()->lang(), kCuda);
   CHECK_EQ(bnScale.device()->lang(), kCuda);
   CHECK_EQ(bnBias.device()->lang(), kCuda);
-  CHECK_EQ(cbnh.runningMean.device()->lang(), kCuda);
-  CHECK_EQ(cbnh.runningVariance.device()->lang(), kCuda);
+  CHECK_EQ(cbnh.running_mean.device()->lang(), kCuda);
+  CHECK_EQ(cbnh.running_variance.device()->lang(), kCuda);
 
   Shape shape = x.shape();
-  Tensor output;
-  Tensor input;  //for unification of 2d and 4d cases.
+
+  Tensor input = x;  //for unification of 2d and 4d cases.
   if (cbnh.is_2d)
-    input = Reshape(x, Shape{shape.at(0), shape.at(1), 1, 1});
-  else
-    input = x;
+    input.Reshape(Shape{shape.at(0), shape.at(1), 1, 1});
+
+  Tensor output;
   output.ResetLike(x);
   output.device()->Exec(
-  [&output, &input, &bnScale, &bnBias, &cbnh](Context * ctx) {
-    Block* inBlock = input.block(), * outBlock = output.block(),
-           * runningMeanBlock = cbnh.runningMean.block(),
-             * runningVarBlock = cbnh.runningVariance.block(),
-               * bnScaleBlock = bnScale.block(),
-                 * bnBiasBlock = bnBias.block();
+  [&](Context * ctx) {
     const float alpha = 1.0f, beta = 0.0f;
     double epsilon = CUDNN_BN_MIN_EPSILON;
     CUDNN_CHECK(cudnnBatchNormalizationForwardInference(
                   ctx->cudnn_handle, cbnh.mode, &alpha, &beta, cbnh.shape_desc,
-                  inBlock->data(), cbnh.shape_desc, outBlock->mutable_data(),
-                  cbnh.param_desc, bnScaleBlock->data(), bnBiasBlock->data(),
-                  runningMeanBlock->data(), runningVarBlock->data(), epsilon));
-  },
-  { input.block(), bnScale.block(), bnBias.block(), cbnh.runningMean.block(),
-    cbnh.runningVariance.block()
+                  input.block()->data(), cbnh.shape_desc, output.block()->mutable_data(),
+                  cbnh.param_desc, bnScale.block()->data(), bnBias.block()->data(),
+                  running_mean.block()->data(), running_var.block()->data(), epsilon));
+  }, {
+    input.block(), bnScale.block(), bnBias.block(), running_mean.block(),
+    running_variance.block()
   },
   {output.block()});
   if (cbnh.is_2d) output.Reshape(Shape{shape.at(0), shape.at(1)});
@@ -126,52 +116,43 @@ Tensor GpuBatchNormForwardInference(const Tensor& x, const Tensor& bnScale, cons
 };
 
 
-std::vector<Tensor> GpuBatchNormBackward(const Tensor& dy, const std::vector<Tensor>& cache, const CudnnBatchNormHandle &cbnh) {
+std::vector<Tensor> GpuBatchNormBackward(const CudnnBatchNormHandle &cbnh,
+    const Tensor& dy, const Tensor& x, const Tensor& bnScale, const Tensor& mean,
+    const Tensor& var) {
   CHECK_EQ(dy.device()->lang(), kCuda);
-  CHECK_EQ(cache[0].device()->lang(), kCuda);
-  CHECK_EQ(cache[1].device()->lang(), kCuda);
-  CHECK_EQ(cache[2].device()->lang(), kCuda);
-  CHECK_EQ(cache[3].device()->lang(), kCuda);
+  CHECK_EQ(x.device()->lang(), kCuda);
+  CHECK_EQ(bnScale.device()->lang(), kCuda);
+  CHECK_EQ(mean.device()->lang(), kCuda);
+  CHECK_EQ(var.device()->lang(), kCuda);
 
   vector<Tensor> out_grads;
   Tensor dx;
   dx.ResetLike(dy);
 
   Tensor dbnScale;
-  dbnScale.ResetLike(cache[3]);
+  dbnScale.ResetLike(bnScale);
 
   Tensor dbnBias;
-  dbnBias.ResetLike(cache[3]);
-  //dbnBias.ResetLike(bnBias);
+  dbnBias.ResetLike(bnScale);
 
   dx.device()->Exec(
-  [&dx, &dbnScale, &dbnBias, &dy, &cache, &cbnh](Context * ctx) {
-    Block* dyblock = dy.block(), * dxblock = dx.block(),
-           * xblock = cache[0].block(), * bnScaleBlock = cache[3].block(),
-             * dbnScaleBlock = dbnScale.block(),
-               * dbnBiasBlock = dbnBias.block(),
-                 * saveMeanBlock = cache[1].block(),
-                   * saveVarBlock = cache[2].block();
+  [&](Context * ctx) {
+    
     const float alpha = 1.0f, beta = .0f;
     double epsilon = CUDNN_BN_MIN_EPSILON;
     CUDNN_CHECK(cudnnBatchNormalizationBackward(
                   ctx->cudnn_handle, cbnh.mode, &alpha, &beta, &alpha, &beta,
-                  cbnh.shape_desc, xblock->data(), cbnh.shape_desc, dyblock->data(),
-                  cbnh.shape_desc, dxblock->mutable_data(), cbnh.param_desc,
-                  bnScaleBlock->data(), dbnScaleBlock->mutable_data(),
-                  dbnBiasBlock->mutable_data(), epsilon, saveMeanBlock->data(),
-                  saveVarBlock->data()));
-  },
-  { cache[0].block(), dy.block(), cache[3].block(), cache[1].block(),
-    cache[2].block()
-  },
+                  cbnh.shape_desc, x.block()->data(), cbnh.shape_desc, dy.block()->data(),
+                  cbnh.shape_desc, dx.block()->mutable_data(), cbnh.param_desc,
+                  bnScale.block()->data(), dbnScale.block()->mutable_data(),
+                  dbnBias.block()->mutable_data(), epsilon, mean.block()->data(),
+                  var.block()->data()));
+  }, {x.block(), dy.block(), bnScale.block(), mean.block(), var.block()},
   {dx.block(), dbnScale.block(), dbnBias.block()});
 
   if (cbnh.is_2d) dx.Reshape(Shape{dx.shape().at(0), dx.shape().at(1)});
-  out_grads.push_back(dx);
-  out_grads.push_back(dbnScale);
-  out_grads.push_back(dbnBias);
-  return out_grads;
+  
+  return {dx, dbnScale, dbnBias};
 };
 
 }

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/ce1a7335/src/model/operation/batchnorm.h
----------------------------------------------------------------------
diff --git a/src/model/operation/batchnorm.h b/src/model/operation/batchnorm.h
index ee182f9..f4372e3 100755
--- a/src/model/operation/batchnorm.h
+++ b/src/model/operation/batchnorm.h
@@ -12,8 +12,8 @@
 namespace singa {
 
 class BatchNormHandle {
-public:
-  BatchNormHandle(const float momentum, const Tensor& input, const Tensor& RunningMean, const Tensor& RunningVariance);
+ public:
+  BatchNormHandle(const float momentum, const Tensor& input);
 
   float factor;
 
@@ -21,10 +21,6 @@ public:
   size_t channels;
   size_t height;
   size_t width;
-
-  Tensor runningMean;
-  Tensor runningVariance;
-
   bool is_2d;
   //bool train = true;
 };
@@ -39,8 +35,8 @@ public:
 #ifdef USE_CUDNN
 
 class CudnnBatchNormHandle: public BatchNormHandle {
-public:
-  CudnnBatchNormHandle(const float momentum, const Tensor& input, const Tensor& RunningMean, const Tensor& RunningVariance);
+ public:
+  CudnnBatchNormHandle(const float momentum, const Tensor& input);
 
   //~CudnnBatchNormHandle();
 
@@ -49,13 +45,17 @@ public:
   cudnnTensorDescriptor_t param_desc = nullptr;
 };
 
-Tensor GpuBatchNormForwardTraining(const Tensor& x, const Tensor& bnScale, const Tensor& bnBias,
-                                   const std::vector<Tensor>& cache, const CudnnBatchNormHandle &cbnh);
+const std::vector<Tensor> GpuBatchNormForwardTraining(const CudnnBatchNormHandle
+    &cbnh, const Tensor& x, const Tensor& bnScale, const Tensor& bnBias,
+    Tensor& running_mean, Tensor& running_var);
 
-Tensor GpuBatchNormForwardInference(const Tensor& x, const Tensor& bnScale, const Tensor& bnBias,
-                                    const CudnnBatchNormHandle &cbnh);
+Tensor GpuBatchNormForwardInference(const CudnnBatchNormHandle &cbnh,
+                                    const Tensor& x, const Tensor& bnScale, const Tensor& bnBias,
+                                    const Tensor& running_mean, const Tensor& running_var);
 
-std::vector<Tensor> GpuBatchNormBackward(const Tensor& dy, const std::vector<Tensor>& cache, const CudnnBatchNormHandle &cbnh);
+const std::vector<Tensor> GpuBatchNormBackward(const CudnnBatchNormHandle &cbnh,
+    const Tensor& dy, const Tensor& x, const Tensor& bnScale, const Tensor& mean,
+    const Tensor& var);
 
 #endif  // USE_CUDNN