You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@singa.apache.org by zh...@apache.org on 2016/06/13 13:20:16 UTC

[23/50] [abbrv] incubator-singa git commit: SINGA-180 Add Activation layer and Softmax layer

SINGA-180 Add Activation layer and Softmax layer

Add cpu and cudnn implementation for activation and softmax layer.

Note: activation layer currently support sigmoid/tanh function and relu forward computation.

Remove tensor softmax function. Instead, use tensor op(*) and function(Sum) to impletment softmax function.

Add test files for activation and softmax layer.

Add Element-wise implementation for activation functions (relu/tanh/sigmoid).

Add tensor scaler comparison function (<, <=, >, >=), i.e., to compare a tensor with a constant.

Add implementation for tensor math functions (exp, log, pow).

Add functions for matrix op vector, where op is multiply and div.

Pass all tests.


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/3e2507b7
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/3e2507b7
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/3e2507b7

Branch: refs/heads/master
Commit: 3e2507b7af8c4fe3746f3156f29eba99a30e546f
Parents: 2dac380
Author: jixin <ji...@comp.nus.edu.sg>
Authored: Fri May 27 22:03:35 2016 +0800
Committer: Wei Wang <wa...@comp.nus.edu.sg>
Committed: Tue May 31 22:08:31 2016 +0800

----------------------------------------------------------------------
 include/singa/core/tensor.h         | 107 +++++++++++++++++-----
 src/core/tensor/math_kernel.cu      | 132 +++++++++++++++++----------
 src/core/tensor/math_kernel.h       |   6 +-
 src/core/tensor/tensor.cc           | 152 ++++++++++++++++---------------
 src/core/tensor/tensor_math.h       |  47 +++++++++-
 src/core/tensor/tensor_math_cpp.h   | 148 ++++++++++++++++++++++++------
 src/core/tensor/tensor_math_cuda.h  |  54 +++++++----
 src/model/layer/activation.cc       |  67 ++++++++++++++
 src/model/layer/activation.h        |  51 +++++++++++
 src/model/layer/cudnn_activation.cc | 115 +++++++++++++++++++++++
 src/model/layer/cudnn_activation.h  |  58 ++++++++++++
 src/model/layer/cudnn_softmax.cc    |  77 ++++++++++++++++
 src/model/layer/cudnn_softmax.h     |  54 +++++++++++
 src/model/layer/softmax.cc          |  64 +++++++++++++
 src/model/layer/softmax.h           |  45 +++++++++
 test/singa/test_activation.cc       | 133 +++++++++++++++++++++++++++
 test/singa/test_cudnn_activation.cc | 136 +++++++++++++++++++++++++++
 test/singa/test_cudnn_dropout.cc    |   2 +-
 test/singa/test_cudnn_softmax.cc    | 107 ++++++++++++++++++++++
 test/singa/test_softmax.cc          | 110 ++++++++++++++++++++++
 20 files changed, 1468 insertions(+), 197 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/3e2507b7/include/singa/core/tensor.h
