You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@singa.apache.org by wa...@apache.org on 2016/10/21 15:33:27 UTC
incubator-singa git commit: SINGA-267 Add spatial mode in batch
normalization layer
Repository: incubator-singa
Updated Branches:
refs/heads/master 9eabb9563 -> 61faa840e
SINGA-267 Add spatial mode in batch normalization layer
Added spatial mode in batch normalization layer in C++ implementation,
which corresponds to CUDNN_BATCHNORM_SPATIAL in CuDNN.
Also added logics to automatically detect proper modes in batch
normalization layer, i.e., if input is 2D tensor then batchnorm
layer chooses PER_ACTIVATION mode, if input is 4D tensor then
batchnorm layer chooses SPATIAL mode.
Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/61faa840
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/61faa840
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/61faa840
Branch: refs/heads/master
Commit: 61faa840e155e4259920141c3c909483fa1c14f2
Parents: 9eabb95
Author: WANG Ji <ij...@gmail.com>
Authored: Fri Oct 21 12:14:10 2016 +0800
Committer: WANG Ji <ij...@gmail.com>
Committed: Fri Oct 21 22:11:34 2016 +0800
----------------------------------------------------------------------
src/model/layer/batchnorm.cc | 243 +++++++++++++++++++-------------
src/model/layer/cudnn_batchnorm.cc | 169 +++++++---------------
2 files changed, 200 insertions(+), 212 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/61faa840/src/model/layer/batchnorm.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/batchnorm.cc b/src/model/layer/batchnorm.cc
index ad7b2b3..afe9a36 100644
--- a/src/model/layer/batchnorm.cc
+++ b/src/model/layer/batchnorm.cc
@@ -18,6 +18,7 @@
* under the License.
*
************************************************************/
+#include <vector>
#include "batchnorm.h"
namespace singa {
@@ -28,7 +29,7 @@ RegisterLayerClass(singacl_batchnorm, BatchNorm);
void BatchNorm::Setup(const Shape& in_sample, const LayerConf& conf) {
Layer::Setup(in_sample, conf);
out_sample_shape_ = in_sample;
- factor_ = (float) conf.batchnorm_conf().factor();
+ factor_ = (float)conf.batchnorm_conf().factor();
channels_ = in_sample.at(0);
if (in_sample.size() == 3u)
height_ = in_sample.at(1);
@@ -43,7 +44,7 @@ void BatchNorm::Setup(const Shape& in_sample, const LayerConf& conf) {
else
is_2d_ = false;
- bnScale_.Reshape(Shape{channels_ * height_ * width_});
+ bnScale_.Reshape(Shape{channels_});
bnBias_.ResetLike(bnScale_);
runningMean_.ResetLike(bnScale_);
runningVariance_.ResetLike(bnScale_);
@@ -70,39 +71,83 @@ const Tensor BatchNorm::Forward(int flag, const Tensor& input) {
Tensor output, mean, var, xnorm;
output.ResetLike(x);
- if ((flag & kTrain) == kTrain) {
- mean = Average(x, 0);
- runningMean_ *= 1.0f - factor_;
- Axpy(factor_, mean, &runningMean_);
- xnorm = x.Clone();
- SubRow(mean, &xnorm);
- xnorm = Square(xnorm);
- var = Average(xnorm, 0);
- runningVariance_ *= 1.0f - factor_;
- Axpy(factor_, var, &runningVariance_);
- Tensor tmp = var.Clone();
- tmp = Sqrt(tmp);
- tmp += 1e-6f;
- xnorm = x.Clone();
- SubRow(mean, &xnorm);
- DivRow(tmp, &xnorm);
- output = xnorm.Clone();
- MultRow(bnScale_, &output);
- AddRow(bnBias_, &output);
- buf_.push(x);
- buf_.push(mean);
- buf_.push(var);
- buf_.push(xnorm);
- } else {
- xnorm = x.Clone();
- SubRow(runningMean_, &xnorm);
- Tensor tmp = runningVariance_.Clone();
- tmp = Sqrt(tmp);
- tmp += 1e-6f;
- DivRow(tmp, &xnorm);
- output = xnorm.Clone();
- MultRow(bnScale_, &output);
- AddRow(bnBias_, &output);
+ if ((flag & kTrain) == kTrain) { // forward for train
+ if (is_2d_) { // batchnorm_per_activation mode
+ mean = Average(x, 0);
+ runningMean_ *= 1.0f - factor_;
+ Axpy(factor_, mean, &runningMean_);
+ xnorm = x.Clone();
+ SubRow(mean, &xnorm);
+ xnorm = Square(xnorm);
+ var = Average(xnorm, 0);
+ runningVariance_ *= 1.0f - factor_;
+ Axpy(factor_, var, &runningVariance_);
+ Tensor tmp = var.Clone();
+ tmp = Sqrt(tmp);
+ tmp += 1e-6f;
+ xnorm = x.Clone();
+ SubRow(mean, &xnorm);
+ DivRow(tmp, &xnorm);
+ output = xnorm.Clone();
+ MultRow(bnScale_, &output);
+ AddRow(bnBias_, &output);
+ buf_.push(x);
+ buf_.push(mean);
+ buf_.push(var);
+ buf_.push(xnorm);
+ } else { // batchnorm_spatial mode
+ LOG(FATAL) << "Trainning SpatialBatchNormalization has not been "
+ "implemented yet...";
+ }
+ } else { // forward for test
+ if (is_2d_) { // batchnorm_per_activation mode
+ xnorm = x.Clone();
+ SubRow(runningMean_, &xnorm);
+ Tensor tmp = runningVariance_.Clone();
+ tmp = Sqrt(tmp);
+ tmp += 1e-6f;
+ DivRow(tmp, &xnorm);
+ output = xnorm.Clone();
+ MultRow(bnScale_, &output);
+ AddRow(bnBias_, &output);
+ } else { // batchnorm_spatial mode
+ runningMean_.Reshape(Shape{channels_, 1});
+ runningVariance_.Reshape(Shape{channels_, 1});
+ bnScale_.Reshape(Shape{channels_, 1});
+ bnBias_.Reshape(Shape{channels_, 1});
+
+ std::vector<Tensor> mean_stack, var_stack, scale_stack, bias_stack;
+ for (int i = 0; i < height_ * width_; ++i) {
+ mean_stack.push_back(runningMean_);
+ var_stack.push_back(runningVariance_);
+ scale_stack.push_back(bnScale_);
+ bias_stack.push_back(bnBias_);
+ }
+ auto mean = ConcatenateColumns(mean_stack);
+ auto var = ConcatenateColumns(var_stack);
+ auto scale = ConcatenateColumns(scale_stack);
+ auto bias = ConcatenateColumns(bias_stack);
+
+ mean.Reshape(Shape{channels_ * height_ * width_});
+ var.Reshape(Shape{channels_ * height_ * width_});
+ scale.Reshape(Shape{channels_ * height_ * width_});
+ bias.Reshape(Shape{channels_ * height_ * width_});
+
+ xnorm = x.Clone();
+ SubRow(mean, &xnorm);
+ var = Sqrt(var);
+ var += 1e-6f;
+ DivRow(var, &xnorm);
+ output = xnorm.Clone();
+
+ MultRow(scale, &output);
+ AddRow(bias, &output);
+
+ runningMean_.Reshape(Shape{channels_});
+ runningVariance_.Reshape(Shape{channels_});
+ bnScale_.Reshape(Shape{channels_});
+ bnBias_.Reshape(Shape{channels_});
+ }
}
if (!is_2d_)
@@ -127,71 +172,75 @@ const std::pair<Tensor, vector<Tensor>> BatchNorm::Backward(
vector<Tensor> param_grad;
if ((flag & kTrain) == kTrain) {
- // gxnrom
- Tensor gxnorm = dy.Clone();
- MultRow(bnScale_, &gxnorm);
- // gvar
- Tensor tmp = var.Clone();
- tmp += 1e-6f;
- tmp = Pow(var, -1.5f);
- tmp *= -0.5f;
-
- Tensor tmpx = input.Clone();
- SubRow(mean, &tmpx);
-
- tmpx = tmpx * gxnorm;
- MultRow(tmp, &tmpx);
- Tensor gvar;
- gvar.ResetLike(var);
- SumRows(tmpx, &gvar);
- // gmean
- tmp = var.Clone();
- tmp += 1e-6f;
- tmp = Pow(tmp, -0.5f);
- tmp *= -1.0f;
- Tensor tmpx_r;
- tmpx_r.ResetLike(tmp);
- SumRows(gxnorm, &tmpx_r);
- Tensor gmean = tmpx_r * tmp;
-
- tmpx = input.Clone();
- SubRow(mean, &tmpx);
- SumRows(tmpx, &tmp);
- tmp *= -2.0f / input.shape(0);
- tmp = tmp * gvar;
- gmean = gmean + tmp;
- // dx
- tmp = var.Clone();
- tmp += 1e-6f;
- tmp = Pow(tmp, -0.5f);
- dx = gxnorm.Clone();
- MultRow(tmp, &dx);
-
- tmpx = input.Clone();
- SubRow(mean, &tmpx);
- tmpx *= 2.0f / input.shape(0);
- MultRow(gvar, &tmpx);
- dx = dx + tmpx;
-
- tmp = gmean.Clone();
- tmp *= 1.0f / input.shape(0);
-
- AddRow(tmp, &dx);
- // dbnScale
- tmpx = dy * xnorm;
- SumRows(tmpx, &dbnScale_);
- // dbnBias
- SumRows(dy, &dbnBias_);
- param_grad.push_back(dbnScale_);
- param_grad.push_back(dbnBias_);
- Tensor dummy;
- param_grad.push_back(dummy);
- param_grad.push_back(dummy);
+ if (is_2d_) {
+ // gxnrom
+ Tensor gxnorm = dy.Clone();
+ MultRow(bnScale_, &gxnorm);
+ // gvar
+ Tensor tmp = var.Clone();
+ tmp += 1e-6f;
+ tmp = Pow(var, -1.5f);
+ tmp *= -0.5f;
+
+ Tensor tmpx = input.Clone();
+ SubRow(mean, &tmpx);
+
+ tmpx = tmpx * gxnorm;
+ MultRow(tmp, &tmpx);
+ Tensor gvar;
+ gvar.ResetLike(var);
+ SumRows(tmpx, &gvar);
+ // gmean
+ tmp = var.Clone();
+ tmp += 1e-6f;
+ tmp = Pow(tmp, -0.5f);
+ tmp *= -1.0f;
+ Tensor tmpx_r;
+ tmpx_r.ResetLike(tmp);
+ SumRows(gxnorm, &tmpx_r);
+ Tensor gmean = tmpx_r * tmp;
+
+ tmpx = input.Clone();
+ SubRow(mean, &tmpx);
+ SumRows(tmpx, &tmp);
+ tmp *= -2.0f / input.shape(0);
+ tmp = tmp * gvar;
+ gmean = gmean + tmp;
+ // dx
+ tmp = var.Clone();
+ tmp += 1e-6f;
+ tmp = Pow(tmp, -0.5f);
+ dx = gxnorm.Clone();
+ MultRow(tmp, &dx);
+
+ tmpx = input.Clone();
+ SubRow(mean, &tmpx);
+ tmpx *= 2.0f / input.shape(0);
+ MultRow(gvar, &tmpx);
+ dx = dx + tmpx;
+
+ tmp = gmean.Clone();
+ tmp *= 1.0f / input.shape(0);
+
+ AddRow(tmp, &dx);
+ // dbnScale
+ tmpx = dy * xnorm;
+ SumRows(tmpx, &dbnScale_);
+ // dbnBias
+ SumRows(dy, &dbnBias_);
+ param_grad.push_back(dbnScale_);
+ param_grad.push_back(dbnBias_);
+ Tensor dummy;
+ param_grad.push_back(dummy);
+ param_grad.push_back(dummy);
+ } else {
+ LOG(FATAL) << "Trainning SpatialBatchNormalization has not been "
+ "implemented yet...";
+ }
} else {
LOG(ERROR) << "Do not call backward for evaluation phase";
}
- if (!is_2d_)
- dx.Reshape(Shape{dx.shape(0), channels_, height_, width_});
+ if (!is_2d_) dx.Reshape(Shape{dx.shape(0), channels_, height_, width_});
return std::make_pair(dx, param_grad);
}
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/61faa840/src/model/layer/cudnn_batchnorm.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/cudnn_batchnorm.cc b/src/model/layer/cudnn_batchnorm.cc
index a7f80be..19a2ccb 100644
--- a/src/model/layer/cudnn_batchnorm.cc
+++ b/src/model/layer/cudnn_batchnorm.cc
@@ -39,36 +39,25 @@ void CudnnBatchNorm::ToDevice(std::shared_ptr<Device> device) {
void CudnnBatchNorm::Setup(const Shape& in_sample, const LayerConf& conf) {
BatchNorm::Setup(in_sample, conf);
- bnScale_.Reshape(Shape{channels_});
- bnBias_.Reshape(Shape{channels_});
- dbnScale_.Reshape(Shape{channels_});
- dbnBias_.Reshape(Shape{channels_});
- runningMean_.Reshape(Shape{channels_});
- runningVariance_.Reshape(Shape{channels_});
resultSaveMean_.Reshape(Shape{channels_});
resultSaveVariance_.Reshape(Shape{channels_});
}
void CudnnBatchNorm::InitCudnn(const Shape& shape, DataType dtype) {
CHECK(!has_init_cudnn_);
- mode_ = CUDNN_BATCHNORM_SPATIAL;
+ if (is_2d_)
+ mode_ = CUDNN_BATCHNORM_PER_ACTIVATION;
+ else
+ mode_ = CUDNN_BATCHNORM_SPATIAL;
CUDNN_CHECK(cudnnCreateTensorDescriptor(&shape_desc_));
CUDNN_CHECK(cudnnCreateTensorDescriptor(¶m_desc_));
CHECK_EQ(shape.size(), 4u);
- CUDNN_CHECK(cudnnSetTensor4dDescriptor(shape_desc_,
- CUDNN_TENSOR_NCHW,
- GetCudnnDataType(dtype),
- shape[0],
- shape[1],
- shape[2],
- shape[3]));
- CUDNN_CHECK(cudnnSetTensor4dDescriptor(param_desc_,
- CUDNN_TENSOR_NCHW,
- GetCudnnDataType(dtype),
- 1,
- shape[1],
- 1,
- 1));
+ CUDNN_CHECK(cudnnSetTensor4dDescriptor(shape_desc_, CUDNN_TENSOR_NCHW,
+ GetCudnnDataType(dtype), shape[0],
+ shape[1], shape[2], shape[3]));
+ CUDNN_CHECK(cudnnSetTensor4dDescriptor(param_desc_, CUDNN_TENSOR_NCHW,
+ GetCudnnDataType(dtype), 1, shape[1],
+ 1, 1));
has_init_cudnn_ = true;
}
const Tensor CudnnBatchNorm::Forward(int flag, const Tensor& input) {
@@ -76,96 +65,65 @@ const Tensor CudnnBatchNorm::Forward(int flag, const Tensor& input) {
auto dtype = input.data_type();
Tensor output;
Tensor x;
- if(is_2d_)
+ if (is_2d_)
x = Reshape(input, Shape{shape.at(0), shape.at(1), 1, 1});
else
x = input;
shape = x.shape();
- if (!has_init_cudnn_)
- InitCudnn(shape, dtype);
+ if (!has_init_cudnn_) InitCudnn(shape, dtype);
// TODO(wangji): check device id of input and params
output.ResetLike(x);
if ((flag & kTrain) == kTrain) {
output.device()->Exec(
[=](Context* ctx) {
- Block *inBlock = x.block(), *outBlock = output.block(),
- *saveMeanBlock = resultSaveMean_.block(),
- *saveVarBlock = resultSaveVariance_.block(),
- *runningMeanBlock = runningMean_.block(),
- *runningVarBlock = runningVariance_.block(),
- *bnScaleBlock = bnScale_.block(),
- *bnBiasBlock = bnBias_.block();
+ Block* inBlock = x.block(), * outBlock = output.block(),
+ * saveMeanBlock = resultSaveMean_.block(),
+ * saveVarBlock = resultSaveVariance_.block(),
+ * runningMeanBlock = runningMean_.block(),
+ * runningVarBlock = runningVariance_.block(),
+ * bnScaleBlock = bnScale_.block(),
+ * bnBiasBlock = bnBias_.block();
const float alpha = 1.0f, beta = 0.0f;
double epsilon = CUDNN_BN_MIN_EPSILON;
CUDNN_CHECK(cudnnBatchNormalizationForwardTraining(
- ctx->cudnn_handle,
- this->mode_,
- &alpha,
- &beta,
- shape_desc_,
- inBlock->data(),
- shape_desc_,
- outBlock->mutable_data(),
- param_desc_,
- bnScaleBlock->data(),
- bnBiasBlock->data(),
- factor_,
- runningMeanBlock->mutable_data(),
- runningVarBlock->mutable_data(),
- epsilon,
- saveMeanBlock->mutable_data(),
+ ctx->cudnn_handle, this->mode_, &alpha, &beta, shape_desc_,
+ inBlock->data(), shape_desc_, outBlock->mutable_data(),
+ param_desc_, bnScaleBlock->data(), bnBiasBlock->data(), factor_,
+ runningMeanBlock->mutable_data(), runningVarBlock->mutable_data(),
+ epsilon, saveMeanBlock->mutable_data(),
saveVarBlock->mutable_data()));
},
- {x.block(),
- bnScale_.block(),
- bnBias_.block()},
- {output.block(),
- runningMean_.block(),
- runningVariance_.block(),
- resultSaveMean_.block(),
- resultSaveVariance_.block()});
+ {x.block(), bnScale_.block(), bnBias_.block()},
+ {output.block(), runningMean_.block(), runningVariance_.block(),
+ resultSaveMean_.block(), resultSaveVariance_.block()});
buf_.push(x);
} else {
output.device()->Exec(
[=](Context* ctx) {
- Block *inBlock = x.block(), *outBlock = output.block(),
- *runningMeanBlock = runningMean_.block(),
- *runningVarBlock = runningVariance_.block(),
- *bnScaleBlock = bnScale_.block(),
- *bnBiasBlock = bnBias_.block();
+ Block* inBlock = x.block(), * outBlock = output.block(),
+ * runningMeanBlock = runningMean_.block(),
+ * runningVarBlock = runningVariance_.block(),
+ * bnScaleBlock = bnScale_.block(),
+ * bnBiasBlock = bnBias_.block();
const float alpha = 1.0f, beta = 0.0f;
double epsilon = CUDNN_BN_MIN_EPSILON;
CUDNN_CHECK(cudnnBatchNormalizationForwardInference(
- ctx->cudnn_handle,
- this->mode_,
- &alpha,
- &beta,
- shape_desc_,
- inBlock->data(),
- shape_desc_,
- outBlock->mutable_data(),
- param_desc_,
- bnScaleBlock->data(),
- bnBiasBlock->data(),
- runningMeanBlock->data(),
- runningVarBlock->data(),
- epsilon));
+ ctx->cudnn_handle, this->mode_, &alpha, &beta, shape_desc_,
+ inBlock->data(), shape_desc_, outBlock->mutable_data(),
+ param_desc_, bnScaleBlock->data(), bnBiasBlock->data(),
+ runningMeanBlock->data(), runningVarBlock->data(), epsilon));
},
- {x.block(),
- bnScale_.block(),
- bnBias_.block(),
- runningMean_.block(),
+ {x.block(), bnScale_.block(), bnBias_.block(), runningMean_.block(),
runningVariance_.block()},
{output.block()});
}
- if (is_2d_)
- output.Reshape(Shape{shape.at(0), shape.at(1)});
+ if (is_2d_) output.Reshape(Shape{shape.at(0), shape.at(1)});
return output;
}
const std::pair<Tensor, vector<Tensor>> CudnnBatchNorm::Backward(
int flag, const Tensor& grad) {
- vector <Tensor> param_grad;
+ vector<Tensor> param_grad;
Tensor dx;
if ((flag & kTrain) == kTrain) {
Tensor x = buf_.top();
@@ -173,44 +131,26 @@ const std::pair<Tensor, vector<Tensor>> CudnnBatchNorm::Backward(
dx.ResetLike(grad);
dx.device()->Exec(
[=](Context* ctx) {
- Block *dyblock = grad.block(), *dxblock = dx.block(),
- *xblock = x.block(),
- *bnScaleBlock = bnScale_.block(),
- *dbnScaleBlock = dbnScale_.block(),
- *dbnBiasBlock = dbnBias_.block(),
- *saveMeanBlock = resultSaveMean_.block(),
- *saveVarBlock = resultSaveVariance_.block();
+ Block* dyblock = grad.block(), * dxblock = dx.block(),
+ * xblock = x.block(), * bnScaleBlock = bnScale_.block(),
+ * dbnScaleBlock = dbnScale_.block(),
+ * dbnBiasBlock = dbnBias_.block(),
+ * saveMeanBlock = resultSaveMean_.block(),
+ * saveVarBlock = resultSaveVariance_.block();
const float alpha = 1.0f, beta = .0f;
double epsilon = CUDNN_BN_MIN_EPSILON;
- CUDNN_CHECK(cudnnBatchNormalizationBackward(ctx->cudnn_handle,
- this->mode_,
- &alpha,
- &beta,
- &alpha,
- &beta,
- shape_desc_,
- xblock->data(),
- shape_desc_,
- dyblock->data(),
- shape_desc_,
- dxblock->mutable_data(),
- param_desc_,
- bnScaleBlock->data(),
- dbnScaleBlock->mutable_data(),
- dbnBiasBlock->mutable_data(),
- epsilon,
- saveMeanBlock->data(),
+ CUDNN_CHECK(cudnnBatchNormalizationBackward(
+ ctx->cudnn_handle, this->mode_, &alpha, &beta, &alpha, &beta,
+ shape_desc_, xblock->data(), shape_desc_, dyblock->data(),
+ shape_desc_, dxblock->mutable_data(), param_desc_,
+ bnScaleBlock->data(), dbnScaleBlock->mutable_data(),
+ dbnBiasBlock->mutable_data(), epsilon, saveMeanBlock->data(),
saveVarBlock->data()));
},
- {dx.block(),
- grad.block(),
- bnScale_.block(),
- resultSaveMean_.block(),
+ {dx.block(), grad.block(), bnScale_.block(), resultSaveMean_.block(),
resultSaveVariance_.block()},
- {dx.block(),
- dbnScale_.block(),
- dbnBias_.block()});
+ {dx.block(), dbnScale_.block(), dbnBias_.block()});
} else {
LOG(ERROR) << "Do not call backward for evaluation phase";
}
@@ -219,8 +159,7 @@ const std::pair<Tensor, vector<Tensor>> CudnnBatchNorm::Backward(
Tensor dummy;
param_grad.push_back(dummy);
param_grad.push_back(dummy);
- if (is_2d_)
- dx.Reshape(Shape{dx.shape().at(0), dx.shape().at(1)});
+ if (is_2d_) dx.Reshape(Shape{dx.shape().at(0), dx.shape().at(1)});
return std::make_pair(dx, param_grad);
}
} // namespace