You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@singa.apache.org by wa...@apache.org on 2016/06/03 07:48:36 UTC

[31/60] incubator-singa git commit: SINGA-167 - Add Tensor Math function APIs

SINGA-167 - Add Tensor Math function APIs

Add basic linalg functions for Tensor

Add blas functions for Tensor.

Unify gemm and gemv in Tensor::Mult

this commit also contains code for Param class, which woud be removed in the next commit.


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/02851fac
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/02851fac
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/02851fac

Branch: refs/heads/dev
Commit: 02851fac11ae6455b60d1cd5be4c2b6f142696cf
Parents: e36bc92
Author: Wei Wang <wa...@comp.nus.edu.sg>
Authored: Fri May 13 21:00:48 2016 +0800
Committer: wangwei <wa...@gmail.com>
Committed: Tue May 17 00:40:23 2016 +0800

----------------------------------------------------------------------
 CMakeLists.txt                       |   2 +-
 include/singa/core/math.h            | 273 ---------------------
 include/singa/core/tensor.h          | 285 +++++++++++-----------
 include/singa/model/layer.h          |  23 +-
 include/singa/model/param.h          |  97 ++++++++
 src/core/device/device.cc            |   1 +
 src/core/math/cpp_math.cc            |  54 -----
 src/core/math/cuda_math.cc           |  48 ----
 src/core/math/opencl_math.cc         |  24 --
 src/core/tensor/tensor.cc            | 379 ++++++++++++++++++++++++++----
 src/core/tensor/tensor_math.h        | 302 ++++++++++++++++++++++++
 src/core/tensor/tensor_math_cpp.h    |  57 +++++
 src/core/tensor/tensor_math_cuda.h   |  53 +++++
 src/core/tensor/tensor_math_opencl.h |  28 +++
 src/model/layer/layer.cc             |   8 +
 src/proto/layer.proto                |  22 +-
 test/singa/test_cpp_math.cc          |   4 +-
 test/singa/test_tensor.cc            |  35 +--
 test/singa/test_tensor_math.cc       |  84 +++++++
 19 files changed, 1135 insertions(+), 644 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/02851fac/CMakeLists.txt
----------------------------------------------------------------------
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 21b3804..67a82e5 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -1,6 +1,6 @@
 CMAKE_MINIMUM_REQUIRED(VERSION 2.6)
 PROJECT(singa)
-SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11")
+SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g -std=c++11")
 
 # Flags
 IF(UNIX OR APPLE)

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/02851fac/include/singa/core/math.h
----------------------------------------------------------------------
diff --git a/include/singa/core/math.h b/include/singa/core/math.h
deleted file mode 100644
index 511d9ee..0000000
--- a/include/singa/core/math.h
+++ /dev/null
@@ -1,273 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-#ifndef SINGA_CORE_MATH_H_
-#define SINGA_CORE_MATH_H_
-#include <type_traits>
-#include "singa/core/common.h"
-#include "singa/utils/logging.h"
-
-namespace singa {
-
-/// \file math.h Math functions for linear algebra, neural net and random
-/// operations.
-/// All functions have a template argument, DType for DataType, Lib for the
-/// backend library, e.g., lib::Cublas, lib::Cudnn, etc.
-
-/// Some operations would have many config/hyper-parameters, e.g., Conv, and
-/// these config vary among diff implementations, e.g., cuda/cudnn/opencl.
-/// To separate the modules, we pass a OpConf pointer to the Tensor Op function.
-/// The specific fields are implemented by inheriting OpConf, and casting the
-/// pointer between the base and the sub-class.
-class OpConf {
- public:
-  template <typename T>
-  T* CastTo() {
-    static_assert(std::is_base_of<OpConf, T>::value,
-                  "The cast type must be a sub-class of OpConf");
-    return static_cast<T*>(this);
-  }
-};
-
-// ================Linear algebra functions====================================
-template <typename DType, typename Lib>
-void Sum(int count, const Blob* input, DType* ret, Context* ctx) {
-  LOG(FATAL) << "Not Implemented";
-}
-
-template <typename DType, typename Lib>
-void Abs(int count, const Blob* input, Blob* ret, Context* ctx) {
-  LOG(FATAL) << "Not Implemented";
-}
-
-template <typename DType, typename Lib>
-void Sign(int count, const Blob* input, Blob* ret, Context* ctx) {
-  LOG(FATAL) << "Not Implemented";
-}
-
-/// Base is e, Neper number
-template <typename DType, typename Lib>
-void Exp(int count, const Blob* input, Blob* ret, Context* ctx) {
-  LOG(FATAL) << "Not Implemented";
-}
-
-/// Natual logarithm, the base is e, Neper number.
-template <typename DType, typename Lib>
-void Log(int count, const Blob* input, Blob* ret, Context* ctx) {
-  LOG(FATAL) << "Not Implemented";
-}
-
-template <typename DType, typename Lib>
-void Sqrt(int count, const Blob* input, Blob* ret, Context* ctx) {
-  LOG(FATAL) << "Not Implemented";
-}
-
-template <typename DType, typename Lib>
-void Tanh(int count, const Blob* input, Blob* ret, Context* ctx) {
-  LOG(FATAL) << "Not Implemented";
-}
-
-template <typename DType, typename Lib>
-void Sigmoid(int count, const Blob* input, Blob* ret, Context* ctx) {
-  LOG(FATAL) << "Not Implemented";
-}
-
-/// Do v^x for every v from the input tensor
-template <typename DType, typename Lib>
-void Pow(int count, DType x, const Blob* input, Blob* ret, Context* ctx) {
-  LOG(FATAL) << "Not Implemented";
-}
-
-/// Do v^x for every v from the lhs and every x from rhs
-template <typename DType, typename Lib>
-void Pow(int count, const Blob* lhs, const Blob* rhs, Blob* ret, Context* ctx) {
-  LOG(FATAL) << "Not Implemented";
-}
-
-/// Clamp every element into [low, high]
-template <typename DType, typename Lib>
-void Clamp(int count, DType low, DType high, const Blob* input, Blob* ret,
-           Context* ctx) {
-  LOG(FATAL) << "Not Implemented";
-}
-
-/// ret = x + input
-template <typename DType, typename Lib>
-void Add(int count, DType x, const Blob* input, Blob* ret, Context* ctx) {
-  LOG(FATAL) << "Not Implemented";
-}
-
-/// ret = x * input
-/// div could be enabled by calling Mult with 1/x
-template <typename DType, typename Lib>
-void Mult(int count, DType x, const Blob* input, Blob* ret, Context* ctx) {
-  LOG(FATAL) << "Not Implemented";
-}
-
-/// ret = lhs + rhs
-template <typename DType, typename Lib>
-void Add(int count, const Blob* lhs, const Blob* rhs, Blob* ret, Context* ctx) {
-  LOG(FATAL) << "Not Implemented";
-}
-
-/// ret = lhs - rhs
-template <typename DType, typename Lib>
-void Sub(int count, const Blob* lhs, const Blob* rhs, Blob* ret, Context* ctx) {
-  LOG(FATAL) << "Not Implemented";
-}
-
-/// ret = lhs * rhs
-template <typename DType, typename Lib>
-void Mult(int count, const Blob* lhs, const Blob* rhs, Blob* ret,
-          Context* ctx) {
-  LOG(FATAL) << "Not Implemented";
-}
-
-/// ret = lhs / rhs
-template <typename DType, typename Lib>
-void Div(int count, const Blob* lhs, const Blob* rhs, Blob* ret, Context* ctx) {
-  LOG(FATAL) << "Not Implemented";
-}
-
-/// outer-product.
-/// lhs and rhs are vectors of len m and n. ret is matrix of shape m * n
-template <typename DType, typename Lib>
-void Outer(int m, int n, const Blob* lhs, const Blob* rhs, Blob* ret,
-           Context* ctx) {
-  LOG(FATAL) << "Not Implemented";
-}
-
-// TODO(wangwei) unify SumRow and SumCol.
-/// Sum the rows of the input matrix into a vector
-template <typename DType, typename Lib>
-void SumRow(int nrow, int ncol, const Blob* input, Blob* ret, Context* ctx) {
-  LOG(FATAL) << "Not Implemented";
-}
-/// Sum the rows of the input matrix into a vector
-template <typename DType, typename Lib>
-void SumCol(int nrow, int ncol, const Blob* input, Blob* ret, Context* ctx) {
-  LOG(FATAL) << "Not Implemented";
-}
-
-// TODO(wangwei) unify AddRow and AddCol.
-/// Add the vector v to every row of A as the row of ret
-template <typename DType, typename Lib>
-void AddRow(int nrow, int ncol, const Blob* A, const Blob* v, Blob* ret,
-            Context* ctx) {
-  LOG(FATAL) << "Not Implemented";
-}
-
-/// Add the vector v to every column of A as the column of ret
-template <typename DType, typename Lib>
-void AddCol(int nrow, int ncol, const Blob* A, const Blob* v, Blob* ret,
-            Context* ctx) {
-  LOG(FATAL) << "Not Implemented";
-}
-
-// ===== BLAS functions, ref to http://docs.nvidia.com/cuda/cublas
-// ===== Level 1
-/// return the index of the element with the max value.
-template <typename DType, typename Lib>
-void Amax(int count, const Blob* input, int* ret, Context* ctx) {
-  LOG(FATAL) << "Not Implemented";
-}
-
-/// return the index of the element with the min value.
-template <typename DType, typename Lib>
-void Amin(int count, const Blob* input, int* ret, Context* ctx) {
-  LOG(FATAL) << "Not Implemented";
-}
-/// ret = sum |x| for all x in input
-template <typename DType, typename Lib>
-void Asum(int count, const Blob* input, DType* ret, Context* ctx) {
-  LOG(FATAL) << "Not Implemented";
-}
-
-/// ret = alpha * input + ret
-template <typename DType, typename Lib>
-void Axpy(int count, DType alpha, const Blob* input, Blob* ret, Context* ctx) {
-  LOG(FATAL) << "Not Implemented";
-}
-
-/// ret *= x
-template <typename DType, typename Lib>
-void Scale(int count, DType x, Blob* ret, Context* ctx) {
-  LOG(FATAL) << "Not Implemented";
-}
-
-template <typename DType, typename Lib>
-void Dot(int count, const Blob* lhs, const Blob* rhs, DType* ret,
-         Context* ctx) {
-  LOG(FATAL) << "Not Implemented";
-}
-
-// ===== Level 2
-/// ret = alpha * op(A) * v + beta * ret.
-/// op(A) = A if trans = false; A^T otherwise; rows(A) = m, cols(A) = n.
-template <typename DType, typename Lib>
-void GEMV(bool trans, int m, int n, DType alpha, const Blob* A, const Blob* v,
-          DType beta, Blob* ret, Context* ctx) {
-  LOG(FATAL) << "Not Implemented";
-}
-
-// ===== Level 3
-/// ret = alpha * op(A) * op(B) + beta * ret.
-/// op(A) = A if trans = false; A^T otherwise; rows(A) = m, cols(A) = n.
-template <typename DType, typename Lib>
-void GEMV(bool transA, bool transB, int m, int n, int k, DType alpha,
-          const Blob* A, const Blob* B, DType beta, Blob* ret, Context* ctx) {
-  LOG(FATAL) << "Not Implemented";
-}
-
-// ================Random functions===========================================
-// The random generator should be extracted from ctx.
-template <typename DType, typename Lib>
-void Uniform(int count, DType low, DType high, Blob* ret, Context* ctx) {
-  LOG(FATAL) << "Not Implemented";
-}
-
-template <typename DType, typename Lib>
-void Gaussian(int count, DType mean, DType std, Blob* ret, Context* ctx) {
-  LOG(FATAL) << "Not Implemented";
-}
-
-/// each element of ret would be 1 with prob p and 0 with 1-p. 0<= p <= 1
-template <typename DType, typename Lib>
-void Bernoulli(int count, DType p, Blob* ret, Context* ctx) {
-  LOG(FATAL) << "Not Implemented";
-}
-
-/// ret[i] would be 1 with prob p[i] and 0 with 1-p[i]. 0<= p[i] <= 1
-template <typename DType, typename Lib>
-void Bernoulli(int count, const Blob* p, Blob* ret, Context* ctx) {
-  LOG(FATAL) << "Not Implemented";
-}
-
-// ================Neural net functions=======================================
-/// Do 2D conv.
-/// c is input image channel, w is input width, h is input height
-/// nb_kernel is output channel, kw, and kh are kenerl width and height
-/*
-template <typename DType, typename Lib>
-void Conv2D(int c, int w, int h, int nb_kernel, int kw, int kh,
-           const Blob* input, const Blob* kernel, Blob* ret, Context* ctx) {
-  LOG(FATAL) << "Not Implemented";
-}
-*/
-}  // namespace singa
-
-#endif  // SINGA_CORE_MATH_H_

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/02851fac/include/singa/core/tensor.h
----------------------------------------------------------------------
diff --git a/include/singa/core/tensor.h b/include/singa/core/tensor.h
index 725f657..4278078 100644
--- a/include/singa/core/tensor.h
+++ b/include/singa/core/tensor.h
@@ -20,23 +20,29 @@
 #define SINGA_CORE_TENSOR_H_
 
 #include <vector>
