You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@singa.apache.org by zh...@apache.org on 2016/06/12 07:27:54 UTC

[2/5] incubator-singa git commit: SINGA-182 Clean math function APIs and implementations

SINGA-182 Clean math function APIs and implementations

Merge branch 'cuda' from #jinyangturbo.
Clean the cuda related code (tensor_math_cuda.h, kernel_math.h and kernel_math.cu)
by unify the function arugments (names and arg order).
Need to reorder the functions.
Add Nrm2 for L2 norm using cblas and cublas.


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/6d69047a
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/6d69047a
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/6d69047a

Branch: refs/heads/dev
Commit: 6d69047addc46e5c9f381b7e1d4cebd20ce9b2e3
Parents: 564c88a
Author: Wei Wang <wa...@comp.nus.edu.sg>
Authored: Sun Jun 12 12:08:48 2016 +0800
Committer: Wei Wang <wa...@comp.nus.edu.sg>
Committed: Sun Jun 12 12:15:11 2016 +0800

----------------------------------------------------------------------
 include/singa/core/tensor.h        |   2 +
 src/core/tensor/math_kernel.cu     | 656 +++++++++++++++++---------------
 src/core/tensor/math_kernel.h      |  93 ++---
 src/core/tensor/tensor.cc          |  14 +
 src/core/tensor/tensor_math.h      | 140 ++++---
 src/core/tensor/tensor_math_cpp.h  | 227 ++++++-----
 src/core/tensor/tensor_math_cuda.h | 384 +++++++++++++++----
 test/singa/test_tensor_math.cc     | 346 ++++++++---------
 8 files changed, 1092 insertions(+), 770 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/6d69047a/include/singa/core/tensor.h
