You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@singa.apache.org by wa...@apache.org on 2016/06/03 07:49:01 UTC

[56/60] incubator-singa git commit: SINGA-174 Add Batch Normalization layer and Local Response Normalization layer.

SINGA-174 Add Batch Normalization layer and Local Response Normalization layer.

Implemented Batch Normalization Layer and Local Response Normalization Layer using CuDNN.
Passed test.


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/eadd3f96
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/eadd3f96
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/eadd3f96

Branch: refs/heads/dev
Commit: eadd3f969984cd4a94a1953a0b5e87af84ea5dc4
Parents: 64ea206
Author: WANG Ji <ij...@gmail.com>
Authored: Sun May 22 12:39:16 2016 +0800
Committer: Wei Wang <wa...@comp.nus.edu.sg>
Committed: Thu Jun 2 13:43:48 2016 +0800

----------------------------------------------------------------------
 include/singa/core/common.h        |   2 +-
 src/model/layer/batchnorm.cc       |  70 +++++++++
 src/model/layer/batchnorm.h        |  84 +++++++++++
 src/model/layer/cudnn_batchnorm.cc | 214 ++++++++++++++++++++++++++
 src/model/layer/cudnn_batchnorm.h  |  60 ++++++++
 src/model/layer/cudnn_lrn.cc       | 114 ++++++++++++++
 src/model/layer/cudnn_lrn.h        |  56 +++++++
 src/model/layer/lrn.cc             |  59 ++++++++
 src/model/layer/lrn.h              |  70 +++++++++
 src/proto/model.proto              |  13 +-
 test/singa/test_cudnn_batchnorm.cc | 257 ++++++++++++++++++++++++++++++++
 test/singa/test_cudnn_lrn.cc       | 205 +++++++++++++++++++++++++
 12 files changed, 1201 insertions(+), 3 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/eadd3f96/include/singa/core/common.h
----------------------------------------------------------------------
diff --git a/include/singa/core/common.h b/include/singa/core/common.h
index 9d005c4..e6f4c90 100644
--- a/include/singa/core/common.h
+++ b/include/singa/core/common.h
@@ -42,7 +42,7 @@ typedef struct _Cuda { } Cuda;
 typedef struct _Opencl { } Opencl;
 }  // namespace lang
 
