You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@singa.apache.org by wa...@apache.org on 2018/05/13 15:26:29 UTC

[02/10] incubator-singa git commit: Singa-341 Added stride functionality to tensors for CPP

Singa-341 Added stride functionality to tensors for CPP


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/a88efa00
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/a88efa00
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/a88efa00

Branch: refs/heads/master
Commit: a88efa00c425f610c54a359e597ecaa82d41ff25
Parents: 060e7df
Author: Vaan Ng <cm...@gmail.com>
Authored: Tue Apr 17 20:09:19 2018 +0800
Committer: Vaan Ng <cm...@gmail.com>
Committed: Tue Apr 17 20:09:19 2018 +0800

----------------------------------------------------------------------
 include/singa/core/tensor.h       |  118 +++-
 src/core/tensor/tensor.cc         |  199 ++++--
 src/core/tensor/tensor_math.h     |  173 +++--
 src/core/tensor/tensor_math_cpp.h | 1199 ++++++++++++++++++++++++--------
 src/proto/core.proto              |   21 +-
 5 files changed, 1275 insertions(+), 435 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/a88efa00/include/singa/core/tensor.h
----------------------------------------------------------------------
diff --git a/include/singa/core/tensor.h b/include/singa/core/tensor.h
index 6621fa0..6eafbdf 100644
--- a/include/singa/core/tensor.h
+++ b/include/singa/core/tensor.h
@@ -22,6 +22,7 @@
 #include <vector>
 #include <tuple>
 #include <memory>
+#include <algorithm>
 
 #include "singa/core/common.h"
 #include "singa/core/device.h"
@@ -30,6 +31,7 @@
 
 using std::vector;
 using std::tuple;
