You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@singa.apache.org by zh...@apache.org on 2018/07/12 08:39:50 UTC

[2/5] incubator-singa git commit: SINGA-379 Implement batchnorm operation and its related functions for autograd

SINGA-379 Implement batchnorm operation and its related functions for autograd

- implement batchnorm2d related functions(GPU part)

- add interface files for developed functions

- create corresponding operation and NewLayer for batchnorm2d


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/a105b240
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/a105b240
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/a105b240

Branch: refs/heads/master
Commit: a105b2404e36c81f30b873e0c74ccdb9a7e36bfd
Parents: b30d7ea
Author: xuewanqi <xu...@outlook.com>
Authored: Sun Jul 8 15:08:43 2018 +0000
Committer: Wang Wei <wa...@gmail.com>
Committed: Wed Jul 11 21:57:47 2018 +0800

----------------------------------------------------------------------
 python/singa/autograd.py           |  80 +++++++++++++++-
 src/api/model_operation.i          |  31 ++++++
 src/model/layer/cudnn_batchnorm.cc |   2 +-
 src/model/operation/batchnorm.cc   | 164 ++++++++++++++++++++++++++++++++
 src/model/operation/batchnorm.h    |  62 ++++++++++++
 test/python/test_operation.py      |  17 ++++
 6 files changed, 354 insertions(+), 2 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/a105b240/python/singa/autograd.py
----------------------------------------------------------------------
diff --git a/python/singa/autograd.py b/python/singa/autograd.py
index aa6b37a..97a75b4 100755
--- a/python/singa/autograd.py
+++ b/python/singa/autograd.py
@@ -27,7 +27,7 @@ from .tensor import Tensor
 from . import layer
 from singa.proto import model_pb2
 from . import singa_wrap as singa
-
+#from .tensor import einsum
 
 CTensor = singa.Tensor
 training = False
@@ -415,6 +415,14 @@ class SoftMax(Operation):
         out = out_1 - out_2
         dx = CTensor(out_1.shape)
         dx.CopyFloatDataFromHostPtr(out.flatten())
+        '''grad = Tensor(data=dy)
+        output = Tensor(data=self.output)
+        out_1 = einsum('ki,ki->ki', grad, output)
+        medium_out = einsum('ki,kj->kij', output, output)
+        out_2 = einsum('kij,kj->ki', medium_out, grad)
+        out = out_1 - out_2
+        dx = CTensor(out_1.data.shape)
+        dx.CopyFloatDataFromHostPtr(out.data.flatten())'''
         if self.axis == 0:
             return dx
         elif self.axis == 1:
@@ -761,3 +769,73 @@ class Conv2D(Layer):
 
         y = conv2d(x, self.W, self.b, self.handle)
         return y