+#include <tuple>
 
 #include "singa/core/common.h"
 #include "singa/core/device.h"
-#include "singa/core/math.h"
 #include "singa/proto/core.pb.h"
 #include "singa/utils/logging.h"
 
 using std::vector;
+using std::tuple;
 namespace singa {
 
 typedef vector<int> Shape;
 inline int Product(Shape shape) {
   if (shape.size() == 0)
     return 0;
+  return Product(shape.begin(), shape.end());
+}
+
+inline int Product(vector<int>::iterator begin, vector<int>::iterator end) {
+  CHECK(begin != end);
   int v = 1;
-  for (auto s : shape)
-    v *= s;
+  for (auto it = being; it < end; it++)
+    v* = *it;
   return v;
 }
 
@@ -60,19 +66,20 @@ inline int SizeOf(DataType t) {
 class Tensor {
  public:
   ~Tensor();
-  Tensor() = default;
-  explicit Tensor(const Shape& shape, DataType dtype = kFloat32);
+  Tensor();
+  Tensor(Shape&& shape, DataType dtype = kFloat32);
+  Tensor(const Shape& shape, DataType dtype = kFloat32);
+  Tensor(Shape&& shape, Device* dev, DataType dtype = kFloat32);
   Tensor(const Shape& shape, Device* dev, DataType dtype = kFloat32);
 
   /// Copy Tensor to share the internal data.  No deep copy.
   Tensor(const Tensor& from);
-
   /// Copy Tensor to share the internal data.  No deep copy.
   Tensor(Tensor&& from);
 
   /// For functions in xx_math.cc to access the blob.
   /// Users should not operate against Blob directly.
-  /// It will malloc memory for the tensor if not allocated before.
+  /// blob_ is allocated in constructors.
   Blob* blob() const {
     return blob_;
   }
@@ -82,9 +89,9 @@ class Tensor {
   }
 
   /// Return immutable Tensor values with given type.
-  template <typename T>
-  const T* data() {
-    return static_cast<const T*> (blob()->data());
+  template <typename DType>
+  const DType* data() const {
+    return static_cast<const DType*> (blob()->data());
   }
 
   /// data type, including kFloat16, kFloat32, kInt
@@ -96,20 +103,28 @@ class Tensor {
     return shape_;
   }
 
+  int nDim() const {
+    return shape_.size();
+  }
+
   bool transpose() const {
     return transpose_;
   }
 
+  /// Return number of total elements
   int Size() const {
     return blob_->size() / SizeOf(data_type_);
   }
 
+  /// Return memory size (i.e., Bytes)
   int MemSize() const {
     return blob_->size();
   }
 
+  /// Reset the tensor shape, it may reallocate blob, if MemSize() changes.
   void ReShape(const Shape& shape);
 
+  /// Reset the data type, it would reallocate blob if type changes.
   void AsType(DataType type);
 
   /// Reset the device.
@@ -119,8 +134,9 @@ class Tensor {
   /// Equivalent to ToDevice(host_dev).
   void ToHost();
 
-  /// For init the tensor values, copy 'size' bytes data.
-  void CopyDataFromHostPtr(const void* src, size_t size);
+  /// For init the tensor values, copy 'num' elements.
+  template<typename DType>
+  void CopyDataFromHostPtr(const DType* src, int num);
 
   /// Copy data from another Tensor which may be on a diff device.
   /// Meta data would not be copied!
@@ -141,49 +157,39 @@ class Tensor {
   /// Copy the meta info with data blob shared.
   void operator=(Tensor&& t);
 
+
   void operator+=(const Tensor& t);
-  /*
-  void operator+=(Tensor&& t);
+  // void operator+=(Tensor&& t);
   void operator-=(const Tensor& t);
-  void operator-=(Tensor&& t);
+  // void operator-=(Tensor&& t);
   void operator*=(const Tensor& t);
-  void operator*=(Tensor&& t);
+  // void operator*=(Tensor&& t);
   void operator/=(const Tensor& t);
-  void operator/=(Tensor&& t);
+  // void operator/=(Tensor&& t);
 
   // Scalar operations.
 
   /// T is a scalar type
-  template <typename T>
-  void operator+=(const T x);
+  template<typename DType>
+  void operator+=(DType x);
 
   /// T is a scalar type
-  template <typename T>
-  void operator-=(const T x);
+  template <typename DType>
+  void operator-=(const DType x);
 
   /// T is a scalar type
-  template <typename T>
-  void operator*=(const T x);
+  template <typename DType>
+  void operator*=(const DType x);
 
   /// T is a scalar type
-  template <typename T>
-  void operator/=(const T x);
-
-  void Log(int base = 2);
-  void Tanh();
-  void Sigmoid();
-  void ReLU();
-
-  // random functions.
-  void Uniform(float low, float high);
-  template <typename T>
-  void Gaussian(float mean, float std);
+  template <typename DType>
+  void operator/=(const DType x);
 
   /// save Tensor into a proto msg
   // void ToProto(TensorProto* t);
   /// load Tensor from proto msg
   // void FromProto(const TensorProto& t);
-  */
+
  protected:
   bool transpose_ = false;
   DataType data_type_ = kFloat32;
@@ -194,142 +200,131 @@ class Tensor {
   Shape shape_;
 };
 
-/// For tensors with sparse content, e.g., missing columns or rows.
+// For tensors with sparse content, e.g., missing columns or rows.
 // class SparseTensor : public Tensor {};
 
-// ==================Simple Linear Algebra Operations=========================
-/*
-Tensor Tanh(const Tensor& t);
-Tensor Log(const Tensor& t);
-Tensor Sigmoid(const Tensor& t);
-Tensor ReLU(const Tensor& t);
-Tensor Softmax(const Tensor& t);
-*/
+/// Copy 'num' elements of src to dst.
+/// The first 'src_offset' ('dst_offset') elements will be skipped.
 void CopyData(Tensor* dst,
               const Tensor& src,
-              int msize,
+              int num,
               int src_offset = 0,
               int dst_offset = 0);
 
-// element-wise ops
+/// Copy 'nBytes' bytes of src data to dst.
+/// The first 'src_offset' ('dst_offset') bytes will be skipped.
+void CopyRawData(Tensor* dst,
+              const Tensor& src,
+              int nBytes,
+              int src_offset = 0,
+              int dst_offset = 0);
+
+// ==================Simple Linear Algebra Operations=========================
+Tensor Abs(const Tensor& t);
+Tensor Exp(const Tensor& t);
+Tensor Log(const Tensor& t);
+Tensor ReLU(const Tensor& t);
+Tensor Sigmoid(const Tensor& t);
+Tensor Sign(const Tensor& t);
+Tensor Sqrt(const Tensor& t);
+Tensor Tanh(const Tensor& t);
+
+/// Regarding the internal data as 2d, with shape_[0]*...*shape_[axis] rows,
+/// and shape_[axis+1]*...*shape_[nDim()] columns.
+/// and do softmax along each row.
+Tensor Softmax(const Tensor& t, int axis = -1);
+void Softmax(const Tensor& t, Tensor* ret, int axis = -1);
+
+/// Element-wise opeartion, ret[i]=t[i]^x
+template<typename DType>
+Tensor Pow(const Tensor& t, DType x);
+/// Element-wise opeartion, ret[i]=t[i]^x
+template<typename DType>
+void Pow(const Tensor& t, DType x, Tensor* ret);
+/// Element-wise opeartion, ret[i]=baes[i]^exp[i]
+Tensor Pow(const Tensor& base, Tensor exp);
+/// Element-wise opeartion, ret[i]=baes[i]^exp[i]
+void Pow(const Tensor& base, const Tensor& exp, Tensor* ret);
 
 Tensor operator+(const Tensor& lhs, const Tensor& rhs);
 void Add(const Tensor& lhs, const Tensor& rhs, Tensor* ret);
-/*
 Tensor operator-(const Tensor& lhs, const Tensor& rhs);
 void Sub(const Tensor& lhs, const Tensor& rhs, Tensor* ret);
 Tensor operator*(const Tensor& lhs, const Tensor& rhs);
-void operator*(const Tensor& lhs, const Tensor& rhs, Tensor* ret);
+void EltwiseMult(const Tensor& lhs, const Tensor& rhs, Tensor* ret);
 Tensor operator/(const Tensor& lhs, const Tensor& rhs);
-void operator/(const Tensor& lhs, const Tensor& rhs, Tensor* ret);
+void Div(const Tensor& lhs, const Tensor& rhs, Tensor* ret);
 
-template <typename T>
-Tensor operator+(const T x, const Tensor& t);
-template <typename T>
-void operator+(const T x, const Tensor& t, Tensor* ret);
+template <typename DType>
+Tensor operator+(const Tensor& t, DType x);
+template <typename DType>
+void Add(const Tensor& t, DType x, Tensor* ret);
 
-template <typename T>
-Tensor operator-(const T x, const Tensor& t);
-template <typename T>
-void operator-(const T x, const Tensor& t, Tensor* ret);
+template <typename DType>
+Tensor operator-(const Tensor& t, DType x);
+template <typename DType>
+void Sub(const Tensor& t, DType x, Tensor* ret);
 
-template <typename T>
-Tensor operator*(const T x, const Tensor& t);
-template <typename T>
-void operator*(const T x, const Tensor& t, Tensor* ret);
+template <typename DType>
+Tensor operator*(const Tensor& t, DType x);
+template <typename DType>
+void EltwiseMult(const Tensor& t, DType x, Tensor* ret);
 
-template <typename T>
-Tensor operator/(const T x, const Tensor& t);
-template <typename T>
-void operator/(const T x, const Tensor& t, Tensor* ret);
+template <typename DType>
+Tensor operator/(const Tensor& t, DType x);
+template <typename DType>
+void Div(const Tensor& t, DType x, Tensor* ret);
 
 //================Blas operations============================================
+// ===== Level 1
+// TODO(wangwei) make amax/amin/asum a member function of tensor
+// void Amax(Tensor, Context* ctx); Get the index of the max value in a vector
+// void Asum(Tensor Context* ctx);
+
+// template <typename DType>
+// void Axpy(DType x, const Blob& t, Blob* ret, Context* ctx);
+
+/// Do matrix vector multipication or matrix matrix multiplication depdending
+/// on the Tensor shape.  ret = lhs * rhs
+template <typename DType>
 Tensor Mult(const Tensor& lhs, const Tensor& rhs);
+/// Do matrix vector multipication or matrix matrix multiplication depdending
+/// on the Tensor shape.  ret = lhs * rhs
+template <typename DType>
 void Mult(const Tensor& lhs, const Tensor& rhs, Tensor* ret);
 
-tempalte<typename T> T Dot(const Tensor& lhs, const Tensor& rhs);
-
-//================Neural Net operations======================================
+/// Do matrix vector multipication or matrix matrix multiplication depdending
+/// on the Tensor shape.  ret = alpha lhs * rhs + beta * ret
+template <typename DType>
+Tensor Mult(DType alpha, const Tensor& lhs, DType beta, const Tensor& rhs);
+/// Do matrix vector multipication or matrix matrix multiplication depdending
+/// on the Tensor shape. ret = alpha lhs * rhs + beta * ret
+template <typename DType>
+void Mult(DType alpha, const Tensor& lhs, DType beta, const Tensor& rhs,
+    Tensor* C);
 
-/// Convolution Op. 'Conf' is ConvConf;
-void Conv(const OpConf* conf,
-          const Tensor& input,
-          const Tensor& W,
-          const Tensor &b,
-          Tensor* ret);
+// tempalte<typename DType> T Dot(const Tensor& lhs, const Tensor& rhs);
 
 //================Random operations==========================================
-Tensor Uniform(float low, float high, const Shape& shape, Device* dev);
-
-Tensor Gaussian(float mean, float std, const Shape& shape, Device* dev);
-*/
-//============================================================================
-/// typedef DType accroding to type value.
-/// DType would be used in the code block __VA_ARGS__.
-#define TYPE_SWITCH(type, DType, ...)                               \
-  do {                                                              \
-    switch (type) {                                                 \
-      case kFloat32: {                                              \
-        typedef float DType;                                        \
-        { __VA_ARGS__ }                                             \
-        break;                                                      \
-      }                                                             \
-      case kInt: {                                                  \
-        typedef int DType;                                          \
-        { __VA_ARGS__ }                                             \
-        break;                                                      \
-      }                                                             \
-      case kChar: {                                                 \
-        typedef char DType;                                         \
-        { __VA_ARGS__ }                                             \
-        break;                                                      \
-      }                                                             \
-      default:                                                      \
-        LOG(FATAL) << "Unknow data type = " << DataType_Name(type); \
-    }                                                               \
-  } while (0)
-
-/// typedef DType and Lib according to values of type and lib respectively.
-/// type is from DataType, and lib is from LibType.
-/// DType and Lib would be used in __VA_ARGS__.
-#define TYPE_LIB_SWITCH(dtype, DType, ltype, Lib, ...)                 \
-  do {                                                               \
-    const int _SwitchShift = 3;                                      \
-    int _SwitchHash = ((dtype) << _SwitchShift) + (ltype);                 \
-    switch (_SwitchHash) {                                           \
-      case ((kFloat32 << _SwitchShift) + kCuda): {                   \
-        typedef float DType;                                          \
-        typedef lib::Cuda Lib;                                            \
-        { __VA_ARGS__ }                                              \
-        break;                                                       \
-      }                                                              \
-      case ((kFloat32 << _SwitchShift) + kCudnn): {                  \
-        typedef float DType;                                          \
-        typedef lib::Cudnn Lib;                                           \
-        { __VA_ARGS__ }                                              \
-        break;                                                       \
-      }                                                              \
-      case ((kFloat32 << _SwitchShift) + kCpp): {                    \
-        typedef float DType;                                          \
-        typedef lib::Cpp Lib;                                             \
-        { __VA_ARGS__ }                                              \
-        break;                                                       \
-      }                                                              \
-      case ((kFloat32 << _SwitchShift) + kOpencl): {                \
-        typedef float DType;                                          \
-        typedef lib::Opencl Lib;                                          \
-        { __VA_ARGS__ }                                              \
-        break;                                                       \
-      }                                                              \
-      default:                                                       \
-        LOG(FATAL) << "Unknown combination of data type "            \
-                   << DataType_Name(dtype) << " and library "        \
-                   << LibType_Name(ltype);                             \
-    }                                                                \
-  } while (0)
-
-
+/// For each element x set x = 0 if random() < p; otherwise x = 1.
+Tensor Bernoulli(float p, Blob* t);
+/// Fill in Tensor 't' following uniform distribution.
+Tensor Uniform(float low, DType high, Blob* t);
+/// Fill in Tensor 't' following Gaussian distribution.
+Tensor Gaussian(float mean, DType std, Blob* t);
 
+//================Neural Net operations======================================
+// following API of cudnn, e.g., conv, pool, lrn, batchnorm, softmax
+void ConvFwd(const ConvConf& conf, const Tensor& x, const Tensor& w, Tensor* y);
+void ConvBwdBias(const ConvConf& conf, const Tensor& dy, Tensor* db);
+void ConvBwdFilter(const ConvConf& conf, const Tensor& dy, const Tensor& x,
+                   Tensor* dw);
+void ConvBwdData(const ConvConf& conf, const Tensor& dy, const Tensor& w,
+                 Tensor* db);
+void PoolFwd(const PoolConf& conf, const Tensor& x, Tensor* y,
+             Tensor* mask = nullptr);
+void PoolBwd(const PoolConf& conf, const Tensor& y, const Tensor& dy,
+             const Tensor& x, Tensor* dx);
 }  // namespace singa
 
 #endif  // SINGA_CORE_TENSOR_H_

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/02851fac/include/singa/model/layer.h
----------------------------------------------------------------------
diff --git a/include/singa/model/layer.h b/include/singa/model/layer.h
index 37f3fa8..7b9b6d4 100644
--- a/include/singa/model/layer.h
+++ b/include/singa/model/layer.h
@@ -45,7 +45,9 @@ class Layer {
   }
 
   /// Set meta data fields configured in 'conf' (a proto message).
-  virtual void Setup(const LayerConf& conf) {}
+  virtual void Setup(const LayerConf& conf) {
+    name_ = conf.name();
+  }
 
   /// Do feature transformation for given 'input' Tensor.
   /// It is the forward pass for feed-forward nets and rnn nets.
@@ -67,6 +69,7 @@ class Layer {
                                                const vector<Tensor>& input) {
     return vector<Tensor>{};
   }
+  // return <dx>  <dw (ParamGrad)>
 
   /// Move the layer (including its parameters and other Tensor) onto the given
   /// device
@@ -82,28 +85,26 @@ class Layer {
   }
 
   /// Serialize the layer info, including params)_, into a LayerConf message.
-  virtual std::string ToProto(LayerConf* param) const = 0;
+  virtual std::string ToProto(LayerConf* conf) const {
+    conf->set_name(name_);
+  }
 
   /// Serialize the layer info, including params_, into a string representing
   /// a LayerParameter message.
-  /*
-  std::string ToProtoStr() const {
-    std:: string str;
-    SerializeToString(&str);
-  }
-  */
+  std::string ToProtoStr() const;
 
   /// Return all Param instances of this layer.
-  const vector<void*> params() const { return params_; }
+  /// Each layer could cache the Param objects.
+  /// To save memory of , it can also create it when this function
+  /// is called
+  const vector<Param*> GetParam();
 
   /// Each layer instance would optionally have a name.
   /// Used for debugging and logging.
   const std::string name() const { return name_; }
 
-
  protected:
   std::string name_;
-  std::vector<void*> params_;
 };
 
 }  // namespace singa

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/02851fac/include/singa/model/param.h
----------------------------------------------------------------------
diff --git a/include/singa/model/param.h b/include/singa/model/param.h
new file mode 100644
index 0000000..b859b1c
--- /dev/null
+++ b/include/singa/model/param.h
@@ -0,0 +1,97 @@
+/************************************************************
+*
+* Licensed to the Apache Software Foundation (ASF) under one
+* or more contributor license agreements.  See the NOTICE file
+* distributed with this work for additional information
+* regarding copyright ownership.  The ASF licenses this file
+* to you under the Apache License, Version 2.0 (the
+* "License"); you may not use this file except in compliance
+* with the License.  You may obtain a copy of the License at
+*
+*   http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing,
+* software distributed under the License is distributed on an
+* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+* KIND, either express or implied.  See the License for the
+* specific language governing permissions and limitations
+* under the License.
+*
+*************************************************************/
+
+#ifndef SINGA_MODEL_PARAM_H_
+#define SINGA_MODEL_PARAM_H_
+#include "singa/core/tensor.h"
+#include <vector>
+#include <string>
+using std::vector;
+using std::string;
+namespace singa {
+/// Base Param class for storing set of parameters, e.g., a weight matrix or a
+/// bias vector.
+/// It includes multiple Tensor s for parameter values, gradients, etc.
+class Param {
+ public:
+  ~Param();
+  Param(const ParamSpec& conf);
+  Param(Param&& p);
+  Param(const Param& p);
+  void operator=(Param&& p);
+  void operator=(const Param& p);
+
+  Tensor& value() {
+    return value_;
+  }
+
+  Tensor& grad() {
+    return grad_;
+  }
+
+  void set_value(const Tensor& t) {
+    value_ = t;
+  }
+
+  void set_value(Tensor&& t) {
+    value_ = std::move(t);
+  }
+
+  void set_grad(const Tensor& t) {
+    isGradValid_ = true;
+    grad_ = t;
+  }
+
+  void set_grad(Tensor&& t) {
+    grad_ = std::move(t);
+  }
+
+  // void Compress();
+  // string ToString();
+
+ protected:
+  string name_;
+  Tensor value_;
+  float lr_mult_ = 1.0f, decay_mult_ = 1.0f;
+};
+
+class ParamGrad {
+// return grad tensor or data to recover the grad tensor, e.g., if W = U * V
+// then, ParamGrad could just store U and V. provide func for serailize and
+// deserialize.
+};
+
+// updater just copy the ParamGrad to a device and submit ops to that device, e.g.,
+// add grad; check update_condidtion; apply sgd; copy back.
+// consider rpc (no rmda).
+
+Param* CreateParam(string type) {
+  Param* p = nullptr;
+  if (type == "default")
+    p = new Param();
+  else
+    LOG(FATAL) << "Currently param type " << type << " is not implemented."
+               << "Pls use the 'default' type";
+  return p;
+}
+#endif  // SINGA_MODEL_PARAM_H_
+
+}  // namespace singa

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/02851fac/src/core/device/device.cc
----------------------------------------------------------------------
diff --git a/src/core/device/device.cc b/src/core/device/device.cc
index 5bdab6f..4976a32 100644
--- a/src/core/device/device.cc
+++ b/src/core/device/device.cc
@@ -49,6 +49,7 @@ void Device::FreeBlob(Blob* blob) {
 
 void Device::CopyData(Blob* dst, const Blob& src, int len, int dst_offset,
                       int src_offset) {
+
   memcpy(reinterpret_cast<Byte*>(dst->mutable_data()) + dst_offset,
          (const Byte*)src.data() + src_offset, len);
 }

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/02851fac/src/core/math/cpp_math.cc
----------------------------------------------------------------------
diff --git a/src/core/math/cpp_math.cc b/src/core/math/cpp_math.cc
deleted file mode 100644
index 638d693..0000000
--- a/src/core/math/cpp_math.cc
+++ /dev/null
@@ -1,54 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "singa/core/math.h"
-#include "singa/core/common.h"
-
-#ifdef USE_CBLAS
-#include <cblas.h>
-#endif
-
-namespace singa {
-template<>
-void Add<float, lib::Cpp>(int count,
-                     const Blob* lhs,
-                     const Blob* rhs,
-                     Blob* ret,
-                     Context* ctx) {
-  // CHECK_EQ(ctx->stream, nullptr);
-  float *dptr = static_cast<float*>(ret->mutable_data());
-  const float *lptr = static_cast<const float*>(lhs->data());
-  const float *rptr = static_cast<const float*>(rhs->data());
-  for (int i = 0; i < count; i++) {
-    dptr[i] = lptr[i] + rptr[i];
-  }
-}
-
-#ifdef USE_CBLAS
-template<>
-void Dot<float, lib::Cpp>(int count,
-                     const Blob* lhs,
-                     const Blob* rhs,
-                     float* ret,
-                     Context* ctx) {
-  float dptr = ret->mutable_data(), lptr = lhs->data(), rptr = rhs->data();
-  *ret = cblas_sdot(count, lptr, 1, rptr, 1);
-}
-
-#endif
-}

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/02851fac/src/core/math/cuda_math.cc
----------------------------------------------------------------------
diff --git a/src/core/math/cuda_math.cc b/src/core/math/cuda_math.cc
deleted file mode 100644
index 1cff1c2..0000000
--- a/src/core/math/cuda_math.cc
+++ /dev/null
@@ -1,48 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "singa/core/math.h"
-#include "singa/core/common.h"
-
-
-namespace singa {
-
-#ifdef USE_CUDA
-template<>
-void Add<float, lib::Cuda>(int count, const Blob* lhs, const Blob* rhs,
-                        Blob* ret, Context* ctx) {
-  cublasSetStream(ctx->handle, ctx->stream);
-  cublasScopy(ctx->handle, count, lhs->data(), 1, ret->mutable_data(), 1);
-  cublasSaxpy(ctx->handle, 1.0f, rhs->data(), 1, ret->mutable_data(), 1);
-}
-
-#ifdef USE_CUDNN
-template<>
-void Conv<float, lib::Cudnn>(const OpConf *conf,
-          const Blob* input,
-          const Blob* W,
-          const Blob* b,
-          Blob* ret,
-          Context* ctx) {
-  // auto conv_conf = conf->CastTo<ConvConf>();
-  // conv op
-}
-
-#endif
-#endif
-}

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/02851fac/src/core/math/opencl_math.cc
----------------------------------------------------------------------
diff --git a/src/core/math/opencl_math.cc b/src/core/math/opencl_math.cc
deleted file mode 100644
index 7012610..0000000
--- a/src/core/math/opencl_math.cc
+++ /dev/null
@@ -1,24 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "singa/core/math.h"
-
-namespace singa {
-
-
-}

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/02851fac/src/core/tensor/tensor.cc
----------------------------------------------------------------------
diff --git a/src/core/tensor/tensor.cc b/src/core/tensor/tensor.cc
index 8fdc2ed..51b785e 100644
--- a/src/core/tensor/tensor.cc
+++ b/src/core/tensor/tensor.cc
@@ -15,28 +15,42 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-
 #include "singa/core/tensor.h"
-#include "singa/core/math.h"
+#include "./tensor_math.h"
+#include "./tensor_math_cpp.h"
+#include "./tensor_math_cuda.h"
+#include "./tensor_math_opencl.h"
 
 namespace singa {
+
 Tensor::~Tensor() {
   if (blob_ != nullptr && blob_->DecRefCount() == 0)
     device_->FreeBlob(blob_);
   blob_ = nullptr;
 }
 
+Tensor::Tensor() {
+  device_ = &hostDeviceSingleton;
+}
+
 Tensor::Tensor(const Shape& shape, DataType dtype)
     : data_type_(dtype), device_(&hostDeviceSingleton), shape_(shape) {
   device_ = &hostDeviceSingleton;
   blob_ = device_->NewBlob(Product(shape_) * SizeOf(data_type_));
 }
-
+Tensor::Tensor(Shape&& shape, DataType dtype)
+    : data_type_(dtype), device_(&hostDeviceSingleton), shape_(shape) {
+  device_ = &hostDeviceSingleton;
+  blob_ = device_->NewBlob(Product(shape_) * SizeOf(data_type_));
+}
 Tensor::Tensor(const Shape& shape, Device* device, DataType dtype)
     : data_type_(dtype), device_(device), shape_(shape) {
   blob_ = device_->NewBlob(Product(shape_) * SizeOf(data_type_));
 }
-
+Tensor::Tensor(Shape&& shape, Device* device, DataType dtype)
+    : data_type_(dtype), device_(device), shape_(shape) {
+  blob_ = device_->NewBlob(Product(shape_) * SizeOf(data_type_));
+}
 Tensor::Tensor(const Tensor& t)
     : transpose_(t.transpose_),
       data_type_(t.data_type_),
@@ -50,7 +64,7 @@ Tensor::Tensor(Tensor&& t)
     : transpose_(t.transpose_),
       data_type_(t.data_type_),
       device_(t.device_),
-      shape_(t.shape_) {
+      shape_(std::move(t.shape_)) {
   blob_ = t.blob_;
   t.blob_ = nullptr;
 }
@@ -90,18 +104,26 @@ void Tensor::ToHost() {
   ToDevice(device_->host());
 }
 
-void Tensor::CopyDataFromHostPtr(const void* src, size_t size) {
+template<typename DType>
+void Tensor::CopyDataFromHostPtr(const DType* src, int num) {
+  CHECK_EQ(sizeof(DType), SizeOf(data_type_)) << "data_type is "
+                                              << DataType_Name(data_type_)
+                                              << " user given type is of size "
+                                              << sizeof(DType);
   if (src != nullptr)
-    device_->CopyDataFromHostPtr(blob(), src, size);
+    device_->CopyDataFromHostPtr(blob(), src, sizeof(DType) * num);
   else
     LOG(WARNING) << "Copy data from null host ptr";
 }
+template void Tensor::CopyDataFromHostPtr(const float* src, int num);
 
 void Tensor::CopyData(const Tensor& src) {
   CHECK_EQ(Size(), src.Size());
+  CHECK(blob_ != nullptr);
   // Do copy only if the src's blob is already initialized.
-  if (src.blob_ != nullptr)
-    singa::CopyData(this, src, Size() * SizeOf(data_type_), 0, 0);
+  if (src.blob_ != nullptr) {
+    singa::CopyData(this, src, Size(), 0, 0);
+  }
 }
 
 Tensor Tensor::Clone() {
@@ -112,8 +134,10 @@ Tensor Tensor::Clone() {
 }
 
 Tensor Tensor::T() const {
+  CHECK_EQ(shape_.size(), 2);
   Tensor t(*this);
   t.transpose_ = ~transpose_;
+  std::swap(shape_[0], shape_[1]);
   return t;
 }
 
@@ -132,80 +156,315 @@ void Tensor::operator=(Tensor&& t) {
   if (blob_ != nullptr && blob_->DecRefCount() == 0)
     device_->FreeBlob(blob_);
   transpose_ = t.transpose_;
-  shape_ = t.shape_;
+  shape_ = std::move(t.shape_);
   device_ = t.device_;
   blob_ = t.blob_;
   t.blob_ = nullptr;
 }
 
-void Tensor::operator+=(const Tensor& t) {
-  Add(*this, t, this);
-}
-// ====================Tensor Operations=======================================
+#define GenUnaryTensorArgMemberFunction(op, fn) \
+  void Tensor::op(const Tensor& t) { fn(*this, t, this); }
+
+GenUnaryTensorArgMemberFunction(operator+=, Add);
+GenUnaryTensorArgMemberFunction(operator-=, Sub);
+GenUnaryTensorArgMemberFunction(operator*=, EltwiseMult);
+GenUnaryTensorArgMemberFunction(operator/=, Div);
+
+#define GenUnaryScalarArgMemberFunction(op, fn) \
+  template <typename DType>                     \
+  void Tensor::op(DType x) {                    \
+    fn(*this, x, this);                         \
+  }                                             \
+  template void Tensor::op<float>(float x)
+
+GenUnaryScalarArgMemberFunction(operator-=, Sub);
+GenUnaryScalarArgMemberFunction(operator+=, Add);
+GenUnaryScalarArgMemberFunction(operator*=, EltwiseMult);
+GenUnaryScalarArgMemberFunction(operator/=, Div);
 
+// ====================Tensor Operations=======================================
 void CopyData(Tensor* dst,
               const Tensor& src,
-              int len,
+              int num,
               int dst_offset,
               int src_offset) {
-  CHECK_GE(src.MemSize(), src_offset + len);
-  CHECK_GE(dst->MemSize(), dst_offset + len);
+  CHECK_GE(src.Size(), src_offset + num);
+  CHECK_GE(dst->Size(), dst_offset + num);
+  int width = SizeOf(src.data_type());
+  CHECK_EQ(width, SizeOf(dst->data_type()));
+  CopyRawData(dst, src, num * width, dst_offset * width, src_offset * width);
+}
+
+void CopyRawData(Tensor* dst,
+              const Tensor& src,
+              int nBytes,
+              int dst_offset,
+              int src_offset) {
+  CHECK_GE(src.MemSize(), src_offset + nBytes);
+  CHECK_GE(dst->MemSize(), dst_offset + nBytes);
   Device* src_dev = src.device(), *dst_dev = dst->device();
   Blob* src_blob = src.blob(), *dst_blob = dst->blob();
   if (dst_dev->device_lib() != src_dev->device_lib()) {
     // let the none cpp device conduct copy op
     if (dst_dev->device_lib() == kCpp) {
-      src_dev->CopyData(dst_blob, *src_blob, len, dst_offset, src_offset);
+      src_dev->CopyData(dst_blob, *src_blob, nBytes, dst_offset, src_offset);
     } else if (src_dev->device_lib() == kCpp) {
-      dst_dev->CopyData(dst_blob, *src_blob, len, dst_offset, src_offset);
+      dst_dev->CopyData(dst_blob, *src_blob, nBytes, dst_offset, src_offset);
     } else {
       LOG(FATAL) << "Not support mem copy betwee Cuda and OpenCL device";
     }
   } else {
-    src_dev->CopyData(dst_blob, *src_blob, len, dst_offset, src_offset);
+    src_dev->CopyData(dst_blob, *src_blob, nBytes, dst_offset, src_offset);
   }
 }
+//============================================================================
+/// typedef DType accroding to type value.
+/// DType would be used in the code block __VA_ARGS__.
+#define TYPE_SWITCH(type, DType, ...)                               \
+  do {                                                              \
+    switch (type) {                                                 \
+      case kFloat32: {                                              \
+        typedef float DType;                                        \
+        { __VA_ARGS__ }                                             \
+        break;                                                      \
+      }                                                             \
+      case kInt: {                                                  \
+        typedef int DType;                                          \
+        { __VA_ARGS__ }                                             \
+        break;                                                      \
+      }                                                             \
+      case kChar: {                                                 \
+        typedef char DType;                                         \
+        { __VA_ARGS__ }                                             \
+        break;                                                      \
+      }                                                             \
+      default:                                                      \
+        LOG(FATAL) << "Unknow data type = " << DataType_Name(type); \
+    }                                                               \
+  } while (0)
+
+/// typedef DType and Lib according to values of type and lib respectively.
+/// type is from DataType, and lib is from LibType.
+/// DType and Lib would be used in __VA_ARGS__.
+#define TYPE_LIB_SWITCH(dtype, DType, ltype, Lib, ...)        \
+  do {                                                        \
+    const int _SwitchShift = 3;                               \
+    int _SwitchHash = ((dtype) << _SwitchShift) + (ltype);    \
+    switch (_SwitchHash) {                                    \
+      case ((kFloat32 << _SwitchShift) + kCuda): {            \
+        typedef float DType;                                  \
+        typedef lib::Cuda Lib;                                \
+        { __VA_ARGS__ }                                       \
+        break;                                                \
+      }                                                       \
+      case ((kFloat32 << _SwitchShift) + kCudnn): {           \
+        typedef float DType;                                  \
+        typedef lib::Cudnn Lib;                               \
+        { __VA_ARGS__ }                                       \
+        break;                                                \
+      }                                                       \
+      case ((kFloat32 << _SwitchShift) + kCpp): {             \
+        typedef float DType;                                  \
+        typedef lib::Cpp Lib;                                 \
+        { __VA_ARGS__ }                                       \
+        break;                                                \
+      }                                                       \
+      case ((kFloat32 << _SwitchShift) + kOpencl): {          \
+        typedef float DType;                                  \
+        typedef lib::Opencl Lib;                              \
+        { __VA_ARGS__ }                                       \
+        break;                                                \
+      }                                                       \
+      default:                                                \
+        LOG(FATAL) << "Unknown combination of data type "     \
+                   << DataType_Name(dtype) << " and library " \
+                   << LibType_Name(ltype);                    \
+    }                                                         \
+  } while (0)
+
+
+#define EltwiseUnaryTensorFn(fn, t, ret)                                   \
+  do {                                                                     \
+    TYPE_LIB_SWITCH(t.data_type(), DType, t.device()->device_lib(), Lib, { \
+      ret->device()->Submit(                                               \
+          [t, ret](Context* ctx) {                                         \
+            fn<DType, Lib>(t.Size(), t.blob(), ret->blob(), ctx);          \
+          },                                                               \
+          {t.blob()}, {ret->blob()});                                      \
+    });                                                                    \
+  } while (0)
+
+#define GenUnaryTensorFunction(fn)                    \
+  Tensor fn(const Tensor& t) {                        \
+    Tensor ret(t.shape(), t.device(), t.data_type()); \
+    auto* retptr = &ret;                              \
+    EltwiseUnaryTensorFn(fn, t, retptr);              \
+    return ret;                                       \
+  }
+
+GenUnaryTensorFunction(Abs);
+GenUnaryTensorFunction(Exp);
+GenUnaryTensorFunction(Log);
+GenUnaryTensorFunction(ReLU);
+GenUnaryTensorFunction(Sigmoid);
+GenUnaryTensorFunction(Sign);
+GenUnaryTensorFunction(Sqrt);
+GenUnaryTensorFunction(Tanh);
 
-Tensor operator+(const Tensor& lhs, const Tensor& rhs) {
-  Tensor ret(lhs.shape(), lhs.device());
-  Add(lhs, rhs, &ret);
+Tensor Softmax(const Tensor& t, int axis) {
+  Tensor ret(t.shape(), t.device(), t.data_type());
+  Softmax(t, &ret, axis);
   return ret;
 }
 
-void Add(const Tensor& lhs, const Tensor& rhs, Tensor* ret) {
-  TYPE_LIB_SWITCH(lhs.data_type(), DType, lhs.device()->device_lib(), Lib, {
+void Softmax(const Tensor& t, Tensor* ret, int axis) {
+  int nrow = 1, ncol = t.Size(), size = ncol;
+  CHECK_GE(axis, -1);
+  CHECK_GT(t.shape().size(), 0);
+  if (axis > -1) {
+    nrow = Product(t.shape().begin(), t.shape().begin() + axis + 1);
+    CHECK_EQ(size % nrow, 0) << "Size = " << size << " nrow = " << nrow;
+    ncol = size / nrow;
+  }
+  TYPE_LIB_SWITCH(t.data_type(), DType, t.device()->device_lib(), Lib, {
     ret->device()->Submit(
-        [lhs, rhs, ret](Context* ctx) {
-          Add<DType, Lib>(lhs.Size(), lhs.blob(), rhs.blob(), ret->blob(), ctx);
+        [nrow, ncol, t, ret](Context* ctx) {
+          Softmax<DType, Lib>(nrow, ncol, t.blob(), ret->blob(), ctx);
         },
-        {lhs.blob(), rhs.blob()}, {ret->blob()});
-  });
+        {t.blob()}, {ret->blob()});
+    });
 }
-/*
-Tensor operator-(const Tensor& lhs, const Tensor& rhs) {
-  Tensor ret(lhs.shape(), lhs.device());
-  Sub(lhs, rhs, &ret);
+
+#define EltwiseBinaryTensorFn(fn, lhs, rhs, ret)                               \
+  do {                                                                         \
+    TYPE_LIB_SWITCH(lhs.data_type(), DType, lhs.device()->device_lib(), Lib, { \
+      ret->device()->Submit(                                                   \
+          CHECK_EQ(sizeof(DType), SizeOf(rhs.data_type()));                    \
+          [lhs, rhs, ret](Context* ctx) {                                      \
+            fn<DType, Lib>(lhs.Size(), lhs.blob(), rhs.blob(), ret->blob(),    \
+                           ctx);                                               \
+          },                                                                   \
+          {lhs.blob(), rhs.blob()}, {ret->blob()});                            \
+    });                                                                        \
+  } while (0)
+
+#define GenBinaryTensorFunction(op, fn)                        \
+  Tensor op(const Tensor& lhs, const Tensor& rhs) {            \
+    Tensor ret(lhs.shape(), lhs.device(), lhs.data_type());    \
+    fn(lhs, rhs, &ret);                                        \
+    return ret;                                                \
+  }                                                            \
+  void fn(const Tensor& lhs, const Tensor& rhs, Tensor* ret) { \
+    EltwiseBinaryTensorFn(fn, lhs, rhs, ret);                  \
+  }
+
+GenBinaryTensorFunction(operator+, Add);
+GenBinaryTensorFunction(operator-, Sub);
+GenBinaryTensorFunction(operator*, EltwiseMult);
+GenBinaryTensorFunction(operator/, Div);
+GenBinaryTensorFunction(Pow, Pow);
+
+#define EltwiseTensorScalarFn(fn, t, x, ret)                                \
+  do {                                                                      \
+    TYPE_LIB_SWITCH(t.data_type(), DType, t.device()->device_lib(), Lib, {  \
+      ret->device()->Submit(                                                \
+          static_assert(typeid(x) == typeid(DType),                         \
+                        "The Scalar type must match the Tensor data type"); \
+          [t, x, ret](Context* ctx) {                                       \
+            fn<DType, Lib>(t.Size(), t.blob(), x, ret->blob(), ctx);        \
+          },                                                                \
+          {t.blob()}, {ret->blob()});                                       \
+    });                                                                     \
+  } while (0)
+
+#define GenTensorScalarFunction(op, fn)                \
+  template <typename DType>                                \
+  Tensor op(const Tensor& t, DType x) {                    \
+    Tensor ret(t.shape(), t.device(), t.data_type());  \
+    fn(t, x, &ret);                                    \
+    return ret;                                        \
+  }                                                    \
+  template <typename DType>                                \
+  void fn(const Tensor& t, DType x, Tensor* ret) {   \
+    EltwiseTensorScalarFn(fn, t, x, ret);              \
+  }                                                    \
+  template Tensor op<float>(const Tensor& t, float x); \
+  template void fn<float>(const Tensor& t, const float x, Tensor* ret)
+
+GenTensorScalarFunction(operator+, Add);
+GenTensorScalarFunction(operator-, Sub);
+GenTensorScalarFunction(operator*, EltwiseMult);
+GenTensorScalarFunction(operator/, Div);
+GenTensorScalarFunction(Pow, Pow);
+
+// ================Blas operations============================================
+template <typename DType>
+Tensor Mult(const Tensor& lhs, const Tensor& rhs) {
+  Tensor ret(lhs.shape(), lhs.device(), lhs.data_type());
+  Mult<DType>(lhs, rhs, &ret);
+  return ret;
+}
+template Tensor Mult<float>(const Tensor& lhs, const Tensor& rhs);
+
+template <typename DType>
+void Mult(const Tensor& lhs, const Tensor& rhs, Tensor* ret) {
+  Mult(DType(1), lhs, DType(1), rhs, ret);
+}
+template void Mult<float>(const Tensor& lhs, const Tensor& rhs, Tensor* ret);
+
+template <typename DType>
+Tensor Mult(DType alpha, const Tensor& A, DType beta, const Tensor& B) {
+  Tensor ret(A.shape(), A.device(), A.data_type());
+  Mult<DType>(alpha, A, beta, B, &ret);
   return ret;
 }
+template Tensor Mult<float>(float alpha, const Tensor& lhs, float beta,
+    const Tensor& rhs);
 
-void Sub(const Tensor& lhs, const Tensor& rhs, Tensor *ret) {
-  TYPE_LIB_SWITCH(lhs.data_type(), DType, lhs.device()->device_lib(), Lib, {
-      ret->device()->Submit(
-        [lhs, rhs, ret](Context* ctx) {
-          Sub<DType, Lib>(
-            lhs.Size(),
-            lhs.blob(),
-            rhs.blob(),
-            ret->blob(),
-            ctx);}
-        , {lhs.blob(), rhs.blob()}, {ret->blob()});
+template <typename SType>
+void Mult(SType alpha, const Tensor& A, SType beta, const Tensor& B, Tensor* C)
+{
+  CHECK_EQ(A.shape().size(), 2);
+  bool transA = A.transpose();
+  int m = transA ? A.shape()[1] : A.shape()[0], n = 0;
+  if (B.shape().size() == 1) {
+    n = C->Size();
+    TYPE_LIB_SWITCH(A.data_type(), DType, A.device()->device_lib(), Lib, {
+      static_assert(std::is_same<SType, DType>::value,
+        "The scalar type must be the same as the tensor data type");
+      C->device()->Submit(
+        [transA, m, n, alpha, A, beta, B, C](Context* ctx) {
+        GEMV<DType, Lib>(transA, m, n, alpha, A.blob(),
+          B.blob(), beta, C->blob(), ctx);
+        },
+        {A.blob(), B.blob()}, {C->blob()});
       });
+  } else {
+    CHECK(!C->transpose());
+    bool transB = B.transpose();
+    int k = transB ? B.shape()[1] : B.shape()[0];
+    n = C->shape()[1];
+    CHECK_EQ(C->shape()[0], m);
+    CHECK_EQ(A.Size(), m * k);
+    CHECK_EQ(B.Size(), n * k);
+    TYPE_LIB_SWITCH(A.data_type(), DType, A.device()->device_lib(), Lib, {
+        static_assert(std::is_same<SType, DType>::value,
+          "The scalar type must be the same as the tensor data type");
+        C->device()->Submit(
+          [transA, transB, m, n, k, alpha, A, beta, B, C](Context* ctx) {
+          GEMM<DType, Lib>(transA, transB, m, n, k, alpha, A.blob(),
+            B.blob(), beta, C->blob(), ctx);
+          },
+          {A.blob(), B.blob()}, {C->blob()});
+        });
+  }
 }
+template void Mult<float>(float alpha, const Tensor& lhs, float beta,
+    const Tensor& rhs, Tensor* ret);
 
-// ================Blas operations============================================
 
 // ================Neural Net operations======================================
-
+/*
 void Conv(const OpConf* conf, const Tensor& input, const Tensor& W,
           const Tensor& b, Tensor* ret) {
   TYPE_LIB_SWITCH(input.data_type(), DType, input.device()->nn_lib(), Lib, {
@@ -218,5 +477,33 @@ void Conv(const OpConf* conf, const Tensor& input, const Tensor& W,
   });
 }
 */
+void Bernoulli(float threshold, Tensor* t) {
+  TYPE_LIB_SWITCH(t->data_type(), DType, t->device()->nn_lib(), Lib, {
+    t->device()->Submit(
+        [threshold, t](Context* ctx) {
+          Bernoulli<DType, Lib>(t->Size(), threshold, t->blob(), ctx);
+        },
+        {}, {t->blob()});
+  });
+}
+
+void Uniform(float low, float high, Tensor* t) {
+  TYPE_LIB_SWITCH(t->data_type(), DType, t->device()->nn_lib(), Lib, {
+    t->device()->Submit(
+        [low, high, t](Context* ctx) {
+          Uniform<DType, Lib>(t->Size(), low, high, t->blob(), ctx);
+        },
+        {}, {t->blob()});
+  });
+}
 
+void Gaussian(float mean, float std, Tensor* t) {
+  TYPE_LIB_SWITCH(t->data_type(), DType, t->device()->nn_lib(), Lib, {
+    t->device()->Submit(
+        [mean, std, t](Context* ctx) {
+          Gaussian<DType, Lib>(t->Size(), mean, std, t->blob(), ctx);
+        },
+        {}, {t->blob()});
+  });
+}
 }  // namespace singa

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/02851fac/src/core/tensor/tensor_math.h
----------------------------------------------------------------------
diff --git a/src/core/tensor/tensor_math.h b/src/core/tensor/tensor_math.h
new file mode 100644
index 0000000..a4f68e3
--- /dev/null
+++ b/src/core/tensor/tensor_math.h
@@ -0,0 +1,302 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#ifndef SINGA_CORE_MATH_H_
+#define SINGA_CORE_MATH_H_
+#include <type_traits>
+#include "singa/core/common.h"
+#include "singa/utils/logging.h"
+
+namespace singa {
+
+/// \file math.h Math functions for linear algebra, neural net and random
+/// operations.
+/// All functions have a template argument, DType for DataType, Lib for the
+/// backend library, e.g., lib::Cublas, lib::Cudnn, etc.
+
+/// Some operations would have many config/hyper-parameters, e.g., Conv, and
+/// these config vary among diff implementations, e.g., cuda/cudnn/opencl.
+/// To separate the modules, we pass a OpConf pointer to the Tensor Op function.
+/// The specific fields are implemented by inheriting OpConf, and casting the
+/// pointer between the base and the sub-class.
+class OpConf {
+ public:
+  template <typename T>
+  T* CastTo() {
+    static_assert(std::is_base_of<OpConf, T>::value,
+                  "The cast type must be a sub-class of OpConf");
+    return static_cast<T*>(this);
+  }
+};
+
+// ================Linear algebra functions====================================
+/// ret[i] = |input[i]|
+template <typename DType, typename Lib>
+void Abs(int count, const Blob* input, Blob* ret, Context* ctx) {
+  LOG(FATAL) << "Not Implemented";
+}
+
+/// sum all elements of input into ret
+template <typename DType, typename Lib>
+void Sum(int count, const Blob* input, DType* ret, Context* ctx) {
+  LOG(FATAL) << "Not Implemented";
+}
+
+/// ret[i] = sign(input[i])
+template <typename DType, typename Lib>
+void Sign(int count, const Blob* input, Blob* ret, Context* ctx) {
+  LOG(FATAL) << "Not Implemented";
+}
+
+/// Base is e, Neper number. ret[i]=exp(input[i])
+template <typename DType, typename Lib>
+void Exp(int count, const Blob* input, Blob* ret, Context* ctx) {
+  LOG(FATAL) << "Not Implemented";
+}
+
+/// Natual logarithm, the base is e, Neper number ret[i]=log(input[i]).
+template <typename DType, typename Lib>
+void Log(int count, const Blob* input, Blob* ret, Context* ctx) {
+  LOG(FATAL) << "Not Implemented";
+}
+
+/// Element-wise operation, ret[i]=sqrt([input[i])
+template <typename DType, typename Lib>
+void Sqrt(int count, const Blob* input, Blob* ret, Context* ctx) {
+  LOG(FATAL) << "Not Implemented";
+}
+
+/// Element-wise operation, ret[i]=tanh([input[i])
+template <typename DType, typename Lib>
+void Tanh(int count, const Blob* input, Blob* ret, Context* ctx) {
+  LOG(FATAL) << "Not Implemented";
+}
+/// Element-wise operation, ret[i]=max(0, input[i])
+template <typename DType, typename Lib>
+void ReLU(int count, const Blob* input, Blob* ret, Context* ctx) {
+  LOG(FATAL) << "Not Implemented";
+}
+/// Element-wise operation, ret[i]=sigmoid([input[i])
+template <typename DType, typename Lib>
+void Sigmoid(int count, const Blob* input, Blob* ret, Context* ctx) {
+  LOG(FATAL) << "Not Implemented";
+}
+
+/// Element-wise operation, do v^x for every v from the input tensor
+template <typename DType, typename Lib>
+void Pow(int count, const Blob* input, DType x, Blob* ret, Context* ctx) {
+  LOG(FATAL) << "Not Implemented";
+}
+
+/// Element-wise operation, do v^x for every v from the lhs and every x from rhs
+template <typename DType, typename Lib>
+void Pow(int count, const Blob* lhs, const Blob* rhs, Blob* ret, Context* ctx) {
+  LOG(FATAL) << "Not Implemented";
+}
+
+/// Element-wise operation, clamp every element into [low, high]
+/// if x>high, then x=high; if x<low, then x=low.
+template <typename DType, typename Lib>
+void Clamp(int count, DType low, DType high, const Blob* input, Blob* ret,
+           Context* ctx) {
+  LOG(FATAL) << "Not Implemented";
+}
+
+/// ret = input + x
+template <typename DType, typename Lib>
+void Add(int count, const Blob* input, DType x, Blob* ret, Context* ctx) {
+  LOG(FATAL) << "Not Implemented";
+}
+/// ret =  input - x
+template <typename DType, typename Lib>
+void Sub(int count, const Blob* input, DType x, Blob* ret, Context* ctx) {
+  Add<DType, Lib>(count, input, -x, ret, ctx);
+}
+/// ret = input * x
+template <typename DType, typename Lib>
+void EltwiseMult(int count, const Blob* input, DType x, Blob* ret, Context* ctx)
+{
+  LOG(FATAL) << "Not Implemented";
+}
+/// ret = input / x
+template <typename DType, typename Lib>
+void Div(int count, const Blob* input, DType x, Blob* ret, Context* ctx) {
+  EltwiseMult<DType, Lib>(count, input, DType(1) / x, ret, ctx);
+}
+
+/// ret = lhs + rhs
+template <typename DType, typename Lib>
+void Add(int count, const Blob* lhs, const Blob* rhs, Blob* ret, Context* ctx) {
+  LOG(FATAL) << "Not Implemented";
+}
+
+/// ret = lhs - rhs
+template <typename DType, typename Lib>
+void Sub(int count, const Blob* lhs, const Blob* rhs, Blob* ret, Context* ctx) {
+  LOG(FATAL) << "Not Implemented";
+}
+
+/// ret = lhs * rhs
+template <typename DType, typename Lib>
+void EltwiseMult(int count, const Blob* lhs, const Blob* rhs, Blob* ret,
+          Context* ctx) {
+  LOG(FATAL) << "Not Implemented";
+}
+
+/// ret = lhs / rhs
+template <typename DType, typename Lib>
+void Div(int count, const Blob* lhs, const Blob* rhs, Blob* ret, Context* ctx) {
+  LOG(FATAL) << "Not Implemented";
+}
+
+/// outer-product.
+/// lhs and rhs are vectors of len m and n. ret is matrix of shape m * n
+template <typename DType, typename Lib>
+void Outer(int m, int n, const Blob* lhs, const Blob* rhs, Blob* ret,
+           Context* ctx) {
+  LOG(FATAL) << "Not Implemented";
+}
+
+// TODO(wangwei) unify SumRow and SumCol.
+/// Sum the rows of the input matrix into a vector
+template <typename DType, typename Lib>
+void SumRow(int nrow, int ncol, const Blob* input, Blob* ret, Context* ctx) {
+  LOG(FATAL) << "Not Implemented";
+}
+/// Sum the rows of the input matrix into a vector
+template <typename DType, typename Lib>
+void SumCol(int nrow, int ncol, const Blob* input, Blob* ret, Context* ctx) {
+  LOG(FATAL) << "Not Implemented";
+}
+
+// TODO(wangwei) unify AddRow and AddCol.
+/// Add the vector v to every row of A as the row of ret
+template <typename DType, typename Lib>
+void AddRow(int nrow, int ncol, const Blob* A, const Blob* v, Blob* ret,
+            Context* ctx) {
+  LOG(FATAL) << "Not Implemented";
+}
+
+/// Add the vector v to every column of A as the column of ret
+template <typename DType, typename Lib>
+void AddCol(int nrow, int ncol, const Blob* A, const Blob* v, Blob* ret,
+            Context* ctx) {
+  LOG(FATAL) << "Not Implemented";
+}
+
+// ===== BLAS functions, ref to http://docs.nvidia.com/cuda/cublas
+// ===== Level 1
+/// return the index of the element with the max value.
+template <typename DType, typename Lib>
+void Amax(int count, const Blob* input, int* ret, Context* ctx) {
+  LOG(FATAL) << "Not Implemented";
+}
+
+/// return the index of the element with the min value.
+template <typename DType, typename Lib>
+void Amin(int count, const Blob* input, int* ret, Context* ctx) {
+  LOG(FATAL) << "Not Implemented";
+}
+/// ret = sum |x| for all x in input
+template <typename DType, typename Lib>
+void Asum(int count, const Blob* input, DType* ret, Context* ctx) {
+  LOG(FATAL) << "Not Implemented";
+}
+
+/// ret = alpha * input + ret
+template <typename DType, typename Lib>
+void Axpy(int count, DType alpha, const Blob* input, Blob* ret, Context* ctx) {
+  LOG(FATAL) << "Not Implemented";
+}
+
+/// ret *= x
+template <typename DType, typename Lib>
+void Scale(int count, DType x, Blob* ret, Context* ctx) {
+  LOG(FATAL) << "Not Implemented";
+}
+
+template <typename DType, typename Lib>
+void Dot(int count, const Blob* lhs, const Blob* rhs, DType* ret,
+         Context* ctx) {
+  LOG(FATAL) << "Not Implemented";
+}
+
+// ===== Level 2
+/// ret = alpha * op(A) * v + beta * ret.
+/// op(A) = A if trans = false; A^T otherwise; rows(op(A)) = m, cols(op(A)) = n.
+template <typename DType, typename Lib>
+void GEMV(bool trans, int m, int n, DType alpha, const Blob* A, const Blob* v,
+          DType beta, Blob* ret, Context* ctx) {
+  LOG(FATAL) << "Not Implemented";
+}
+
+// ===== Level 3
+/// ret = alpha * op(A) * op(B) + beta * ret.
+/// op(A) = A if trans = false; A^T otherwise; rows(ret) = m, cols(ret) = n.
+template <typename DType, typename Lib>
+void GEMM(bool transA, bool transB, int m, int n, int k, DType alpha,
+          const Blob* A, const Blob* B, DType beta, Blob* ret, Context* ctx) {
+  LOG(FATAL) << "Not Implemented";
+}
+
+// ================Random functions===========================================
+/// Each element of ret would be 1 with prob p and 0 with 1-p. 0<= p <= 1
+// Get the random generator from 'ctx'
+// If DType is not float, then convert the threshold to DType
+template <typename DType, typename Lib>
+void Bernoulli(int count, float threshold, Blob* ret, Context* ctx) {
+  LOG(FATAL) << "Not Implemented";
+}
+// The random generator should be extracted from ctx.
+// If DType is not float, then convert the low and high to DType
+template <typename DType, typename Lib>
+void Uniform(int count, float low, float high, Blob* ret, Context* ctx) {
+  LOG(FATAL) << "Not Implemented";
+}
+// The random generator should be extracted from ctx.
+// If DType is not float, then convert the mean and std to DType
+template <typename DType, typename Lib>
+void Gaussian(int count, float mean, float std, Blob* ret, Context* ctx) {
+  LOG(FATAL) << "Not Implemented";
+}
+
+// ================Neural net functions=======================================
+template <typename DType, typename Lib>
+void ConvFwd(ConvConf* conf, const Blob* x, const Blob* w, Blob* y,
+             Context* ctx) {
+  LOG(FATAL) << "Not Implemented";
+}
+
+template <typename DType, typename Lib>
+void ConvBwdBias(const ConvConf* conf, const Blob* dy, Blob* db, Context* ctx) {
+  LOG(FATAL) << "Not Implemented";
+}
+
+template <typename DType, typename Lib>
+void PoolFwd(const PoolConf* conf, const Blob* x, Blob* y, Context* ctx) {
+  LOG(FATAL) << "Not Implemented";
+}
+
+template <typename DType, typename Lib>
+void PoolBwd(const PoolConf* conf, const Blob* y, const Blob* dy, const Blob* x,
+             Blob* dx, Context* ctx) {
+  LOG(FATAL) << "Not Implemented";
+}
+
+}  // namespace singa
+
+#endif  // SINGA_CORE_MATH_H_

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/02851fac/src/core/tensor/tensor_math_cpp.h
----------------------------------------------------------------------
diff --git a/src/core/tensor/tensor_math_cpp.h b/src/core/tensor/tensor_math_cpp.h
new file mode 100644
index 0000000..a953085
--- /dev/null
+++ b/src/core/tensor/tensor_math_cpp.h
@@ -0,0 +1,57 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#ifndef SINGA_CORE_TENSOR_TENSOR_MATH_CPP_H_
+#define SINGA_CORE_TENSOR_TENSOR_MATH_CPP_H_
+#include "./tensor_math.h"
+#include "singa/core/common.h"
+
+#ifdef USE_CBLAS
+#include <cblas.h>
+#endif
+
+namespace singa {
+template<>
+void Add<float, lib::Cpp>(int count,
+                     const Blob* lhs,
+                     const Blob* rhs,
+                     Blob* ret,
+                     Context* ctx) {
+  // CHECK_EQ(ctx->stream, nullptr);
+  float *dptr = static_cast<float*>(ret->mutable_data());
+  const float *lptr = static_cast<const float*>(lhs->data());
+  const float *rptr = static_cast<const float*>(rhs->data());
+  for (int i = 0; i < count; i++) {
+    dptr[i] = lptr[i] + rptr[i];
+  }
+}
+
+#ifdef USE_CBLAS
+template<>
+void Dot<float, lib::Cpp>(int count,
+                     const Blob* lhs,
+                     const Blob* rhs,
+                     float* ret,
+                     Context* ctx) {
+  float dptr = ret->mutable_data(), lptr = lhs->data(), rptr = rhs->data();
+  *ret = cblas_sdot(count, lptr, 1, rptr, 1);
+}
+
+#endif
+}
+
+#endif  // SINGA_CORE_TENSOR_TENSOR_MATH_CPP_H_

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/02851fac/src/core/tensor/tensor_math_cuda.h
----------------------------------------------------------------------
diff --git a/src/core/tensor/tensor_math_cuda.h b/src/core/tensor/tensor_math_cuda.h
new file mode 100644
index 0000000..e1c72d8
--- /dev/null
+++ b/src/core/tensor/tensor_math_cuda.h
@@ -0,0 +1,53 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef  SINGA_CORE_TENSOR_TENSOR_MATH_CUDA_H_
+#define  SINGA_CORE_TENSOR_TENSOR_MATH_CUDA_H_
+#include "./tensor_math.h"
+#include "singa/core/common.h"
+
+
+namespace singa {
+
+#ifdef USE_CUDA
+template<>
+void Add<float, lib::Cuda>(int count, const Blob* lhs, const Blob* rhs,
+                        Blob* ret, Context* ctx) {
+  cublasSetStream(ctx->handle, ctx->stream);
+  cublasScopy(ctx->handle, count, lhs->data(), 1, ret->mutable_data(), 1);
+  cublasSaxpy(ctx->handle, 1.0f, rhs->data(), 1, ret->mutable_data(), 1);
+}
+
+#ifdef USE_CUDNN
+template<>
+void Conv<float, lib::Cudnn>(const OpConf *conf,
+          const Blob* input,
+          const Blob* W,
+          const Blob* b,
+          Blob* ret,
+          Context* ctx) {
+  // auto conv_conf = conf->CastTo<ConvConf>();
+  // conv op
+}
+
+#endif
+#endif
+}
+
+
+#endif  // SINGA_CORE_TENSOR_TENSOR_MATH_CUDA_H_

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/02851fac/src/core/tensor/tensor_math_opencl.h
----------------------------------------------------------------------
diff --git a/src/core/tensor/tensor_math_opencl.h b/src/core/tensor/tensor_math_opencl.h
new file mode 100644
index 0000000..c4b1347
--- /dev/null
+++ b/src/core/tensor/tensor_math_opencl.h
@@ -0,0 +1,28 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef  SINGA_CORE_TENSOR_TENSOR_MATH_OPENCL_H_
+#include "./tensor_math.h"
+
+namespace singa {
+
+
+}
+
+
+#endif  // SINGA_CORE_TENSOR_TENSOR_MATH_OPENCL_H_

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/02851fac/src/model/layer/layer.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/layer.cc b/src/model/layer/layer.cc
index 1f0e34d..0e83cde 100644
--- a/src/model/layer/layer.cc
+++ b/src/model/layer/layer.cc
@@ -18,5 +18,13 @@
 #include "singa/model/layer.h"
 
 namespace singa {
+const vector<Tensor> ComputeFeature(int flag, const vector<Tensor>& input) {
+  const vector<Blob*> input_blobs;
 
+}
+
+void ComputeFeature(int flag, const vector<Tensor>& input) {
+  const vector<Blob*> input_blobs;
+
+}
 }  // namespace singa

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/02851fac/src/proto/layer.proto
----------------------------------------------------------------------
diff --git a/src/proto/layer.proto b/src/proto/layer.proto
index bb87af9..0fbbb5d 100644
--- a/src/proto/layer.proto
+++ b/src/proto/layer.proto
@@ -97,6 +97,10 @@ message ParamSpec {
 
   // The multiplier on the global weight decay for this parameter.
   optional float decay_mult = 4 [default = 1.0];
+
+  // SINGA field for creating diff Param, e.g. SparseParam or CompressableParam
+  // Curently only have a default param implementation.
+  optional string type = 20 [default = "default"];
 }
 
 // NOTE
@@ -154,27 +158,27 @@ message LayerConf {
   optional ConcatConf concat_conf = 104;
   optional ContrastiveLossConf contrastive_loss_conf = 105;
   optional ConvolutionConf convolution_conf = 106;
-  optional DataConf data_conf = 107;
+  // optional DataConf data_conf = 107;
   optional DropoutConf dropout_conf = 108;
-  optional DummyDataConf dummy_data_conf = 109;
+  // optional DummyDataConf dummy_data_conf = 109;
   optional EltwiseConf eltwise_conf = 110;
   optional EmbedConf embed_conf = 137;
   optional ExpConf exp_conf = 111;
   optional FlattenConf flatten_conf = 135;
-  optional HDF5DataConf hdf5_data_conf = 112;
-  optional HDF5OutputConf hdf5_output_conf = 113;
+  // optional HDF5DataConf hdf5_data_conf = 112;
+  // optional HDF5OutputConf hdf5_output_conf = 113;
   optional HingeLossConf hinge_loss_conf = 114;
-  optional ImageDataConf image_data_conf = 115;
+  // optional ImageDataConf image_data_conf = 115;
   optional InfogainLossConf infogain_loss_conf = 116;
   optional InnerProductConf inner_product_conf = 117;
   optional LogConf log_conf = 134;
   optional LRNConf lrn_conf = 118;
-  optional MemoryDataConf memory_data_conf = 119;
+  // optional MemoryDataConf memory_data_conf = 119;
   optional MVNConf mvn_conf = 120;
   optional PoolingConf pooling_conf = 121;
   optional PowerConf power_conf = 122;
   optional PReLUConf prelu_conf = 131;
-  optional PythonConf python_conf = 130;
+  // optional PythonConf python_conf = 130;
   optional ReductionConf reduction_conf = 136;
   optional ReLUConf relu_conf = 123;
   optional ReshapeConf reshape_conf = 133;
@@ -185,7 +189,7 @@ message LayerConf {
   optional TanHConf tanh_conf = 127;
   optional ThresholdConf threshold_conf = 128;
   optional TileConf tile_conf = 138;
-  optional WindowDataConf window_data_conf = 129;
+  //optional WindowDataConf window_data_conf = 129;
 }
 
 // Message that stores hyper-parameters used to apply transformation
@@ -835,7 +839,7 @@ message PReLUConf {
   // Surpassing Human-Level Performance on ImageNet Classification, 2015.
 
   // Initial value of a_i. Default is a_i=0.25 for all i.
-  optional FillerParameter filler = 1;
+  optional FillerConf filler = 1;
   // Whether or not slope paramters are shared across channels.
   optional bool channel_shared = 2 [default = false];
 }

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/02851fac/test/singa/test_cpp_math.cc
----------------------------------------------------------------------
diff --git a/test/singa/test_cpp_math.cc b/test/singa/test_cpp_math.cc
index 268785d..78c713f 100644
--- a/test/singa/test_cpp_math.cc
+++ b/test/singa/test_cpp_math.cc
@@ -20,8 +20,6 @@
 *************************************************************/
 
 #include "gtest/gtest.h"
-#include "singa/core/math.h"
+#include "../src/core/tensor/tensor_math_cpp.h"
 
-TEST(CppMath, Add) {
 
-}

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/02851fac/test/singa/test_tensor.cc
----------------------------------------------------------------------
diff --git a/test/singa/test_tensor.cc b/test/singa/test_tensor.cc
index 04068ae..86200a8 100644
--- a/test/singa/test_tensor.cc
+++ b/test/singa/test_tensor.cc
@@ -15,7 +15,7 @@ TEST(TensorTest, TestConstructor) {
 
   EXPECT_NE(float_t.device(), nullptr);
 
-  singa::Tensor float16_t(singa::Shape{2,3}, singa::kFloat16);
+  singa::Tensor float16_t(Shape{2,3}, singa::kFloat16);
   EXPECT_EQ(singa::kFloat16, float16_t.data_type());
   EXPECT_EQ(6, float16_t.Size());
   EXPECT_EQ(12, float16_t.blob()->size());
@@ -68,7 +68,7 @@ TEST(TensorClass, ToDevice) {
 TEST(TensorClass, CopyDataFromHostPtr) {
   float data[] = {1.0f, 2.0f, 3.0f};
   Tensor t(Shape{3});
-  t.CopyDataFromHostPtr(data, sizeof(float) * 3);
+  t.CopyDataFromHostPtr(data, 3);
   const float* dptr = static_cast<const float*>(t.blob()->data());
   EXPECT_FLOAT_EQ(1.0f, dptr[0]);
   EXPECT_FLOAT_EQ(2.0f, dptr[1]);
@@ -78,7 +78,7 @@ TEST(TensorClass, CopyDataFromHostPtr) {
 TEST(TensorClass, CopyData) {
   float data[] = {1.0f, 2.0f, 3.0f};
   Tensor t(Shape{3});
-  t.CopyDataFromHostPtr(data, sizeof(float) * 3);
+  t.CopyDataFromHostPtr(data, 3);
 
   Tensor o(Shape{3});
   o.CopyData(t);
@@ -91,7 +91,7 @@ TEST(TensorClass, CopyData) {
 TEST(TensorClass, Clone) {
   float data[] = {1.0f, 2.0f, 3.0f};
   Tensor t(Shape{3});
-  t.CopyDataFromHostPtr(data, sizeof(float) * 3);
+  t.CopyDataFromHostPtr(data, 3);
 
   Tensor o = t.Clone();
   const float* dptr = static_cast<const float*>(o.blob()->data());
@@ -110,30 +110,5 @@ TEST(TensorClass, T) {
   EXPECT_TRUE((t.shape() ==  o.shape()));
 }
 
-TEST(TensorClass, Add) {
-  const float data[] = {1.0f, 2.0f, 3.0f, 1.1f, 2.1f, 3.1f};
-  Tensor t(Shape{3});
-  t.CopyDataFromHostPtr(data, sizeof(float) * 3);
 
-  Tensor o = t.Clone();
-  o += t;
-  const float* dptr = o.data<float>();
-  EXPECT_FLOAT_EQ(2.0f, dptr[0]);
-  EXPECT_FLOAT_EQ(4.0f, dptr[1]);
-  EXPECT_FLOAT_EQ(6.0f, dptr[2]);
-
-  Tensor p(Shape{3});
-  o += p;
-  const float* dptr1 = o.data<float>();
-  EXPECT_FLOAT_EQ(2.0f, dptr1[0]);
-  EXPECT_FLOAT_EQ(4.0f, dptr1[1]);
-  EXPECT_FLOAT_EQ(6.0f, dptr1[2]);
-
-  Tensor q(Shape{3});
-  q.CopyDataFromHostPtr(data + 3, sizeof(float) * 3);
-  t += q;
-  const float* dptr2 = t.data<float>();
-  EXPECT_FLOAT_EQ(2.1f, dptr2[0]);
-  EXPECT_FLOAT_EQ(4.1f, dptr2[1]);
-  EXPECT_FLOAT_EQ(6.1f, dptr2[2]);
-}
+

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/02851fac/test/singa/test_tensor_math.cc
----------------------------------------------------------------------
diff --git a/test/singa/test_tensor_math.cc b/test/singa/test_tensor_math.cc
new file mode 100644
index 0000000..51e7cfb
--- /dev/null
+++ b/test/singa/test_tensor_math.cc
@@ -0,0 +1,84 @@
+#include "gtest/gtest.h"
+#include "singa/core/tensor.h"
+using singa::Tensor;
+using singa::Shape;
+using singa::Device;
+
+class TestTensorMath : public ::testing::Test {
+ protected:
+  virtual void SetUp() {
+    const float dat1[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
+    const float dat2[] = {1.1f, 2.1f, 3.1f, 4.1f, 5.1f, 6.1f};
+    a.ReShape(singa::Shape{6});
+    b.ReShape(singa::Shape{6});
+    c.ReShape(singa::Shape{6, 1});
+    d.ReShape(singa::Shape{3, 2});
+
+    a.CopyDataFromHostPtr<float>(dat1, 6);
+    b.CopyDataFromHostPtr<float>(dat2, 6);
+  }
+  Tensor a, b, c, d;
+};
+
+TEST_F(TestTensorMath, MemberAddTensor) {
+  Tensor aa = a.Clone();
+  aa += a;
+  const float* dptr = aa.data<float>();
+  EXPECT_FLOAT_EQ(2.0f, dptr[0]);
+  EXPECT_FLOAT_EQ(4.0f, dptr[1]);
+  EXPECT_FLOAT_EQ(6.0f, dptr[2]);
+
+  // check p is initialized to 0
+  Tensor p(Shape{6});
+  p += aa;
+  const float* dptr1 = p.data<float>();
+  EXPECT_FLOAT_EQ(2.0f, dptr1[0]);
+  EXPECT_FLOAT_EQ(4.0f, dptr1[1]);
+  EXPECT_FLOAT_EQ(6.0f, dptr1[2]);
+
+  a += b;
+  const float* dptr2 = a.data<float>();
+  EXPECT_FLOAT_EQ(2.1f, dptr2[0]);
+  EXPECT_FLOAT_EQ(4.1f, dptr2[1]);
+  EXPECT_FLOAT_EQ(6.1f, dptr2[2]);
+  EXPECT_FLOAT_EQ(12.1f, dptr2[5]);
+}
+/*
+TEST(TensorClass, SubTensor) {
+  Tensor a(Shape{2,3}), b(Shape{6});
+  float x[]={1.f, 2.f, 3.f, 4.f, 5.f, 6.f};
+  float y[]={1.1f, 2.1f, 3.1f, 4.1f, 5.1f, 6.1f};
+  a.CopyDataFromHostPtr(x, 6);
+  b.CopyDataFromHostPtr(y, 6);
+  b -= a;
+  const float* dptr = b.data<float>();
+  EXPECT_FLOAT_EQ(0.1f, dptr[0]);
+  EXPECT_FLOAT_EQ(0.1f, dptr[1]);
+  EXPECT_FLOAT_EQ(0.1f, dptr[2]);
+  EXPECT_FLOAT_EQ(0.1f, dptr[5]);
+}
+*/
+
+TEST_F(TestTensorMath, AddTensors) {
+  Tensor ret(a.shape(), a.device(), a.data_type());
+  Add(a, b, &ret);
+  const float* dptr = ret.data<float>();
+  EXPECT_FLOAT_EQ(2.1f, dptr[0]);
+  EXPECT_FLOAT_EQ(4.1f, dptr[1]);
+  EXPECT_FLOAT_EQ(6.1f, dptr[2]);
+  EXPECT_FLOAT_EQ(12.1f, dptr[5]);
+
+  const Tensor d = a + b;
+  const float* dptr2 = d.data<float>();
+  EXPECT_FLOAT_EQ(2.1f, dptr2[0]);
+  EXPECT_FLOAT_EQ(4.1f, dptr2[1]);
+  EXPECT_FLOAT_EQ(6.1f, dptr2[2]);
+  EXPECT_FLOAT_EQ(12.1f, dptr2[5]);
+
+  Add(a, b, &a);
+  const float* dptr1 = a.data<float>();
+  EXPECT_FLOAT_EQ(2.1f, dptr1[0]);
+  EXPECT_FLOAT_EQ(4.1f, dptr1[1]);
+  EXPECT_FLOAT_EQ(6.1f, dptr1[2]);
+  EXPECT_FLOAT_EQ(12.1f, dptr1[5]);
+}