----------------------------------------------------------------------
diff --git a/include/singa/core/tensor.h b/include/singa/core/tensor.h
index 82bbe81..cd750c5 100644
--- a/include/singa/core/tensor.h
+++ b/include/singa/core/tensor.h
@@ -173,6 +173,8 @@ class Tensor {
   template <typename SType>
   Tensor &operator/=(const SType x);
 
+  float L2() const;
+
  protected:
   bool transpose_ = false;
   DataType data_type_ = kFloat32;

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/6d69047a/src/core/tensor/math_kernel.cu
----------------------------------------------------------------------
diff --git a/src/core/tensor/math_kernel.cu b/src/core/tensor/math_kernel.cu
index aed6add..b618f9b 100644
--- a/src/core/tensor/math_kernel.cu
+++ b/src/core/tensor/math_kernel.cu
@@ -35,36 +35,16 @@
 namespace singa {
 // Cuda Kernel Functions
 namespace cuda {
-__global__ void kernel_softmax_loss(const float *prob, const int *label,
-                                    float *loss, int n, int dim) {
-  int index = blockIdx.x * blockDim.x + threadIdx.x;
-  int num_threads = blockDim.x * gridDim.x;
-  for (; index < n; index += num_threads) {
-    float prob_of_truth = prob[index * dim + label[index]];
-    loss[index] -= std::log(max(prob_of_truth, FLT_MIN));
-  }
-}
-
-__global__ void kernel_softmax_gradient(float *grad, const int *label, int n,
-                                        int dim, float scale) {
-  int index = blockIdx.x * blockDim.x + threadIdx.x;
-  int num_threads = blockDim.x * gridDim.x;
-  for (; index < n; index += num_threads) {
-    int pos = index * dim + label[index];
-    grad[pos] = (grad[pos] - 1.0f) * scale;
-  }
-}
-
-__global__ void kernel_sum_vec(const float *data, float *sum, int n) {
+__global__ void KernelSum(const size_t n, const float *in, float *out) {
   int THREADS = blockDim.x;
 
   __shared__ float aux[CU1DBLOCK];
   int steps = (n - 1) / THREADS + 1;
-  aux[threadIdx.x] = data[threadIdx.x];
+  aux[threadIdx.x] = in[threadIdx.x];
 
   for (int i = 1; i < steps; ++i) {
     if (threadIdx.x + i * THREADS < n) {
-      aux[threadIdx.x] += data[threadIdx.x + i * THREADS];
+      aux[threadIdx.x] += in[threadIdx.x + i * THREADS];
     }
   }
 
@@ -83,432 +63,484 @@ __global__ void kernel_sum_vec(const float *data, float *sum, int n) {
   }
 
   __syncthreads();
-  *sum = aux[0];
+  *out = aux[0];
 }
 
-__global__ void kernel_sum_col(const float *src_mat_data, float *dst_vec_data,
-                               int rows, int cols, int stride) {
-  int index = blockIdx.x * blockDim.x + threadIdx.x;
-  int num_threads = blockDim.x * gridDim.x;
-  for (; index < rows; index += num_threads) {
-    dst_vec_data[index] = 0.0f;
-    for (int k = 0; k < cols; k++) {
-      dst_vec_data[index] += src_mat_data[index * stride + k];
-    }
+__global__ void KernelAdd(const size_t n, const float *in1, const float *in2,
+                          float *out) {
+  for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n;
+       i += blockDim.x * gridDim.x) {
+    out[i] = in1[i] + in2[i];
   }
 }
 
-__global__ void kernel_sum_row(const float *src_mat_data, float *dst_vec_data,
-                               int rows, int cols, int stride) {
-  int j = blockIdx.x;
-  int THREADS = blockDim.x;
-  if (j >= cols) {
-    return;
-  }
-
-  __shared__ float aux[CU1DBLOCK];
-  int steps = (rows - 1) / THREADS + 1;
-  aux[threadIdx.x] = src_mat_data[j + threadIdx.x * stride];
-  for (int i = 1; i < steps; ++i) {
-    if (threadIdx.x + i * THREADS < rows) {
-      aux[threadIdx.x] +=
-          src_mat_data[j + (threadIdx.x + i * THREADS) * stride];
-    }
+__global__ void KernelAdd(const size_t n, const float *in, const float x,
+                          float *out) {
+  for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n;
+       i += blockDim.x * gridDim.x) {
+    out[i] = in[i] + x;
   }
+}
 
-  int total_threads = THREADS;
-  __syncthreads();
-  while (total_threads > 1) {
-    int half_point = ((1 + total_threads) >> 1);
-    if (threadIdx.x < half_point) {
-      if (threadIdx.x + half_point < total_threads) {
-        aux[threadIdx.x] += aux[threadIdx.x + half_point];
-      }
-    }
-    __syncthreads();
-    total_threads = ((total_threads + 1) >> 1);
+__global__ void KernelSub(const size_t n, const float *in1, const float *in2,
+                          float *out) {
+  for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n;
+       i += blockDim.x * gridDim.x) {
+    out[i] = in1[i] - in2[i];
   }
-
-  __syncthreads();
-  dst_vec_data[j] = aux[0];
 }
 
-__global__ void kernel_add_vec_row(const float *src_vec_data,
-                                   const float *src_mat_data,
-                                   float *des_mat_data, int rows, int cols,
-                                   int stride) {
-  int i = blockIdx.x * blockDim.x + threadIdx.x;
-  int j = blockIdx.y * blockDim.y + threadIdx.y;
-  int num_threads_x = blockDim.x * gridDim.x;
-  int num_threads_y = blockDim.y * gridDim.y;
-  int index = 0;
-  for (; i < cols && j < rows; i += num_threads_x, j += num_threads_y) {
-    index = j * stride + i;
-    des_mat_data[index] = src_mat_data[index] + src_vec_data[i];
+__global__ void KernelExp(const size_t n, const float *in, float *out) {
+  for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n;
+       i += blockDim.x * gridDim.x) {
+    out[i] = std::exp(in[i]);
   }
 }
-__global__ void kernel_add(const float *src1, const float *src2, float *out,
-                           int n) {
-  int index = blockIdx.x * blockDim.x + threadIdx.x;
-  int num_threads = blockDim.x * gridDim.x;
-  for (; index < n; index += num_threads) {
-    out[index] = src1[index] + src2[index];
+
+__global__ void KernelLog(const size_t n, const float *in, float *out) {
+  for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n;
+       i += blockDim.x * gridDim.x) {
+    out[i] = std::log(in[i]);
   }
 }
 
-__global__ void kernel_sub(const float *src1, const float *src2, float *out,
-                           int n) {
-  int index = blockIdx.x * blockDim.x + threadIdx.x;
-  int num_threads = blockDim.x * gridDim.x;
-  for (; index < n; index += num_threads) {
-    out[index] = src1[index] - src2[index];
+__global__ void KernelSigmoid(const size_t n, const float *in, float *out) {
+  for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n;
+       i += blockDim.x * gridDim.x) {
+    out[i] = 1.0f / (1.0f + expf(-in[i]));
   }
 }
-__global__ void kernel_exp(const float *src_data, float *des_data, int n) {
-  int index = blockIdx.x * blockDim.x + threadIdx.x;
-  int num_threads = blockDim.x * gridDim.x;
-  for (; index < n; index += num_threads) {
-    des_data[index] = std::exp(src_data[index]);
+__global__ void KernelSign(const size_t n, const float *in, float *out) {
+  for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n;
+       i += blockDim.x * gridDim.x) {
+    if (in[i] > 0.0f)
+      out[i] = 1.0f;
+    else if (in[i] < 0.0f)
+      out[i] = -1.0f;
+    else
+      out[i] = 0.0f;
   }
 }
 
-__global__ void kernel_log(const float *src_data, float *des_data, int n) {
-  int index = blockIdx.x * blockDim.x + threadIdx.x;
-  int num_threads = blockDim.x * gridDim.x;
-  for (; index < n; index += num_threads) {
-    des_data[index] = std::log(src_data[index]);
+__global__ void KernelClamp(const size_t n, const float low, const float high,
+                            const float *in, float *out) {
+  for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n;
+       i += blockDim.x * gridDim.x) {
+    if (in[i] > high)
+      out[i] = high;
+    else if (in[i] < low)
+      out[i] = low;
+    else
+      out[i] = in[i];
   }
 }
 
-__global__ void kernel_sigmoid(const float *src_data, float *des_data, int n) {
-  int index = blockIdx.x * blockDim.x + threadIdx.x;
-  int num_threads = blockDim.x * gridDim.x;
-  for (; index < n; index += num_threads) {
-    des_data[index] = 1.0f / (1.0f + expf(-src_data[index]));
+__global__ void KernelRelu(const size_t n, const float *in, float *out) {
+  for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n;
+       i += blockDim.x * gridDim.x) {
+    out[i] = max(in[i], 0.0f);
   }
 }
 
-__global__ void kernel_sigmoid_grad(const float *src_data, float *des_data,
-                                    int n) {
-  int index = blockIdx.x * blockDim.x + threadIdx.x;
-  int num_threads = blockDim.x * gridDim.x;
-  for (; index < n; index += num_threads) {
-    des_data[index] = src_data[index] * (1.0f - src_data[index]);
+__global__ void KernelAbs(const size_t n, const float *in, float *out) {
+  for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n;
+       i += blockDim.x * gridDim.x) {
+    out[i] =  max(in[i], -in[i]);
   }
 }
 
-__global__ void kernel_relu(const float *src_data, float *des_data, int n) {
-  int index = blockIdx.x * blockDim.x + threadIdx.x;
-  int num_threads = blockDim.x * gridDim.x;
-  for (; index < n; index += num_threads) {
-    des_data[index] = max(src_data[index], 0.0f);
+__global__ void KernelTanh(const size_t n, const float *in, float *out) {
+  for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n;
+       i += blockDim.x * gridDim.x) {
+    out[i] = tanhf(in[i]);
   }
 }
 
-__global__ void kernel_relu_grad(const float *src_data, float *des_data,
-                                 int n) {
-  int index = blockIdx.x * blockDim.x + threadIdx.x;
-  int num_threads = blockDim.x * gridDim.x;
-  for (; index < n; index += num_threads) {
-    des_data[index] = src_data[index] > 0.0f ? 1.0f : 0.0f;
+__global__ void KernelSoftplus(const size_t n, const float *in, float *out) {
+  for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n;
+       i += blockDim.x * gridDim.x) {
+    out[i] = logf(1 + expf(in[i]));
   }
 }
-
-__global__ void kernel_tanh(const float *src_data, float *des_data, int n) {
-  int index = blockIdx.x * blockDim.x + threadIdx.x;
-  int num_threads = blockDim.x * gridDim.x;
-  for (; index < n; index += num_threads) {
-    des_data[index] = tanhf(src_data[index]);
+__global__ void KernelSquare(const size_t n, const float *in, float *out) {
+  for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n;
+       i += blockDim.x * gridDim.x) {
+    out[i] = in[i] * in[i];
   }
 }
-
-__global__ void kernel_tanh_grad(const float *src_data, float *des_data,
-                                 int n) {
-  int index = blockIdx.x * blockDim.x + threadIdx.x;
-  int num_threads = blockDim.x * gridDim.x;
-  for (; index < n; index += num_threads) {
-    des_data[index] = (1.0f - src_data[index] * src_data[index]);
+__global__ void KernelSqrt(const size_t n, const float *in, float *out) {
+  for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n;
+       i += blockDim.x * gridDim.x) {
+    out[i] = std::sqrt(in[i]);
   }
 }
 
-__global__ void kernel_softplus(const float *src_data, float *des_data, int n) {
-  int index = blockIdx.x * blockDim.x + threadIdx.x;
-  int num_threads = blockDim.x * gridDim.x;
-  for (; index < n; index += num_threads) {
-    des_data[index] = logf(1 + expf(src_data[index]));
+__global__ void KernelPow(const size_t n, const float *in1, const float *in2,
+                          float *out) {
+  for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n;
+       i += blockDim.x * gridDim.x) {
+    out[i] = std::pow(in1[i], in2[i]);
   }
 }
 
-__global__ void kernel_softplus_grad(const float *src_data, float *des_data,
-                                     int n) {
-  int index = blockIdx.x * blockDim.x + threadIdx.x;
-  int num_threads = blockDim.x * gridDim.x;
-  for (; index < n; index += num_threads) {
-    des_data[index] = 1.0f / (1.0f + expf(-src_data[index]));
+__global__ void KernelPow(const size_t n, const float *in, const float x,
+                          float *out) {
+  for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n;
+       i += blockDim.x * gridDim.x) {
+    out[i] = std::pow(in[i], x);
   }
 }
 
-__global__ void kernel_square(const float *src_data, float *des_data, int n) {
-  int index = blockIdx.x * blockDim.x + threadIdx.x;
-  int num_threads = blockDim.x * gridDim.x;
-  for (; index < n; index += num_threads) {
-    des_data[index] = src_data[index] * src_data[index];
+__global__ void KernelMult(const size_t n, const float *in1, const float *in2,
+                           float *out) {
+  for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n;
+       i += blockDim.x * gridDim.x) {
+    out[i] = in1[i] * in2[i];
   }
 }
 
-__global__ void kernel_square_grad(const float *src_data, float *des_data,
-                                   int n) {
-  int index = blockIdx.x * blockDim.x + threadIdx.x;
-  int num_threads = blockDim.x * gridDim.x;
-  for (; index < n; index += num_threads) {
-    des_data[index] = 2 * src_data[index];
+__global__ void KernelMult(const size_t n, const float *in, const float x,
+                           float *out) {
+  for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n;
+       i += blockDim.x * gridDim.x) {
+    out[i] = in[i] * x;
   }
 }
 
-__global__ void kernel_sqrt(const float *src_data, float *des_data, int n) {
-  int index = blockIdx.x * blockDim.x + threadIdx.x;
-  int num_threads = blockDim.x * gridDim.x;
-  for (; index < n; index += num_threads) {
-    des_data[index] = std::sqrt(src_data[index]);
+__global__ void KernelDiv(const size_t n, const float *in1, const float *in2,
+                          float *out) {
+  for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n;
+       i += blockDim.x * gridDim.x) {
+    out[i] = in1[i] / in2[i];
   }
 }
-
-__global__ void kernel_pow(const float *src_data_a, const float *src_data_b,
-                           float *des_data, int n) {
-  int index = blockIdx.x * blockDim.x + threadIdx.x;
-  int num_threads = blockDim.x * gridDim.x;
-  for (; index < n; index += num_threads) {
-    des_data[index] = std::pow(src_data_a[index], src_data_b[index]);
+__global__ void KernelDiv(const size_t n, const float x, const float *in,
+                          float *out) {
+  for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n;
+       i += blockDim.x * gridDim.x) {
+    out[i] = x / in[i];
   }
 }
-
-__global__ void kernel_mult(const float *src_data_a, const float *src_data_b,
-                            float *des_data, int n) {
-  int index = blockIdx.x * blockDim.x + threadIdx.x;
-  int num_threads = blockDim.x * gridDim.x;
-  for (; index < n; index += num_threads) {
-    des_data[index] = src_data_a[index] * src_data_b[index];
+__global__ static void KernelSet(const size_t n, const float x, float *out) {
+  for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n;
+       i += blockDim.x * gridDim.x) {
+    out[i] = x;
   }
 }
 
-__global__ void kernel_mult(const float *src_data_a, const float x,
-                            float *des_data, int n) {
-  int index = blockIdx.x * blockDim.x + threadIdx.x;
-  int num_threads = blockDim.x * gridDim.x;
-  for (; index < n; index += num_threads) {
-    des_data[index] = src_data_a[index] * x;
+__global__ void KernelThreshold(const size_t n, const float x, const float *in,
+                                float *out) {
+  for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n;
+       i += blockDim.x * gridDim.x) {
+    out[i] = in[i] < x ? 1.0f : 0.0f;
   }
 }
 
-__global__ void kernel_div(const float *src_data_a, const float *src_data_b,
-                           float *des_data, int n) {
-  int index = blockIdx.x * blockDim.x + threadIdx.x;
-  int num_threads = blockDim.x * gridDim.x;
-  for (; index < n; index += num_threads) {
-    des_data[index] = src_data_a[index] / src_data_b[index];
+__global__ void KernelGE(const int num, const float *in, const float x,
+                         float *out) {
+  for (size_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < num;
+       idx += blockDim.x * gridDim.x) {
+    out[idx] = in[idx] >= x ? 1.0f : 0.0f;
   }
 }
-
-__global__ static void kernel_set_value(float *data, float value, int n) {
-  int index = blockIdx.x * blockDim.x + threadIdx.x;
-  int num_threads = blockDim.x * gridDim.x;
-  for (; index < n; index += num_threads) {
-    data[index] = value;
+__global__ void KernelGT(const int num, const float *in, const float x,
+                         float *out) {
+  for (size_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < num;
+       idx += blockDim.x * gridDim.x) {
+    out[idx] = in[idx] > x ? 1.0f : 0.0f;
   }
 }
-
-__global__ void kernel_threshold(const float *src_data, float *des_data,
-                                 float alpha, int n) {
-  int index = blockIdx.x * blockDim.x + threadIdx.x;
-  int num_threads = blockDim.x * gridDim.x;
-  for (; index < n; index += num_threads) {
-    des_data[index] = src_data[index] < alpha ? 1.0f : 0.0f;
+__global__ void KernelLE(const int num, const float *in, const float x,
+                         float *out) {
+  for (size_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < num;
+       idx += blockDim.x * gridDim.x) {
+    out[idx] = in[idx] <= x ? 1.0f : 0.0f;
   }
 }
-void sum(int n, const float *in, float *out) {
-  int threads_per_block = n > CU1DBLOCK ? CU1DBLOCK : n;
-  //  here, we only need one block
-  int num_blocks = 1;
 
-  kernel_sum_vec << <num_blocks, threads_per_block>>> (in, out, n);
+__global__ void KernelLT(const int num, const float *in, const float x,
+                         float *out) {
+  for (size_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < num;
+       idx += blockDim.x * gridDim.x) {
+    out[idx] = in[idx] < x ? 1.0f : 0.0f;
+  }
 }
 
-void sum_row(int rows, int cols, int stride, const float *in, float *out) {
-  int threads_per_block = rows > CU1DBLOCK ? CU1DBLOCK : rows;
-  int num_blocks = cols;
+// ********************************
+// Functions call kernels
+// ********************************
 
-  kernel_sum_row << <num_blocks, threads_per_block>>>
-      (in, out, rows, cols, stride);
+void set(const size_t n, const float v, float *out, cudaStream_t s) {
+  KernelSet <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (n, v, out);
 }
 
-void sum_col(int rows, int cols, int stride, const float *in, float *out) {
-  int threads_per_block = cols > CU1DBLOCK ? CU1DBLOCK : cols;
-  int num_blocks = rows;
+void abs(const size_t n, const float *in, float *out, cudaStream_t s) {
+  KernelAbs <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (n, in, out);
+}
 
-  kernel_sum_col << <num_blocks, threads_per_block>>>
-      (in, out, rows, cols, stride);
+void sign(const size_t n, const float *in, float *out, cudaStream_t s) {
+  KernelSign <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (n, in, out);
 }
-void add_row(int rows, int cols, int stride, const float *in_row,
-             const float *in_mat, float *out) {
-  dim3 threads_per_block(CU2DBLOCK_X, CU2DBLOCK_Y);
-  dim3 num_blocks(
-      cols / threads_per_block.x + (cols % threads_per_block.x == 0 ? 0 : 1),
-      rows / threads_per_block.y + (rows % threads_per_block.y == 0 ? 0 : 1));
-  kernel_add_vec_row << <num_blocks, threads_per_block>>>
-      (in_row, in_mat, out, rows, cols, stride);
+
+void exp(const size_t n, const float *in, float *out, cudaStream_t s) {
+  KernelExp <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (n, in, out);
 }
-void add(int n, const float *a, const float *b, float *out) {
-  kernel_add << <ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (a, b, out, n);
+
+void log(const size_t n, const float *in, float *out, cudaStream_t s) {
+  KernelLog <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (n, in, out);
 }
-void sub(int n, const float *a, const float *b, float *out) {
-  kernel_sub << <ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (a, b, out, n);
+
+void sqrt(const size_t n, const float *in, float *out, cudaStream_t s) {
+  KernelSqrt <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (n, in, out);
 }
-void exp(int n, const float *in, float *out) {
-  kernel_exp << <ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (in, out, n);
+
+void square(const size_t n, const float *in, float *out, cudaStream_t s) {
+  KernelSquare <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (n, in, out);
 }
 
-void log(int n, const float *in, float *out) {
-  kernel_log << <ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (in, out, n);
+void tanh(const size_t n, const float *in, float *out, cudaStream_t s) {
+  KernelTanh <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (n, in, out);
 }
 
-void sigmoid(int n, const float *in, float *out) {
-  kernel_sigmoid << <ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (in, out, n);
+void relu(const size_t n, const float *in, float *out, cudaStream_t s) {
+  KernelRelu <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (n, in, out);
+}
+void sigmoid(const int n, const float *in, float *out, cudaStream_t s) {
+  KernelSigmoid <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (n, in, out);
+}
+void softplus(const size_t n, const float *in, float *out, cudaStream_t s) {
+  KernelSoftplus <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (n, in, out);
+}
+void clamp(const size_t n, const float low, const float high, const float *in,
+           float *out, cudaStream_t s) {
+  KernelClamp <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (n, low, high, in, out);
 }
 
-void sigmoid_grad(int n, const float *in, float *out) {
-  kernel_sigmoid_grad << <ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (in, out, n);
+void pow(const size_t n, const float *in, const float x, float *out,
+         cudaStream_t s) {
+  KernelPow <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (n, in, x, out);
 }
 
-void relu(int n, const float *in, float *out) {
-  kernel_relu << <ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (in, out, n);
+void add(const size_t n, const float *in, const float x, float *out,
+         cudaStream_t s) {
+  KernelAdd <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (n, in, x, out);
 }
 
-void relu_grad(int n, const float *in, float *out) {
-  kernel_relu_grad << <ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (in, out, n);
+void mult(const size_t n, const float *in, const float x, float *out,
+          cudaStream_t s) {
+  KernelMult <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (n, in, x, out);
 }
 
-void tanh(int n, const float *in, float *out) {
-  kernel_tanh << <ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (in, out, n);
+void div(const size_t n, const float x, const float *in, float *out,
+          cudaStream_t s) {
+  KernelDiv <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (n, x, in, out);
 }
 
-void tanh_grad(int n, const float *in, float *out) {
-  kernel_tanh_grad << <ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (in, out, n);
+void threshold(const size_t n, const float x, const float *in, float *out,
+               cudaStream_t s) {
+  KernelThreshold <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (n, x, in, out);
 }
 
-void softplus(int n, const float *in, float *out) {
-  kernel_softplus << <ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (in, out, n);
+void gt(const size_t num, const float *in, const float x, float *out,
+        cudaStream_t s) {
+  KernelGT <<<ceil(num / CU1DBLOCKF), CU1DBLOCKF>>> (num, in, x, out);
+}
+void ge(const size_t num, const float *in, const float x, float *out,
+        cudaStream_t s) {
+  KernelGE <<<ceil(num / CU1DBLOCKF), CU1DBLOCKF>>> (num, in, x, out);
+}
+void lt(const size_t num, const float *in, const float x, float *out,
+        cudaStream_t s) {
+  KernelLT <<<ceil(num / CU1DBLOCKF), CU1DBLOCKF>>> (num, in, x, out);
+}
+void le(const size_t num, const float *in, const float x, float *out,
+        cudaStream_t s) {
+  KernelLE <<<ceil(num / CU1DBLOCKF), CU1DBLOCKF>>> (num, in, x, out);
 }
 
-void softplus_grad(int n, const float *in, float *out) {
-  kernel_softplus_grad << <ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (in, out, n);
+void pow(const size_t n, const float *in1, const float *in2, float *out,
+         cudaStream_t s) {
+  KernelPow <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (n, in1, in2, out);
 }
 
-void square(int n, const float *in, float *out) {
-  kernel_square << <ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (in, out, n);
+void add(const size_t n, const float *in1, const float *in2, float *out,
+         cudaStream_t s) {
+  KernelAdd <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (n, in1, in2, out);
 }
 
-void square_grad(int n, const float *in, float *out) {
-  kernel_square_grad << <ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (in, out, n);
+void sub(const size_t n, const float *in1, const float *in2, float *out,
+         cudaStream_t s) {
+  KernelSub <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (n, in1, in2, out);
 }
 
-void sqrt(int n, const float *in, float *out) {
-  kernel_sqrt << <ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (in, out, n);
+void mult(const size_t n, const float *in1, const float *in2, float *out,
+          cudaStream_t s) {
+  KernelMult <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (n, in1, in2, out);
 }
 
-void pow(int n, const float *a, const float *b, float *out) {
-  kernel_pow << <ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (a, b, out, n);
+void div(const size_t n, const float *in1, const float *in2, float *out,
+         cudaStream_t s) {
+  KernelDiv <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (n, in1, in2, out);
 }
 
-void mult(int n, const float *a, const float *b, float *out) {
-  kernel_mult << <ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (a, b, out, n);
+void sum(const size_t n, const float *in, float *out, cudaStream_t s) {
+  int threads_per_block = n > CU1DBLOCK ? CU1DBLOCK : n;
+  //  here, we only need one block
+  int num_blocks = 1;
+  KernelSum <<<num_blocks, threads_per_block>>> (n, in, out);
+}
+/*
+void square_grad(int n, const float *in, float *out, cudaStream_t s) {
+  kernel_square_grad <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (in, out, n);
 }
 
-void mult(int n, const float *a, const float x, float *out) {
-  kernel_mult << <ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (a, x, out, n);
+void tanh_grad(int n, const float *in, float *out, cudaStream_t s) {
+  kernel_tanh_grad <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (in, out, n);
 }
 
-void div(int n, const float *a, const float *b, float *out) {
-  kernel_div << <ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (a, b, out, n);
+
+void relu_grad(int n, const float *in, float *out, cudaStream_t s) {
+  kernel_relu_grad <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (in, out, n);
 }
 
-void set_value(int n, float v, float *out) {
-  kernel_set_value << <ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (out, v, n);
+
+void sigmoid_grad(int n, const float *in, float *out, cudaStream_t s) {
+  kernel_sigmoid_grad <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (in, out, n);
 }
 
-void threshold(int n, float alpha, const float *in, float *out) {
-  kernel_threshold << <ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (in, out, alpha, n);
+void softplus_grad(int n, const float *in, float *out, cudaStream_t s) {
+  kernel_softplus_grad <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (in, out, n);
 }
 
-// follow the consistency guide for math API
-__global__ void KernelDiv(const size_t num, const float alpha, const float *in,
-                          float *out) {
-  for (size_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < num;
-       idx += blockDim.x * gridDim.x) {
-    out[idx] = alpha / in[idx];
+
+__global__ void kernel_sum_col(const float *src_mat_data, float *dst_vec_data,
+                               int rows, int cols, int stride) {
+  int index = blockIdx.x * blockDim.x + threadIdx.x;
+  int num_threads = blockDim.x * gridDim.x;
+  for (; index < rows; index += num_threads) {
+    dst_vec_data[index] = 0.0f;
+    for (int k = 0; k < cols; k++) {
+      dst_vec_data[index] += src_mat_data[index * stride + k];
+    }
   }
 }
 
-__global__ void KernelGE(const int num, const float *in, const float x,
-                         float *out) {
-  for (size_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < num;
-       idx += blockDim.x * gridDim.x) {
-    out[idx] = in[idx] >= x ? 1.0f : 0.0f;
+__global__ void kernel_sum_row(const float *src_mat_data, float *dst_vec_data,
+                               int rows, int cols, int stride) {
+  int j = blockIdx.x;
+  int THREADS = blockDim.x;
+  if (j >= cols) {
+    return;
   }
-}
-__global__ void KernelGT(const int num, const float *in, const float x,
-                         float *out) {
-  for (size_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < num;
-       idx += blockDim.x * gridDim.x) {
-    out[idx] = in[idx] > x ? 1.0f : 0.0f;
+
+  __shared__ float aux[CU1DBLOCK];
+  int steps = (rows - 1) / THREADS + 1;
+  aux[threadIdx.x] = src_mat_data[j + threadIdx.x * stride];
+  for (int i = 1; i < steps; ++i) {
+    if (threadIdx.x + i * THREADS < rows) {
+      aux[threadIdx.x] +=
+          src_mat_data[j + (threadIdx.x + i * THREADS) * stride];
+    }
   }
-}
-__global__ void KernelLE(const int num, const float *in, const float x,
-                         float *out) {
-  for (size_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < num;
-       idx += blockDim.x * gridDim.x) {
-    out[idx] = in[idx] <= x ? 1.0f : 0.0f;
+
+  int total_threads = THREADS;
+  __syncthreads();
+  while (total_threads > 1) {
+    int half_point = ((1 + total_threads) >> 1);
+    if (threadIdx.x < half_point) {
+      if (threadIdx.x + half_point < total_threads) {
+        aux[threadIdx.x] += aux[threadIdx.x + half_point];
+      }
+    }
+    __syncthreads();
+    total_threads = ((total_threads + 1) >> 1);
   }
+
+  __syncthreads();
+  dst_vec_data[j] = aux[0];
 }
 
-__global__ void KernelLT(const int num, const float *in, const float x,
-                         float *out) {
-  for (size_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < num;
-       idx += blockDim.x * gridDim.x) {
-    out[idx] = in[idx] < x ? 1.0f : 0.0f;
+
+__global__ void kernel_add_vec_row(const float *src_vec_data,
+                                   const float *src_mat_data,
+                                   float *des_mat_data, int rows, int cols,
+                                   int stride) {
+  int i = blockIdx.x * blockDim.x + threadIdx.x;
+  int j = blockIdx.y * blockDim.y + threadIdx.y;
+  int num_threads_x = blockDim.x * gridDim.x;
+  int num_threads_y = blockDim.y * gridDim.y;
+  int index = 0;
+  for (; i < cols && j < rows; i += num_threads_x, j += num_threads_y) {
+    index = j * stride + i;
+    des_mat_data[index] = src_mat_data[index] + src_vec_data[i];
   }
 }
 
-__global__ void KernelSet(const size_t num, const float x, float *out) {
-  for (size_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < num;
-       idx += blockDim.x * gridDim.x) {
-    out[idx] = x;
+__global__ void kernel_sigmoid_grad(const float *src_data, float *des_data,
+                                    int n) {
+  int index = blockIdx.x * blockDim.x + threadIdx.x;
+  int num_threads = blockDim.x * gridDim.x;
+  for (; index < n; index += num_threads) {
+    des_data[index] = src_data[index] * (1.0f - src_data[index]);
   }
 }
 
-void Set(const size_t num, const float x, float *out, cudaStream_t s) {
-  KernelSet << <ceil(num / CU1DBLOCKF), CU1DBLOCKF>>> (num, x, out);
+
+__global__ void kernel_relu_grad(const float *src_data, float *des_data,
+                                 int n) {
+  int index = blockIdx.x * blockDim.x + threadIdx.x;
+  int num_threads = blockDim.x * gridDim.x;
+  for (; index < n; index += num_threads) {
+    des_data[index] = src_data[index] > 0.0f ? 1.0f : 0.0f;
+  }
 }
-void Div(const size_t num, float alpha, const float *in, float *out,
-         cudaStream_t s) {
-  KernelDiv << <ceil(num / CU1DBLOCKF), CU1DBLOCKF>>> (num, alpha, in, out);
+
+__global__ void kernel_tanh_grad(const float *src_data, float *des_data,
+                                 int n) {
+  int index = blockIdx.x * blockDim.x + threadIdx.x;
+  int num_threads = blockDim.x * gridDim.x;
+  for (; index < n; index += num_threads) {
+    des_data[index] = (1.0f - src_data[index] * src_data[index]);
+  }
 }
 
-void GT(const size_t num, const float *in, const float x, float *out,
-        cudaStream_t s) {
-  KernelGT << <ceil(num / CU1DBLOCKF), CU1DBLOCKF>>> (num, in, x, out);
+
+__global__ void kernel_softplus_grad(const float *src_data, float *des_data,
+                                     int n) {
+  int index = blockIdx.x * blockDim.x + threadIdx.x;
+  int num_threads = blockDim.x * gridDim.x;
+  for (; index < n; index += num_threads) {
+    des_data[index] = 1.0f / (1.0f + expf(-src_data[index]));
+  }
 }
-void GE(const size_t num, const float *in, const float x, float *out,
-        cudaStream_t s) {
-  KernelGE << <ceil(num / CU1DBLOCKF), CU1DBLOCKF>>> (num, in, x, out);
+__global__ void KernelSquareGrad(const float *src_data, float *des_data,
+                                   int n) {
+  int index = blockIdx.x * blockDim.x + threadIdx.x;
+  int num_threads = blockDim.x * gridDim.x;
+  for (; index < n; index += num_threads) {
+    des_data[index] = 2 * src_data[index];
+  }
 }
-void LT(const size_t num, const float *in, const float x, float *out,
-        cudaStream_t s) {
-  KernelLT << <ceil(num / CU1DBLOCKF), CU1DBLOCKF>>> (num, in, x, out);
+__global__ void kernel_softmax_loss(const float *prob, const int *label,
+                                    float *loss, int n, int dim) {
+  int index = blockIdx.x * blockDim.x + threadIdx.x;
+  int num_threads = blockDim.x * gridDim.x;
+  for (; index < n; index += num_threads) {
+    float prob_of_truth = prob[index * dim + label[index]];
+    loss[index] -= std::log(max(prob_of_truth, FLT_MIN));
+  }
 }
-void LE(const size_t num, const float *in, const float x, float *out,
-        cudaStream_t s) {
-  KernelLE << <ceil(num / CU1DBLOCKF), CU1DBLOCKF>>> (num, in, x, out);
+__global__ void kernel_softmax_gradient(float *grad, const int *label, int n,
+                                        int dim, float scale) {
+  int index = blockIdx.x * blockDim.x + threadIdx.x;
+  int num_threads = blockDim.x * gridDim.x;
+  for (; index < n; index += num_threads) {
+    int pos = index * dim + label[index];
+    grad[pos] = (grad[pos] - 1.0f) * scale;
+  }
 }
+*/
+
 
 }  // namespace cuda
 }  // namespace singa

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/6d69047a/src/core/tensor/math_kernel.h
----------------------------------------------------------------------
diff --git a/src/core/tensor/math_kernel.h b/src/core/tensor/math_kernel.h
index 5c906a9..d8a58a5 100644
--- a/src/core/tensor/math_kernel.h
+++ b/src/core/tensor/math_kernel.h
@@ -31,65 +31,66 @@ namespace singa {
 
 // TODO(wangwei) make all function templates.
 namespace cuda {
-void sum(int n, const float *in, float *out);
 
-void sum_row(int rows, int cols, int stride, const float *in, float *out);
-
-void sum_col(int rows, int cols, int stride, const float *in, float *out);
-
-void add_row(int rows, int cols, int stride, const float *in_row,
-             const float *in_mat, float *out);
-
-void add(int n, const float *a, const float *b, float *out);
-
-void sub(int n, const float *a, const float *b, float *out);
-
-void exp(int n, const float *in, float *out);
-
-void log(int n, const float *in, float *out);
-
-void sigmoid(int n, const float *in, float *out);
-
-void sigmoid_grad(int n, const float *in, float *out);
-
-void relu(int n, const float *in, float *out);
-
-void relu_grad(int n, const float *in, float *out);
-
-void tanh(int n, const float *in, float *out);
-
-void tanh_grad(int n, const float *in, float *out);
+// 0 input
+void set(const size_t n, const float v, float *out, cudaStream_t s);
+
+// 1 input
+void abs(const size_t n, const float *in, float *out, cudaStream_t s);
+void sign(const size_t n, const float *in, float *out, cudaStream_t s);
+void exp(const size_t n, const float *in, float *out, cudaStream_t s);
+void log(const size_t n, const float *in, float *out, cudaStream_t s);
+void sqrt(const size_t n, const float *in, float *out, cudaStream_t s);
+void square(const size_t n, const float *in, float *out, cudaStream_t s);
+void tanh(const size_t n, const float *in, float *out, cudaStream_t s);
+void relu(const size_t n, const float *in, float *out, cudaStream_t s);
+void sigmoid(const int n, const float *in, float *out, cudaStream_t s);
+void softplus(const size_t n, const float *in, float *out, cudaStream_t s);
+void clamp(const size_t n, const float low, const float high, const float *in,
+           float *out, cudaStream_t s);
+
+void pow(const size_t n, const float *in, const float x, float *out,
+         cudaStream_t s);
 
-void softplus(int n, const float *in, float *out);
+void add(const size_t n, const float *in, const float x, float *out,
+         cudaStream_t s);
 
-void softplus_grad(int n, const float *in, float *out);
+void mult(const size_t n, const float *in, const float x, float *out,
+          cudaStream_t s);
 
-void square(int n, const float *in, float *out);
+void div(const size_t n, const float x, const float *in, float *out,
+         cudaStream_t s);
 
-void square_grad(int n, const float *in, float *out);
+void threshold(const size_t n, const float x, const float *in, float *out,
+               cudaStream_t s);
 
-void sqrt(int n, const float *in, float *out);
+void gt(const size_t num, const float *in, const float x, float *out,
+        cudaStream_t s);
+void ge(const size_t num, const float *in, const float x, float *out,
+        cudaStream_t s);
+void lt(const size_t num, const float *in, const float x, float *out,
+        cudaStream_t s);
+void le(const size_t num, const float *in, const float x, float *out,
+        cudaStream_t s);
 
-void pow(int n, const float *a, const float *b, float *out);
+// 2 inputs
+void pow(const size_t n, const float *in1, const float *in2, float *out,
+         cudaStream_t s);
 
-void mult(int n, const float *a, const float *b, float *out);
+void add(const size_t n, const float *in1, const float *in2, float *out,
+         cudaStream_t s);
 
-void mult(int n, const float *a, const float x, float *out);
+void sub(const size_t n, const float *in1, const float *in2, float *out,
+         cudaStream_t s);
 
-void div(int n, const float *a, const float *b, float *out);
+void mult(const size_t n, const float *in1, const float *in2, float *out,
+          cudaStream_t s);
 
-void set_value(int n, float v, float *out);
+void div(const size_t n, const float *in1, const float *in2, float *out,
+         cudaStream_t s);
 
-void threshold(int n, float alpha, const float *in, float *out);
+void sum(const size_t n, const float *in, float *out, cudaStream_t s);
 
-// follow the consistency guide for math API
-void Div(const size_t num, const float x, const float *in, float *out,
-         cudaStream_t s);
-void Set(const size_t num, const float x, float *out, cudaStream_t s);
-void GT(size_t num, const float *in, const float x, float *out, cudaStream_t s);
-void GE(size_t num, const float *in, const float x, float *out, cudaStream_t s);
-void LT(size_t num, const float *in, const float x, float *out, cudaStream_t s);
-void LE(size_t num, const float *in, const float x, float *out, cudaStream_t s);
 }  // cuda
 
 }  // namespace singa

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/6d69047a/src/core/tensor/tensor.cc
----------------------------------------------------------------------
diff --git a/src/core/tensor/tensor.cc b/src/core/tensor/tensor.cc
index f4e9da2..e62386a 100644
--- a/src/core/tensor/tensor.cc
+++ b/src/core/tensor/tensor.cc
@@ -219,6 +219,8 @@ GenUnaryScalarArgMemberFn(operator+=, Add);
 GenUnaryScalarArgMemberFn(operator*=, EltwiseMult);
 GenUnaryScalarArgMemberFn(operator/=, Div);
 
+
+
 // ====================Tensor Operations=======================================
 void CopyDataToFrom(Tensor *dst, const Tensor &src, const size_t num,
                     const size_t dst_offset, const size_t src_offset) {
@@ -309,6 +311,18 @@ void CopyDataToFrom(Tensor *dst, const Tensor &src, const size_t num,
   } while (0)
 
 // =============Element-wise operations====================================
+/// L2 norm, Do not use Nrm2 (name conflict).
+float Tensor::L2() const {
+  float nrm = 0.0f;
+  TYPE_LANG_SWITCH(data_type_, DType, device_->lang(), Lang, {
+    device_->Exec([&nrm, this](Context *ctx) {
+      DType ret;
+      Nrm2<DType, Lang>(this->Size(), this->blob(), &ret, ctx);
+      nrm = TypeCast<DType, float>(ret);
+    }, {this->blob()}, {});
+  });
+  return nrm;
+}
 template <typename SType>
 void Tensor::SetValue(const SType x) {
   CHECK_EQ(sizeof(SType), SizeOf(data_type_));

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/6d69047a/src/core/tensor/tensor_math.h
----------------------------------------------------------------------
diff --git a/src/core/tensor/tensor_math.h b/src/core/tensor/tensor_math.h
index b5d0ba9..b86e1cb 100644
--- a/src/core/tensor/tensor_math.h
+++ b/src/core/tensor/tensor_math.h
@@ -48,41 +48,45 @@ namespace singa {
 /// 7. Use size_t for the number of elements, rows or columns.
 /// 8. Use the same name for the Tensor and Blob level math functions.
 
-// =============Element-wise operations====================================
+// **************************************
+// Element-wise functions
+// **************************************
+
 /// out[i] = |in[i]|
 template <typename DType, typename Lang>
 void Abs(const size_t num, const Blob *in, Blob *out, Context *ctx) {
   LOG(FATAL) << "Abs Not Implemented";
 }
 
-/// out = in + x
+/// out[i] = in[i] + x
 template <typename DType, typename Lang>
 void Add(const size_t num, const Blob *in, const DType x, Blob *out,
          Context *ctx) {
   LOG(FATAL) << "Add Not Implemented";
 }
 
-/// out = in1 + in2
+/// out[i] = in1[i] + in2[i]
 template <typename DType, typename Lang>
 void Add(const size_t num, const Blob *in1, const Blob *in2, Blob *out,
          Context *ctx) {
   LOG(FATAL) << "Add-Pair Not Implemented";
 }
-/// Element-wise operation, clamp every element into [low, high]
-/// if x>high, then x=high; if x<low, then x=low.
+/// Clamp every element into [low, high]
+/// if in[i]>high, then out[i]=high; if in[i]<low, then out[i]=low.
 template <typename DType, typename Lang>
 void Clamp(const size_t num, const DType low, const DType high, const Blob *in,
            Blob *out, Context *ctx) {
   LOG(FATAL) << "Clamp Not Implemented";
 }
 
-/// out = x / in
+/// out[i] = x / in[i]
 template <typename DType, typename Lang>
 void Div(const size_t num, const DType x, const Blob *in, Blob *out,
          Context *ctx) {
   LOG(FATAL) << "Div Not Implemented";
 }
 
+/// out[i] = in[i] / x
 template <typename DType, typename Lang>
 void Div(const size_t num, const Blob *in, const DType x, Blob *out,
          Context *ctx) {
@@ -90,21 +94,21 @@ void Div(const size_t num, const Blob *in, const DType x, Blob *out,
   EltwiseMult<DType, Lang>(num, in, DType(1) / x, out, ctx);
 }
 
-/// out = in1 / in2
+/// out[i] = in1[i] / in2[i]
 template <typename DType, typename Lang>
 void Div(const size_t num, const Blob *in1, const Blob *in2, Blob *out,
          Context *ctx) {
   LOG(FATAL) << "Div-Pair Not Implemented";
 }
 
-/// out = in * x
+/// out[i] = in[i] * x
 template <typename DType, typename Lang>
 void EltwiseMult(const size_t num, const Blob *in, const DType x, Blob *out,
                  Context *ctx) {
   LOG(FATAL) << "EltwiseMult Not Implemented";
 }
 
-/// out = in2 * in2
+/// out[i] = in1[i] * in2[i]
 template <typename DType, typename Lang>
 void EltwiseMult(const size_t num, const Blob *in1, const Blob *in2, Blob *out,
                  Context *ctx) {
@@ -146,31 +150,32 @@ void GT(const size_t num, const Blob *in, const DType x, Blob *out,
         Context *ctx) {
   LOG(FATAL) << "GT Not Implemented";
 }
-/// Element-wise operation, do v^x for every v from the in tensor
+/// out[i] = pow(in[i], x)
 template <typename DType, typename Lang>
 void Pow(const size_t num, const Blob *in, const DType x, Blob *out,
          Context *ctx) {
   LOG(FATAL) << "Pow Not Implemented";
 }
 
-/// Element-wise operation, do v^x for every v from the lhs and every x from rhs
+/// out[i]=pow(in1[i], in2[i])
 template <typename DType, typename Lang>
 void Pow(const size_t num, const Blob *in1, const Blob *in2, Blob *out,
          Context *ctx) {
   LOG(FATAL) << "Pow-Pair Not Implemented";
 }
 
-/// Element-wise operation, out[i]=max(0, in[i])
+/// out[i]=max(0, in[i])
 template <typename DType, typename Lang>
 void ReLU(const size_t num, const Blob *in, Blob *out, Context *ctx) {
   LOG(FATAL) << "ReLU Not Implemented";
 }
 
+/// out[i] = x
 template <typename DType, typename Lang>
 void Set(const size_t num, const DType x, Blob *out, Context *ctx) {
   LOG(FATAL) << "Set Not Implemented";
 }
-/// Element-wise operation, out[i]=sigmoid([in[i])
+/// out[i]=sigmoid(in[i])
 template <typename DType, typename Lang>
 void Sigmoid(const size_t num, const Blob *in, Blob *out, Context *ctx) {
   LOG(FATAL) << "Sigmoid Not Implemented";
@@ -181,85 +186,47 @@ template <typename DType, typename Lang>
 void Sign(const size_t num, const Blob *in, Blob *out, Context *ctx) {
   LOG(FATAL) << "Sign Not Implemented";
 }
-/// Element-wise operation, out[i]=sqrt([in[i])
+/// out[i]=sqrt(in[i])
 template <typename DType, typename Lang>
 void Sqrt(const size_t num, const Blob *in, Blob *out, Context *ctx) {
   LOG(FATAL) << "Sqrt Not Implemented";
 }
 
-/// Element-wise operation, out[i]=square([in[i])
+/// out[i]=square(in[i])
 template <typename DType, typename Lang>
 void Square(const size_t num, const Blob *in, Blob *out, Context *ctx) {
-  LOG(FATAL) << "Square Not Implemented";
+  EltwiseMult<DType, Lang>(num, in, in, out, ctx);
 }
 
-/// out =  in - x
+/// out[i] =  in[i] - x
 template <typename DType, typename Lang>
 void Sub(const size_t num, const Blob *in, const DType x, Blob *out,
          Context *ctx) {
   Add<DType, Lang>(num, in, -x, out, ctx);
 }
 
-/// out = in1 - in2
+/// out[i] = in1[i] - in2[i]
 template <typename DType, typename Lang>
 void Sub(const size_t num, const Blob *in1, const Blob *in2, Blob *out,
          Context *ctx) {
   LOG(FATAL) << "Sub-Pair Not Implemented";
 }
+
 /// sum all elements of in into out
 template <typename DType, typename Lang>
 void Sum(const size_t num, const Blob *in, DType *out, Context *ctx) {
   LOG(FATAL) << "Sum Not Implemented";
 }
 
-/// Element-wise operation, out[i]=tanh([in[i])
+/// out[i]=tanh(in[i])
 template <typename DType, typename Lang>
 void Tanh(const size_t num, const Blob *in, Blob *out, Context *ctx) {
   LOG(FATAL) << "Tanh Not Implemented";
 }
 
-// =========== Matrix operations ===========================================
-/// Add the vector v to every column of A as the column of out
-template <typename DType, typename Lang>
-void AddCol(const size_t nrow, const size_t ncol, const Blob *A, const Blob *v,
-            Blob *out, Context *ctx) {
-  LOG(FATAL) << "AddCol Not Implemented";
-}
-// TODO(wangwei) unify AddRow and AddCol.
-/// Add the vector v to every row of A as the row of out
-template <typename DType, typename Lang>
-void AddRow(const size_t nrow, const size_t ncol, const Blob *A, const Blob *v,
-            Blob *out, Context *ctx) {
-  LOG(FATAL) << "AddRow Not Implemented";
-}
-/// outer-product.
-/// in1 and in2 are vectors of len m and n. out is matrix of shape m * n
-template <typename DType, typename Lang>
-void Outer(const size_t m, const size_t n, const Blob *in1, const Blob *in2,
-           Blob *out, Context *ctx) {
-  LOG(FATAL) << "Outer Not Implemented";
-}
-// Do softmax for each row invidually
-template <typename DType, typename Lang>
-void Softmax(const size_t nrow, const size_t ncol, const Blob *in, Blob *out,
-             Context *ctx) {
-  LOG(FATAL) << "Softmax Not Implemented";
-}
-/// Sum the columns of the in matrix into a vector
-template <typename DType, typename Lang>
-void SumColumns(const size_t nrow, const size_t ncol, const Blob *in, Blob *out,
-                Context *ctx) {
-  LOG(FATAL) << "SumColumns Not Implemented";
-}
-// TODO(wangwei) unify SumRow and SumCol.
-/// Sum the rows of the in matrix into a vector
-template <typename DType, typename Lang>
-void SumRows(const size_t nrow, const size_t ncol, const Blob *in, Blob *out,
-             Context *ctx) {
-  LOG(FATAL) << "SumRows Not Implemented";
-}
-
-// ================Random functions===========================================
+// **************************************
+// Random functions
+// **************************************
 /// Each element of out would be 1 with prob p and 0 with 1-p. 0<= p <= 1
 // Get the random generator from 'ctx'
 // If DType is not float, then convert the threshold to DType
@@ -282,7 +249,10 @@ void Uniform(const size_t num, const float low, const float high, Blob *out,
   LOG(FATAL) << "Uniform Not Implemented";
 }
 
-// ===== BLAS functions, ref to http://docs.nvidia.com/cuda/cublas
+// *********************************************************
+// BLAS functions, ref to http://docs.nvidia.com/cuda/cublas
+// *********************************************************
+
 /// outurn the index of the element with the max value.
 template <typename DType, typename Lang>
 void Amax(const size_t num, const Blob *in, size_t *out, Context *ctx) {
@@ -307,12 +277,19 @@ void Axpy(const size_t num, const DType alpha, const Blob *in, Blob *out,
   LOG(FATAL) << "Axpy Not Implemented";
 }
 
+/// out = ||in||_2^2, i.e, L2 norm.
+template <typename DType, typename Lang>
+void Nrm2(const size_t num, const Blob *in, float *out, Context *ctx) {
+  LOG(FATAL) << "Nrm2 Not Implemented";
+}
+
 /// out *= x
 template <typename DType, typename Lang>
 void Scale(const size_t num, const DType x, Blob *out, Context *ctx) {
   LOG(FATAL) << "Scale Not Implemented";
 }
 
+/// inner product of array in1 and in2
 template <typename DType, typename Lang>
 void Dot(const size_t num, const Blob *in1, const Blob *in2, DType *out,
          Context *ctx) {
@@ -346,5 +323,44 @@ void GEMM(const bool transA, const bool transB, const size_t nrowA,
   LOG(FATAL) << "GEMM Not Implemented";
 }
 
+// **************************************
+// Matrix functions
+// **************************************
+/*
+/// Add the vector v to every column of A as the column of out
+template <typename DType, typename Lang>
+void AddCol(const size_t nrow, const size_t ncol, const Blob *A, const Blob *v,
+            Blob *out, Context *ctx) {
+  LOG(FATAL) << "AddCol Not Implemented";
+}
+// TODO(wangwei) unify AddRow and AddCol.
+/// Add the vector v to every row of A as the row of out
+template <typename DType, typename Lang>
+void AddRow(const size_t nrow, const size_t ncol, const Blob *A, const Blob *v,
+            Blob *out, Context *ctx) {
+  LOG(FATAL) << "AddRow Not Implemented";
+}
+/// outer-product.
+/// in1 and in2 are vectors of len m and n. out is matrix of shape m * n
+template <typename DType, typename Lang>
+void Outer(const size_t m, const size_t n, const Blob *in1, const Blob *in2,
+           Blob *out, Context *ctx) {
+  LOG(FATAL) << "Outer Not Implemented";
+}
+
+/// Sum the columns of the in matrix into a vector
+template <typename DType, typename Lang>
+void SumColumns(const size_t nrow, const size_t ncol, const Blob *in, Blob *out,
+                Context *ctx) {
+  LOG(FATAL) << "SumColumns Not Implemented";
+}
+// TODO(wangwei) unify SumRow and SumCol.
+/// Sum the rows of the in matrix into a vector
+template <typename DType, typename Lang>
+void SumRows(const size_t nrow, const size_t ncol, const Blob *in, Blob *out,
+             Context *ctx) {
+  LOG(FATAL) << "SumRows Not Implemented";
+}
+*/
 }  // namespace singa
 #endif  // SINGA_CORE_MATH_H_

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/6d69047a/src/core/tensor/tensor_math_cpp.h
----------------------------------------------------------------------
diff --git a/src/core/tensor/tensor_math_cpp.h b/src/core/tensor/tensor_math_cpp.h
index 2c5c272..0b280a3 100644
--- a/src/core/tensor/tensor_math_cpp.h
+++ b/src/core/tensor/tensor_math_cpp.h
@@ -241,7 +241,7 @@ void Sqrt<float, lang::Cpp>(const size_t num, const Blob *in, Blob *out,
     outPtr[i] = sqrt(inPtr[i]);
   }
 }
-
+/*
 template <>
 void Square<float, lang::Cpp>(const size_t num, const Blob *in, Blob *out,
                               Context *ctx) {
@@ -251,6 +251,7 @@ void Square<float, lang::Cpp>(const size_t num, const Blob *in, Blob *out,
     outPtr[i] = inPtr[i] * inPtr[i];
   }
 }
+*/
 
 template <>
 void Sub<float, lang::Cpp>(const size_t num, const Blob *in1, const Blob *in2,
@@ -287,101 +288,6 @@ void Tanh<float, lang::Cpp>(const size_t num, const Blob *in, Blob *out,
   }
 }
 
-// =========Matrix operations ================================================
-
-template <>
-void AddCol<float, lang::Cpp>(const size_t nrow, const size_t ncol,
-                              const Blob *A, const Blob *v, Blob *out,
-                              Context *ctx) {
-  float *outPtr = static_cast<float *>(out->mutable_data());
-  const float *APtr = static_cast<const float *>(A->data());
-  const float *vPtr = static_cast<const float *>(v->data());
-  for (size_t r = 0; r < nrow; r++) {
-    size_t offset = r * ncol;
-    for (size_t c = 0; c < ncol; c++) {
-      outPtr[offset + c] = APtr[offset + c] + vPtr[r];
-    }
-  }
-}
-
-template <>
-void AddRow<float, lang::Cpp>(const size_t nrow, const size_t ncol,
-                              const Blob *A, const Blob *v, Blob *out,
-                              Context *ctx) {
-  float *outPtr = static_cast<float *>(out->mutable_data());
-  const float *APtr = static_cast<const float *>(A->data());
-  const float *vPtr = static_cast<const float *>(v->data());
-  for (size_t r = 0; r < nrow; r++) {
-    size_t offset = r * ncol;
-    for (size_t c = 0; c < ncol; c++) {
-      outPtr[offset + c] = APtr[offset + c] + vPtr[c];
-    }
-  }
-}
-template <>
-void Outer<float, lang::Cpp>(const size_t m, const size_t n, const Blob *in1,
-                             const Blob *in2, Blob *out, Context *ctx) {
-  float *outPtr = static_cast<float *>(out->mutable_data());
-  const float *in1Ptr = static_cast<const float *>(in1->data());
-  const float *in2Ptr = static_cast<const float *>(in2->data());
-  for (size_t r = 0; r < m; r++) {
-    size_t offset = r * n;
-    for (size_t c = 0; c < n; c++) {
-      outPtr[offset + c] = in1Ptr[r] * in2Ptr[c];
-    }
-  }
-}
-template <>
-void Softmax<float, lang::Cpp>(const size_t nrow, const size_t ncol,
-                               const Blob *in, Blob *out, Context *ctx) {
-  float *outPtr = static_cast<float *>(out->mutable_data());
-  const float *inPtr = static_cast<const float *>(in->data());
-  float *bPtr = new float[ncol];
-  for (size_t r = 0; r < nrow; r++) {
-    size_t offset = r * ncol;
-    float denom = 0.f;
-    for (size_t c = 0; c < ncol; c++) {
-      bPtr[c] = exp(inPtr[offset + c]);
-      denom += bPtr[c];
-    }
-    for (size_t c = 0; c < ncol; c++) {
-      size_t idx = offset + c;
-      outPtr[idx] = bPtr[c] / denom;
-    }
-  }
-  delete bPtr;
-}
-
-template <>
-void SumColumns<float, lang::Cpp>(const size_t nrow, const size_t ncol,
-                                  const Blob *in, Blob *out, Context *ctx) {
-  float *outPtr = static_cast<float *>(out->mutable_data());
-  const float *inPtr = static_cast<const float *>(in->data());
-  for (size_t c = 0; c < ncol; c++) {
-    outPtr[c] = 0.f;
-  }
-  for (size_t r = 0; r < nrow; r++) {
-    size_t offset = r * ncol;
-    for (size_t c = 0; c < ncol; c++) {
-      outPtr[c] += inPtr[offset + c];
-    }
-  }
-}
-
-template <>
-void SumRows<float, lang::Cpp>(const size_t nrow, const size_t ncol,
-                               const Blob *in, Blob *out, Context *ctx) {
-  float *outPtr = static_cast<float *>(out->mutable_data());
-  const float *inPtr = static_cast<const float *>(in->data());
-  for (size_t r = 0; r < nrow; r++) {
-    size_t offset = r * ncol;
-    outPtr[r] = 0.f;
-    for (size_t c = 0; c < ncol; c++) {
-      outPtr[r] += inPtr[offset + c];
-    }
-  }
-}
-
 // ===============Random operations==========================================
 template <>
 void Bernoulli<float, lang::Cpp>(const size_t num, const float p, Blob *out,
@@ -440,18 +346,26 @@ void DGMM<float, lang::Cpp>(const bool side_right, const size_t nrow,
 
 #ifdef USE_CBLAS
 template <>
+void Amax<float, lang::Cpp>(const size_t num, const Blob *in, size_t *out,
+                            Context *ctx) {
+  const float *inPtr = static_cast<const float *>(in->data());
+  *out = cblas_isamax(num, inPtr, 1);
+}
+
+template <>
+void Asum<float, lang::Cpp>(const size_t num, const Blob *in, float *out,
+                            Context *ctx) {
+  const float *inPtr = static_cast<const float *>(in->data());
+  *out = cblas_sasum(num, inPtr, 1);
+}
+
+template <>
 void Axpy<float, lang::Cpp>(const size_t num, const float alpha, const Blob *in,
                             Blob *out, Context *ctx) {
   const float *inPtr = static_cast<const float *>(in->data());
   float *outPtr = static_cast<float *>(out->mutable_data());
   cblas_saxpy(num, alpha, inPtr, 1, outPtr, 1);
 }
-template <>
-void Scale<float, lang::Cpp>(const size_t num, const float x, Blob *out,
-                             Context *ctx) {
-  float *outPtr = static_cast<float *>(out->mutable_data());
-  cblas_sscal(num, x, outPtr, 1);
-}
 
 template <>
 void Dot<float, lang::Cpp>(const size_t num, const Blob *in1, const Blob *in2,
@@ -461,6 +375,19 @@ void Dot<float, lang::Cpp>(const size_t num, const Blob *in1, const Blob *in2,
   *out = cblas_sdot(num, in1Ptr, 1, in2Ptr, 1);
 }
 template <>
+void Scale<float, lang::Cpp>(const size_t num, const float x, Blob *out,
+                             Context *ctx) {
+  float *outPtr = static_cast<float *>(out->mutable_data());
+  cblas_sscal(num, x, outPtr, 1);
+}
+template <>
+void Nrm2<float, lang::Cpp>(const size_t num, const Blob *in, float *out,
+                            Context *ctx) {
+  const float *inPtr = static_cast<const float *>(in->data());
+  *out = cblas_snrm2(num, inPtr, 1);
+}
+
+template <>
 void GEMV<float, lang::Cpp>(bool trans, const size_t m, const size_t n,
                             const float alpha, const Blob *A, const Blob *v,
                             const float beta, Blob *out, Context *ctx) {
@@ -587,6 +514,102 @@ void GEMV<float, lang::Cpp>(bool trans, const size_t m, const size_t n,
 }
 
 #endif  // USE_CBLAS
+
+// =========Matrix operations ================================================
+/*
+template <>
+void AddCol<float, lang::Cpp>(const size_t nrow, const size_t ncol,
+                              const Blob *A, const Blob *v, Blob *out,
+                              Context *ctx) {
+  float *outPtr = static_cast<float *>(out->mutable_data());
+  const float *APtr = static_cast<const float *>(A->data());
+  const float *vPtr = static_cast<const float *>(v->data());
+  for (size_t r = 0; r < nrow; r++) {
+    size_t offset = r * ncol;
+    for (size_t c = 0; c < ncol; c++) {
+      outPtr[offset + c] = APtr[offset + c] + vPtr[r];
+    }
+  }
+}
+
+template <>
+void AddRow<float, lang::Cpp>(const size_t nrow, const size_t ncol,
+                              const Blob *A, const Blob *v, Blob *out,
+                              Context *ctx) {
+  float *outPtr = static_cast<float *>(out->mutable_data());
+  const float *APtr = static_cast<const float *>(A->data());
+  const float *vPtr = static_cast<const float *>(v->data());
+  for (size_t r = 0; r < nrow; r++) {
+    size_t offset = r * ncol;
+    for (size_t c = 0; c < ncol; c++) {
+      outPtr[offset + c] = APtr[offset + c] + vPtr[c];
+    }
+  }
+}
+template <>
+void Outer<float, lang::Cpp>(const size_t m, const size_t n, const Blob *in1,
+                             const Blob *in2, Blob *out, Context *ctx) {
+  float *outPtr = static_cast<float *>(out->mutable_data());
+  const float *in1Ptr = static_cast<const float *>(in1->data());
+  const float *in2Ptr = static_cast<const float *>(in2->data());
+  for (size_t r = 0; r < m; r++) {
+    size_t offset = r * n;
+    for (size_t c = 0; c < n; c++) {
+      outPtr[offset + c] = in1Ptr[r] * in2Ptr[c];
+    }
+  }
+}
+template <>
+void Softmax<float, lang::Cpp>(const size_t nrow, const size_t ncol,
+                               const Blob *in, Blob *out, Context *ctx) {
+  float *outPtr = static_cast<float *>(out->mutable_data());
+  const float *inPtr = static_cast<const float *>(in->data());
+  float *bPtr = new float[ncol];
+  for (size_t r = 0; r < nrow; r++) {
+    size_t offset = r * ncol;
+    float denom = 0.f;
+    for (size_t c = 0; c < ncol; c++) {
+      bPtr[c] = exp(inPtr[offset + c]);
+      denom += bPtr[c];
+    }
+    for (size_t c = 0; c < ncol; c++) {
+      size_t idx = offset + c;
+      outPtr[idx] = bPtr[c] / denom;
+    }
+  }
+  delete bPtr;
+}
+
+template <>
+void SumColumns<float, lang::Cpp>(const size_t nrow, const size_t ncol,
+                                  const Blob *in, Blob *out, Context *ctx) {
+  float *outPtr = static_cast<float *>(out->mutable_data());
+  const float *inPtr = static_cast<const float *>(in->data());
+  for (size_t c = 0; c < ncol; c++) {
+    outPtr[c] = 0.f;
+  }
+  for (size_t r = 0; r < nrow; r++) {
+    size_t offset = r * ncol;
+    for (size_t c = 0; c < ncol; c++) {
+      outPtr[c] += inPtr[offset + c];
+    }
+  }
+}
+
+template <>
+void SumRows<float, lang::Cpp>(const size_t nrow, const size_t ncol,
+                               const Blob *in, Blob *out, Context *ctx) {
+  float *outPtr = static_cast<float *>(out->mutable_data());
+  const float *inPtr = static_cast<const float *>(in->data());
+  for (size_t r = 0; r < nrow; r++) {
+    size_t offset = r * ncol;
+    outPtr[r] = 0.f;
+    for (size_t c = 0; c < ncol; c++) {
+      outPtr[r] += inPtr[offset + c];
+    }
+  }
+}
+*/
 }  // namespace singa
 
 #endif  // SINGA_CORE_TENSOR_TENSOR_MATH_CPP_H_

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/6d69047a/src/core/tensor/tensor_math_cuda.h
----------------------------------------------------------------------
diff --git a/src/core/tensor/tensor_math_cuda.h b/src/core/tensor/tensor_math_cuda.h
index f9841a3..e2597d5 100644
--- a/src/core/tensor/tensor_math_cuda.h
+++ b/src/core/tensor/tensor_math_cuda.h
@@ -24,105 +24,336 @@
 #include "./math_kernel.h"
 #include "singa/utils/cuda_utils.h"
 #include "singa/core/common.h"
+#include <cuda_runtime.h>
+#include <cublas_v2.h>
+#include "singa/utils/cuda_utils.h"
 
 namespace singa {
-// =================Elementwise operations===================================
+
+/// out[i] = |in[i]|
+template <>
+void Abs<float, lang::Cuda>(const size_t num, const Blob* in, Blob* out,
+                            Context* ctx) {
+  const float* inPtr = static_cast<const float*>(in->data());
+  float* outPtr = static_cast<float*>(out->mutable_data());
+  cuda::abs(num, inPtr, outPtr, ctx->stream);
+}
+/// out = in + x
+template <>
+void Add<float, lang::Cuda>(const size_t num, const Blob* in, const float x,
+                            Blob* out, Context* ctx) {
+  const float* inPtr = static_cast<const float*>(in->data());
+  float* outPtr = static_cast<float*>(out->mutable_data());
+  cuda::add(num, inPtr, x, outPtr, ctx->stream);
+}
+/// out = in1 + in2
+template <>
+void Add<float, lang::Cuda>(const size_t num, const Blob* in1, const Blob* in2,
+                            Blob* out, Context* ctx) {
+  const float* inPtr1 = static_cast<const float*>(in1->data());
+  const float* inPtr2 = static_cast<const float*>(in2->data());
+  float* outPtr = static_cast<float*>(out->mutable_data());
+  cuda::add(num, inPtr1, inPtr2, outPtr, ctx->stream);
+}
+/// Element-wise operation, clamp every element into [low, high]
+/// if x>high, then x=high; if x<low, then x=low.
+template <>
+void Clamp<float, lang::Cuda>(const size_t num, const float low,
+                              const float high, const Blob* in, Blob* out,
+                              Context* ctx) {
+  const float* inPtr = static_cast<const float*>(in->data());
+  float* outPtr = static_cast<float*>(out->mutable_data());
+  cuda::clamp(num, low, high, inPtr, outPtr, ctx->stream);
+}
+/// out = in1 / in2
+template <>
+void Div<float, lang::Cuda>(const size_t num, const Blob* in1, const Blob* in2,
+                            Blob* out, Context* ctx) {
+  const float* inPtr1 = static_cast<const float*>(in1->data());
+  const float* inPtr2 = static_cast<const float*>(in2->data());
+  float* outPtr = static_cast<float*>(out->mutable_data());
+  cuda::div(num, inPtr1, inPtr2, outPtr, ctx->stream);
+}
+
+template <>
+void Div<float, lang::Cuda>(const size_t num, const float x, const Blob* in,
+                            Blob* out, Context* ctx) {
+  const float* inPtr = static_cast<const float*>(in->data());
+  float* outPtr = static_cast<float*>(out->mutable_data());
+  cuda::div(num, x, inPtr, outPtr, ctx->stream);
+}
+
+/// out = in * x
+template <>
+void EltwiseMult<float, lang::Cuda>(const size_t num, const Blob* in,
+                                    const float x, Blob* out, Context* ctx) {
+  const float* inPtr = static_cast<const float*>(in->data());
+  float* outPtr = static_cast<float*>(out->mutable_data());
+  cuda::mult(num, inPtr, x, outPtr, ctx->stream);
+}
+/// out = in1 * in2
+template <>
+void EltwiseMult<float, lang::Cuda>(const size_t num, const Blob* in1,
+                                    const Blob* in2, Blob* out, Context* ctx) {
+  const float* inPtr1 = static_cast<const float*>(in1->data());
+  const float* inPtr2 = static_cast<const float*>(in2->data());
+  float* outPtr = static_cast<float*>(out->mutable_data());
+  cuda::mult(num, inPtr1, inPtr2, outPtr, ctx->stream);
+}
+/// Base is e. out[i]=e^in[i]
+template <>
+void Exp<float, lang::Cuda>(const size_t num, const Blob* in, Blob* out,
+                            Context* ctx) {
+  const float* inPtr = static_cast<const float*>(in->data());
+  float* outPtr = static_cast<float*>(out->mutable_data());
+  cuda::exp(num, inPtr, outPtr, ctx->stream);
+}
+
+template <>
+void GE<float, lang::Cuda>(const size_t num, const Blob* in, const float x,
+                           Blob* out, Context* ctx) {
+  float* outPtr = static_cast<float*>(out->mutable_data());
+  const float* inPtr = static_cast<const float*>(in->data());
+  cuda::ge(num, inPtr, x, outPtr, ctx->stream);
+}
+
+template <>
+void GT<float, lang::Cuda>(const size_t num, const Blob* in, const float x,
+                           Blob* out, Context* ctx) {
+  float* outPtr = static_cast<float*>(out->mutable_data());
+  const float* inPtr = static_cast<const float*>(in->data());
+  cuda::gt(num, inPtr, x, outPtr, ctx->stream);
+}
+
+template <>
+void LE<float, lang::Cuda>(const size_t num, const Blob* in, const float x,
+                           Blob* out, Context* ctx) {
+  float* outPtr = static_cast<float*>(out->mutable_data());
+  const float* inPtr = static_cast<const float*>(in->data());
+  cuda::le(num, inPtr, x, outPtr, ctx->stream);
+}
+
+/// Natual logarithm, the base is e, Neper number out[i]=ln(in[i]).
+template <>
+void Log<float, lang::Cuda>(const size_t num, const Blob* in, Blob* out,
+                            Context* ctx) {
+  const float* inPtr = static_cast<const float*>(in->data());
+  float* outPtr = static_cast<float*>(out->mutable_data());
+  cuda::log(num, inPtr, outPtr, ctx->stream);
+}
+template <>
+void LT<float, lang::Cuda>(const size_t num, const Blob* in, const float x,
+                           Blob* out, Context* ctx) {
+  float* outPtr = static_cast<float*>(out->mutable_data());
+  const float* inPtr = static_cast<const float*>(in->data());
+  cuda::lt(num, inPtr, x, outPtr, ctx->stream);
+}
+
+/// Element-wise operation, out[i] = in[i]^x
+template <>
+void Pow<float, lang::Cuda>(const size_t num, const Blob* in, const float x,
+                            Blob* out, Context* ctx) {
+  const float* inPtr = static_cast<const float*>(in->data());
+  float* outPtr = static_cast<float*>(out->mutable_data());
+  cuda::pow(num, inPtr, x, outPtr, ctx->stream);
+}
+/// Element-wise operation, out[i] = in1[i]^in2[i]
 template <>
-void Add<float, lang::Cuda>(const size_t num, const Blob *in1, const Blob *in2,
-                            Blob *out, Context *ctx) {
-  const float *in1Ptr = static_cast<const float *>(in1->data());
-  const float *in2Ptr = static_cast<const float *>(in2->data());
-  float *outPtr = static_cast<float *>(out->mutable_data());
-  cuda::add(num, in1Ptr, in2Ptr, outPtr);
+void Pow<float, lang::Cuda>(const size_t num, const Blob* in1, const Blob* in2,
+                            Blob* out, Context* ctx) {
+  const float* inPtr1 = static_cast<const float*>(in1->data());
+  const float* inPtr2 = static_cast<const float*>(in2->data());
+  float* outPtr = static_cast<float*>(out->mutable_data());
+  cuda::pow(num, inPtr1, inPtr2, outPtr, ctx->stream);
 }
 
-// follow the consistency guide of math API
+/// Element-wise operation, out[i]=max(0, in[i])
 template <>
-void Div<float, lang::Cuda>(const size_t num, const float x, const Blob *in,
-                            Blob *out, Context *ctx) {
-  float *outPtr = static_cast<float *>(out->mutable_data());
-  const float *inPtr = static_cast<const float *>(in->data());
-  cuda::Div(num, x, inPtr, outPtr, ctx->stream);
+void ReLU<float, lang::Cuda>(const size_t num, const Blob* in, Blob* out,
+                             Context* ctx) {
+  const float* inPtr = static_cast<const float*>(in->data());
+  float* outPtr = static_cast<float*>(out->mutable_data());
+  cuda::relu(num, inPtr, outPtr, ctx->stream);
 }
 
+/// out[i] = x
 template <>
-void EltwiseMult<float, lang::Cuda>(const size_t num, const Blob *in,
-                                    const float x, Blob *out, Context *ctx) {
-  float *outPtr = static_cast<float *>(out->mutable_data());
-  const float *inPtr = static_cast<const float *>(in->data());
-  cuda::mult(num, inPtr, x, outPtr);
+void Set<float, lang::Cuda>(const size_t num, const float x, Blob* out,
+                            Context* ctx) {
+  float* outPtr = static_cast<float*>(out->mutable_data());
+  cuda::set(num, x, outPtr, ctx->stream);
 }
+/// Element-wise operation, out[i]=sigmoid([in[i])
 template <>
-void GE<float, lang::Cuda>(const size_t num, const Blob *in, const float x,
-                           Blob *out, Context *ctx) {
-  float *outPtr = static_cast<float *>(out->mutable_data());
-  const float *inPtr = static_cast<const float *>(in->data());
-  cuda::GE(num, inPtr, x, outPtr, ctx->stream);
+void Sigmoid<float, lang::Cuda>(const size_t num, const Blob* in, Blob* out,
+                                Context* ctx) {
+  const float* inPtr = static_cast<const float*>(in->data());
+  float* outPtr = static_cast<float*>(out->mutable_data());
+  cuda::sigmoid(num, inPtr, outPtr, ctx->stream);
 }
+// out[i] = sign(in[i])
 template <>
-void GT<float, lang::Cuda>(const size_t num, const Blob *in, const float x,
-                           Blob *out, Context *ctx) {
-  float *outPtr = static_cast<float *>(out->mutable_data());
-  const float *inPtr = static_cast<const float *>(in->data());
-  cuda::GT(num, inPtr, x, outPtr, ctx->stream);
+void Sign<float, lang::Cuda>(const size_t num, const Blob* in, Blob* out,
+                             Context* ctx) {
+  const float* inPtr = static_cast<const float*>(in->data());
+  float* outPtr = static_cast<float*>(out->mutable_data());
+  cuda::sign(num, inPtr, outPtr, ctx->stream);
 }
+
+/// Element-wise operation, out[i]=sqrt([in[i])
+template <>
+void Sqrt<float, lang::Cuda>(const size_t num, const Blob* in, Blob* out,
+                             Context* ctx) {
+  const float* inPtr = static_cast<const float*>(in->data());
+  float* outPtr = static_cast<float*>(out->mutable_data());
+  cuda::sqrt(num, inPtr, outPtr, ctx->stream);
+}
+
+/// Element-wise operation, out[i]=in[i]^2
 template <>
-void LE<float, lang::Cuda>(const size_t num, const Blob *in, const float x,
-                           Blob *out, Context *ctx) {
-  float *outPtr = static_cast<float *>(out->mutable_data());
-  const float *inPtr = static_cast<const float *>(in->data());
-  cuda::LE(num, inPtr, x, outPtr, ctx->stream);
+void Square<float, lang::Cuda>(const size_t num, const Blob* in, Blob* out,
+                               Context* ctx) {
+  const float* inPtr = static_cast<const float*>(in->data());
+  float* outPtr = static_cast<float*>(out->mutable_data());
+  cuda::square(num, inPtr, outPtr, ctx->stream);
 }
+/// out = in1 - in2
 template <>
-void LT<float, lang::Cuda>(const size_t num, const Blob *in, const float x,
-                           Blob *out, Context *ctx) {
-  float *outPtr = static_cast<float *>(out->mutable_data());
-  const float *inPtr = static_cast<const float *>(in->data());
-  cuda::LT(num, inPtr, x, outPtr, ctx->stream);
+void Sub<float, lang::Cuda>(const size_t num, const Blob* in1, const Blob* in2,
+                            Blob* out, Context* ctx) {
+  const float* inPtr1 = static_cast<const float*>(in1->data());
+  const float* inPtr2 = static_cast<const float*>(in2->data());
+  float* outPtr = static_cast<float*>(out->mutable_data());
+  cuda::sub(num, inPtr1, inPtr2, outPtr, ctx->stream);
 }
+
+/// sum all elements of input into out
 template <>
-void Set<float, lang::Cuda>(const size_t num, const float x, Blob *out,
-                            Context *ctx) {
-  float *outPtr = static_cast<float *>(out->mutable_data());
-  cuda::Set(num, x, outPtr, ctx->stream);
+void Sum<float, lang::Cuda>(const size_t num, const Blob* in, float* out,
+                            Context* ctx) {
+  const float* inPtr = static_cast<const float*>(in->data());
+  cuda::sum(num, inPtr, out, ctx->stream);
 }
-// TODO(wangwei) optimize using stream
+
+/// Element-wise operation, out[i]=tanh([in[i])
 template <>
-void Square<float, lang::Cuda>(const size_t num, const Blob *in, Blob *out,
-                               Context *ctx) {
-  float *outPtr = static_cast<float *>(out->mutable_data());
-  const float *inPtr = static_cast<const float *>(in->data());
-  cuda::square(num, inPtr, outPtr);
+void Tanh<float, lang::Cuda>(const size_t num, const Blob* in, Blob* out,
+                             Context* ctx) {
+  const float* inPtr = static_cast<const float*>(in->data());
+  float* outPtr = static_cast<float*>(out->mutable_data());
+  cuda::tanh(num, inPtr, outPtr, ctx->stream);
 }
-// TODO(wangwei) optimize using stream
+
+// ================Random functions===========================================
+/// Each element of out would be 1 with prob p and 0 with 1-p. 0<= p <= 1
+// Get the random generator from 'ctx'
+// If DType is not float, then convert the threshold to DType
 template <>
-void Sub<float, lang::Cuda>(const size_t num, const Blob *in1, const Blob *in2,
-                            Blob *out, Context *ctx) {
-  float *outPtr = static_cast<float *>(out->mutable_data());
-  const float *in1Ptr = static_cast<const float *>(in1->data());
-  const float *in2Ptr = static_cast<const float *>(in2->data());
-  cuda::sub(num, in1Ptr, in2Ptr, outPtr);
+void Bernoulli<float, lang::Cuda>(const size_t num, const float p, Blob* out,
+                                  Context* ctx) {
+  auto rgen = ctx->curand_generator;
+  float* outPtr = static_cast<float*>(out->mutable_data());
+  CURAND_CHECK(curandGenerateUniform(rgen, outPtr, num));
+  cuda::threshold(num, p, outPtr, outPtr, ctx->stream);
 }
-// sum all elements of input into ret
-// TODO(wangwei) optimize using stream
+
+// The random generator should be extracted from ctx.
+// If DType is not float, then convert the low and high to DType
 template <>
-void Sum<float, lang::Cuda>(const size_t num, const Blob *in, float *out,
-                            Context *ctx) {
-  const float *inPtr = static_cast<const float *>(in->data());
-  cuda::sum(num, inPtr, out);
+void Uniform<float, lang::Cuda>(const size_t num, const float low,
+                                const float high, Blob* out, Context* ctx) {
+  auto rgen = ctx->curand_generator;
+  float* outPtr = static_cast<float*>(out->mutable_data());
+  CURAND_CHECK(curandGenerateUniform(rgen, outPtr, num));
+  cuda::mult(num, outPtr, high - low, outPtr, ctx->stream);
+  cuda::add(num, outPtr, low, outPtr, ctx->stream);
+}
+
+// The random generator should be extracted from ctx.
+// If DType is not float, then convert the mean and delta to DType
+template <>
+void Gaussian<float, lang::Cuda>(const size_t num, const float mean,
+                                 const float std, Blob* out, Context* ctx) {
+  auto rgen = ctx->curand_generator;
+  float* outPtr = static_cast<float*>(out->mutable_data());
+  CURAND_CHECK(curandGenerateNormal(rgen, outPtr, num, mean, std));
 }
 
 // =========================Blas operations==================================
+// ref to http://docs.nvidia.com/cuda/cublas
+template <>
+void Amax<float, lang::Cuda>(const size_t num, const Blob* in, size_t* out,
+                             Context* ctx) {
+  const float* inPtr = static_cast<const float*>(in->data());
+  auto handle = ctx->cublas_handle;  // TODO(wangwei) set cudastream
+  int idx = 1;
+  CUBLAS_CHECK(cublasIsamax(handle, num, inPtr, 1, &idx));
+  *out = idx - 1;  // cublas index starts from 1
+}
+
+/// return the index of the element with the min value.
+template <>
+void Amin<float, lang::Cuda>(const size_t num, const Blob* in, size_t* out,
+                             Context* ctx) {
+  const float* inPtr = static_cast<const float*>(in->data());
+  auto handle = ctx->cublas_handle;  // TODO(wangwei) set cudastream
+  int idx = 1;
+  CUBLAS_CHECK(cublasIsamin(handle, num, inPtr, 1, &idx));
+  *out = idx - 1;
+}
+
+/// out = sum |x| for all x in in
+template <>
+void Asum<float, lang::Cuda>(const size_t num, const Blob* in, float* out,
+                             Context* ctx) {
+  const float* inPtr = static_cast<const float*>(in->data());
+  auto handle = ctx->cublas_handle;  // TODO(wangwei) set cudastream
+  CUBLAS_CHECK(cublasSasum(handle, num, inPtr, 1, out));
+}
+
+/// out = alpha * in + out
+template <>
+void Axpy<float, lang::Cuda>(const size_t num, const float alpha,
+                             const Blob* in, Blob* out, Context* ctx) {
+  const float* inPtr = static_cast<const float*>(in->data());
+  float* outPtr = static_cast<float*>(out->mutable_data());
+  auto handle = ctx->cublas_handle;  // TODO(wangwei) set cudastream
+  CUBLAS_CHECK(cublasSaxpy(handle, num, &alpha, inPtr, 1, outPtr, 1));
+}
+
+/// out = \sum_i in1[i] * in2[i]
+template <>
+void Dot<float, lang::Cuda>(const size_t num, const Blob* in1, const Blob* in2,
+                            float* out, Context* ctx) {
+  const float* inPtr1 = static_cast<const float*>(in1->data());
+  const float* inPtr2 = static_cast<const float*>(in2->data());
+  auto handle = ctx->cublas_handle;  // TODO(wangwei) set cudastream
+  CUBLAS_CHECK(cublasSdot(handle, num, inPtr1, 1, inPtr2, 1, out));
+}
+template <>
+void Nrm2<float, lang::Cuda>(const size_t num, const Blob* in, float* out,
+                             Context* ctx) {
+  auto handle = ctx->cublas_handle;  // TODO(wangwei) set cudastream
+  const float* inPtr = static_cast<const float*>(in->data());
+  cublasSnrm2(handle, num, inPtr, 1, out);
+}
+template <>
+void Scale<float, lang::Cuda>(const size_t num, const float x, Blob* out,
+                              Context* ctx) {
+  auto handle = ctx->cublas_handle;  // TODO(wangwei) set cudastream
+  float* outPtr = static_cast<float*>(out->mutable_data());
+  CUBLAS_CHECK(cublasSscal(handle, num, &x, outPtr, 1));
+}
 // NOTE: cublas uses column major order.
 // http://peterwittek.com/cublas-matrix-c-style.html
 template <>
 void DGMM<float, lang::Cuda>(const bool side_right, const size_t nrow,
-                             const size_t ncol, const Blob *M, const Blob *v,
-                             Blob *out, Context *ctx) {
+                             const size_t ncol, const Blob* M, const Blob* v,
+                             Blob* out, Context* ctx) {
   auto handle = ctx->cublas_handle;  // TODO(wangwei) set cudastream
-  const float *MPtr = static_cast<const float *>(M->data());
-  const float *vPtr = static_cast<const float *>(v->data());
-  float *outPtr = static_cast<float *>(out->mutable_data());
+  const float* MPtr = static_cast<const float*>(M->data());
+  const float* vPtr = static_cast<const float*>(v->data());
+  float* outPtr = static_cast<float*>(out->mutable_data());
   if (side_right) {
     CUBLAS_CHECK(cublasSdgmm(handle, CUBLAS_SIDE_LEFT, ncol, nrow, MPtr, ncol,
                              vPtr, 1, outPtr, ncol));
@@ -133,11 +364,11 @@ void DGMM<float, lang::Cuda>(const bool side_right, const size_t nrow,
 }
 template <>
 void GEMV<float, lang::Cuda>(bool trans, const size_t m, const size_t n,
-                             const float alpha, const Blob *A, const Blob *v,
-                             const float beta, Blob *out, Context *ctx) {
-  const float *APtr = static_cast<const float *>(A->data());
-  const float *vPtr = static_cast<const float *>(v->data());
-  float *outPtr = static_cast<float *>(out->mutable_data());
+                             const float alpha, const Blob* A, const Blob* v,
+                             const float beta, Blob* out, Context* ctx) {
+  const float* APtr = static_cast<const float*>(A->data());
+  const float* vPtr = static_cast<const float*>(v->data());
+  float* outPtr = static_cast<float*>(out->mutable_data());
   auto handle = ctx->cublas_handle;  // TODO(wangwei) set cudastream
   if (!trans)
     CUBLAS_CHECK(cublasSgemv(handle, CUBLAS_OP_T, n, m, &alpha, APtr, n, vPtr,
@@ -152,16 +383,16 @@ template <>
 void GEMM<float, lang::Cuda>(const bool transA, const bool transB,
                              const size_t nrowA, const size_t ncolB,
                              const size_t ncolA, const float alpha,
-                             const Blob *A, const Blob *B, const float beta,
-                             Blob *C, Context *ctx) {
+                             const Blob* A, const Blob* B, const float beta,
+                             Blob* C, Context* ctx) {
   auto transa = transA ? CUBLAS_OP_T : CUBLAS_OP_N;
   auto transb = transB ? CUBLAS_OP_T : CUBLAS_OP_N;
   int lda = transA ? nrowA : ncolA;
   int ldb = transB ? ncolA : ncolB;
   int ldc = ncolB;
-  const float *APtr = static_cast<const float *>(A->data());
-  const float *BPtr = static_cast<const float *>(B->data());
-  float *CPtr = static_cast<float *>(C->mutable_data());
+  const float* APtr = static_cast<const float*>(A->data());
+  const float* BPtr = static_cast<const float*>(B->data());
+  float* CPtr = static_cast<float*>(C->mutable_data());
   auto handle = ctx->cublas_handle;  // TODO(wangwei) set cudastream
   CUBLAS_CHECK(cublasSgemm(handle, transb, transa, ncolB, nrowA, ncolA, &alpha,
                            BPtr, ldb, APtr, lda, &beta, CPtr, ldc));
@@ -171,4 +402,3 @@ void GEMM<float, lang::Cuda>(const bool transA, const bool transB,
 
 #endif  // USE_CUDA
 #endif  // SINGA_CORE_TENSOR_TENSOR_MATH_CUDA_H_
-