You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@singa.apache.org by wa...@apache.org on 2016/04/06 18:00:14 UTC
[2/4] incubator-singa git commit: SINGA-136 Support cuDNN v4
SINGA-136 Support cuDNN v4
Add cuDNN BatchNorm Layer
Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/4b4ad057
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/4b4ad057
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/4b4ad057
Branch: refs/heads/master
Commit: 4b4ad0573ac0f3bcc0403fe13d81925782d4352d
Parents: afc50a9
Author: seaok <se...@gmail.com>
Authored: Wed Apr 6 14:58:26 2016 +0800
Committer: seaok <se...@gmail.com>
Committed: Wed Apr 6 14:58:26 2016 +0800
----------------------------------------------------------------------
examples/cifar10/cudnn_bm.conf | 334 ++++++++++++++++++++++++++++
include/singa/neuralnet/neuron_layer.h | 30 +++
src/driver.cc | 2 +
src/neuralnet/neuron_layer/bm.cc | 68 ++++++
src/neuralnet/neuron_layer/cudnn_bm.cc | 150 +++++++++++++
src/proto/job.proto | 7 +
6 files changed, 591 insertions(+)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/4b4ad057/examples/cifar10/cudnn_bm.conf
----------------------------------------------------------------------
diff --git a/examples/cifar10/cudnn_bm.conf b/examples/cifar10/cudnn_bm.conf
new file mode 100644
index 0000000..071ed8a
--- /dev/null
+++ b/examples/cifar10/cudnn_bm.conf
@@ -0,0 +1,334 @@
+name: "cifar10-convnet"
+train_steps: 70000
+test_steps: 100
+test_freq: 1000
+#validate_steps: 100
+#validate_freq: 300
+disp_freq: 200
+gpu: 0
+#checkpoint_path: "examples/cifar10/checkpoint/step1000-worker0"
+train_one_batch {
+ alg: kBP
+}
+updater{
+ type: kSGD
+ weight_decay:0.004
+ momentum:0.9
+ learning_rate {
+ type: kFixedStep
+ fixedstep_conf:{
+ step:0
+ step:60000
+ step:65000
+ step_lr:0.001
+ step_lr:0.0001
+ step_lr:0.00001
+ }
+ }
+}
+neuralnet {
+ layer{
+ name: "data"
+ type: kRecordInput
+ store_conf {
+ backend: "kvfile"
+ path: "examples/cifar10/train_data.bin"
+ mean_file: "examples/cifar10/image_mean.bin"
+ batchsize: 100
+ #random_skip: 5000
+ shape: 3
+ shape: 32
+ shape: 32
+ }
+ include: kTrain
+ }
+# layer{
+# name: "data"
+# type: kRecordInput
+# store_conf {
+# backend: "kvfile"
+# path: "examples/cifar10/val_data.bin"
+# mean_file: "examples/cifar10/image_mean.bin"
+# batchsize: 64
+# random_skip: 5000
+# shape: 3
+# shape: 32
+# shape: 32
+# }
+# include: kVal
+# }
+ layer{
+ name: "data"
+ type: kRecordInput
+ store_conf {
+ backend: "kvfile"
+ path: "examples/cifar10/test_data.bin"
+ mean_file: "examples/cifar10/image_mean.bin"
+ batchsize: 100
+ shape: 3
+ shape: 32
+ shape: 32
+ }
+ include: kTest
+ }
+
+ layer {
+ name: "conv1"
+ type: kCudnnConv
+ srclayers: "data"
+ convolution_conf {
+ num_filters: 32
+ kernel: 5
+ stride: 1
+ pad:2
+ }
+ param {
+ name: "w1"
+ init {
+ type:kGaussian
+ std:0.0001
+ }
+ }
+ param {
+ name: "b1"
+ lr_scale:2.0
+ wd_scale: 0
+ init {
+ type: kConstant
+ value:0
+ }
+ }
+ }
+
+ layer {
+ name: "pool1"
+ type: kCudnnPool
+ srclayers: "conv1"
+ pooling_conf {
+ pool: MAX
+ kernel: 3
+ stride: 2
+ }
+ }
+ layer {
+ name: "bm1"
+ type: kCudnnBM
+ param {
+ name: "s11"
+ init {
+ type:kConstant
+ value:1
+ }
+ }
+ param {
+ name: "s12"
+ init {
+ type:kConstant
+ value:0
+ }
+ }
+ srclayers:"pool1"
+ }
+ layer {
+ name: "relu1"
+ type: kCudnnActivation
+ activation_conf {
+ type: RELU
+ }
+ share_src_blobs: true
+ srclayers:"bm1"
+ }
+ layer {
+ name: "conv2"
+ type: kCudnnConv
+ srclayers: "relu1"
+ convolution_conf {
+ num_filters: 32
+ kernel: 5
+ stride: 1
+ pad:2
+ }
+ param {
+ name: "w2"
+ init {
+ type:kGaussian
+ std:0.01
+ }
+ }
+ param {
+ name: "b2"
+ lr_scale:2.0
+ wd_scale: 0
+ init {
+ type: kConstant
+ value:0
+ }
+ }
+ }
+ layer {
+ name: "bm2"
+ type: kCudnnBM
+ param {
+ name: "s21"
+ init {
+ type:kConstant
+ value:1
+ }
+ }
+ param {
+ name: "s22"
+ init {
+ type:kConstant
+ value:0
+ }
+ }
+ srclayers:"conv2"
+ }
+ layer {
+ name: "relu2"
+ type: kCudnnActivation
+ activation_conf {
+ type: RELU
+ }
+ share_src_blobs: true
+ srclayers:"bm2"
+ }
+ layer {
+ name: "pool2"
+ type: kCudnnPool
+ srclayers: "relu2"
+ pooling_conf {
+ pool: AVG
+ kernel: 3
+ stride: 2
+ }
+ }
+ layer {
+ name: "conv3"
+ type: kCudnnConv
+ srclayers: "relu2"
+ convolution_conf {
+ num_filters: 64
+ kernel: 5
+ stride: 1
+ pad:2
+ }
+ param {
+ name: "w3"
+ init {
+ type:kGaussian
+ std:0.01
+ }
+ }
+ param {
+ name: "b3"
+ lr_scale: 2
+ wd_scale: 0
+ init {
+ type: kConstant
+ value:0
+ }
+ }
+ }
+ layer {
+ name: "bm3"
+ type: kCudnnBM
+ param {
+ name: "s31"
+ init {
+ type:kConstant
+ value:1
+ }
+ }
+ param {
+ name: "s32"
+ init {
+ type:kConstant
+ value:0
+ }
+ }
+ srclayers:"conv3"
+ }
+ layer {
+ name: "relu3"
+ type: kCudnnActivation
+ activation_conf {
+ type: RELU
+ }
+ share_src_blobs: true
+ srclayers:"bm3"
+ }
+ layer {
+ name: "pool3"
+ type: kCudnnPool
+ srclayers: "relu3"
+ pooling_conf {
+ pool: AVG
+ kernel: 3
+ stride: 2
+ }
+ }
+ layer {
+ name: "ip1"
+ type: kInnerProduct
+ srclayers:"pool3"
+ innerproduct_conf {
+ num_output: 10
+ }
+ param {
+ name: "w4"
+ wd_scale:250
+ init {
+ type:kGaussian
+ std:0.01
+ }
+ }
+ param {
+ name: "b4"
+ lr_scale:2.0
+ wd_scale:0
+ init {
+ type: kConstant
+ value:0
+ }
+ }
+ }
+ layer {
+ name : "softmax"
+ type: kCudnnSoftmax
+ srclayers: "ip1"
+ include: kTest
+ }
+
+ layer {
+ name : "accuracy"
+ type: kAccuracy
+ srclayers: "softmax"
+ srclayers: "data"
+ include: kTest
+ }
+ layer{
+ name: "loss"
+ type: kSoftmaxLoss
+ srclayers:"ip1"
+ srclayers: "data"
+ include : kTrain
+ }
+# uncomment "softmax", "argsort", "output" layer and comment "loss" layer
+# to extract features from argsort
+# layer {
+# name : "output"
+# type: kCSVOutput
+# srclayers: "argsort"
+# store_conf {
+# path: "examples/cifar10/out.csv"
+# }
+# }
+}
+cluster {
+ nworker_groups: 1
+ nserver_groups: 1
+ nworkers_per_group: 1
+ nworkers_per_procs: 1
+ workspace: "examples/cifar10"
+}
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/4b4ad057/include/singa/neuralnet/neuron_layer.h
----------------------------------------------------------------------
diff --git a/include/singa/neuralnet/neuron_layer.h b/include/singa/neuralnet/neuron_layer.h
index f03e91b..3cdc137 100644
--- a/include/singa/neuralnet/neuron_layer.h
+++ b/include/singa/neuralnet/neuron_layer.h
@@ -351,6 +351,17 @@ class STanhLayer : public NeuronLayer {
void ComputeGradient(int flag, const vector<Layer*>& srclayers) override;
};
+
+class BMLayer : public NeuronLayer {
+ public:
+ void Setup(const LayerProto& proto, const vector<Layer*>& srclayers) override;
+ void ComputeFeature(int flag, const vector<Layer*>& srclayers) override;
+ void ComputeGradient(int flag, const vector<Layer*>& srclayers) override;
+ protected:
+ Param *bnScale_, *bnBias_;
+ int batchsize_, channels_, height_, width_;
+};
+
/*************** Layers implemented using cudnn v3 ***************/
#ifdef USE_CUDNN
#define CHECK_CUDNN(x) CHECK_EQ(x, CUDNN_STATUS_SUCCESS)
@@ -447,6 +458,25 @@ class CudnnSoftmaxLayer : public SoftmaxLayer, public CudnnBase {
void ComputeFeature(int flag, const vector<Layer*>& srclayers) override;
void ComputeGradient(int flag, const vector<Layer*>& srclayers) override;
};
+
+/**
+ * Cudnn Batch Normalization layer
+ */
+class CudnnBMLayer : public BMLayer, public CudnnBase {
+ public:
+ ~CudnnBMLayer();
+ void InitCudnn() override;
+ void ComputeFeature(int flag, const vector<Layer*>& srclayers) override;
+ void ComputeGradient(int flag, const vector<Layer*>& srclayers) override;
+ protected:
+ cudnnBatchNormMode_t mode_;
+ cudnnTensorDescriptor_t bnScaleBiasMeanVar_desc_;
+ cudnnTensorDescriptor_t bnScaleBiasDiff_desc_;
+ Blob<float> resultSaveMean_;
+ Blob<float> resultSaveInvVariance_;
+ Blob<float> resultRunningMean_;
+ Blob<float> resultRunningInvVariance_;
+};
#endif // USE_CUDNN
/******************** RBM layers *****************/
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/4b4ad057/src/driver.cc
----------------------------------------------------------------------
diff --git a/src/driver.cc b/src/driver.cc
index 6163865..ce1c635 100644
--- a/src/driver.cc
+++ b/src/driver.cc
@@ -98,6 +98,7 @@ void Driver::Init(int argc, char **argv) {
RegisterLayer<CudnnLRNLayer, int>(kCudnnLRN);
RegisterLayer<CudnnSoftmaxLayer, int>(kCudnnSoftmax);
RegisterLayer<CudnnSoftmaxLossLayer, int>(kCudnnSoftmaxLoss);
+ RegisterLayer<CudnnBMLayer, int>(kCudnnBM);
#endif
RegisterLayer<DropoutLayer, int>(kDropout);
@@ -119,6 +120,7 @@ void Driver::Init(int argc, char **argv) {
RegisterLayer<STanhLayer, int>(kSTanh);
RegisterLayer<SoftmaxLayer, int>(kSoftmax);
RegisterLayer<GRULayer, int>(kGRU);
+ RegisterLayer<BMLayer, int>(kBM);
#ifdef USE_LMDB
RegisterLayer<LMDBDataLayer, int>(kLMDBData);
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/4b4ad057/src/neuralnet/neuron_layer/bm.cc
----------------------------------------------------------------------
diff --git a/src/neuralnet/neuron_layer/bm.cc b/src/neuralnet/neuron_layer/bm.cc
new file mode 100644
index 0000000..5784595
--- /dev/null
+++ b/src/neuralnet/neuron_layer/bm.cc
@@ -0,0 +1,68 @@
+/************************************************************
+*
+* Licensed to the Apache Software Foundation (ASF) under one
+* or more contributor license agreements. See the NOTICE file
+* distributed with this work for additional information
+* regarding copyright ownership. The ASF licenses this file
+* to you under the Apache License, Version 2.0 (the
+* "License"); you may not use this file except in compliance
+* with the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing,
+* software distributed under the License is distributed on an
+* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+* KIND, either express or implied. See the License for the
+* specific language governing permissions and limitations
+* under the License.
+*
+*************************************************************/
+
+#include <glog/logging.h>
+#include "singa/neuralnet/neuron_layer.h"
+#include "singa/utils/singleton.h"
+
+namespace singa {
+
+using std::vector;
+
+void BMLayer::Setup(const LayerProto& conf,
+ const vector<Layer*>& srclayers) {
+ Layer::Setup(conf, srclayers);
+ data_.ReshapeLike(srclayers[0]->data(this));
+ grad_.ReshapeLike(srclayers[0]->grad(this));
+
+ const vector<int>& srcshape = srclayers[0]->data(this).shape();
+
+ batchsize_ = srcshape[0];
+ channels_ = srcshape[1];
+ height_ = srcshape[2];
+ width_ = srcshape[3];
+
+ bnScale_ = Param::Create(conf.param(0));
+ bnScale_->Setup(vector<int>{1, channels_, 1, 1});
+
+ bnBias_ = Param::Create(conf.param(1));
+ bnBias_->Setup(vector<int>{1, channels_, 1, 1});
+ bnScale_->InitValues(1);
+
+ /* float* p = bnScale_->data().mutable_cpu_data();
+
+ cout<<"inite param"<<endl;
+ for(int i=0;i<3;++i) {
+// p[i]=1.0;
+ cout<<p[i]<<" ";
+ }
+ cout<<endl;*/
+}
+
+void BMLayer::ComputeFeature(int flag, const vector<Layer*>& srclayers) {
+ // Todo
+}
+
+void BMLayer::ComputeGradient(int flag, const vector<Layer*>& srclayers) {
+ // Todo
+}
+
+} // namespace singa
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/4b4ad057/src/neuralnet/neuron_layer/cudnn_bm.cc
----------------------------------------------------------------------
diff --git a/src/neuralnet/neuron_layer/cudnn_bm.cc b/src/neuralnet/neuron_layer/cudnn_bm.cc
new file mode 100644
index 0000000..fdc9ea9
--- /dev/null
+++ b/src/neuralnet/neuron_layer/cudnn_bm.cc
@@ -0,0 +1,150 @@
+/************************************************************
+*
+* Licensed to the Apache Software Foundation (ASF) under one
+* or more contributor license agreements. See the NOTICE file
+* distributed with this work for additional information
+* regarding copyright ownership. The ASF licenses this file
+* to you under the Apache License, Version 2.0 (the
+* "License"); you may not use this file except in compliance
+* with the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing,
+* software distributed under the License is distributed on an
+* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+* KIND, either express or implied. See the License for the
+* specific language governing permissions and limitations
+* under the License.
+*
+*************************************************************/
+
+#include "singa/neuralnet/neuron_layer.h"
+
+namespace singa {
+
+CudnnBMLayer::~CudnnBMLayer() {
+ if (has_init_cudnn_) {
+ CHECK_CUDNN(cudnnDestroyTensorDescriptor(bnScaleBiasMeanVar_desc_));
+ CHECK_CUDNN(cudnnDestroyTensorDescriptor(bnScaleBiasDiff_desc_));
+ }
+}
+
+void CudnnBMLayer::InitCudnn() {
+ CudnnBase::InitCudnn();
+
+ CHECK_CUDNN(cudnnCreateTensorDescriptor(&bnScaleBiasMeanVar_desc_));
+ CHECK_CUDNN(cudnnCreateTensorDescriptor(&bnScaleBiasDiff_desc_));
+
+ CHECK_CUDNN(cudnnSetTensor4dDescriptor(src_desc_,
+ CUDNN_TENSOR_NCHW,
+ CUDNN_DATA_FLOAT,
+ batchsize_,
+ channels_,
+ height_,
+ width_));
+ CHECK_CUDNN(cudnnSetTensor4dDescriptor(my_desc_,
+ CUDNN_TENSOR_NCHW,
+ CUDNN_DATA_FLOAT,
+ batchsize_,
+ channels_,
+ height_,
+ width_));
+ CHECK_CUDNN(cudnnSetTensor4dDescriptor(bnScaleBiasMeanVar_desc_,
+ CUDNN_TENSOR_NCHW,
+ CUDNN_DATA_FLOAT,
+ 1,
+ channels_,
+ 1,
+ 1));
+ CHECK_CUDNN(cudnnSetTensor4dDescriptor(bnScaleBiasDiff_desc_,
+ CUDNN_TENSOR_NCHW,
+ CUDNN_DATA_FLOAT,
+ 1,
+ channels_,
+ 1,
+ 1));
+
+ vector<int> shape{1, channels_, 1, 1};
+
+ resultSaveMean_.Reshape(shape);
+ resultSaveInvVariance_.Reshape(shape);
+ resultRunningMean_.Reshape(shape);
+ resultRunningInvVariance_.Reshape(shape);
+
+ mode_ = CUDNN_BATCHNORM_SPATIAL;
+}
+
+void CudnnBMLayer::ComputeFeature(int flag,
+ const vector<Layer*>& srclayers) {
+ if (!has_init_cudnn_)
+ InitCudnn();
+
+ const float alpha = 1.0f, beta = 0.0f;
+ double exponentialAverageFactor = 1.0;
+ double epsilon = CUDNN_BN_MIN_EPSILON;
+
+ // check training
+ if ((flag & kTrain) != kTrain) {
+ CHECK_CUDNN(cudnnBatchNormalizationForwardInference(handle_,
+ mode_,
+ &alpha,
+ &beta,
+ src_desc_,
+ srclayers.at(0)->data(this).gpu_data(),
+ my_desc_,
+ data_.mutable_gpu_data(),
+ bnScaleBiasMeanVar_desc_,
+ bnScale_->data().gpu_data(),
+ bnBias_->data().gpu_data(),
+ resultRunningMean_.gpu_data(),
+ resultRunningInvVariance_.gpu_data(),
+ epsilon));
+ } else {
+ CHECK_CUDNN(cudnnBatchNormalizationForwardTraining(handle_,
+ mode_,
+ &alpha,
+ &beta,
+ src_desc_,
+ srclayers.at(0)->data(this).gpu_data(),
+ my_desc_,
+ data_.mutable_gpu_data(),
+ bnScaleBiasMeanVar_desc_,
+ bnScale_->data().gpu_data(),
+ bnBias_->data().gpu_data(),
+ exponentialAverageFactor,
+ resultRunningMean_.mutable_gpu_data(),
+ resultRunningInvVariance_.mutable_gpu_data(),
+ epsilon,
+ resultSaveMean_.mutable_gpu_data(),
+ resultSaveInvVariance_.mutable_gpu_data()));
+ }
+}
+
+void CudnnBMLayer::ComputeGradient(int flag,
+ const vector<Layer*>& srclayers) {
+
+ const float alpha = 1.0f, beta = 0.0f, alphaDiff = 1.0f, betaDiff = 0.0f;
+ double epsilon = CUDNN_BN_MIN_EPSILON;
+
+ CHECK_CUDNN(cudnnBatchNormalizationBackward(handle_,
+ mode_,
+ &alpha,
+ &beta,
+ &alphaDiff,
+ &betaDiff,
+ src_desc_,
+ srclayers.at(0)->data(this).gpu_data(),
+ my_desc_,
+ grad_.gpu_data(),
+ src_desc_,
+ srclayers.at(0)->mutable_grad(this)->mutable_gpu_data(),
+ bnScaleBiasDiff_desc_,
+ bnScale_->data().gpu_data(),
+ bnScale_->grad().mutable_gpu_data(),
+ bnBias_->grad().mutable_gpu_data(),
+ epsilon,
+ resultSaveMean_.gpu_data(),
+ resultSaveInvVariance_.gpu_data()));
+}
+} // namespace singa
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/4b4ad057/src/proto/job.proto
----------------------------------------------------------------------
diff --git a/src/proto/job.proto b/src/proto/job.proto
index 7bc0ea3..622248c 100644
--- a/src/proto/job.proto
+++ b/src/proto/job.proto
@@ -249,6 +249,7 @@ message LayerProto {
optional SoftmaxProto softmax_conf = 214;
optional GRUProto gru_conf = 215;
optional EmbeddingProto embedding_conf = 216;
+ optional BMProto bm_conf = 217;
// configuration for loss layers, id range [300, 400)
optional SoftmaxLossProto softmaxloss_conf = 301;
@@ -393,6 +394,10 @@ message EmbeddingProto {
optional int32 feature_dim = 2 [default = 100];
}
+
+message BMProto {
+}
+
message SoftmaxLossProto {
// computing accuracy against topk results
optional int32 topk = 1 [default = 1];
@@ -676,6 +681,7 @@ enum LayerType {
kSoftmax = 214;
kGRU = 215;
kEmbedding = 216;
+ kBM = 217;
// cudnn v3
kCudnnConv = 250;
@@ -683,6 +689,7 @@ enum LayerType {
kCudnnLRN = 252;
kCudnnSoftmax = 253;
kCudnnActivation = 254;
+ kCudnnBM = 255;
/*
* Loss layers