You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@singa.apache.org by wa...@apache.org on 2016/04/06 18:00:14 UTC

[2/4] incubator-singa git commit: SINGA-136 Support cuDNN v4

SINGA-136 Support cuDNN v4

Add cuDNN BatchNorm Layer


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/4b4ad057
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/4b4ad057
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/4b4ad057

Branch: refs/heads/master
Commit: 4b4ad0573ac0f3bcc0403fe13d81925782d4352d
Parents: afc50a9
Author: seaok <se...@gmail.com>
Authored: Wed Apr 6 14:58:26 2016 +0800
Committer: seaok <se...@gmail.com>
Committed: Wed Apr 6 14:58:26 2016 +0800

----------------------------------------------------------------------
 examples/cifar10/cudnn_bm.conf         | 334 ++++++++++++++++++++++++++++
 include/singa/neuralnet/neuron_layer.h |  30 +++
 src/driver.cc                          |   2 +
 src/neuralnet/neuron_layer/bm.cc       |  68 ++++++
 src/neuralnet/neuron_layer/cudnn_bm.cc | 150 +++++++++++++
 src/proto/job.proto                    |   7 +
 6 files changed, 591 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/4b4ad057/examples/cifar10/cudnn_bm.conf
----------------------------------------------------------------------
diff --git a/examples/cifar10/cudnn_bm.conf b/examples/cifar10/cudnn_bm.conf
new file mode 100644
index 0000000..071ed8a
--- /dev/null
+++ b/examples/cifar10/cudnn_bm.conf
@@ -0,0 +1,334 @@
+name: "cifar10-convnet"
+train_steps: 70000
+test_steps: 100
+test_freq: 1000
+#validate_steps: 100
+#validate_freq: 300
+disp_freq: 200
+gpu: 0
+#checkpoint_path: "examples/cifar10/checkpoint/step1000-worker0"
+train_one_batch {
+  alg: kBP
+}
+updater{
+  type: kSGD
+  weight_decay:0.004
+  momentum:0.9
+  learning_rate {
+    type: kFixedStep
+    fixedstep_conf:{
+      step:0
+      step:60000
+      step:65000
+      step_lr:0.001
+      step_lr:0.0001
+      step_lr:0.00001
+    }
+  }
+}
+neuralnet {
+  layer{
+    name: "data"
+    type: kRecordInput
+    store_conf {
+      backend: "kvfile"
+      path: "examples/cifar10/train_data.bin"
+      mean_file: "examples/cifar10/image_mean.bin"
+      batchsize: 100
+      #random_skip: 5000
+      shape: 3
+      shape: 32
+      shape: 32
+    }
+    include: kTrain
+  }
+#  layer{
+#    name: "data"
+#    type: kRecordInput
+#    store_conf {
+#      backend: "kvfile"
+#      path: "examples/cifar10/val_data.bin"
+#      mean_file: "examples/cifar10/image_mean.bin"
+#      batchsize: 64
+#      random_skip: 5000
+#      shape: 3
+#      shape: 32
+#      shape: 32
+#    }
+#    include: kVal
+#  }
+  layer{
+    name: "data"
+    type: kRecordInput
+    store_conf {
+      backend: "kvfile"
+      path: "examples/cifar10/test_data.bin"
+      mean_file: "examples/cifar10/image_mean.bin"
+      batchsize: 100
+      shape: 3
+      shape: 32
+      shape: 32
+    }
+    include: kTest
+  }
+
+  layer {
+    name: "conv1"
+    type: kCudnnConv
+    srclayers: "data"
+    convolution_conf {
+      num_filters: 32
+      kernel: 5
+      stride: 1
+      pad:2
+    }
+    param {
+      name: "w1"
+      init {
+        type:kGaussian
+        std:0.0001
+      }
+    }
+    param {
+      name: "b1"
+      lr_scale:2.0
+      wd_scale: 0
+      init {
+        type: kConstant
+        value:0
+      }
+    }
+  }
+
+  layer {
+    name: "pool1"
+    type: kCudnnPool
+    srclayers: "conv1"
+    pooling_conf {
+      pool: MAX
+      kernel: 3
+      stride: 2
+    }
+  }
+  layer {
+    name: "bm1"
+    type: kCudnnBM
+      param {
+        name: "s11"
+        init {
+ 		  type:kConstant
+ 		  value:1
+        }
+      }
+	  param {
+		name: "s12"
+		init {
+		  type:kConstant
+		  value:0
+		}
+	  }
+    srclayers:"pool1"
+  }
+  layer {
+    name: "relu1"
+    type: kCudnnActivation
+    activation_conf {
+      type: RELU
+    }
+    share_src_blobs: true
+    srclayers:"bm1"
+  }
+  layer {
+    name: "conv2"
+    type: kCudnnConv
+    srclayers: "relu1"
+    convolution_conf {
+      num_filters: 32
+      kernel: 5
+      stride: 1
+      pad:2
+    }
+    param {
+      name: "w2"
+      init {
+        type:kGaussian
+        std:0.01
+      }
+    }
+    param {
+      name: "b2"
+      lr_scale:2.0
+      wd_scale: 0
+      init {
+        type: kConstant
+        value:0
+      }
+    }
+  }
+  layer {
+    name: "bm2"
+    type: kCudnnBM
+      param {
+        name: "s21"
+        init {
+ 		  type:kConstant
+ 		  value:1
+        }
+      }
+	  param {
+		name: "s22"
+		init {
+		  type:kConstant
+		  value:0
+		}
+	  }
+    srclayers:"conv2"
+  }
+  layer {
+    name: "relu2"
+    type: kCudnnActivation
+    activation_conf {
+      type: RELU
+    }
+    share_src_blobs: true
+    srclayers:"bm2"
+  }
+  layer {
+    name: "pool2"
+    type: kCudnnPool
+    srclayers: "relu2"
+    pooling_conf {
+      pool: AVG
+      kernel: 3
+      stride: 2
+    }
+  }
+  layer {
+    name: "conv3"
+    type: kCudnnConv
+    srclayers: "relu2"
+    convolution_conf {
+      num_filters: 64
+      kernel: 5
+      stride: 1
+      pad:2
+    }
+    param {
+      name: "w3"
+      init {
+        type:kGaussian
+        std:0.01
+      }
+    }
+    param {
+      name: "b3"
+      lr_scale: 2
+      wd_scale: 0
+      init {
+        type: kConstant
+        value:0
+      }
+    }
+  }
+  layer {
+    name: "bm3"
+    type: kCudnnBM
+      param {
+        name: "s31"
+        init {
+ 		  type:kConstant
+ 		  value:1
+        }
+      }
+	  param {
+		name: "s32"
+		init {
+		  type:kConstant
+		  value:0
+		}
+	  }
+    srclayers:"conv3"
+  }
+  layer {
+    name: "relu3"
+    type: kCudnnActivation
+    activation_conf {
+      type: RELU
+    }
+    share_src_blobs: true
+    srclayers:"bm3"
+  }
+  layer {
+    name: "pool3"
+    type: kCudnnPool
+    srclayers: "relu3"
+    pooling_conf {
+      pool: AVG
+      kernel: 3
+      stride: 2
+    }
+  }
+  layer {
+    name: "ip1"
+    type: kInnerProduct
+    srclayers:"pool3"
+    innerproduct_conf {
+      num_output: 10
+    }
+    param {
+      name: "w4"
+      wd_scale:250
+      init {
+        type:kGaussian
+        std:0.01
+      }
+    }
+    param {
+      name: "b4"
+      lr_scale:2.0
+      wd_scale:0
+      init {
+        type: kConstant
+        value:0
+      }
+    }
+  }
+  layer {
+   name : "softmax"
+   type: kCudnnSoftmax
+   srclayers: "ip1"
+   include: kTest
+  }
+
+  layer {
+   name : "accuracy"
+   type: kAccuracy
+   srclayers: "softmax"
+   srclayers: "data"
+   include: kTest
+  }
+  layer{
+    name: "loss"
+    type: kSoftmaxLoss
+    srclayers:"ip1"
+    srclayers: "data"
+    include : kTrain
+  }
+# uncomment "softmax", "argsort", "output" layer and comment "loss" layer
+# to extract features from argsort
+#  layer {
+#    name : "output"
+#    type: kCSVOutput
+#    srclayers: "argsort"
+#    store_conf {
+#      path: "examples/cifar10/out.csv"
+#    }
+#  }
+}
+cluster {
+  nworker_groups: 1
+  nserver_groups: 1
+  nworkers_per_group: 1
+  nworkers_per_procs: 1
+  workspace: "examples/cifar10"
+}

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/4b4ad057/include/singa/neuralnet/neuron_layer.h
----------------------------------------------------------------------
diff --git a/include/singa/neuralnet/neuron_layer.h b/include/singa/neuralnet/neuron_layer.h
index f03e91b..3cdc137 100644
--- a/include/singa/neuralnet/neuron_layer.h
+++ b/include/singa/neuralnet/neuron_layer.h
@@ -351,6 +351,17 @@ class STanhLayer : public NeuronLayer {
   void ComputeGradient(int flag, const vector<Layer*>& srclayers) override;
 };
 
