You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@singa.apache.org by ka...@apache.org on 2015/11/16 07:08:55 UTC

[09/19] incubator-singa git commit: SINGA-80 New Blob Level and Address Level Math Operation Interface

SINGA-80 New Blob Level and Address Level Math Operation Interface

Temp commit, not compiled yet.
* Move functions in math_addr.cc and math_blob.cc into header files to simply the compilation of template code.
* Add comments in math_blob.h file.
* Add shape checking.

TODO
*remove the functions like relu/softplus/sigmoid for Blob, the
function body contains only one line of code, which can be written
directly when calling the underlying functions.
* Update Blob class to implement helper functions, e.g., Reshape, shape(int k).
* Update math functions for gpu, there are mis-matching APIs.


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/4b84dbe3
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/4b84dbe3
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/4b84dbe3

Branch: refs/heads/master
Commit: 4b84dbe30296985afafc88c08dc84f664cfc3617
Parents: 641eb31
Author: Wei Wang <wa...@comp.nus.edu.sg>
Authored: Mon Nov 9 14:10:40 2015 +0800
Committer: Wei Wang <wa...@comp.nus.edu.sg>
Committed: Mon Nov 9 17:04:48 2015 +0800

----------------------------------------------------------------------
 include/singa/blob/math_addr.h | 201 +++++---
 include/singa/blob/math_blob.h | 955 +++++++++++++++++++++---------------
 include/singa/blob/singa_op.h  | 500 +++++++++----------
 include/singa/utils/blob.h     | 188 +++++--
 src/blob/math_addr.cc          | 120 -----
 src/blob/math_blob.cc          | 214 --------
 6 files changed, 1092 insertions(+), 1086 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/4b84dbe3/include/singa/blob/math_addr.h
----------------------------------------------------------------------
diff --git a/include/singa/blob/math_addr.h b/include/singa/blob/math_addr.h
index 2a25a29..25cef07 100644
--- a/include/singa/blob/math_addr.h
+++ b/include/singa/blob/math_addr.h
@@ -21,113 +21,186 @@
 
 #ifndef SINGA_BLOB_MATH_ADDR_H_
 #define SINGA_BLOB_MATH_ADDR_H_
+extern "C" {
+    #include <cblas.h>
+}
+#ifdef USE_GPU
+#include <cuda_runtime.h>
+#endif
+#include "singa/blob/singa_op.h"
+#ifdef USE_GPU
+#include "cublas_v2.h"
+#endif
 