+
+class BatchNorm2d(NewLayer):
+    def __init__(self, num_features, momentum = 0.9):
+        self.channels = num_features
+        self.momentum = momentum
+
+        param_shape = (self.channels,)
+
+        self.scale = Tensor(shape=param_shape, requires_grad=True, stores_grad=True)
+        self.scale.set_value(1.0)
+
+        self.bias =  Tensor(shape=param_shape, requires_grad=True, stores_grad=True)
+        self.bias.set_value(0.0)
+
+        self.runningmean = Tensor(shape=param_shape, requires_grad=False, stores_grad=False)
+        self.runningvariance = Tensor(shape=param_shape, requires_grad=False, stores_grad=False)
+
+    def __call__(self, x):
+        assert x.shape[1] == self.channels, 'number of channels dismatched.'
+
+        self.device_check(x, self.scale, self.bias, self.runningmean,self.runningvariance)
+
+        if x.device.id() == -1:
+            raise NotImplementedError
+
+        else:
+            if not hasattr(self, 'handle'):
+                self.handle = singa.CudnnBatchNormHandle(self.momentum, x.data, self.runningmean.data, self.runningvariance.data)
+            elif x.shape[0] != self.handle.batchsize:
+                self.handle = singa.CudnnBatchNormHandle(self.momentum, x.data, self.runningmean.data, self.runningvariance.data)
+        self.handle.device_id = x.device.id()
+
+        y = batchnorm2d(x, self.scale, self.bias, self.handle)
+        return y
+
+
+class _BatchNorm2d(Operation):
+    def __init(self, handle):
+        self.handle = handle
+
+    def forward(self, x, scale, bias):
+        if training:
+            self.cache=(x,)
+            if self.handle.device_id == -1:
+                raise NotImplementedError
+            else:
+                return singa.GpuBatchNormForwardTraining(x, scale, bias, self.cache, self.handle)
+
+        else:
+            if self.handle.device_id == -1:
+                raise NotImplementedError
+            else:
+                return singa.GpuBatchNormForwardInference(x, scale, bias ,self.handle)
+
+    def backward(self, dy):
+        assert training is True and hasattr(
+            self, 'cahce'), 'Please set training as True before do BP. '
+
+        if dy.device().id() != self.handle.device_id:
+            dy.ToDevice(self.cache[0].device())
+
+        if self.handle.device_id == -1:
+            raise NotImplementedError
+        else:
+            dx, ds, db = singa.GpuBatchNormBackward(dy, self.cache, self.handle)
+            return dx, ds, db
+
+
+def batchnorm2d(x, scale, bias, handle):
+    return _BatchNorm2d(handle)(x, scale, bias)[0]

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/a105b240/src/api/model_operation.i
----------------------------------------------------------------------
diff --git a/src/api/model_operation.i b/src/api/model_operation.i
index 3858a2b..783a1f8 100755
--- a/src/api/model_operation.i
+++ b/src/api/model_operation.i
@@ -5,6 +5,7 @@
 %include "std_string.i"
 %{
 #include "../src/model/operation/convolution.h"
+#include "../src/model/operation/batchnorm.h"
 %}
 namespace singa {
 
@@ -48,4 +49,34 @@ Tensor GpuConvBackwardb(const Tensor &dy, const Tensor &b, const CudnnConvHandle
 
 #endif  // USE_CUDNN
 
+class BatchNormHandle{
+  public:
+    BatchNormHandle(const float momentum, const Tensor& input, const Tensor& RunningMean_, const Tensor& RunningVariance_);
+
+    size_t batchsize;
+    Tensor runningMean_;
+    Tensor runningVariance_;
+
+};
+
+
+class CudnnBatchNormHandle: public BatchNormHandle{
+    public:
+      CudnnBatchNormHandle(const float momentum, const Tensor& input, const Tensor& RunningMean_, const Tensor& RunningVariance_);
+
+    size_t batchsize;
+    Tensor runningMean_;
+    Tensor runningVariance_;
+};
+
+Tensor GpuBatchNormForwardTraining(const Tensor& x, const Tensor& bnScale_, const Tensor& bnBias_, 
+   std::vector<Tensor>& cache, CudnnBatchNormHandle &cbnh);
+
+Tensor GpuBatchNormForwardInference(const Tensor& x, const Tensor& bnScale_, const Tensor& bnBias_, const CudnnBatchNormHandle &cbnh);
+
+std::vector<Tensor> GpuBatchNormBackward(const Tensor& dy,
+  const std::vector<Tensor>& cache, const CudnnBatchNormHandle &cbnh);
+     
+
 }
+

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/a105b240/src/model/layer/cudnn_batchnorm.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/cudnn_batchnorm.cc b/src/model/layer/cudnn_batchnorm.cc
old mode 100644
new mode 100755
index 389b41b..4816817
--- a/src/model/layer/cudnn_batchnorm.cc
+++ b/src/model/layer/cudnn_batchnorm.cc
@@ -167,7 +167,7 @@ const std::pair<Tensor, vector<Tensor>> CudnnBatchNorm::Backward(
               saveVarBlock->data()));
 
         },
