You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@singa.apache.org by wa...@apache.org on 2016/06/17 07:52:29 UTC

incubator-singa git commit: SINGA-180 Add Activation layer and Softmax layer

Repository: incubator-singa
Updated Branches:
  refs/heads/dev 74f02143a -> b91002b55


SINGA-180 Add Activation layer and Softmax layer

Fix a bug in cudnn softmax and let softmax support 1D or 2D tensor as input.


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/b91002b5
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/b91002b5
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/b91002b5

Branch: refs/heads/dev
Commit: b91002b551781503507f0a15acfcd6e12279b765
Parents: 74f0214
Author: jixin <ji...@comp.nus.edu.sg>
Authored: Thu Jun 16 22:55:02 2016 +0800
Committer: Wei Wang <wa...@comp.nus.edu.sg>
Committed: Fri Jun 17 15:51:46 2016 +0800

----------------------------------------------------------------------
 include/singa/model/layer.h         |   2 +-
 src/model/layer/activation.cc       |  22 +++----
 src/model/layer/activation.h        |   2 +-
 src/model/layer/batchnorm.h         |   2 +-
 src/model/layer/convolution.h       |   2 +-
 src/model/layer/cudnn_activation.cc |   7 +--
 src/model/layer/cudnn_activation.h  |  11 ++--
 src/model/layer/cudnn_batchnorm.h   |  36 +++++------
 src/model/layer/cudnn_dropout.h     |   1 +
 src/model/layer/cudnn_lrn.h         |  32 +++++-----
 src/model/layer/cudnn_pooling.h     |   2 +-
 src/model/layer/cudnn_softmax.cc    |  40 +++++++++---
 src/model/layer/cudnn_softmax.h     |  11 +++-
 src/model/layer/dense.h             |   2 +-
 src/model/layer/dropout.h           |   2 +-
 src/model/layer/flatten.h           |   2 +-
 src/model/layer/lrn.h               |   2 +-
 src/model/layer/pooling.h           |   2 +-
 src/model/layer/prelu.h             |   2 +-
 src/model/layer/softmax.cc          |  23 +++----
 src/model/layer/softmax.h           |   6 +-
 src/proto/model.proto               |   4 ++
 test/singa/test_cudnn_activation.cc |   5 +-
 test/singa/test_cudnn_softmax.cc    | 105 +++++++++++++++++++++++++------
 24 files changed, 202 insertions(+), 123 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b91002b5/include/singa/model/layer.h
----------------------------------------------------------------------
diff --git a/include/singa/model/layer.h b/include/singa/model/layer.h
index 5f5c197..a505f15 100644
--- a/include/singa/model/layer.h
+++ b/include/singa/model/layer.h
@@ -79,7 +79,7 @@ class Layer {
   }
 
   /// Return the shape of the generated Tensor without the batchsize dimension
-  virtual const Shape GetOutputSampleShape() {
+  virtual const Shape GetOutputSampleShape() const {
     LOG(FATAL) << "Pls override this function";
     return vector<size_t>{};
   }

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b91002b5/src/model/layer/activation.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/activation.cc b/src/model/layer/activation.cc
index 2f76a6d..b00834f 100644
--- a/src/model/layer/activation.cc
+++ b/src/model/layer/activation.cc
@@ -33,19 +33,15 @@ const Tensor Activation::Forward(int flag, const Tensor& input) {
   Tensor output;
   if (mode_ == "SIGMOID") {
     output = Sigmoid(input);
-    if (flag & kTrain)
-      buf_.push(output);
+    if (flag & kTrain) buf_.push(output);
   } else if (mode_ == "TANH") {
     output = Tanh(input);
-    if (flag & kTrain)
-      buf_.push(output);
+    if (flag & kTrain) buf_.push(output);
   } else if (mode_ == "RELU") {
     output = ReLU(input);
-    if (flag & kTrain)
-      buf_.push(input);
-  } else {
+    if (flag & kTrain) buf_.push(input);
+  } else
     LOG(FATAL) << "Unkown activation: " << mode_;
-  }
   return output;
 }
 