+
+class BMLayer : public NeuronLayer {
+ public:
+  void Setup(const LayerProto& proto, const vector<Layer*>& srclayers) override;
+  void ComputeFeature(int flag, const vector<Layer*>& srclayers) override;
+  void ComputeGradient(int flag, const vector<Layer*>& srclayers) override;
+ protected:
+  Param *bnScale_, *bnBias_;
+  int batchsize_,  channels_, height_, width_;
+};
+
 /*************** Layers implemented using cudnn v3 ***************/
 #ifdef USE_CUDNN
 #define CHECK_CUDNN(x) CHECK_EQ(x, CUDNN_STATUS_SUCCESS)
@@ -447,6 +458,25 @@ class CudnnSoftmaxLayer : public SoftmaxLayer, public CudnnBase {
   void ComputeFeature(int flag, const vector<Layer*>& srclayers) override;
   void ComputeGradient(int flag, const vector<Layer*>& srclayers) override;
 };
+
+/**
+ * Cudnn Batch Normalization layer
+ */
+class CudnnBMLayer : public BMLayer, public CudnnBase {
+ public:
+  ~CudnnBMLayer();
+  void InitCudnn() override;
+  void ComputeFeature(int flag, const vector<Layer*>& srclayers) override;
+  void ComputeGradient(int flag, const vector<Layer*>& srclayers) override;
+ protected:
+  cudnnBatchNormMode_t mode_;
+  cudnnTensorDescriptor_t bnScaleBiasMeanVar_desc_;
+  cudnnTensorDescriptor_t bnScaleBiasDiff_desc_;
+  Blob<float> resultSaveMean_;
+  Blob<float> resultSaveInvVariance_;
+  Blob<float> resultRunningMean_;
+  Blob<float> resultRunningInvVariance_;
+};
 #endif  // USE_CUDNN
 
 /******************** RBM layers *****************/

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/4b4ad057/src/driver.cc
----------------------------------------------------------------------
diff --git a/src/driver.cc b/src/driver.cc
index 6163865..ce1c635 100644
--- a/src/driver.cc
+++ b/src/driver.cc
@@ -98,6 +98,7 @@ void Driver::Init(int argc, char **argv) {
   RegisterLayer<CudnnLRNLayer, int>(kCudnnLRN);
   RegisterLayer<CudnnSoftmaxLayer, int>(kCudnnSoftmax);
   RegisterLayer<CudnnSoftmaxLossLayer, int>(kCudnnSoftmaxLoss);
+  RegisterLayer<CudnnBMLayer, int>(kCudnnBM);
 #endif
 
   RegisterLayer<DropoutLayer, int>(kDropout);
@@ -119,6 +120,7 @@ void Driver::Init(int argc, char **argv) {
   RegisterLayer<STanhLayer, int>(kSTanh);
   RegisterLayer<SoftmaxLayer, int>(kSoftmax);
   RegisterLayer<GRULayer, int>(kGRU);
+  RegisterLayer<BMLayer, int>(kBM);
 
 #ifdef USE_LMDB
   RegisterLayer<LMDBDataLayer, int>(kLMDBData);

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/4b4ad057/src/neuralnet/neuron_layer/bm.cc
----------------------------------------------------------------------
diff --git a/src/neuralnet/neuron_layer/bm.cc b/src/neuralnet/neuron_layer/bm.cc
new file mode 100644
index 0000000..5784595
--- /dev/null
+++ b/src/neuralnet/neuron_layer/bm.cc
@@ -0,0 +1,68 @@
+/************************************************************
+*
+* Licensed to the Apache Software Foundation (ASF) under one
+* or more contributor license agreements.  See the NOTICE file
+* distributed with this work for additional information
+* regarding copyright ownership.  The ASF licenses this file
+* to you under the Apache License, Version 2.0 (the
+* "License"); you may not use this file except in compliance
+* with the License.  You may obtain a copy of the License at
+*
+*   http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing,
+* software distributed under the License is distributed on an
+* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+* KIND, either express or implied.  See the License for the
+* specific language governing permissions and limitations
+* under the License.
+*
+*************************************************************/
+
+#include <glog/logging.h>
+#include "singa/neuralnet/neuron_layer.h"
+#include "singa/utils/singleton.h"
+
+namespace singa {
+
+using std::vector;
+
+void BMLayer::Setup(const LayerProto& conf,
+    const vector<Layer*>& srclayers) {
+  Layer::Setup(conf, srclayers);
+  data_.ReshapeLike(srclayers[0]->data(this));
+  grad_.ReshapeLike(srclayers[0]->grad(this));
+
+  const vector<int>& srcshape = srclayers[0]->data(this).shape();
+
+  batchsize_ = srcshape[0];
+  channels_ = srcshape[1];
+  height_ = srcshape[2];
+  width_ = srcshape[3];
+
+  bnScale_ = Param::Create(conf.param(0));
+  bnScale_->Setup(vector<int>{1, channels_, 1, 1});
+
+  bnBias_ = Param::Create(conf.param(1));
+  bnBias_->Setup(vector<int>{1, channels_, 1, 1});
+  bnScale_->InitValues(1);
+
+ /* float* p = bnScale_->data().mutable_cpu_data();
+
+  cout<<"inite param"<<endl;
+  for(int i=0;i<3;++i) {
+//	p[i]=1.0;
+	cout<<p[i]<<" ";
+  }
+  cout<<endl;*/
+}
+
+void BMLayer::ComputeFeature(int flag, const vector<Layer*>& srclayers) {
+  // Todo
+}
+
+void BMLayer::ComputeGradient(int flag, const vector<Layer*>& srclayers) {
+  // Todo
+}
+
+}  //  namespace singa

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/4b4ad057/src/neuralnet/neuron_layer/cudnn_bm.cc
----------------------------------------------------------------------
diff --git a/src/neuralnet/neuron_layer/cudnn_bm.cc b/src/neuralnet/neuron_layer/cudnn_bm.cc
new file mode 100644
index 0000000..fdc9ea9
--- /dev/null
+++ b/src/neuralnet/neuron_layer/cudnn_bm.cc
@@ -0,0 +1,150 @@
+/************************************************************
+*
+* Licensed to the Apache Software Foundation (ASF) under one
+* or more contributor license agreements.  See the NOTICE file
+* distributed with this work for additional information
+* regarding copyright ownership.  The ASF licenses this file
+* to you under the Apache License, Version 2.0 (the
+* "License"); you may not use this file except in compliance
+* with the License.  You may obtain a copy of the License at
+*
+*   http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing,
+* software distributed under the License is distributed on an
+* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+* KIND, either express or implied.  See the License for the
+* specific language governing permissions and limitations
+* under the License.
+*
+*************************************************************/
+
+#include "singa/neuralnet/neuron_layer.h"
+
+namespace singa {
+
+CudnnBMLayer::~CudnnBMLayer() {
+  if (has_init_cudnn_) {
+    CHECK_CUDNN(cudnnDestroyTensorDescriptor(bnScaleBiasMeanVar_desc_));
+    CHECK_CUDNN(cudnnDestroyTensorDescriptor(bnScaleBiasDiff_desc_));
+  }
+}
+
+void CudnnBMLayer::InitCudnn() {
+  CudnnBase::InitCudnn();
+
+  CHECK_CUDNN(cudnnCreateTensorDescriptor(&bnScaleBiasMeanVar_desc_));
+  CHECK_CUDNN(cudnnCreateTensorDescriptor(&bnScaleBiasDiff_desc_));
+
+  CHECK_CUDNN(cudnnSetTensor4dDescriptor(src_desc_,
+        CUDNN_TENSOR_NCHW,
+        CUDNN_DATA_FLOAT,
+        batchsize_,
+        channels_,
+        height_,
+        width_));
+  CHECK_CUDNN(cudnnSetTensor4dDescriptor(my_desc_,
+        CUDNN_TENSOR_NCHW,
+        CUDNN_DATA_FLOAT,
+        batchsize_,
+        channels_,
+        height_,
+        width_));
+  CHECK_CUDNN(cudnnSetTensor4dDescriptor(bnScaleBiasMeanVar_desc_,
+        CUDNN_TENSOR_NCHW,
+        CUDNN_DATA_FLOAT,
+        1,
+        channels_,
+        1,
+        1));
+  CHECK_CUDNN(cudnnSetTensor4dDescriptor(bnScaleBiasDiff_desc_,
+        CUDNN_TENSOR_NCHW,
+        CUDNN_DATA_FLOAT,
+        1,
+        channels_,
+        1,
+        1));
+
+  vector<int> shape{1, channels_, 1, 1};
+
+  resultSaveMean_.Reshape(shape);
+  resultSaveInvVariance_.Reshape(shape);
+  resultRunningMean_.Reshape(shape);
+  resultRunningInvVariance_.Reshape(shape);
+
+  mode_ = CUDNN_BATCHNORM_SPATIAL;
+}
+
+void CudnnBMLayer::ComputeFeature(int flag,
+    const vector<Layer*>& srclayers) {
+  if (!has_init_cudnn_)
+    InitCudnn();
+
+  const float alpha = 1.0f, beta = 0.0f;
+  double exponentialAverageFactor = 1.0;
+  double epsilon = CUDNN_BN_MIN_EPSILON;
+
+  // check training
+  if ((flag & kTrain) != kTrain) {
+    CHECK_CUDNN(cudnnBatchNormalizationForwardInference(handle_,
+          mode_,
+          &alpha,
+          &beta,
+          src_desc_,
+          srclayers.at(0)->data(this).gpu_data(),
+          my_desc_,
+          data_.mutable_gpu_data(),
+          bnScaleBiasMeanVar_desc_,
+          bnScale_->data().gpu_data(),
+          bnBias_->data().gpu_data(),
+          resultRunningMean_.gpu_data(),
+          resultRunningInvVariance_.gpu_data(),
+          epsilon));
+  } else {
+    CHECK_CUDNN(cudnnBatchNormalizationForwardTraining(handle_,
+          mode_,
+          &alpha,
+          &beta,
+          src_desc_,
+          srclayers.at(0)->data(this).gpu_data(),
+          my_desc_,
+          data_.mutable_gpu_data(),
+          bnScaleBiasMeanVar_desc_,
+          bnScale_->data().gpu_data(),
+          bnBias_->data().gpu_data(),
+          exponentialAverageFactor,
+          resultRunningMean_.mutable_gpu_data(),
+          resultRunningInvVariance_.mutable_gpu_data(),
+          epsilon,
+          resultSaveMean_.mutable_gpu_data(),
+          resultSaveInvVariance_.mutable_gpu_data()));
+  }
+}
+
+void CudnnBMLayer::ComputeGradient(int flag,
+    const vector<Layer*>& srclayers) {
+
+  const float alpha = 1.0f, beta = 0.0f, alphaDiff = 1.0f, betaDiff = 0.0f;
+  double epsilon = CUDNN_BN_MIN_EPSILON;
+
+  CHECK_CUDNN(cudnnBatchNormalizationBackward(handle_,
+      mode_,
+      &alpha,
+      &beta,
+      &alphaDiff,
+      &betaDiff,
+      src_desc_,
+      srclayers.at(0)->data(this).gpu_data(),
+      my_desc_,
+      grad_.gpu_data(),
+      src_desc_,
+      srclayers.at(0)->mutable_grad(this)->mutable_gpu_data(),
+      bnScaleBiasDiff_desc_,
+      bnScale_->data().gpu_data(),
+      bnScale_->grad().mutable_gpu_data(),
+      bnBias_->grad().mutable_gpu_data(),
+      epsilon,
+      resultSaveMean_.gpu_data(),
+      resultSaveInvVariance_.gpu_data()));
+}
+}  // namespace singa

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/4b4ad057/src/proto/job.proto
----------------------------------------------------------------------
diff --git a/src/proto/job.proto b/src/proto/job.proto
index 7bc0ea3..622248c 100644
--- a/src/proto/job.proto
+++ b/src/proto/job.proto
@@ -249,6 +249,7 @@ message LayerProto {
   optional SoftmaxProto softmax_conf = 214;
   optional GRUProto gru_conf = 215;
   optional EmbeddingProto embedding_conf = 216;
+  optional BMProto bm_conf = 217;
 
   // configuration for loss layers, id range [300, 400)
   optional SoftmaxLossProto softmaxloss_conf = 301;
@@ -393,6 +394,10 @@ message EmbeddingProto {
   optional int32 feature_dim = 2 [default = 100];
 
 }
+
+message BMProto {
+}
+
 message SoftmaxLossProto {
   // computing accuracy against topk results
   optional int32 topk = 1 [default = 1];
@@ -676,6 +681,7 @@ enum LayerType {
   kSoftmax = 214;
   kGRU = 215;
   kEmbedding = 216;
+  kBM = 217;
 
   // cudnn v3
   kCudnnConv = 250;
@@ -683,6 +689,7 @@ enum LayerType {
   kCudnnLRN = 252;
   kCudnnSoftmax = 253;
   kCudnnActivation = 254;
+  kCudnnBM = 255;
 
   /*
    * Loss layers