+using std::reverse;
 namespace singa {
 
 typedef vector<size_t> Shape;
@@ -58,12 +60,14 @@ class Tensor {
   Tensor();
   explicit Tensor(Shape &&shape, DataType dtype = kFloat32);
   explicit Tensor(const Shape &shape, DataType dtype = kFloat32);
+
   Tensor(Shape &&shape, std::shared_ptr<Device> dev, DataType dtype = kFloat32);
-  Tensor(const Shape &shape, std::shared_ptr<Device> dev,
-         DataType dtype = kFloat32);
+  Tensor(const Shape &shape, std::shared_ptr<Device> dev, DataType dtype = kFloat32);
 
   /// Copy Tensor to share the internal data.  No deep copy.
   Tensor(const Tensor &from);
+  /// Copy Tensor to share the internal data.  No deep copy. For 2 tensors sharing same block but different strides.
+  Tensor(const Tensor &from, Shape &new_shape, vector<int> &new_strides);
   /// Copy Tensor to share the internal data.  No deep copy.
   Tensor(Tensor &&from);
 
@@ -104,7 +108,12 @@ class Tensor {
 
   bool empty() const { return nDim() == 0; }
 
-  bool transpose() const { return transpose_; }
+  //bool transpose() const { return transpose_; }
+  bool transpose() const { return (strides_[0] != 1); }
+
+  const vector<int>& strides() const { return strides_; }
+
+  const vector<int>& shape_multipliers() const { return shape_multipliers_; }
 
   /// return true if the content of the tensor is initialized
   bool initailized() const {
@@ -171,6 +180,10 @@ class Tensor {
   /// No data copy, just set the transpose_ filed of the returned tensor.
   Tensor T() const;
 
+  Tensor Transpose() const;
+
+  Tensor Transpose(Shape axes) const;
+
   /// Copy the meta info with data block shared.
   Tensor &operator=(const Tensor &in);
 
@@ -209,15 +222,106 @@ class Tensor {
   /// Return average L2 norm
   float L2() const;
 
+  //generate strides automatically if stride field is not passed
+void Generate_Strides(){
+    if(shape_.size()==0){
+      strides_ = {1};
+      return void();
+    }
+    strides_.clear();
+    size_t dim = Size();
+    int cumulative_product = 1;
+    for (size_t n=0; n<shape_.size(); ++n) {
+        cumulative_product = cumulative_product*shape_[n];
+        strides_.push_back(dim/cumulative_product);
+    }
+    reverse(strides_.begin(), strides_.end());
+};
+
+//generate shape multipliers
+//for e.g. tensor of shape (3,3), stride (1,3) will have shape multipliers of (3,1)
+//for e.g. tensor of shape (3,3), stride (3,1) will also have shape multipliers of (3,1)
+//this means that the 3rd, 6th, and 9th index of the array will always be the starting element of their respective rows
+//so we need to need use the inner stride when jumping from 1st->2nd element, and outer stride when jumping from 2nd->3rd
+vector<int> Generate_Shape_Multipliers(Shape y_shape) const {
+    if(y_shape.size()==0){
+      return {1};
+    }
+    reverse(y_shape.begin(), y_shape.end());
+    vector<int> shape_multipliers = {};
+    int cumulative_product = 1;
+
+    shape_multipliers.push_back(1);
+    for (size_t n=0; n<(y_shape.size()-1); ++n) {
+        cumulative_product = cumulative_product*y_shape[n];
+        shape_multipliers.push_back(cumulative_product);
+    }
+    reverse(shape_multipliers.begin(), shape_multipliers.end());
+    return shape_multipliers;
+};
+
+// ******************************************************************************************
+// Some traversal operations (works on const declarations without modifying tensor variables)
+// ******************************************************************************************
+
+//generate a traversal_info vector based on the tensor's shape for the traverse_next function to work
+vector<int> generate_traversal_info() const {
+    vector<int> traversal_info = {};
+    for(size_t n=0; n<(shape_.size()+2); ++n) {
+      traversal_info.push_back(0);
+    }
+    return traversal_info;
+};
+
+//this function checks whether the next index falls on a special multiplier of the outer shape
+//so the algorithm knows when to jump over/back to a starting element of the outer shape
+//for e.g. in [[1,4,7], [2,5,8], [3,6,9]], elements 1,2,3 are the starting elements of their respective rows
+//this additional check only has 1 loop for 2d matrix
+//but runtime performance might degrade to O(nlog(n)) for higher dimensional tensors
+int determine_order(int counter) const {
+    for (size_t n=0; n<(shape_multipliers_.size()-1); ++n) {
+        if((counter%shape_multipliers_[n])==0){
+            return ((shape_multipliers_.size()) - 1 - n);
+        }
+    }
+    return 0;
+};
+
+//this function updates the base indexes with the current index after every single traversal step, can be generalized beyond 2d cases
+void update_base_index(std::vector<int>& traversal_info) const {
+    for (int n=0; n<(traversal_info[shape_.size()+1]+1); ++n) {
+        traversal_info[n] = traversal_info[shape_.size()];
+    }
+};
+
+//function to traverse a const strided tensor object
+//it requires an additional vector, traversal_info {0,0,0,0 ...}, comprising (shape_.size()+2) elements of 0
+//for e.g. 2d matrix:
+//index 0 and 1 store the base row and column index respectively
+//index 2 stores the current index of the traversal
+//index 3 stores the order of the traversal for e.g. if the order is 0, it means the next element can be navigated to using the innermost stride
+void traverse_next(std::vector<int>& traversal_info, int counter) const {
+    update_base_index(traversal_info);
+    traversal_info[shape_.size()+1] = determine_order(counter);
+    traversal_info[shape_.size()] = traversal_info[traversal_info[shape_.size()+1]]+strides_[traversal_info[shape_.size()+1]];
+};
+
+// ******************************************************************************************
+// traversal operations end
+// ******************************************************************************************
+
  protected:
-  bool transpose_ = false;
+  //bool transpose_ = false;
   DataType data_type_ = kFloat32;
   std::shared_ptr<Device> device_ = nullptr;
   /// Note: block_ is allocated in lazy manner to avoid frequent malloc/free.
   /// If you want to get an allocated Block, use block() instead of block_.
   Block *block_ = nullptr;
   Shape shape_ = {};
-};
+  vector<int> strides_ = {};
+  vector<int> shape_multipliers_ = {};
+
+}; //end of tensor class
 
 typedef Shape::iterator ShapeIter;
 inline size_t Product(const Shape &shape, int start = 0, size_t len = 0) {
@@ -452,12 +556,16 @@ void Mult(const SType alpha, const Tensor &A, const Tensor &B, const SType beta,
 /// each instance, t[i] could be 2 or [0, 0, 1]. If one instance could have
 /// multiple labels, then t[i] could be [1, 0, 1].
 /// The loss is computed into p.
+
 void ComputeCrossEntropy(const Tensor &p, const Tensor &t, Tensor *loss);
+
 /// Compute the dx, given prediction probability 'p' (p=softmax(x)) and
 /// the target (ground truth) labels 't'. 'p' and 't' are either 1-d vector
 /// or 2-d matrix. 'grad' has the same shape as 'p'. dx is computed into p.
+
 void SoftmaxCrossEntropyBwd(const Tensor &t, Tensor *p);
 
+
 /// Return a tensor consisting of rows ([start, end)) from 'in'. It copies the
 /// values from 'in'. 'in' ia a 2D Tensor.
 Tensor CopyRows(const Tensor &in, const size_t start, const size_t end);

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/a88efa00/src/core/tensor/tensor.cc
----------------------------------------------------------------------
diff --git a/src/core/tensor/tensor.cc b/src/core/tensor/tensor.cc
index ed4da96..48751ef 100644
--- a/src/core/tensor/tensor.cc
+++ b/src/core/tensor/tensor.cc
@@ -21,6 +21,7 @@
 #include "./tensor_math_cuda.h"
 #include "./tensor_math_opencl.h"
 #include <utility>
+#include <iostream>
 
 namespace singa {
 
@@ -30,52 +31,87 @@ Tensor::~Tensor() {
   block_ = nullptr;
 }
 
-Tensor::Tensor() { device_ = defaultDevice; }
+Tensor::Tensor() { 
+  device_ = defaultDevice;
+  strides_ = {1};
+  shape_multipliers_ = {1};
+}
 
+//non-strided constructors 
 Tensor::Tensor(const Shape &shape, DataType dtype)
     : data_type_(dtype), device_(defaultDevice), shape_(shape) {
   size_t size = Product(shape_) * SizeOf(data_type_);
   if (size)
     block_ = device_->NewBlock((int)size);
+  Generate_Strides();
+  shape_multipliers_ = Generate_Shape_Multipliers(shape_);
 }
 Tensor::Tensor(Shape &&shape, DataType dtype)
     : data_type_(dtype), device_(defaultDevice), shape_(shape) {
   size_t size = Product(shape_) * SizeOf(data_type_);
   if (size)
     block_ = device_->NewBlock((int)size);
+  Generate_Strides();
+  shape_multipliers_ = Generate_Shape_Multipliers(shape_);
 }
+
+//non-strided constructors with device
 Tensor::Tensor(const Shape &shape, std::shared_ptr<Device> device,
                DataType dtype)
     : data_type_(dtype), device_(device), shape_(shape) {
   size_t size = Product(shape_) * SizeOf(data_type_);
   if (size)
     block_ = device_->NewBlock((int)size);
+  Generate_Strides();
+  shape_multipliers_ = Generate_Shape_Multipliers(shape_);
 }
 Tensor::Tensor(Shape &&shape, std::shared_ptr<Device> device, DataType dtype)
     : data_type_(dtype), device_(device), shape_(shape) {
   size_t size = Product(shape_) * SizeOf(data_type_);
   if (size)
     block_ = device_->NewBlock((int)size);
+  Generate_Strides();
+  shape_multipliers_ = Generate_Shape_Multipliers(shape_);
 }
+
+
 Tensor::Tensor(const Tensor &in)
-    : transpose_(in.transpose_),
+    : //transpose_(in.transpose_),
+      data_type_(in.data_type_),
+      device_(in.device_),
+      block_(in.block()),
+      shape_(in.shape_),
+      strides_(in.strides_),
+      shape_multipliers_(in.shape_multipliers_) {
+  if (block_ != nullptr)
+    block_->IncRefCount();
+}
+
+//strided constructor taking in a tensor, shape and strides
+Tensor::Tensor(const Tensor &in, Shape &new_shape, vector<int> &new_strides)
+    : //transpose_(in.transpose_),
       data_type_(in.data_type_),
       device_(in.device_),
       block_(in.block()),
-      shape_(in.shape_) {
+      shape_(new_shape),
+      strides_(new_strides) {
+  shape_multipliers_ = Generate_Shape_Multipliers(shape_);
   if (block_ != nullptr)
     block_->IncRefCount();
 }
 
 Tensor::Tensor(Tensor &&in)
-    : transpose_(in.transpose_),
+    : //transpose_(in.transpose_),
       data_type_(in.data_type_),
       device_(in.device_),
-      shape_(std::move(in.shape_)) {
+      shape_(std::move(in.shape_)),
+      strides_(in.strides_),
+      shape_multipliers_(in.shape_multipliers_) {
   block_ = in.block_;
   in.block_ = nullptr;
 }
 
+
 void Tensor::SetBlock(Block *block) {
   LOG(WARNING) << "Pls avoid using this function, which may have side-effect.";
   if (block_ != nullptr)
@@ -92,24 +128,46 @@ void Tensor::ResetLike(const Tensor &in) {
     block_ = device_->NewBlock((int)in.MemSize());
   }
   shape_ = in.shape_;
+  strides_ = in.strides_;
+  shape_multipliers_ = in.shape_multipliers_;
 }
 
+//yisen todo
+//if tensor is not transposed yet i.e strides == 1, then we simply change the shape and generate new default strides
+//if tensor is already transposed i.e strides != 1, it should be copied to a new tensor with newly generated default strides 
+
 void Tensor::Reshape(const Shape &shape) {
+  if(strides_.size()==0)
+    strides_.push_back(1);
+
   if (Product(shape_) != Product(shape)) {
     if (block_ != nullptr && block_->DecRefCount() == 0)
       device_->FreeBlock(block_);
     block_ = device_->NewBlock((int)(Product(shape) * SizeOf(data_type_)));
+  } else if (strides_[0] != 1) {
+    std::cout << "Reshape Error: Tranposed tensor must return new tensor. Not implemented yet." << std::endl;
+    return void();
   }
   shape_ = shape;
+  Generate_Strides();
+  shape_multipliers_ = Generate_Shape_Multipliers(shape_);
 }
 
 void Tensor::Reshape(Shape &&shape) {
+  if(strides_.size()==0)
+    strides_.push_back(1);
+
   if (Product(shape_) != Product(shape)) {
     if (block_ != nullptr && block_->DecRefCount() == 0)
       device_->FreeBlock(block_);
     block_ = device_->NewBlock((int)(Product(shape) * SizeOf(data_type_)));
+  } else if (strides_[0] != 1) {
+    std::cout << "Reshape Error: Tranposed tensor must return new tensor. Not implemented yet." << std::endl;
+    return void();
   }
   shape_ = std::move(shape);
+  Generate_Strides();
+  shape_multipliers_ = Generate_Shape_Multipliers(shape_);
 }
 
 void Tensor::AsType(const DataType type) {
@@ -177,7 +235,9 @@ void Tensor::FromProto(const singa::TensorProto &proto) {
   for (uint32_t s : proto.shape()) shape.push_back(s);
   data_type_ = proto.data_type();
   Reshape(shape);
-  transpose_ = proto.transpose();
+  //transpose_ = proto.transpose();
+  strides_.clear();
+  for (int32_t s : proto.strides()) strides_.push_back(s);
   switch (data_type_) {
     case kFloat32: {
       std::unique_ptr<float[]> data_ptr(new float[Product(shape_)]);
@@ -226,7 +286,11 @@ void Tensor::ToProto(singa::TensorProto *proto) const {
     proto->add_shape(s);
   }
   proto->set_data_type(data_type_);
-  proto->set_transpose(transpose_);
+  //proto->set_transpose(transpose_);
+  proto->clear_strides();
+  for (auto s : strides_) {
+    proto->add_strides(s);
+  }
   switch (data_type_) {
     case kFloat32: {
       proto->clear_float_data();
@@ -272,19 +336,67 @@ void Tensor::ToProto(singa::TensorProto *proto) const {
 Tensor Tensor::Clone(std::shared_ptr<Device> device) const {
   if (device == nullptr) device = device_;
   Tensor t(shape_, device_, data_type_);
-  t.transpose_ = transpose_;
+  //t.transpose_ = transpose_;
+  t.strides_ = strides_;
   t.CopyData(*this);
   return t;
 }
 
+//yisen todo
 Tensor Tensor::T() const {
+  // this function only works for 2d tensors
   CHECK_EQ(shape_.size(), 2u);
   Tensor t;
   t.device_ = device_;
   t.data_type_ = data_type_;
-  t.transpose_ = !transpose_;
   t.shape_.push_back(shape_[1]);
   t.shape_.push_back(shape_[0]);
+  t.strides_.clear();
+  t.strides_.push_back(strides_[1]);
+  t.strides_.push_back(strides_[0]);
+  t.shape_multipliers_ = Generate_Shape_Multipliers(t.shape_);
+  t.block_ = block_;
+  block_->IncRefCount();
+  return t;
+}
+
+//normal transpose without axes
+Tensor Tensor::Transpose() const {
+  // if(shape_.size() != strides_.size())
+  //   Generate_Strides();
+
+  Tensor t;
+  t.device_ = device_;
+  t.data_type_ = data_type_;
+  t.strides_.clear();
+  for(size_t n=0; n<shape_.size(); ++n){
+    t.shape_.push_back(shape_[shape_.size()-n-1]);
+    t.strides_.push_back(strides_[shape_.size()-n-1]);
+  }
+  t.shape_multipliers_ = Generate_Shape_Multipliers(t.shape_);
+  t.block_ = block_;
+  block_->IncRefCount();
+  return t;
+}
+
+//transpose with axes
+Tensor Tensor::Transpose(Shape axes) const {
+  // if(axes.size() != shape_.size()){
+  //   std::cout << "Warning: Size of input axes doesn't match size of shape" << std::endl;
+  //   return void();
+  // }
+  // if(shape_.size() != strides_.size())
+  //   Generate_Strides();
+
+  Tensor t;
+  t.device_ = device_;
+  t.data_type_ = data_type_;
+  t.strides_.clear();
+  for(size_t n=0; n<axes.size(); ++n){
+    t.shape_.push_back(shape_[axes[n]]);
+    t.strides_.push_back(strides_[axes[n]]);
+  }
+  t.shape_multipliers_ = Generate_Shape_Multipliers(t.shape_);
   t.block_ = block_;
   block_->IncRefCount();
   return t;
@@ -294,7 +406,8 @@ Tensor &Tensor::operator=(const Tensor &in) {
   // LOG(ERROR) << "= const &";
   if (block_ != nullptr && block_->DecRefCount() == 0)
     device_->FreeBlock(block_);
-  transpose_ = in.transpose_;
+  //transpose_ = in.transpose_;
+  strides_ = in.strides_;
   data_type_ = in.data_type_;
   shape_ = in.shape_;
   device_ = in.device_;
@@ -308,7 +421,8 @@ Tensor &Tensor::operator=(Tensor &&in) {
   // LOG(ERROR) << "= &&";
   if (block_ != nullptr && block_->DecRefCount() == 0)
     device_->FreeBlock(block_);
-  transpose_ = in.transpose_;
+  //transpose_ = in.transpose_;
+  strides_ = in.strides_;
   data_type_ = in.data_type_;
   shape_ = std::move(in.shape_);
   device_ = in.device_;
@@ -317,6 +431,7 @@ Tensor &Tensor::operator=(Tensor &&in) {
   return *this;
 }
 
+//yisen todo
 Tensor Reshape(const Tensor &in, const Shape &s) {
   Tensor out(in);
   out.Reshape(s);
@@ -373,7 +488,7 @@ void CopyDataToFrom(Tensor *dst, const Tensor &src, const size_t num,
                               (int)s_offset);
     } else if (src_dev->lang() == kCpp) {
       dst_dev->CopyDataToFrom(to, from, nBytes, kHostToDevice, (int)d_offset,
-							  (int)s_offset);
+                (int)s_offset);
     } else {
       LOG(FATAL) << "Not support mem copy betwee Cuda and OpenCL device";
     }
@@ -453,7 +568,7 @@ float Tensor::L1() const {
   TYPE_LANG_SWITCH(data_type_, DType, device_->lang(), Lang, {
     device_->Exec([&nrm, this](Context *ctx) {
       DType ret = DType(0);
-      Asum<DType, Lang>(this->Size(), this->block(), &ret, ctx);
+      Asum<DType, Lang>(this, &ret, ctx);
       nrm = TypeCast<DType, float>(ret);
     }, {this->block()}, {});
   });
@@ -466,7 +581,7 @@ float Tensor::L2() const {
   TYPE_LANG_SWITCH(data_type_, DType, device_->lang(), Lang, {
     device_->Exec([&nrm, this](Context *ctx) {
       DType ret = DType(0);
-      Nrm2<DType, Lang>(this->Size(), this->block(), &ret, ctx);
+      Nrm2<DType, Lang>(this, &ret, ctx);
       nrm = TypeCast<DType, float>(ret);
     }, {this->block()}, {});
   });
@@ -476,12 +591,12 @@ float Tensor::L2() const {
 template <typename SType>
 void Tensor::SetValue(const SType x) {
   CHECK_EQ(sizeof(SType), SizeOf(data_type_));
-  auto size = Size();
+  //auto size = Size();
   auto ptr = block_;
   TYPE_LANG_SWITCH(data_type_, DType, device_->lang(), Lang, {
     // TODO(wangwei) cast x to DType
-    device_->Exec([size, x, ptr](Context *ctx) {
-      Set<DType, Lang>(size, x, ptr, ctx);
+    device_->Exec([this, x, ptr](Context *ctx) {
+      Set<DType, Lang>(x, this, ctx);
     }, {}, {ptr});
   });
 }
@@ -492,7 +607,7 @@ template void Tensor::SetValue<int>(const int x);
   do {                                                                 \
     TYPE_LANG_SWITCH(t.data_type(), DType, t.device()->lang(), Lang, { \
       ret->device()->Exec([t, ret](Context * ctx) {                    \
-        fn<DType, Lang>(t.Size(), t.block(), ret->block(), ctx);       \
+        fn<DType, Lang>(&t, ret, ctx);       \
       }, {t.block()}, {ret->block()});                                 \
     });                                                                \
   } while (0)
@@ -521,7 +636,7 @@ GenUnaryTensorFn(Tanh);
     TYPE_LANG_SWITCH(lhs.data_type(), DType, lhs.device()->lang(), Lang, {  \
       CHECK_EQ(sizeof(DType), SizeOf(rhs.data_type()));                     \
       ret->device()->Exec([lhs, rhs, ret](Context * ctx) {                  \
-        fn<DType, Lang>(lhs.Size(), lhs.block(), rhs.block(), ret->block(), \
+        fn<DType, Lang>(&lhs, &rhs, ret, \
                         ctx);                                               \
       }, {lhs.block(), rhs.block()}, {ret->block()});                       \
     });                                                                     \
@@ -552,7 +667,7 @@ GenBinaryTensorFn(operator>=, GE);
       static_assert(std::is_same<SType, DType>::value,                  \
                     "The Scalar type must match the Tensor data type"); \
       ret->device()->Exec([t, x, ret](Context * ctx) {                  \
-        fn<DType, Lang>(t.Size(), t.block(), x, ret->block(), ctx);     \
+        fn<DType, Lang>(&t, x, ret, ctx);     \
       }, {t.block()}, {ret->block()});                                  \
     });                                                                 \
   } while (0)
@@ -595,7 +710,7 @@ void Div(const SType alpha, const Tensor &in, Tensor *out) {
   TYPE_LANG_SWITCH(in.data_type(), DType, in.device()->lang(), Lang, {
     // TODO(wangwei) type cast SType to DType;
     in.device()->Exec([alpha, in, out](Context *ctx) {
-      Div<DType, Lang>(in.Size(), alpha, in.block(), out->block(), ctx);
+      Div<DType, Lang>(alpha, &in, out, ctx);
     }, {in.block()}, {out->block()});
   });
 }
@@ -632,7 +747,7 @@ float Sum<float>(const Tensor &in) {
   TYPE_LANG_SWITCH(in.data_type(), DType, in.device()->lang(), Lang, {
     one.device()->Exec([in, one, &s](Context *ctx) {
       DType ret = DType(0);
-      Dot<DType, Lang>(in.Size(), in.block(), one.block(), &ret, ctx);
+      Dot<DType, Lang>(&in, &one, &ret, ctx);
       s = ret;
     }, {in.block(), one.block()}, {});
   });
@@ -661,11 +776,11 @@ Tensor SoftMax(const Tensor &in) {
 Tensor RowMax(const Tensor &in) {
   Tensor ret({in.shape(0)}, in.device(), in.data_type());
   TYPE_LANG_SWITCH(in.data_type(), DType, in.device()->lang(), Lang, {
-    in.device()->Exec([in, ret](Context *ctx) {
-      size_t nrow = 1;
-      if (in.nDim() > 1) nrow = in.shape(0);
-      size_t ncol = in.Size() / nrow;
-      RowMax<DType, Lang>(nrow, ncol, in.block(), ret.block(), ctx);
+    in.device()->Exec([&in, &ret](Context *ctx) {
+      //size_t nrow = 1;
+      //if (in.nDim() > 1) nrow = in.shape(0);
+      //size_t ncol = in.Size() / nrow;
+      RowMax<DType, Lang>(&in, &ret, ctx);
     }, {in.block()}, {ret.block()});
   });
   return ret;
@@ -708,13 +823,13 @@ void AddColumn(const SType alpha, const SType beta, const Tensor &v,
     Tensor vmat = Reshape(v, Shape{nb_row, 1});
     Mult(alpha, vmat, one, beta, M);
   }
-}
+} 
 template
 void AddColumn(const float alpha, const float beta, const Tensor &v, Tensor *M);
 
 void AddRow(const Tensor &v, Tensor *M) { AddRow(1, 1, v, M); }
 
-/// Sub column 'v' by each column of matrix M; write results into 'out'
+/// Add row 'v' by each column of matrix M; write results into 'out'
 template <typename SType>
 void AddRow(const SType alpha, const SType beta, const Tensor &v, Tensor *M) {
   if (M->transpose()) {
@@ -894,30 +1009,30 @@ void DivRow(const Tensor &v, Tensor *M) {
 
 /// Multiply column 'v' and each column of matrix M; write results into 'out'
 void MultColumn(const Tensor &v, Tensor *M) {
-  CHECK(!M->transpose()) << "Not supported yet";
+  //CHECK(!M->transpose()) << "Not supported yet";
   CHECK_EQ(M->nDim(), 2u);
   // CHECK_EQ(v.nDim(), 1u); (chonho) shape of v is 2-element tuple
   CHECK_EQ(v.Size(), M->shape(0));
   CheckDataTypeAndLang(*M, v);
   TYPE_LANG_SWITCH(v.data_type(), DType, v.device()->lang(), Lang, {
     v.device()->Exec([M, v](Context *ctx) {
-      DGMM<DType, Lang>(false, M->shape(0), M->shape(1), M->block(), v.block(),
-                        M->block(), ctx);
+      DGMM<DType, Lang>(false, M, &v,
+                        M, ctx);
     }, {M->block(), v.block()}, {M->block()});
   });
 }
 
 /// Multiply row 'v' with each row of matrix M; write results into 'out'
 void MultRow(const Tensor &v, Tensor *M) {
-  CHECK(!M->transpose()) << "Not supported yet";
+  //CHECK(!M->transpose()) << "Not supported yet";
   CHECK_EQ(M->nDim(), 2u);
   // CHECK_EQ(v.nDim(), 1u); (chonho) shape of v is 2-element tuple
   CHECK_EQ(v.Size(), M->shape(1));
   CheckDataTypeAndLang(*M, v);
   TYPE_LANG_SWITCH(v.data_type(), DType, v.device()->lang(), Lang, {
     v.device()->Exec([M, v](Context *ctx) {
-      DGMM<DType, Lang>(true, M->shape(0), M->shape(1), M->block(), v.block(),
-                        M->block(), ctx);
+      DGMM<DType, Lang>(true, M, &v,
+                        M, ctx);
     }, {M->block(), v.block()}, {M->block()});
   });
 }
@@ -963,7 +1078,7 @@ void Bernoulli(const SType p, Tensor *out) {
   TYPE_LANG_SWITCH(out->data_type(), DType, out->device()->lang(), Lang, {
     auto prob = TypeCast<SType, DType>(p);
     out->device()->Exec([prob, out](Context *ctx) {
-      Bernoulli<DType, Lang>(out->Size(), prob, out->block(), ctx);
+      Bernoulli<DType, Lang>(prob, out, ctx);
     }, {}, {out->block()}, true);
   });
 }
@@ -976,7 +1091,7 @@ void Uniform(const SType low, const SType high, Tensor *out) {
     auto l = TypeCast<SType, DType>(low);
     auto h = TypeCast<SType, DType>(high);
     out->device()->Exec([l, h, out](Context *ctx) {
-      Uniform<DType, Lang>(out->Size(), l, h, out->block(), ctx);
+      Uniform<DType, Lang>(l, h, out, ctx);
     }, {}, {out->block()}, true);
   });
 }
@@ -989,7 +1104,7 @@ void Gaussian(const SType mean, const SType std, Tensor *out) {
     auto m = TypeCast<SType, DType>(mean);
     auto s = TypeCast<SType, DType>(std);
     out->device()->Exec([m, s, out](Context *ctx) {
-      Gaussian<DType, Lang>(out->Size(), m, s, out->block(), ctx);
+      Gaussian<DType, Lang>(m, s, out, ctx);
     }, {}, {out->block()}, true);
   });
 }
@@ -1002,7 +1117,7 @@ void Axpy(const SType alpha, const Tensor &in, Tensor *out) {
   TYPE_LANG_SWITCH(in.data_type(), DType, in.device()->lang(), Lang, {
     auto a = TypeCast<SType, DType>(alpha);
     out->device()->Exec([a, in, out](Context *ctx) {
-      Axpy<DType, Lang>(in.Size(), a, in.block(), out->block(), ctx);
+      Axpy<DType, Lang>(a, &in, out, ctx);
     }, {in.block(), out->block()}, {out->block()});
   });
 }
@@ -1032,8 +1147,7 @@ void Mult(const SType alpha, const Tensor &A, const Tensor &B, const SType beta,
       auto a = TypeCast<SType, DType>(alpha);
       auto b = TypeCast<SType, DType>(beta);
       C->device()->Exec([a, A, b, B, C](Context *ctx) {
-        GEMV<DType, Lang>(A.transpose(), A.shape(0), A.shape(1), a, A.block(),
-                          B.block(), b, C->block(), ctx);
+        GEMV<DType, Lang>(a, &A, &B, b, C, ctx);
       }, {A.block(), B.block()}, {C->block()});
     });
   } else {
@@ -1042,8 +1156,7 @@ void Mult(const SType alpha, const Tensor &A, const Tensor &B, const SType beta,
       auto a = TypeCast<SType, DType>(alpha);
       auto b = TypeCast<SType, DType>(beta);
       C->device()->Exec([a, A, b, B, C](Context *ctx) {
-        GEMM<DType, Lang>(A.transpose(), B.transpose(), A.shape(0), B.shape(1),
-                          A.shape(1), a, A.block(), B.block(), b, C->block(),
+        GEMM<DType, Lang>(a, &A, &B, b, C,
                           ctx);
       }, {A.block(), B.block()}, {C->block()});
     });

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/a88efa00/src/core/tensor/tensor_math.h
----------------------------------------------------------------------
diff --git a/src/core/tensor/tensor_math.h b/src/core/tensor/tensor_math.h
index 6d42211..c403f30 100644
--- a/src/core/tensor/tensor_math.h
+++ b/src/core/tensor/tensor_math.h
@@ -19,7 +19,9 @@
 #define SINGA_CORE_MATH_H_
 #include <type_traits>
 #include "singa/core/common.h"
+#include "singa/core/tensor.h"
 #include "singa/utils/logging.h"
+#include <vector>
 
 namespace singa {
 
@@ -33,20 +35,52 @@ namespace singa {
 /// first
 ///    letter.
 /// 2. Order functions based on function name in alphabetical order.
-/// 3. Function arguments order is [const basic type] [const Block] [mutable
-/// Block].
+/// 3. Function arguments order is [const basic type] [const Tensor] [mutable
+/// Tensor].
 /// 4. Function argument names, use 'num' for total number of elements in
-///    elementwise operations; use 'in1' 'in2' for in blocks; use 'out' for
-///    output block or value. With exceptions for some functions, e.g.,
-///      Scale(const float alpha, const Block* in, Block* out);
+///    elementwise operations; use 'in1' 'in2' for in Tensors; use 'out' for
+///    output Tensor or value. With exceptions for some functions, e.g.,
+///      Scale(const float alpha, const Tensor* in, Tensor* out);
 ///    For such cases, use x, v, alpha, etc for scalar types.
 ///    For blas functions, follow the blas style for argument names.
 ///    Use 'M' and 'v' for matrix and vector tensors in functions involving both
 ///    matrix and vectors.
-/// 5. For Block argument xxx, name its raw pointer as xxxPtr.
+/// 5. For Tensor argument xxx, name its raw pointer as xxxPtr.
 /// 6. Pass the 'cudaStream_t s' to every function in math_kernel.h
 /// 7. Use size_t for the number of elements, rows or columns.
-/// 8. Use the same name for the Tensor and Block level math functions.
+/// 8. Use the same name for the Tensor and Tensor level math functions.
+
+// template <typename DType>
+// void TraverseUnary(const Tensor* in, Tensor* out, std::function<DType(DType)> func){}
+
+// template <typename DType>
+// void TraverseBinary(const Tensor* in1, const Tensor* in2, Tensor* out, std::function<DType(DType, DType)> func){}
+
+template <typename DType>
+void TraverseUnary(const Tensor* in, Tensor* out, std::function<DType(DType)> func){
+  DType *outPtr = static_cast<DType *>(out->block()->mutable_data());
+  const DType *inPtr = static_cast<const DType *>(in->block()->data());
+  vector<int> traversal_info = in->generate_traversal_info();
+  for (size_t i = 0; i < in->Size(); i++) { 
+    outPtr[i] = func(inPtr[traversal_info[in->shape().size()]]);
+    in->traverse_next(traversal_info, i+1);
+  }
+}
+
+template <typename DType>
+void TraverseBinary(const Tensor* in1, const Tensor* in2, Tensor* out, std::function<DType(DType, DType)> func){
+  DType *outPtr = static_cast<DType *>(out->block()->mutable_data());
+  const DType *in1Ptr = static_cast<const DType *>(in1->block()->data());
+  const DType *in2Ptr = static_cast<const DType *>(in2->block()->data());
+  vector<int> traversal_info_in1 = in1->generate_traversal_info();
+  vector<int> traversal_info_in2 = in2->generate_traversal_info();
+  for (size_t i = 0; i < in1->Size(); i++) {
+    outPtr[i] = func(in1Ptr[traversal_info_in1[in1->shape().size()]], in2Ptr[traversal_info_in2[in2->shape().size()]]);
+    in1->traverse_next(traversal_info_in1, i+1);
+    in2->traverse_next(traversal_info_in2, i+1);
+  }
+}
+
 
 // **************************************
 // Element-wise functions
@@ -54,197 +88,197 @@ namespace singa {
 
 /// out[i] = |in[i]|
 template <typename DType, typename Lang>
-void Abs(const size_t num, const Block *in, Block *out, Context *ctx) {
+void Abs(const Tensor *in, Tensor *out, Context *ctx) {
   LOG(FATAL) << "Abs Not Implemented";
 }
 
 /// out[i] = in[i] + x
 template <typename DType, typename Lang>
-void Add(const size_t num, const Block *in, const DType x, Block *out,
+void Add(const Tensor *in, const DType x, Tensor *out,
          Context *ctx) {
   LOG(FATAL) << "Add Not Implemented";
 }
 
 /// out[i] = in1[i] + in2[i]
 template <typename DType, typename Lang>
-void Add(const size_t num, const Block *in1, const Block *in2, Block *out,
+void Add(const Tensor *in1, const Tensor *in2, Tensor *out,
          Context *ctx) {
   LOG(FATAL) << "Add-Pair Not Implemented";
 }
 /// Clamp every element into [low, high]
 /// if in[i]>high, then out[i]=high; if in[i]<low, then out[i]=low.
 template <typename DType, typename Lang>
-void Clamp(const size_t num, const DType low, const DType high, const Block *in,
-           Block *out, Context *ctx) {
+void Clamp(const DType low, const DType high, const Tensor *in,
+           Tensor *out, Context *ctx) {
   LOG(FATAL) << "Clamp Not Implemented";
 }
 
 /// out[i] = x / in[i]
 template <typename DType, typename Lang>
-void Div(const size_t num, const DType x, const Block *in, Block *out,
+void Div(const DType x, const Tensor *in, Tensor *out,
          Context *ctx) {
   LOG(FATAL) << "Div Not Implemented";
 }
 
 /// out[i] = in[i] / x
 template <typename DType, typename Lang>
-void Div(const size_t num, const Block *in, const DType x, Block *out,
+void Div(const Tensor *in, const DType x, Tensor *out,
          Context *ctx) {
   CHECK_NE(x, 0.f);
-  EltwiseMult<DType, Lang>(num, in, DType(1) / x, out, ctx);
+  EltwiseMult<DType, Lang>(in, DType(1) / x, out, ctx);
 }
 
 /// out[i] = in1[i] / in2[i]
 template <typename DType, typename Lang>
-void Div(const size_t num, const Block *in1, const Block *in2, Block *out,
+void Div(const Tensor *in1, const Tensor *in2, Tensor *out,
          Context *ctx) {
   LOG(FATAL) << "Div-Pair Not Implemented";
 }
 
 /// out[i] = in[i] * x
 template <typename DType, typename Lang>
-void EltwiseMult(const size_t num, const Block *in, const DType x, Block *out,
+void EltwiseMult(const Tensor *in, const DType x, Tensor *out,
                  Context *ctx) {
   LOG(FATAL) << "EltwiseMult Not Implemented";
 }
 
 /// out[i] = in1[i] * in2[i]
 template <typename DType, typename Lang>
-void EltwiseMult(const size_t num, const Block *in1, const Block *in2, Block *out,
+void EltwiseMult(const Tensor *in1, const Tensor *in2, Tensor *out,
                  Context *ctx) {
   LOG(FATAL) << "EltwiseMult-Pair Not Implemented";
 }
 
 /// Base is e, Neper number. out[i]=exp(in[i])
 template <typename DType, typename Lang>
-void Exp(const size_t num, const Block *in, Block *out, Context *ctx) {
+void Exp(const Tensor *in, Tensor *out, Context *ctx) {
   LOG(FATAL) << "Exp Not Implemented";
 }
 
 /// out[i]=(in[i]<=x)?1.f:0.f
 template <typename DType, typename Lang>
-void LE(const size_t num, const Block *in, const DType x, Block *out,
+void LE(const Tensor *in, const DType x, Tensor *out,
         Context *ctx) {
   LOG(FATAL) << "LE Not Implemented";
 }
 /// out[i]=(in1[i]<=in2[i])?1.f:0.f
 template <typename DType, typename Lang>
-void LE(const size_t num, const Block *in1, const Block *in2, Block *out,
+void LE(const Tensor *in1, const Tensor *in2, Tensor *out,
         Context *ctx) {
   LOG(FATAL) << "Tensor-Tensor LE Not Implemented";
 }
 /// Natual logarithm, the base is e, Neper number out[i]=log(in[i]).
 template <typename DType, typename Lang>
-void Log(const size_t num, const Block *in, Block *out, Context *ctx) {
+void Log(const Tensor *in, Tensor *out, Context *ctx) {
   LOG(FATAL) << "Log Not Implemented";
 }
 /// out[i]=(in[i]<x)?1.f:0.f
 template <typename DType, typename Lang>
-void LT(const size_t num, const Block *in, const DType x, Block *out,
+void LT(const Tensor *in, const DType x, Tensor *out,
         Context *ctx) {
   LOG(FATAL) << "LT Not Implemented";
 }
 /// out[i]=(in1[i]<in2[i])?1.f:0.f
 template <typename DType, typename Lang>
-void LT(const size_t num, const Block *in1, const Block *in2, Block *out,
+void LT(const Tensor *in1, const Tensor *in2, Tensor *out,
         Context *ctx) {
   LOG(FATAL) << "Tensor-Tensor LT Not Implemented";
 }
 /// out[i]=(in[i]>=x)?1.f:0.f
 template <typename DType, typename Lang>
-void GE(const size_t num, const Block *in, const DType x, Block *out,
+void GE(const Tensor *in, const DType x, Tensor *out,
         Context *ctx) {
   LOG(FATAL) << "GE Not Implemented";
 }
 /// out[i]=(in1[i]>=in2[i])?1.f:0.f
 template <typename DType, typename Lang>
-void GE(const size_t num, const Block *in1, const Block *in2, Block *out,
+void GE(const Tensor *in1, const Tensor *in2, Tensor *out,
         Context *ctx) {
   LOG(FATAL) << "Tensor-Tensor GE Not Implemented";
 }
 /// out[i]=(in[i]>x)?1.f:0.f
 template <typename DType, typename Lang>
-void GT(const size_t num, const Block *in, const DType x, Block *out,
+void GT(const Tensor *in, const DType x, Tensor *out,
         Context *ctx) {
   LOG(FATAL) << "GT Not Implemented";
 }
 /// out[i]=(in[i]>in2[i])?1.f:0.f
 template <typename DType, typename Lang>
-void GT(const size_t num, const Block *in, const Block *in2, Block *out,
+void GT(const Tensor *in, const Tensor *in2, Tensor *out,
         Context *ctx) {
   LOG(FATAL) << "Tensor-Tensor GT Not Implemented";
 }
 /// out[i] = pow(in[i], x)
 template <typename DType, typename Lang>
-void Pow(const size_t num, const Block *in, const DType x, Block *out,
+void Pow(const Tensor *in, const DType x, Tensor *out,
          Context *ctx) {
   LOG(FATAL) << "Pow Not Implemented";
 }
 
 /// out[i]=pow(in1[i], in2[i])
 template <typename DType, typename Lang>
-void Pow(const size_t num, const Block *in1, const Block *in2, Block *out,
+void Pow(const Tensor *in1, const Tensor *in2, Tensor *out,
          Context *ctx) {
   LOG(FATAL) << "Pow-Pair Not Implemented";
 }
 
 /// out[i]=max(0, in[i])
 template <typename DType, typename Lang>
-void ReLU(const size_t num, const Block *in, Block *out, Context *ctx) {
+void ReLU(const Tensor *in, Tensor *out, Context *ctx) {
   LOG(FATAL) << "ReLU Not Implemented";
 }
 
 /// out[i] = x
 template <typename DType, typename Lang>
-void Set(const size_t num, const DType x, Block *out, Context *ctx) {
+void Set(const DType x, Tensor *out, Context *ctx) {
   LOG(FATAL) << "Set Not Implemented";
 }
 /// out[i]=sigmoid(in[i])
 template <typename DType, typename Lang>
-void Sigmoid(const size_t num, const Block *in, Block *out, Context *ctx) {
+void Sigmoid(const Tensor *in, Tensor *out, Context *ctx) {
   LOG(FATAL) << "Sigmoid Not Implemented";
 }
 
 /// out[i] = sign(in[i])
 template <typename DType, typename Lang>
-void Sign(const size_t num, const Block *in, Block *out, Context *ctx) {
+void Sign(const Tensor *in, Tensor *out, Context *ctx) {
   LOG(FATAL) << "Sign Not Implemented";
 }
 /// out[i]=sqrt(in[i])
 template <typename DType, typename Lang>
-void Sqrt(const size_t num, const Block *in, Block *out, Context *ctx) {
+void Sqrt(const Tensor *in, Tensor *out, Context *ctx) {
   LOG(FATAL) << "Sqrt Not Implemented";
 }
 
 /// out[i]=square(in[i])
 template <typename DType, typename Lang>
-void Square(const size_t num, const Block *in, Block *out, Context *ctx) {
-  EltwiseMult<DType, Lang>(num, in, in, out, ctx);
+void Square(const Tensor *in, Tensor *out, Context *ctx) {
+  EltwiseMult<DType, Lang>(in, in, out, ctx);
 }
 
 /// out[i] =  in[i] - x
 template <typename DType, typename Lang>
-void Sub(const size_t num, const Block *in, const DType x, Block *out,
+void Sub(const Tensor *in, const DType x, Tensor *out,
          Context *ctx) {
-  Add<DType, Lang>(num, in, -x, out, ctx);
+  Add<DType, Lang>(in, -x, out, ctx);
 }
 
 /// out[i] = in1[i] - in2[i]
 template <typename DType, typename Lang>
-void Sub(const size_t num, const Block *in1, const Block *in2, Block *out,
+void Sub(const Tensor *in1, const Tensor *in2, Tensor *out,
          Context *ctx) {
   LOG(FATAL) << "Sub-Pair Not Implemented";
 }
 
 /// sum all elements of in into out
 template <typename DType, typename Lang>
-void Sum(const size_t num, const Block *in, DType *out, Context *ctx) {
+void Sum(const Tensor *in, DType *out, Context *ctx) {
   LOG(FATAL) << "Sum Not Implemented";
 }
 
 /// out[i]=tanh(in[i])
 template <typename DType, typename Lang>
-void Tanh(const size_t num, const Block *in, Block *out, Context *ctx) {
+void Tanh(const Tensor *in, Tensor *out, Context *ctx) {
   LOG(FATAL) << "Tanh Not Implemented";
 }
 
@@ -255,20 +289,20 @@ void Tanh(const size_t num, const Block *in, Block *out, Context *ctx) {
 // Get the random generator from 'ctx'
 // If DType is not float, then convert the threshold to DType
 template <typename DType, typename Lang>
-void Bernoulli(const size_t num, const float p, Block *out, Context *ctx) {
+void Bernoulli(const float p, Tensor *out, Context *ctx) {
   LOG(FATAL) << "Bernoulli Not Implemented";
 }
 // The random generator should be extracted from ctx.
 // If DType is not float, then convert the mean and std to DType
 template <typename DType, typename Lang>
-void Gaussian(const size_t num, const float mean, const float std, Block *out,
+void Gaussian(const float mean, const float std, Tensor *out,
               Context *ctx) {
   LOG(FATAL) << "Gaussian Not Implemented";
 }
 // The random generator should be extracted from ctx.
 // If DType is not float, then convert the low and high to DType
 template <typename DType, typename Lang>
-void Uniform(const size_t num, const float low, const float high, Block *out,
+void Uniform(const float low, const float high, Tensor *out,
              Context *ctx) {
   LOG(FATAL) << "Uniform Not Implemented";
 }
@@ -279,43 +313,43 @@ void Uniform(const size_t num, const float low, const float high, Block *out,
 
 /// outurn the index of the element with the max value.
 template <typename DType, typename Lang>
-void Amax(const size_t num, const Block *in, size_t *out, Context *ctx) {
+void Amax(const Tensor *in, size_t *out, Context *ctx) {
   LOG(FATAL) << "Amax Not Implemented";
 }
 
 /// outurn the index of the element with the min value.
 template <typename DType, typename Lang>
-void Amin(const size_t num, const Block *in, size_t *out, Context *ctx) {
+void Amin(const Tensor *in, size_t *out, Context *ctx) {
   LOG(FATAL) << "Amin Not Implemented";
 }
 /// out = sum |x| for all x in in
 template <typename DType, typename Lang>
-void Asum(const size_t num, const Block *in, DType *out, Context *ctx) {
+void Asum(const Tensor *in, DType *out, Context *ctx) {
   LOG(FATAL) << "Asum Not Implemented";
 }
 
 /// out = alpha * in + out
 template <typename DType, typename Lang>
-void Axpy(const size_t num, const DType alpha, const Block *in, Block *out,
+void Axpy(const DType alpha, const Tensor *in, Tensor *out,
           Context *ctx) {
   LOG(FATAL) << "Axpy Not Implemented";
 }
 
 /// out = ||in||_2^2, i.e, L2 norm.
 template <typename DType, typename Lang>
-void Nrm2(const size_t num, const Block *in, float *out, Context *ctx) {
+void Nrm2(const Tensor *in, float *out, Context *ctx) {
   LOG(FATAL) << "Nrm2 Not Implemented";
 }
 
 /// out *= x
 template <typename DType, typename Lang>
-void Scale(const size_t num, const DType x, Block *out, Context *ctx) {
+void Scale(const DType x, Tensor *out, Context *ctx) {
   LOG(FATAL) << "Scale Not Implemented";
 }
 
 /// inner product of array in1 and in2
 template <typename DType, typename Lang>
-void Dot(const size_t num, const Block *in1, const Block *in2, DType *out,
+void Dot(const Tensor *in1, const Tensor *in2, DType *out,
          Context *ctx) {
   LOG(FATAL) << "Dot Not Implemented";
 }
@@ -323,8 +357,8 @@ void Dot(const size_t num, const Block *in1, const Block *in2, DType *out,
 /// out = alpha * A * v + beta * out.
 /// transA indicates if the internal data layout is transposed of A
 template <typename DType, typename Lang>
-void GEMV(bool trans, const size_t m, const size_t n, const DType alpha,
-          const Block *A, const Block *v, const DType beta, Block *out,
+void GEMV(const DType alpha,
+          const Tensor *A, const Tensor *v, const DType beta, Tensor *out,
           Context *ctx) {
   LOG(FATAL) << "GEMV Not Implemented";
 }
@@ -332,21 +366,21 @@ void GEMV(bool trans, const size_t m, const size_t n, const DType alpha,
 /// multiply a matrix with a diagnoal matrix constructed using values from 'v'.
 /// if matrix_lef_side is true, do M*v; else do v*M
 template <typename DType, typename Lang>
-void DGMM(const bool side_right, const size_t nrow, const size_t ncol,
-          const Block *M, const Block *v, Block *out, Context *ctx) {
+void DGMM(const bool side_right,
+  const Tensor *M, const Tensor *v, Tensor *out, Context *ctx) {
   LOG(FATAL) << "DGMM Not Implemented";
 }
 
 /// C = alpha * A * B + beta * C.
 /// transA indicates if the internal data layout is transposed of A
 template <typename DType, typename Lang>
-void GEMM(const bool transA, const bool transB, const size_t nrowA,
-          const size_t ncolB, const size_t ncolA, const DType alpha,
-          const Block *A, const Block *B, const DType beta, Block *C,
+void GEMM(const DType alpha,
+          const Tensor *A, const Tensor *B, const DType beta, Tensor *C,
           Context *ctx) {
   LOG(FATAL) << "GEMM Not Implemented";
 }
 
+//yisen todo
 template <typename DType, typename Lang>
 void ComputeCrossEntropy(bool int_target, const size_t batchsize,
                          const size_t dim, const Block *p, const Block *t,
@@ -362,8 +396,7 @@ void SoftmaxCrossEntropyBwd(bool int_target, const size_t batchsize,
 }
 
 template <typename DType, typename Lang>
-void RowMax(const size_t nrow, const size_t ncol, const Block *in,
-    Block *ret, Context* ctx) {
+void RowMax(const Tensor *in, Tensor *out, Context* ctx) {
   LOG(FATAL) << "Not Implemented";
 }
 // **************************************
@@ -372,40 +405,40 @@ void RowMax(const size_t nrow, const size_t ncol, const Block *in,
 /*
 /// Add the vector v to every column of A as the column of out
 template <typename DType, typename Lang>
-void AddCol(const size_t nrow, const size_t ncol, const Block *A, const Block *v,
-            Block *out, Context *ctx) {
+void AddCol(const size_t nrow, const size_t ncol, const Tensor *A, const Tensor *v,
+            Tensor *out, Context *ctx) {
   LOG(FATAL) << "AddCol Not Implemented";
 }
 // TODO(wangwei) unify AddRow and AddCol.
 /// Add the vector v to every row of A as the row of out
 template <typename DType, typename Lang>
-void AddRow(const size_t nrow, const size_t ncol, const Block *A, const Block *v,
-            Block *out, Context *ctx) {
+void AddRow(const size_t nrow, const size_t ncol, const Tensor *A, const Tensor *v,
+            Tensor *out, Context *ctx) {
   LOG(FATAL) << "AddRow Not Implemented";
 }
 /// outer-product.
 /// in1 and in2 are vectors of len m and n. out is matrix of shape m * n
 template <typename DType, typename Lang>
-void Outer(const size_t m, const size_t n, const Block *in1, const Block *in2,
-           Block *out, Context *ctx) {
+void Outer(const size_t m, const size_t n, const Tensor *in1, const Tensor *in2,
+           Tensor *out, Context *ctx) {
   LOG(FATAL) << "Outer Not Implemented";
 }
 
 /// Sum the columns of the in matrix into a vector
 template <typename DType, typename Lang>
-void SumColumns(const size_t nrow, const size_t ncol, const Block *in, Block *out,
+void SumColumns(const size_t nrow, const size_t ncol, const Tensor *in, Tensor *out,
                 Context *ctx) {
   LOG(FATAL) << "SumColumns Not Implemented";
 }
 template <typename DType, typename Lang>
-void Set(const size_t num, const DType x, Block *out, Context *ctx) {
+void Set(const DType x, Tensor *out, Context *ctx) {
   LOG(FATAL) << "Not Implemented";
 }
 
 // TODO(wangwei) unify SumRow and SumCol.
 /// Sum the rows of the in matrix into a vector
 template <typename DType, typename Lang>
-void SumRows(const size_t nrow, const size_t ncol, const Block *in, Block *out,
+void SumRows(const size_t nrow, const size_t ncol, const Tensor *in, Tensor *out,
              Context *ctx) {
   LOG(FATAL) << "SumRows Not Implemented";
 }