-/// Blob reprent a chunk of memory (on device or host) managed by VirtualMemory.
+/// Blob represent a chunk of memory (on device or host) managed by VirtualMemory.
 class Blob {
  public:
   Blob(void* ptr, size_t size) : data_(ptr), size_(size), ref_count_(1) {}

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/eadd3f96/src/model/layer/batchnorm.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/batchnorm.cc b/src/model/layer/batchnorm.cc
new file mode 100644
index 0000000..bcd0870
--- /dev/null
+++ b/src/model/layer/batchnorm.cc
@@ -0,0 +1,70 @@
+/*********************************************************
+*
+* Licensed to the Apache Software Foundation (ASF) under one
+* or more contributor license agreements.  See the NOTICE file
+* distributed with this work for additional information
+* regarding copyright ownership.  The ASF licenses this file
+* to you under the Apache License, Version 2.0 (the
+* "License"); you may not use this file except in compliance
+* with the License.  You may obtain a copy of the License at
+*
+*   http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing,
+* software distributed under the License is distributed on an
+* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+* KIND, either express or implied.  See the License for the
+* specific language governing permissions and limitations
+* under the License.
+*
+************************************************************/
+#include "batchnorm.h"
+
+namespace singa {
+void BatchNorm::Setup(const LayerConf& conf) {
+  Layer::Setup(conf);
+  factor_ = conf.batchnorm_conf().factor();
+  channels_ = conf.batchnorm_conf().channels();
+  height_ = conf.batchnorm_conf().height();
+  width_ = conf.batchnorm_conf().width();
+
+  bnScale_.Reshape(Shape{channels_ * height_ * width_});
+  bnBias_.ResetLike(bnScale_);
+  runningMean_.ResetLike(bnScale_);
+  runningVariance_.ResetLike(bnScale_);
+
+  dbnScale_.ResetLike(bnScale_);
+  dbnBias_.ResetLike(bnBias_);
+  // Push back params into param_values_
+  // Assume the order of param is: bnScale, bnBias, runningMean, runningVariance
+  for (const auto &spec : conf.param()) param_specs_.push_back(spec);
+  param_values_.push_back(&bnScale_);
+  param_values_.push_back(&bnBias_);
+  param_values_.push_back(&runningMean_);
+  param_values_.push_back(&runningVariance_);
+}
+
+void BatchNorm::ToDevice(Device* device) {
+  bnScale_.ToDevice(device);
+  bnBias_.ToDevice(device);
+  dbnScale_.ToDevice(device);
+  dbnBias_.ToDevice(device);
+  runningMean_.ToDevice(device);
+  runningVariance_.ToDevice(device);
+}
+
+const Tensor BatchNorm::Forward(int flag, const Tensor& input) {
+  LOG(FATAL) << "Not implemented";
+  Tensor output;
+  return output;
+}
+
+const std::pair<Tensor, vector<Tensor>> BatchNorm::Backward(
+    int flag, const Tensor& grad) {
+  LOG(FATAL) << "Not implemented";
+  Tensor dx;
+  vector<Tensor> param_grad;
+  return std::make_pair(dx, param_grad);
+}
+
+}  // namespace

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/eadd3f96/src/model/layer/batchnorm.h
----------------------------------------------------------------------
diff --git a/src/model/layer/batchnorm.h b/src/model/layer/batchnorm.h
new file mode 100644
index 0000000..0255179
--- /dev/null
+++ b/src/model/layer/batchnorm.h
@@ -0,0 +1,84 @@
+/*********************************************************
+*
+* Licensed to the Apache Software Foundation (ASF) under one
+* or more contributor license agreements.  See the NOTICE file
+* distributed with this work for additional information
+* regarding copyright ownership.  The ASF licenses this file
+* to you under the Apache License, Version 2.0 (the
+* "License"); you may not use this file except in compliance
+* with the License.  You may obtain a copy of the License at
+*
+*   http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing,
+* software distributed under the License is distributed on an
+* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+* KIND, either express or implied.  See the License for the
+* specific language governing permissions and limitations
+* under the License.
+*
+************************************************************/
+#ifndef SINGA_MODEL_LAYER_BATCHNORM_H
+#define SINGA_MODEL_LAYER_BATCHNORM_H
+#include "singa/model/layer.h"
+#include "singa/core/common.h"
+#include "singa/proto/core.pb.h"
+#include <stack>
+
+namespace singa {
+class BatchNorm : public Layer {
+ public:
+  /// \copydoc Layer::layer_type()
+  const std::string layer_type() const override {
+    return "Batch Normalization";
+  }
+
+  /// \copydoc Layer::Setup(const LayerConf&)
+  virtual void Setup(const LayerConf& conf) override;
+
+  const Tensor Forward(int flag, const Tensor& input)
+    override;
+
+  /// \copydoc Layer::Backward(int, const Tensor&, const Tensor&);
+  const std::pair<Tensor, vector<Tensor>> Backward(
+      int flag, const Tensor& grad) override;
+
+  const float factor() const { return factor_; }
+  const Tensor& bnScale() const { return bnScale_; }
+  const Tensor& bnBias() const { return bnBias_; }
+  const Tensor& runningMean() const { return runningMean_; }
+  const Tensor& runningVariance() const { return runningVariance_; }
+  const size_t channels() const { return channels_; }
+  const size_t height() const { return height_; }
+  const size_t width() const { return width_; }
+  void set_bnScale(Tensor x) {
+    bnScale_.ResetLike(x);
+    bnScale_.CopyData(x);
+  }
+  void set_bnBias(Tensor x) {
+    bnBias_.ResetLike(x);
+    bnBias_.CopyData(x);
+  }
+  void set_runningMean(Tensor x) {
+    runningMean_.ResetLike(x);
+    runningMean_.CopyData(x);
+  }
+  void set_runningVariance(Tensor x) {
+    runningVariance_.ResetLike(x);
+    runningVariance_.CopyData(x);
+  }
+  virtual void ToDevice(Device* device) override;
+
+ protected:
+  float factor_;
+  size_t channels_, height_, width_;
+  Tensor bnScale_, bnBias_;
+  Tensor dbnScale_, dbnBias_;
+  Tensor runningMean_, runningVariance_;
+  // Store intermediate data, i.e., input tensor
+  std::stack<Tensor> buf_;
+  
+}; // class batchnorm
+} // namespace 
+
+#endif  // SINGA_MODEL_LAYER_BATCHNORM_H

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/eadd3f96/src/model/layer/cudnn_batchnorm.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/cudnn_batchnorm.cc b/src/model/layer/cudnn_batchnorm.cc
new file mode 100644
index 0000000..8288a41
--- /dev/null
+++ b/src/model/layer/cudnn_batchnorm.cc
@@ -0,0 +1,214 @@
+/*********************************************************
+*
+* Licensed to the Apache Software Foundation (ASF) under one
+* or more contributor license agreements.  See the NOTICE file
+* distributed with this work for additional information
+* regarding copyright ownership.  The ASF licenses this file
+* to you under the Apache License, Version 2.0 (the
+* "License"); you may not use this file except in compliance
+* with the License.  You may obtain a copy of the License at
+*
+*   http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing,
+* software distributed under the License is distributed on an
+* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+* KIND, either express or implied.  See the License for the
+* specific language governing permissions and limitations
+* under the License.
+*
+************************************************************/
+#include "cudnn_batchnorm.h"
+#ifdef USE_CUDNN
+
+namespace singa {
+
+CudnnBatchNorm::~CudnnBatchNorm() {
+  if (has_init_cudnn_) {
+    CUDNN_CHECK(cudnnDestroyTensorDescriptor(shape_desc_));
+    CUDNN_CHECK(cudnnDestroyTensorDescriptor(param_desc_));
+  }
+}
+
+void CudnnBatchNorm::ToDevice(Device* device) {
+  BatchNorm::ToDevice(device);
+  resultSaveMean_.ToDevice(device);
+  resultSaveVariance_.ToDevice(device);
+}
+
+void CudnnBatchNorm::Setup(const LayerConf& conf) {
+  BatchNorm::Setup(conf);
+  bnScale_.Reshape(Shape{1,channels_,1,1});
+  bnBias_.ResetLike(bnScale_);
+  dbnScale_.ResetLike(bnScale_);
+  dbnBias_.ResetLike(bnScale_);
+  runningMean_.ResetLike(bnScale_);
+  runningVariance_.ResetLike(bnScale_);
+  resultSaveMean_.ResetLike(bnScale_);
+  resultSaveVariance_.ResetLike(bnScale_);
+}
+
+void CudnnBatchNorm::InitCudnn(const Shape& shape, DataType dtype) {
+  CHECK(!has_init_cudnn_);
+  mode_ = CUDNN_BATCHNORM_SPATIAL;
+  CUDNN_CHECK(cudnnCreateTensorDescriptor(&shape_desc_));
+  CUDNN_CHECK(cudnnCreateTensorDescriptor(&param_desc_));
+  CHECK_EQ(shape.size(), 4u);
+  CUDNN_CHECK(cudnnSetTensor4dDescriptor(shape_desc_,
+        CUDNN_TENSOR_NCHW,
+        GetCudnnDataType(dtype),
+        shape[0],
+        shape[1],
+        shape[2],
+        shape[3]));
+  CUDNN_CHECK(cudnnSetTensor4dDescriptor(param_desc_,
+        CUDNN_TENSOR_NCHW,
+        GetCudnnDataType(dtype),
+        1,
+        shape[1],
+        1,
+        1));
+  has_init_cudnn_ = true;
+}
+const Tensor CudnnBatchNorm::Forward(int flag, const Tensor& input) {
+  auto shape = input.shape();
+  auto dtype = input.data_type();
+  Tensor output;
+  if (!has_init_cudnn_)
+    InitCudnn(shape, dtype);
+  // TODO(wangji): check device id of input and params
+  output.ResetLike(input);
+  if ((flag & kTrain) == kTrain) {
+    output.device()->Exec(
+        [=](Context* ctx) {
+          Blob *inBlob = input.blob(), *outBlob = output.blob(),
+            *saveMeanBlob = resultSaveMean_.blob(),
+            *saveVarBlob = resultSaveVariance_.blob(),
+            *runningMeanBlob = runningMean_.blob(),
+            *runningVarBlob = runningVariance_.blob(),
+            *bnScaleBlob = bnScale_.blob(),
+            *bnBiasBlob = bnBias_.blob();
+          const float alpha = 1.0f, beta = 0.0f;
+          double epsilon = CUDNN_BN_MIN_EPSILON;
+          CUDNN_CHECK(cudnnBatchNormalizationForwardTraining(
+              ctx->cudnn_handle,
+              this->mode_,
+              &alpha,
+              &beta,
+              shape_desc_,
+              inBlob->data(),
+              shape_desc_,
+              outBlob->mutable_data(),
+              param_desc_,
+              bnScaleBlob->data(),
+              bnBiasBlob->data(),
+              factor_,
+              runningMeanBlob->mutable_data(),
+              runningVarBlob->mutable_data(),
+              epsilon,
+              saveMeanBlob->mutable_data(),
+              saveVarBlob->mutable_data()));
+        },
+        {input.blob(),
+         bnScale_.blob(),
+         bnBias_.blob()},
+        {output.blob(),
+         runningMean_.blob(),
+         runningVariance_.blob(),
+         resultSaveMean_.blob(),
+         resultSaveVariance_.blob()});
+    buf_.push(input);
+  } else {
+    output.device()->Exec(
+        [=](Context* ctx) {
+          Blob *inBlob = input.blob(), *outBlob = output.blob(),
+            *runningMeanBlob = runningMean_.blob(),
+            *runningVarBlob = runningVariance_.blob(),
+            *bnScaleBlob = bnScale_.blob(),
+            *bnBiasBlob = bnBias_.blob();
+          const float alpha = 1.0f, beta = 0.0f;
+          double epsilon = CUDNN_BN_MIN_EPSILON;
+          CUDNN_CHECK(cudnnBatchNormalizationForwardInference(
+              ctx->cudnn_handle,
+              this->mode_,
+              &alpha,
+              &beta,
+              shape_desc_,
+              inBlob->data(),
+              shape_desc_,
+              outBlob->mutable_data(),
+              param_desc_,
+              bnScaleBlob->data(),
+              bnBiasBlob->data(),
+              runningMeanBlob->data(),
+              runningVarBlob->data(),
+              epsilon));
+        },
+        {input.blob(),
+         bnScale_.blob(),
+         bnBias_.blob(),
+         runningMean_.blob(),
+         runningVariance_.blob()},
+        {output.blob()});
+  }
+  return output;
+}
+
+const std::pair<Tensor, vector<Tensor>> CudnnBatchNorm::Backward(
+    int flag, const Tensor& grad) {
+  vector <Tensor> param_grad;
+  Tensor dx;
+  if ((flag & kTrain) == kTrain) {
+    Tensor input = buf_.top();
+    buf_.pop();
+    dx.ResetLike(grad);
+    dx.device()->Exec(
+        [=](Context* ctx) {
+          Blob *dyblob = grad.blob(), *dxblob = dx.blob(),
+            *xblob = input.blob(),
+            *bnScaleBlob = bnScale_.blob(),
+            *dbnScaleBlob = dbnScale_.blob(),
+            *dbnBiasBlob = dbnBias_.blob(),
+            *saveMeanBlob = resultSaveMean_.blob(),
+            *saveVarBlob = resultSaveVariance_.blob();
+          const float alpha = 1.0f, beta = .0f;
+          double epsilon = CUDNN_BN_MIN_EPSILON;
+          CUDNN_CHECK(cudnnBatchNormalizationBackward(ctx->cudnn_handle,
+              this->mode_,
+              &alpha,
+              &beta,
+              &alpha,
+              &beta,
+              shape_desc_,
+              xblob->data(),
+              shape_desc_,
+              dyblob->data(),
+              shape_desc_,
+              dxblob->mutable_data(),
+              param_desc_,
+              bnScaleBlob->data(),
+              dbnScaleBlob->mutable_data(),
+              dbnBiasBlob->mutable_data(),
+              epsilon,
+              saveMeanBlob->data(),
+              saveVarBlob->data()));
+
+        },
+        {dx.blob(),
+         grad.blob(),
+         bnScale_.blob(),
+         resultSaveMean_.blob(),
+         resultSaveVariance_.blob()},
+        {dx.blob(),
+         dbnScale_.blob(),
+         dbnBias_.blob()});
+  } else {
+    LOG(ERROR) << "Do not call backward for evaluation phase";
+  }
+  param_grad.push_back(dbnScale_);
+  param_grad.push_back(dbnBias_);
+  return std::make_pair(dx, param_grad);
+}
+}  // namespace
+
+#endif  // USE_CUDNN

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/eadd3f96/src/model/layer/cudnn_batchnorm.h
----------------------------------------------------------------------
diff --git a/src/model/layer/cudnn_batchnorm.h b/src/model/layer/cudnn_batchnorm.h
new file mode 100644
index 0000000..83258d2
--- /dev/null
+++ b/src/model/layer/cudnn_batchnorm.h
@@ -0,0 +1,60 @@
+/*********************************************************
+*
+* Licensed to the Apache Software Foundation (ASF) under one
+* or more contributor license agreements.  See the NOTICE file
+* distributed with this work for additional information
+* regarding copyright ownership.  The ASF licenses this file
+* to you under the Apache License, Version 2.0 (the
+* "License"); you may not use this file except in compliance
+* with the License.  You may obtain a copy of the License at
+*
+*   http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing,
+* software distributed under the License is distributed on an
+* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+* KIND, either express or implied.  See the License for the
+* specific language governing permissions and limitations
+* under the License.
+*
+************************************************************/
+#ifndef SINGA_MODEL_LAYER_CUDNN_BATCHNORM_H
+#define SINGA_MODEL_LAYER_CUDNN_BATCHNORM_H
+#include "singa_config.h"
+#ifdef USE_CUDNN
+
+#include "batchnorm.h"
+#include "cudnn_utils.h"
+
+namespace singa {
+class CudnnBatchNorm : public BatchNorm {
+ public:
+   ~CudnnBatchNorm();
+   /// \copy doc Layer::layer_type()
+   const std::string layer_type() const override {
+     return "CudnnBatchNorm";
+   }
+
+   void Setup(const LayerConf& conf) override;
+
+   const Tensor Forward(int flag, const Tensor& input)
+     override;
+   const std::pair<Tensor, vector<Tensor>> Backward(
+       int flag, const Tensor& grad) override;
+
+   /// Init cudnn related data structures.
+   void InitCudnn(const Shape& shape, DataType dtype);
+   void ToDevice(Device* device) override;
+
+ private:
+   bool has_init_cudnn_ = false;
+   cudnnBatchNormMode_t mode_;
+   cudnnLRNDescriptor_t lrn_desc_;
+   cudnnTensorDescriptor_t shape_desc_, param_desc_;
+   Tensor resultSaveMean_, resultSaveVariance_;
+   
+}; // class CudnnBatchNorm
+}  // namespace
+
+#endif  // USE_CUDNN
+#endif  // SINGA_MODEL_LAYER_CUDNN_BATCHNORM 

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/eadd3f96/src/model/layer/cudnn_lrn.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/cudnn_lrn.cc b/src/model/layer/cudnn_lrn.cc
new file mode 100644
index 0000000..ee661b6
--- /dev/null
+++ b/src/model/layer/cudnn_lrn.cc
@@ -0,0 +1,114 @@
+/*********************************************************
+*
+* Licensed to the Apache Software Foundation (ASF) under one
+* or more contributor license agreements.  See the NOTICE file
+* distributed with this work for additional information
+* regarding copyright ownership.  The ASF licenses this file
+* to you under the Apache License, Version 2.0 (the
+* "License"); you may not use this file except in compliance
+* with the License.  You may obtain a copy of the License at
+*
+*   http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing,
+* software distributed under the License is distributed on an
+* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+* KIND, either express or implied.  See the License for the
+* specific language governing permissions and limitations
+* under the License.
+*
+************************************************************/
+#include "cudnn_lrn.h"
+#ifdef USE_CUDNN
+#include "cudnn_utils.h"
+
+namespace singa {
+CudnnLRN::~CudnnLRN() {
+  if (has_init_cudnn_) {
+    CUDNN_CHECK(cudnnDestroyLRNDescriptor(lrn_desc_));
+    CUDNN_CHECK(cudnnDestroyTensorDescriptor(shape_desc_));
+  }
+}
+void CudnnLRN::InitCudnn(const Shape& shape , DataType dtype) {
+  CHECK(!has_init_cudnn_);
+  mode_ = CUDNN_LRN_CROSS_CHANNEL_DIM1;
+  CUDNN_CHECK(cudnnCreateTensorDescriptor(&shape_desc_));
+  CHECK_EQ(shape.size(), 4);
+  CUDNN_CHECK(cudnnSetTensor4dDescriptor(shape_desc_,
+      CUDNN_TENSOR_NCHW,
+      GetCudnnDataType(dtype),
+      shape[0],
+      shape[1],
+      shape[2],
+      shape[3]));
+  CUDNN_CHECK(cudnnCreateLRNDescriptor(&lrn_desc_));
+  CUDNN_CHECK(cudnnSetLRNDescriptor(lrn_desc_,
+        local_size_,
+        alpha_,
+        beta_,
+        k_));
+  has_init_cudnn_ = true;
+}
+const Tensor CudnnLRN::Forward(int flag, const Tensor& input) {
+  auto shape = input.shape();
+  auto dtype = input.data_type();
+  if (!has_init_cudnn_)
+    InitCudnn(shape, dtype);
+  Tensor output;
+  output.ResetLike(input);
+  output.device()->Exec(
+      [=](Context* ctx) {
+        Blob *inblob = input.blob(), *outblob = output.blob();
+        const float alpha = 1.0f, beta = 0.0f;
+        CUDNN_CHECK(cudnnLRNCrossChannelForward(ctx->cudnn_handle,
+            this->lrn_desc_,
+            this->mode_,
+            &alpha,
+            this->shape_desc_,
+            inblob->data(),
+            &beta,
+            this->shape_desc_,
+            outblob->mutable_data()));
+      }, {input.blob()}, {output.blob()});
+  buf_.push(input);
+  buf_.push(output);
+  return output;
+}
+
+const std::pair<Tensor, vector<Tensor>> CudnnLRN::Backward(
+    int flag, const Tensor& grad) {
+  vector <Tensor> param_grad;
+  Tensor dx;
+  Tensor output = buf_.top();
+  buf_.pop();
+  Tensor input = buf_.top();
+  buf_.pop();
+  if ((flag & kTrain) == kTrain) {
+    dx.ResetLike(grad);
+    dx.device()->Exec(
+        [=](Context *ctx) {
+          Blob *dyblob = grad.blob(), *dxblob = dx.blob();
+          Blob *yblob = output.blob(), *xblob = input.blob();
+          float alpha = 1.0f, beta = 0.0f;
+          CUDNN_CHECK(cudnnLRNCrossChannelBackward(ctx->cudnn_handle,
+              this->lrn_desc_,
+              this->mode_,
+              &alpha,
+              this->shape_desc_,
+              yblob->data(),
+              this->shape_desc_,
+              dyblob->data(),
+              this->shape_desc_,
+              xblob->data(),
+              &beta,
+              this->shape_desc_,
+              dxblob->mutable_data()));
+        }, {output.blob(), grad.blob(), input.blob()}, {dx.blob()});
+  } else {
+    LOG(ERROR) << "Do not call backward for evaluation phase";
+  }
+  return std::make_pair(dx, param_grad);
+}
+}  // namespace
+
+#endif  // USE_CUDNN

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/eadd3f96/src/model/layer/cudnn_lrn.h
----------------------------------------------------------------------
diff --git a/src/model/layer/cudnn_lrn.h b/src/model/layer/cudnn_lrn.h
new file mode 100644
index 0000000..0f650fe
--- /dev/null
+++ b/src/model/layer/cudnn_lrn.h
@@ -0,0 +1,56 @@
+/*********************************************************
+*
+* Licensed to the Apache Software Foundation (ASF) under one
+* or more contributor license agreements.  See the NOTICE file
+* distributed with this work for additional information
+* regarding copyright ownership.  The ASF licenses this file
+* to you under the Apache License, Version 2.0 (the
+* "License"); you may not use this file except in compliance
+* with the License.  You may obtain a copy of the License at
+*
+*   http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing,
+* software distributed under the License is distributed on an
+* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+* KIND, either express or implied.  See the License for the
+* specific language governing permissions and limitations
+* under the License.
+*
+************************************************************/
+#ifndef SINGA_MODEL_LAYER_CUDNN_LRN_H_
+#define SINGA_MODEL_LAYER_CUDNN_LRN_H_
+#include "singa_config.h"
+#ifdef USE_CUDNN
+
+#include "lrn.h"
+#include "cudnn_utils.h"
+
+namespace singa {
+class CudnnLRN : public LRN {
+ public:
+   ~CudnnLRN();
+   /// \copy doc Layer::layer_type()
+   const std::string layer_type() const override {
+     return "CudnnLRN";
+   }
+
+   const Tensor Forward(int flag, const Tensor& input)
+     override;
+   const std::pair<Tensor, vector<Tensor>> Backward(
+       int flag, const Tensor& grad) override;
+
+   /// Init cudnn related data structures.
+   void InitCudnn(const Shape& shape, DataType dtype);
+
+ private:
+   bool has_init_cudnn_ = false;
+   cudnnLRNMode_t mode_;
+   cudnnLRNDescriptor_t lrn_desc_;
+   cudnnTensorDescriptor_t shape_desc_;
+   
+}; // class CudnnLRN
+}  // namespcae
+
+#endif  // USE_CUDNN
+#endif  // SINGA_MODEL_LAYER_CUDNN_LRN_H_H

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/eadd3f96/src/model/layer/lrn.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/lrn.cc b/src/model/layer/lrn.cc
new file mode 100644
index 0000000..55135f1
--- /dev/null
+++ b/src/model/layer/lrn.cc
@@ -0,0 +1,59 @@
+/*********************************************************
+*
+* Licensed to the Apache Software Foundation (ASF) under one
+* or more contributor license agreements.  See the NOTICE file
+* distributed with this work for additional information
+* regarding copyright ownership.  The ASF licenses this file
+* to you under the Apache License, Version 2.0 (the
+* "License"); you may not use this file except in compliance
+* with the License.  You may obtain a copy of the License at
+*
+*   http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing,
+* software distributed under the License is distributed on an
+* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+* KIND, either express or implied.  See the License for the
+* specific language governing permissions and limitations
+* under the License.
+*
+************************************************************/
+#include "lrn.h"
+
+namespace singa{
+void LRN::Setup(const LayerConf& conf) {
+  Layer::Setup(conf);
+  local_size_ = conf.lrn_conf().local_size();
+  CHECK_EQ(local_size_ % 2, 1) << "LRN only supports odd values for Localvol";
+  k_ = conf.lrn_conf().k();
+  alpha_ = conf.lrn_conf().alpha();
+  beta_ = conf.lrn_conf().beta();
+}
+
+const Tensor LRN::Forward(int flag, const Tensor& input) {
+  //Tensor output;
+  //const float salpha = alpha_ / local_size_;
+  LOG(FATAL) << "Not implemented";
+  /* Tensor API may be need
+   * 1. set
+   * template <typename Dtype>
+   * void Set(Dtype val);
+   *
+   * 2. axpy
+   * 3. padding
+   *
+   *
+   */
+  Tensor output;
+  return output;
+}
+
+const std::pair<Tensor, vector<Tensor>> LRN::Backward(
+    int flag, const Tensor& grad) {
+  LOG(FATAL) << "Not implemented";
+  Tensor dx;
+  vector<Tensor> param_grad;
+  return std::make_pair(dx, param_grad);
+}
+
+}  // namespace

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/eadd3f96/src/model/layer/lrn.h
----------------------------------------------------------------------
diff --git a/src/model/layer/lrn.h b/src/model/layer/lrn.h
new file mode 100644
index 0000000..118d062
--- /dev/null
+++ b/src/model/layer/lrn.h
@@ -0,0 +1,70 @@
+/*********************************************************
+*
+* Licensed to the Apache Software Foundation (ASF) under one
+* or more contributor license agreements.  See the NOTICE file
+* distributed with this work for additional information
+* regarding copyright ownership.  The ASF licenses this file
+* to you under the Apache License, Version 2.0 (the
+* "License"); you may not use this file except in compliance
+* with the License.  You may obtain a copy of the License at
+*
+*   http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing,
+* software distributed under the License is distributed on an
+* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+* KIND, either express or implied.  See the License for the
+* specific language governing permissions and limitations
+* under the License.
+*
+************************************************************/
+#ifndef SINGA_MODEL_LAYER_LRN_H_
+#define SINGA_MODEL_LAYER_LRN_H_
+#include "singa/model/layer.h"
+#include <stack>
+
+namespace singa {
+class LRN : public Layer {
+ public:
+  /// \copydoc Layer::layer_type()
+  const std::string layer_type() const override {
+    return "LRN";
+  }
+
+  /// \copydoc Layer::Setup(const LayerConf&)
+  void Setup(const LayerConf& conf) override;
+
+  /**
+   * Local Response Normalization edge
+   *
+   * @f$ b_i=a_i/x_i^beta @f$
+   * @f$x_i=k+alpha*\sum_{j=max(0,i-n/2)}^{min(N,i+n/2)}(a_j)^2 @f$
+   * n is size of local response area.
+   * @f$a_i@f$, the activation (after ReLU) of a neuron convolved with the i-th kernel.
+   * @f$b_i@f$, the neuron after normalization, N is the total num of kernels
+   */
+  const Tensor Forward(int flag, const Tensor& input)
+    override;
+
+  /// \copydoc Layer::Backward(int, const Tensor&, const Tensor&);
+  const std::pair<Tensor, vector<Tensor>> Backward(
+      int flag, const Tensor& grad) override;
+
+  int local_size() const { return local_size_; }
+  float alpha() const { return alpha_; }
+  float beta() const { return beta_; }
+  float k() const { return k_; }
+
+ protected:
+  //!< hyper-parameter: size local response (neighbor) area
+  int local_size_;
+  //!< other hyper-parameters
+  float alpha_, beta_, k_;
+  // store intermediate data, i.e., input tensor
+  std::stack<Tensor> buf_;
+  
+}; // class LRN
+} // namespace 
+
+#endif  // SINGA_MODEL_LAYER_LRN_H_
+

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/eadd3f96/src/proto/model.proto
----------------------------------------------------------------------
diff --git a/src/proto/model.proto b/src/proto/model.proto
index 16ba62f..d368296 100644
--- a/src/proto/model.proto
+++ b/src/proto/model.proto
@@ -231,6 +231,7 @@ message LayerConf {
   // Used in SINGA
   optional DenseConf dense_conf = 201;
   optional MetricConf metric_conf = 200;
+  optional BatchNormConf batchnorm_conf = 202;
 }
 
 // Message that stores hyper-parameters used to apply transformation
@@ -902,14 +903,12 @@ message SPPConf {
   }
   optional uint32 pyramid_height = 1;
   optional PoolMethod pool = 2 [default = MAX]; // The pooling method
-  /*
   enum Engine {
     DEFAULT = 0;
     CAFFE = 1;
     CUDNN = 2;
   }
   optional Engine engine = 6 [default = DEFAULT];
-  */
 }
 
 message PReLUConf {
@@ -921,3 +920,13 @@ message PReLUConf {
   // Whether or not slope paramters are shared across channels.
   optional bool channel_shared = 2 [default = false];
 }
+
+message BatchNormConf {
+  // Used in the moving average computation runningMean =
+  // newMean*factor + runningMean*(1-factor).
+  optional double factor = 1 [default = 0.9];
+  // input shape
+  optional int32 channels = 2;
+  optional int32 height = 3;
+  optional int32 width = 4;
+}

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/eadd3f96/test/singa/test_cudnn_batchnorm.cc
----------------------------------------------------------------------
diff --git a/test/singa/test_cudnn_batchnorm.cc b/test/singa/test_cudnn_batchnorm.cc
new file mode 100644
index 0000000..d38fdaa
--- /dev/null
+++ b/test/singa/test_cudnn_batchnorm.cc
@@ -0,0 +1,257 @@
+/*********************************************************
+*
+* Licensed to the Apache Software Foundation (ASF) under one
+* or more contributor license agreements.  See the NOTICE file
+* distributed with this work for additional information
+* regarding copyright ownership.  The ASF licenses this file
+* to you under the Apache License, Version 2.0 (the
+* "License"); you may not use this file except in compliance
+* with the License.  You may obtain a copy of the License at
+*
+*   http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing,
+* software distributed under the License is distributed on an
+* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+* KIND, either express or implied.  See the License for the
+* specific language governing permissions and limitations
+* under the License.
+*
+************************************************************/
+
+#include "../src/model/layer/cudnn_batchnorm.h"
+
+#ifdef USE_CUDNN
+#include "gtest/gtest.h"
+
+using singa::CudnnBatchNorm;
+
+TEST(CudnnBatchNorm, Setup) {
+  CudnnBatchNorm batchnorm;
+  EXPECT_EQ("CudnnBatchNorm", batchnorm.layer_type());
+
+  singa::LayerConf conf;
+  singa::BatchNormConf *batchnorm_conf = conf.mutable_batchnorm_conf();
+  batchnorm_conf->set_factor(0.01);
+  batchnorm_conf->set_channels(2);
+  batchnorm_conf->set_height(4);
+  batchnorm_conf->set_width(4);
+  batchnorm.Setup(conf);
+
+  EXPECT_FLOAT_EQ(0.01, batchnorm.factor());
+  EXPECT_EQ(2u, batchnorm.channels());
+  EXPECT_EQ(4u, batchnorm.height());
+  EXPECT_EQ(4u, batchnorm.width());
+}
+
+TEST(CudnnBatchNorm, Forward) {
+  CudnnBatchNorm batchnorm;
+  const float x[] = {
+    0.0736655, 0.0459045, 0.0779517, 0.0771059,
+    0.0586862, 0.0561263, 0.0708457, 0.0977273,
+    0.0405025, -0.170897, 0.0208982, 0.136865,
+    -0.0367905, -0.0618205, -0.0103908, -0.0522777,
+    -0.122161, -0.025427, -0.0718576, -0.185941,
+    0.0166533, 0.178679, -0.0576606, -0.137817,
+    0.150676, 0.153442, -0.0929899, -0.148675,
+    -0.112459, -0.106284, -0.103074, -0.0668811
+  };
+  singa::CudaGPU cuda(0, 1);
+  singa::Tensor in(singa::Shape{1,2,4,4}, &cuda);
+  in.CopyDataFromHostPtr(x, 1*2*4*4);
+  const float alpha_[] = {1, 1};
+  singa::Tensor alpha(singa::Shape{1,2,1,1}, &cuda);
+  alpha.CopyDataFromHostPtr(alpha_, 1*2*1*1);
+
+  const float beta_[] = {0, 0};
+  singa::Tensor beta(singa::Shape{1,2,1,1}, &cuda);
+  beta.CopyDataFromHostPtr(beta_, 1*2*1*1);
+
+  singa::LayerConf conf;
+  singa::BatchNormConf *batchnorm_conf = conf.mutable_batchnorm_conf();
+  batchnorm_conf->set_factor(0.9);
+  batchnorm_conf->set_channels(2);
+  batchnorm_conf->set_height(4);
+  batchnorm_conf->set_width(4);
+  batchnorm.Setup(conf);
+
+  batchnorm.ToDevice(&cuda);
+  batchnorm.set_bnScale(alpha);
+  batchnorm.set_bnBias(beta);
+  batchnorm.set_runningMean(beta);
+  batchnorm.set_runningVariance(beta);
+  singa::Tensor out = batchnorm.Forward(singa::kTrain, in);
+  singa::CppCPU host(0, 1);
+  out.ToHost();
+  const float *outptr = out.data<const float *>();
+  const auto & shape = out.shape();
+  EXPECT_EQ(4u, shape.size());
+  EXPECT_EQ(1u, shape[0]);
+  EXPECT_EQ(2u, shape[1]);
+  EXPECT_EQ(4u, shape[2]);
+  EXPECT_EQ(4u, shape[3]);
+  EXPECT_NEAR(0.637092, outptr[0], 1e-4f);
+  EXPECT_NEAR(0.262057, outptr[1], 1e-4f);
+  EXPECT_NEAR(0.694995, outptr[2], 1e-4f);
+  EXPECT_NEAR(0.683569, outptr[3], 1e-4f);
+  EXPECT_NEAR(0.43473, outptr[4], 1e-4f);
+  EXPECT_NEAR(0.400147, outptr[5], 1e-4f);
+  EXPECT_NEAR(0.598998, outptr[6], 1e-4f);
+  EXPECT_NEAR(0.962152, outptr[7], 1e-4f);
+  EXPECT_NEAR(0.189079, outptr[8], 1e-4f);
+  EXPECT_NEAR(-2.6668, outptr[9], 1e-4f);
+  EXPECT_NEAR(-0.0757632, outptr[10], 1e-4f);
+  EXPECT_NEAR(1.49088, outptr[11], 1e-4f);
+  EXPECT_NEAR(-0.855104, outptr[12], 1e-4f);
+  EXPECT_NEAR(-1.19324, outptr[13], 1e-4f);
+  EXPECT_NEAR(-0.498459, outptr[14], 1e-4f);
+  EXPECT_NEAR(-1.06433, outptr[15], 1e-4f);
+  EXPECT_NEAR(-0.696646, outptr[16], 1e-4f);
+  EXPECT_NEAR(0.185125, outptr[17], 1e-4f);
+  EXPECT_NEAR(-0.238109, outptr[18], 1e-4f);
+  EXPECT_NEAR(-1.27803, outptr[19], 1e-4f);
+  EXPECT_NEAR(0.568704, outptr[20], 1e-4f);
+  EXPECT_NEAR(2.04564, outptr[21], 1e-4f);
+  EXPECT_NEAR(-0.108697, outptr[22], 1e-4f);
+  EXPECT_NEAR(-0.839356, outptr[23], 1e-4f);
+  EXPECT_NEAR(1.79038, outptr[24], 1e-4f);
+  EXPECT_NEAR(1.81559, outptr[25], 1e-4f);
+  EXPECT_NEAR(-0.430738, outptr[26], 1e-4f);
+  EXPECT_NEAR(-0.938335, outptr[27], 1e-4f);
+  EXPECT_NEAR(-0.608203, outptr[28], 1e-4f);
+  EXPECT_NEAR(-0.551921, outptr[29], 1e-4f);
+  EXPECT_NEAR(-0.522658, outptr[30], 1e-4f);
+  EXPECT_NEAR(-0.192746, outptr[31], 1e-4f);
+}
+
+TEST(CudnnBatchNorm, Backward) {
+  CudnnBatchNorm batchnorm;
+  const float x[] = {
+    0.0736655, 0.0459045, 0.0779517, 0.0771059,
+    0.0586862, 0.0561263, 0.0708457, 0.0977273,
+    0.0405025, -0.170897, 0.0208982, 0.136865,
+    -0.0367905, -0.0618205, -0.0103908, -0.0522777,
+    -0.122161, -0.025427, -0.0718576, -0.185941,
+    0.0166533, 0.178679, -0.0576606, -0.137817,
+    0.150676, 0.153442, -0.0929899, -0.148675,
+    -0.112459, -0.106284, -0.103074, -0.0668811
+  };
+  singa::CudaGPU cuda(0, 1);
+  singa::Tensor x_tensor(singa::Shape{1,2,4,4}, &cuda);
+  x_tensor.CopyDataFromHostPtr(x, 1*2*4*4);
+
+  singa::LayerConf conf;
+  singa::BatchNormConf *batchnorm_conf = conf.mutable_batchnorm_conf();
+  batchnorm_conf->set_factor(1);
+  batchnorm_conf->set_channels(2);
+  batchnorm_conf->set_height(4);
+  batchnorm_conf->set_width(4);
+  batchnorm.Setup(conf);
+
+  const float dy[] = {
+    -0.0064714, 0, 0, 0,
+    0, -0.00297655, -0.0195729, 0,
+    0, 0, 0, 0,
+    0, 0, 0, -0.0032594,
+    0, 0, 0, 0,
+    0, 0, 0.0125562, 0,
+    0.00041933, 0.000386108, -0.0074611, 0.0015929,
+    0.00468428, 0.00735506, -0.00682525, 0.00342023
+  };
+
+  singa::Tensor dy_tensor(singa::Shape{1,2,4,4}, &cuda);
+  dy_tensor.CopyDataFromHostPtr(dy, 1*2*4*4);
+  const float alpha_[] = {1, 1};
+  singa::Tensor alpha(singa::Shape{1,2,1,1}, &cuda);
+  alpha.CopyDataFromHostPtr(alpha_, 1*2*1*1);
+
+  const float beta_[] = {0, 0};
+  singa::Tensor beta(singa::Shape{1,2,1,1}, &cuda);
+  beta.CopyDataFromHostPtr(beta_, 1*2*1*1);
+
+  const float mean_[] = {0.0123405, -0.0622333};
+  singa::Tensor mean(singa::Shape{1,2,1,1}, &cuda);
+  mean.CopyDataFromHostPtr(mean_, 1*2*1*1);
+
+  const float var_[] = {15.9948, 8.68198};
+  singa::Tensor var(singa::Shape{1,2,1,1}, &cuda);
+  var.CopyDataFromHostPtr(var_, 1*2*1*1);
+
+  batchnorm.ToDevice(&cuda);
+  batchnorm.set_bnScale(alpha);
+  batchnorm.set_bnBias(beta);
+  batchnorm.set_runningMean(beta);
+  batchnorm.set_runningVariance(beta);
+  batchnorm.Forward(singa::kTrain, x_tensor);
+  const auto ret = batchnorm.Backward(singa::kTrain, dy_tensor);
+  singa::CppCPU host(0, 1);
+  singa::Tensor dx = ret.first;
+  dx.ToDevice(&host);
+  const float *dxptr = dx.data<const float *>();
+  const auto & shape = dx.shape();
+  EXPECT_EQ(4u, shape.size());
+  EXPECT_EQ(1u, shape[0]);
+  EXPECT_EQ(2u, shape[1]);
+  EXPECT_EQ(4u, shape[2]);
+  EXPECT_EQ(4u, shape[3]);
+  EXPECT_NEAR(-0.0528703, dxptr[0], 1e-4f);
+  EXPECT_NEAR(0.0302578, dxptr[1], 1e-4f);
+  EXPECT_NEAR(0.0352178, dxptr[2], 1e-4f);
+  EXPECT_NEAR(0.0350869, dxptr[3], 1e-4f);
+  EXPECT_NEAR(0.032236, dxptr[4], 1e-4f);
+  EXPECT_NEAR(-0.00837157, dxptr[5], 1e-4f);
+  EXPECT_NEAR(-0.2303, dxptr[6], 1e-4f);
+  EXPECT_NEAR(0.0382786, dxptr[7], 1e-4f);
+  EXPECT_NEAR(0.0294217, dxptr[8], 1e-4f);
+  EXPECT_NEAR(-0.00329757, dxptr[9], 1e-4f);
+  EXPECT_NEAR(0.0263874, dxptr[10], 1e-4f);
+  EXPECT_NEAR(0.0443361, dxptr[11], 1e-4f);
+  EXPECT_NEAR(0.0174587, dxptr[12], 1e-4f);
+  EXPECT_NEAR(0.0135847, dxptr[13], 1e-4f);
+  EXPECT_NEAR(0.0215447, dxptr[14], 1e-4f);
+  EXPECT_NEAR(-0.0289709, dxptr[15], 1e-4f);
+  EXPECT_NEAR(-0.0100591, dxptr[16], 1e-4f);
+  EXPECT_NEAR(-0.00895677, dxptr[17], 1e-4f);
+  EXPECT_NEAR(-0.00948587, dxptr[18], 1e-4f);
+  EXPECT_NEAR(-0.0107859, dxptr[19], 1e-4f);
+  EXPECT_NEAR(-0.00847725, dxptr[20], 1e-4f);
+  EXPECT_NEAR(-0.0066309, dxptr[21], 1e-4f);
+  EXPECT_NEAR(0.105131, dxptr[22], 1e-4f);
+  EXPECT_NEAR(-0.0102375, dxptr[23], 1e-4f);
+  EXPECT_NEAR(-0.00312763, dxptr[24], 1e-4f);
+  EXPECT_NEAR(-0.00339895, dxptr[25], 1e-4f);
+  EXPECT_NEAR(-0.0777377, dxptr[26], 1e-4f);
+  EXPECT_NEAR(0.00415871, dxptr[27], 1e-4f);
+  EXPECT_NEAR(0.0327506, dxptr[28], 1e-4f);
+  EXPECT_NEAR(0.0571663, dxptr[29], 1e-4f);
+  EXPECT_NEAR(-0.0720566, dxptr[30], 1e-4f);
+  EXPECT_NEAR(0.0217477, dxptr[31], 1e-4f);
+
+  singa::Tensor dbnScale = ret.second.at(0);
+  dbnScale.ToDevice(&host);
+  const float *dbnScaleptr = dbnScale.data<const float *>();
+  const auto & dbnScaleShape = dbnScale.shape();
+  EXPECT_EQ(4u, dbnScaleShape.size());
+  EXPECT_EQ(1u, dbnScaleShape[0]);
+  EXPECT_EQ(2u, dbnScaleShape[1]);
+  EXPECT_EQ(1u, dbnScaleShape[2]);
+  EXPECT_EQ(1u, dbnScaleShape[3]);
+
+  EXPECT_NEAR(-0.013569f, dbnScaleptr[0], 1e-4f);
+  EXPECT_NEAR(-0.00219431f, dbnScaleptr[1], 1e-4f);
+
+  singa::Tensor dbnBias = ret.second.at(1);
+  dbnBias.ToDevice(&host);
+  const float *dbnBiasptr = dbnBias.data<const float *>();
+  const auto & dbnBiasShape = dbnBias.shape();
+  EXPECT_EQ(4u, dbnBiasShape.size());
+  EXPECT_EQ(1u, dbnBiasShape[0]);
+  EXPECT_EQ(2u, dbnBiasShape[1]);
+  EXPECT_EQ(1u, dbnBiasShape[2]);
+  EXPECT_EQ(1u, dbnBiasShape[3]);
+
+  EXPECT_NEAR(-0.0322803f, dbnBiasptr[0], 1e-4f);
+  EXPECT_NEAR(0.0161278f, dbnBiasptr[1], 1e-4f);
+}
+
+#endif  //  USE_CUDNN

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/eadd3f96/test/singa/test_cudnn_lrn.cc
----------------------------------------------------------------------
diff --git a/test/singa/test_cudnn_lrn.cc b/test/singa/test_cudnn_lrn.cc
new file mode 100644
index 0000000..390c588
--- /dev/null
+++ b/test/singa/test_cudnn_lrn.cc
@@ -0,0 +1,205 @@
+/*********************************************************
+*
+* Licensed to the Apache Software Foundation (ASF) under one
+* or more contributor license agreements.  See the NOTICE file
+* distributed with this work for additional information
+* regarding copyright ownership.  The ASF licenses this file
+* to you under the Apache License, Version 2.0 (the
+* "License"); you may not use this file except in compliance
+* with the License.  You may obtain a copy of the License at
+*
+*   http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing,
+* software distributed under the License is distributed on an
+* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+* KIND, either express or implied.  See the License for the
+* specific language governing permissions and limitations
+* under the License.
+*
+************************************************************/
+
+#include "../src/model/layer/cudnn_lrn.h"
+
+#ifdef USE_CUDNN
+// cudnn lrn is added in cudnn 4
+#if CUDNN_VERSION_MAJOR >=4
+#include "gtest/gtest.h"
+
+using singa::CudnnLRN;
+
+TEST(CudnnLRN, Setup) {
+  CudnnLRN lrn;
+  EXPECT_EQ("CudnnLRN", lrn.layer_type());
+
+  singa::LayerConf conf;
+  singa::LRNConf *lrn_conf = conf.mutable_lrn_conf();
+  lrn_conf->set_k(1.0);
+  lrn_conf->set_local_size(3);
+  lrn_conf->set_alpha(0.1);
+  lrn_conf->set_beta(0.75);
+  lrn.Setup(conf);
+
+  EXPECT_FLOAT_EQ(1.0, lrn.k());
+  EXPECT_EQ(3, lrn.local_size());
+  EXPECT_FLOAT_EQ(0.1, lrn.alpha());
+  EXPECT_FLOAT_EQ(0.75, lrn.beta());
+}
+
+TEST(CudnnLRN, Forward) {
+  CudnnLRN lrn;
+  const float x[] = {
+    0.00658502, -0.0496967, -0.0333733, -0.0263094,
+    -0.044298, 0.0211638, 0.0829358, -0.0172312,
+    -0.0665471, -0.10017, -0.0750333, -0.104551,
+    -0.00981208, -0.0583349, -0.0751652, 0.011747,
+    0.0151165, 0.0304321, 0.0736639, -0.00652653,
+    0.00962833, 0.169646, -0.044588, -0.00244141,
+    0.0597329, -0.0530868, 0.0124246, 0.108429,
+    0.0451175, 0.0247055, 0.0304345, 0.0179575
+  };
+  singa::CudaGPU cuda(0, 1);
+  singa::Tensor in(singa::Shape{1,2,4,4}, &cuda);
+  in.CopyDataFromHostPtr(x, 1*2*4*4);
+
+  singa::LayerConf conf;
+  singa::LRNConf *lrn_conf = conf.mutable_lrn_conf();
+  lrn_conf->set_k(1.0);
+  lrn_conf->set_local_size(3);
+  lrn_conf->set_alpha(0.1);
+  lrn_conf->set_beta(0.75);
+  lrn.Setup(conf);
+
+  singa::Tensor out = lrn.Forward(singa::kTrain, in);
+  singa::CppCPU host(0, 1);
+  out.ToDevice(&host);
+  const float *outptr = out.data<const float *>();
+  const auto & shape = out.shape();
+  EXPECT_EQ(4u, shape.size());
+  EXPECT_EQ(1u, shape[0]);
+  EXPECT_EQ(2u, shape[1]);
+  EXPECT_EQ(4u, shape[2]);
+  EXPECT_EQ(4u, shape[3]);
+
+  EXPECT_NEAR(0.00658498f, outptr[0], 1e-6f);
+  EXPECT_NEAR(-0.0496925f, outptr[1], 1e-6f);
+  EXPECT_NEAR(-0.0333678f, outptr[2], 1e-6f);
+  EXPECT_NEAR(-0.0263089f, outptr[3], 1e-6f);
+  EXPECT_NEAR(-0.0442958f, outptr[4], 1e-6f);
+  EXPECT_NEAR(0.0211483f, outptr[5], 1e-6f);
+  EXPECT_NEAR(0.0829174f, outptr[6], 1e-6f);
+  EXPECT_NEAR(-0.0172311f, outptr[7], 1e-6f);
+  EXPECT_NEAR(-0.0665338f, outptr[8], 1e-6f);
+  EXPECT_NEAR(-0.100138f, outptr[9], 1e-6f);
+  EXPECT_NEAR(-0.0750224f, outptr[10], 1e-6f);
+  EXPECT_NEAR(-0.104492f, outptr[11], 1e-6f);
+  EXPECT_NEAR(-0.00981155f, outptr[12], 1e-6f);
+  EXPECT_NEAR(-0.058329f, outptr[13], 1e-6f);
+  EXPECT_NEAR(-0.0751528f, outptr[14], 1e-6f);
+  EXPECT_NEAR(0.0117468f, outptr[15], 1e-6f);
+  EXPECT_NEAR(0.0151164f, outptr[16], 1e-6f);
+  EXPECT_NEAR(0.0304296f, outptr[17], 1e-6f);
+  EXPECT_NEAR(0.0736518f, outptr[18], 1e-6f);
+  EXPECT_NEAR(-0.00652641f, outptr[19], 1e-6f);
+  EXPECT_NEAR(0.00962783f, outptr[20], 1e-6f);
+  EXPECT_NEAR(0.169522f, outptr[21], 1e-6f);
+  EXPECT_NEAR(-0.0445781f, outptr[22], 1e-6f);
+  EXPECT_NEAR(-0.00244139f, outptr[23], 1e-6f);
+  EXPECT_NEAR(0.0597209f, outptr[24], 1e-6f);
+  EXPECT_NEAR(-0.0530697f, outptr[25], 1e-6f);
+  EXPECT_NEAR(0.0124228f, outptr[26], 1e-6f);
+  EXPECT_NEAR(0.108367f, outptr[27], 1e-6f);
+  EXPECT_NEAR(0.045115f, outptr[28], 1e-6f);
+  EXPECT_NEAR(0.024703f, outptr[29], 1e-6f);
+  EXPECT_NEAR(0.0304295f, outptr[30], 1e-6f);
+  EXPECT_NEAR(0.0179573f, outptr[31], 1e-6f);
+}
+
+TEST(CudnnLRN, Backward) {
+  CudnnLRN lrn;
+
+  const float x[] = {
+    0.00658502, -0.0496967, -0.0333733, -0.0263094,
+    -0.044298, 0.0211638, 0.0829358, -0.0172312,
+    -0.0665471, -0.10017, -0.0750333, -0.104551,
+    -0.00981208, -0.0583349, -0.0751652, 0.011747,
+    0.0151165, 0.0304321, 0.0736639, -0.00652653,
+    0.00962833, 0.169646, -0.044588, -0.00244141,
+    0.0597329, -0.0530868, 0.0124246, 0.108429,
+    0.0451175, 0.0247055, 0.0304345, 0.0179575
+  };
+  singa::CudaGPU cuda(0, 1);
+  singa::Tensor x_tensor(singa::Shape{1,2,4,4}, &cuda);
+  x_tensor.CopyDataFromHostPtr(x, 1*2*4*4);
+
+  const float dy[] = {
+    -0.103178, -0.0326904, 0.293932, 0.355288,
+    -0.0288079, -0.0543308, -0.0668226, 0.0462216,
+    -0.0448064, -0.068982, -0.0509133, -0.0721143,
+    0.0959078, -0.0389037, -0.0510071, -0.178793,
+    0.00428248, -0.001132, -0.19928, 0.011935,
+    0.00622313, 0.143793, 0.0253894, 0.0104906,
+    -0.170673, 0.0283919, 0.00523488, -0.0455003,
+    0.177807, 0.000892812, -0.00113197, 0.00327798
+  };
+
+  singa::Tensor dy_tensor(singa::Shape{1,2,4,4}, &cuda);
+  dy_tensor.CopyDataFromHostPtr(dy, 1*2*4*4);
+
+  singa::LayerConf conf;
+  singa::LRNConf *lrn_conf = conf.mutable_lrn_conf();
+  lrn_conf->set_k(1.0);
+  lrn_conf->set_local_size(3);
+  lrn_conf->set_alpha(0.1);
+  lrn_conf->set_beta(0.75);
+  lrn.Setup(conf);
+
+  lrn.Forward(singa::kTrain, x_tensor);
+  const auto ret = lrn.Backward(singa::kTrain, dy_tensor);
+  singa::CppCPU host(0, 1);
+  singa::Tensor dx = ret.first;
+  dx.ToDevice(&host);
+  const float *dxptr = dx.data<const float *>();
+  const auto & shape = dx.shape();
+  EXPECT_EQ(4u, shape.size());
+  EXPECT_EQ(1u, shape[0]);
+  EXPECT_EQ(2u, shape[1]);
+  EXPECT_EQ(4u, shape[2]);
+  EXPECT_EQ(4u, shape[3]);
+
+  EXPECT_NEAR(-0.103177, dxptr[0], 1e-6f);
+  EXPECT_NEAR(-0.0326837, dxptr[1], 1e-6f);
+  EXPECT_NEAR(0.293844, dxptr[2], 1e-6f);
+  EXPECT_NEAR(0.355269, dxptr[3], 1e-6f);
+  EXPECT_NEAR(-0.0288034, dxptr[4], 1e-6f);
+  EXPECT_NEAR(-0.0543157, dxptr[5], 1e-6f);
+  EXPECT_NEAR(-0.0667802, dxptr[6], 1e-6f);
+  EXPECT_NEAR(0.0462206, dxptr[7], 1e-6f);
+  EXPECT_NEAR(-0.0448215, dxptr[8], 1e-6f);
+  EXPECT_NEAR(-0.0689328, dxptr[9], 1e-6f);
+  EXPECT_NEAR(-0.0508914, dxptr[10], 1e-6f);
+  EXPECT_NEAR(-0.0720598, dxptr[11], 1e-6f);
+  EXPECT_NEAR(0.0959062, dxptr[12], 1e-6f);
+  EXPECT_NEAR(-0.0388931, dxptr[13], 1e-6f);
+  EXPECT_NEAR(-0.0509844, dxptr[14], 1e-6f);
+  EXPECT_NEAR(-0.17879, dxptr[15], 1e-6f);
+  EXPECT_NEAR(0.00428292, dxptr[16], 1e-6f);
+  EXPECT_NEAR(-0.00113432, dxptr[17], 1e-6f);
+  EXPECT_NEAR(-0.199158, dxptr[18], 1e-6f);
+  EXPECT_NEAR(0.0119317, dxptr[19], 1e-6f);
+  EXPECT_NEAR(0.00622216, dxptr[20], 1e-6f);
+  EXPECT_NEAR(0.143491, dxptr[21], 1e-6f);
+  EXPECT_NEAR(0.0253689, dxptr[22], 1e-6f);
+  EXPECT_NEAR(0.0104904, dxptr[23], 1e-6f);
+  EXPECT_NEAR(-0.170617, dxptr[24], 1e-6f);
+  EXPECT_NEAR(0.0283971, dxptr[25], 1e-6f);
+  EXPECT_NEAR(0.00523171, dxptr[26], 1e-6f);
+  EXPECT_NEAR(-0.0454887, dxptr[27], 1e-6f);
+  EXPECT_NEAR(0.177781, dxptr[28], 1e-6f);
+  EXPECT_NEAR(0.000889893, dxptr[29], 1e-6f);
+  EXPECT_NEAR(-0.00113756, dxptr[30], 1e-6f);
+  EXPECT_NEAR(0.00327978, dxptr[31], 1e-6f);
+}
+
+#endif  //  CUDNN_VERSION_MAJOR >= 4
+#endif  //  USE_CUDNN