You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@singa.apache.org by zh...@apache.org on 2018/07/12 08:39:49 UTC

[1/5] incubator-singa git commit: SINGA-379 Implement batchnorm operation and its related functions for autograd

Repository: incubator-singa
Updated Branches:
  refs/heads/master b30d7ea55 -> f134a24e2


SINGA-379 Implement batchnorm operation and its related functions for autograd

- format former codes and rename some variables.

- fixed some make error


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/10274f3b
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/10274f3b
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/10274f3b

Branch: refs/heads/master
Commit: 10274f3bf82106595305c58644c69353f4b414a8
Parents: a105b24
Author: xuewanqi <xu...@outlook.com>
Authored: Wed Jul 11 02:59:35 2018 +0000
Committer: Wang Wei <wa...@gmail.com>
Committed: Wed Jul 11 21:57:47 2018 +0800

----------------------------------------------------------------------
 src/api/model_operation.i        |  16 +--
 src/model/operation/batchnorm.cc | 246 +++++++++++++++++-----------------
 src/model/operation/batchnorm.h  |  48 +++----
 3 files changed, 158 insertions(+), 152 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/10274f3b/src/api/model_operation.i
----------------------------------------------------------------------
diff --git a/src/api/model_operation.i b/src/api/model_operation.i
index 783a1f8..95efd26 100755
--- a/src/api/model_operation.i
+++ b/src/api/model_operation.i
@@ -51,28 +51,28 @@ Tensor GpuConvBackwardb(const Tensor &dy, const Tensor &b, const CudnnConvHandle
 
 class BatchNormHandle{
   public:
-    BatchNormHandle(const float momentum, const Tensor& input, const Tensor& RunningMean_, const Tensor& RunningVariance_);
+    BatchNormHandle(const float momentum, const Tensor& input, const Tensor& RunningMean, const Tensor& RunningVariance);
 
     size_t batchsize;
-    Tensor runningMean_;
-    Tensor runningVariance_;
+    Tensor runningMean;
+    Tensor runningVariance;
 
 };
 
 
 class CudnnBatchNormHandle: public BatchNormHandle{
     public:
-      CudnnBatchNormHandle(const float momentum, const Tensor& input, const Tensor& RunningMean_, const Tensor& RunningVariance_);
+      CudnnBatchNormHandle(const float momentum, const Tensor& input, const Tensor& RunningMean, const Tensor& RunningVariance);
 
     size_t batchsize;
-    Tensor runningMean_;
-    Tensor runningVariance_;
+    Tensor runningMean;
+    Tensor runningVariance;
 };
 
-Tensor GpuBatchNormForwardTraining(const Tensor& x, const Tensor& bnScale_, const Tensor& bnBias_, 
+Tensor GpuBatchNormForwardTraining(const Tensor& x, const Tensor& bnScale, const Tensor& bnBias, 
    std::vector<Tensor>& cache, CudnnBatchNormHandle &cbnh);
 
-Tensor GpuBatchNormForwardInference(const Tensor& x, const Tensor& bnScale_, const Tensor& bnBias_, const CudnnBatchNormHandle &cbnh);
+Tensor GpuBatchNormForwardInference(const Tensor& x, const Tensor& bnScale, const Tensor& bnBias, const CudnnBatchNormHandle &cbnh);
 
 std::vector<Tensor> GpuBatchNormBackward(const Tensor& dy,
   const std::vector<Tensor>& cache, const CudnnBatchNormHandle &cbnh);

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/10274f3b/src/model/operation/batchnorm.cc
----------------------------------------------------------------------
diff --git a/src/model/operation/batchnorm.cc b/src/model/operation/batchnorm.cc
old mode 100644
new mode 100755
index 9b6f9cd..145b90b
--- a/src/model/operation/batchnorm.cc
+++ b/src/model/operation/batchnorm.cc
@@ -1,164 +1,170 @@
 #include "./batchnorm.h"
 
-namespace singa{
-
-BatchNormHandle::BatchNormHandle(const float momentum, const Tensor& input, const Tensor& RunningMean_, 
-  const Tensor& RunningVariance_){
-  factor_ = momentum;
-  batchsize = input.shape()[0];
-  channels_= input.shape()[2];
-  if (input.nDim()== 4u){
-    height_= input.shape()[3];
-    width_=input.shape()[4];
-    is_2d_= false;
-  }else{
-    size_t height_ = 1;
-    size_t width_ = 1;
-    bool is_2d_ = true;
+namespace singa {
+
+BatchNormHandle::BatchNormHandle(const float momentum, const Tensor& input, const Tensor& RunningMean,
+                                 const Tensor& RunningVariance) {
+  factor = momentum;
+  batchsize = input.shape(0);
+  channels = input.shape(1);
+  if (input.nDim() == 4u) {
+    height = input.shape(2);
+    width = input.shape(3);
+    is_2d = false;
+  } else if (input.nDim() == 2u) {
+    height = 1;
+    width = 1;
+    is_2d = true;
+  } else {
+    LOG(FATAL) << "The dimension of input should either be 4D or 2D.";
   }
-  runningMean_= RunningMean_;
-  runningVariance_= RunningVariance_;
+  runningMean = RunningMean;
+  runningVariance = RunningVariance;
 };
 
-CudnnBatchNormHandle::CudnnBatchNormHandle(const float momentum, const Tensor& input, const Tensor& RunningMean_, 
-  const Tensor& RunningVariance_):BatchNormHandle(momentum, input, RunningMean_, RunningVariance_){
-  if (is_2d_)
-      mode_ = CUDNN_BATCHNORM_PER_ACTIVATION;
+CudnnBatchNormHandle::CudnnBatchNormHandle(const float momentum, const Tensor& input, const Tensor& RunningMean,
+    const Tensor& RunningVariance): BatchNormHandle(momentum, input, RunningMean, RunningVariance) {
+  if (is_2d)
+    mode = CUDNN_BATCHNORM_PER_ACTIVATION;
   else
-      mode_ = CUDNN_BATCHNORM_SPATIAL;
-  auto dtype = input.data_type();
-  CUDNN_CHECK(cudnnCreateTensorDescriptor(&shape_desc_));
-  CUDNN_CHECK(cudnnCreateTensorDescriptor(&param_desc_));
-  CUDNN_CHECK(cudnnSetTensor4dDescriptor(shape_desc_, CUDNN_TENSOR_NCHW,
-                                         GetCudnnDataType(dtype), batchsize,
-                                         channels_, height_, width_));
-  CUDNN_CHECK(cudnnSetTensor4dDescriptor(param_desc_, CUDNN_TENSOR_NCHW,
-                                         GetCudnnDataType(dtype), 1, channels_,
+    mode = CUDNN_BATCHNORM_SPATIAL;
+  DataType dtype = input.data_type();
+  CUDNN_CHECK(cudnnCreateTensorDescriptor(&shape_desc));
+  CUDNN_CHECK(cudnnCreateTensorDescriptor(&param_desc));
+  CUDNN_CHECK(cudnnSetTensor4dDescriptor(shape_desc, CUDNN_TENSOR_NCHW,
+                                         GetCudnnDataType(dtype),
+                                         batchsize,
+                                         channels, height, width));
+  CUDNN_CHECK(cudnnSetTensor4dDescriptor(param_desc, CUDNN_TENSOR_NCHW,
+                                         GetCudnnDataType(dtype), 1, channels,
                                          1, 1));
-  };
+};
+
+Tensor GpuBatchNormForwardTraining(const Tensor& x, const Tensor& bnScale, const Tensor& bnBias,
+                                   std::vector<Tensor>& cache, const CudnnBatchNormHandle &cbnh) {
 
-Tensor GpuBatchNormForwardTraining(const Tensor& x, const Tensor& bnScale_, const Tensor& bnBias_, 
-  std::vector<Tensor>& cache, const CudnnBatchNormHandle &cbnh) {
-  
-  auto shape = x.shape();
+  Shape shape = x.shape();
   Tensor output;
   Tensor input;  //for unification of 2d and 4d cases.
-  if (cbnh.is_2d_)
+  if (cbnh.is_2d)
     input = Reshape(x, Shape{shape.at(0), shape.at(1), 1, 1});
   else
     input = x;
   output.ResetLike(x);
 
-  Tensor resultSaveMean_;
-  Tensor resultSaveVariance_;
+  Tensor resultSaveMean;
+  Tensor resultSaveVariance;
 
-  resultSaveMean_.Reshape(Shape{cbnh.channels_});
-  resultSaveVariance_.Reshape(Shape{cbnh.channels_});
+  resultSaveMean.Reshape(Shape{cbnh.channels});
+  resultSaveVariance.Reshape(Shape{cbnh.channels});
 
-  cache.push_back(resultSaveMean_);
-  cache.push_back(resultSaveVariance_);
-  cache.push_back(bnScale_);
+  cache.push_back(resultSaveMean);
+  cache.push_back(resultSaveVariance);
+  cache.push_back(bnScale);
   //cache={x, mean, var, scale}
 
-    output.device()->Exec(
-        [&output, &input, &bnScale_, &bnBias_, &cache, &cbnh](Context* ctx) {
-          Block* inBlock = input.block(), * outBlock = output.block(),
-                 * saveMeanBlock = cache[1].block(),
-                 * saveVarBlock = cache[2].block(),
-                 * runningMeanBlock = cbnh.runningMean_.block(),
-                 * runningVarBlock = cbnh.runningVariance_.block(),
-                 * bnScaleBlock = bnScale_.block(),
-                 * bnBiasBlock = bnBias_.block();
-          const float alpha = 1.0f, beta = 0.0f;
-          double epsilon = CUDNN_BN_MIN_EPSILON;
-          CUDNN_CHECK(cudnnBatchNormalizationForwardTraining(
-              ctx->cudnn_handle, cbnh.mode_, &alpha, &beta, cbnh.shape_desc_,
-              inBlock->data(), cbnh.shape_desc_, outBlock->mutable_data(),
-              cbnh.param_desc_, bnScaleBlock->data(), bnBiasBlock->data(), cbnh.factor_,
-              runningMeanBlock->mutable_data(), runningVarBlock->mutable_data(),
-              epsilon, saveMeanBlock->mutable_data(),
-              saveVarBlock->mutable_data()));
-        },
-        {input.block(), bnScale_.block(), bnBias_.block()},
-        {output.block(), cbnh.runningMean_.block(), cbnh.runningVariance_.block(),
-         cache[1].block(), cache[2].block()}); 
-  if (cbnh.is_2d_) output.Reshape(Shape{shape.at(0), shape.at(1)});
+  output.device()->Exec(
+  [&output, &input, &bnScale, &bnBias, &cache, &cbnh](Context * ctx) {
+    Block* inBlock = input.block(), * outBlock = output.block(),
+           * saveMeanBlock = cache[1].block(),
+             * saveVarBlock = cache[2].block(),
+               * runningMeanBlock = cbnh.runningMean.block(),
+                 * runningVarBlock = cbnh.runningVariance.block(),
+                   * bnScaleBlock = bnScale.block(),
+                     * bnBiasBlock = bnBias.block();
+    const float alpha = 1.0f, beta = 0.0f;
+    double epsilon = CUDNN_BN_MIN_EPSILON;
+    CUDNN_CHECK(cudnnBatchNormalizationForwardTraining(
+                  ctx->cudnn_handle, cbnh.mode, &alpha, &beta, cbnh.shape_desc,
+                  inBlock->data(), cbnh.shape_desc, outBlock->mutable_data(),
+                  cbnh.param_desc, bnScaleBlock->data(), bnBiasBlock->data(), cbnh.factor,
+                  runningMeanBlock->mutable_data(), runningVarBlock->mutable_data(),
+                  epsilon, saveMeanBlock->mutable_data(),
+                  saveVarBlock->mutable_data()));
+  },
+  {input.block(), bnScale.block(), bnBias.block()},
+  { output.block(), cbnh.runningMean.block(), cbnh.runningVariance.block(),
+    cache[1].block(), cache[2].block()
+  });
+  if (cbnh.is_2d) output.Reshape(Shape{shape.at(0), shape.at(1)});
   return output;
 };
 
-Tensor GpuBatchNormForwardInference(const Tensor& x, const Tensor& bnScale_, const Tensor& bnBias_, 
-   const CudnnBatchNormHandle &cbnh) {
-  auto shape = x.shape();
+Tensor GpuBatchNormForwardInference(const Tensor& x, const Tensor& bnScale, const Tensor& bnBias,
+                                    const CudnnBatchNormHandle &cbnh) {
+  Shape shape = x.shape();
   Tensor output;
   Tensor input;  //for unification of 2d and 4d cases.
-  if (cbnh.is_2d_)
+  if (cbnh.is_2d)
     input = Reshape(x, Shape{shape.at(0), shape.at(1), 1, 1});
   else
     input = x;
   output.ResetLike(x);
-    output.device()->Exec(
-        [&output, &input, &bnScale_, &bnBias_, &cbnh](Context* ctx) {
-          Block* inBlock = input.block(), * outBlock = output.block(),
-                 * runningMeanBlock = cbnh.runningMean_.block(),
-                 * runningVarBlock = cbnh.runningVariance_.block(),
-                 * bnScaleBlock = bnScale_.block(),
-                 * bnBiasBlock = bnBias_.block();
-          const float alpha = 1.0f, beta = 0.0f;
-          double epsilon = CUDNN_BN_MIN_EPSILON;
-          CUDNN_CHECK(cudnnBatchNormalizationForwardInference(
-              ctx->cudnn_handle, cbnh.mode_, &alpha, &beta, cbnh.shape_desc_,
-              inBlock->data(), cbnh.shape_desc_, outBlock->mutable_data(),
-              cbnh.param_desc_, bnScaleBlock->data(), bnBiasBlock->data(),
-              runningMeanBlock->data(), runningVarBlock->data(), epsilon));
-        },
-        {input.block(), bnScale_.block(), bnBias_.block(), cbnh.runningMean_.block(),
-         cbnh.runningVariance_.block()},
-        {output.block()});
-  if (cbnh.is_2d_) output.Reshape(Shape{shape.at(0), shape.at(1)});
+  output.device()->Exec(
+  [&output, &input, &bnScale, &bnBias, &cbnh](Context * ctx) {
+    Block* inBlock = input.block(), * outBlock = output.block(),
+           * runningMeanBlock = cbnh.runningMean.block(),
+             * runningVarBlock = cbnh.runningVariance.block(),
+               * bnScaleBlock = bnScale.block(),
+                 * bnBiasBlock = bnBias.block();
+    const float alpha = 1.0f, beta = 0.0f;
+    double epsilon = CUDNN_BN_MIN_EPSILON;
+    CUDNN_CHECK(cudnnBatchNormalizationForwardInference(
+                  ctx->cudnn_handle, cbnh.mode, &alpha, &beta, cbnh.shape_desc,
+                  inBlock->data(), cbnh.shape_desc, outBlock->mutable_data(),
+                  cbnh.param_desc, bnScaleBlock->data(), bnBiasBlock->data(),
+                  runningMeanBlock->data(), runningVarBlock->data(), epsilon));
+  },
+  { input.block(), bnScale.block(), bnBias.block(), cbnh.runningMean.block(),
+    cbnh.runningVariance.block()
+  },
+  {output.block()});
+  if (cbnh.is_2d) output.Reshape(Shape{shape.at(0), shape.at(1)});
   return output;
 };
 
 
-std::vector<Tensor> GpuBatchNormBackward(const Tensor& dy, const std::vector<Tensor>& cache, const CudnnBatchNormHandle &cbnh){
+std::vector<Tensor> GpuBatchNormBackward(const Tensor& dy, const std::vector<Tensor>& cache, const CudnnBatchNormHandle &cbnh) {
 
   vector<Tensor> out_grads;
   Tensor dx;
   dx.ResetLike(dy);
 
-  Tensor dbnScale_;
-  dbnScale_.ResetLike(cache[3]);
+  Tensor dbnScale;
+  dbnScale.ResetLike(cache[3]);
 
-  Tensor dbnBias_;
-  dbnBias_.ResetLike(cache[3]);
-  //dbnBias_.ResetLike(bnBias_);
+  Tensor dbnBias;
+  dbnBias.ResetLike(cache[3]);
+  //dbnBias.ResetLike(bnBias);
 
   dx.device()->Exec(
-      [&dx, &dbnScale_, &dbnBias_, &dy, &cache, &cbnh](Context* ctx) {
-        Block* dyblock = dy.block(), * dxblock = dx.block(),
-               * xblock = cache[0].block(), * bnScaleBlock = cache[3].block(),
-               * dbnScaleBlock = dbnScale_.block(),
-               * dbnBiasBlock = dbnBias_.block(),
-               * saveMeanBlock = cache[1].block(),
-               * saveVarBlock = cache[2].block();
-        const float alpha = 1.0f, beta = .0f;
-        double epsilon = CUDNN_BN_MIN_EPSILON;
-        CUDNN_CHECK(cudnnBatchNormalizationBackward(
-            ctx->cudnn_handle, cbnh.mode_, &alpha, &beta, &alpha, &beta,
-            cbnh.shape_desc_, xblock->data(), cbnh.shape_desc_, dyblock->data(),
-            cbnh.shape_desc_, dxblock->mutable_data(), cbnh.param_desc_,
-            bnScaleBlock->data(), dbnScaleBlock->mutable_data(),
-            dbnBiasBlock->mutable_data(), epsilon, saveMeanBlock->data(),
-            saveVarBlock->data()));
-      },
-      {cache[0].block(), dy.block(), cache[3].block(), cache[1].block(),
-       cache[2].block()},
-      {dx.block(), dbnScale_.block(), dbnBias_.block()});
-  
-  if (cbnh.is_2d_) dx.Reshape(Shape{dx.shape().at(0), dx.shape().at(1)});
+  [&dx, &dbnScale, &dbnBias, &dy, &cache, &cbnh](Context * ctx) {
+    Block* dyblock = dy.block(), * dxblock = dx.block(),
+           * xblock = cache[0].block(), * bnScaleBlock = cache[3].block(),
+             * dbnScaleBlock = dbnScale.block(),
+               * dbnBiasBlock = dbnBias.block(),
+                 * saveMeanBlock = cache[1].block(),
+                   * saveVarBlock = cache[2].block();
+    const float alpha = 1.0f, beta = .0f;
+    double epsilon = CUDNN_BN_MIN_EPSILON;
+    CUDNN_CHECK(cudnnBatchNormalizationBackward(
+                  ctx->cudnn_handle, cbnh.mode, &alpha, &beta, &alpha, &beta,
+                  cbnh.shape_desc, xblock->data(), cbnh.shape_desc, dyblock->data(),
+                  cbnh.shape_desc, dxblock->mutable_data(), cbnh.param_desc,
+                  bnScaleBlock->data(), dbnScaleBlock->mutable_data(),
+                  dbnBiasBlock->mutable_data(), epsilon, saveMeanBlock->data(),
+                  saveVarBlock->data()));
+  },
+  { cache[0].block(), dy.block(), cache[3].block(), cache[1].block(),
+    cache[2].block()
+  },
+  {dx.block(), dbnScale.block(), dbnBias.block()});
+
+  if (cbnh.is_2d) dx.Reshape(Shape{dx.shape().at(0), dx.shape().at(1)});
   out_grads.push_back(dx);
-  out_grads.push_back(dbnScale_);
-  out_grads.push_back(dbnBias_);
-return out_grads;
+  out_grads.push_back(dbnScale);
+  out_grads.push_back(dbnBias);
+  return out_grads;
 };
 
 }

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/10274f3b/src/model/operation/batchnorm.h
----------------------------------------------------------------------
diff --git a/src/model/operation/batchnorm.h b/src/model/operation/batchnorm.h
old mode 100644
new mode 100755
index f2da4cd..f21bd1d
--- a/src/model/operation/batchnorm.h
+++ b/src/model/operation/batchnorm.h
@@ -9,24 +9,24 @@
 #include "../layer/cudnn_utils.h" // check_cudnn
 #endif // USE_CUDNN 
 
-namespace singa{
+namespace singa {
 
-class BatchNormHandle{
-  public:
-  	BatchNormHandle(const float momentum, const Tensor& input, const Tensor& RunningMean_, const Tensor& RunningVariance_);
+class BatchNormHandle {
+public:
+  BatchNormHandle(const float momentum, const Tensor& input, const Tensor& RunningMean, const Tensor& RunningVariance);
 
-  	float factor_;
-  	size_t channels_;
-  	size_t batchsize;
+  float factor;
 
-  	Tensor runningMean_;
-  	Tensor runningVariance_;
+  size_t batchsize;
+  size_t channels;
+  size_t height;
+  size_t width;
 
-  	bool is_2d_ ;
-  	//bool train = true;
+  Tensor runningMean;
+  Tensor runningVariance;
 
-  	size_t height_;
-  	size_t width_;
+  bool is_2d;
+  //bool train = true;
 };
 
 //Tensor CpuBatchNormForwardTraining();
@@ -38,22 +38,22 @@ class BatchNormHandle{
 
 #ifdef USE_CUDNN
 
-class CudnnBatchNormHandle: public BatchNormHandle{
-    public:
-      CudnnBatchNormHandle(const float momentum, const Tensor& input, const Tensor& RunningMean_, const Tensor& RunningVariance_);
+class CudnnBatchNormHandle: public BatchNormHandle {
+public:
+  CudnnBatchNormHandle(const float momentum, const Tensor& input, const Tensor& RunningMean, const Tensor& RunningVariance);
 
-      //~CudnnBatchNormHandle();
+  //~CudnnBatchNormHandle();
 
-      cudnnBatchNormMode_t mode_;
-      cudnnTensorDescriptor_t shape_desc_ = nullptr;
-      cudnnTensorDescriptor_t param_desc_ = nullptr;
+  cudnnBatchNormMode_t mode;
+  cudnnTensorDescriptor_t shape_desc = nullptr;
+  cudnnTensorDescriptor_t param_desc = nullptr;
 };
 
-Tensor GpuBatchNormForwardTraining(const Tensor& x, const Tensor& bnScale_, const Tensor& bnBias_, 
-  std::vector<Tensor>& cache, const CudnnBatchNormHandle &cbnh);
+Tensor GpuBatchNormForwardTraining(const Tensor& x, const Tensor& bnScale, const Tensor& bnBias,
+                                   std::vector<Tensor>& cache, const CudnnBatchNormHandle &cbnh);
 
-Tensor GpuBatchNormForwardInference(const Tensor& x, const Tensor& bnScale_, const Tensor& bnBias_, 
-	const CudnnBatchNormHandle &cbnh);
+Tensor GpuBatchNormForwardInference(const Tensor& x, const Tensor& bnScale, const Tensor& bnBias,
+                                    const CudnnBatchNormHandle &cbnh);
 
 std::vector<Tensor> GpuBatchNormBackward(const Tensor& dy, const std::vector<Tensor>& cache, const CudnnBatchNormHandle &cbnh);
 


[5/5] incubator-singa git commit: SINGA-379 Implement batchnorm operation and its related functions for autograd

Posted by zh...@apache.org.
SINGA-379 Implement batchnorm operation and its related functions for autograd

Test mnist_cnn.py with batchnorm


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/f134a24e
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/f134a24e
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/f134a24e

Branch: refs/heads/master
Commit: f134a24e2b58baad9dc29167e323d14cdf89d2a4
Parents: ce1a733
Author: wang wei <wa...@comp.nus.edu.sg>
Authored: Thu Jul 12 12:28:41 2018 +0800
Committer: wang wei <wa...@comp.nus.edu.sg>
Committed: Thu Jul 12 12:33:04 2018 +0800

----------------------------------------------------------------------
 examples/autograd/mnist_cnn.py   |  4 ++++
 python/singa/autograd.py         | 10 +++++-----
 src/api/model_layer.i            | 18 +++++++++---------
 src/api/model_operation.i        |  9 +++++----
 src/model/operation/batchnorm.cc | 31 ++++++++++++++-----------------
 5 files changed, 37 insertions(+), 35 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/f134a24e/examples/autograd/mnist_cnn.py
----------------------------------------------------------------------
diff --git a/examples/autograd/mnist_cnn.py b/examples/autograd/mnist_cnn.py
index f78ccc8..b1d8dbe 100755
--- a/examples/autograd/mnist_cnn.py
+++ b/examples/autograd/mnist_cnn.py
@@ -106,15 +106,19 @@ if __name__ == '__main__':
 
     # operations initialization
     conv1 = autograd.Conv2D(1, 32, 3, padding=1, bias=False)
+    bn1 = autograd.BatchNorm(32)
     conv2 = autograd.Conv2D(32, 32, 3, padding=1)
+    bn2 = autograd.BatchNorm(32)
     linear = autograd.Linear(32 * 28 * 28, 10)
 
 
     def forward(x, t):
         y = conv1(x)
         y = autograd.relu(y)
+        y = bn1(y)
         y = autograd.max_pool_2d(y)
         y = conv2(y)
+        y = bn2(y)
         y = autograd.relu(y)
         y = autograd.max_pool_2d(y)
         y=autograd.flatten(y)

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/f134a24e/python/singa/autograd.py
----------------------------------------------------------------------
diff --git a/python/singa/autograd.py b/python/singa/autograd.py
index 3a2eddd..d272dcd 100755
--- a/python/singa/autograd.py
+++ b/python/singa/autograd.py
@@ -771,7 +771,7 @@ class Conv2D(Layer):
         return y
 
 
-class BatchNorm2d(Layer):
+class BatchNorm(Layer):
 
     def __init__(self, num_features, momentum=0.9):
         self.channels = num_features
@@ -810,12 +810,12 @@ class BatchNorm2d(Layer):
                     self.momentum, x.data)
         self.handle.device_id = x.device.id()
 
-        y = batchnorm2d(x, self.scale, self.bias,
+        y = batchnorm(x, self.scale, self.bias,
                         self.running_mean, self.running_var, self.handle)
         return y
 
 
-class _BatchNorm2d(Operation):
+class _BatchNorm(Operation):
 
     def __init__(self, running_mean, running_var, handle):
         self.running_mean = running_mean.data
@@ -855,5 +855,5 @@ class _BatchNorm2d(Operation):
             return dx, ds, db
 
 
-def batchnorm2d(x, scale, bias, running_mean, running_var, handle):
-    return _BatchNorm2d(running_mean, running_var, handle)(x, scale, bias)[0]
+def batchnorm(x, scale, bias, running_mean, running_var, handle):
+    return _BatchNorm(running_mean, running_var, handle)(x, scale, bias)[0]

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/f134a24e/src/api/model_layer.i
----------------------------------------------------------------------
diff --git a/src/api/model_layer.i b/src/api/model_layer.i
index d449f24..dc04be0 100644
--- a/src/api/model_layer.i
+++ b/src/api/model_layer.i
@@ -29,21 +29,21 @@
 
 
 %{
-// To make the code compatible between py2 and py3, the follow 
-// macro is required, which forces the 
-// interface (function) to accept byte string (from python) and 
-// return byte string (in python) in py3. Otherwise the strings 
+// To make the code compatible between py2 and py3, the follow
+// macro is required, which forces the
+// interface (function) to accept byte string (from python) and
+// return byte string (in python) in py3. Otherwise the strings
 // should be unicode strings in py3.
 // Note that by default the strings in python3 are of type unicode.
-// You have to encode it with the correct encoding (default is utf-8) 
+// You have to encode it with the correct encoding (default is utf-8)
 // to convert it into bytes. Sometimes, the string is already byte string
 // e.g. from protobuf SerializeToString, then there is no need to do
 // conversion. The output byte strings should be decoded into unicode.
-// For python2, the default type of string is byte string. 
+// For python2, the default type of string is byte string.
 //
-// Because protobuf::SerializeToString cannot be decoded into unicode 
-// string, we cannot use SWIG_PYTHON_2_UNICODE which forces the 
-// interface (function) to accept unicode strings as input args 
+// Because protobuf::SerializeToString cannot be decoded into unicode
+// string, we cannot use SWIG_PYTHON_2_UNICODE which forces the
+// interface (function) to accept unicode strings as input args
 // and return unicode strings.
 //
 // TODO(wangwei) make strings compatible between py2 and py3.

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/f134a24e/src/api/model_operation.i
----------------------------------------------------------------------
diff --git a/src/api/model_operation.i b/src/api/model_operation.i
index 6f2d1fa..eb41fd0 100755
--- a/src/api/model_operation.i
+++ b/src/api/model_operation.i
@@ -7,6 +7,7 @@
 #include "../src/model/operation/convolution.h"
 #include "../src/model/operation/batchnorm.h"
 %}
+
 namespace singa {
 
 class ConvHandle {
@@ -68,15 +69,15 @@ class CudnnBatchNormHandle: public BatchNormHandle{
     size_t batchsize;
 };
 
-const vector<Tensor> GpuBatchNormForwardTraining(const CudnnBatchNormHandle &cbnh, 
+const std::vector<Tensor> GpuBatchNormForwardTraining(const CudnnBatchNormHandle &cbnh,
   const Tensor& x, const Tensor& bnScale, const Tensor& bnBias, Tensor& running_mean, Tensor& running_var);
 
-Tensor GpuBatchNormForwardInference(const CudnnBatchNormHandle &cbnh, const Tensor& x, 
+Tensor GpuBatchNormForwardInference(const CudnnBatchNormHandle &cbnh, const Tensor& x,
   const Tensor& bnScale, const Tensor& bnBias,  const Tensor& running_mean, const Tensor& running_var);
 
-const std::vector<Tensor> GpuBatchNormBackward(const CudnnBatchNormHandle &cbnh, 
+const std::vector<Tensor> GpuBatchNormBackward(const CudnnBatchNormHandle &cbnh,
   const Tensor& dy, const Tensor& x, const Tensor& bnScale, const Tensor& mean, const Tensor& var);
-     
+
 #endif  // USE_CUDNN
 }
 

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/f134a24e/src/model/operation/batchnorm.cc
----------------------------------------------------------------------
diff --git a/src/model/operation/batchnorm.cc b/src/model/operation/batchnorm.cc
index 7040895..29eaba9 100755
--- a/src/model/operation/batchnorm.cc
+++ b/src/model/operation/batchnorm.cc
@@ -19,7 +19,7 @@ BatchNormHandle::BatchNormHandle(const float momentum, const Tensor& input) {
   }
 };
 
-#if USE_CUDNN
+#ifdef USE_CUDNN
 CudnnBatchNormHandle::CudnnBatchNormHandle(const float momentum,
     const Tensor& input): BatchNormHandle(momentum, input) {
   if (is_2d)
@@ -38,14 +38,14 @@ CudnnBatchNormHandle::CudnnBatchNormHandle(const float momentum,
                                          1, 1));
 };
 
-Tensor GpuBatchNormForwardTraining(const CudnnBatchNormHandle &cbnh,
+const std::vector<Tensor> GpuBatchNormForwardTraining(const CudnnBatchNormHandle &cbnh,
                                    const Tensor& x, const Tensor& bnScale, const Tensor& bnBias,
                                    Tensor& running_mean, Tensor& running_var) {
   CHECK_EQ(x.device()->lang(), kCuda);
   CHECK_EQ(bnScale.device()->lang(), kCuda);
   CHECK_EQ(bnBias.device()->lang(), kCuda);
-  CHECK_EQ(runningMean.device()->lang(), kCuda);
-  CHECK_EQ(runningVariance.device()->lang(), kCuda);
+  CHECK_EQ(running_mean.device()->lang(), kCuda);
+  CHECK_EQ(running_var.device()->lang(), kCuda);
 
   Tensor mean, var;
   mean.ResetLike(running_mean);
@@ -78,7 +78,7 @@ Tensor GpuBatchNormForwardTraining(const CudnnBatchNormHandle &cbnh,
   });
   if (cbnh.is_2d) output.Reshape(Shape{shape.at(0), shape.at(1)});
   return {output, mean, var};
-};
+}
 
 Tensor GpuBatchNormForwardInference(const CudnnBatchNormHandle &cbnh,
                                     const Tensor& x, const Tensor& bnScale,
@@ -86,8 +86,8 @@ Tensor GpuBatchNormForwardInference(const CudnnBatchNormHandle &cbnh,
   CHECK_EQ(x.device()->lang(), kCuda);
   CHECK_EQ(bnScale.device()->lang(), kCuda);
   CHECK_EQ(bnBias.device()->lang(), kCuda);
-  CHECK_EQ(cbnh.running_mean.device()->lang(), kCuda);
-  CHECK_EQ(cbnh.running_variance.device()->lang(), kCuda);
+  CHECK_EQ(running_mean.device()->lang(), kCuda);
+  CHECK_EQ(running_var.device()->lang(), kCuda);
 
   Shape shape = x.shape();
 
@@ -106,17 +106,13 @@ Tensor GpuBatchNormForwardInference(const CudnnBatchNormHandle &cbnh,
                   input.block()->data(), cbnh.shape_desc, output.block()->mutable_data(),
                   cbnh.param_desc, bnScale.block()->data(), bnBias.block()->data(),
                   running_mean.block()->data(), running_var.block()->data(), epsilon));
-  }, {
-    input.block(), bnScale.block(), bnBias.block(), running_mean.block(),
-    running_variance.block()
-  },
+  }, { input.block(), bnScale.block(), bnBias.block(), running_mean.block(), running_var.block() },
   {output.block()});
-  if (cbnh.is_2d) output.Reshape(Shape{shape.at(0), shape.at(1)});
   return output;
-};
+}
 
 
-std::vector<Tensor> GpuBatchNormBackward(const CudnnBatchNormHandle &cbnh,
+const std::vector<Tensor> GpuBatchNormBackward(const CudnnBatchNormHandle &cbnh,
     const Tensor& dy, const Tensor& x, const Tensor& bnScale, const Tensor& mean,
     const Tensor& var) {
   CHECK_EQ(dy.device()->lang(), kCuda);
@@ -137,7 +133,7 @@ std::vector<Tensor> GpuBatchNormBackward(const CudnnBatchNormHandle &cbnh,
 
   dx.device()->Exec(
   [&](Context * ctx) {
-    
+
     const float alpha = 1.0f, beta = .0f;
     double epsilon = CUDNN_BN_MIN_EPSILON;
     CUDNN_CHECK(cudnnBatchNormalizationBackward(
@@ -151,8 +147,9 @@ std::vector<Tensor> GpuBatchNormBackward(const CudnnBatchNormHandle &cbnh,
   {dx.block(), dbnScale.block(), dbnBias.block()});
 
   if (cbnh.is_2d) dx.Reshape(Shape{dx.shape().at(0), dx.shape().at(1)});
-  
+
   return {dx, dbnScale, dbnBias};
-};
+}
 
+#endif  //USE_CUDNN
 }


[4/5] incubator-singa git commit: SINGA-379 Implement batchnorm operation and its related functions for autograd

Posted by zh...@apache.org.
SINGA-379 Implement batchnorm operation and its related functions for autograd

Change the API (arguments) of batchnorm functions.


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/ce1a7335
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/ce1a7335
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/ce1a7335

Branch: refs/heads/master
Commit: ce1a73359a6e3eb2c3e7ec5ac861ac1829144dad
Parents: 8654f89
Author: Wang Wei <wa...@gmail.com>
Authored: Thu Jul 12 00:31:40 2018 +0800
Committer: wang wei <wa...@comp.nus.edu.sg>
Committed: Thu Jul 12 12:32:50 2018 +0800

----------------------------------------------------------------------
 python/singa/autograd.py         |  42 +++++-----
 src/api/model_operation.i        |  38 ++++-----
 src/model/operation/batchnorm.cc | 147 +++++++++++++++-------------------
 src/model/operation/batchnorm.h  |  26 +++---
 4 files changed, 118 insertions(+), 135 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/ce1a7335/python/singa/autograd.py
----------------------------------------------------------------------
diff --git a/python/singa/autograd.py b/python/singa/autograd.py
index 4ba0b11..3a2eddd 100755
--- a/python/singa/autograd.py
+++ b/python/singa/autograd.py
@@ -33,7 +33,6 @@ CTensor = singa.Tensor
 training = False
 
 
-
 def infer_dependency(op):
     '''
     Infer the dependency of all operations with the
@@ -483,6 +482,7 @@ def cross_entropy(y, t):
 
 
 class SoftMaxCrossEntropy(Operation):
+
     def forward(self, x, t):
         self.p = singa.SoftMax(x)
         self.t = t
@@ -771,7 +771,7 @@ class Conv2D(Layer):
         return y
 
 
-class BatchNorm2d(NewLayer):
+class BatchNorm2d(Layer):
 
     def __init__(self, num_features, momentum=0.9):
         self.channels = num_features
@@ -787,16 +787,16 @@ class BatchNorm2d(NewLayer):
                            requires_grad=True, stores_grad=True)
         self.bias.set_value(0.0)
 
-        self.runningmean = Tensor(
+        self.running_mean = Tensor(
             shape=param_shape, requires_grad=False, stores_grad=False)
-        self.runningvariance = Tensor(
+        self.running_var = Tensor(
             shape=param_shape, requires_grad=False, stores_grad=False)
 
     def __call__(self, x):
         assert x.shape[1] == self.channels, 'number of channels dismatched.'
 
         self.device_check(x, self.scale, self.bias,
-                          self.runningmean, self.runningvariance)
+                          self.running_mean, self.running_var)
 
         if x.device.id() == -1:
             raise NotImplementedError
@@ -804,39 +804,40 @@ class BatchNorm2d(NewLayer):
         else:
             if not hasattr(self, 'handle'):
                 self.handle = singa.CudnnBatchNormHandle(
-                    self.momentum, x.data, self.runningmean.data, self.runningvariance.data)
+                    self.momentum, x.data)
             elif x.shape[0] != self.handle.batchsize:
                 self.handle = singa.CudnnBatchNormHandle(
-                    self.momentum, x.data, self.runningmean.data, self.runningvariance.data)
+                    self.momentum, x.data)
         self.handle.device_id = x.device.id()
 
-        y = batchnorm2d(x, self.scale, self.bias, self.handle)
+        y = batchnorm2d(x, self.scale, self.bias,
+                        self.running_mean, self.running_var, self.handle)
         return y
 
 
 class _BatchNorm2d(Operation):
 
-    def __init__(self, handle):
+    def __init__(self, running_mean, running_var, handle):
+        self.running_mean = running_mean.data
+        self.running_var = running_var.data
         self.handle = handle
 
     def forward(self, x, scale, bias):
         if training:
-            resultmean = CTensor([scale.shape(0)])
-            resultvariance = CTensor([scale.shape(0)])
-            self.cache = (x, resultmean, resultvariance, scale)
 
             if self.handle.device_id == -1:
                 raise NotImplementedError
             else:
-                resultmean.ToDevice(x.device())
-                resultvariance.ToDevice(x.device())
-                return singa.GpuBatchNormForwardTraining(x, scale, bias, self.cache, self.handle)
-
+                y, mean, var = singa.GpuBatchNormForwardTraining(self.handle,
+                                                                 x, scale, bias, self.running_mean, self.running_var)
+                self.cache = (x, scale, mean, var)
         else:
             if self.handle.device_id == -1:
                 raise NotImplementedError
             else:
-                return singa.GpuBatchNormForwardInference(x, scale, bias, self.handle)
+                y, _, _ = singa.GpuBatchNormForwardInference(
+                    self.handle, x, scale, bias, self.running_mean, self.running_var)
+        return y
 
     def backward(self, dy):
         assert training is True and hasattr(
@@ -848,10 +849,11 @@ class _BatchNorm2d(Operation):
         if self.handle.device_id == -1:
             raise NotImplementedError
         else:
+            x, scale, mean, var = self.cache
             dx, ds, db = singa.GpuBatchNormBackward(
-                dy, self.cache, self.handle)
+                self.handle, dy, x, scale, mean, var)
             return dx, ds, db
 
 
-def batchnorm2d(x, scale, bias, handle):
-    return _BatchNorm2d(handle)(x, scale, bias)[0]
+def batchnorm2d(x, scale, bias, running_mean, running_var, handle):
+    return _BatchNorm2d(running_mean, running_var, handle)(x, scale, bias)[0]

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/ce1a7335/src/api/model_operation.i
----------------------------------------------------------------------
diff --git a/src/api/model_operation.i b/src/api/model_operation.i
index a1d59ed..6f2d1fa 100755
--- a/src/api/model_operation.i
+++ b/src/api/model_operation.i
@@ -27,6 +27,17 @@ Tensor CpuConvBackwardW(const Tensor &dy, const Tensor &x, const Tensor &W, cons
 
 Tensor CpuConvBackwardb(const Tensor &dy, const Tensor &b, const ConvHandle &ch);
 
+
+
+class BatchNormHandle{
+  public:
+    BatchNormHandle(const float momentum, const Tensor& input);
+
+    size_t batchsize;
+};
+
+
+
 #if USE_CUDNN
 class CudnnConvHandle: public ConvHandle {
  public:
@@ -47,36 +58,25 @@ Tensor GpuConvBackwardW(const Tensor &dy, const Tensor &x, const Tensor &W, cons
 
 Tensor GpuConvBackwardb(const Tensor &dy, const Tensor &b, const CudnnConvHandle &cch);
 
-#endif  // USE_CUDNN
 
-class BatchNormHandle{
-  public:
-    BatchNormHandle(const float momentum, const Tensor& input, const Tensor& RunningMean, const Tensor& RunningVariance);
-
-    size_t batchsize;
-    Tensor runningMean;
-    Tensor runningVariance;
-
-};
 
 
 class CudnnBatchNormHandle: public BatchNormHandle{
     public:
-      CudnnBatchNormHandle(const float momentum, const Tensor& input, const Tensor& RunningMean, const Tensor& RunningVariance);
+      CudnnBatchNormHandle(const float momentum, const Tensor& input);
 
     size_t batchsize;
-    Tensor runningMean;
-    Tensor runningVariance;
 };
 
-Tensor GpuBatchNormForwardTraining(const Tensor& x, const Tensor& bnScale, const Tensor& bnBias, 
-   const std::vector<Tensor>& cache, CudnnBatchNormHandle &cbnh);
+const vector<Tensor> GpuBatchNormForwardTraining(const CudnnBatchNormHandle &cbnh, 
+  const Tensor& x, const Tensor& bnScale, const Tensor& bnBias, Tensor& running_mean, Tensor& running_var);
 
-Tensor GpuBatchNormForwardInference(const Tensor& x, const Tensor& bnScale, const Tensor& bnBias, const CudnnBatchNormHandle &cbnh);
+Tensor GpuBatchNormForwardInference(const CudnnBatchNormHandle &cbnh, const Tensor& x, 
+  const Tensor& bnScale, const Tensor& bnBias,  const Tensor& running_mean, const Tensor& running_var);
 
-std::vector<Tensor> GpuBatchNormBackward(const Tensor& dy,
-  const std::vector<Tensor>& cache, const CudnnBatchNormHandle &cbnh);
+const std::vector<Tensor> GpuBatchNormBackward(const CudnnBatchNormHandle &cbnh, 
+  const Tensor& dy, const Tensor& x, const Tensor& bnScale, const Tensor& mean, const Tensor& var);
      
-
+#endif  // USE_CUDNN
 }
 

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/ce1a7335/src/model/operation/batchnorm.cc
----------------------------------------------------------------------
diff --git a/src/model/operation/batchnorm.cc b/src/model/operation/batchnorm.cc
index 6b2421d..7040895 100755
--- a/src/model/operation/batchnorm.cc
+++ b/src/model/operation/batchnorm.cc
@@ -2,8 +2,7 @@
 
 namespace singa {
 
-BatchNormHandle::BatchNormHandle(const float momentum, const Tensor& input, const Tensor& RunningMean,
-                                 const Tensor& RunningVariance) {
+BatchNormHandle::BatchNormHandle(const float momentum, const Tensor& input) {
   factor = momentum;
   batchsize = input.shape(0);
   channels = input.shape(1);
@@ -18,12 +17,11 @@ BatchNormHandle::BatchNormHandle(const float momentum, const Tensor& input, cons
   } else {
     LOG(FATAL) << "The dimension of input should either be 4D or 2D.";
   }
-  runningMean = RunningMean;
-  runningVariance = RunningVariance;
 };
 
-CudnnBatchNormHandle::CudnnBatchNormHandle(const float momentum, const Tensor& input, const Tensor& RunningMean,
-    const Tensor& RunningVariance): BatchNormHandle(momentum, input, RunningMean, RunningVariance) {
+#if USE_CUDNN
+CudnnBatchNormHandle::CudnnBatchNormHandle(const float momentum,
+    const Tensor& input): BatchNormHandle(momentum, input) {
   if (is_2d)
     mode = CUDNN_BATCHNORM_PER_ACTIVATION;
   else
@@ -40,85 +38,77 @@ CudnnBatchNormHandle::CudnnBatchNormHandle(const float momentum, const Tensor& i
                                          1, 1));
 };
 
-Tensor GpuBatchNormForwardTraining(const Tensor& x, const Tensor& bnScale, const Tensor& bnBias,
-                                   const std::vector<Tensor>& cache, const CudnnBatchNormHandle &cbnh) {
+Tensor GpuBatchNormForwardTraining(const CudnnBatchNormHandle &cbnh,
+                                   const Tensor& x, const Tensor& bnScale, const Tensor& bnBias,
+                                   Tensor& running_mean, Tensor& running_var) {
   CHECK_EQ(x.device()->lang(), kCuda);
   CHECK_EQ(bnScale.device()->lang(), kCuda);
   CHECK_EQ(bnBias.device()->lang(), kCuda);
-  CHECK_EQ(cbnh.runningMean.device()->lang(), kCuda);
-  CHECK_EQ(cbnh.runningVariance.device()->lang(), kCuda);
-  CHECK_EQ(cache[1].device()->lang(), kCuda);  //resultmean
-  CHECK_EQ(cache[2].device()->lang(), kCuda);  //resultvariance
+  CHECK_EQ(runningMean.device()->lang(), kCuda);
+  CHECK_EQ(runningVariance.device()->lang(), kCuda);
+
+  Tensor mean, var;
+  mean.ResetLike(running_mean);
+  var.ResetLike(running_var);
 
   Shape shape = x.shape();
-  Tensor output;
-  Tensor input;  //for unification of 2d and 4d cases.
+
+  Tensor input = x;  //for unification of 2d and 4d cases.
   if (cbnh.is_2d)
-    input = Reshape(x, Shape{shape.at(0), shape.at(1), 1, 1});
-  else
-    input = x;
+    input.Reshape(Shape{shape.at(0), shape.at(1), 1, 1});
+
+  Tensor output;
   output.ResetLike(x);
 
   output.device()->Exec(
-  [&output, &input, &bnScale, &bnBias, &cache, &cbnh](Context * ctx) {
-    Block* inBlock = input.block(), * outBlock = output.block(),
-           * saveMeanBlock = cache[1].block(),
-             * saveVarBlock = cache[2].block(),
-               * runningMeanBlock = cbnh.runningMean.block(),
-                 * runningVarBlock = cbnh.runningVariance.block(),
-                   * bnScaleBlock = bnScale.block(),
-                     * bnBiasBlock = bnBias.block();
+  [&](Context * ctx) {
     const float alpha = 1.0f, beta = 0.0f;
     double epsilon = CUDNN_BN_MIN_EPSILON;
     CUDNN_CHECK(cudnnBatchNormalizationForwardTraining(
                   ctx->cudnn_handle, cbnh.mode, &alpha, &beta, cbnh.shape_desc,
-                  inBlock->data(), cbnh.shape_desc, outBlock->mutable_data(),
-                  cbnh.param_desc, bnScaleBlock->data(), bnBiasBlock->data(), cbnh.factor,
-                  runningMeanBlock->mutable_data(), runningVarBlock->mutable_data(),
-                  epsilon, saveMeanBlock->mutable_data(),
-                  saveVarBlock->mutable_data()));
+                  input.block()->data(), cbnh.shape_desc, output.block()->mutable_data(),
+                  cbnh.param_desc, bnScale.block()->data(), bnBias.block()->data(), cbnh.factor,
+                  running_mean.block()->mutable_data(), running_var.block()->mutable_data(),
+                  epsilon, mean.block()->mutable_data(),
+                  var.block()->mutable_data()));
   },
-  {input.block(), bnScale.block(), bnBias.block()},
-  { output.block(), cbnh.runningMean.block(), cbnh.runningVariance.block(),
-    cache[1].block(), cache[2].block()
+  {input.block(), bnScale.block(), bnBias.block(), running_mean.block(), running_var.block()}, {
+    output.block(), running_mean.block(), running_var.block(),
+    mean.block(), var.block()
   });
   if (cbnh.is_2d) output.Reshape(Shape{shape.at(0), shape.at(1)});
-  return output;
+  return {output, mean, var};
 };
 
-Tensor GpuBatchNormForwardInference(const Tensor& x, const Tensor& bnScale, const Tensor& bnBias,
-                                    const CudnnBatchNormHandle &cbnh) {
+Tensor GpuBatchNormForwardInference(const CudnnBatchNormHandle &cbnh,
+                                    const Tensor& x, const Tensor& bnScale,
+                                    const Tensor& bnBias, const Tensor& running_mean, const Tensor& running_var) {
   CHECK_EQ(x.device()->lang(), kCuda);
   CHECK_EQ(bnScale.device()->lang(), kCuda);
   CHECK_EQ(bnBias.device()->lang(), kCuda);
-  CHECK_EQ(cbnh.runningMean.device()->lang(), kCuda);
-  CHECK_EQ(cbnh.runningVariance.device()->lang(), kCuda);
+  CHECK_EQ(cbnh.running_mean.device()->lang(), kCuda);
+  CHECK_EQ(cbnh.running_variance.device()->lang(), kCuda);
 
   Shape shape = x.shape();
-  Tensor output;
-  Tensor input;  //for unification of 2d and 4d cases.
+
+  Tensor input = x;  //for unification of 2d and 4d cases.
   if (cbnh.is_2d)
-    input = Reshape(x, Shape{shape.at(0), shape.at(1), 1, 1});
-  else
-    input = x;
+    input.Reshape(Shape{shape.at(0), shape.at(1), 1, 1});
+
+  Tensor output;
   output.ResetLike(x);
   output.device()->Exec(
-  [&output, &input, &bnScale, &bnBias, &cbnh](Context * ctx) {
-    Block* inBlock = input.block(), * outBlock = output.block(),
-           * runningMeanBlock = cbnh.runningMean.block(),
-             * runningVarBlock = cbnh.runningVariance.block(),
-               * bnScaleBlock = bnScale.block(),
-                 * bnBiasBlock = bnBias.block();
+  [&](Context * ctx) {
     const float alpha = 1.0f, beta = 0.0f;
     double epsilon = CUDNN_BN_MIN_EPSILON;
     CUDNN_CHECK(cudnnBatchNormalizationForwardInference(
                   ctx->cudnn_handle, cbnh.mode, &alpha, &beta, cbnh.shape_desc,
-                  inBlock->data(), cbnh.shape_desc, outBlock->mutable_data(),
-                  cbnh.param_desc, bnScaleBlock->data(), bnBiasBlock->data(),
-                  runningMeanBlock->data(), runningVarBlock->data(), epsilon));
-  },
-  { input.block(), bnScale.block(), bnBias.block(), cbnh.runningMean.block(),
-    cbnh.runningVariance.block()
+                  input.block()->data(), cbnh.shape_desc, output.block()->mutable_data(),
+                  cbnh.param_desc, bnScale.block()->data(), bnBias.block()->data(),
+                  running_mean.block()->data(), running_var.block()->data(), epsilon));
+  }, {
+    input.block(), bnScale.block(), bnBias.block(), running_mean.block(),
+    running_variance.block()
   },
   {output.block()});
   if (cbnh.is_2d) output.Reshape(Shape{shape.at(0), shape.at(1)});
@@ -126,52 +116,43 @@ Tensor GpuBatchNormForwardInference(const Tensor& x, const Tensor& bnScale, cons
 };
 
 
-std::vector<Tensor> GpuBatchNormBackward(const Tensor& dy, const std::vector<Tensor>& cache, const CudnnBatchNormHandle &cbnh) {
+std::vector<Tensor> GpuBatchNormBackward(const CudnnBatchNormHandle &cbnh,
+    const Tensor& dy, const Tensor& x, const Tensor& bnScale, const Tensor& mean,
+    const Tensor& var) {
   CHECK_EQ(dy.device()->lang(), kCuda);
-  CHECK_EQ(cache[0].device()->lang(), kCuda);
-  CHECK_EQ(cache[1].device()->lang(), kCuda);
-  CHECK_EQ(cache[2].device()->lang(), kCuda);
-  CHECK_EQ(cache[3].device()->lang(), kCuda);
+  CHECK_EQ(x.device()->lang(), kCuda);
+  CHECK_EQ(bnScale.device()->lang(), kCuda);
+  CHECK_EQ(mean.device()->lang(), kCuda);
+  CHECK_EQ(var.device()->lang(), kCuda);
 
   vector<Tensor> out_grads;
   Tensor dx;
   dx.ResetLike(dy);
 
   Tensor dbnScale;
-  dbnScale.ResetLike(cache[3]);
+  dbnScale.ResetLike(bnScale);
 
   Tensor dbnBias;
-  dbnBias.ResetLike(cache[3]);
-  //dbnBias.ResetLike(bnBias);
+  dbnBias.ResetLike(bnScale);
 
   dx.device()->Exec(
-  [&dx, &dbnScale, &dbnBias, &dy, &cache, &cbnh](Context * ctx) {
-    Block* dyblock = dy.block(), * dxblock = dx.block(),
-           * xblock = cache[0].block(), * bnScaleBlock = cache[3].block(),
-             * dbnScaleBlock = dbnScale.block(),
-               * dbnBiasBlock = dbnBias.block(),
-                 * saveMeanBlock = cache[1].block(),
-                   * saveVarBlock = cache[2].block();
+  [&](Context * ctx) {
+    
     const float alpha = 1.0f, beta = .0f;
     double epsilon = CUDNN_BN_MIN_EPSILON;
     CUDNN_CHECK(cudnnBatchNormalizationBackward(
                   ctx->cudnn_handle, cbnh.mode, &alpha, &beta, &alpha, &beta,
-                  cbnh.shape_desc, xblock->data(), cbnh.shape_desc, dyblock->data(),
-                  cbnh.shape_desc, dxblock->mutable_data(), cbnh.param_desc,
-                  bnScaleBlock->data(), dbnScaleBlock->mutable_data(),
-                  dbnBiasBlock->mutable_data(), epsilon, saveMeanBlock->data(),
-                  saveVarBlock->data()));
-  },
-  { cache[0].block(), dy.block(), cache[3].block(), cache[1].block(),
-    cache[2].block()
-  },
+                  cbnh.shape_desc, x.block()->data(), cbnh.shape_desc, dy.block()->data(),
+                  cbnh.shape_desc, dx.block()->mutable_data(), cbnh.param_desc,
+                  bnScale.block()->data(), dbnScale.block()->mutable_data(),
+                  dbnBias.block()->mutable_data(), epsilon, mean.block()->data(),
+                  var.block()->data()));
+  }, {x.block(), dy.block(), bnScale.block(), mean.block(), var.block()},
   {dx.block(), dbnScale.block(), dbnBias.block()});
 
   if (cbnh.is_2d) dx.Reshape(Shape{dx.shape().at(0), dx.shape().at(1)});
-  out_grads.push_back(dx);
-  out_grads.push_back(dbnScale);
-  out_grads.push_back(dbnBias);
-  return out_grads;
+  
+  return {dx, dbnScale, dbnBias};
 };
 
 }

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/ce1a7335/src/model/operation/batchnorm.h
----------------------------------------------------------------------
diff --git a/src/model/operation/batchnorm.h b/src/model/operation/batchnorm.h
index ee182f9..f4372e3 100755
--- a/src/model/operation/batchnorm.h
+++ b/src/model/operation/batchnorm.h
@@ -12,8 +12,8 @@
 namespace singa {
 
 class BatchNormHandle {
-public:
-  BatchNormHandle(const float momentum, const Tensor& input, const Tensor& RunningMean, const Tensor& RunningVariance);
+ public:
+  BatchNormHandle(const float momentum, const Tensor& input);
 
   float factor;
 
@@ -21,10 +21,6 @@ public:
   size_t channels;
   size_t height;
   size_t width;
-
-  Tensor runningMean;
-  Tensor runningVariance;
-
   bool is_2d;
   //bool train = true;
 };
@@ -39,8 +35,8 @@ public:
 #ifdef USE_CUDNN
 
 class CudnnBatchNormHandle: public BatchNormHandle {
-public:
-  CudnnBatchNormHandle(const float momentum, const Tensor& input, const Tensor& RunningMean, const Tensor& RunningVariance);
+ public:
+  CudnnBatchNormHandle(const float momentum, const Tensor& input);
 
   //~CudnnBatchNormHandle();
 
@@ -49,13 +45,17 @@ public:
   cudnnTensorDescriptor_t param_desc = nullptr;
 };
 
-Tensor GpuBatchNormForwardTraining(const Tensor& x, const Tensor& bnScale, const Tensor& bnBias,
-                                   const std::vector<Tensor>& cache, const CudnnBatchNormHandle &cbnh);
+const std::vector<Tensor> GpuBatchNormForwardTraining(const CudnnBatchNormHandle
+    &cbnh, const Tensor& x, const Tensor& bnScale, const Tensor& bnBias,
+    Tensor& running_mean, Tensor& running_var);
 
-Tensor GpuBatchNormForwardInference(const Tensor& x, const Tensor& bnScale, const Tensor& bnBias,
-                                    const CudnnBatchNormHandle &cbnh);
+Tensor GpuBatchNormForwardInference(const CudnnBatchNormHandle &cbnh,
+                                    const Tensor& x, const Tensor& bnScale, const Tensor& bnBias,
+                                    const Tensor& running_mean, const Tensor& running_var);
 
-std::vector<Tensor> GpuBatchNormBackward(const Tensor& dy, const std::vector<Tensor>& cache, const CudnnBatchNormHandle &cbnh);
+const std::vector<Tensor> GpuBatchNormBackward(const CudnnBatchNormHandle &cbnh,
+    const Tensor& dy, const Tensor& x, const Tensor& bnScale, const Tensor& mean,
+    const Tensor& var);
 
 #endif  // USE_CUDNN
 


[3/5] incubator-singa git commit: SINGA-379 Implement batchnorm operation and its related functions for autograd

Posted by zh...@apache.org.
SINGA-379 Implement batchnorm operation and its related functions for autograd

- fixed some bugs.

- modified the design of batchnorm operation

- write test file for batchnorm layer(operation), unit test passed.


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/8654f894
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/8654f894
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/8654f894

Branch: refs/heads/master
Commit: 8654f8942a73d7bd86a0bf2e4b2a9f154b124d1e
Parents: 10274f3
Author: xuewanqi <xu...@outlook.com>
Authored: Wed Jul 11 06:40:07 2018 +0000
Committer: Wang Wei <wa...@gmail.com>
Committed: Wed Jul 11 21:57:48 2018 +0800

----------------------------------------------------------------------
 python/singa/autograd.py         | 42 ++++++++++++++++++++++++-----------
 src/api/model_operation.i        |  2 +-
 src/model/operation/batchnorm.cc | 35 +++++++++++++++++------------
 src/model/operation/batchnorm.h  |  2 +-
 test/python/test_operation.py    |  4 ++--
 5 files changed, 54 insertions(+), 31 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/8654f894/python/singa/autograd.py
----------------------------------------------------------------------
diff --git a/python/singa/autograd.py b/python/singa/autograd.py
index 97a75b4..4ba0b11 100755
--- a/python/singa/autograd.py
+++ b/python/singa/autograd.py
@@ -770,35 +770,44 @@ class Conv2D(Layer):
         y = conv2d(x, self.W, self.b, self.handle)
         return y
 
+
 class BatchNorm2d(NewLayer):
-    def __init__(self, num_features, momentum = 0.9):
+
+    def __init__(self, num_features, momentum=0.9):
         self.channels = num_features
         self.momentum = momentum
 
         param_shape = (self.channels,)
 
-        self.scale = Tensor(shape=param_shape, requires_grad=True, stores_grad=True)
+        self.scale = Tensor(shape=param_shape,
+                            requires_grad=True, stores_grad=True)
         self.scale.set_value(1.0)
 
-        self.bias =  Tensor(shape=param_shape, requires_grad=True, stores_grad=True)
+        self.bias = Tensor(shape=param_shape,
+                           requires_grad=True, stores_grad=True)
         self.bias.set_value(0.0)
 
-        self.runningmean = Tensor(shape=param_shape, requires_grad=False, stores_grad=False)
-        self.runningvariance = Tensor(shape=param_shape, requires_grad=False, stores_grad=False)
+        self.runningmean = Tensor(
+            shape=param_shape, requires_grad=False, stores_grad=False)
+        self.runningvariance = Tensor(
+            shape=param_shape, requires_grad=False, stores_grad=False)
 
     def __call__(self, x):
         assert x.shape[1] == self.channels, 'number of channels dismatched.'
 
-        self.device_check(x, self.scale, self.bias, self.runningmean,self.runningvariance)
+        self.device_check(x, self.scale, self.bias,
+                          self.runningmean, self.runningvariance)
 
         if x.device.id() == -1:
             raise NotImplementedError
 
         else:
             if not hasattr(self, 'handle'):
-                self.handle = singa.CudnnBatchNormHandle(self.momentum, x.data, self.runningmean.data, self.runningvariance.data)
+                self.handle = singa.CudnnBatchNormHandle(
+                    self.momentum, x.data, self.runningmean.data, self.runningvariance.data)
             elif x.shape[0] != self.handle.batchsize:
-                self.handle = singa.CudnnBatchNormHandle(self.momentum, x.data, self.runningmean.data, self.runningvariance.data)
+                self.handle = singa.CudnnBatchNormHandle(
+                    self.momentum, x.data, self.runningmean.data, self.runningvariance.data)
         self.handle.device_id = x.device.id()
 
         y = batchnorm2d(x, self.scale, self.bias, self.handle)
@@ -806,26 +815,32 @@ class BatchNorm2d(NewLayer):
 
 
 class _BatchNorm2d(Operation):
-    def __init(self, handle):
+
+    def __init__(self, handle):
         self.handle = handle
 
     def forward(self, x, scale, bias):
         if training:
-            self.cache=(x,)
+            resultmean = CTensor([scale.shape(0)])
+            resultvariance = CTensor([scale.shape(0)])
+            self.cache = (x, resultmean, resultvariance, scale)
+
             if self.handle.device_id == -1:
                 raise NotImplementedError
             else:
+                resultmean.ToDevice(x.device())
+                resultvariance.ToDevice(x.device())
                 return singa.GpuBatchNormForwardTraining(x, scale, bias, self.cache, self.handle)
 
         else:
             if self.handle.device_id == -1:
                 raise NotImplementedError
             else:
-                return singa.GpuBatchNormForwardInference(x, scale, bias ,self.handle)
+                return singa.GpuBatchNormForwardInference(x, scale, bias, self.handle)
 
     def backward(self, dy):
         assert training is True and hasattr(
-            self, 'cahce'), 'Please set training as True before do BP. '
+            self, 'cache'), 'Please set training as True before do BP. '
 
         if dy.device().id() != self.handle.device_id:
             dy.ToDevice(self.cache[0].device())
@@ -833,7 +848,8 @@ class _BatchNorm2d(Operation):
         if self.handle.device_id == -1:
             raise NotImplementedError
         else:
-            dx, ds, db = singa.GpuBatchNormBackward(dy, self.cache, self.handle)
+            dx, ds, db = singa.GpuBatchNormBackward(
+                dy, self.cache, self.handle)
             return dx, ds, db
 
 

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/8654f894/src/api/model_operation.i
----------------------------------------------------------------------
diff --git a/src/api/model_operation.i b/src/api/model_operation.i
index 95efd26..a1d59ed 100755
--- a/src/api/model_operation.i
+++ b/src/api/model_operation.i
@@ -70,7 +70,7 @@ class CudnnBatchNormHandle: public BatchNormHandle{
 };
 
 Tensor GpuBatchNormForwardTraining(const Tensor& x, const Tensor& bnScale, const Tensor& bnBias, 
-   std::vector<Tensor>& cache, CudnnBatchNormHandle &cbnh);
+   const std::vector<Tensor>& cache, CudnnBatchNormHandle &cbnh);
 
 Tensor GpuBatchNormForwardInference(const Tensor& x, const Tensor& bnScale, const Tensor& bnBias, const CudnnBatchNormHandle &cbnh);
 

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/8654f894/src/model/operation/batchnorm.cc
----------------------------------------------------------------------
diff --git a/src/model/operation/batchnorm.cc b/src/model/operation/batchnorm.cc
index 145b90b..6b2421d 100755
--- a/src/model/operation/batchnorm.cc
+++ b/src/model/operation/batchnorm.cc
@@ -8,8 +8,8 @@ BatchNormHandle::BatchNormHandle(const float momentum, const Tensor& input, cons
   batchsize = input.shape(0);
   channels = input.shape(1);
   if (input.nDim() == 4u) {
-    height = input.shape(2);
-    width = input.shape(3);
+    height = input.shape().at(2);
+    width = input.shape().at(3);
     is_2d = false;
   } else if (input.nDim() == 2u) {
     height = 1;
@@ -41,7 +41,14 @@ CudnnBatchNormHandle::CudnnBatchNormHandle(const float momentum, const Tensor& i
 };
 
 Tensor GpuBatchNormForwardTraining(const Tensor& x, const Tensor& bnScale, const Tensor& bnBias,
-                                   std::vector<Tensor>& cache, const CudnnBatchNormHandle &cbnh) {
+                                   const std::vector<Tensor>& cache, const CudnnBatchNormHandle &cbnh) {
+  CHECK_EQ(x.device()->lang(), kCuda);
+  CHECK_EQ(bnScale.device()->lang(), kCuda);
+  CHECK_EQ(bnBias.device()->lang(), kCuda);
+  CHECK_EQ(cbnh.runningMean.device()->lang(), kCuda);
+  CHECK_EQ(cbnh.runningVariance.device()->lang(), kCuda);
+  CHECK_EQ(cache[1].device()->lang(), kCuda);  //resultmean
+  CHECK_EQ(cache[2].device()->lang(), kCuda);  //resultvariance
 
   Shape shape = x.shape();
   Tensor output;
@@ -52,17 +59,6 @@ Tensor GpuBatchNormForwardTraining(const Tensor& x, const Tensor& bnScale, const
     input = x;
   output.ResetLike(x);
 
-  Tensor resultSaveMean;
-  Tensor resultSaveVariance;
-
-  resultSaveMean.Reshape(Shape{cbnh.channels});
-  resultSaveVariance.Reshape(Shape{cbnh.channels});
-
-  cache.push_back(resultSaveMean);
-  cache.push_back(resultSaveVariance);
-  cache.push_back(bnScale);
-  //cache={x, mean, var, scale}
-
   output.device()->Exec(
   [&output, &input, &bnScale, &bnBias, &cache, &cbnh](Context * ctx) {
     Block* inBlock = input.block(), * outBlock = output.block(),
@@ -92,6 +88,12 @@ Tensor GpuBatchNormForwardTraining(const Tensor& x, const Tensor& bnScale, const
 
 Tensor GpuBatchNormForwardInference(const Tensor& x, const Tensor& bnScale, const Tensor& bnBias,
                                     const CudnnBatchNormHandle &cbnh) {
+  CHECK_EQ(x.device()->lang(), kCuda);
+  CHECK_EQ(bnScale.device()->lang(), kCuda);
+  CHECK_EQ(bnBias.device()->lang(), kCuda);
+  CHECK_EQ(cbnh.runningMean.device()->lang(), kCuda);
+  CHECK_EQ(cbnh.runningVariance.device()->lang(), kCuda);
+
   Shape shape = x.shape();
   Tensor output;
   Tensor input;  //for unification of 2d and 4d cases.
@@ -125,6 +127,11 @@ Tensor GpuBatchNormForwardInference(const Tensor& x, const Tensor& bnScale, cons
 
 
 std::vector<Tensor> GpuBatchNormBackward(const Tensor& dy, const std::vector<Tensor>& cache, const CudnnBatchNormHandle &cbnh) {
+  CHECK_EQ(dy.device()->lang(), kCuda);
+  CHECK_EQ(cache[0].device()->lang(), kCuda);
+  CHECK_EQ(cache[1].device()->lang(), kCuda);
+  CHECK_EQ(cache[2].device()->lang(), kCuda);
+  CHECK_EQ(cache[3].device()->lang(), kCuda);
 
   vector<Tensor> out_grads;
   Tensor dx;

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/8654f894/src/model/operation/batchnorm.h
----------------------------------------------------------------------
diff --git a/src/model/operation/batchnorm.h b/src/model/operation/batchnorm.h
index f21bd1d..ee182f9 100755
--- a/src/model/operation/batchnorm.h
+++ b/src/model/operation/batchnorm.h
@@ -50,7 +50,7 @@ public:
 };
 
 Tensor GpuBatchNormForwardTraining(const Tensor& x, const Tensor& bnScale, const Tensor& bnBias,
-                                   std::vector<Tensor>& cache, const CudnnBatchNormHandle &cbnh);
+                                   const std::vector<Tensor>& cache, const CudnnBatchNormHandle &cbnh);
 
 Tensor GpuBatchNormForwardInference(const Tensor& x, const Tensor& bnScale, const Tensor& bnBias,
                                     const CudnnBatchNormHandle &cbnh);

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/8654f894/test/python/test_operation.py
----------------------------------------------------------------------
diff --git a/test/python/test_operation.py b/test/python/test_operation.py
index 0e851d7..67018c1 100755
--- a/test/python/test_operation.py
+++ b/test/python/test_operation.py
@@ -79,12 +79,12 @@ class TestPythonOperation(unittest.TestCase):
         dy = CTensor([2, 3, 3, 3])
         singa.Gaussian(0.0, 1.0, dy)
 
-        y=batchnorm_0(gpu_input_tensor)
+        y = batchnorm_0(gpu_input_tensor)
         dx, ds, db = y.creator.backward(dy)
 
         self.check_shape(y.shape, (2, 3, 3, 3))
         self.check_shape(dx.shape(), (2, 3, 3, 3))
-        self.check_shape(dx.shape(), (3,))
+        self.check_shape(ds.shape(), (3,))
         self.check_shape(db.shape(), (3,))
 
 if __name__ == '__main__':


[2/5] incubator-singa git commit: SINGA-379 Implement batchnorm operation and its related functions for autograd

Posted by zh...@apache.org.
SINGA-379 Implement batchnorm operation and its related functions for autograd

- implement batchnorm2d related functions(GPU part)

- add interface files for developed functions

- create corresponding operation and NewLayer for batchnorm2d


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/a105b240
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/a105b240
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/a105b240

Branch: refs/heads/master
Commit: a105b2404e36c81f30b873e0c74ccdb9a7e36bfd
Parents: b30d7ea
Author: xuewanqi <xu...@outlook.com>
Authored: Sun Jul 8 15:08:43 2018 +0000
Committer: Wang Wei <wa...@gmail.com>
Committed: Wed Jul 11 21:57:47 2018 +0800

----------------------------------------------------------------------
 python/singa/autograd.py           |  80 +++++++++++++++-
 src/api/model_operation.i          |  31 ++++++
 src/model/layer/cudnn_batchnorm.cc |   2 +-
 src/model/operation/batchnorm.cc   | 164 ++++++++++++++++++++++++++++++++
 src/model/operation/batchnorm.h    |  62 ++++++++++++
 test/python/test_operation.py      |  17 ++++
 6 files changed, 354 insertions(+), 2 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/a105b240/python/singa/autograd.py
----------------------------------------------------------------------
diff --git a/python/singa/autograd.py b/python/singa/autograd.py
index aa6b37a..97a75b4 100755
--- a/python/singa/autograd.py
+++ b/python/singa/autograd.py
@@ -27,7 +27,7 @@ from .tensor import Tensor
 from . import layer
 from singa.proto import model_pb2
 from . import singa_wrap as singa
-
+#from .tensor import einsum
 
 CTensor = singa.Tensor
 training = False
@@ -415,6 +415,14 @@ class SoftMax(Operation):
         out = out_1 - out_2
         dx = CTensor(out_1.shape)
         dx.CopyFloatDataFromHostPtr(out.flatten())
+        '''grad = Tensor(data=dy)
+        output = Tensor(data=self.output)
+        out_1 = einsum('ki,ki->ki', grad, output)
+        medium_out = einsum('ki,kj->kij', output, output)
+        out_2 = einsum('kij,kj->ki', medium_out, grad)
+        out = out_1 - out_2
+        dx = CTensor(out_1.data.shape)
+        dx.CopyFloatDataFromHostPtr(out.data.flatten())'''
         if self.axis == 0:
             return dx
         elif self.axis == 1:
@@ -761,3 +769,73 @@ class Conv2D(Layer):
 
         y = conv2d(x, self.W, self.b, self.handle)
         return y
+
+class BatchNorm2d(NewLayer):
+    def __init__(self, num_features, momentum = 0.9):
+        self.channels = num_features
+        self.momentum = momentum
+
+        param_shape = (self.channels,)
+
+        self.scale = Tensor(shape=param_shape, requires_grad=True, stores_grad=True)
+        self.scale.set_value(1.0)
+
+        self.bias =  Tensor(shape=param_shape, requires_grad=True, stores_grad=True)
+        self.bias.set_value(0.0)
+
+        self.runningmean = Tensor(shape=param_shape, requires_grad=False, stores_grad=False)
+        self.runningvariance = Tensor(shape=param_shape, requires_grad=False, stores_grad=False)
+
+    def __call__(self, x):
+        assert x.shape[1] == self.channels, 'number of channels dismatched.'
+
+        self.device_check(x, self.scale, self.bias, self.runningmean,self.runningvariance)
+
+        if x.device.id() == -1:
+            raise NotImplementedError
+
+        else:
+            if not hasattr(self, 'handle'):
+                self.handle = singa.CudnnBatchNormHandle(self.momentum, x.data, self.runningmean.data, self.runningvariance.data)
+            elif x.shape[0] != self.handle.batchsize:
+                self.handle = singa.CudnnBatchNormHandle(self.momentum, x.data, self.runningmean.data, self.runningvariance.data)
+        self.handle.device_id = x.device.id()
+
+        y = batchnorm2d(x, self.scale, self.bias, self.handle)
+        return y
+
+
+class _BatchNorm2d(Operation):
+    def __init(self, handle):
+        self.handle = handle
+
+    def forward(self, x, scale, bias):
+        if training:
+            self.cache=(x,)
+            if self.handle.device_id == -1:
+                raise NotImplementedError
+            else:
+                return singa.GpuBatchNormForwardTraining(x, scale, bias, self.cache, self.handle)
+
+        else:
+            if self.handle.device_id == -1:
+                raise NotImplementedError
+            else:
+                return singa.GpuBatchNormForwardInference(x, scale, bias ,self.handle)
+
+    def backward(self, dy):
+        assert training is True and hasattr(
+            self, 'cahce'), 'Please set training as True before do BP. '
+
+        if dy.device().id() != self.handle.device_id:
+            dy.ToDevice(self.cache[0].device())
+
+        if self.handle.device_id == -1:
+            raise NotImplementedError
+        else:
+            dx, ds, db = singa.GpuBatchNormBackward(dy, self.cache, self.handle)
+            return dx, ds, db
+
+
+def batchnorm2d(x, scale, bias, handle):
+    return _BatchNorm2d(handle)(x, scale, bias)[0]

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/a105b240/src/api/model_operation.i
----------------------------------------------------------------------
diff --git a/src/api/model_operation.i b/src/api/model_operation.i
index 3858a2b..783a1f8 100755
--- a/src/api/model_operation.i
+++ b/src/api/model_operation.i
@@ -5,6 +5,7 @@
 %include "std_string.i"
 %{
 #include "../src/model/operation/convolution.h"
+#include "../src/model/operation/batchnorm.h"
 %}
 namespace singa {
 
@@ -48,4 +49,34 @@ Tensor GpuConvBackwardb(const Tensor &dy, const Tensor &b, const CudnnConvHandle
 
 #endif  // USE_CUDNN
 
+class BatchNormHandle{
+  public:
+    BatchNormHandle(const float momentum, const Tensor& input, const Tensor& RunningMean_, const Tensor& RunningVariance_);
+
+    size_t batchsize;
+    Tensor runningMean_;
+    Tensor runningVariance_;
+
+};
+
+
+class CudnnBatchNormHandle: public BatchNormHandle{
+    public:
+      CudnnBatchNormHandle(const float momentum, const Tensor& input, const Tensor& RunningMean_, const Tensor& RunningVariance_);
+
+    size_t batchsize;
+    Tensor runningMean_;
+    Tensor runningVariance_;
+};
+
+Tensor GpuBatchNormForwardTraining(const Tensor& x, const Tensor& bnScale_, const Tensor& bnBias_, 
+   std::vector<Tensor>& cache, CudnnBatchNormHandle &cbnh);
+
+Tensor GpuBatchNormForwardInference(const Tensor& x, const Tensor& bnScale_, const Tensor& bnBias_, const CudnnBatchNormHandle &cbnh);
+
+std::vector<Tensor> GpuBatchNormBackward(const Tensor& dy,
+  const std::vector<Tensor>& cache, const CudnnBatchNormHandle &cbnh);
+     
+
 }
+

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/a105b240/src/model/layer/cudnn_batchnorm.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/cudnn_batchnorm.cc b/src/model/layer/cudnn_batchnorm.cc
old mode 100644
new mode 100755
index 389b41b..4816817
--- a/src/model/layer/cudnn_batchnorm.cc
+++ b/src/model/layer/cudnn_batchnorm.cc
@@ -167,7 +167,7 @@ const std::pair<Tensor, vector<Tensor>> CudnnBatchNorm::Backward(
               saveVarBlock->data()));
 
         },
-        {dx.block(), grad.block(), bnScale_.block(), resultSaveMean_.block(),
+        {x.block(), grad.block(), bnScale_.block(), resultSaveMean_.block(),
          resultSaveVariance_.block()},
         {dx.block(), dbnScale_.block(), dbnBias_.block()});
   } else {

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/a105b240/src/model/operation/batchnorm.cc
----------------------------------------------------------------------
diff --git a/src/model/operation/batchnorm.cc b/src/model/operation/batchnorm.cc
new file mode 100644
index 0000000..9b6f9cd
--- /dev/null
+++ b/src/model/operation/batchnorm.cc
@@ -0,0 +1,164 @@
+#include "./batchnorm.h"
+
+namespace singa{
+
+BatchNormHandle::BatchNormHandle(const float momentum, const Tensor& input, const Tensor& RunningMean_, 
+  const Tensor& RunningVariance_){
+  factor_ = momentum;
+  batchsize = input.shape()[0];
+  channels_= input.shape()[2];
+  if (input.nDim()== 4u){
+    height_= input.shape()[3];
+    width_=input.shape()[4];
+    is_2d_= false;
+  }else{
+    size_t height_ = 1;
+    size_t width_ = 1;
+    bool is_2d_ = true;
+  }
+  runningMean_= RunningMean_;
+  runningVariance_= RunningVariance_;
+};
+
+CudnnBatchNormHandle::CudnnBatchNormHandle(const float momentum, const Tensor& input, const Tensor& RunningMean_, 
+  const Tensor& RunningVariance_):BatchNormHandle(momentum, input, RunningMean_, RunningVariance_){
+  if (is_2d_)
+      mode_ = CUDNN_BATCHNORM_PER_ACTIVATION;
+  else
+      mode_ = CUDNN_BATCHNORM_SPATIAL;
+  auto dtype = input.data_type();
+  CUDNN_CHECK(cudnnCreateTensorDescriptor(&shape_desc_));
+  CUDNN_CHECK(cudnnCreateTensorDescriptor(&param_desc_));
+  CUDNN_CHECK(cudnnSetTensor4dDescriptor(shape_desc_, CUDNN_TENSOR_NCHW,
+                                         GetCudnnDataType(dtype), batchsize,
+                                         channels_, height_, width_));
+  CUDNN_CHECK(cudnnSetTensor4dDescriptor(param_desc_, CUDNN_TENSOR_NCHW,
+                                         GetCudnnDataType(dtype), 1, channels_,
+                                         1, 1));
+  };
+
+Tensor GpuBatchNormForwardTraining(const Tensor& x, const Tensor& bnScale_, const Tensor& bnBias_, 
+  std::vector<Tensor>& cache, const CudnnBatchNormHandle &cbnh) {
+  
+  auto shape = x.shape();
+  Tensor output;
+  Tensor input;  //for unification of 2d and 4d cases.
+  if (cbnh.is_2d_)
+    input = Reshape(x, Shape{shape.at(0), shape.at(1), 1, 1});
+  else
+    input = x;
+  output.ResetLike(x);
+
+  Tensor resultSaveMean_;
+  Tensor resultSaveVariance_;
+
+  resultSaveMean_.Reshape(Shape{cbnh.channels_});
+  resultSaveVariance_.Reshape(Shape{cbnh.channels_});
+
+  cache.push_back(resultSaveMean_);
+  cache.push_back(resultSaveVariance_);
+  cache.push_back(bnScale_);
+  //cache={x, mean, var, scale}
+
+    output.device()->Exec(
+        [&output, &input, &bnScale_, &bnBias_, &cache, &cbnh](Context* ctx) {
+          Block* inBlock = input.block(), * outBlock = output.block(),
+                 * saveMeanBlock = cache[1].block(),
+                 * saveVarBlock = cache[2].block(),
+                 * runningMeanBlock = cbnh.runningMean_.block(),
+                 * runningVarBlock = cbnh.runningVariance_.block(),
+                 * bnScaleBlock = bnScale_.block(),
+                 * bnBiasBlock = bnBias_.block();
+          const float alpha = 1.0f, beta = 0.0f;
+          double epsilon = CUDNN_BN_MIN_EPSILON;
+          CUDNN_CHECK(cudnnBatchNormalizationForwardTraining(
+              ctx->cudnn_handle, cbnh.mode_, &alpha, &beta, cbnh.shape_desc_,
+              inBlock->data(), cbnh.shape_desc_, outBlock->mutable_data(),
+              cbnh.param_desc_, bnScaleBlock->data(), bnBiasBlock->data(), cbnh.factor_,
+              runningMeanBlock->mutable_data(), runningVarBlock->mutable_data(),
+              epsilon, saveMeanBlock->mutable_data(),
+              saveVarBlock->mutable_data()));
+        },
+        {input.block(), bnScale_.block(), bnBias_.block()},
+        {output.block(), cbnh.runningMean_.block(), cbnh.runningVariance_.block(),
+         cache[1].block(), cache[2].block()}); 
+  if (cbnh.is_2d_) output.Reshape(Shape{shape.at(0), shape.at(1)});
+  return output;
+};
+
+Tensor GpuBatchNormForwardInference(const Tensor& x, const Tensor& bnScale_, const Tensor& bnBias_, 
+   const CudnnBatchNormHandle &cbnh) {
+  auto shape = x.shape();
+  Tensor output;
+  Tensor input;  //for unification of 2d and 4d cases.
+  if (cbnh.is_2d_)
+    input = Reshape(x, Shape{shape.at(0), shape.at(1), 1, 1});
+  else
+    input = x;
+  output.ResetLike(x);
+    output.device()->Exec(
+        [&output, &input, &bnScale_, &bnBias_, &cbnh](Context* ctx) {
+          Block* inBlock = input.block(), * outBlock = output.block(),
+                 * runningMeanBlock = cbnh.runningMean_.block(),
+                 * runningVarBlock = cbnh.runningVariance_.block(),
+                 * bnScaleBlock = bnScale_.block(),
+                 * bnBiasBlock = bnBias_.block();
+          const float alpha = 1.0f, beta = 0.0f;
+          double epsilon = CUDNN_BN_MIN_EPSILON;
+          CUDNN_CHECK(cudnnBatchNormalizationForwardInference(
+              ctx->cudnn_handle, cbnh.mode_, &alpha, &beta, cbnh.shape_desc_,
+              inBlock->data(), cbnh.shape_desc_, outBlock->mutable_data(),
+              cbnh.param_desc_, bnScaleBlock->data(), bnBiasBlock->data(),
+              runningMeanBlock->data(), runningVarBlock->data(), epsilon));
+        },
+        {input.block(), bnScale_.block(), bnBias_.block(), cbnh.runningMean_.block(),
+         cbnh.runningVariance_.block()},
+        {output.block()});
+  if (cbnh.is_2d_) output.Reshape(Shape{shape.at(0), shape.at(1)});
+  return output;
+};
+
+
+std::vector<Tensor> GpuBatchNormBackward(const Tensor& dy, const std::vector<Tensor>& cache, const CudnnBatchNormHandle &cbnh){
+
+  vector<Tensor> out_grads;
+  Tensor dx;
+  dx.ResetLike(dy);
+
+  Tensor dbnScale_;
+  dbnScale_.ResetLike(cache[3]);
+
+  Tensor dbnBias_;
+  dbnBias_.ResetLike(cache[3]);
+  //dbnBias_.ResetLike(bnBias_);
+
+  dx.device()->Exec(
+      [&dx, &dbnScale_, &dbnBias_, &dy, &cache, &cbnh](Context* ctx) {
+        Block* dyblock = dy.block(), * dxblock = dx.block(),
+               * xblock = cache[0].block(), * bnScaleBlock = cache[3].block(),
+               * dbnScaleBlock = dbnScale_.block(),
+               * dbnBiasBlock = dbnBias_.block(),
+               * saveMeanBlock = cache[1].block(),
+               * saveVarBlock = cache[2].block();
+        const float alpha = 1.0f, beta = .0f;
+        double epsilon = CUDNN_BN_MIN_EPSILON;
+        CUDNN_CHECK(cudnnBatchNormalizationBackward(
+            ctx->cudnn_handle, cbnh.mode_, &alpha, &beta, &alpha, &beta,
+            cbnh.shape_desc_, xblock->data(), cbnh.shape_desc_, dyblock->data(),
+            cbnh.shape_desc_, dxblock->mutable_data(), cbnh.param_desc_,
+            bnScaleBlock->data(), dbnScaleBlock->mutable_data(),
+            dbnBiasBlock->mutable_data(), epsilon, saveMeanBlock->data(),
+            saveVarBlock->data()));
+      },
+      {cache[0].block(), dy.block(), cache[3].block(), cache[1].block(),
+       cache[2].block()},
+      {dx.block(), dbnScale_.block(), dbnBias_.block()});
+  
+  if (cbnh.is_2d_) dx.Reshape(Shape{dx.shape().at(0), dx.shape().at(1)});
+  out_grads.push_back(dx);
+  out_grads.push_back(dbnScale_);
+  out_grads.push_back(dbnBias_);
+return out_grads;
+};
+
+}

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/a105b240/src/model/operation/batchnorm.h
----------------------------------------------------------------------
diff --git a/src/model/operation/batchnorm.h b/src/model/operation/batchnorm.h
new file mode 100644
index 0000000..f2da4cd
--- /dev/null
+++ b/src/model/operation/batchnorm.h
@@ -0,0 +1,62 @@
+//#ifndef SINGA_MODEL_OPERATION_BATCHNORM_H_
+//#define SINGA_MODEL_OPERATION_BATCHNORM_H_
+
+#include <vector>
+#include "singa/core/tensor.h"
+
+#ifdef USE_CUDNN
+#include <cudnn.h>
+#include "../layer/cudnn_utils.h" // check_cudnn
+#endif // USE_CUDNN 
+
+namespace singa{
+
+class BatchNormHandle{
+  public:
+  	BatchNormHandle(const float momentum, const Tensor& input, const Tensor& RunningMean_, const Tensor& RunningVariance_);
+
+  	float factor_;
+  	size_t channels_;
+  	size_t batchsize;
+
+  	Tensor runningMean_;
+  	Tensor runningVariance_;
+
+  	bool is_2d_ ;
+  	//bool train = true;
+
+  	size_t height_;
+  	size_t width_;
+};
+
+//Tensor CpuBatchNormForwardTraining();
+
+//Tensor CpuBatchNormForwardInference();
+
+//Tensor CpuBatchNormBackwardx();
+
+
+#ifdef USE_CUDNN
+
+class CudnnBatchNormHandle: public BatchNormHandle{
+    public:
+      CudnnBatchNormHandle(const float momentum, const Tensor& input, const Tensor& RunningMean_, const Tensor& RunningVariance_);
+
+      //~CudnnBatchNormHandle();
+
+      cudnnBatchNormMode_t mode_;
+      cudnnTensorDescriptor_t shape_desc_ = nullptr;
+      cudnnTensorDescriptor_t param_desc_ = nullptr;
+};
+
+Tensor GpuBatchNormForwardTraining(const Tensor& x, const Tensor& bnScale_, const Tensor& bnBias_, 
+  std::vector<Tensor>& cache, const CudnnBatchNormHandle &cbnh);
+
+Tensor GpuBatchNormForwardInference(const Tensor& x, const Tensor& bnScale_, const Tensor& bnBias_, 
+	const CudnnBatchNormHandle &cbnh);
+
+std::vector<Tensor> GpuBatchNormBackward(const Tensor& dy, const std::vector<Tensor>& cache, const CudnnBatchNormHandle &cbnh);
+
+#endif  // USE_CUDNN
+
+}  // namespace singa

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/a105b240/test/python/test_operation.py
----------------------------------------------------------------------
diff --git a/test/python/test_operation.py b/test/python/test_operation.py
index 315a992..0e851d7 100755
--- a/test/python/test_operation.py
+++ b/test/python/test_operation.py
@@ -70,5 +70,22 @@ class TestPythonOperation(unittest.TestCase):
         y_without_bias = conv_without_bias_1(cpu_input_tensor)
         self.check_shape(y_without_bias.shape, (2, 1, 2, 2))
 
+    def test_batchnorm2d_gpu(self):
+        batchnorm_0 = autograd.BatchNorm2d(3)
+
+        gpu_input_tensor = tensor.Tensor(shape=(2, 3, 3, 3), device=gpu_dev)
+        gpu_input_tensor.gaussian(0.0, 1.0)
+
+        dy = CTensor([2, 3, 3, 3])
+        singa.Gaussian(0.0, 1.0, dy)
+
+        y=batchnorm_0(gpu_input_tensor)
+        dx, ds, db = y.creator.backward(dy)
+
+        self.check_shape(y.shape, (2, 3, 3, 3))
+        self.check_shape(dx.shape(), (2, 3, 3, 3))
+        self.check_shape(dx.shape(), (3,))
+        self.check_shape(db.shape(), (3,))
+
 if __name__ == '__main__':
     unittest.main()