@@ -57,15 +53,13 @@ const std::pair<Tensor, vector<Tensor>> Activation::Backward(
   // activation.
   Tensor input_grad, inout = buf_.top();
   buf_.pop();
-  if (mode_ == "SIGMOID") {
+  if (mode_ == "SIGMOID")
     input_grad = grad * inout * (inout * (-1.f) + 1.f);
-  } else if (mode_ == "TANH") {
+  else if (mode_ == "TANH")
     input_grad = grad * (inout * inout * (-1.f) + 1.f);
-  } else if (mode_ == "RELU") {
+  else if (mode_ == "RELU")
     input_grad = grad * (inout > 0.f) + (inout <= 0.f) * neg_slope_;
-  } else {
-    LOG(FATAL) << "Unkown activation: " << mode_;
-  }
+  else LOG(FATAL) << "Unkown activation: " << mode_;
   return std::make_pair(input_grad, param_grad);
 }
 

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b91002b5/src/model/layer/activation.h
----------------------------------------------------------------------
diff --git a/src/model/layer/activation.h b/src/model/layer/activation.h
index 1799514..db3a8f5 100644
--- a/src/model/layer/activation.h
+++ b/src/model/layer/activation.h
@@ -30,7 +30,7 @@ class Activation : public Layer {
 
   /// \copydoc Layer::Setup(const LayerConf&);
   void Setup(const Shape& in_sample, const LayerConf& conf) override;
-  const Shape GetOutputSampleShape() {
+  const Shape GetOutputSampleShape() const override {
     CHECK(out_sample_shape_.size()) << "You may haven't call Setup()";
     return out_sample_shape_;
   }

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b91002b5/src/model/layer/batchnorm.h
----------------------------------------------------------------------
diff --git a/src/model/layer/batchnorm.h b/src/model/layer/batchnorm.h
index 433e0c7..35b05b1 100644
--- a/src/model/layer/batchnorm.h
+++ b/src/model/layer/batchnorm.h
@@ -35,7 +35,7 @@ class BatchNorm : public Layer {
 
   /// \copydoc Layer::Setup(const LayerConf&);
   void Setup(const Shape& in_sample, const LayerConf& conf) override;
-  const Shape GetOutputSampleShape() {
+  const Shape GetOutputSampleShape() const override {
     CHECK(out_sample_shape_.size()) << "You may haven't call Setup()";
     return out_sample_shape_;
   }

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b91002b5/src/model/layer/convolution.h
----------------------------------------------------------------------
diff --git a/src/model/layer/convolution.h b/src/model/layer/convolution.h
index 3901049..7ea5712 100644
--- a/src/model/layer/convolution.h
+++ b/src/model/layer/convolution.h
@@ -31,7 +31,7 @@ class Convolution : public Layer {
 
   /// \copydoc Layer::Setup(const LayerConf&);
   void Setup(const vector<size_t>& in_shape, const LayerConf& conf) override;
-  const Shape GetOutputSampleShape() {
+  const Shape GetOutputSampleShape() const override {
     CHECK(out_sample_shape_.size()) << "You may haven't call Setup()";
     return out_sample_shape_;
   }

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b91002b5/src/model/layer/cudnn_activation.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/cudnn_activation.cc b/src/model/layer/cudnn_activation.cc
index 98a5758..72352b8 100644
--- a/src/model/layer/cudnn_activation.cc
+++ b/src/model/layer/cudnn_activation.cc
@@ -48,9 +48,8 @@ void CudnnActivation::InitCudnn(size_t size, DataType dtype) {
   else
     LOG(FATAL) << "Unkown activation: " << mode_;
 
-  nan_opt_ = CUDNN_PROPAGATE_NAN;
-  CUDNN_CHECK(
-      cudnnSetActivationDescriptor(acti_desc_, cudnn_mode_, nan_opt_, 0.0f));
+  CUDNN_CHECK(cudnnSetActivationDescriptor(
+        acti_desc_, cudnn_mode_, CUDNN_PROPAGATE_NAN, 0.0f));
   has_init_cudnn_ = true;
 }
 
@@ -89,7 +88,7 @@ const Tensor CudnnActivation::Forward(int flag, const Tensor& input) {
 const std::pair<Tensor, vector<Tensor>> CudnnActivation::Backward(
     int flag, const Tensor& grad) {
   vector<Tensor> param_grad;
-  Tensor dx;  // inout = buf_.top();
+  Tensor dx;
   CHECK(!buf_.empty());
   // inout means either used as input or output, only one is valid for one type
   // of activation

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b91002b5/src/model/layer/cudnn_activation.h
----------------------------------------------------------------------
diff --git a/src/model/layer/cudnn_activation.h b/src/model/layer/cudnn_activation.h
index b572db7..71bede5 100644
--- a/src/model/layer/cudnn_activation.h
+++ b/src/model/layer/cudnn_activation.h
@@ -41,16 +41,17 @@ class CudnnActivation : public Activation {
   const std::pair<Tensor, vector<Tensor>> Backward(int flag,
                                                    const Tensor& grad) override;
 
-  /// Init cudnn related data structures.
-  void InitCudnn(size_t size, DataType dtype);
 
   const cudnnActivationMode_t CudnnMode() const { return cudnn_mode_; }
 
  private:
+  /// Init cudnn related data structures.
+  void InitCudnn(size_t size, DataType dtype);
+
+ private:
   bool has_init_cudnn_ = false;
-  cudnnActivationDescriptor_t acti_desc_;
-  cudnnTensorDescriptor_t desc_;
-  cudnnNanPropagation_t nan_opt_;
+  cudnnActivationDescriptor_t acti_desc_ = nullptr;
+  cudnnTensorDescriptor_t desc_ = nullptr;
   cudnnActivationMode_t cudnn_mode_;
 };
 }  // namespace

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b91002b5/src/model/layer/cudnn_batchnorm.h
----------------------------------------------------------------------
diff --git a/src/model/layer/cudnn_batchnorm.h b/src/model/layer/cudnn_batchnorm.h
index 47fd4c5..36dbbce 100644
--- a/src/model/layer/cudnn_batchnorm.h
+++ b/src/model/layer/cudnn_batchnorm.h
@@ -29,31 +29,29 @@
 namespace singa {
 class CudnnBatchNorm : public BatchNorm {
  public:
-   ~CudnnBatchNorm();
-   /// \copy doc Layer::layer_type()
-   const std::string layer_type() const override {
-     return "CudnnBatchNorm";
-   }
+  ~CudnnBatchNorm();
+  /// \copy doc Layer::layer_type()
+  const std::string layer_type() const override { return "CudnnBatchNorm"; }
 
-   void Setup(const Shape& in_sample, const LayerConf& conf) override;
+  void Setup(const Shape& in_sample, const LayerConf& conf) override;
 
-   const Tensor Forward(int flag, const Tensor& input)
-     override;
-   const std::pair<Tensor, vector<Tensor>> Backward(
-       int flag, const Tensor& grad) override;
+  const Tensor Forward(int flag, const Tensor& input) override;
+  const std::pair<Tensor, vector<Tensor>> Backward(int flag,
+                                                   const Tensor& grad) override;
+  void ToDevice(Device* device) override;
 
-   /// Init cudnn related data structures.
-   void InitCudnn(const Shape& shape, DataType dtype);
-   void ToDevice(Device* device) override;
+ private:
+  /// Init cudnn related data structures.
+  void InitCudnn(const Shape& shape, DataType dtype);
 
  private:
-   bool has_init_cudnn_ = false;
-   cudnnBatchNormMode_t mode_;
-   cudnnLRNDescriptor_t lrn_desc_;
-   cudnnTensorDescriptor_t shape_desc_, param_desc_;
-   Tensor resultSaveMean_, resultSaveVariance_;
+  bool has_init_cudnn_ = false;
+  cudnnBatchNormMode_t mode_;
+  cudnnLRNDescriptor_t lrn_desc_ = nullptr;
+  cudnnTensorDescriptor_t shape_desc_ = nullptr, param_desc_ = nullptr;
+  Tensor resultSaveMean_, resultSaveVariance_;
 
-}; // class CudnnBatchNorm
+};  // class CudnnBatchNorm
 }  // namespace
 
 #endif  // USE_CUDNN

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b91002b5/src/model/layer/cudnn_dropout.h
----------------------------------------------------------------------
diff --git a/src/model/layer/cudnn_dropout.h b/src/model/layer/cudnn_dropout.h
index 7cb185b..83572cf 100644
--- a/src/model/layer/cudnn_dropout.h
+++ b/src/model/layer/cudnn_dropout.h
@@ -42,6 +42,7 @@ class CudnnDropout : public Dropout {
   const std::pair<Tensor, vector<Tensor>> Backward(int flag,
                                                    const Tensor& grad) override;
 
+ private:
   /// Init cudnn related data structures.
   void InitCudnn(int size, DataType dtype, Device* dev, Context* ctx);
 

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b91002b5/src/model/layer/cudnn_lrn.h
----------------------------------------------------------------------
diff --git a/src/model/layer/cudnn_lrn.h b/src/model/layer/cudnn_lrn.h
index 0f650fe..ddf4a37 100644
--- a/src/model/layer/cudnn_lrn.h
+++ b/src/model/layer/cudnn_lrn.h
@@ -29,27 +29,25 @@
 namespace singa {
 class CudnnLRN : public LRN {
  public:
-   ~CudnnLRN();
-   /// \copy doc Layer::layer_type()
-   const std::string layer_type() const override {
-     return "CudnnLRN";
-   }
+  ~CudnnLRN();
+  /// \copy doc Layer::layer_type()
+  const std::string layer_type() const override { return "CudnnLRN"; }
 
-   const Tensor Forward(int flag, const Tensor& input)
-     override;
-   const std::pair<Tensor, vector<Tensor>> Backward(
-       int flag, const Tensor& grad) override;
+  const Tensor Forward(int flag, const Tensor& input) override;
+  const std::pair<Tensor, vector<Tensor>> Backward(int flag,
+                                                   const Tensor& grad) override;
 
-   /// Init cudnn related data structures.
-   void InitCudnn(const Shape& shape, DataType dtype);
+ private:
+  /// Init cudnn related data structures.
+  void InitCudnn(const Shape& shape, DataType dtype);
 
  private:
-   bool has_init_cudnn_ = false;
-   cudnnLRNMode_t mode_;
-   cudnnLRNDescriptor_t lrn_desc_;
-   cudnnTensorDescriptor_t shape_desc_;
-   
-}; // class CudnnLRN
+  bool has_init_cudnn_ = false;
+  cudnnLRNMode_t mode_;
+  cudnnLRNDescriptor_t lrn_desc_ = nullptr;
+  cudnnTensorDescriptor_t shape_desc_ = nullptr;
+
+};  // class CudnnLRN
 }  // namespcae
 
 #endif  // USE_CUDNN

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b91002b5/src/model/layer/cudnn_pooling.h
----------------------------------------------------------------------
diff --git a/src/model/layer/cudnn_pooling.h b/src/model/layer/cudnn_pooling.h
index c3c7060..c323222 100644
--- a/src/model/layer/cudnn_pooling.h
+++ b/src/model/layer/cudnn_pooling.h
@@ -41,7 +41,7 @@ class CudnnPooling : public Pooling {
   const Tensor Forward(int flag, const Tensor &input) override;
   const std::pair<Tensor, vector<Tensor>> Backward(int flag,
                                                    const Tensor &grad) override;
-
+ private:
   /// Init cudnn related data structures.
   void InitCudnn(const Tensor& input);
 

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b91002b5/src/model/layer/cudnn_softmax.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/cudnn_softmax.cc b/src/model/layer/cudnn_softmax.cc
index 16d4022..7efc797 100644
--- a/src/model/layer/cudnn_softmax.cc
+++ b/src/model/layer/cudnn_softmax.cc
@@ -26,30 +26,49 @@ CudnnSoftmax::~CudnnSoftmax() {
   if (desc_ != nullptr) CUDNN_CHECK(cudnnDestroyTensorDescriptor(desc_));
 }
 
-void CudnnSoftmax::InitCudnn(size_t size, DataType dtype) {
+void CudnnSoftmax::Setup(const Shape& in_sample, const LayerConf &conf) {
+  Softmax::Setup(in_sample, conf);
+  SoftmaxConf sft_conf = conf.softmax_conf();
+  std::string algorithm = sft_conf.algorithm();
+  CHECK(algorithm == "accurate" || algorithm == "fast" || algorithm == "log")
+    << "CudnnSoftmax only supports three algorithm preferences: "
+    << "accurate, fast and log.";
+  if (algorithm == "accurate")
+    algorithm_ = CUDNN_SOFTMAX_ACCURATE;
+  else if (algorithm == "fast")
+    algorithm_ = CUDNN_SOFTMAX_FAST;
+  else algorithm_ = CUDNN_SOFTMAX_LOG;
+}
+
+void CudnnSoftmax::InitCudnn(Shape shape, DataType dtype) {
   CHECK(!has_init_cudnn_);
   CUDNN_CHECK(cudnnCreateTensorDescriptor(&desc_));
 
-  CUDNN_CHECK(cudnnSetTensor4dDescriptor(
-      desc_, CUDNN_TENSOR_NCHW, GetCudnnDataType(dtype), 1, 1, 1, size));
-
-  algorithm_ = CUDNN_SOFTMAX_ACCURATE;
-  mode_ = CUDNN_SOFTMAX_MODE_INSTANCE;
+  CHECK_LE(shape.size(), 2u)
+    << "Tensor shape should range from 1 to 2D;"
+    << "otherwise, add flatten layer to transform";
+  if (shape.size() == 1u)
+    CUDNN_CHECK(cudnnSetTensor4dDescriptor( desc_,
+      CUDNN_TENSOR_NCHW, GetCudnnDataType(dtype), 1, shape[0], 1, 1));
+  else
+    CUDNN_CHECK(cudnnSetTensor4dDescriptor( desc_, CUDNN_TENSOR_NCHW,
+      GetCudnnDataType(dtype), shape[0], shape[1], 1, 1));
   has_init_cudnn_ = true;
 }
 
 const Tensor CudnnSoftmax::Forward(int flag, const Tensor& input) {
-  auto size = input.Size();
+  auto shape = input.shape();
   DataType dtype = input.data_type();
   if (!has_init_cudnn_) {
-    InitCudnn(size, dtype);
+    InitCudnn(shape, dtype);
   }
   Tensor output;
   output.ResetLike(input);
   output.device()->Exec([input, output, this](Context* ctx) {
     Block* inblock = input.block(), * outblock = output.block();
     float alpha = 1.0f, beta = 0.0f;
-    cudnnSoftmaxForward(ctx->cudnn_handle, this->algorithm_, this->mode_,
+    cudnnSoftmaxForward(ctx->cudnn_handle, this->algorithm_,
+                        CUDNN_SOFTMAX_MODE_INSTANCE,
                         &alpha, this->desc_, inblock->data(), &beta,
                         this->desc_, outblock->mutable_data());
   }, {input.block()}, {output.block()});
@@ -68,7 +87,8 @@ const std::pair<Tensor, vector<Tensor>> CudnnSoftmax::Backward(
     Block* dyblock = grad.block(), * dxblock = dx.block(),
            * yblock = output.block();
     float alpha = 1.0f, beta = 0.0f;
-    cudnnSoftmaxBackward(ctx->cudnn_handle, this->algorithm_, this->mode_,
+    cudnnSoftmaxBackward(ctx->cudnn_handle, this->algorithm_,
+                         CUDNN_SOFTMAX_MODE_INSTANCE,
                          &alpha, this->desc_, yblock->data(), this->desc_,
                          dyblock->data(), &beta, this->desc_,
                          dxblock->mutable_data());

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b91002b5/src/model/layer/cudnn_softmax.h
----------------------------------------------------------------------
diff --git a/src/model/layer/cudnn_softmax.h b/src/model/layer/cudnn_softmax.h
index ee92d6f..aca3729 100644
--- a/src/model/layer/cudnn_softmax.h
+++ b/src/model/layer/cudnn_softmax.h
@@ -36,18 +36,23 @@ class CudnnSoftmax : public Softmax {
   /// \copydoc Layer::layer_type()
   const std::string layer_type() const override { return "CudnnSoftmax"; }
 
+  /// \copydoc Layer::Setup(const LayerConf&);
+  void Setup(const Shape& in_sample_shape, const LayerConf &conf) override;
+
   const Tensor Forward(int flag, const Tensor& input) override;
   const std::pair<Tensor, vector<Tensor>> Backward(int flag,
                                                    const Tensor& grad) override;
 
+  const cudnnSoftmaxAlgorithm_t Algorithm() const { return algorithm_; }
+
+ private:
   /// Init cudnn related data structures.
-  void InitCudnn(size_t size, DataType dtype);
+  void InitCudnn(Shape shape, DataType dtype);
 
  private:
   bool has_init_cudnn_ = false;
-  cudnnTensorDescriptor_t desc_;
+  cudnnTensorDescriptor_t desc_ = nullptr;
   cudnnSoftmaxAlgorithm_t algorithm_;
-  cudnnSoftmaxMode_t mode_;
 };
 }  // namespace
 #endif  // USE_CUDNN

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b91002b5/src/model/layer/dense.h
----------------------------------------------------------------------
diff --git a/src/model/layer/dense.h b/src/model/layer/dense.h
index 6704106..8438d5c 100644
--- a/src/model/layer/dense.h
+++ b/src/model/layer/dense.h
@@ -32,7 +32,7 @@ class Dense : public Layer {
 
   /// \copydoc Layer::Setup(const LayerConf&);
   void Setup(const Shape& in_sample, const LayerConf& conf) override;
-  const Shape GetOutputSampleShape() {
+  const Shape GetOutputSampleShape() const override {
     CHECK(hdim_) << "You may haven't call Setup()";
     return vector<size_t>{hdim_};
   }

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b91002b5/src/model/layer/dropout.h
----------------------------------------------------------------------
diff --git a/src/model/layer/dropout.h b/src/model/layer/dropout.h
index e9ff798..14be6a0 100644
--- a/src/model/layer/dropout.h
+++ b/src/model/layer/dropout.h
@@ -30,7 +30,7 @@ class Dropout : public Layer {
 
   /// \copydoc Layer::Setup(const LayerConf&);
   void Setup(const Shape& in_sample, const LayerConf& conf) override;
-  const Shape GetOutputSampleShape() {
+  const Shape GetOutputSampleShape() const override {
     CHECK(out_sample_shape_.size()) << "You may haven't call Setup()";
     return out_sample_shape_;
   }

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b91002b5/src/model/layer/flatten.h
----------------------------------------------------------------------
diff --git a/src/model/layer/flatten.h b/src/model/layer/flatten.h
index 0981f32..6ac90c2 100644
--- a/src/model/layer/flatten.h
+++ b/src/model/layer/flatten.h
@@ -30,7 +30,7 @@ class Flatten : public Layer {
 
   /// \copydoc Layer::Setup(const LayerConf&);
   void Setup(const Shape& in_sample, const LayerConf& conf) override;
-  const Shape GetOutputSampleShape() {
+  const Shape GetOutputSampleShape() const override {
     CHECK(out_sample_shape_.size()) << "You may haven't call Setup()";
     return out_sample_shape_;
   }

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b91002b5/src/model/layer/lrn.h
----------------------------------------------------------------------
diff --git a/src/model/layer/lrn.h b/src/model/layer/lrn.h
index a165d12..0632f8c 100644
--- a/src/model/layer/lrn.h
+++ b/src/model/layer/lrn.h
@@ -33,7 +33,7 @@ class LRN : public Layer {
 
   /// \copydoc Layer::Setup(const LayerConf&);
   void Setup(const Shape& in_sample, const LayerConf& conf) override;
-  const Shape GetOutputSampleShape() {
+  const Shape GetOutputSampleShape() const override {
     CHECK(out_sample_shape_.size()) << "You may haven't call Setup()";
     return out_sample_shape_;
   }

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b91002b5/src/model/layer/pooling.h
----------------------------------------------------------------------
diff --git a/src/model/layer/pooling.h b/src/model/layer/pooling.h
index ddee45b..26a1d07 100644
--- a/src/model/layer/pooling.h
+++ b/src/model/layer/pooling.h
@@ -31,7 +31,7 @@ class Pooling : public Layer {
 
   /// \copydoc Layer::Setup(const LayerConf&);
   void Setup(const Shape& in_sample, const LayerConf& conf) override;
-  const Shape GetOutputSampleShape() {
+  const Shape GetOutputSampleShape() const override {
     CHECK(out_sample_shape_.size()) << "You may haven't call Setup()";
     return out_sample_shape_;
   }

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b91002b5/src/model/layer/prelu.h
----------------------------------------------------------------------
diff --git a/src/model/layer/prelu.h b/src/model/layer/prelu.h
index 7387bfb..ee571e1 100644
--- a/src/model/layer/prelu.h
+++ b/src/model/layer/prelu.h
@@ -32,7 +32,7 @@ class PReLU : public Layer {
 
   /// \copydoc Layer::Setup(const LayerConf&);
   void Setup(const Shape& in_sample, const LayerConf& conf) override;
-  const Shape GetOutputSampleShape() {
+  const Shape GetOutputSampleShape() const override {
     CHECK(out_sample_shape_.size()) << "You may haven't call Setup()";
     return out_sample_shape_;
   }

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b91002b5/src/model/layer/softmax.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/softmax.cc b/src/model/layer/softmax.cc
index 25bb9fe..cccb06b 100644
--- a/src/model/layer/softmax.cc
+++ b/src/model/layer/softmax.cc
@@ -21,8 +21,6 @@ namespace singa {
 
 void Softmax::Setup(const Shape& in_sample, const LayerConf& conf) {
   Layer::Setup(in_sample, conf);
-  // TODO(wangwei) disable axis, use a flatten layer to reshape the tensor.
-  // axis_ = conf.softmax_conf().axis();  // default is 1
   CHECK_EQ(in_sample.size(), 1u);
   out_sample_shape_ = in_sample;
 }
@@ -30,11 +28,6 @@ void Softmax::Setup(const Shape& in_sample, const LayerConf& conf) {
 const Tensor Softmax::Forward(int flag, const Tensor& input) {
   CHECK_LE(input.nDim(), 2u);
   Tensor output =  SoftMax(input);
-  /*
-  size_t nrow = Product(input.shape(), 0, axis_);
-  const Tensor& tmp = Reshape(input, Shape{nrow, input.Size() / nrow});
-  output = SoftMax(tmp);
-  */
   if (flag & kTrain)
     buf_.push(output);
   return output;
@@ -43,19 +36,21 @@ const Tensor Softmax::Forward(int flag, const Tensor& input) {
 const std::pair<Tensor, vector<Tensor>> Softmax::Backward(int flag,
                                                           const Tensor& grad) {
   CHECK_LE(grad.nDim(), 2u);
-  size_t nrow = 1, ncol = grad.Size();
   Tensor input_grad = grad.Clone();
+  CHECK(!buf_.empty());
+  Tensor y = buf_.top();
+  buf_.pop();
+  CHECK(y.shape() == input_grad.shape());
+  Tensor sigma = input_grad * y;
+
+  size_t nrow = 1, ncol = grad.Size();
   if (grad.nDim() > 1) {
     nrow = grad.shape(0);
     ncol = grad.shape(1);
   } else {
     input_grad.Reshape({nrow, ncol});
+    sigma.Reshape({nrow, ncol});
   }
-  CHECK(!buf_.empty());
-  Tensor y = buf_.top();
-  buf_.pop();
-  CHECK(y.shape() == input_grad.shape());
-  Tensor sigma = input_grad * y;
   Tensor sum(Shape{nrow}, grad.device(), grad.data_type());
   SumColumns(sigma, &sum);
   // dL / dy_i = grad_i
@@ -65,6 +60,8 @@ const std::pair<Tensor, vector<Tensor>> Softmax::Backward(int flag,
   // dL / dx_i = y_i * (grad_i - sum), where sum = sum_i(grad_i * y_i);
   SubColumn(sum, &input_grad);
   input_grad = input_grad * y;
+  if (grad.nDim() == 1)
+    input_grad.Reshape(Shape{ncol});
   // Mult(input_grad, y, &input_grad);
   vector<Tensor> param_grad;
   return std::make_pair(input_grad, param_grad);

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b91002b5/src/model/layer/softmax.h
----------------------------------------------------------------------
diff --git a/src/model/layer/softmax.h b/src/model/layer/softmax.h
index fed544e..837b23a 100644
--- a/src/model/layer/softmax.h
+++ b/src/model/layer/softmax.h
@@ -20,6 +20,7 @@
 #include "singa/model/layer.h"
 #include <stack>
 namespace singa {
+/// Do softmax for 1D or 2D tensors along the last dimension.
 class Softmax : public Layer {
  public:
   /// \copydoc Layer::layer_type()
@@ -27,7 +28,7 @@ class Softmax : public Layer {
 
   /// \copydoc Layer::Setup(const LayerConf&);
   void Setup(const Shape& in_sample, const LayerConf& conf) override;
-  const Shape GetOutputSampleShape() {
+  const Shape GetOutputSampleShape() const override {
     CHECK(out_sample_shape_.size()) << "You may haven't call Setup()";
     return out_sample_shape_;
   }
@@ -39,10 +40,7 @@ class Softmax : public Layer {
   const std::pair<Tensor, vector<Tensor>> Backward(int flag,
                                                    const Tensor& grad) override;
 
-  const int Axis() const { return axis_; }
-
  protected:
-  int axis_;
   std::stack<Tensor> buf_;
   Shape out_sample_shape_;
 };

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b91002b5/src/proto/model.proto
----------------------------------------------------------------------
diff --git a/src/proto/model.proto b/src/proto/model.proto
index e9746c1..c06deec 100644
--- a/src/proto/model.proto
+++ b/src/proto/model.proto
@@ -829,6 +829,10 @@ message SoftmaxConf {
   // from the end (e.g., -1 for the last axis).
   // Any other axes will be evaluated as independent softmaxes.
   // optional int32 axis = 2 [default = 1];
+
+  /// The cudnn algorithm preferences
+  /// Options are: accurate, fast and log
+  optional string algorithm = 50 [default = "accurate"];
 }
 
 message TanHConf {

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b91002b5/test/singa/test_cudnn_activation.cc
----------------------------------------------------------------------
diff --git a/test/singa/test_cudnn_activation.cc b/test/singa/test_cudnn_activation.cc
index 0dac497..da8ec62 100644
--- a/test/singa/test_cudnn_activation.cc
+++ b/test/singa/test_cudnn_activation.cc
@@ -39,8 +39,7 @@ TEST(TCudnnActivation, Setup) {
   reluconf->set_negative_slope(0.5f);
 
   acti.Setup(Shape{3}, conf);
-  acti.InitCudnn(1, singa::kFloat32);
-  EXPECT_EQ(CUDNN_ACTIVATION_RELU, acti.CudnnMode());
+//  EXPECT_EQ(CUDNN_ACTIVATION_RELU, acti.CudnnMode());
   EXPECT_EQ(0.5f, acti.Negative_slope());
 }
 
@@ -63,7 +62,6 @@ TEST(TCudnnActivation, Forward) {
       reluconf->set_negative_slope(neg_slope);
     }
     acti.Setup(Shape{n}, conf);
-    // acti.InitCudnn(n, singa::kFloat32);
 
     singa::Tensor out = acti.Forward(singa::kTrain, in);
     EXPECT_EQ(n, out.Size());
@@ -103,7 +101,6 @@ TEST(TCudnnActivation, Backward) {
       reluconf->set_negative_slope(neg_slope);
     }
     acti.Setup(Shape{n}, conf);
-    acti.InitCudnn(n, singa::kFloat32);
     singa::Tensor out = acti.Forward(singa::kTrain, in);
     EXPECT_EQ(n, out.Size());
     singa::CppCPU host(0, 1);

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b91002b5/test/singa/test_cudnn_softmax.cc
----------------------------------------------------------------------
diff --git a/test/singa/test_cudnn_softmax.cc b/test/singa/test_cudnn_softmax.cc
index d671ecf..067491f 100644
--- a/test/singa/test_cudnn_softmax.cc
+++ b/test/singa/test_cudnn_softmax.cc
@@ -34,23 +34,25 @@ TEST(CudnnSoftmax, Setup) {
   EXPECT_EQ("CudnnSoftmax", sft.layer_type());
 
   singa::LayerConf conf;
-
-  sft.Setup(Shape{4}, conf);
-  sft.InitCudnn(1, singa::kFloat32);
+  singa::SoftmaxConf* softmaxconf = conf.mutable_softmax_conf();
+  softmaxconf->set_algorithm("fast");
+  sft.Setup(Shape{1}, conf);
+  EXPECT_EQ(CUDNN_SOFTMAX_FAST, sft.Algorithm());
 }
 
-TEST(CudnnSoftmax, Forward) {
-  const float x[] = {1.0f, 2.0f, 0.0f, -2.0f, -3.0f, -1.0};
+TEST(CudnnSoftmax, Forward1D) {
+  const float x[] = {1.f, 2.f, 0.f, -2.f, -3.f, -1.f};
   size_t n = sizeof(x) / sizeof(float);
   singa::CudaGPU cuda(0, 1);
-  singa::Tensor in(singa::Shape{n}, &cuda);
+  singa::Shape shape = {n};
+  singa::Tensor in(shape, &cuda);
   in.CopyDataFromHostPtr<float>(x, n);
 
   CudnnSoftmax sft;
   singa::LayerConf conf;
+  singa::SoftmaxConf* softmaxconf = conf.mutable_softmax_conf();
+  softmaxconf->set_algorithm("accurate");
   sft.Setup(Shape{1}, conf);
-  sft.InitCudnn(n, singa::kFloat32);
-
   singa::Tensor out = sft.Forward(singa::kTrain, in);
   singa::CppCPU host(0, 1);
   out.ToDevice(&host);
@@ -61,28 +63,30 @@ TEST(CudnnSoftmax, Forward) {
   float sigma = 0.f;
   for (size_t i = 0; i < n; i++) sigma += exp(x[i]);
   for (size_t i = 0; i < n; i++) y[i] = exp(x[i]) / sigma;
-  EXPECT_FLOAT_EQ(y[0], yptr[0]);
-  EXPECT_FLOAT_EQ(y[4], yptr[4]);
-  EXPECT_FLOAT_EQ(y[5], yptr[5]);
+  for (size_t i = 0; i < n; i++) EXPECT_FLOAT_EQ(y[i], yptr[i]);
 }
 
-TEST(CudnnSoftmax, Backward) {
-  const float x[] = {1.0f, 2.0f, 3.0f, -2.0f, -3.0f, -1.0};
+TEST(CudnnSoftmax, Backward1D) {
+  const float x[] = {1.f, 2.f, 3.f, -2.f, -3.f, -1.f};
   size_t n = sizeof(x) / sizeof(float);
   singa::CudaGPU cuda(0, 1);
-  singa::Tensor in(singa::Shape{n}, &cuda);
+  singa::Shape shape = {n};
+  singa::Tensor in(shape, &cuda);
   in.CopyDataFromHostPtr<float>(x, n);
 
   CudnnSoftmax sft;
   singa::LayerConf conf;
+  singa::SoftmaxConf* softmaxconf = conf.mutable_softmax_conf();
+  softmaxconf->set_algorithm("accurate");
   sft.Setup(Shape{1}, conf);
+
   singa::Tensor out = sft.Forward(singa::kTrain, in);
   singa::CppCPU host(0, 1);
   out.ToDevice(&host);
   const float* yptr = out.data<const float*>();
 
-  const float grad[] = {2.0f, -3.0f, 1.0f, 3.0f, -1.0f, -2.0};
-  singa::Tensor out_diff(singa::Shape{n}, &cuda);
+  const float grad[] = {2.f, -3.f, 1.f, 3.f, -1.f, -2.f};
+  singa::Tensor out_diff(shape, &cuda);
   out_diff.CopyDataFromHostPtr<float>(grad, n);
   const auto ret = sft.Backward(singa::kTrain, out_diff);
   singa::Tensor in_diff = ret.first;
@@ -93,8 +97,71 @@ TEST(CudnnSoftmax, Backward) {
   float sigma = 0.f;
   for (size_t i = 0; i < n; i++) sigma += grad[i] * yptr[i];
   for (size_t i = 0; i < n; i++) dx[i] = (grad[i] - sigma) * yptr[i];
-  EXPECT_FLOAT_EQ(dx[0], xptr[0]);
-  EXPECT_FLOAT_EQ(dx[4], xptr[4]);
-  EXPECT_FLOAT_EQ(dx[5], xptr[5]);
+  for (size_t i = 0; i < n; i++) EXPECT_FLOAT_EQ(dx[i], xptr[i]);
+}
+
+TEST(CudnnSoftmax, Forward2D) {
+  const float x[] = {1.f, 2.f, 0.f, -2.f, -3.f, -1.f};
+  size_t n = sizeof(x) / sizeof(float);
+  size_t batch = 2, c = 3;
+  singa::CudaGPU cuda(0, 1);
+  singa::Shape shape = {batch, c};
+  singa::Tensor in(shape, &cuda);
+  in.CopyDataFromHostPtr<float>(x, n);
+
+  CudnnSoftmax sft;
+  singa::LayerConf conf;
+  singa::SoftmaxConf* softmaxconf = conf.mutable_softmax_conf();
+  softmaxconf->set_algorithm("accurate");
+  sft.Setup(Shape{c}, conf);
+
+  singa::Tensor out = sft.Forward(singa::kTrain, in);
+  singa::CppCPU host(0, 1);
+  out.ToDevice(&host);
+  const float* yptr = out.data<const float*>();
+  EXPECT_EQ(n, out.Size());
+
+  float* y = new float[n];
+  float* sigma = new float[batch];
+  for (size_t i = 0; i < batch; i++) sigma[i] = 0.f;
+  for (size_t i = 0; i < n; i++) sigma[i / c] += exp(x[i]);
+  for (size_t i = 0; i < n; i++) y[i] = exp(x[i]) / sigma[i / c];
+  for (size_t i = 0; i < n; i++) EXPECT_FLOAT_EQ(y[i], yptr[i]);
+}
+
+TEST(CudnnSoftmax, Backward2D) {
+  const float x[] = {1.f, 2.f, 3.f, -2.f, -3.f, -1.f};
+  size_t n = sizeof(x) / sizeof(float);
+  size_t batch = 2, c = 3;
+  singa::CudaGPU cuda(0, 1);
+  singa::Shape shape = {batch, c};
+  singa::Tensor in(shape, &cuda);
+  in.CopyDataFromHostPtr<float>(x, n);
+
+  CudnnSoftmax sft;
+  singa::LayerConf conf;
+  singa::SoftmaxConf* softmaxconf = conf.mutable_softmax_conf();
+  softmaxconf->set_algorithm("accurate");
+  sft.Setup(Shape{c}, conf);
+
+  singa::Tensor out = sft.Forward(singa::kTrain, in);
+  singa::CppCPU host(0, 1);
+  out.ToDevice(&host);
+  const float* yptr = out.data<const float*>();
+
+  const float grad[] = {2.f, -3.f, 1.f, 3.f, -1.f, -2.f};
+  singa::Tensor out_diff(shape, &cuda);
+  out_diff.CopyDataFromHostPtr<float>(grad, n);
+  const auto ret = sft.Backward(singa::kTrain, out_diff);
+  singa::Tensor in_diff = ret.first;
+  in_diff.ToDevice(&host);
+  const float* xptr = in_diff.data<const float*>();
+
+  float* dx = new float[n];
+  float* sigma = new float[batch];
+  for (size_t i = 0; i < batch; i++) sigma[i] = 0.f;
+  for (size_t i = 0; i < n; i++) sigma[i / c] += grad[i] * yptr[i];
+  for (size_t i = 0; i < n; i++) dx[i] = (grad[i] - sigma[i / c]) * yptr[i];
+  for (size_t i = 0; i < n; i++) EXPECT_FLOAT_EQ(dx[i], xptr[i]);
 }
 #endif  // USE_CUDNN