-namespace singa {
 
-const float * cpu_uni_vec(const int n);
+namespace singa {
 
-void cpu_gemm(const float * A, const float * B,
-const int m, const int n, const int k, const float alpha, const float beta,
-const bool TranA, const bool TranB, float * C);
+template<typename Dtype>
+void cpu_gemm(const Dtype * A, const Dtype * B,
+    const int m, const int n, const int k, const Dtype alpha, const Dtype beta,
+    const bool TranA, const bool TranB, Dtype * C) {
+  int lda, ldb;
+  CBLAS_TRANSPOSE tA, tB;
+  lda = TranA ? m : k;
+  ldb = TranB ? k : n;
+  tA = TranA ? CblasTrans : CblasNoTrans;
+  tB = TranB ? CblasTrans : CblasNoTrans;
+  cblas_sgemm(CblasRowMajor, tA, tB, m, n, k, alpha, A, lda,
+      B, ldb, beta, C, n);
+}
 
-void cpu_gemv(const float * A, const float * B, const int m, const int n,
-const float alpha, const float beta, const bool TranA, float * C);
 // should be very careful:
 // m is the length of B, and n is the length of C , A is a n*m matrix
+template<typename Dtype>
+void cpu_gemv(const Dtype * A, const Dtype * B, const int m, const int n,
+    const Dtype alpha, const Dtype beta, const bool TranA, Dtype * C) {
+  int lda, ldb;
+  CBLAS_TRANSPOSE tA, tB;
+  lda = TranA ? m : k;
+  ldb = TranB ? k : n;
+  tA = TranA ? CblasTrans : CblasNoTrans;
+  tB = TranB ? CblasTrans : CblasNoTrans;
+  cblas_sgemm(CblasRowMajor, tA, tB, m, n, k, alpha, A, lda,
+      B, ldb, beta, C, n);
+
+}
 
-void cpu_axpy(const float * A, const int n, const float alpha, float * B);
+template<typename Dtype>
+void cpu_axpy(const Dtype * A, const int n, const Dtype alpha, Dtype * B) {
+  cblas_saxpy(n, alpha, A, 1, B, 1);
+}
 
-float cpu_dot(const float * A, const float * B, const int n);
+template<typename Dtype>
+Dtype cpu_dot(const Dtype * A, const Dtype * B, const int n) {
+  Dtype sum = 0;
+  for (int i = 0 ; i < n ; i++)
+    sum += A[i] * B[i];
+  return sum;
+}
 
 // element-wise
-template<typename Op>
-void cpu_e_f(const int n, const float alpha, float * A) {
-                for (int i = 0 ; i < n ; i++) {
-                                Op::Map(alpha, &A[i]);
-                }
+template<typename Op, typename Dtype>
+void cpu_e_f(const int n, const Dtype alpha, Dtype * A) {
+  for (int i = 0 ; i < n ; i++) {
+    Op::Map(alpha, &A[i]);
+  }
 }
 
-template<typename Op>
-void cpu_e_f(const int n, const float * A, const float alpha, float * B) {
-                for (int i = 0 ; i < n ; i++) {
-                                Op::Map(alpha, A[i], &B[i]);
-                }
+template<typename Op, typename Dtype>
+void cpu_e_f(const int n, const Dtype * A, const Dtype alpha, Dtype * B) {
+  for (int i = 0 ; i < n ; i++) {
+    Op::Map(alpha, A[i], &B[i]);
+  }
 }
 
-template<typename Op>
-void cpu_e_f(const int n, const float * A, const float * B,
-const float alpha, const float beta, float * C) {
-                for (int i = 0 ; i < n ; i++) {
-                                Op::Map(alpha, beta, A[i], B[i], &C[i]);
-                }
+template<typename Op, typename Dtype>
+void cpu_e_f(const int n, const Dtype * A, const Dtype * B,
+    const Dtype alpha, const Dtype beta, Dtype * C) {
+  for (int i = 0 ; i < n ; i++) {
+    Op::Map(alpha, beta, A[i], B[i], &C[i]);
+  }
 }
 // element-wise generalized operation defined in Op
 
 
 // matrix/vector expand/reduce
 
-template<typename Op>
-void cpu_reduce_f(const float * A, const int m, const int n, float * B) {
-                for (int i = 0 ; i < m ; i++) {
-                                Op::Map(A+i*n, n, B[i]);
-                }
+template<typename Op, typename Dtype>
+void cpu_reduce_f(const Dtype * A, const int m, const int n, Dtype * B) {
+  for (int i = 0 ; i < m ; i++) {
+    Op::Map(A+i*n, n, B[i]);
+  }
 }
 // reduce each row of A to an element of B e.g. the sum operation in softmax
-template<typename Op>
-void cpu_expand_f(const float * A, const int m, const int n, float * B) {
-                for (int i = 0 ; i < m ; i++) {
-                                Op::Map(A[i], n, B+i*n);
-                }
+template<typename Op, typename Dtype>
+void cpu_expand_f(const Dtype * A, const int m, const int n, Dtype * B) {
+  for (int i = 0 ; i < m ; i++) {
+    Op::Map(A[i], n, B+i*n);
+  }
 }
 // expand each element in A into a row of B
 
-#ifdef SINGA_GPU
-void gpu_gemm(const float * A, const float * B,
-const int m, const int n, const int k, const float alpha, const float beta,
-const bool TranA, const bool TranB, float * C);
+#ifdef USE_GPU
+template<typename Dtype>
+void gpu_gemm(const Dtype * A, const Dtype * B, const int m, const int n,
+    const int k, const Dtype alpha, const Dtype beta, const bool TranA,
+    const bool TranB, Dtype * C) {
+  int lda = TranA ? m : k;
+  int ldb = TranB ? k : n;
+  int ldc = n;
+  cublasOperation_t tA = (TranA == false) ? CUBLAS_OP_N : CUBLAS_OP_T;
+  cublasOperation_t tB = (TranB == false) ? CUBLAS_OP_N : CUBLAS_OP_T;
+  cublasHandle_t handle;
+  cublasCreate(&handle);
+  cublasSgemm(handle, tB, tA, n, m, k, &alpha, B, ldb,
+      A, lda, &beta, C, ldc);
+  cublasDestroy(handle);
+}
 
-void gpu_gemv(const float * A, const float * B, const int m, const int n,
-const float alpha, const float beta, const bool TranA, float * C);
+template<typename Dtype>
+void gpu_gemv(const Dtype * A, const Dtype * B, const int m, const int n,
+    const Dtype alpha, const Dtype beta, const bool TranA, Dtype * C) {
+  int lda = n;
+  cublasOperation_t tA = (TranA == true) ? CUBLAS_OP_N : CUBLAS_OP_T;
+  cublasHandle_t handle;
+  cublasCreate(&handle);
+  cublasSgemv(handle, tA, n, m, &alpha , A, lda, B, 1, &beta, C, 1);
+  cublasDestroy(handle);
+}
 
-void gpu_axpy(const float * A, const int n, const float alpha, float * B);
+template<typename Dtype>
+void gpu_axpy(const Dtype * A, const int n, const Dtype alpha, Dtype * B) {
+  cublasHandle_t handle;
+  cublasCreate(&handle);
+  cublasSaxpy(handle, n, &alpha, A, 1, B, 1);
+  cublasDestroy(handle);
+}
 
-float gpu_dot(const float * A, const float * B, const int n);
+template<typename Dtype>
+Dtype gpu_dot(const Dtype * A, const Dtype * B, const int n) {
+  cublasHandle_t handle;
+  cublasCreate(&handle);
+  Dtype result = 0.0;
+  cublasSdot(handle, n, A, 1, B, 1, &result);
+  cublasDestroy(handle);
+  return result;
+}
 
 // element-wise
-template<typename Op>
-void gpu_e_f(const int n, const float alpha, float * A) {
-    Op::CudaMap(alpha, A, n);
+template<typename Op, typename Dtype>
+void gpu_e_f(const int n, const Dtype alpha, Dtype * A) {
+  Op::CudaMap(alpha, A, n);
 }
 
-template<typename Op>
-void gpu_e_f(const int n, const float * A, const float alpha, float * B) {
-    Op::CudaMap(alpha, A, B, n);
+template<typename Op, typename Dtype>
+void gpu_e_f(const int n, const Dtype * A, const Dtype alpha, Dtype * B) {
+  Op::CudaMap(alpha, A, B, n);
 }
 
-template<typename Op>
-void gpu_e_f(const int n, const float * A, const float * B,
-const float alpha, const float beta, float * C) {
-    Op::CudaMap(alpha, beta, A, B, C, n);
+template<typename Op, typename Dtype>
+void gpu_e_f(const int n, const Dtype * A, const Dtype * B,
+    const Dtype alpha, const Dtype beta, Dtype * C) {
+  Op::CudaMap(alpha, beta, A, B, C, n);
 }
 // element-wise generalized operation defined in Op
 
 // matrix/vector expand/reduce
 
-template<typename Op>
-void gpu_reduce_f(const float * A, const int m, const int n, float * B) {
-                for (int i = 0 ; i < m ; i++) {
-                                Op::CudaMap(A+i*n, n, B[i]);
-                }
+template<typename Op, typename Dtype>
+void gpu_reduce_f(const Dtype * A, const int m, const int n, Dtype * B) {
+  for (int i = 0 ; i < m ; i++) {
+    Op::CudaMap(A+i*n, n, B[i]);
+  }
 }
 // reduce each row of A to an element of B e.g. the sum operation in softmax
-template<typename Op>
-void gpu_expand_f(const float * A, const int m, const int n, float * B) {
-                for (int i = 0 ; i < m ; i++) {
-                                Op::CudaMap(A[i], n, B+i*n);
-                }
+template<typename Op, typename Dtype>
+void gpu_expand_f(const Dtype * A, const int m, const int n, Dtype * B) {
+  for (int i = 0 ; i < m ; i++) {
+    Op::CudaMap(A[i], n, B+i*n);
+  }
 }
 // expand each element in A into a row of B
-#endif  // SINGA_GPU  
+#endif  // USE_GPU
 
 }  // namespace singa
 #endif  // SINGA_BLOB_MATH_ADDR_H_

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/4b84dbe3/include/singa/blob/math_blob.h
----------------------------------------------------------------------
diff --git a/include/singa/blob/math_blob.h b/include/singa/blob/math_blob.h
index 638f9cc..f3147e8 100644
--- a/include/singa/blob/math_blob.h
+++ b/include/singa/blob/math_blob.h
@@ -29,420 +29,569 @@
 
 
 namespace singa {
-/*********************Level-2 interface, called by user code*******************/
-
-int get_size(const std::vector<int>& shape);
-
-template <typename Dtype>
-bool check_shape_mv(const Blob<Dtype> & A, const Blob<Dtype> & B) {
-    if (A.shape().size() != 2) return false;
-    if (B.shape().size() != 1) return false;
-    if (A.shape().at(0) != B.shape().at(0)) return false;
-    return true;
-}
-
-template <typename Dtype>
-bool check_shape_equal(const Blob<Dtype> & A, const Blob<Dtype> & B,
-const Blob<Dtype> & C) {
-    int asize, bsize, csize;
-    asize = get_size(A.shape());
-    bsize = get_size(B.shape());
-    csize = get_size(C.shape());
-    if (asize != bsize) return false;
-    if (asize != csize) return false;
-    return true;
-}
-
-template <typename Dtype>
-bool check_shape_mmm(const Blob<Dtype> & A, const Blob<Dtype> & B,
-const Blob<Dtype> & C) {
-    if (A.shape().size() != 2) return false;
-    if (B.shape().size() != 2) return false;
-    if (C.shape().size() != 2) return false;
-    int a1, a2, b1, b2, c1, c2;
-    if (C.isTranspose()) return false;
-    a1 = A.isTranspose() ? A.shape().at(1) : A.shape().at(0);
-    a2 = A.isTranspose() ? A.shape().at(0) : A.shape().at(1);
-    b1 = B.isTranspose() ? B.shape().at(1) : B.shape().at(0);
-    b2 = B.isTranspose() ? B.shape().at(0) : B.shape().at(1);
-    c1 = C.shape().at(0);
-    c2 = C.shape().at(1);
-    if (a2 != b1) return false;
-    if (a1 != c1) return false;
-    if (b2 != c2) return false;
-    return true;
-}
-
-template <typename Dtype>
-bool check_shape_vvm(const Blob<Dtype> & A, const Blob<Dtype> & B,
-const Blob<Dtype> & C) {
-    if (A.shape().size() != 1) return false;
-    if (B.shape().size() != 1) return false;
-    if (C.shape().size() != 2) return false;
-    int a1, b1, c1, c2;
-    if (C.isTranspose()) return false;
-    a1 = A.shape().at(0);
-    b1 = B.shape().at(0);
-    c1 = C.shape().at(0);
-    c2 = C.shape().at(1);
-    if (a1 != c2) return false;
-    if (b1 != c1) return false;
-    return true;
-}
-
+enum XPU {cpu, gpu, any};
+
+/************* BLAS level 1 *****************/
+/**
+ * Scale each element of A with alpha, and put the result into B.
+ * Bi = alpha*Ai
+ * Use blas scale internally.
+ */
+template<typename Dtype>
+void Scale(xpu xpu, Dtype alpha, const Blob<Dtype> & A, Blob<Dtype> * B) {
+  CHECK_EQ(A.count(), B->count());
+  if (xpu == cpu)
+    cpu_scale(A.count(), alpha, A.cpu_data(), B->mutable_cpu_data());
+#ifdef USE_GPU
+#endif
+}
+
+/**
+ * Element-wise operation: Bi = alpha*Ai+Bi. A and B should have the same size
+ */
+template<typename Dtype>
+void AXPY(XPU xpu, Dtype alpha, const Blob<Dtype> & A, Blob<Dtype> * B) {
+  CHECK_EQ(A.count(), B.count());
+  if (xpu == cpu) {
+    cpu_axpy(A.cpu_data(), A.count(),
+        alpha, B->mutable_cpu_data());
+  }
+#ifdef USE_GPU
+  if (xpu == gpu) {
+    gpu_axpy(A.gpu_data(), A.count(),
+        alpha, B->mutable_gpu_data());
+  }
+#endif  // USE_GPU
+}
+
+/************* BLAS level 2 *****************/
+/**
+ * Matrix vector multiplication, C = alpha A(.T) * B + beta C.
+ * Strict shape checking:
+ * - dim of A ==2
+ *   columsn of A(.T) == B.count()
+ * - rows of A(.T) == C.count()
+ *
+ * @param[in] alpha
+ * @param[in] beta
+ * @param[in] A, matrix
+ * @param[in] B, vector
+ * @param[in, out] C, vector
+ */
+template<typename Dtype>
+void GEMV(XPU, xpu, Dtype alpha, Dtype beta, const Blob<Dtype>& A,
+    const Blob<Dtype>& B, Blob<Dtype>* C) {
+  CHECK_EQ(A.shape().size(), 2) << "A must be a matrix";
+  int a1, a2, m, n;
+  a1 = A.transpose() ? A.shape(1) : A.shape(0);
+  a2 = A.transpose() ? A.shape(0) : A.shape(1);
+  m = B.count();
+  n = C->count();
+  CHECK_EQ(a2, m) << "# columns of A(.T) must = length of B";
+  CHECK_EQ(a1, n) << "# rows of A(.T) must = length of C";
+
+  bool TranA = A.transpose();
+  if (xpu == cpu) {
+    cpu_gemv(A.cpu_data(), B.cpu_data(), m, n, alpha, beta, TranA,
+        C->mutable_cpu_data());
+  }
+#ifdef USE_GPU
+  if (xpu == gpu) {
+    // gpu part
+    gpu_gemv(A.gpu_data(), B.gpu_data(), m, n, alpha, beta, TranA,
+        C->mutable_gpu_data());
+  }
+#endif  // USE_GPU
+}
+/**
+ * Matrix vector multiplication, C = A(.T) * B, transpose is considered.
+ * Loose shape checking:
+ * - dim of A >=2
+ * - A.count() % B.count() == 0
+ * - B.count() == C.count()
+ *
+ * @param[in] A input matrix
+ * @param[in] B input vector
+ * @param[out] C output vector
+ */
 template <typename Dtype>
-bool check_shape_mvv(const Blob<Dtype> & A, const Blob<Dtype> & B,
-const Blob<Dtype> & C) {
-    if (A.shape().size() != 2) return false;
-    if (B.shape().size() != 1) return false;
-    if (C.shape().size() != 1) return false;
-    int a1, a2, b1, c1;
-    a1 = A.isTranspose() ? A.shape().at(1) : A.shape().at(0);
-    a2 = A.isTranspose() ? A.shape().at(0) : A.shape().at(1);
-    b1 = B.shape().at(0);
-    c1 = C.shape().at(0);
-    if (a2 != b1) return false;
-    if (a1 != c1) return false;
-    return true;
-}
-
-/*****************************************************************************/
-// blob transformation
-
+void MVDot(XPU xpu, const Blob<Dtype>& A, const Blob<Dtype>& B, Blob<Dtype>* C)
+{
+  GEMV(xpu, Dtype(1), Dtype(0), A, B, C);
+}
+
+/************* BLAS level 3 *****************/
+/**
+ * Matrix multiplication, C = alpha A*B + beta C, A, B and C are matrix.
+ *
+ * Tranpose is considered for A and B.
+ * Strict shape checking:
+ * - all are matrix
+ * - shapes match for matrix multiplication
+ *
+ * @param[in] alpha
+ * @param[in] beta
+ * @param[in] A, matrix
+ * @param[in] B, matrix
+ * @param[in, out] C, matrix
+ */
 template <typename Dtype>
-Blob<Dtype>* Reshape(const Blob<Dtype> & A, const std::vector<int>& shape) {
-    Blob<Dtype>* res = new Blob<Dtype>();
-    res->Mirror(A);
-    res->Reshape(shape);
-    return res;
-}
-
-// the current reshape in blob.h is:
-// void Reshape(const std::vector<int>& shape);
-
+void GEMM(XPU xpu, Dtype alpha, Dtype beta, const Blob<Dtype>& A,
+    const Blob<Dtype> & B, Blob<Dtype> * C) {
+  CHECK_EQ(A.shape().size(), 2);
+  CHECK_EQ(B.shape().size(), 2);
+  CHECK_EQ(C.shape().size(), 2);
+  int a1, a2, b1, b2, m, n;
+  CHECK(!C->transpose());
+  a1 = A.transpose() ? A.shape(1) : A.shape(0);
+  a2 = A.transpose() ? A.shape(0) : A.shape(1);
+  b1 = B.transpose() ? B.shape(1) : B.shape(0);
+  b2 = B.transpose() ? B.shape(0) : B.shape(1);
+  m = C->shape(0);
+  n = C->shape(1);
+  CHECK__EQ(a2, b1);
+  CHECK__EQ(a1, m);
+  CHECK__EQ(b2, n);
+
+  int k = A.transpose() ? A.shape(0) : A.shape(1);
+  bool TranA = A.transpose();
+  bool TranB = B.transpose();
+  if (xpu == cpu) {
+    cpu_gemm(A.cpu_data(), B.cpu_data(), m, n, k, alpha, beta,
+        TranA, TranB, C->mutable_cpu_data());
+  }
+#ifdef USE_GPU
+  if (xpu == gpu) {
+    // gpu part
+    gpu_gemm(A.gpu_data(), B.gpu_data(), m, n, k, alpha, beta,
+        TranA, TranB, C->mutable_gpu_data());
+  }
+#endif  // USE_GPU
+}
+/**
+ * Matrix multiplication, C = A(.T) * B(.T), transpose is considered.
+ * Strict shape checking:
+ * - all are matrix
+ * - shapes match for matrix multiplication
+ *
+ * @param[in] A input matrix
+ * @param[in] B input matrix
+ * @param[out] C output matrix
+ */
 template <typename Dtype>
-Blob<Dtype>* Reshape(const Blob<Dtype> & A, int dim1) {
-    std::vector<int> tmpshape;
-    tmpshape.push_back(dim1);
-    return Reshape(A, tmpshape);
+void MMDot(XPU xpu, const Blob<Dtype>& A, const Blob<Dtype>& B, Blob<Dtype>* C)
+{
+  GEMM(xpu, Dtype(1), Dtype(0), A, B, C);
 }
 
-template <typename Dtype>
-Blob<Dtype>* Reshape(const Blob<Dtype> & A, int dim1, int dim2) {
-    std::vector<int> tmpshape;
-    tmpshape.push_back(dim1);
-    tmpshape.push_back(dim2);;
-    return Reshape(A, tmpshape);
-}
-
-template <typename Dtype>
-Blob<Dtype>* Reshape(const Blob<Dtype> & A, int dim1, int dim2, int dim3) {
-    std::vector<int> tmpshape;
-    tmpshape.push_back(dim1);
-    tmpshape.push_back(dim2);
-    tmpshape.push_back(dim3);
-    return Reshape(A, tmpshape);
-}
-
-template <typename Dtype>
-Blob<Dtype>* Reshape(const Blob<Dtype> & A, int dim1, int dim2,
-int dim3, int dim4) {
-    std::vector<int> tmpshape;
-    tmpshape.push_back(dim1);
-    tmpshape.push_back(dim2);
-    tmpshape.push_back(dim3);
-    tmpshape.push_back(dim4);
-    return Reshape(A, tmpshape);
-}
 
+/*********************** Inner and Outer product****************************/
+/**
+ * Inner product for two vectors.
+ * Loose shape checking, A.count() == B.count.
+ *
+ * @param[in] A, input vector (shape checking using A.count()).
+ * @param[in] B, input vector (shape checking using B.count()).
+ * @return inner product value.
+ */
 template <typename Dtype>
-Blob<Dtype>* Reshape(const Blob<Dtype> & A, int dim1, int dim2,
-int dim3, int dim4, int dim5) {
-    std::vector<int> tmpshape;
-    tmpshape.push_back(dim1);
-    tmpshape.push_back(dim2);
-    tmpshape.push_back(dim3);
-    tmpshape.push_back(dim4);
-    tmpshape.push_back(dim5);
-    return Reshape(A, tmpshape);
-}
-
+Dtype VVDot(XPU xpu, const Blob<Dtype> & A, const Blob<Dtype> & B) {
+  Dtype res = 0;
+  CHECK_EQ(A.count(), B.count());
+  int n = A.count();
+  if (xpu == cpu) {
+    res = cpu_dot(A.cpu_data(), B.cpu_data(), n);
+  }
+#ifdef USE_GPU
+  if (xpu == gpu) {
+    // gpu part
+    res = gpu_dot(A.gpu_data(), B.gpu_data(), n);
+  }
+#endif  // USE_GPU
+  return res;
+}
+
+/**
+ * Outer product, C = A ** B, transpose is disabled.
+ * Loose shape checking, A.count() * B.count() == C.count()
+ *
+ * @param[in] A, input vector
+ * @param[in] B, input vector
+ * @param[out] C, output matrix
+ */
 template <typename Dtype>
-Blob<Dtype>* Transpose(const Blob<Dtype> & A) {
-    Blob<Dtype>* res = new Blob<Dtype>();
-    res->Mirror(A);
-    res->setTranspose();
-    return res;
-}
-// return A^T
-
-
-/*****************************************************************************/
-// class1 matrix operation
-
-
-void MMDot(XPU xpu, const Blob<float> & A, const Blob<float> & B,
-Blob<float> * C);
-// A, B and C are matrix
-
-
-void MVDot(XPU xpu, const Blob<float> & A, const Blob<float> & B,
-Blob<float> * C);
-// A is matrix,B and C are vector
-
-
-void VVDot(XPU xpu, const Blob<float> & A, const Blob<float> & B,
-Blob<float> * C);
-// C is matrix,A and B are vector
-
-
-float VVdot(XPU xpu, const Blob<float> & A, const Blob<float> & B);
-// A and B are vectors
-
-
-void GEMM(XPU xpu, const Blob<float> & A, const Blob<float> & B,
-Blob<float> * C, float alpha = 1, float beta = 1);
-// C = alpha*A*B+beta*C, A, B and C are matrix
-
-
-
-/*****************************************************************************/
-// class2 element-wise operation
-
-// element-wise generalized operation defined in Op
-
-
-template<typename Op>
-void E_Func(XPU xpu, Blob<float> * A, float alpha) {
-    if (xpu == cpu) {
-        int n = get_size(A->shape());
-        cpu_e_f<Op>(n, alpha, A->mutable_cpu_data());
-    }
-    #ifdef SINGA_GPU
-    if (xpu == gpu) {
-        // gpu part
-        int n = get_size(A->shape());
-        gpu_e_f<Op>(n, alpha, A->mutable_gpu_data());
-    }
-    #endif  // SINGA_GPU
-}
-
-template<typename Op>
-void E_Func(XPU xpu, const Blob<float> & A, Blob<float> * B, float alpha) {
-    if (check_shape_equal(A, *B, *B)) {
-        int n = get_size(A.shape());
-        if (xpu == cpu) {
-            cpu_e_f<Op>(n, A.cpu_data(), alpha, B->mutable_cpu_data());
-        }
-        #ifdef SINGA_GPU
-        if (xpu == gpu) {
-            // gpu part
-            gpu_e_f<Op>(n, A.gpu_data(), alpha, B->mutable_gpu_data());
-        }
-        #endif  // SINGA_GPU
-    } else {
-        // report errors here
-    }
+void OuterProduct(XPU xpu, const Blob<Dtype>& A, const Blob<Dtype>& B,
+    Blob<Dtype> * C) {
+  CHECK(!C.transpose());  // do not support C.T now.
+
+  int m = A.count();
+  int n = B.count();
+  CHECK_EQ(C->count(), m * n);
+
+  if (xpu == cpu) {
+    cpu_gemm(A.cpu_data(), B.cpu_data(), m, n, 1, 1, 0,
+        false, false, C->mutable_cpu_data());
+  }
+#ifdef USE_GPU
+  if (xpu == gpu) {
+    // gpu part
+    gpu_gemm(A.gpu_data(), B.gpu_data(), m, n, 1, 1, 0,
+        false, false, C->mutable_gpu_data());
+  }
+#endif  // USE_GPU
+}
+/*********************** Element-wise functions ***********************/
+/**
+ * Apply the function from Op for each element in A and put the result into B,
+ * i.e., Bi = Op(Ai).
+ * Loose shape checking, A.count() == B.count().
+ */
+template<typename Op, typename Dtype>
+void Map(XPU xpu, const Blob<Dtype> & A, Blob<Dtype> * B) {
+  CHECK_EQ(A.count(), B->count()) << "Blobs must have the same size";
+  if (xpu == cpu) {
+    cpu_e_f<Op>(A.count(), A.cpu_data(), B->mutable_cpu_data());
+  }
+#ifdef SINGA_GPU
+  if (xpu == gpu) {
+    // gpu part
+    gpu_e_f<Op>(A.count(), A.gpu_data(), B->mutable_gpu_data());
+  }
+#endif  // SINGA_GPU
+}
+
+/**
+ * Apply the function from Op for each element in A and B, and put the result
+ * into C, i.e., Ci = Op(Ai, Bi).
+ * Loose shape checking, A, B and C are of the same size.
+ */
+template<typename Op, typename Dtype>
+void Map(XPU xpu, const Blob<Dtype> & A, const Blob<Dtype> & B,
+    Blob<Dtype> * C) {
+  CHECK_EQ(A.count(), B.count()) << "Blobs must have the same size";
+  CHECK_EQ(A.count(), C->count()) << "Blobs must have the same size";
+  if (xpu == cpu) {
+    cpu_e_f<Op>(A.count(), A.cpu_data(), B.cpu_data(), C->mutable_cpu_data());
+  }
+#ifdef SINGA_GPU
+  if (xpu == gpu) {
+    // gpu part
+    gpu_e_f<Op>(A.count(), A.gpu_data(), B.gpu_data(), C->mutable_gpu_data());
+  }
+#endif  // SINGA_GPU
+}
+
+/**
+ * Bi = Op(alpha, Ai)
+ * Loose shape checking, A.count() == B.count().
+ */
+template<typename Op, typename Dtype>
+void Map(XPU xpu, Dtype alpha, const Blob<Dtype>& A, const Blob<Dtype>* B) {
+  CHECK_EQ(A.count(), B->count()) << "Blobs must have the same size";
+  if (xpu == cpu) {
+    cpu_e_f<Op>(A.count(), alpha, A.cpu_data(), B->mutable_cpu_data());
+  }
+#ifdef SINGA_GPU
+#endif  // SINGA_GPU
+}
+/**
+ * Ci = Op(alpha, Ai, Bi)
+ * Loose shape checking, A, B and C are of the same size.
+ */
+template<typename Op, typename Dtype>
+void Map(XPU xpu, Dtype alpha, const Blob<Dtype>& A, const Blob<Dtype>& B,
+    Blob<Dtype>* C) {
+  CHECK_EQ(A.count(), B->count()) << "Blobs must have the same size";
+  if (xpu == cpu) {
+    cpu_e_f<Op>(A.count(), alpha, A.cpu_data(), B->cpu_data(),
+        C->mutable_cpu_data());
+  }
+#ifdef SINGA_GPU
+#endif  // SINGA_GPU
+}
+
+/**
+ * Currently use std::copy which has shown better performance than memcpy.
+ * http://stackoverflow.com/questions/4707012/c-memcpy-vs-stdcopy
+ * TODO(wangwei) test blas copy vs std::copy.
+ *
+ * Loose shape checking, A.count() == B.count().
+ */
+template<typename Dtype>
+void Copy(XPU xpu, const Blob<Dtype>& A, const Blob<Dtype>* B) {
+  CHECK_EQ(A.count(), B->count()) << "Blobs must have the same size";
+  if (xpu == cpu)
+    std::copy(A.cpu_data(), A.cpu_data() + A.count(), B->mutable_cpu_data());
+  else {
+    LOG(FATAL) << "Not implemented";
+  }
+}
+
+/**
+ * C = A + B
+ * Implemented using Copy and AXPY.
+ */
+template<typename Dtype>
+void Add(XPU xpu, const Blob<Dtype> & A, const Blob<Dtype> & B,
+    Blob<Dtype> * C) {
+  Copy(A, C);
+  AXPY(B, C, 1);
+}
+
+/**
+ * C = A - B
+ * Implemented using Copy and AXPY.
+ */
+template<typename Dtype>
+void Sub(XPU xpu, const Blob<Dtype> & A, const Blob<Dtype> & B,
+    Blob<Dtype> * C) {
+  Copy(xpu, A, C);
+  AXPY(xpu, B, C, -1);
+}
+
+/**
+ * C = A * B, implemented using
+ * Map(XPU, const Blob<Dtype>&, const Blob<Dtype>&, Blob<Dtype>*).
+ */
+template<typename Dtype>
+void Mult(XPU xpu, const Blob<Dtype> & A, const Blob<Dtype> & B,
+    Blob<Dtype> * C) {
+  Map<singa::op::Mult>(xpu, A, B, C);
+  // TODO(wangwei) use MKL's vector func
+}
+
+/**
+ * C = A / B, implemented using
+ * Map(XPU, const Blob<Dtype>&, const Blob<Dtype>&, Blob<Dtype>*).
+ */
+template<typename Dtype>
+void Div(XPU xpu, const Blob<Dtype> & A, const Blob<Dtype> & B,
+    Blob<Dtype> * C) {
+  Map<singa::op::Div>(xpu, A, B, C);
+  // TODO(wangwei) use MKL's vector func
+}
+/*************************1D<-->2D op/transform***************************/
+/**
+ * Add each row of B with A, i.e., Bij = alpha*Ai + beta*Bij
+ * Loose shape checking, B.count() % A.count() == 0.
+ * # rows of B = B.count() / A.count().
+ * Transpose is disabled.
+ */
+template<typename Dtype>
+void MVAdd(XPU xpu, Dtype alpha, Dtype beta, const Blob<Dtype> & A,
+    Blob<Dtype> * B) {
+  CHECK_EQ(B.count() % A.count(), 0) << "#col of B not match length of A";
+  int m = A.count(), n = B->count() / m;
+  if (xpu == cpu) {
+    Blob<Dtype> one(n);
+    one.SetValue(1);
+    cpu_gemm(A.cpu_data(), one.cpu_data(), m, n, 1, alpha, beta,
+        false, false, B->mutable_cpu_data());
+  }
+#ifdef USE_GPU
+  if (xpu == gpu) {
+    singa_gpu_add_vec_row(B->gpu_data(),
+        A.gpu_data(), A.gpu_data(), m, n, n);
+    // gpu part
+  }
+#endif  // USE_GPU
+}
+/**
+ * Add each row of B with A, i.e., Bij = Ai + Bij
+ * Loose shape checking, B.count() % A.count() == 0.
+ * # rows of B = B.count() / A.count().
+ * Transpose is disabled.
+ */
+template<typename Dtype>
+void MVAdd(XPU xpu, const Blob<Dtype> & A, Blob<Dtype>* B) {
+  MVAdd(xpu, Dtype(1), Dtype(1), A, B);
+}
+
+/**
+ * Copy A to each row of B
+ * Loose shape checking, B.count() % A.count() == 0,
+ * # rows of B = B.count() / A.count().
+ * Transpose is disabled.
+ */
+template<typename Dtype>
+void Repmat(XPU xpu, const Blob<Dtype> & A, Blob<Dtype> * B) {
+  MVAdd(xpu, Dtype(1), Dtype(0), A, B);
+}
+
+/**
+ * Add each col of matrix A to vector B, i.e., Bi = \sum_j {alpha*Aij}+beta*Bi
+ * Loose shape checking, A.count() % B.count() == 0.
+ * # rows of A = A.count() / B.count().
+ * Transpose is disabled.
+ */
+template<typename Dtype>
+void MVSum(XPU xpu, Dtype alpha, Dtype beta, const Blob<Dtype> & A,
+    Blob<Dtype> * B) {
+  CHECK_EQ(A.count() % B->count(), 0) << "length of B must = # of cols of A";
+
+  int m = B->count(), n = A.count() / m;
+  if (xpu == cpu) {
+    Blob<Dtype> one(n);
+    one.SetValue(1);
+    cpu_gemm(A.cpu_data(), one.cpu_data(), m, 1, n, alpha, beta,
+        false, false, B->mutable_cpu_data());
+  }
+#ifdef USE_GPU
+  if (xpu == gpu) {
+    singa_gpu_sum_col(A.gpu_data(), B->gpu_data(), m, n, n);
+    // gpu part
+  }
+#endif  // USE_GPU
+}
+/**
+ * Reduce each row of A to an element of B.
+ * Loose shape checking, A.count() % B.count() == 0.
+ * # rows of A = A.count() / B.count().
+ */
+template<typename Op, typename Dtype>
+void Reduce2D(XPU xpu, const Blob<Dtype> & A, Blob<Dtype> * B) {
+  CHECK_EQ(A.count() % B.count(), 0) << "Row size not match B length";
+  int m = B->count(), n = A.count() / m;
+  if (xpu == cpu) {
+    cpu_reduce_f<Op>(A.cpu_data(), m, n, B->mutable_cpu_data());
+  }
+#ifdef SINGA_GPU
+  if (xpu == gpu) {
+    // gpu part
+    gpu_reduce_f<Op>(A.gpu_data(), m, n, B->mutable_gpu_data());
+  }
+#endif  // SINGA_GPU
+}
+/**
+ * Duplicate each element of A into a row of B.
+ * Loose shape checking, B.count() % A.count() == 0.
+ * # rows of B = B.count() / A.count().
+ */
+template<typename Op, typename Dtype>
+void Expand2D(XPU xpu, const Blob<Dtype> & A, Blob<Dtype> * B) {
+  CHECK_EQ(B.count() % A.count(), 0) << "Row size of B not match length of A";
+  int m = A.count(), n = B->count() / m;
+  if (xpu == cpu) {
+    cpu_expand_f<Op>(A.cpu_data(), m, n, B->mutable_cpu_data());
+  }
+#ifdef SINGA_GPU
+  if (xpu == gpu) {
+    // gpu part
+    gpu_expand_f<Op>(A.gpu_data(), m, n, B->mutable_gpu_data());
+  }
+#endif  // SINGA_GPU
+}
+
+/***********************************************************/
+/**
+ * Apply the function from Op for each element in A, Op(Ai).
+ * @param A
+ */
+template<typename Op, typename Dtype>
+void Map(XPU xpu, Blob<Dtype>* A) {
+  if (xpu == cpu) {
+    cpu_e_f<Op>(A->count(), A->mutable_cpu_data());
+  }
+#ifdef SINGA_GPU
+  if (xpu == gpu) {
+    // gpu part
+    gpu_e_f<Op>(A->count(), A->mutable_gpu_data());
+  }
+#endif  // SINGA_GPU
+}
+
+/**
+ * B = e ^ A
+ */
+template<typename Dtype>
+void Exp(XPU xpu, const Blob<Dtype> & A, Blob<Dtype> * B) {
+  Map<singa::op::Exp>(xpu, A, B);
+}
+
+/**
+ * element-wise operation: b = 1.0f / (1.0f + expf(-a));
+ */
+template<typename Dtype>
+inline void Sigmoid(XPU xpu, const Blob<Dtype> & A, Blob<Dtype> * B) {
+    Map<singa::op::Sigmoid>(xpu, A, B);
+}
+
+/**
+ * element-wise operation: b = a * ( 1.0f - a );
+ */
+inline void SigmoidGrad(XPU xpu, const Blob<Dtype> & A, Blob<Dtype> * B) {
+    Map<singa::op::SigmoidGrad>(xpu, A, B);
+}
+
+/**
+ * element-wise operation: b = std::max( a, 0)
+ */
+inline void Relu(XPU xpu, const Blob<Dtype> & A, Blob<Dtype> * B) {
+    Map<singa::op::Relu>(xpu, A, B);
+}
+
+/**
+ * element-wise operation: b = a > 0 ? 1: 0;
+ */
+template<typename Dtype>
+inline void ReluGrad(XPU xpu, const Blob<Dtype> & A, Blob<Dtype> * B) {
+    Map<singa::op::ReluGrad>(xpu, A, B);
+}
+
+/**
+ * element-wise operation: b = tanh(a);
+ */
+template<typename Dtype>
+inline void Tanh(XPU xpu, const Blob<Dtype> & A, Blob<Dtype> * B) {
+    Map<singa::op::Tanh>(xpu, A, B);
+}
+
+/**
+ * B = 1- A^2
+ */
+template<typename Dtype>
+inline void TanhGrad(XPU xpu, const Blob<Dtype> & A, Blob<Dtype> * B) {
+    Map<singa::op::TanhGrad>(xpu, A, B);
+}
+/**
+ * B = log(1+exp(A))
+ */
+template<typename Dtype>
+inline void Softplus(XPU xpu, const Blob<Dtype> & A, Blob<Dtype> * B) {
+    Map<singa::op::Softplus>(xpu, A, B);
+}
+
+/**
+ * B = 1.0f / (1.0f + expf(-A));
+ */
+template<typename Dtype>
+inline void SoftplusGrad(XPU xpu, const Blob<Dtype> & A, Blob<Dtype> * B) {
+    Map<singa::op::SoftplusGrad>(xpu, A, B);
+}
+
+template<typename Dtype>
+inline void Square(XPU xpu, const Blob<Dtype> & A, Blob<Dtype> * B) {
+    Map<singa::op::Square>(xpu, A, B);
+}
+
+template<typename Dtype>
+inline void SquareGrad(XPU xpu, const Blob<Dtype> & A, Blob<Dtype> * B) {
+    Map<singa::op::Square_grad>(xpu, A, B);
 }
 
-template<typename Op>
-void E_Func(XPU xpu, const Blob<float> & A, const Blob<float> & B,
-Blob<float> * C, float alpha, float beta) {
-    if (check_shape_equal(A, B, *C)) {
-        int n = get_size(A.shape());
-        if (xpu == cpu) {
-            cpu_e_f<Op>(n, A.cpu_data(), B.cpu_data(), alpha, beta,
-            C->mutable_cpu_data());
-        }
-        #ifdef SINGA_GPU
-        if (xpu == gpu) {
-            // gpu part
-            gpu_e_f<Op>(n, A.gpu_data(), B.gpu_data(), alpha, beta,
-            C->mutable_gpu_data());
-        }
-        #endif  // SINGA_GPU
-    } else {
-        // report errors here
-    }
-}
-
-
-inline void Set(XPU xpu, Blob<float> * A, float alpha) {
-    E_Func<singa::op::Set>(xpu, A, alpha);
+template<typename Dtype>
+inline void Sqrt(XPU xpu, const Blob<Dtype> & A, Blob<Dtype> * B) {
+    Map<singa::op::Sqrt>(xpu, A, B);
+}
+
+/**
+ * B = A < alpha ? 1:0;
+ */
+template<typename Dtype>
+inline void Threshold(XPU xpu, Dtype alpha, const Blob<Dtype> & A,
+    Blob<Dtype> * B) {
+  Map<singa::op::Threshold>(xpu, alpha, A, B);
 }
-// element-wise operation: Ai = alpha
-
-
-inline void Scale(XPU xpu, const Blob<float> & A, Blob<float> * B,
-float alpha) {
-    E_Func<singa::op::Scale>(xpu, A, B, alpha);
-}
-// element-wise operation: Bi = alpha*Ai
-
-inline void Exp(XPU xpu, const Blob<float> & A, Blob<float> * B,
-float alpha = 2.71) {
-    E_Func<singa::op::Exp>(xpu, A, B, alpha);
-}
-// element-wise operation: Bi = alpha^Ai
-
-inline void Exp_grad(XPU xpu, const Blob<float> & A, Blob<float> * B,
-float alpha = 2.71) {
-    E_Func<singa::op::Exp_grad>(xpu, A, B, alpha);
-}
-// element-wise operation: Bi = Ai*log(alpha)
-
-inline void Gsigmoid(XPU xpu, const Blob<float> & A, Blob<float> * B,
-float alpha) {
-    E_Func<singa::op::Gsigmoid>(xpu, A, B, alpha);
-}
-// element-wise operation: b = 1.0f / (1.0f + expf(-a * alpha));
-
-inline void Gsigmoid_grad(XPU xpu, const Blob<float> & A, Blob<float> * B,
-float alpha) {
-    E_Func<singa::op::Gsigmoid_grad>(xpu, A, B, alpha);
-}
-// element-wise operation: b = alpha * a * ( 1.0f - a );
-
-inline void Grelu(XPU xpu, const Blob<float> & A, Blob<float> * B,
-float alpha = 0) {
-    E_Func<singa::op::Grelu>(xpu, A, B, alpha);
-}
-// element-wise operation: b = ( 1 - alpha ) * std::max( a, 0.0f ) + alpha * a;
-
-inline void Grelu_grad(XPU xpu, const Blob<float> & A, Blob<float> * B,
-float alpha = 0) {
-    E_Func<singa::op::Grelu_grad>(xpu, A, B, alpha);
-}
-// element-wise operation: b = a > 0.0f ? 1.0f : alpha;
-
-inline void Gtanh(XPU xpu, const Blob<float> & A, Blob<float> * B,
-float alpha) {
-    E_Func<singa::op::Gtanh>(xpu, A, B, alpha);
-}
-// element-wise operation: b = tanhf( a * alpha );
-
-inline void Gtanh_grad(XPU xpu, const Blob<float> & A, Blob<float> * B,
-float alpha) {
-    E_Func<singa::op::Gtanh_grad>(xpu, A, B, alpha);
-}
-// element-wise operation: b = alpha * ( 1.0f - a * a );
-
-inline void Softplus(XPU xpu, const Blob<float> & A, Blob<float> * B) {
-    E_Func<singa::op::Softplus>(xpu, A, B, 0);
-}
-// element-wise operation: b = logf(1 + expf(a));
-
-inline void Softplus_grad(XPU xpu, const Blob<float> & A, Blob<float> * B) {
-    E_Func<singa::op::Softplus_grad>(xpu, A, B, 0);
-}
-// element-wise operation: b = 1.0f / (1.0f + expf(-a));
-
-inline void Square(XPU xpu, const Blob<float> & A, Blob<float> * B) {
-    E_Func<singa::op::Square>(xpu, A, B, 0);
-}
-// element-wise operation: b = a * a;
-
-inline void Square_grad(XPU xpu, const Blob<float> & A, Blob<float> * B) {
-    E_Func<singa::op::Square_grad>(xpu, A, B, 0);
-}
-// element-wise operation: b = 2 * sqrt(a);
-
-inline void Sqrt(XPU xpu, const Blob<float> & A, Blob<float> * B) {
-    E_Func<singa::op::Sqrt>(xpu, A, B, 0);
-}
-// element-wise operation: b = sqrt(a);
-
-inline void Threshold(XPU xpu, const Blob<float> & A, float alpha,
-Blob<float> * B) {
-    E_Func<singa::op::Threshold>(xpu, A, B, alpha);
-}
-// element-wise operation: b =  a < alpha ? 1.0f : 0.0f;
-
-inline void Add(XPU xpu, const Blob<float> & A, const Blob<float> & B,
-Blob<float> * C) {
-    E_Func<singa::op::Add>(xpu, A, B, C, 0, 0);
-}
-// element-wise operation: Ci = Ai+Bi  A,B and C should have the same size
-
-inline void Sub(XPU xpu, const Blob<float> & A, const Blob<float> & B,
-Blob<float> * C) {
-    E_Func<singa::op::Sub>(xpu, A, B, C, 0, 0);
-}
-// element-wise operation: Ci = Ai-Bi  A,B and C should have the same size
-
-inline void Mult(XPU xpu, const Blob<float> & A, const Blob<float> & B,
-Blob<float> * C) {
-    E_Func<singa::op::Mult>(xpu, A, B, C, 0, 0);
-}
-// element-wise operation: Ci = Ai*Bi  A,B and C should have the same size
-
-inline void Div(XPU xpu, const Blob<float> & A, const Blob<float> & B,
-Blob<float> * C) {
-    E_Func<singa::op::Div>(xpu, A, B, C, 0, 0);
-}
-// element-wise operation: Ci = Ai/Bi  A,B and C should have the same size
-
-
-void AXPY(XPU xpu, const Blob<float> & A, Blob<float> * B, float alpha);
-// element-wise operation: Bi = alpha*Ai+Bi  A and B should have the same size
-
-/*****************************************************************************/
-// class3 matrix-vector expand/reduce operation
-
-template<typename Op>
-void Reduce_F(XPU xpu, const Blob<float> & A, Blob<float> * B) {
-    if (check_shape_mv(A, *B)) {
-        int m = get_size(B->shape());
-        int n = get_size(A.shape()) / m;
-        if (xpu == cpu) {
-            cpu_reduce_f<Op>(A.cpu_data(), m, n, B->mutable_cpu_data());
-        }
-        #ifdef SINGA_GPU
-        if (xpu == gpu) {
-            // gpu part
-            gpu_reduce_f<Op>(A.gpu_data(), m, n, B->mutable_gpu_data());
-        }
-        #endif  // SINGA_GPU
-    } else {
-        // report errors here
-    }
-}
-// reduce each row of A to an element of B e.g. the sum operation in softmax
-
-template<typename Op>
-void Expand_F(XPU xpu, const Blob<float> & A, Blob<float> * B) {
-    if (check_shape_mv(*B, A)) {
-        int m = get_size(A.shape());
-        int n = get_size(B->shape()) / m;
-        if (xpu == cpu) {
-            cpu_expand_f<Op>(A.cpu_data(), m, n, B->mutable_cpu_data());
-        }
-        #ifdef SINGA_GPU
-        if (xpu == gpu) {
-            // gpu part
-            gpu_expand_f<Op>(A.gpu_data(), m, n, B->mutable_gpu_data());
-        }
-        #endif  // SINGA_GPU  
-    } else {
-        // report errors here
-    }
-}
-// expand each element in A into a row of B
-
-void Repmat(XPU xpu, const Blob<float> & A, Blob<float> * B);
-// A is a vector, B is a matrix , let each row of B to be A
-
-void MVAdd(XPU xpu, const Blob<float> & A, Blob<float> * B,
-float alpha, float beta);
-// A is a vector, B is a matrix , Bij = alpha*Ai+beta*Bij
-// will use gemm. faster than general expand_f
-
-void MVSum(XPU xpu, const Blob<float> & A, Blob<float> * B,
-float alpha, float beta);
-// A is a vector, B is a matrix , Ai = \sigma_j_{alpha*Bij}+beta*Ai
-// will use gemm. faster than general reduce_f
-
-
 }  // end of namespace singa
 
 #endif  // SINGA_BLOB_MATH_BLOB_H_

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/4b84dbe3/include/singa/blob/singa_op.h
----------------------------------------------------------------------
diff --git a/include/singa/blob/singa_op.h b/include/singa/blob/singa_op.h
index 3747568..1131c5d 100644
--- a/include/singa/blob/singa_op.h
+++ b/include/singa/blob/singa_op.h
@@ -33,314 +33,318 @@
 #endif  // SINGA_GPU
 
 namespace singa {
-    enum XPU { cpu, gpu, any};
 
 namespace op {
-struct Set {
-    inline static void Map(float alpha, float * a) {
-        *a = alpha;
-    }
-    #ifdef SINGA_GPU
-    inline static void CudaMap(float alpha, float * a, int n) {
-        singa::singa_gpu_set_value(a, alpha, n);
-    }
-    #endif  // SINGA_GPU
-};
 
-struct Scale {
-    inline static void Map(float alpha,  const float & a, float * b) {
-        *b = alpha * a;
-    }
-    #ifdef SINGA_GPU
-    inline static void CudaMap(float alpha,  const float * a,
-    float * b, int n) {
-        singa::singa_gpu_scale(a, b, alpha, n);
-    }
-    #endif  // SINGA_GPU
+/**
+ * b = e^a
+ */
+template<Dtype>
+struct Exp {
+  inline static void Map(const float & a, float * b) {
+    *b = exp(a);
+  }
+#ifdef SINGA_GPU
+  inline static void CudaMap(float alpha,  const float * a,
+      float * b, int n) {
+    singa::singa_gpu_exp(a, b, alpha, n);
+  }
+#endif  // SINGA_GPU
 };
-
-struct Scale_grad {
-    inline static void Map(float alpha,  float * output) {
-        *output = alpha;
-    }
-    #ifdef SINGA_GPU
-    inline static void CudaMap(float alpha,  float * output, int n) {
-        singa::singa_gpu_scale_grad(output, alpha, n);
-    }
-    #endif  // SINGA_GPU
+/**
+ * b = log(a), base is e
+ */
+template<Dtype>
+struct Log {
+  inline static void Map(const float & a, float *b) {
+    *b = log(a);
+  }
+}
+
+template<Dtype>
+struct Sigmoid {
+  inline static void Map(const float & a, float * b) {
+    *b = 1.0f / (1.0f + expf(-a * alpha));
+  }
+#ifdef SINGA_GPU
+  inline static void CudaMap(const float * a,
+      float * b, int n) {
+    singa::singa_gpu_sigmoid(a, b, 1, n);
+  }
+#endif  // SINGA_GPU
 };
-
-struct Exp {
-    inline static void Map(float alpha,  const float & a, float * b) {
-        *b = pow(a, alpha);
-    }
-    #ifdef SINGA_GPU
-    inline static void CudaMap(float alpha,  const float * a,
-    float * b, int n) {
-        singa::singa_gpu_exp(a, b, alpha, n);
-    }
-    #endif  // SINGA_GPU
+template<Dtype>
+struct SigmoidGrad {
+  inline static void Map(const float & a, float * b) {
+    *b = a * (1.0f - a);
+  }
+#ifdef SINGA_GPU
+  inline static void CudaMap(float alpha,  const float * a, float * b, int n) {
+    singa::singa_gpu_sigmoid_grad(a, b, 1, n);
+  }
+#endif  // SINGA_GPU
 };
 
-struct Exp_grad {
-    inline static void Map(float alpha,  const float & a, float * b) {
-        // log is the natrual log based on e
-        *b = a * log(alpha);
-    }
-    #ifdef SINGA_GPU
-    inline static void CudaMap(float alpha,  const float * a,
-    float * b, int n) {
-        singa::singa_gpu_exp_grad(a, b, alpha, n);
-    }
-    #endif  // SINGA_GPU
+template<Dtype>
+struct Relu {
+  inline static void Map(const float & a, float * b) {
+    *b = std::max(a, 0.0f);
+  }
+#ifdef SINGA_GPU
+  inline static void CudaMap(const float * a, float * b, int n) {
+    singa::singa_gpu_relu(a, b, 1, n);
+  }
+#endif  // SINGA_GPU
 };
 
-struct Gsigmoid {
-    inline static void Map(float alpha,  const float & a, float * b) {
-        *b = 1.0f / (1.0f + expf(-a * alpha));
-    }
-    #ifdef SINGA_GPU
-    inline static void CudaMap(float alpha,  const float * a,
-    float * b, int n) {
-        singa::singa_gpu_sigmoid(a, b, alpha, n);
-    }
-    #endif  // SINGA_GPU
+template<Dtype>
+struct ReluGrad {
+  inline static void Map(const float & a, float * b) {
+    *b = a > 0 ? 1 : 0;
+  }
+#ifdef SINGA_GPU
+  inline static void CudaMap(const float * a, float * b, int n) {
+    singa::singa_gpu_relu_grad(a, b, 1, n);
+  }
+#endif  // SINGA_GPU
 };
 
-struct Gsigmoid_grad {
-    inline static void Map(float alpha,  const float & a, float * b) {
-        *b = alpha * a * (1.0f - a);
-    }
-    #ifdef SINGA_GPU
-    inline static void CudaMap(float alpha,  const float * a,
-    float * b, int n) {
-        singa::singa_gpu_sigmoid_grad(a, b, alpha, n);
-    }
-    #endif  // SINGA_GPU
+template<Dtype>
+struct Tanh {
+  inline static void Map(const float & a, float * b) {
+    *b = tanhf(a);
+  }
+#ifdef SINGA_GPU
+  inline static void CudaMap(float alpha,  const float * a, float * b, int n) {
+    singa::singa_gpu_tanh(a, b, alpha, n);
+  }
+#endif  // SINGA_GPU
 };
 
-struct Grelu {
-    inline static void Map(float alpha,  const float & a, float * b) {
-        *b = (1 - alpha) * std::max(a, 0.0f) + alpha * a;
-    }
-    #ifdef SINGA_GPU
-    inline static void CudaMap(float alpha,  const float * a,
-    float * b, int n) {
-        singa::singa_gpu_relu(a, b, alpha, n);
-    }
-    #endif  // SINGA_GPU
+template<Dtype>
+struct TanhGrad {
+  inline static void Map(const float & a, float * b) {
+    *b = 1 - a * a;
+  }
+#ifdef SINGA_GPU
+  inline static void CudaMap(float alpha,  const float * a, float * b, int n) {
+    singa::singa_gpu_tanh_grad(a, b, alpha, n);
+  }
+#endif  // SINGA_GPU
 };
 
-struct Grelu_grad {
-    inline static void Map(float alpha,  const float & a, float * b) {
-        *b = a > 0.0f ? 1.0f : alpha;
-    }
-    #ifdef SINGA_GPU
-    inline static void CudaMap(float alpha,  const float * a,
-    float * b, int n) {
-        singa::singa_gpu_relu_grad(a, b, alpha, n);
-    }
-    #endif  // SINGA_GPU
+template<Dtype>
+struct Softplus {
+  inline static void Map(const float & a, float * b) {
+    *b = logf(1 + expf(a));
+  }
+#ifdef SINGA_GPU
+  inline static void CudaMap(const float * a, float * b, int n) {
+    singa::singa_gpu_softplus(a, b, 1, n);
+  }
+#endif  // SINGA_GPU
 };
 
-struct Gtanh {
-    inline static void Map(float alpha,  const float & a, float * b) {
-        *b = tanhf(a * alpha);
-    }
-    #ifdef SINGA_GPU
-    inline static void CudaMap(float alpha,  const float * a,
-    float * b, int n) {
-        singa::singa_gpu_tanh(a, b, alpha, n);
-    }
-    #endif  // SINGA_GPU
+template<Dtype>
+struct SoftplusGrad {
+  inline static void Map(const float & a, float * b) {
+    *b = 1.0f / (1.0f + expf(-a));
+  }
+#ifdef SINGA_GPU
+  inline static void CudaMap(const float * a,
+      float * b, int n) {
+    singa::singa_gpu_softplus_grad(a, b, n);
+  }
+#endif  // SINGA_GPU
 };
 
-struct Gtanh_grad {
-    inline static void Map(float alpha,  const float & a, float * b) {
-        *b = alpha * (1.0f - a * a);
-    }
-    #ifdef SINGA_GPU
-    inline static void CudaMap(float alpha,  const float * a,
-    float * b, int n) {
-        singa::singa_gpu_tanh_grad(a, b, alpha, n);
-    }
-    #endif  // SINGA_GPU
+template<Dtype>
+struct Square {
+  inline static void Map(const float & a, float * b) {
+    *b = a * a;
+  }
+#ifdef SINGA_GPU
+  inline static void CudaMap(const float * a,
+      float * b, int n) {
+    singa::singa_gpu_square(a, b, n);
+  }
+#endif  // SINGA_GPU
 };
 
-struct Softplus {
-    inline static void Map(float alpha,  const float & a, float * b) {
-        *b = logf(1 + expf(a));
-    }
-    #ifdef SINGA_GPU
-    inline static void CudaMap(float alpha,  const float * a,
-    float * b, int n) {
-        singa::singa_gpu_softplus(a, b, alpha, n);
-    }
-    #endif  // SINGA_GPU
+template<Dtype>
+struct SquareGrad {
+  inline static void Map(const float & a, float * b) {
+    *b = 2 * sqrt(a);
+  }
+#ifdef SINGA_GPU
+  inline static void CudaMap(const float * a,
+      float * b, int n) {
+    singa::singa_gpu_square_grad(a, b, 1, n);
+  }
+#endif  // SINGA_GPU
 };
 
-struct Softplus_grad {
-    inline static void Map(float alpha,  const float & a, float * b) {
-        *b = 1.0f / (1.0f + expf(-a));
-    }
-    #ifdef SINGA_GPU
-    inline static void CudaMap(float alpha,  const float * a,
-    float * b, int n) {
-        singa::singa_gpu_softplus_grad(a, b, alpha, n);
-    }
-    #endif  // SINGA_GPU
+template<Dtype>
+struct Sqrt {
+  inline static void Map(const float & a, float * b) {
+    *b = sqrt(a);
+  }
+#ifdef SINGA_GPU
+  inline static void CudaMap(const float * a,
+      float * b, int n) {
+    singa::singa_gpu_sqrt(a, b, n);
+  }
+#endif  // SINGA_GPU
 };
 
-struct Square {
-    inline static void Map(float alpha,  const float & a, float * b) {
-        *b = a * a;
-    }
-    #ifdef SINGA_GPU
-    inline static void CudaMap(float alpha,  const float * a,
-    float * b, int n) {
-        singa::singa_gpu_square(a, b, alpha, n);
-    }
-    #endif  // SINGA_GPU
+/*********************************************************************/
+/**
+ * c = pow(a, b), i.e., c = a^b
+ */
+template<Dtype>
+struct Pow {
+  inline static void Map(const float & a, const float &b, float * c) {
+    *c = pow(a, b);
+  }
+}
+template<Dtype>
+struct Mult {
+  inline static void Map(const float & a, const float & b, float * c) {
+    *c =  a * b;
+  }
+#ifdef SINGA_GPU
+  inline static void CudaMap(const float* a, const float* b, float* c, int n) {
+    singa::singa_gpu_mult(a, b, c, 1, 1, n);
+  }
+#endif  // SINGA_GPU
 };
 
-struct Square_grad {
-    inline static void Map(float alpha,  const float & a, float * b) {
-        *b = 2 * sqrt(a);
-    }
-    #ifdef SINGA_GPU
-    inline static void CudaMap(float alpha,  const float * a,
-    float * b, int n) {
-        singa::singa_gpu_square_grad(a, b, alpha, n);
-    }
-    #endif  // SINGA_GPU
+template<Dtype>
+struct Div {
+  inline static void Map(const float & a, const float & b, float * c) {
+    *c =  a / b;
+  }
+#ifdef SINGA_GPU
+  inline static void CudaMap(const float * a,
+      const float * b, float * c, int n) {
+    singa::singa_gpu_div(a, b, c, 1, 1, n);
+  }
+#endif  // SINGA_GPU
 };
 
-struct Sqrt {
-    inline static void Map(float alpha,  const float & a, float * b) {
-        *b = sqrt(a);
-    }
-    #ifdef SINGA_GPU
-    inline static void CudaMap(float alpha,  const float * a,
-    float * b, int n) {
-        singa::singa_gpu_sqrt(a, b, alpha, n);
-    }
-    #endif  // SINGA_GPU
+
+/*********************************************************************/
+template<Dtype>
+struct Set {
+  inline static void Map(float alpha, float * a) {
+    *a = alpha;
+  }
+#ifdef SINGA_GPU
+  inline static void CudaMap(float alpha, float * a, int n) {
+    singa::singa_gpu_set_value(a, alpha, n);
+  }
+#endif  // SINGA_GPU
 };
 
+template<Dtype>
 struct Threshold {
-    inline static void Map(float alpha, const float & a, float * b) {
-        *b =  a < alpha ? 1.0f : 0.0f;
-    }
-    #ifdef SINGA_GPU
-    inline static void CudaMap(float alpha,  const float * a,
-    float * b, int n) {
-        singa::singa_gpu_threshold(a, b, alpha, n);
-    }
-    #endif  // SINGA_GPU
+  inline static void Map(float alpha, const float & a, float * b) {
+    *b =  a < alpha ? 1.0f : 0.0f;
+  }
+#ifdef SINGA_GPU
+  inline static void CudaMap(float alpha,  const float * a,
+      float * b, int n) {
+    singa::singa_gpu_threshold(a, b, alpha, n);
+  }
+#endif  // SINGA_GPU
 };
 
-struct Add {
-    inline static void Map(float alpha, float beta, const float & a,
-    const float & b, float * c) {
-        *c =  a + b;
-    }
-    #ifdef SINGA_GPU
-    inline static void CudaMap(float alpha, float beta, const float * a,
-    const float * b, float * c, int n) {
-        singa::singa_gpu_add(a, b, c, alpha, beta, n);
+/**********************************/
+struct Expand_Div {
+  inline static void Map(const float & a, int n, float * b) {
+    for (int i = 0 ; i < n ; i++) {
+      b[i] /= a;
     }
-    #endif  // SINGA_GPU
+  }
+#ifdef SINGA_GPU
+  inline static void CudaMap(const float & a, int n, float * b) {
+    singa::singa_gpu_scale(b, b, a, n);
+  }
+#endif  // SINGA_GPU
 };
 
-struct Sub {
-    inline static void Map(float alpha, float beta, const float & a,
-    const float & b, float * c) {
-        *c =  a - b;
-    }
-    #ifdef SINGA_GPU
-    inline static void CudaMap(float alpha, float beta, const float * a,
-    const float * b, float * c, int n) {
-        singa::singa_gpu_sub(a, b, c, alpha, beta, n);
+struct Repmat {
+  inline static void Map(const float & a, int n, float * b) {
+    for (int i = 0 ; i < n ; i++) {
+      b[i] = a;
     }
-    #endif  // SINGA_GPU
+  }
+#ifdef SINGA_GPU
+  inline static void CudaMap(const float & a, int n, float * b) {
+    singa::singa_gpu_set_value(b, a, n);
+  }
+#endif  // SINGA_GPU
 };
 
-struct Mult {
-    inline static void Map(float alpha, float beta, const float & a,
-    const float & b, float * c) {
-        *c =  a * b;
-    }
-    #ifdef SINGA_GPU
-    inline static void CudaMap(float alpha, float beta, const float * a,
-    const float * b, float * c, int n) {
-        singa::singa_gpu_mult(a, b, c, alpha, beta, n);
-    }
-    #endif  // SINGA_GPU
-};
 
-struct Div {
-    inline static void Map(float alpha, float beta, const float & a,
-    const float & b, float * c) {
-        *c =  a / b;
+struct Scale {
+    inline static void Map(float alpha,  const float & a, float * b) {
+        *b = alpha * a;
     }
     #ifdef SINGA_GPU
-    inline static void CudaMap(float alpha, float beta, const float * a,
-    const float * b, float * c, int n) {
-        singa::singa_gpu_div(a, b, c, alpha, beta, n);
+    inline static void CudaMap(float alpha,  const float * a,
+    float * b, int n) {
+        singa::singa_gpu_scale(a, b, alpha, n);
     }
     #endif  // SINGA_GPU
 };
 
-struct Sum {
-    inline static void Map(const float * a, int n, float * b) {
-        *b = 0;
-        for (int i = 0 ; i < n ; i++) {
-                    *b += a[i];
-        }
+struct Scale_grad {
+    inline static void Map(float alpha,  float * output) {
+        *output = alpha;
     }
     #ifdef SINGA_GPU
-    inline static void CudaMap(const float * a, int n, float * b) {
-        float *sum = NULL;
-        cudaMalloc(<void**>(&sum), n*sizeof(float));
-
-        singa::singa_gpu_sum_vec(a, sum, n);
-
-        cudaMemcpyAsync(b, sum, sizeof(float), cudaMemcpyDeviceToDevice);
-        cudaFree(sum);
+    inline static void CudaMap(float alpha,  float * output, int n) {
+        singa::singa_gpu_scale_grad(output, alpha, n);
     }
     #endif  // SINGA_GPU
 };
 
-struct Expand_Div {
-    inline static void Map(const float & a, int n, float * b) {
-        for (int i = 0 ; i < n ; i++) {
-                    b[i] /= a;
-        }
-    }
-    #ifdef SINGA_GPU
-    inline static void CudaMap(const float & a, int n, float * b) {
-        singa::singa_gpu_scale(b, b, a, n);
-    }
-    #endif  // SINGA_GPU
+struct ExpGrad {
+  inline static void Map(float alpha,  const float & a, float * b) {
+    // log is the natrual log based on e
+    *b = a * log(alpha);
+  }
+#ifdef SINGA_GPU
+  inline static void CudaMap(float alpha,  const float * a,
+      float * b, int n) {
+    singa::singa_gpu_exp_grad(a, b, alpha, n);
+  }
+#endif  // SINGA_GPU
 };
 
-struct Repmat {
-    inline static void Map(const float & a, int n, float * b) {
-        for (int i = 0 ; i < n ; i++) {
-                    b[i] = a;
-        }
-    }
-    #ifdef SINGA_GPU
-    inline static void CudaMap(const float & a, int n, float * b) {
-        singa::singa_gpu_set_value(b, a, n);
+struct Sum {
+  inline static void Map(const float * a, int n, float * b) {
+    *b = 0;
+    for (int i = 0 ; i < n ; i++) {
+      *b += a[i];
     }
-    #endif  // SINGA_GPU
+  }
+#ifdef SINGA_GPU
+  inline static void CudaMap(const float * a, int n, float * b) {
+    float *sum = NULL;
+    cudaMalloc(<void**>(&sum), n*sizeof(float));
+
+    singa::singa_gpu_sum_vec(a, sum, n);
+
+    cudaMemcpyAsync(b, sum, sizeof(float), cudaMemcpyDeviceToDevice);
+    cudaFree(sum);
+  }
+#endif  // SINGA_GPU
 };
 
 };  // namespace op
 
 };  // namespace singa
 
-
-
 #endif  // SINGA_BLOB_SINGA_OP_H_

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/4b84dbe3/include/singa/utils/blob.h
----------------------------------------------------------------------
diff --git a/include/singa/utils/blob.h b/include/singa/utils/blob.h
index 0ebf8fd..eecb674 100644
--- a/include/singa/utils/blob.h
+++ b/include/singa/utils/blob.h
@@ -7,9 +7,9 @@
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
-* 
+*
 *   http://www.apache.org/licenses/LICENSE-2.0
-* 
+*
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -121,21 +121,53 @@ class Blob {
  public:
   Blob() {}
   explicit Blob(const std::vector<int>& shape) { Reshape(shape); }
+  explicit Blob(int count) { Reshape(count); }
+  explicit Blob(int a, int b) { Reshape(a, b); }
+  explicit Blob(int a, int b, int c) { Reshape(a, b, c); }
+  explicit Blob(int a, int b, int c, int d) { Reshape(a, b, c, d); }
   /**
-   * @brief Change the dimensions of the blob, allocating new memory if
-   *        necessary.
+   * Change the shape of the blob, re-allocat memory if Blob size() changes.
    *
-   * This function can be called both to create an initial allocation
-   * of memory, and to adjust the dimensions of a top blob during Layer::Reshape
-   * or Layer::Forward. When changing the size of blob, memory will only be
-   * reallocated if sufficient memory does not already exist, and excess memory
-   * will never be freed.
-   *
-   * Note that reshaping an input blob and immediately calling Net::Backward is
-   * an error; either Net::Forward or Net::Reshape need to be called to
-   * propagate the new input shape to higher layers.
+   * @param[in] shape specifies the size of each dimension, shape[0] is the highest
+   * dimension, i.e., stride[0] = shape[1] * shape[2] * ...
    */
   void Reshape(const std::vector<int>& shape);
+  /**
+   * Helper for Reshape(const std::vector<int>& shape) with shape.size() = 1.
+   *
+   * @see Reshape(const std::vector<int>&).
+   * @param[in] count total num of elements.
+   */
+  void Reshape(int count);
+  /**
+   * Helper for Reshape(const std::vector<int>& shape) with shape.size() = 2.
+   *
+   * @param a the highest dimension size, i.e., a = shape[0]. E.g., a could the
+   * batchsize.
+   * @param[in] b, b = shape[1], e.g., b could be the length of the feature vector.
+   */
+  void Reshape(int a, int b);
+  /**
+   * Helper for Reshape(const std::vector<int>& shape) with shape.size() = 3.
+   *
+   * @param[in] a, a = shape[0]
+   * @param[in] b, b = shape[1]
+   * @param[in] c, c = shape[2]
+   */
+  void Reshape(int a, int b, int c);
+  /**
+   * Helper for Reshape(const std::vector<int>& shape) with shape.size() = 4.
+   *
+   * @param[in] a, a = shape[0]
+   * @param[in] b, b = shape[1]
+   * @param[in] c, c = shape[2]
+   * @param[in] d, d = shape[3]
+   */
+  void Reshape(int a, int b, int c, int d);
+  /**
+   * Reshape as the shape of *other* Blob.
+   * @param[in] other
+   */
   void ReshapeLike(const Blob& other);
   /**
    * @brief Copy from a source Blob.
@@ -149,20 +181,45 @@ class Blob {
   void CopyFrom(const Blob<Dtype>& source, bool reshape);
   void FromProto(const singa::BlobProto& proto);
   void ToProto(singa::BlobProto* proto) const;
+  void SetValue(Dtype v);
+  /**
+   * Compute the sum of absolute values (L1 norm) of the data.
+   */
+  Dtype asum_data() const;
   /**
-   * @brief Set the data_ shared_ptr to point to the SyncedMemory holding the
-   *        data_ of Blob other -- useful in Layer&s which simply perform a copy
-   *        in their Forward pass.
+   * Sum all elements
+   */
+  Dtype sum_data() const;
+  /**
+   * Share data with the other Blob.
+   * Set the data_ shared_ptr to point to the SyncedMemory holding the data_
+   * of Blob other.
    *
-   * This deallocates the SyncedMemory holding this Blob's data_, as
+   * It may deallocate the SyncedMemory holding this Blob's data_, as
    * shared_ptr calls its destructor when reset with the "=" operator.
    */
   void ShareData(const Blob& other);
   void Swap(Blob& other);
+  /**
+   * @return the shape vector.
+   */
   inline const std::vector<int>& shape() const { return shape_; }
-  inline int count() const { return count_; }
-  inline const int version() const { return version_; }
-  inline void set_version(int v) { version_ = v; }
+  /**
+   * @return the size of the k-th dimension.
+   */
+  inline const int shape(int k) const {
+    CHECK_LT(k, shape_.size());
+    return shape_.at(k);
+  }
+  inline int count() const {
+    return count_;
+  }
+  inline const int version() const {
+    return version_;
+  }
+  inline void set_version(int v) {
+    version_ = v;
+  }
   inline const Dtype* cpu_data() const {
     CHECK(data_);
     return static_cast<const Dtype*>(data_->cpu_data());
@@ -183,34 +240,90 @@ class Blob {
     CHECK(data_);
     return static_cast<Dtype*>(data_->mutable_gpu_data());
   }
-  /// @brief Compute the sum of absolute values (L1 norm) of the data.
-  Dtype asum_data() const;
-  Dtype sum_data() const;
-  inline void setTranspose() {
-    isTranspose_ = !isTranspose_;
-  }
-  inline bool isTranspose() const {
-    return isTranspose_;
+  inline void set_transpose() {
+    transpose_ = true;
   }
-  inline void Mirror(const Blob<Dtype> & other) {
-    data_ = other.data_;
-    shape_ = other.shape_;
-    count_ = other.count_;
-    capacity_ = other.capacity_;
-    version_ = other.version_;
-    isTranspose_ = other.isTranspose_;
+  inline bool transpose() const {
+    return transpose_;
   }
 
-
  protected:
   std::shared_ptr<SyncedMemory> data_ = nullptr;
   std::vector<int> shape_;
   int count_ = 0;
   int capacity_ = 0;
   int version_ = -1;
-  bool isTranspose_ = false;
+  bool transpose_ = false;
 };  // class Blob
 
+/**
+ * Reshape a Blob.
+ * @return a new Blob with the given shape, it shares the internal data_ with
+ * the original Blob, i.e., no memory copy and allocation.
+ */
+template <typename Dtype>
+Blob<Dtype>* Reshape(const Blob<Dtype> & A, const std::vector<int>& shape) {
+  Blob<Dtype>* res = new Blob<Dtype>(A);
+  res->Reshape(shape);
+  return res;
+}
+
+/**
+ * Helper of Reshape(const Blob<Dtype>, const std::vector<int>*).
+ */
+template <typename Dtype>
+Blob<Dtype>* Reshape(const Blob<Dtype> & A, int count) {
+  std::vector<int> tmpshape;
+  tmpshape.push_back(dim1);
+  return Reshape(A, tmpshape);
+}
+/**
+ * Helper of Reshape(const Blob<Dtype>, const std::vector<int>*).
+ */
+template <typename Dtype>
+Blob<Dtype>* Reshape(const Blob<Dtype> & A, int dim0, int dim1) {
+  std::vector<int> tmpshape;
+  tmpshape.push_back(dim0);
+  tmpshape.push_back(dim1);;
+  return Reshape(A, tmpshape);
+}
+/**
+ * Helper of Reshape(const Blob<Dtype>, const std::vector<int>*).
+ */
+template <typename Dtype>
+Blob<Dtype>* Reshape(const Blob<Dtype> & A, int dim0, int dim1, int dim2) {
+  std::vector<int> tmpshape;
+  tmpshape.push_back(dim0);
+  tmpshape.push_back(dim1);
+  tmpshape.push_back(dim2);
+  return Reshape(A, tmpshape);
+}
+/**
+ * Helper of Reshape(const Blob<Dtype>, const std::vector<int>*).
+ */
+template <typename Dtype>
+Blob<Dtype>* Reshape(const Blob<Dtype> & A, int dim0, int dim1,
+    int dim2, int dim3) {
+  std::vector<int> tmpshape;
+  tmpshape.push_back(dim0);
+  tmpshape.push_back(dim1);
+  tmpshape.push_back(dim2);
+  tmpshape.push_back(dim3);
+  return Reshape(A, tmpshape);
+}
+
+/**
+ * @return a new Blob which share all internal members with the input Blob
+ * except that the transpose_ field is set to true.
+ */
+template <typename Dtype>
+Blob<Dtype>* Transpose(const Blob<Dtype> & A) {
+  Blob<Dtype>* res = new Blob<Dtype>(A);
+  res->set_transpose();
+  return res;
+}
+
+// TODO(wangwei) remove mshadow functions.
 using namespace mshadow;
 using mshadow::cpu;
 
@@ -249,6 +362,7 @@ inline Tensor<cpu, 1> Tensor1(Blob<float>* blob) {
   return tensor;
 }
 
+
 }  // namespace singa
 
 #endif  // SINGA_UTILS_BLOB_H_

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/4b84dbe3/src/blob/math_addr.cc
----------------------------------------------------------------------
diff --git a/src/blob/math_addr.cc b/src/blob/math_addr.cc
deleted file mode 100644
index fb1c42e..0000000
--- a/src/blob/math_addr.cc
+++ /dev/null
@@ -1,120 +0,0 @@
-/************************************************************
-*
-* Licensed to the Apache Software Foundation (ASF) under one
-* or more contributor license agreements.  See the NOTICE file
-* distributed with this work for additional information
-* regarding copyright ownership.  The ASF licenses this file
-* to you under the Apache License, Version 2.0 (the
-* "License"); you may not use this file except in compliance
-* with the License.  You may obtain a copy of the License at
-*
-*   http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing,
-* software distributed under the License is distributed on an
-* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-* KIND, either express or implied.  See the License for the
-* specific language governing permissions and limitations
-* under the License.
-*
-*************************************************************/
-
-#include "singa/blob/math_addr.h"
-extern "C" {
-    #include <cblas.h>
-}
-#ifdef SINGA_GPU
-#include <cuda_runtime.h>
-#endif
-#include "singa/blob/singa_op.h"
-#ifdef SINGA_GPU
-#include "cublas_v2.h"
-#endif
-
-namespace singa {
-
-const float * cpu_uni_vec(const int n) {
-    float * res = new float[n];
-    for (int i = 0; i < n; i++)
-        res[i] = 1.0;
-    return res;
-}
-
-void cpu_gemm(const float * A, const float * B, const int m, const int n,
-const int k, const float alpha, const float beta,
-const bool TranA, const bool TranB, float * C) {
-    int lda, ldb;
-    CBLAS_TRANSPOSE tA, tB;
-    lda = TranA ? m : k;
-    ldb = TranB ? k : n;
-    tA = TranA ? CblasTrans : CblasNoTrans;
-    tB = TranB ? CblasTrans : CblasNoTrans;
-    cblas_sgemm(CblasRowMajor, tA, tB, m, n, k, alpha, A, lda,
-    B, ldb, beta, C, n);
-}
-
-void cpu_gemv(const float * A, const float * B, const int m, const int n,
-const float alpha, const float beta, const bool TranA, float * C) {
-    CBLAS_TRANSPOSE tA;
-    tA = TranA ? CblasTrans : CblasNoTrans;
-    cblas_sgemv(CblasRowMajor, tA, m, n, alpha, A, n, B, 1, beta, C, 1);
-}
-
-void cpu_axpy(const float * A, const int n, const float alpha, float * B) {
-    cblas_saxpy(n, alpha, A, 1, B, 1);
-}
-
-float cpu_dot(const float * A, const float * B, const int n) {
-    float sum = 0;
-    for (int i = 0 ; i < n ; i++)
-        sum += A[i] * B[i];
-    return sum;
-}
-
-#ifdef SINGA_GPU
-// Trick: swap A and B
-void gpu_gemm(const float * A, const float * B, const int m, const int n,
-const int k, const float alpha, const float beta, const bool TranA,
-const bool TranB, float * C) {
-    int lda = TranA ? m : k;
-    int ldb = TranB ? k : n;
-    int ldc = n;
-    cublasOperation_t tA = (TranA == false) ? CUBLAS_OP_N : CUBLAS_OP_T;
-    cublasOperation_t tB = (TranB == false) ? CUBLAS_OP_N : CUBLAS_OP_T;
-    cublasHandle_t handle;
-    cublasCreate(&handle);
-    cublasSgemm(handle, tB, tA, n, m, k, &alpha, B, ldb,
-    A, lda, &beta, C, ldc);
-    cublasDestroy(handle);
-}
-
-void gpu_gemv(const float * A, const float * B, const int m, const int n,
-const float alpha, const float beta, const bool TranA, float * C) {
-    int lda = n;
-    cublasOperation_t tA = (TranA == true) ? CUBLAS_OP_N : CUBLAS_OP_T;
-    cublasHandle_t handle;
-    cublasCreate(&handle);
-    cublasSgemv(handle, tA, n, m, &alpha , A, lda, B, 1, &beta, C, 1);
-    cublasDestroy(handle);
-}
-
-
-void gpu_axpy(const float * A, const int n, const float alpha, float * B) {
-    cublasHandle_t handle;
-    cublasCreate(&handle);
-    cublasSaxpy(handle, n, &alpha, A, 1, B, 1);
-    cublasDestroy(handle);
-}
-
-
-float gpu_dot(const float * A, const float * B, const int n) {
-    cublasHandle_t handle;
-    cublasCreate(&handle);
-    float result = 0.0;
-    cublasSdot(handle, n, A, 1, B, 1, &result);
-    cublasDestroy(handle);
-    return result;
-}
-#endif
-
-}  // namespace singa

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/4b84dbe3/src/blob/math_blob.cc
----------------------------------------------------------------------
diff --git a/src/blob/math_blob.cc b/src/blob/math_blob.cc
deleted file mode 100644
index 083d3e5..0000000
--- a/src/blob/math_blob.cc
+++ /dev/null
@@ -1,214 +0,0 @@
-/************************************************************
-*
-* Licensed to the Apache Software Foundation (ASF) under one
-* or more contributor license agreements.  See the NOTICE file
-* distributed with this work for additional information
-* regarding copyright ownership.  The ASF licenses this file
-* to you under the Apache License, Version 2.0 (the
-* "License"); you may not use this file except in compliance
-* with the License.  You may obtain a copy of the License at
-*
-*   http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing,
-* software distributed under the License is distributed on an
-* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-* KIND, either express or implied.  See the License for the
-* specific language governing permissions and limitations
-* under the License.
-*
-*************************************************************/
-
-#include "singa/blob/math_blob.h"
-#ifdef SINGA_GPU
-#include "singa/blob/math_kernel.h"
-#endif  // SINGA_GPU
-
-namespace singa {
-
-/*****************************************************************************/
-// shape_check function
-
-int get_size(const std::vector<int>& shape) {
-    int sum = 1;
-    for (unsigned int i = 0; i < shape.size(); i++) sum *= shape[i];
-    return sum;
-}
-
-/*****************************************************************************/
-// class1 matrix operation
-
-
-void GEMM(XPU xpu, const Blob<float> & A, const Blob<float> & B,
-Blob<float> * C, float alpha, float beta) {
-    if (check_shape_mmm(A, B, *C)) {
-        int m = C->shape().at(0);
-        int n = C->shape().at(1);
-        int k = A.isTranspose() ? A.shape().at(0) : A.shape().at(1);
-        bool TranA = A.isTranspose();
-        bool TranB = B.isTranspose();
-        if (xpu == cpu) {
-            cpu_gemm(A.cpu_data(), B.cpu_data(), m, n, k, alpha, beta,
-            TranA, TranB, C->mutable_cpu_data());
-        }
-        #ifdef SINGA_GPU
-        if (xpu == gpu) {
-            // gpu part
-            gpu_gemm(A.gpu_data(), B.gpu_data(), m, n, k, alpha, beta,
-            TranA, TranB, C->mutable_gpu_data());
-        }
-        #endif  // SINGA_GPU
-    } else {
-        // report errors here
-    }
-}
-// C = alpha*A*B+beta*C, A, B and C are matrix
-
-void MMDot(XPU xpu, const Blob<float> & A, const Blob<float> & B,
-Blob<float> * C) {
-    GEMM(xpu, A, B, C, 1, 0);
-}
-// A,B and C are matrix
-
-
-void MVDot(XPU xpu, const Blob<float> & A, const Blob<float> & B,
-Blob<float> * C) {
-    if (check_shape_mvv(A, B, *C)) {
-        int m = B.shape().at(0);
-        int n = C->shape().at(0);
-        bool TranA = A.isTranspose();
-        if (xpu == cpu) {
-            cpu_gemv(A.cpu_data(), B.cpu_data(), m, n, 1, 0, TranA,
-            C->mutable_cpu_data());
-        }
-        #ifdef SINGA_GPU
-        if (xpu == gpu) {
-            // gpu part
-            gpu_gemv(A.gpu_data(), B.gpu_data(), m, n, 1, 0, TranA,
-            C->mutable_gpu_data());
-        }
-        #endif  // SINGA_GPU
-    } else {
-        // report errors here
-    }
-}
-// A is matrix,B and C are vector
-
-
-void VVDot(XPU xpu, const Blob<float> & A, const Blob<float> & B,
-Blob<float> * C) {
-    if (check_shape_vvm(A, B, *C)) {
-        int m = C->shape().at(0);
-        int n = C->shape().at(1);
-        if (xpu == cpu) {
-            cpu_gemm(A.cpu_data(), B.cpu_data(), m, n, 1, 1, 0,
-            false, false, C->mutable_cpu_data());
-        }
-        #ifdef SINGA_GPU
-        if (xpu == gpu) {
-            // gpu part
-            gpu_gemm(A.gpu_data(), B.gpu_data(), m, n, 1, 1, 0,
-            false, false, C->mutable_gpu_data());
-        }
-        #endif  // SINGA_GPU
-    } else {
-        // report errors here
-    }
-}
-// C is matrix,A and B are vector
-
-
-float VVdot(XPU xpu, const Blob<float> & A, const Blob<float> & B) {
-    float res = 0;
-    if (check_shape_equal(A, B, B)) {
-        int n = get_size(A.shape());
-        if (xpu == cpu) {
-            res = cpu_dot(A.cpu_data(), B.cpu_data(), n);
-        }
-        #ifdef SINGA_GPU
-        if (xpu == gpu) {
-            // gpu part
-            res = gpu_dot(A.gpu_data(), B.gpu_data(), n);
-        }
-        #endif  // SINGA_GPU
-    } else {
-        // report errors here
-    }
-    return res;
-}
-// A and B are vectors
-
-void AXPY(XPU xpu, const Blob<float> & A, Blob<float> * B, float alpha) {
-    if (check_shape_equal(A, *B, *B)) {
-        if (xpu == cpu) {
-            cpu_axpy(A.cpu_data(), get_size(A.shape()),
-            alpha, B->mutable_cpu_data());
-        }
-        #ifdef SINGA_GPU
-        if (xpu == gpu) {
-            gpu_axpy(A.gpu_data(), get_size(A.shape()),
-            alpha, B->mutable_gpu_data());
-        }
-        #endif  // SINGA_GPU
-    } else {
-        // report errors here
-    }
-}
-// element-wise operation: Bi = alpha*Ai+Bi  A and B should have the same size
-
-inline void Repmat(XPU xpu, const Blob<float> & A, Blob<float> * B) {
-    MVAdd(xpu, A, B, 1, 0);
-}
-// A is a vector, B is a matrix , let each row of B to be A
-
-void MVAdd(XPU xpu, const Blob<float> & A, Blob<float> * B,
-float alpha, float beta) {
-    if (check_shape_mv(*B, A)) {
-        int m = get_size(A.shape());
-        int n = get_size(B->shape()) / m;
-        if (xpu == cpu) {
-            const float * univ = cpu_uni_vec(n);
-            cpu_gemm(A.cpu_data(), univ, m, n, 1, alpha, beta,
-            false, false, B->mutable_cpu_data());
-            delete univ;
-        }
-        #ifdef SINGA_GPU
-        if (xpu == gpu) {
-            singa_gpu_add_vec_row(B->gpu_data(),
-            A.gpu_data(), A.gpu_data(), m, n, n);
-            // gpu part
-        }
-        #endif  // SINGA_GPU
-    } else {
-        // report errors here
-    }
-}
-// A is a vector, B is a matrix , Bij = alpha*Ai+beta*Bij
-// will use gemm. faster than general expand_f
-
-void MVSum(XPU xpu, const Blob<float> & A, Blob<float> * B,
-float alpha, float beta) {
-    if (check_shape_mv(A, *B)) {
-        int m = get_size(B->shape());
-        int n = get_size(A.shape()) / m;
-        if (xpu == cpu) {
-            const float * univ = cpu_uni_vec(n);
-            cpu_gemm(A.cpu_data(), univ, m, 1, n, alpha, beta,
-            false, false, B->mutable_cpu_data());
-            delete univ;
-        }
-        #ifdef SINGA_GPU
-        if (xpu == gpu) {
-            singa_gpu_sum_col(A.gpu_data(), B->gpu_data(), m, n, n);
-            // gpu part
-        }
-        #endif  // SINGA_GPU
-    } else {
-        // report errors here
-    }
-}
-// B is a vector, A is a matrix , Bi = \sigma_j_{alpha*Aij}+beta*Bi
-// will use gemm. faster than general reduce_f
-
-}  // namespace singa
-