----------------------------------------------------------------------
diff --git a/include/singa/core/tensor.h b/include/singa/core/tensor.h
index 8682bca..bb8d7f8 100644
--- a/include/singa/core/tensor.h
+++ b/include/singa/core/tensor.h
@@ -62,7 +62,7 @@ inline size_t SizeOf(DataType t) {
 /// then it must be set up correctly (shape, device). Otherwise, runtime error
 /// like SegmentFault would happen. Simply type/device check would be conducted.
 class Tensor {
-public:
+ public:
   ~Tensor();
   Tensor();
   explicit Tensor(Shape &&shape, DataType dtype = kFloat32);
@@ -83,7 +83,8 @@ public:
   Device *device() const { return device_; }
 
   /// Return immutable Tensor values with given type.
-  template <typename DType> DType data() const {
+  template <typename DType>
+  DType data() const {
     return static_cast<DType>(blob()->data());
   }
 
@@ -130,7 +131,8 @@ public:
   void ToHost();
 
   /// Set each element of the tensor to be x
-  template <typename SType> void SetValue(const SType x);
+  template <typename SType>
+  void SetValue(const SType x);
 
   /// For init the tensor values, copy 'num' elements.
   template <typename DType>
@@ -141,7 +143,7 @@ public:
   void CopyData(const Tensor &other);
 
   /// Return an exactly the same Tensor with data been deep copied.
-  Tensor Clone();
+  Tensor Clone() const;
 
   // Tensor operations
 
@@ -167,23 +169,27 @@ public:
   // Scalar operations.
 
   /// T is a scalar type
-  template <typename DType> Tensor &operator+=(DType x);
+  template <typename DType>
+  Tensor &operator+=(DType x);
 
   /// T is a scalar type
-  template <typename DType> Tensor &operator-=(const DType x);
+  template <typename DType>
+  Tensor &operator-=(const DType x);
 
   /// T is a scalar type
-  template <typename DType> Tensor &operator*=(const DType x);
+  template <typename DType>
+  Tensor &operator*=(const DType x);
 
   /// T is a scalar type
-  template <typename DType> Tensor &operator/=(const DType x);
+  template <typename DType>
+  Tensor &operator/=(const DType x);
 
   /// save Tensor into a proto msg
   // void ToProto(TensorProto* t);
   /// load Tensor from proto msg
   // void FromProto(const TensorProto& t);
 
-protected:
+ protected:
   bool transpose_ = false;
   DataType data_type_ = kFloat32;
   Device *device_ = nullptr;
@@ -220,7 +226,8 @@ Tensor Sqrt(const Tensor &t);
 Tensor Square(const Tensor &t);
 Tensor Tanh(const Tensor &t);
 
-template <typename SType> SType Sum(const Tensor &t);
+template <typename SType>
+SType Sum(const Tensor &t);
 /// Sum elements in the Tensor, currently only support vector and matrix.
 /// if 'axis' is 0, sum all rows into a single row
 /// if 'axis' is 1, sum all columns into a single column
@@ -232,16 +239,48 @@ Tensor Sum(const Tensor &t, int axis);
 /// if 'axis' is 1, average all columns into a single column
 /// TODO(wangwei) support arbitrary Tensor like numpy.average
 Tensor Average(const Tensor &t, int axis);
+/// Regarding the internal data as 2d, with shape_[0]*...*shape_[axis-1] rows,
+/// and shape_[axis]*...*shape_[nDim()] columns.
+/// and do softmax along each row.
+Tensor SoftMax(const Tensor &t, int axis = 0);
+void SoftMax(const Tensor &t, int axis, Tensor *ret);
+
 /// Regarding the internal data as 2d, with shape_[0]*...*shape_[axis] rows,
 /// and shape_[axis+1]*...*shape_[nDim()] columns.
 /// and do softmax along each row.
-Tensor Softmax(const Tensor &t, int axis = -1);
-void Softmax(const Tensor &t, Tensor *ret, int axis = -1);
+// Tensor Softmax(const Tensor& t, int axis = -1);
+// void Softmax(const Tensor& t, Tensor* ret, int axis = -1);
+
+/// Element-wise operation, ret[i]= (t[i] < x) ? 1.f : 0.f
+template <typename DType>
+Tensor operator<(const Tensor &t, const DType x);
+template <typename DType>
+void LT(const Tensor &t, DType x, Tensor *ret);
+
+/// Element-wise operation, ret[i]= (t[i] <= x) ? 1.f : 0.f
+template <typename DType>
+Tensor operator<=(const Tensor &t, const DType x);
+template <typename DType>
+void LE(const Tensor &t, DType x, Tensor *ret);
+
+/// Element-wise operation, ret[i]= (t[i] > x) ? 1.f : 0.f
+template <typename DType>
+Tensor operator>(const Tensor &t, const DType x);
+template <typename DType>
+void GT(const Tensor &t, DType x, Tensor *ret);
+
+/// Element-wise operation, ret[i]= (t[i] >= x) ? 1.f : 0.f
+template <typename DType>
+Tensor operator>=(const Tensor &t, const DType x);
+template <typename DType>
+void GE(const Tensor &t, DType x, Tensor *ret);
 
 /// Element-wise opeartion, ret[i]=t[i]^x
-template <typename DType> Tensor Pow(const Tensor &t, DType x);
+template <typename DType>
+Tensor Pow(const Tensor &t, DType x);
 /// Element-wise opeartion, ret[i]=t[i]^x
-template <typename DType> void Pow(const Tensor &t, DType x, Tensor *ret);
+template <typename DType>
+void Pow(const Tensor &t, DType x, Tensor *ret);
 /// Element-wise opeartion, ret[i]=baes[i]^exp[i]
 Tensor Pow(const Tensor &base, Tensor exp);
 /// Element-wise opeartion, ret[i]=baes[i]^exp[i]
@@ -256,18 +295,25 @@ void EltwiseMult(const Tensor &lhs, const Tensor &rhs, Tensor *ret);
 Tensor operator/(const Tensor &lhs, const Tensor &rhs);
 void Div(const Tensor &lhs, const Tensor &rhs, Tensor *ret);
 
-template <typename DType> Tensor operator+(const Tensor &t, DType x);
-template <typename DType> void Add(const Tensor &t, DType x, Tensor *ret);
+template <typename DType>
+Tensor operator+(const Tensor &t, DType x);
+template <typename DType>
+void Add(const Tensor &t, DType x, Tensor *ret);
 
-template <typename DType> Tensor operator-(const Tensor &t, DType x);
-template <typename DType> void Sub(const Tensor &t, DType x, Tensor *ret);
+template <typename DType>
+Tensor operator-(const Tensor &t, DType x);
+template <typename DType>
+void Sub(const Tensor &t, DType x, Tensor *ret);
 
-template <typename DType> Tensor operator*(const Tensor &t, DType x);
+template <typename DType>
+Tensor operator*(const Tensor &t, DType x);
 template <typename DType>
 void EltwiseMult(const Tensor &t, DType x, Tensor *ret);
 
-template <typename DType> Tensor operator/(const Tensor &t, DType x);
-template <typename DType> void Div(const Tensor &t, DType x, Tensor *ret);
+template <typename DType>
+Tensor operator/(const Tensor &t, DType x);
+template <typename DType>
+void Div(const Tensor &t, DType x, Tensor *ret);
 
 // ================Blas operations============================================
 // We fix the scalar argument type to be float.
@@ -301,6 +347,7 @@ void Uniform(float low, float high, Tensor *t);
 void Gaussian(float mean, float std, Tensor *t);
 
 // follow the consistency guide
+// https://issues.apache.org/jira/browse/SINGA-182
 // ============Matrix vector operations=======================================
 /// Add column 'v' with each column of matrix M
 void AddColumn(const Tensor &v, Tensor *M);
@@ -329,12 +376,28 @@ void SumRows(const Tensor &M, Tensor *out);
 void SumColumns(const Tensor &M, Tensor *out);
 
 /// For each element x of Tensor 'in', compute alpha/x
-template <typename SType> Tensor Div(const SType alpha, const Tensor &in);
+template <typename SType>
+Tensor Div(const SType alpha, const Tensor &in);
 
 /// For each element x of Tensor 'in', compute alpha/x into 'out'
 template <typename SType>
 void Div(const SType alpha, const Tensor &in, Tensor *out);
 
+/*
+/// Multiply each column of the lhs matrix with the rhs column
+Tensor MultColumn(const Tensor &lhs, const Tensor &rhs);
+void MultColumn(const Tensor &lhs, const Tensor &rhs, Tensor *ret);
+/// Multiply each row of the lhs matrix with the rhs row
+Tensor MultRow(const Tensor &lhs, const Tensor &rhs);
+void MultRow(const Tensor &lhs, const Tensor &rhs, Tensor *ret);
+/// Div each row of the lhs matrix with the rhs column
+Tensor DivColumn(const Tensor &lhs, const Tensor &rhs);
+void DivColumn(const Tensor &lhs, const Tensor &rhs, Tensor *ret);
+/// Divide each row of the lhs matrix by the rhs row
+Tensor DivRow(const Tensor &lhs, const Tensor &rhs);
+void DivRow(const Tensor &lhs, const Tensor &rhs, Tensor *ret);
+*/
+
 }  // namespace singa
 
 #endif  // SINGA_CORE_TENSOR_H_

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/3e2507b7/src/core/tensor/math_kernel.cu
----------------------------------------------------------------------
diff --git a/src/core/tensor/math_kernel.cu b/src/core/tensor/math_kernel.cu
index 88041b1..aed6add 100644
--- a/src/core/tensor/math_kernel.cu
+++ b/src/core/tensor/math_kernel.cu
@@ -32,7 +32,7 @@
 #define CU1DBLOCK 1024
 #define CU1DBLOCKF 1024.0
 
-namespace singa{
+namespace singa {
 // Cuda Kernel Functions
 namespace cuda {
 __global__ void kernel_softmax_loss(const float *prob, const int *label,
@@ -147,7 +147,8 @@ __global__ void kernel_add_vec_row(const float *src_vec_data,
     des_mat_data[index] = src_mat_data[index] + src_vec_data[i];
   }
 }
-__global__ void kernel_add(const float *src1, const float *src2, float*out, int n) {
+__global__ void kernel_add(const float *src1, const float *src2, float *out,
+                           int n) {
   int index = blockIdx.x * blockDim.x + threadIdx.x;
   int num_threads = blockDim.x * gridDim.x;
   for (; index < n; index += num_threads) {
@@ -155,7 +156,8 @@ __global__ void kernel_add(const float *src1, const float *src2, float*out, int
   }
 }
 
-__global__ void kernel_sub(const float *src1, const float *src2, float*out, int n) {
+__global__ void kernel_sub(const float *src1, const float *src2, float *out,
+                           int n) {
   int index = blockIdx.x * blockDim.x + threadIdx.x;
   int num_threads = blockDim.x * gridDim.x;
   for (; index < n; index += num_threads) {
@@ -323,42 +325,28 @@ __global__ void kernel_threshold(const float *src_data, float *des_data,
     des_data[index] = src_data[index] < alpha ? 1.0f : 0.0f;
   }
 }
-
-/*
-void softmaxloss_forward(int n, int dim, const float *prob,
-    const int *label, float *loss) {
-  kernel_softmax_loss<<<ceil(n/CU1DBLOCKF), CU1DBLOCKF>>>(prob, label, loss, n,
-      dim);
-}
-
-void softmaxloss_backward(int n, int dim, float scale,
-    const int *label, float *grad) {
-  kernel_softmax_gradient<<<ceil(n/CU1DBLOCKF), CU1DBLOCKF>>>(grad, label, n,
-      dim, scale);
-}
-*/
 void sum(int n, const float *in, float *out) {
   int threads_per_block = n > CU1DBLOCK ? CU1DBLOCK : n;
   //  here, we only need one block
   int num_blocks = 1;
 
-  kernel_sum_vec<<<num_blocks, threads_per_block>>>(in, out, n);
+  kernel_sum_vec << <num_blocks, threads_per_block>>> (in, out, n);
 }
 
 void sum_row(int rows, int cols, int stride, const float *in, float *out) {
   int threads_per_block = rows > CU1DBLOCK ? CU1DBLOCK : rows;
   int num_blocks = cols;
 
-  kernel_sum_row<<<num_blocks, threads_per_block>>>(in, out, rows, cols,
-                                                    stride);
+  kernel_sum_row << <num_blocks, threads_per_block>>>
+      (in, out, rows, cols, stride);
 }
 
 void sum_col(int rows, int cols, int stride, const float *in, float *out) {
   int threads_per_block = cols > CU1DBLOCK ? CU1DBLOCK : cols;
   int num_blocks = rows;
 
-  kernel_sum_col<<<num_blocks, threads_per_block>>>(in, out,
-                                                    rows, cols, stride);
+  kernel_sum_col << <num_blocks, threads_per_block>>>
+      (in, out, rows, cols, stride);
 }
 void add_row(int rows, int cols, int stride, const float *in_row,
              const float *in_mat, float *out) {
@@ -366,92 +354,91 @@ void add_row(int rows, int cols, int stride, const float *in_row,
   dim3 num_blocks(
       cols / threads_per_block.x + (cols % threads_per_block.x == 0 ? 0 : 1),
       rows / threads_per_block.y + (rows % threads_per_block.y == 0 ? 0 : 1));
-  kernel_add_vec_row<<<num_blocks, threads_per_block>>>(in_row, in_mat, out,
-                                                        rows, cols, stride);
+  kernel_add_vec_row << <num_blocks, threads_per_block>>>
+      (in_row, in_mat, out, rows, cols, stride);
 }
 void add(int n, const float *a, const float *b, float *out) {
-  kernel_add<<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>>(a, b, out, n);
+  kernel_add << <ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (a, b, out, n);
 }
 void sub(int n, const float *a, const float *b, float *out) {
-  kernel_sub<<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>>(a, b, out, n);
+  kernel_sub << <ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (a, b, out, n);
 }
 void exp(int n, const float *in, float *out) {
-  kernel_exp<<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>>(in, out, n);
+  kernel_exp << <ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (in, out, n);
 }
 
 void log(int n, const float *in, float *out) {
-  kernel_log<<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>>(in, out, n);
+  kernel_log << <ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (in, out, n);
 }
 
 void sigmoid(int n, const float *in, float *out) {
-  kernel_sigmoid<<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>>(in, out, n);
+  kernel_sigmoid << <ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (in, out, n);
 }
 
 void sigmoid_grad(int n, const float *in, float *out) {
-  kernel_sigmoid_grad<<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>>(in, out, n);
+  kernel_sigmoid_grad << <ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (in, out, n);
 }
 
 void relu(int n, const float *in, float *out) {
-  kernel_relu<<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>>(in, out, n);
+  kernel_relu << <ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (in, out, n);
 }
 
 void relu_grad(int n, const float *in, float *out) {
-  kernel_relu_grad<<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>>(in, out, n);
+  kernel_relu_grad << <ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (in, out, n);
 }
 
 void tanh(int n, const float *in, float *out) {
-  kernel_tanh<<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>>(in, out, n);
+  kernel_tanh << <ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (in, out, n);
 }
 
 void tanh_grad(int n, const float *in, float *out) {
-  kernel_tanh_grad<<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>>(in, out, n);
+  kernel_tanh_grad << <ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (in, out, n);
 }
 
 void softplus(int n, const float *in, float *out) {
-  kernel_softplus<<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>>(in, out, n);
+  kernel_softplus << <ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (in, out, n);
 }
 
 void softplus_grad(int n, const float *in, float *out) {
-  kernel_softplus_grad<<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>>(in, out, n);
+  kernel_softplus_grad << <ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (in, out, n);
 }
 
 void square(int n, const float *in, float *out) {
-  kernel_square<<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>>(in, out, n);
+  kernel_square << <ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (in, out, n);
 }
 
 void square_grad(int n, const float *in, float *out) {
-  kernel_square_grad<<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>>(in, out, n);
+  kernel_square_grad << <ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (in, out, n);
 }
 
 void sqrt(int n, const float *in, float *out) {
-  kernel_sqrt<<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>>(in, out, n);
+  kernel_sqrt << <ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (in, out, n);
 }
 
 void pow(int n, const float *a, const float *b, float *out) {
-  kernel_pow<<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>>(a, b, out, n);
+  kernel_pow << <ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (a, b, out, n);
 }
 
 void mult(int n, const float *a, const float *b, float *out) {
-  kernel_mult<<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>>(a, b, out, n);
+  kernel_mult << <ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (a, b, out, n);
 }
 
 void mult(int n, const float *a, const float x, float *out) {
-  kernel_mult<<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>>(a, x, out, n);
+  kernel_mult << <ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (a, x, out, n);
 }
 
 void div(int n, const float *a, const float *b, float *out) {
-  kernel_div<<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>>(a, b, out, n);
+  kernel_div << <ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (a, b, out, n);
 }
 
 void set_value(int n, float v, float *out) {
-  kernel_set_value<<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>>(out, v, n);
+  kernel_set_value << <ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (out, v, n);
 }
 
 void threshold(int n, float alpha, const float *in, float *out) {
-  kernel_threshold<<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>>(in, out, alpha, n);
+  kernel_threshold << <ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (in, out, alpha, n);
 }
 
-
 // follow the consistency guide for math API
 __global__ void KernelDiv(const size_t num, const float alpha, const float *in,
                           float *out) {
@@ -461,6 +448,36 @@ __global__ void KernelDiv(const size_t num, const float alpha, const float *in,
   }
 }
 
+__global__ void KernelGE(const int num, const float *in, const float x,
+                         float *out) {
+  for (size_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < num;
+       idx += blockDim.x * gridDim.x) {
+    out[idx] = in[idx] >= x ? 1.0f : 0.0f;
+  }
+}
+__global__ void KernelGT(const int num, const float *in, const float x,
+                         float *out) {
+  for (size_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < num;
+       idx += blockDim.x * gridDim.x) {
+    out[idx] = in[idx] > x ? 1.0f : 0.0f;
+  }
+}
+__global__ void KernelLE(const int num, const float *in, const float x,
+                         float *out) {
+  for (size_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < num;
+       idx += blockDim.x * gridDim.x) {
+    out[idx] = in[idx] <= x ? 1.0f : 0.0f;
+  }
+}
+
+__global__ void KernelLT(const int num, const float *in, const float x,
+                         float *out) {
+  for (size_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < num;
+       idx += blockDim.x * gridDim.x) {
+    out[idx] = in[idx] < x ? 1.0f : 0.0f;
+  }
+}
+
 __global__ void KernelSet(const size_t num, const float x, float *out) {
   for (size_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < num;
        idx += blockDim.x * gridDim.x) {
@@ -468,14 +485,31 @@ __global__ void KernelSet(const size_t num, const float x, float *out) {
   }
 }
 
+void Set(const size_t num, const float x, float *out, cudaStream_t s) {
+  KernelSet << <ceil(num / CU1DBLOCKF), CU1DBLOCKF>>> (num, x, out);
+}
 void Div(const size_t num, float alpha, const float *in, float *out,
          cudaStream_t s) {
-  KernelDiv<<<ceil(num / CU1DBLOCKF), CU1DBLOCKF>>>(num, alpha, in, out);
+  KernelDiv << <ceil(num / CU1DBLOCKF), CU1DBLOCKF>>> (num, alpha, in, out);
 }
 
-void Set(const size_t num, const float x, float *out, cudaStream_t s) {
-  KernelSet<<<ceil(num / CU1DBLOCKF), CU1DBLOCKF>>>(num, x, out);
+void GT(const size_t num, const float *in, const float x, float *out,
+        cudaStream_t s) {
+  KernelGT << <ceil(num / CU1DBLOCKF), CU1DBLOCKF>>> (num, in, x, out);
+}
+void GE(const size_t num, const float *in, const float x, float *out,
+        cudaStream_t s) {
+  KernelGE << <ceil(num / CU1DBLOCKF), CU1DBLOCKF>>> (num, in, x, out);
 }
+void LT(const size_t num, const float *in, const float x, float *out,
+        cudaStream_t s) {
+  KernelLT << <ceil(num / CU1DBLOCKF), CU1DBLOCKF>>> (num, in, x, out);
+}
+void LE(const size_t num, const float *in, const float x, float *out,
+        cudaStream_t s) {
+  KernelLE << <ceil(num / CU1DBLOCKF), CU1DBLOCKF>>> (num, in, x, out);
+}
+
 }  // namespace cuda
 }  // namespace singa
 

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/3e2507b7/src/core/tensor/math_kernel.h
----------------------------------------------------------------------
diff --git a/src/core/tensor/math_kernel.h b/src/core/tensor/math_kernel.h
index 925346e..5c906a9 100644
--- a/src/core/tensor/math_kernel.h
+++ b/src/core/tensor/math_kernel.h
@@ -86,7 +86,11 @@ void threshold(int n, float alpha, const float *in, float *out);
 void Div(const size_t num, const float x, const float *in, float *out,
          cudaStream_t s);
 void Set(const size_t num, const float x, float *out, cudaStream_t s);
-} // cuda
+void GT(size_t num, const float *in, const float x, float *out, cudaStream_t s);
+void GE(size_t num, const float *in, const float x, float *out, cudaStream_t s);
+void LT(size_t num, const float *in, const float x, float *out, cudaStream_t s);
+void LE(size_t num, const float *in, const float x, float *out, cudaStream_t s);
+}  // cuda
 
 }  // namespace singa
 

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/3e2507b7/src/core/tensor/tensor.cc
----------------------------------------------------------------------
diff --git a/src/core/tensor/tensor.cc b/src/core/tensor/tensor.cc
index fcf42c2..5ae375c 100644
--- a/src/core/tensor/tensor.cc
+++ b/src/core/tensor/tensor.cc
@@ -142,7 +142,7 @@ void Tensor::CopyData(const Tensor &src) {
   }
 }
 
-Tensor Tensor::Clone() {
+Tensor Tensor::Clone() const {
   Tensor t(shape_, device_, data_type_);
   t.transpose_ = transpose_;
   t.CopyData(*this);
@@ -200,28 +200,28 @@ Tensor Reshape(const Tensor &in, Shape &&s) {
   return out;
 }
 
-#define GenUnaryTensorArgMemberFunction(op, fn)                                \
+#define GenUnaryTensorArgMemberFn(op, fn)                                \
   Tensor &Tensor::op(const Tensor &t) {                                        \
     fn(*this, t, this);                                                        \
     return *this;                                                              \
   }
 
-GenUnaryTensorArgMemberFunction(operator+=, Add);
-GenUnaryTensorArgMemberFunction(operator-=, Sub);
-GenUnaryTensorArgMemberFunction(operator*=, EltwiseMult);
-GenUnaryTensorArgMemberFunction(operator/=, Div);
+GenUnaryTensorArgMemberFn(operator+=, Add);
+GenUnaryTensorArgMemberFn(operator-=, Sub);
+GenUnaryTensorArgMemberFn(operator*=, EltwiseMult);
+GenUnaryTensorArgMemberFn(operator/=, Div);
 
-#define GenUnaryScalarArgMemberFunction(op, fn)                                \
+#define GenUnaryScalarArgMemberFn(op, fn)                                \
   template <typename DType> Tensor &Tensor::op(DType x) {                      \
     fn(*this, x, this);                                                        \
     return *this;                                                              \
   }                                                                            \
   template Tensor &Tensor::op<float>(float x)
 
-GenUnaryScalarArgMemberFunction(operator-=, Sub);
-GenUnaryScalarArgMemberFunction(operator+=, Add);
-GenUnaryScalarArgMemberFunction(operator*=, EltwiseMult);
-GenUnaryScalarArgMemberFunction(operator/=, Div);
+GenUnaryScalarArgMemberFn(operator-=, Sub);
+GenUnaryScalarArgMemberFn(operator+=, Add);
+GenUnaryScalarArgMemberFn(operator*=, EltwiseMult);
+GenUnaryScalarArgMemberFn(operator/=, Div);
 
 // ====================Tensor Operations=======================================
 void CopyDataToFrom(Tensor *dst, const Tensor &src, size_t num,
@@ -325,34 +325,35 @@ template <typename SType> void Tensor::SetValue(const SType x) {
 }
 template void Tensor::SetValue<float>(const float x);
 
-#define EltwiseUnaryTensorFn(fn, t, ret)                                       \
-  do {                                                                         \
-    TYPE_LANG_SWITCH(t.data_type(), DType, t.device()->lang(), Lang, {         \
-      ret->device()->Exec(                                                     \
-          [t, ret](Context *ctx) {                                             \
-            fn<DType, Lang>(t.Size(), t.blob(), ret->blob(), ctx);             \
-          },                                                                   \
-          {t.blob()}, {ret->blob()});                                          \
-    });                                                                        \
+#define EltwiseUnaryTensorFn(fn, t, ret)                               \
+  do {                                                                 \
+    TYPE_LANG_SWITCH(t.data_type(), DType, t.device()->lang(), Lang, { \
+      ret->device()->Exec(                                             \
+          [t, ret](Context* ctx) {                                     \
+            fn<DType, Lang>(t.Size(), t.blob(), ret->blob(), ctx);     \
+          },                                                           \
+          {t.blob()}, {ret->blob()});                                  \
+    });                                                                \
   } while (0)
 
-#define GenUnaryTensorFunction(fn)                                             \
-  Tensor fn(const Tensor &t) {                                                 \
-    Tensor ret(t.shape(), t.device(), t.data_type());                          \
-    auto *retptr = &ret;                                                       \
-    EltwiseUnaryTensorFn(fn, t, retptr);                                       \
-    return ret;                                                                \
-  }
-
-GenUnaryTensorFunction(Abs);
-GenUnaryTensorFunction(Exp);
-GenUnaryTensorFunction(Log);
-GenUnaryTensorFunction(ReLU);
-GenUnaryTensorFunction(Sigmoid);
-GenUnaryTensorFunction(Sign);
-GenUnaryTensorFunction(Sqrt);
-GenUnaryTensorFunction(Square);
-GenUnaryTensorFunction(Tanh);
+#define GenUnaryTensorFn(fn)                          \
+  Tensor fn(const Tensor &t) {                        \
+    Tensor ret(t.shape(), t.device(), t.data_type()); \
+    auto *retptr = &ret;                              \
+    EltwiseUnaryTensorFn(fn, t, retptr);              \
+    return ret;                                       \
+  }                                                   \
+  void fn(const Tensor &in, Tensor *out) { EltwiseUnaryTensorFn(fn, in, out); }
+
+GenUnaryTensorFn(Abs);
+GenUnaryTensorFn(Exp);
+GenUnaryTensorFn(Log);
+GenUnaryTensorFn(ReLU);
+GenUnaryTensorFn(Sigmoid);
+GenUnaryTensorFn(Sign);
+GenUnaryTensorFn(Sqrt);
+GenUnaryTensorFn(Square);
+GenUnaryTensorFn(Tanh);
 
 // TODO(wangwei) conside async exec
 template <> float Sum<float>(const Tensor &t) {
@@ -402,28 +403,25 @@ Tensor Average(const Tensor &t, int axis) {
   }
 }
 
-Tensor Softmax(const Tensor &t, int axis) {
-  Tensor ret(t.shape(), t.device(), t.data_type());
-  Softmax(t, &ret, axis);
-  return ret;
+Tensor SoftMax(const Tensor &in, int axis) {
+  Tensor out(in.shape(), in.device(), in.data_type());
+  SoftMax(in, axis, &out);
+  return out;
 }
 
-void Softmax(const Tensor &t, Tensor *ret, int axis) {
-  int nrow = 1, ncol = t.Size(), size = ncol;
-  CHECK_GE(axis, -1);
-  CHECK_GT(t.shape().size(), 0u);
-  if (axis > -1) {
-    nrow = Product(t.shape(), 0, axis + 1);
-    CHECK_EQ(size % nrow, 0) << "Size = " << size << " nrow = " << nrow;
+void SoftMax(const Tensor &in, int axis, Tensor *out) {
+  size_t nrow = 1, ncol = in.Size(), size = ncol;
+  CHECK_GE(axis, 0);
+  if (axis > 0) {
+    nrow = Product(in.shape(), 0, axis);
+    CHECK_EQ(size % nrow, 0u) << "Size = " << size << " nrow = " << nrow;
     ncol = size / nrow;
   }
-  TYPE_LANG_SWITCH(t.data_type(), DType, t.device()->lang(), Lang, {
-    ret->device()->Exec(
-        [nrow, ncol, t, ret](Context *ctx) {
-          Softmax<DType, Lang>(nrow, ncol, t.blob(), ret->blob(), ctx);
-        },
-        {t.blob()}, {ret->blob()});
-  });
+  Exp(in, out);
+  out->Reshape(Shape{nrow, ncol});
+  Tensor sum(Shape{nrow}, in.device(), in.data_type());
+  SumColumns(*out, &sum);
+  DivColumn(sum, out);
 }
 
 #define EltwiseBinaryTensorFn(fn, lhs, rhs, ret)                               \
@@ -439,7 +437,7 @@ void Softmax(const Tensor &t, Tensor *ret, int axis) {
     });                                                                        \
   } while (0)
 
-#define GenBinaryTensorFunction(op, fn)                                        \
+#define GenBinaryTensorFn(op, fn)                                        \
   Tensor op(const Tensor &lhs, const Tensor &rhs) {                            \
     Tensor ret(lhs.shape(), lhs.device(), lhs.data_type());                    \
     fn(lhs, rhs, &ret);                                                        \
@@ -449,11 +447,11 @@ void Softmax(const Tensor &t, Tensor *ret, int axis) {
     EltwiseBinaryTensorFn(fn, lhs, rhs, ret);                                  \
   }
 
-GenBinaryTensorFunction(operator+, Add);
-GenBinaryTensorFunction(operator-, Sub);
-GenBinaryTensorFunction(operator*, EltwiseMult);
-GenBinaryTensorFunction(operator/, Div);
-GenBinaryTensorFunction(Pow, Pow);
+GenBinaryTensorFn(operator+, Add);
+GenBinaryTensorFn(operator-, Sub);
+GenBinaryTensorFn(operator*, EltwiseMult);
+GenBinaryTensorFn(operator/, Div);
+GenBinaryTensorFn(Pow, Pow);
 
 #define EltwiseTensorScalarFn(fn, t, x, ret)                                   \
   do {                                                                         \
@@ -468,7 +466,7 @@ GenBinaryTensorFunction(Pow, Pow);
     });                                                                        \
   } while (0)
 
-#define GenTensorScalarFunction(op, fn)                                        \
+#define GenTensorScalarFn(op, fn)                                        \
   template <typename SType> Tensor op(const Tensor &t, SType x) {              \
     Tensor ret(t.shape(), t.device(), t.data_type());                          \
     fn(t, x, &ret);                                                            \
@@ -480,11 +478,15 @@ GenBinaryTensorFunction(Pow, Pow);
   template Tensor op<float>(const Tensor &t, float x);                         \
   template void fn<float>(const Tensor &t, const float x, Tensor *ret)
 
-GenTensorScalarFunction(operator+, Add);
-GenTensorScalarFunction(operator-, Sub);
-GenTensorScalarFunction(operator*, EltwiseMult);
-GenTensorScalarFunction(operator/, Div);
-GenTensorScalarFunction(Pow, Pow);
+GenTensorScalarFn(operator+, Add);
+GenTensorScalarFn(operator-, Sub);
+GenTensorScalarFn(operator*, EltwiseMult);
+GenTensorScalarFn(operator/, Div);
+GenTensorScalarFn(Pow, Pow);
+GenTensorScalarFn(operator<, LT);
+GenTensorScalarFn(operator<=, LE);
+GenTensorScalarFn(operator>, GT);
+GenTensorScalarFn(operator>=, GE);
 
 // ================Blas operations============================================
 Tensor Mult(const Tensor &lhs, const Tensor &rhs) {
@@ -633,8 +635,8 @@ void DivRow(const Tensor &v, Tensor *M) {
 /// Multiply column 'v' and each column of matrix M; write results into 'out'
 void MultColumn(const Tensor &v, Tensor *M) {
   CHECK(!M->transpose()) << "Not supported yet";
-  CHECK_EQ(M->nDim(), 2);
-  CHECK_EQ(v.nDim(), 1);
+  CHECK_EQ(M->nDim(), 2u);
+  CHECK_EQ(v.nDim(), 1u);
   CHECK_EQ(v.Size(), M->shape(0));
   CheckDataTypeAndLang(*M, v);
   TYPE_LANG_SWITCH(v.data_type(), DType, v.device()->lang(), Lang, {
@@ -650,8 +652,8 @@ void MultColumn(const Tensor &v, Tensor *M) {
 /// Multiply row 'v' with each row of matrix M; write results into 'out'
 void MultRow(const Tensor &v, Tensor *M) {
   CHECK(!M->transpose()) << "Not supported yet";
-  CHECK_EQ(M->nDim(), 2);
-  CHECK_EQ(v.nDim(), 1);
+  CHECK_EQ(M->nDim(), 2u);
+  CHECK_EQ(v.nDim(), 1u);
   CHECK_EQ(v.Size(), M->shape(1));
   CheckDataTypeAndLang(*M, v);
   TYPE_LANG_SWITCH(v.data_type(), DType, v.device()->lang(), Lang, {
@@ -673,8 +675,8 @@ void SumColumns(const Tensor &M, Tensor *v) {
     Tensor X = M.T();
     SumRows(X, v);
   } else {
-    CHECK_EQ(M.nDim(), 2);
-    CHECK_EQ(v->nDim(), 1);
+    CHECK_EQ(M.nDim(), 2u);
+    CHECK_EQ(v->nDim(), 1u);
     size_t nb_row = M.shape().at(0), nb_col = M.shape().at(1);
     CHECK_EQ(nb_row, v->Size());
 
@@ -688,8 +690,8 @@ void SumRows(const Tensor &M, Tensor *v) {
     Tensor X = M.T();
     SumColumns(X, v);
   } else {
-    CHECK_EQ(M.nDim(), 2);
-    CHECK_EQ(v->nDim(), 1);
+    CHECK_EQ(M.nDim(), 2u);
+    CHECK_EQ(v->nDim(), 1u);
     size_t nb_row = M.shape(0), nb_col = M.shape(1);
     CHECK_EQ(nb_col, v->Size());
 

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/3e2507b7/src/core/tensor/tensor_math.h
----------------------------------------------------------------------
diff --git a/src/core/tensor/tensor_math.h b/src/core/tensor/tensor_math.h
index 98d91bf..ff865e0 100644
--- a/src/core/tensor/tensor_math.h
+++ b/src/core/tensor/tensor_math.h
@@ -220,6 +220,27 @@ void Outer(int m, int n, const Blob *lhs, const Blob *rhs, Blob *ret,
   LOG(FATAL) << "Not Implemented";
 }
 
+/// ret[i]=(input[i]<x)?1.f:0.f
+template <typename DType, typename Lang>
+void LT(int count, const Blob *input, float x, Blob *ret, Context *ctx) {
+  LOG(FATAL) << "Not Implemented";
+}
+/// ret[i]=(input[i]<=x)?1.f:0.f
+template <typename DType, typename Lang>
+void LE(int count, const Blob *input, float x, Blob *ret, Context *ctx) {
+  LOG(FATAL) << "Not Implemented";
+}
+/// ret[i]=(input[i]>x)?1.f:0.f
+template <typename DType, typename Lang>
+void GT(int count, const Blob *input, float x, Blob *ret, Context *ctx) {
+  LOG(FATAL) << "Not Implemented";
+}
+/// ret[i]=(input[i]>x)?1.f:0.f
+template <typename DType, typename Lang>
+void GE(int count, const Blob *input, float x, Blob *ret, Context *ctx) {
+  LOG(FATAL) << "Not Implemented";
+}
+
 // ===== BLAS functions, ref to http://docs.nvidia.com/cuda/cublas
 // ===== Level 1
 /// return the index of the element with the max value.
@@ -319,6 +340,30 @@ void GEMM(const bool transA, const bool transB, const size_t nrowA,
           Context *ctx) {
   LOG(FATAL) << "Not Implemented";
 }
-} // namespace singa
+/// ret[i]=(input[i]<x)?1.f:0.f
+template <typename DType, typename Lang>
+void LT(const size_t num, const Blob *in, const DType x, Blob *out,
+        Context *ctx) {
+  LOG(FATAL) << "Not Implemented";
+}
+/// ret[i]=(input[i]<=x)?1.f:0.f
+template <typename DType, typename Lang>
+void LE(const size_t num, const Blob *in, const DType x, Blob *out,
+        Context *ctx) {
+  LOG(FATAL) << "Not Implemented";
+}
+/// ret[i]=(input[i]>x)?1.f:0.f
+template <typename DType, typename Lang>
+void GT(const size_t num, const Blob *in, const DType x, Blob *out,
+        Context *ctx) {
+  LOG(FATAL) << "Not Implemented";
+}
+/// ret[i]=(input[i]>=x)?1.f:0.f
+template <typename DType, typename Lang>
+void GE(const size_t num, const Blob *in, const DType x, Blob *out,
+        Context *ctx) {
+  LOG(FATAL) << "Not Implemented";
+}
 
+}  // namespace singa
 #endif  // SINGA_CORE_MATH_H_

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/3e2507b7/src/core/tensor/tensor_math_cpp.h
----------------------------------------------------------------------
diff --git a/src/core/tensor/tensor_math_cpp.h b/src/core/tensor/tensor_math_cpp.h
index 97da896..693f09c 100644
--- a/src/core/tensor/tensor_math_cpp.h
+++ b/src/core/tensor/tensor_math_cpp.h
@@ -19,6 +19,7 @@
 #define SINGA_CORE_TENSOR_TENSOR_MATH_CPP_H_
 #include "./tensor_math.h"
 #include "singa/core/common.h"
+#include <math.h>
 
 #ifdef USE_CBLAS
 #include <cblas.h>
@@ -51,6 +52,16 @@ void Add<float, lang::Cpp>(int count, const Blob *lhs, const Blob *rhs,
 }
 
 template <>
+void Add<float, lang::Cpp>(int count, const Blob *input, float x, Blob *ret,
+                           Context *ctx) {
+  float *dptr = static_cast<float *>(ret->mutable_data());
+  const float *lptr = static_cast<const float *>(input->data());
+  for (int i = 0; i < count; i++) {
+    dptr[i] = lptr[i] + x;
+  }
+}
+
+template <>
 void Sub<float, lang::Cpp>(int count, const Blob *lhs, const Blob *rhs,
                            Blob *ret, Context *ctx) {
   // CHECK_EQ(ctx->stream, nullptr);
@@ -61,6 +72,7 @@ void Sub<float, lang::Cpp>(int count, const Blob *lhs, const Blob *rhs,
     dptr[i] = lptr[i] - rptr[i];
   }
 }
+
 // sum all elements of input into ret
 // TODO(wangwei) optimize using omp
 template <>
@@ -74,53 +86,96 @@ void Sum<float, lang::Cpp>(int count, const Blob *input, float *ret,
   *ret = s;
 }
 
-// TODO(wangwei) optimize using omp
 template <>
-void SumRows<float, lang::Cpp>(int nrow, int ncol, const Blob *input, Blob *ret,
-                               Context *ctx) {
+void EltwiseMult<float, lang::Cpp>(int count, const Blob *input, float x,
+                                   Blob *ret, Context *ctx) {
   float *dptr = static_cast<float *>(ret->mutable_data());
-  const float *in = static_cast<const float *>(input->data());
-  memset(dptr, 0, ncol * sizeof(float));
-  for (int r = 0; r < nrow; r++) {
-    for (int c = 0; c < ncol; c++) {
-      dptr[c] += in[r * ncol + c];
-    }
+  const float *lptr = static_cast<const float *>(input->data());
+  for (int i = 0; i < count; i++) {
+    dptr[i] = lptr[i] * x;
   }
 }
 
-// Sum the rows of the input matrix into a vector
-// TODO(wangwei) optimize using omp
 template <>
-void SumColumns<float, lang::Cpp>(int nrow, int ncol, const Blob *input,
-                                  Blob *ret, Context *ctx) {
+void EltwiseMult<float, lang::Cpp>(int count, const Blob *lhs, const Blob *rhs,
+                                   Blob *ret, Context *ctx) {
   float *dptr = static_cast<float *>(ret->mutable_data());
-  const float *in = static_cast<const float *>(input->data());
-  memset(dptr, 0, ncol * sizeof(float));
-  for (int r = 0; r < nrow; r++) {
-    for (int c = 0; c < ncol; c++) {
-      dptr[r] += in[r * ncol + c];
-    }
+  const float *lptr = static_cast<const float *>(lhs->data());
+  const float *rptr = static_cast<const float *>(rhs->data());
+  for (int i = 0; i < count; i++) {
+    dptr[i] = lptr[i] * rptr[i];
   }
 }
 
 template <>
-void EltwiseMult<float, lang::Cpp>(int count, const Blob *input, float x,
-                                   Blob *ret, Context *ctx) {
+void Exp<float, lang::Cpp>(int count, const Blob *input, Blob *ret,
+                           Context *ctx) {
   float *dptr = static_cast<float *>(ret->mutable_data());
   const float *lptr = static_cast<const float *>(input->data());
   for (int i = 0; i < count; i++) {
-    dptr[i] = lptr[i] * x;
+    dptr[i] = exp(lptr[i]);
   }
 }
 
 template <>
-void EltwiseMult<float, lang::Cpp>(int count, const Blob *lhs, const Blob *rhs,
-                                   Blob *ret, Context *ctx) {
+void Log<float, lang::Cpp>(int count, const Blob *input, Blob *ret,
+                           Context *ctx) {
+  float *dptr = static_cast<float *>(ret->mutable_data());
+  const float *lptr = static_cast<const float *>(input->data());
+  for (int i = 0; i < count; i++) {
+    CHECK_GT(lptr[i], 0.f);
+    dptr[i] = log(lptr[i]);
+  }
+}
+
+template <>
+void Tanh<float, lang::Cpp>(int count, const Blob *input, Blob *ret,
+                            Context *ctx) {
+  float *dptr = static_cast<float *>(ret->mutable_data());
+  const float *lptr = static_cast<const float *>(input->data());
+  for (int i = 0; i < count; i++) {
+    dptr[i] = tanh(lptr[i]);
+  }
+}
+
+template <>
+void ReLU<float, lang::Cpp>(int count, const Blob *input, Blob *ret,
+                            Context *ctx) {
+  float *dptr = static_cast<float *>(ret->mutable_data());
+  const float *lptr = static_cast<const float *>(input->data());
+  for (int i = 0; i < count; i++) {
+    dptr[i] = (lptr[i] >= 0.f) ? lptr[i] : 0.f;
+  }
+}
+
+template <>
+void Sigmoid<float, lang::Cpp>(int count, const Blob *input, Blob *ret,
+                               Context *ctx) {
+  float *dptr = static_cast<float *>(ret->mutable_data());
+  const float *lptr = static_cast<const float *>(input->data());
+  for (int i = 0; i < count; i++) {
+    dptr[i] = 1.f / (1.f + exp(-lptr[i]));
+  }
+}
+
+template <>
+void Pow<float, lang::Cpp>(int count, const Blob *input, float x, Blob *ret,
+                           Context *ctx) {
+  float *dptr = static_cast<float *>(ret->mutable_data());
+  const float *lptr = static_cast<const float *>(input->data());
+  for (int i = 0; i < count; i++) {
+    dptr[i] = pow(lptr[i], x);
+  }
+}
+
+template <>
+void Pow<float, lang::Cpp>(int count, const Blob *lhs, const Blob *rhs,
+                           Blob *ret, Context *ctx) {
   float *dptr = static_cast<float *>(ret->mutable_data());
   const float *lptr = static_cast<const float *>(lhs->data());
   const float *rptr = static_cast<const float *>(rhs->data());
   for (int i = 0; i < count; i++) {
-    dptr[i] = lptr[i] * rptr[i];
+    dptr[i] = pow(lptr[i], rptr[i]);
   }
 }
 
@@ -159,8 +214,15 @@ void Div<float, lang::Cpp>(const size_t num, const float alpha, const Blob *in,
                            Blob *out, Context *ctx) {
   float *outPtr = static_cast<float *>(out->mutable_data());
   const float *inPtr = static_cast<const float *>(in->data());
+  for (size_t i = 0; i < num; i++) outPtr[i] = alpha / inPtr[i];
+}
+template <>
+void LT<float, lang::Cpp>(const size_t num, const Blob *in, const float x,
+                          Blob *out, Context *ctx) {
+  float *outPtr = static_cast<float *>(out->mutable_data());
+  const float *inPtr = static_cast<const float *>(in->data());
   for (size_t i = 0; i < num; i++) {
-    outPtr[i] = alpha / inPtr[i];
+    outPtr[i] = (inPtr[i] < x) ? 1.f : 0.f;
   }
 }
 
@@ -192,9 +254,38 @@ template <>
 void Set<float, lang::Cpp>(const size_t num, const float x, Blob *out,
                            Context *ctx) {
   float *outPtr = static_cast<float *>(out->mutable_data());
-  for (size_t i = 0; i < num; i++)
-    outPtr[i] = x;
+  for (size_t i = 0; i < num; i++) outPtr[i] = x;
+}
+template <>
+void LE<float, lang::Cpp>(const size_t num, const Blob *in, const float x,
+                          Blob *out, Context *ctx) {
+  float *outPtr = static_cast<float *>(out->mutable_data());
+  const float *inPtr = static_cast<const float *>(in->data());
+  for (size_t i = 0; i < num; i++) {
+    outPtr[i] = (inPtr[i] <= x) ? 1.f : 0.f;
+  }
+}
+
+template <>
+void GT<float, lang::Cpp>(const size_t num, const Blob *in, const float x,
+                          Blob *out, Context *ctx) {
+  float *outPtr = static_cast<float *>(out->mutable_data());
+  const float *inPtr = static_cast<const float *>(in->data());
+  for (size_t i = 0; i < num; i++) {
+    outPtr[i] = (inPtr[i] > x) ? 1.f : 0.f;
+  }
+}
+
+template <>
+void GE<float, lang::Cpp>(const size_t num, const Blob *in, const float x,
+                          Blob *out, Context *ctx) {
+  float *outPtr = static_cast<float *>(out->mutable_data());
+  const float *inPtr = static_cast<const float *>(in->data());
+  for (size_t i = 0; i < num; i++) {
+    outPtr[i] = (inPtr[i] >= x) ? 1.f : 0.f;
+  }
 }
+
 #ifdef USE_CBLAS
 template <>
 void Dot<float, lang::Cpp>(const size_t num, const Blob *in1, const Blob *in2,
@@ -224,7 +315,6 @@ void GEMM<float, lang::Cpp>(const bool transA, const bool transB,
 
 #endif  // USE_CBLAS
 
-
 }  // namespace singa
 
 #endif  // SINGA_CORE_TENSOR_TENSOR_MATH_CPP_H_

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/3e2507b7/src/core/tensor/tensor_math_cuda.h
----------------------------------------------------------------------
diff --git a/src/core/tensor/tensor_math_cuda.h b/src/core/tensor/tensor_math_cuda.h
index 26299ba..4a2ba66 100644
--- a/src/core/tensor/tensor_math_cuda.h
+++ b/src/core/tensor/tensor_math_cuda.h
@@ -73,25 +73,6 @@ void Sum<float, lang::Cuda>(int count, const Blob *input, float *ret,
   cuda::sum(count, in, ret);
 }
 
-// TODO(wangwei) optimize using stream
-template <>
-void SumRows<float, lang::Cuda>(int nrow, int ncol, const Blob *input,
-                                Blob *ret, Context *ctx) {
-  float *dptr = static_cast<float *>(ret->mutable_data());
-  const float *in = static_cast<const float *>(input->data());
-  cuda::sum_row(nrow, ncol, ncol, in, dptr);
-}
-
-// Sum the rows of the input matrix into a vector
-// TODO(wangwei) optimize using stream
-template <>
-void SumColumns<float, lang::Cuda>(int nrow, int ncol, const Blob *input,
-                                   Blob *ret, Context *ctx) {
-  float *dptr = static_cast<float *>(ret->mutable_data());
-  const float *in = static_cast<const float *>(input->data());
-  cuda::sum_col(nrow, ncol, ncol, in, dptr);
-}
-
 // follow the consistency guide of math API
 template <>
 void Div<float, lang::Cuda>(const size_t num, const float alpha, const Blob *in,
@@ -144,7 +125,42 @@ void GEMM<float, lang::Cuda>(const bool transA, const bool transB,
   CUBLAS_CHECK(cublasSgemm(handle, transb, transa, ncolB, nrowA, ncolA, &alpha,
                            BPtr, ldb, APtr, lda, &beta, CPtr, ldc));
 }
+
+template <>
+void GE<float, lang::Cuda>(const size_t num, const Blob* in, const float x,
+                                   Blob* out, Context *ctx) {
+  float* outPtr = static_cast<float*>(out->mutable_data());
+  const float* inPtr = static_cast<const float*>(in->data());
+  cuda::GE(num, inPtr, x, outPtr, ctx->stream);
+}
+template <>
+void GT<float, lang::Cuda>(const size_t num, const Blob* in, const float x,
+                                   Blob* out,  Context *ctx) {
+  float* outPtr = static_cast<float*>(out->mutable_data());
+  const float* inPtr = static_cast<const float*>(in->data());
+  cuda::GT(num, inPtr, x, outPtr, ctx->stream);
+}
+template <>
+void LE<float, lang::Cuda>(const size_t num, const Blob* in, const float x,
+                                   Blob* out, Context *ctx) {
+  float* outPtr = static_cast<float*>(out->mutable_data());
+  const float* inPtr = static_cast<const float*>(in->data());
+  cuda::LE(num, inPtr, x, outPtr, ctx->stream);
+}
+template <>
+void LT<float, lang::Cuda>(const size_t num, const Blob* in, const float x,
+                                   Blob* out,  Context *ctx) {
+  float* outPtr = static_cast<float*>(out->mutable_data());
+  const float* inPtr = static_cast<const float*>(in->data());
+  cuda::LT(num, inPtr, x, outPtr, ctx->stream);
+}
+
+
+
+
+
 }  // namespace singa
 
 #endif  // USE_CUDA
 #endif  // SINGA_CORE_TENSOR_TENSOR_MATH_CUDA_H_
+

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/3e2507b7/src/model/layer/activation.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/activation.cc b/src/model/layer/activation.cc
new file mode 100644
index 0000000..464e24d
--- /dev/null
+++ b/src/model/layer/activation.cc
@@ -0,0 +1,67 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "singa/model/layer.h"
+#include "./activation.h"
+namespace singa {
+
+void Activation::Setup(const LayerConf& conf) {
+  Layer::Setup(conf);
+  mode_ = conf.type();
+  if (mode_ == "RELU") {
+    neg_slope_ = conf.relu_conf().negative_slope();
+  }
+}
+
+const Tensor Activation::Forward(int flag, const Tensor& input) {
+  Tensor output;
+  if (mode_ == "SIGMOID") {
+    output = Sigmoid(input);
+    buf_.push(output);
+  } else if (mode_ == "TANH") {
+    output = Tanh(input);
+    buf_.push(output);
+  } else if (mode_ == "RELU") {
+    output = ReLU(input);
+    buf_.push(input);
+  } else {
+    LOG(FATAL) << "Unkown activation: " << mode_;
+  }
+  return output;
+}
+
+const std::pair<Tensor, vector<Tensor>> Activation::Backward(
+    int flag, const Tensor& grad) {
+  vector<Tensor> param_grad;
+  // inout means either input or output, but only one is valid for an
+  // activation.
+  Tensor input_grad, inout = buf_.top();
+  buf_.pop();
+  if (mode_ == "SIGMOID") {
+    input_grad = grad * inout * (inout * (-1.f) + 1.f);
+  } else if (mode_ == "TANH") {
+    input_grad = grad * (inout * inout * (-1.f) + 1.f);
+  } else if (mode_ == "RELU") {
+    input_grad = grad * (inout > 0.f) + (inout <= 0.f) * neg_slope_;
+  } else {
+    LOG(FATAL) << "Unkown activation: " << mode_;
+  }
+  return std::make_pair(input_grad, param_grad);
+}
+
+}  // namespace singa

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/3e2507b7/src/model/layer/activation.h
----------------------------------------------------------------------
diff --git a/src/model/layer/activation.h b/src/model/layer/activation.h
new file mode 100644
index 0000000..1747577
--- /dev/null
+++ b/src/model/layer/activation.h
@@ -0,0 +1,51 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#ifndef SINGA_MODEL_LAYER_ACTIVATION_H_
+#define SINGA_MODEL_LAYER_ACTIVATION_H_
+#include <utility>
+#include <string>
+#include <vector>
+#include "singa/model/layer.h"
+
+namespace singa {
+class Activation : public Layer {
+ public:
+  /// \copydoc Layer::layer_type()
+  const std::string layer_type() const override { return "Activation"; }
+
+  /// \copydoc Layer::Setup(const LayerConf&);
+  void Setup(const LayerConf& conf) override;
+
+  /// \copydoc Layer::Forward(int flag, const Tensor&)
+  const Tensor Forward(int flag, const Tensor& input) override;
+
+  /// \copydoc Layer::Backward(int, const Tensor&, const Tensor&);
+  const std::pair<Tensor, vector<Tensor>> Backward(int flag,
+                                                   const Tensor& grad) override;
+
+  const std::string Mode() const { return mode_; }
+
+  const float Negative_slope() const { return neg_slope_; }
+
+ protected:
+  std::string mode_;
+  std::stack<Tensor> buf_;
+  float neg_slope_;
+};
+}  // namespace singa
+#endif  // SINGA_MODEL_LAYER_ACTIVATION_H_

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/3e2507b7/src/model/layer/cudnn_activation.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/cudnn_activation.cc b/src/model/layer/cudnn_activation.cc
new file mode 100644
index 0000000..73c70d7
--- /dev/null
+++ b/src/model/layer/cudnn_activation.cc
@@ -0,0 +1,115 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "singa_config.h"
+#ifdef USE_CUDNN
+#include "./cudnn_activation.h"
+#include <cudnn.h>
+
+#include "./cudnn_utils.h"
+#include "singa/core/common.h"
+#include "singa/utils/logging.h"
+
+namespace singa {
+CudnnActivation::~CudnnActivation() {
+  if (acti_desc_ != nullptr)
+    CUDNN_CHECK(cudnnDestroyActivationDescriptor(acti_desc_));
+  if (desc_ != nullptr) CUDNN_CHECK(cudnnDestroyTensorDescriptor(desc_));
+}
+
+void CudnnActivation::InitCudnn(size_t size, DataType dtype) {
+  CHECK(!has_init_cudnn_);
+  CUDNN_CHECK(cudnnCreateTensorDescriptor(&desc_));
+  CUDNN_CHECK(cudnnCreateActivationDescriptor(&acti_desc_));
+
+  CUDNN_CHECK(cudnnSetTensor4dDescriptor(
+      desc_, CUDNN_TENSOR_NCHW, GetCudnnDataType(dtype), 1, 1, 1, size));
+
+  if (mode_ == "SIGMOID")
+    cudnn_mode_ = CUDNN_ACTIVATION_SIGMOID;
+  else if (mode_ == "TANH")
+    cudnn_mode_ = CUDNN_ACTIVATION_TANH;
+  else if (mode_ == "RELU")
+    cudnn_mode_ = CUDNN_ACTIVATION_RELU;
+  else
+    LOG(FATAL) << "Unkown activation: " << mode_;
+
+  nan_opt_ = CUDNN_PROPAGATE_NAN;
+  CUDNN_CHECK(
+      cudnnSetActivationDescriptor(acti_desc_, cudnn_mode_, nan_opt_, 0.0f));
+  has_init_cudnn_ = true;
+}
+
+const Tensor CudnnActivation::Forward(int flag, const Tensor& input) {
+  auto size = input.Size();
+  DataType dtype = input.data_type();
+  if (!has_init_cudnn_) {
+    InitCudnn(size, dtype);
+  }
+  Tensor output;
+  output.ResetLike(input);
+  output.device()->Exec([input, output, this](Context* ctx) {
+    Blob* inblob = input.blob(), * outblob = output.blob();
+    float alpha = 1.0f, beta = 0.0f;
+#if CUDNN_VERSION_MAJOR == 5
+    CUDNN_CHECK(cudnnActivationForward(
+        ctx->cudnn_handle, this->acti_desc_, &alpha, this->desc_,
+        inblob->data(), &beta, this->desc_, outblob->mutable_data()));
+#elif CUDNN_VERSION_MAJOR == 4
+    CUDNN_CHECK(cudnnActivationForward_v4(
+        ctx->cudnn_handle, this->acti_desc_, &alpha, this->desc_,
+        inblob->data(), &beta, this->desc_, outblob->mutable_data()));
+#endif
+  }, {input.blob()}, {output.blob()});
+  if (cudnn_mode_ == CUDNN_ACTIVATION_SIGMOID ||
+      cudnn_mode_ == CUDNN_ACTIVATION_TANH) {
+    buf_.push(output);
+  } else if (cudnn_mode_ == CUDNN_ACTIVATION_RELU) {
+    buf_.push(input);
+  }
+  return output;
+}
+
+const std::pair<Tensor, vector<Tensor>> CudnnActivation::Backward(
+    int flag, const Tensor& grad) {
+  vector<Tensor> param_grad;
+  Tensor dx;  // inout = buf_.top();
+  // inout means either used as input or output, only one is valid for one type
+  // of activation
+  Tensor inout = buf_.top();
+  buf_.pop();
+  dx.ResetLike(grad);
+  dx.device()->Exec([dx, grad, inout, this](Context* ctx) {
+    Blob* dyblob = grad.blob(), * dxblob = dx.blob(), * yblob = inout.blob(),
+          * xblob = inout.blob();
+    float alpha = 1.0f, beta = 0.0f;
+#if CUDNN_VERSION_MAJOR == 5
+    CUDNN_CHECK(cudnnActivationBackward(
+        ctx->cudnn_handle, this->acti_desc_, &alpha, this->desc_, yblob->data(),
+        this->desc_, dyblob->data(), this->desc_, xblob->data(), &beta,
+        this->desc_, dxblob->mutable_data()));
+#elif CUDNN_VERSION_MAJOR == 4
+    CUDNN_CHECK(cudnnActivationBackward_v4(
+        ctx->cudnn_handle, this->acti_desc_, &alpha, this->desc_, yblob->data(),
+        this->desc_, dyblob->data(), this->desc_, xblob->data(), &beta,
+        this->desc_, dxblob->mutable_data()));
+#endif
+  }, {grad.blob(), inout.blob()}, {dx.blob()});
+  return std::make_pair(dx, param_grad);
+}
+}  // namespace singa
+#endif  // USE_CUDNN

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/3e2507b7/src/model/layer/cudnn_activation.h
----------------------------------------------------------------------
diff --git a/src/model/layer/cudnn_activation.h b/src/model/layer/cudnn_activation.h
new file mode 100644
index 0000000..b572db7
--- /dev/null
+++ b/src/model/layer/cudnn_activation.h
@@ -0,0 +1,58 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef SINGA_MODEL_LAYER_CUDNN_ACTIVATION_H_
+#define SINGA_MODEL_LAYER_CUDNN_ACTIVATION_H_
+#include "singa_config.h"
+#ifdef USE_CUDNN
+#include <cudnn.h>
+#include <utility>
+#include <string>
+#include <vector>
+
+#include "./activation.h"
+#include "singa/core/common.h"
+#include "singa/model/layer.h"
+#include "singa/proto/core.pb.h"
+
+namespace singa {
+class CudnnActivation : public Activation {
+ public:
+  ~CudnnActivation();
+  /// \copydoc Layer::layer_type()
+  const std::string layer_type() const override { return "CudnnActivation"; }
+
+  const Tensor Forward(int flag, const Tensor& input) override;
+  const std::pair<Tensor, vector<Tensor>> Backward(int flag,
+                                                   const Tensor& grad) override;
+
+  /// Init cudnn related data structures.
+  void InitCudnn(size_t size, DataType dtype);
+
+  const cudnnActivationMode_t CudnnMode() const { return cudnn_mode_; }
+
+ private:
+  bool has_init_cudnn_ = false;
+  cudnnActivationDescriptor_t acti_desc_;
+  cudnnTensorDescriptor_t desc_;
+  cudnnNanPropagation_t nan_opt_;
+  cudnnActivationMode_t cudnn_mode_;
+};
+}  // namespace
+#endif  // USE_CUDNN
+#endif  // SINGA_MODEL_LAYER_CUDNN_ACTIVATION_H_

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/3e2507b7/src/model/layer/cudnn_softmax.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/cudnn_softmax.cc b/src/model/layer/cudnn_softmax.cc
new file mode 100644
index 0000000..bc7fe78
--- /dev/null
+++ b/src/model/layer/cudnn_softmax.cc
@@ -0,0 +1,77 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "singa_config.h"
+#include "./cudnn_softmax.h"
+#ifdef USE_CUDNN
+#include <cudnn.h>
+#include "./cudnn_utils.h"
+#include "singa/utils/logging.h"
+namespace singa {
+CudnnSoftmax::~CudnnSoftmax() {
+  if (desc_ != nullptr) CUDNN_CHECK(cudnnDestroyTensorDescriptor(desc_));
+}
+
+void CudnnSoftmax::InitCudnn(size_t size, DataType dtype) {
+  CHECK(!has_init_cudnn_);
+  CUDNN_CHECK(cudnnCreateTensorDescriptor(&desc_));
+
+  CUDNN_CHECK(cudnnSetTensor4dDescriptor(
+      desc_, CUDNN_TENSOR_NCHW, GetCudnnDataType(dtype), 1, 1, 1, size));
+
+  algorithm_ = CUDNN_SOFTMAX_ACCURATE;
+  mode_ = CUDNN_SOFTMAX_MODE_INSTANCE;
+  has_init_cudnn_ = true;
+}
+
+const Tensor CudnnSoftmax::Forward(int flag, const Tensor& input) {
+  auto size = input.Size();
+  DataType dtype = input.data_type();
+  if (!has_init_cudnn_) {
+    InitCudnn(size, dtype);
+  }
+  Tensor output;
+  output.ResetLike(input);
+  output.device()->Exec([input, output, this](Context* ctx) {
+    Blob* inblob = input.blob(), * outblob = output.blob();
+    float alpha = 1.0f, beta = 0.0f;
+    cudnnSoftmaxForward(ctx->cudnn_handle, this->algorithm_, this->mode_,
+                        &alpha, this->desc_, inblob->data(), &beta, this->desc_,
+                        outblob->mutable_data());
+  }, {input.blob()}, {output.blob()});
+  buf_.push(output);
+  return output;
+}
+
+const std::pair<Tensor, vector<Tensor>> CudnnSoftmax::Backward(
+    int flag, const Tensor& grad) {
+  vector<Tensor> param_grad;
+  Tensor dx, output = buf_.top();
+  buf_.pop();
+  dx.ResetLike(grad);
+  dx.device()->Exec([dx, grad, output, this](Context* ctx) {
+    Blob* dyblob = grad.blob(), * dxblob = dx.blob(), * yblob = output.blob();
+    float alpha = 1.0f, beta = 0.0f;
+    cudnnSoftmaxBackward(ctx->cudnn_handle, this->algorithm_, this->mode_,
+                         &alpha, this->desc_, yblob->data(), this->desc_,
+                         dyblob->data(), &beta, this->desc_,
+                         dxblob->mutable_data());
+  }, {grad.blob(), output.blob()}, {dx.blob()});
+  return std::make_pair(dx, param_grad);
+}
+}  // namespace singa
+#endif  // USE_CUDNN

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/3e2507b7/src/model/layer/cudnn_softmax.h
----------------------------------------------------------------------
diff --git a/src/model/layer/cudnn_softmax.h b/src/model/layer/cudnn_softmax.h
new file mode 100644
index 0000000..ee92d6f
--- /dev/null
+++ b/src/model/layer/cudnn_softmax.h
@@ -0,0 +1,54 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef SINGA_MODEL_LAYER_CUDNN_SOFTMAX_H_
+#define SINGA_MODEL_LAYER_CUDNN_SOFTMAX_H_
+#ifdef USE_CUDNN
+#include <cudnn.h>
+#include <utility>
+#include <string>
+#include <vector>
+
+#include "./softmax.h"
+#include "singa/core/common.h"
+#include "singa/model/layer.h"
+#include "singa/proto/core.pb.h"
+
+namespace singa {
+class CudnnSoftmax : public Softmax {
+ public:
+  ~CudnnSoftmax();
+  /// \copydoc Layer::layer_type()
+  const std::string layer_type() const override { return "CudnnSoftmax"; }
+
+  const Tensor Forward(int flag, const Tensor& input) override;
+  const std::pair<Tensor, vector<Tensor>> Backward(int flag,
+                                                   const Tensor& grad) override;
+
+  /// Init cudnn related data structures.
+  void InitCudnn(size_t size, DataType dtype);
+
+ private:
+  bool has_init_cudnn_ = false;
+  cudnnTensorDescriptor_t desc_;
+  cudnnSoftmaxAlgorithm_t algorithm_;
+  cudnnSoftmaxMode_t mode_;
+};
+}  // namespace
+#endif  // USE_CUDNN
+#endif  // SINGA_MODEL_LAYER_CUDNN_SOFTMAX_H_

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/3e2507b7/src/model/layer/softmax.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/softmax.cc b/src/model/layer/softmax.cc
new file mode 100644
index 0000000..813ebf0
--- /dev/null
+++ b/src/model/layer/softmax.cc
@@ -0,0 +1,64 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "./softmax.h"
+namespace singa {
+
+void Softmax::Setup(const LayerConf& conf) {
+  Layer::Setup(conf);
+  axis_ = conf.softmax_conf().axis();  // default is 1
+}
+
+const Tensor Softmax::Forward(int flag, const Tensor& input) {
+  if (input.nDim() == 1) {
+    Tensor tmp = Reshape(input, Shape{1, input.Size()});
+    buf_.push(SoftMax(tmp, 0));
+  } else {
+    buf_.push(SoftMax(input, axis_));
+  }
+  return buf_.top();
+}
+
+const std::pair<Tensor, vector<Tensor>> Softmax::Backward(int flag,
+                                                          const Tensor& grad) {
+  size_t nrow = 1, ncol = grad.Size();
+  if (grad.nDim() > 1 && axis_ > 0) {
+    nrow = Product(grad.shape(), 0, axis_);
+    ncol = Product(grad.shape(), axis_, grad.nDim());
+  }
+  Tensor input_grad = grad.Clone();
+  input_grad.Reshape(Shape{nrow, ncol});
+  Tensor y = buf_.top();
+  buf_.pop();
+  CHECK(y.shape() == input_grad.shape());
+  Tensor sigma = input_grad * y;
+  Tensor sum(Shape{nrow}, grad.device(), grad.data_type());
+  SumColumns(sigma, &sum);
+  // dL / dy_i = grad_i
+  // dy_i / dx_i = y_i - y_i^2, if i == j
+  // dy_i / dx_j = - y_i * y_j, if i != j
+  // dL / dx_i = sum_j((dL / dy_j) * (dy_j / dx_i))
+  // dL / dx_i = y_i * (grad_i - sum), where sum = sum_i(grad_i * y_i);
+  SubColumn(sum, &input_grad);
+  input_grad = input_grad * y;
+  // Mult(input_grad, y, &input_grad);
+  vector<Tensor> param_grad;
+  return std::make_pair(input_grad, param_grad);
+}
+
+}  // namespace singa

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/3e2507b7/src/model/layer/softmax.h
----------------------------------------------------------------------
diff --git a/src/model/layer/softmax.h b/src/model/layer/softmax.h
new file mode 100644
index 0000000..ea3a70a
--- /dev/null
+++ b/src/model/layer/softmax.h
@@ -0,0 +1,45 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#ifndef SINGA_MODEL_LAYER_SOFTMAX_H_
+#define SINGA_MODEL_LAYER_SOFTMAX_H_
+#include "singa/model/layer.h"
+#include <stack>
+namespace singa {
+class Softmax : public Layer {
+ public:
+  /// \copydoc Layer::layer_type()
+  const std::string layer_type() const override { return "Softmax"; }
+
+  /// \copydoc Layer::Setup(const LayerConf&);
+  void Setup(const LayerConf& conf) override;
+
+  /// \copydoc Layer::Forward(int flag, const Tensor&)
+  const Tensor Forward(int flag, const Tensor& input) override;
+
+  /// \copydoc Layer::Backward(int flag, const Tensor&, const Tensor&);
+  const std::pair<Tensor, vector<Tensor>> Backward(int flag,
+                                                   const Tensor& grad) override;
+
+  const int Axis() const { return axis_; }
+
+ protected:
+  int axis_;
+  std::stack<Tensor> buf_;
+};
+}  // namespace singa
+#endif  // SINGA_MODEL_LAYER_SOFTMAX_H_

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/3e2507b7/test/singa/test_activation.cc
----------------------------------------------------------------------
diff --git a/test/singa/test_activation.cc b/test/singa/test_activation.cc
new file mode 100644
index 0000000..9e34282
--- /dev/null
+++ b/test/singa/test_activation.cc
@@ -0,0 +1,133 @@
+/************************************************************
+*
+* Licensed to the Apache Software Foundation (ASF) under one
+* or more contributor license agreements.  See the NOTICE file
+* distributed with this work for additional information
+* regarding copyright ownership.  The ASF licenses this file
+* to you under the Apache License, Version 2.0 (the
+* "License"); you may not use this file except in compliance
+* with the License.  You may obtain a copy of the License at
+*
+*   http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing,
+* software distributed under the License is distributed on an
+* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+* KIND, either express or implied.  See the License for the
+* specific language governing permissions and limitations
+* under the License.
+*
+*************************************************************/
+
+#include "../src/model/layer/activation.h"
+#include "gtest/gtest.h"
+#include <math.h> // exp, tanh
+
+using singa::Activation;
+TEST(Activation, Setup) {
+  Activation acti;
+  EXPECT_EQ("Activation", acti.layer_type());
+
+  singa::LayerConf conf;
+  conf.set_type("RELU");
+  singa::ReLUConf* reluconf = conf.mutable_relu_conf();
+  reluconf->set_negative_slope(0.5);
+
+  acti.Setup(conf);
+  EXPECT_EQ("RELU", acti.Mode());
+  EXPECT_EQ(0.5f, acti.Negative_slope());
+}
+
+TEST(Activation, Forward) {
+  const float x[] = {1.0f, 2.0f, 3.0f, -2.0f, -3.0f, -4.0};
+  size_t n = sizeof(x) / sizeof(float);
+  singa::Tensor in(singa::Shape{n});
+  in.CopyDataFromHostPtr<float>(x, n);
+
+  float neg_slope = 0.5f;
+  std::string types[] = {"SIGMOID","TANH","RELU"};
+  for (int j = 0; j < 3; j++) {
+    Activation acti;
+    singa::LayerConf conf;
+    std::string layertype = types[j];
+    conf.set_type(layertype);
+    if (layertype == "RELU") {
+      singa::ReLUConf* reluconf = conf.mutable_relu_conf();
+      reluconf->set_negative_slope(neg_slope);
+    }
+    acti.Setup(conf);
+
+    singa::Tensor out = acti.Forward(0, in);
+
+    const float* yptr = out.data<const float*>();
+    EXPECT_EQ(n, out.Size());
+
+    float* y = new float[n];
+    if (acti.Mode() == "SIGMOID") {
+      for (size_t i = 0; i < n; i++)
+        y[i] = 1.f / (1.f + exp(-x[i]));
+    }
+    else if (acti.Mode() == "TANH") {
+      for (size_t i = 0; i < n; i++)
+        y[i] = tanh(x[i]);
+    }
+    else if (acti.Mode() == "RELU") {
+      for (size_t i = 0; i < n; i++)
+        y[i] = (x[i] >= 0.f) ? x[i] : 0.f;
+    }
+    else
+      LOG(FATAL) << "Unkown activation: " << acti.Mode();
+    EXPECT_FLOAT_EQ(y[0], yptr[0]);
+    EXPECT_FLOAT_EQ(y[4], yptr[4]);
+    EXPECT_FLOAT_EQ(y[5], yptr[5]);
+  }
+}
+
+TEST(Activation, Backward) {
+  const float x[] = {1.0f, 2.0f, 3.0f, -2.0f, -3.0f, -4.0};
+  size_t n = sizeof(x) / sizeof(float);
+  singa::Tensor in(singa::Shape{n});
+  in.CopyDataFromHostPtr<float>(x, n);
+
+  float neg_slope = 0.5f;
+  std::string types[] = {"SIGMOID","TANH","RELU"};  
+  for (int j = 0; j < 3; j++) {
+    Activation acti;
+    singa::LayerConf conf;
+    std::string layertype = types[j];
+    conf.set_type(layertype);
+    if (layertype == "RELU") {
+      singa::ReLUConf* reluconf = conf.mutable_relu_conf();
+      reluconf->set_negative_slope(neg_slope);
+    }
+    acti.Setup(conf);
+
+    singa::Tensor out = acti.Forward(0, in);
+    const float* yptr = out.data<const float*>();
+
+    const float grad[] = {2.0f, -3.0f, 1.0f, 3.0f, -1.0f, -2.0};
+    singa::Tensor out_diff(singa::Shape{n});
+    out_diff.CopyDataFromHostPtr<float>(grad, n);
+    const auto in_diff = acti.Backward(0, out_diff);
+    const float* xptr = in_diff.first.data<const float*>();
+
+    float* dx = new float[n];
+    if (acti.Mode() == "SIGMOID") {
+      for (size_t i = 0; i < n; i++)
+        dx[i] = grad[i] * yptr[i] * (1. - yptr[i]);
+    }
+    else if (acti.Mode() == "TANH") {
+      for (size_t i = 0; i < n; i++)
+        dx[i] = grad[i] * (1 - yptr[i] * yptr[i]);
+    }
+    else if (acti.Mode() == "RELU") {
+      for (size_t i = 0; i < n; i++)
+        dx[i] = grad[i] * (x[i] > 0.f) + acti.Negative_slope() * (x[i] <= 0.f);
+    }
+    else
+      LOG(FATAL) << "Unkown activation: " << acti.Mode();
+    EXPECT_FLOAT_EQ(dx[0], xptr[0]);
+    EXPECT_FLOAT_EQ(dx[4], xptr[4]);
+    EXPECT_FLOAT_EQ(dx[5], xptr[5]);
+  }
+}

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/3e2507b7/test/singa/test_cudnn_activation.cc
----------------------------------------------------------------------
diff --git a/test/singa/test_cudnn_activation.cc b/test/singa/test_cudnn_activation.cc
new file mode 100644
index 0000000..ee9f9b5
--- /dev/null
+++ b/test/singa/test_cudnn_activation.cc
@@ -0,0 +1,136 @@
+/************************************************************
+*
+* Licensed to the Apache Software Foundation (ASF) under one
+* or more contributor license agreements.  See the NOTICE file
+* distributed with this work for additional information
+* regarding copyright ownership.  The ASF licenses this file
+* to you under the Apache License, Version 2.0 (the
+* "License"); you may not use this file except in compliance
+* with the License.  You may obtain a copy of the License at
+*
+*   http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing,
+* software distributed under the License is distributed on an
+* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+* KIND, either express or implied.  See the License for the
+* specific language governing permissions and limitations
+* under the License.
+*
+*************************************************************/
+#include "singa_config.h"
+#ifdef USE_CUDNN
+
+#include "singa/proto/core.pb.h"
+#include "../src/model/layer/cudnn_activation.h"
+#include "gtest/gtest.h"
+#include <math.h>  // exp tanh
+#include <cudnn.h>
+
+using singa::CudnnActivation;
+TEST(TCudnnActivation, Setup) {
+  CudnnActivation acti;
+  EXPECT_EQ("CudnnActivation", acti.layer_type());
+
+  singa::LayerConf conf;
+  conf.set_type("RELU");
+  singa::ReLUConf* reluconf = conf.mutable_relu_conf();
+  reluconf->set_negative_slope(0.5f);
+
+  acti.Setup(conf);
+  acti.InitCudnn(1, singa::kFloat32);
+  EXPECT_EQ(CUDNN_ACTIVATION_RELU, acti.CudnnMode());
+  EXPECT_EQ(0.5f, acti.Negative_slope());
+}
+
+TEST(TCudnnActivation, Forward) {
+  const float x[] = {1.0f, 2.0f, 3.0f, -2.0f, -3.0f, -4.0};
+  size_t n = sizeof(x) / sizeof(float);
+  singa::CudaGPU cuda(0, 1);
+  singa::Tensor in(singa::Shape{n}, &cuda);
+  in.CopyDataFromHostPtr<float>(x, n);
+
+  float neg_slope = 0.5f;
+  std::string types[] = {"SIGMOID", "TANH", "RELU"};
+  for (int j = 0; j < 3; j++) {
+    CudnnActivation acti;
+    singa::LayerConf conf;
+    std::string layertype = types[j];
+    conf.set_type(layertype);
+    if (layertype == "RELU") {
+      singa::ReLUConf* reluconf = conf.mutable_relu_conf();
+      reluconf->set_negative_slope(neg_slope);
+    }
+    acti.Setup(conf);
+    // acti.InitCudnn(n, singa::kFloat32);
+
+    singa::Tensor out = acti.Forward(0, in);
+    EXPECT_EQ(n, out.Size());
+    singa::CppCPU host(0, 1);
+    out.ToDevice(&host);
+    const float* yptr = out.data<const float*>();
+    float* y = new float[n];
+    if (acti.Mode() == "SIGMOID") {
+      for (size_t i = 0; i < n; i++) y[i] = 1.f / (1.f + exp(-x[i]));
+    } else if (acti.Mode() == "TANH") {
+      for (size_t i = 0; i < n; i++) y[i] = tanh(x[i]);
+    } else if (acti.Mode() == "RELU") {
+      for (size_t i = 0; i < n; i++) y[i] = (x[i] >= 0.f) ? x[i] : 0.f;
+    } else
+      LOG(FATAL) << "Unkown activation: " << acti.Mode();
+    EXPECT_FLOAT_EQ(y[0], yptr[0]);
+    EXPECT_FLOAT_EQ(y[4], yptr[4]);
+    EXPECT_FLOAT_EQ(y[5], yptr[5]);
+  }
+}
+
+TEST(TCudnnActivation, Backward) {
+  const float x[] = {2.0f, 3.0f, 3.0f, 7.f, 0.0f, 5.0, 1.5, 2.5, -2.5, 1.5};
+  size_t n = sizeof(x) / sizeof(float);
+  singa::CudaGPU cuda(0, 1);
+  singa::Tensor in(singa::Shape{n}, &cuda);
+  in.CopyDataFromHostPtr<float>(x, n);
+  float neg_slope = 0.5f;
+  std::string types[] = {"SIGMOID", "TANH", "RELU"};
+  for (int j = 0; j < 3; j++) {
+    CudnnActivation acti;
+    singa::LayerConf conf;
+    std::string layertype = types[j];
+    conf.set_type(layertype);
+    if (layertype == "RELU") {
+      singa::ReLUConf* reluconf = conf.mutable_relu_conf();
+      reluconf->set_negative_slope(neg_slope);
+    }
+    acti.Setup(conf);
+    acti.InitCudnn(n, singa::kFloat32);
+    singa::Tensor out = acti.Forward(0, in);
+    EXPECT_EQ(n, out.Size());
+    singa::CppCPU host(0, 1);
+    out.ToDevice(&host);
+    const float* yptr = out.data<const float*>();
+
+    const float grad[] = {2.0f, 1.0f, 2.0f, 0.0f, -2.0f,
+                          -1.0, 1.5,  2.5,  -1.5, -2.5};
+    singa::Tensor out_diff(singa::Shape{n}, &cuda);
+    out_diff.CopyDataFromHostPtr<float>(grad, n);
+    const auto ret = acti.Backward(0, out_diff);
+    singa::Tensor in_diff = ret.first;
+    in_diff.ToDevice(&host);
+    const float* xptr = in_diff.data<const float*>();
+    float* dx = new float[n];
+    if (acti.Mode() == "SIGMOID") {
+      for (size_t i = 0; i < n; i++) dx[i] = grad[i] * yptr[i] * (1. - yptr[i]);
+    } else if (acti.Mode() == "TANH") {
+      for (size_t i = 0; i < n; i++) dx[i] = grad[i] * (1. - yptr[i] * yptr[i]);
+    } else if (acti.Mode() == "RELU") {
+      for (size_t i = 0; i < n; i++)
+        dx[i] =
+            grad[i] * (x[i] > 0.f);  //+ acti.Negative_slope() * (x[i] <= 0.f);
+    } else
+      LOG(FATAL) << "Unkown activation: " << acti.Mode();
+    for (size_t i = 0; i < n; i++) {
+      EXPECT_NEAR(dx[i], xptr[i], 1e-7);
+    }
+  }
+}
+#endif  // USE_CUDNN

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/3e2507b7/test/singa/test_cudnn_dropout.cc
----------------------------------------------------------------------
diff --git a/test/singa/test_cudnn_dropout.cc b/test/singa/test_cudnn_dropout.cc
index e1a6333..32572d0 100644
--- a/test/singa/test_cudnn_dropout.cc
+++ b/test/singa/test_cudnn_dropout.cc
@@ -21,7 +21,7 @@
 #include "../src/model/layer/cudnn_dropout.h"
 #ifdef USE_CUDNN
 // cudnn dropout is added in cudnn 5
-#if CUDNN_MAJOR_VERSION >= 5
+#if CUDNN_VERSION_MAJOR >= 5
 
 #include "gtest/gtest.h"
 

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/3e2507b7/test/singa/test_cudnn_softmax.cc
----------------------------------------------------------------------
diff --git a/test/singa/test_cudnn_softmax.cc b/test/singa/test_cudnn_softmax.cc
new file mode 100644
index 0000000..dcbf1ed
--- /dev/null
+++ b/test/singa/test_cudnn_softmax.cc
@@ -0,0 +1,107 @@
+/************************************************************
+*
+* Licensed to the Apache Software Foundation (ASF) under one
+* or more contributor license agreements.  See the NOTICE file
+* distributed with this work for additional information
+* regarding copyright ownership.  The ASF licenses this file
+* to you under the Apache License, Version 2.0 (the
+* "License"); you may not use this file except in compliance
+* with the License.  You may obtain a copy of the License at
+*
+*   http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing,
+* software distributed under the License is distributed on an
+* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+* KIND, either express or implied.  See the License for the
+* specific language governing permissions and limitations
+* under the License.
+*
+*************************************************************/
+#include "singa_config.h"
+#ifdef USE_CUDNN
+
+#include "../src/model/layer/cudnn_softmax.h"
+#include "gtest/gtest.h"
+#include <math.h>  // exp
+#include <cudnn.h>
+
+using singa::CudnnSoftmax;
+TEST(CudnnSoftmax, Setup) {
+  CudnnSoftmax sft;
+  EXPECT_EQ("CudnnSoftmax", sft.layer_type());
+
+  singa::LayerConf conf;
+  singa::SoftmaxConf* softmaxconf = conf.mutable_softmax_conf();
+  softmaxconf->set_axis(2);
+
+  sft.Setup(conf);
+  sft.InitCudnn(1, singa::kFloat32);
+  EXPECT_EQ(2, sft.Axis());
+}
+
+TEST(CudnnSoftmax, Forward) {
+  const float x[] = {1.0f, 2.0f, 0.0f, -2.0f, -3.0f, -1.0};
+  size_t n = sizeof(x) / sizeof(float);
+  singa::CudaGPU cuda(0, 1);
+  singa::Tensor in(singa::Shape{n}, &cuda);
+  in.CopyDataFromHostPtr<float>(x, n);
+
+  int axis = 1;
+  CudnnSoftmax sft;
+  singa::LayerConf conf;
+  singa::SoftmaxConf* softmaxconf = conf.mutable_softmax_conf();
+  softmaxconf->set_axis(axis);
+  sft.Setup(conf);
+  sft.InitCudnn(n, singa::kFloat32);
+
+  singa::Tensor out = sft.Forward(0, in);
+  singa::CppCPU host(0, 1);
+  out.ToDevice(&host);
+  const float* yptr = out.data<const float*>();
+  EXPECT_EQ(n, out.Size());
+
+  float* y = new float[n];
+  float sigma = 0.f;
+  for (size_t i = 0; i < n; i++) sigma += exp(x[i]);
+  for (size_t i = 0; i < n; i++) y[i] = exp(x[i]) / sigma;
+  EXPECT_FLOAT_EQ(y[0], yptr[0]);
+  EXPECT_FLOAT_EQ(y[4], yptr[4]);
+  EXPECT_FLOAT_EQ(y[5], yptr[5]);
+}
+
+TEST(CudnnSoftmax, Backward) {
+  const float x[] = {1.0f, 2.0f, 3.0f, -2.0f, -3.0f, -1.0};
+  size_t n = sizeof(x) / sizeof(float);
+  singa::CudaGPU cuda(0, 1);
+  singa::Tensor in(singa::Shape{n}, &cuda);
+  in.CopyDataFromHostPtr<float>(x, n);
+
+  int axis = 1;
+  CudnnSoftmax sft;
+  singa::LayerConf conf;
+  singa::SoftmaxConf* softmaxconf = conf.mutable_softmax_conf();
+  softmaxconf->set_axis(axis);
+  sft.Setup(conf);
+  singa::Tensor out = sft.Forward(0, in);
+  singa::CppCPU host(0, 1);
+  out.ToDevice(&host);
+  const float* yptr = out.data<const float*>();
+
+  const float grad[] = {2.0f, -3.0f, 1.0f, 3.0f, -1.0f, -2.0};
+  singa::Tensor out_diff(singa::Shape{n}, &cuda);
+  out_diff.CopyDataFromHostPtr<float>(grad, n);
+  const auto ret = sft.Backward(0, out_diff);
+  singa::Tensor in_diff = ret.first;
+  in_diff.ToDevice(&host);
+  const float* xptr = in_diff.data<const float*>();
+
+  float* dx = new float[n];
+  float sigma = 0.f;
+  for (size_t i = 0; i < n; i++) sigma += grad[i] * yptr[i];
+  for (size_t i = 0; i < n; i++) dx[i] = (grad[i] - sigma) * yptr[i];
+  EXPECT_FLOAT_EQ(dx[0], xptr[0]);
+  EXPECT_FLOAT_EQ(dx[4], xptr[4]);
+  EXPECT_FLOAT_EQ(dx[5], xptr[5]);
+}
+#endif  // USE_CUDNN

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/3e2507b7/test/singa/test_softmax.cc
----------------------------------------------------------------------
diff --git a/test/singa/test_softmax.cc b/test/singa/test_softmax.cc
new file mode 100644
index 0000000..da2a6ef
--- /dev/null
+++ b/test/singa/test_softmax.cc
@@ -0,0 +1,110 @@
+/************************************************************
+*
+* Licensed to the Apache Software Foundation (ASF) under one
+* or more contributor license agreements.  See the NOTICE file
+* distributed with this work for additional information
+* regarding copyright ownership.  The ASF licenses this file
+* to you under the Apache License, Version 2.0 (the
+* "License"); you may not use this file except in compliance
+* with the License.  You may obtain a copy of the License at
+*
+*   http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing,
+* software distributed under the License is distributed on an
+* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+* KIND, either express or implied.  See the License for the
+* specific language governing permissions and limitations
+* under the License.
+*
+*************************************************************/
+
+#include "../src/model/layer/softmax.h"
+#include "gtest/gtest.h"
+#include <math.h> // exp
+
+using singa::Softmax;
+TEST(Softmax, Setup) {
+  Softmax sft;
+  EXPECT_EQ("Softmax", sft.layer_type());
+
+  singa::LayerConf conf;
+  singa::SoftmaxConf* softmaxconf = conf.mutable_softmax_conf();
+  softmaxconf->set_axis(2);
+
+  sft.Setup(conf);
+  EXPECT_EQ(2, sft.Axis());
+}
+
+TEST(Softmax, Forward) {
+  const float x[] = {1.0f, 2.0f, 0.0f, -2.0f, -3.0f, -1.0};
+  size_t n = sizeof(x) / sizeof(float);
+  size_t row = 2;
+  size_t col = 3;
+  singa::Tensor in(singa::Shape{row, col});
+  in.CopyDataFromHostPtr<float>(x, n);
+
+  int axis = 1;
+  Softmax sft;
+  singa::LayerConf conf;
+  singa::SoftmaxConf* softmaxconf = conf.mutable_softmax_conf();
+  softmaxconf->set_axis(axis);
+  sft.Setup(conf);
+
+  singa::Tensor out = sft.Forward(0, in);
+  const float* yptr = out.data<const float*>();
+  EXPECT_EQ(n, out.Size());
+
+  float* y = new float[n];
+  float* sigma = new float[row];
+  for (size_t i = 0; i < row; i++)
+    sigma[i] = 0.f;
+  for (size_t i = 0; i < n; i++)
+    sigma[i / col] += exp(x[i]);
+  //EXPECT_EQ(0, sigma[1]);
+  for (size_t i = 0; i < row; i++)
+    for (size_t j = 0; j < col; j++)
+      y[i * col + j] = exp(x[i * col + j]) / sigma[i];
+  EXPECT_FLOAT_EQ(y[0], yptr[0]);
+  EXPECT_FLOAT_EQ(y[4], yptr[4]);
+  EXPECT_FLOAT_EQ(y[5], yptr[5]);
+}
+
+TEST(Softmax, Backward) {
+  const float x[] = {1.0f, 2.0f, 0.0f, -2.0f, -3.0f, -1.0};
+  size_t n = sizeof(x) / sizeof(float);
+  size_t row = 2;
+  size_t col = 3;
+  singa::Tensor in(singa::Shape{row, col});
+  in.CopyDataFromHostPtr<float>(x, n);
+
+  int axis = 1;
+  Softmax sft;
+  singa::LayerConf conf;
+  singa::SoftmaxConf* softmaxconf = conf.mutable_softmax_conf();
+  softmaxconf->set_axis(axis);
+  sft.Setup(conf);
+  singa::Tensor out = sft.Forward(0, in);
+  const float* yptr = out.data<const float*>();
+
+  const float grad[] = {2.0f, -3.0f, 1.0f, 3.0f, -1.0f, -2.0};
+  singa::Tensor out_diff(singa::Shape{row, col});
+  out_diff.CopyDataFromHostPtr<float>(grad, n);
+  const auto in_diff = sft.Backward(0, out_diff);
+  const float* xptr = in_diff.first.data<const float*>();
+
+  float* dx = new float[n];
+  float* sigma = new float[row];
+  for (size_t i = 0; i < row; i++)
+    sigma[i] = 0.f;
+  for (size_t i = 0; i < n; i++)
+    sigma[i / col] += grad[i] * yptr[i];
+  // EXPECT_EQ(0, sigma[0]);
+  // EXPECT_EQ(0, sigma[1]);
+  for (size_t i = 0; i < row; i++)
+    for (size_t j = 0; j < col; j++)
+      dx[i * col + j] = (grad[i * col + j] - sigma[i]) * yptr[i * col +j];
+  EXPECT_FLOAT_EQ(dx[0], xptr[0]);
+  EXPECT_FLOAT_EQ(dx[4], xptr[4]);
+  EXPECT_FLOAT_EQ(dx[5], xptr[5]);
+}