-        {dx.block(), grad.block(), bnScale_.block(), resultSaveMean_.block(),
+        {x.block(), grad.block(), bnScale_.block(), resultSaveMean_.block(),
          resultSaveVariance_.block()},
         {dx.block(), dbnScale_.block(), dbnBias_.block()});
   } else {

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/a105b240/src/model/operation/batchnorm.cc
----------------------------------------------------------------------
diff --git a/src/model/operation/batchnorm.cc b/src/model/operation/batchnorm.cc
new file mode 100644
index 0000000..9b6f9cd
--- /dev/null
+++ b/src/model/operation/batchnorm.cc
@@ -0,0 +1,164 @@
+#include "./batchnorm.h"
+
+namespace singa{
+
+BatchNormHandle::BatchNormHandle(const float momentum, const Tensor& input, const Tensor& RunningMean_, 
+  const Tensor& RunningVariance_){
+  factor_ = momentum;
+  batchsize = input.shape()[0];
+  channels_= input.shape()[2];
+  if (input.nDim()== 4u){
+    height_= input.shape()[3];
+    width_=input.shape()[4];
+    is_2d_= false;
+  }else{
+    size_t height_ = 1;
+    size_t width_ = 1;
+    bool is_2d_ = true;
+  }
+  runningMean_= RunningMean_;
+  runningVariance_= RunningVariance_;
+};
+
+CudnnBatchNormHandle::CudnnBatchNormHandle(const float momentum, const Tensor& input, const Tensor& RunningMean_, 
+  const Tensor& RunningVariance_):BatchNormHandle(momentum, input, RunningMean_, RunningVariance_){
+  if (is_2d_)
+      mode_ = CUDNN_BATCHNORM_PER_ACTIVATION;
+  else
+      mode_ = CUDNN_BATCHNORM_SPATIAL;
+  auto dtype = input.data_type();
+  CUDNN_CHECK(cudnnCreateTensorDescriptor(&shape_desc_));
+  CUDNN_CHECK(cudnnCreateTensorDescriptor(&param_desc_));
+  CUDNN_CHECK(cudnnSetTensor4dDescriptor(shape_desc_, CUDNN_TENSOR_NCHW,
+                                         GetCudnnDataType(dtype), batchsize,
+                                         channels_, height_, width_));
+  CUDNN_CHECK(cudnnSetTensor4dDescriptor(param_desc_, CUDNN_TENSOR_NCHW,
+                                         GetCudnnDataType(dtype), 1, channels_,
+                                         1, 1));
+  };
+
+Tensor GpuBatchNormForwardTraining(const Tensor& x, const Tensor& bnScale_, const Tensor& bnBias_, 
+  std::vector<Tensor>& cache, const CudnnBatchNormHandle &cbnh) {
+  
+  auto shape = x.shape();
+  Tensor output;
+  Tensor input;  //for unification of 2d and 4d cases.
+  if (cbnh.is_2d_)
+    input = Reshape(x, Shape{shape.at(0), shape.at(1), 1, 1});
+  else
+    input = x;
+  output.ResetLike(x);
+
+  Tensor resultSaveMean_;
+  Tensor resultSaveVariance_;
+
+  resultSaveMean_.Reshape(Shape{cbnh.channels_});
+  resultSaveVariance_.Reshape(Shape{cbnh.channels_});
+
+  cache.push_back(resultSaveMean_);
+  cache.push_back(resultSaveVariance_);
+  cache.push_back(bnScale_);
+  //cache={x, mean, var, scale}
+
+    output.device()->Exec(
+        [&output, &input, &bnScale_, &bnBias_, &cache, &cbnh](Context* ctx) {
+          Block* inBlock = input.block(), * outBlock = output.block(),
+                 * saveMeanBlock = cache[1].block(),
+                 * saveVarBlock = cache[2].block(),
+                 * runningMeanBlock = cbnh.runningMean_.block(),
+                 * runningVarBlock = cbnh.runningVariance_.block(),
+                 * bnScaleBlock = bnScale_.block(),
+                 * bnBiasBlock = bnBias_.block();
+          const float alpha = 1.0f, beta = 0.0f;
+          double epsilon = CUDNN_BN_MIN_EPSILON;
+          CUDNN_CHECK(cudnnBatchNormalizationForwardTraining(
+              ctx->cudnn_handle, cbnh.mode_, &alpha, &beta, cbnh.shape_desc_,
+              inBlock->data(), cbnh.shape_desc_, outBlock->mutable_data(),
+              cbnh.param_desc_, bnScaleBlock->data(), bnBiasBlock->data(), cbnh.factor_,
+              runningMeanBlock->mutable_data(), runningVarBlock->mutable_data(),
+              epsilon, saveMeanBlock->mutable_data(),
+              saveVarBlock->mutable_data()));
+        },
+        {input.block(), bnScale_.block(), bnBias_.block()},
+        {output.block(), cbnh.runningMean_.block(), cbnh.runningVariance_.block(),
+         cache[1].block(), cache[2].block()}); 
+  if (cbnh.is_2d_) output.Reshape(Shape{shape.at(0), shape.at(1)});
+  return output;
+};
+
+Tensor GpuBatchNormForwardInference(const Tensor& x, const Tensor& bnScale_, const Tensor& bnBias_, 
+   const CudnnBatchNormHandle &cbnh) {
+  auto shape = x.shape();
+  Tensor output;
+  Tensor input;  //for unification of 2d and 4d cases.
+  if (cbnh.is_2d_)
+    input = Reshape(x, Shape{shape.at(0), shape.at(1), 1, 1});
+  else
+    input = x;
+  output.ResetLike(x);
+    output.device()->Exec(
+        [&output, &input, &bnScale_, &bnBias_, &cbnh](Context* ctx) {
+          Block* inBlock = input.block(), * outBlock = output.block(),
+                 * runningMeanBlock = cbnh.runningMean_.block(),
+                 * runningVarBlock = cbnh.runningVariance_.block(),
+                 * bnScaleBlock = bnScale_.block(),
+                 * bnBiasBlock = bnBias_.block();
+          const float alpha = 1.0f, beta = 0.0f;
+          double epsilon = CUDNN_BN_MIN_EPSILON;
+          CUDNN_CHECK(cudnnBatchNormalizationForwardInference(
+              ctx->cudnn_handle, cbnh.mode_, &alpha, &beta, cbnh.shape_desc_,
+              inBlock->data(), cbnh.shape_desc_, outBlock->mutable_data(),
+              cbnh.param_desc_, bnScaleBlock->data(), bnBiasBlock->data(),
+              runningMeanBlock->data(), runningVarBlock->data(), epsilon));
+        },
+        {input.block(), bnScale_.block(), bnBias_.block(), cbnh.runningMean_.block(),
+         cbnh.runningVariance_.block()},
+        {output.block()});
+  if (cbnh.is_2d_) output.Reshape(Shape{shape.at(0), shape.at(1)});
+  return output;
+};
+
+
+std::vector<Tensor> GpuBatchNormBackward(const Tensor& dy, const std::vector<Tensor>& cache, const CudnnBatchNormHandle &cbnh){
+
+  vector<Tensor> out_grads;
+  Tensor dx;
+  dx.ResetLike(dy);
+
+  Tensor dbnScale_;
+  dbnScale_.ResetLike(cache[3]);
+
+  Tensor dbnBias_;
+  dbnBias_.ResetLike(cache[3]);
+  //dbnBias_.ResetLike(bnBias_);
+
+  dx.device()->Exec(
+      [&dx, &dbnScale_, &dbnBias_, &dy, &cache, &cbnh](Context* ctx) {
+        Block* dyblock = dy.block(), * dxblock = dx.block(),
+               * xblock = cache[0].block(), * bnScaleBlock = cache[3].block(),
+               * dbnScaleBlock = dbnScale_.block(),
+               * dbnBiasBlock = dbnBias_.block(),
+               * saveMeanBlock = cache[1].block(),
+               * saveVarBlock = cache[2].block();
+        const float alpha = 1.0f, beta = .0f;
+        double epsilon = CUDNN_BN_MIN_EPSILON;
+        CUDNN_CHECK(cudnnBatchNormalizationBackward(
+            ctx->cudnn_handle, cbnh.mode_, &alpha, &beta, &alpha, &beta,
+            cbnh.shape_desc_, xblock->data(), cbnh.shape_desc_, dyblock->data(),
+            cbnh.shape_desc_, dxblock->mutable_data(), cbnh.param_desc_,
+            bnScaleBlock->data(), dbnScaleBlock->mutable_data(),
+            dbnBiasBlock->mutable_data(), epsilon, saveMeanBlock->data(),
+            saveVarBlock->data()));
+      },
+      {cache[0].block(), dy.block(), cache[3].block(), cache[1].block(),
+       cache[2].block()},
+      {dx.block(), dbnScale_.block(), dbnBias_.block()});
+  
+  if (cbnh.is_2d_) dx.Reshape(Shape{dx.shape().at(0), dx.shape().at(1)});
+  out_grads.push_back(dx);
+  out_grads.push_back(dbnScale_);
+  out_grads.push_back(dbnBias_);
+return out_grads;
+};
+
+}

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/a105b240/src/model/operation/batchnorm.h
----------------------------------------------------------------------
diff --git a/src/model/operation/batchnorm.h b/src/model/operation/batchnorm.h
new file mode 100644
index 0000000..f2da4cd
--- /dev/null
+++ b/src/model/operation/batchnorm.h
@@ -0,0 +1,62 @@
+//#ifndef SINGA_MODEL_OPERATION_BATCHNORM_H_
+//#define SINGA_MODEL_OPERATION_BATCHNORM_H_
+
+#include <vector>
+#include "singa/core/tensor.h"
+
+#ifdef USE_CUDNN
+#include <cudnn.h>
+#include "../layer/cudnn_utils.h" // check_cudnn
+#endif // USE_CUDNN 
+
+namespace singa{
+
+class BatchNormHandle{
+  public:
+  	BatchNormHandle(const float momentum, const Tensor& input, const Tensor& RunningMean_, const Tensor& RunningVariance_);
+
+  	float factor_;
+  	size_t channels_;
+  	size_t batchsize;
+
+  	Tensor runningMean_;
+  	Tensor runningVariance_;
+
+  	bool is_2d_ ;
+  	//bool train = true;
+
+  	size_t height_;
+  	size_t width_;
+};
+
+//Tensor CpuBatchNormForwardTraining();
+
+//Tensor CpuBatchNormForwardInference();
+
+//Tensor CpuBatchNormBackwardx();
+
+
+#ifdef USE_CUDNN
+
+class CudnnBatchNormHandle: public BatchNormHandle{
+    public:
+      CudnnBatchNormHandle(const float momentum, const Tensor& input, const Tensor& RunningMean_, const Tensor& RunningVariance_);
+
+      //~CudnnBatchNormHandle();
+
+      cudnnBatchNormMode_t mode_;
+      cudnnTensorDescriptor_t shape_desc_ = nullptr;
+      cudnnTensorDescriptor_t param_desc_ = nullptr;
+};
+
+Tensor GpuBatchNormForwardTraining(const Tensor& x, const Tensor& bnScale_, const Tensor& bnBias_, 
+  std::vector<Tensor>& cache, const CudnnBatchNormHandle &cbnh);
+
+Tensor GpuBatchNormForwardInference(const Tensor& x, const Tensor& bnScale_, const Tensor& bnBias_, 
+	const CudnnBatchNormHandle &cbnh);
+
+std::vector<Tensor> GpuBatchNormBackward(const Tensor& dy, const std::vector<Tensor>& cache, const CudnnBatchNormHandle &cbnh);
+
+#endif  // USE_CUDNN
+
+}  // namespace singa

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/a105b240/test/python/test_operation.py
----------------------------------------------------------------------
diff --git a/test/python/test_operation.py b/test/python/test_operation.py
index 315a992..0e851d7 100755
--- a/test/python/test_operation.py
+++ b/test/python/test_operation.py
@@ -70,5 +70,22 @@ class TestPythonOperation(unittest.TestCase):
         y_without_bias = conv_without_bias_1(cpu_input_tensor)
         self.check_shape(y_without_bias.shape, (2, 1, 2, 2))
 
+    def test_batchnorm2d_gpu(self):
+        batchnorm_0 = autograd.BatchNorm2d(3)
+
+        gpu_input_tensor = tensor.Tensor(shape=(2, 3, 3, 3), device=gpu_dev)
+        gpu_input_tensor.gaussian(0.0, 1.0)
+
+        dy = CTensor([2, 3, 3, 3])
+        singa.Gaussian(0.0, 1.0, dy)
+
+        y=batchnorm_0(gpu_input_tensor)
+        dx, ds, db = y.creator.backward(dy)
+
+        self.check_shape(y.shape, (2, 3, 3, 3))
+        self.check_shape(dx.shape(), (2, 3, 3, 3))
+        self.check_shape(dx.shape(), (3,))
+        self.check_shape(db.shape(), (3,))
+
 if __name__ == '__main__':
     unittest.main()