You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@singa.apache.org by wa...@apache.org on 2018/07/01 13:10:34 UTC

[5/7] incubator-singa git commit: SINGA-362 Add functions to support einsum function 1. change the api of repeat, reshape, transpose to be similar to numpy api 2. have some change in the reshape function to make it modified to 'Tensor Reshape' instead of '

SINGA-362 Add functions to support einsum function
1. change the api of repeat,reshape,transpose to be similar to numpy api
2. have some change in the reshape function to make it modified to 'Tensor Reshape' instead of 'void Reshape'(it is same as Yisen's revise)


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/4940fefb
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/4940fefb
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/4940fefb

Branch: refs/heads/master
Commit: 4940fefbf65f0da474aff71b23bc60656aa40dc5
Parents: 8d9eb29
Author: sheyujian <sh...@me.com>
Authored: Thu May 31 10:28:12 2018 +0800
Committer: sheyujian <sh...@me.com>
Committed: Sat Jun 2 01:38:15 2018 +0800

----------------------------------------------------------------------
 examples/cifar10/cnn.cc           |   2 +-
 include/singa/core/device.h       |   5 -
 include/singa/core/tensor.h       |  14 ++-
 python/singa/tensor.py            | 194 +++++++++++++++------------------
 src/api/core_tensor.i             |   5 +-
 src/core/device/device.cc         |  23 ----
 src/core/tensor/tensor.cc         | 111 +++++++++++--------
 src/core/tensor/tensor_math.h     |   7 ++
 src/core/tensor/tensor_math_cpp.h |  14 +++
 src/io/image_transformer.cc       |  12 +-
 test/python/test_tensor.py        |   8 +-
 11 files changed, 199 insertions(+), 196 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/4940fefb/examples/cifar10/cnn.cc
----------------------------------------------------------------------
diff --git a/examples/cifar10/cnn.cc b/examples/cifar10/cnn.cc
index 61097b6..8af8a2f 100644
--- a/examples/cifar10/cnn.cc
+++ b/examples/cifar10/cnn.cc
@@ -144,7 +144,7 @@ void Train(int num_epoch, string data_dir) {
     auto train = data.ReadTrainData();
     size_t nsamples = train.first.shape(0);
     auto mtrain =
-        Reshape(train.first, Shape{nsamples, train.first.Size() / nsamples});
+         Reshape(train.first, Shape{nsamples, train.first.Size() / nsamples});
     const Tensor& mean = Average(mtrain, 0);
     SubRow(mean, &mtrain);
     train_x = Reshape(mtrain, train.first.shape());

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/4940fefb/include/singa/core/device.h
----------------------------------------------------------------------
diff --git a/include/singa/core/device.h b/include/singa/core/device.h
index d6b8bf3..1a960d8 100644
--- a/include/singa/core/device.h
+++ b/include/singa/core/device.h
@@ -75,11 +75,6 @@ class Device {
   virtual void CopyDataToFrom(Block* dst, Block* src, size_t nBytes,
                       CopyDirection direction, int dst_offset, int src_offset);
 
-  virtual void RepeatDataToFrom(Block* dst, Block* src, size_t nBytes,
-                                CopyDirection direct, bool broadcast_flag, 
-                                int axis_shape, int shape_outer, int chunk, 
-                                vector<size_t> repeats, int dst_offset, int src_offset);
-
   void CopyDataFromHostPtr(Block* dst, const void* src, size_t nBytes,
                            size_t dst_offset = 0);
   /// Submit the operation to the device, which may execute it right now or

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/4940fefb/include/singa/core/tensor.h
----------------------------------------------------------------------
diff --git a/include/singa/core/tensor.h b/include/singa/core/tensor.h
index 7947d93..d9bb069 100644
--- a/include/singa/core/tensor.h
+++ b/include/singa/core/tensor.h
@@ -133,8 +133,8 @@ class Tensor {
   size_t MemSize() const { return block_->size(); }
 
   /// Reset the tensor shape, it may reallocate block, if MemSize() changes.
-  void Reshape(const Shape &shape);
-  void Reshape(Shape &&shape);
+  // void Reshape(const Shape &shape);
+  // void Reshape(Shape &&shape);
 
   /// Reset the shape, device, and data type as given tensor.
   /// If block size changes, then reallocate a new block.
@@ -191,6 +191,10 @@ class Tensor {
   /// Change the axes
   Tensor Transpose(const vector<size_t> &axes) const;
 
+  Tensor Reshape(const Shape &shape);
+
+  Tensor Reshape(Shape &&shape);
+
   /// Copy the meta info with data block shared.
   Tensor &operator=(const Tensor &in);
 
@@ -269,6 +273,7 @@ inline size_t Product(const Shape &shape, int start = 0, size_t len = 0) {
   return v;
 }
 
+
 inline void CheckDataTypeAndLang(const Tensor &in1, const Tensor &in2) {
   CHECK_EQ(in1.data_type(), in2.data_type());
   CHECK_EQ(in1.device()->lang(), in2.device()->lang());
@@ -292,8 +297,7 @@ void CopyDataToFrom(Tensor *dst, const Tensor &src, const size_t num,
                     const size_t dst_offset = 0, const size_t src_offset = 0);
 
 void RepeatDataToFrom(bool broadcast_flag, vector<size_t> repeats, int axis, 
-                      Tensor *dst, const Tensor &in, const size_t num, 
-                      const size_t dst_offset = 0, const size_t src_offset = 0);
+                      Tensor *dst, const Tensor &in, const size_t num);
 
 // =============Element-wise operations====================================
 Tensor Abs(const Tensor &in);
@@ -305,6 +309,7 @@ Tensor Sign(const Tensor &in);
 Tensor Sqrt(const Tensor &in);
 Tensor Square(const Tensor &in);
 Tensor Tanh(const Tensor &in);
+Tensor Transform(const Tensor &in);
 
 void Abs(const Tensor &in, Tensor *out);
 void Exp(const Tensor &in, Tensor *out);
@@ -315,6 +320,7 @@ void Sign(const Tensor &in, Tensor *out);
 void Sqrt(const Tensor &in, Tensor *out);
 void Square(const Tensor &in, Tensor *out);
 void Tanh(const Tensor &in, Tensor *out);
+void Transform(const Tensor &in, Tensor *out);
 
 /// Element-wise opeartion, out[i]=in[i]^x
 template <typename SType>

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/4940fefb/python/singa/tensor.py
----------------------------------------------------------------------
diff --git a/python/singa/tensor.py b/python/singa/tensor.py
index 21a362a..ba8d02c 100644
--- a/python/singa/tensor.py
+++ b/python/singa/tensor.py
@@ -71,8 +71,6 @@ float32 = core_pb2.kFloat32
 CTensor = singa.Tensor
 
 
-
-
 class Tensor(object):
     '''Python Tensor, which wraps a swig converted Tensor from CPP Tensor.
 
@@ -140,16 +138,18 @@ class Tensor(object):
         '''
         To transpose the tensor
         '''
+        t = Tensor(self.shape, self.device, self.dtype)
         if axes == None:
-            tshape = [self.shape[x] for x in range(len(self.shape))]
-            self.shape = tuple(tshape)
-            self.data = self.data.Transpose()
+            tshape = [self.shape[x] for x in range(len(t.shape))]
+            t.shape = tuple(tshape)
+            t.data = self.data.Transpose()
         else:
             if(len(axes) != len(self.shape)):
                 raise ValueError('dimensions do not match')
             tshape = [self.shape[x] for x in axes]
-            self.shape = tuple(tshape)
-            self.data = self.data.Transpose(list(axes))
+            t.shape = tuple(tshape)
+            t.data = self.data.Transpose(list(axes))
+        return t
 
     def size(self):  # TODO(wangwei) compute size
         '''
@@ -172,10 +172,12 @@ class Tensor(object):
             shape (list<int>): new shape, which should have the same volumn as
                 the original shape.
         '''
+        t = Tensor(self.shape, self.device, self.dtype)
         assert product(self.shape) == product(shape), \
             'product of shape should be equal'
-        self.shape = shape
-        self.data.Reshape(list(shape))
+        t.shape = shape
+        t.data = self.data.Reshape(list(shape))
+        return t
 
     def reset_like(self, t):
         '''Reset the shape, dtype and device as the given tensor.
@@ -283,6 +285,7 @@ class Tensor(object):
             the tensor which has been repeated
         
         '''
+        t = Tensor()
         t_ndim = self.ndim()
         if isinstance(repeats, int) or isinstance(repeats, long):
             if repeats < 0:
@@ -292,15 +295,15 @@ class Tensor(object):
             # broadcast = True
             if axis == None:
                 axis = 9999
-                self.shape = (product(self.shape)*repeats,)
+                t.shape = (product(self.shape)*repeats,)
                 Repeats = [repeats,]
-                self.data = self.data.Repeat(Repeats, axis)
+                t.data = self.data.Repeat(Repeats, axis)
             elif axis >= 0:
                 t_shape = list(self.shape)
                 t_shape[axis] = self.shape[axis]*repeats
-                self.shape = tuple(t_shape)
+                t.shape = tuple(t_shape)
                 Repeats = [repeats,]
-                self.data = self.data.Repeat(Repeats, axis)
+                t.data = self.data.Repeat(Repeats, axis)
 
         elif isinstance(repeats, tuple) or isinstance(repeats, list):
             for rep in repeats:
@@ -315,13 +318,12 @@ class Tensor(object):
             elif axis >= 0:
                 t_shape = list(self.shape)
                 t_shape[axis] = sum(repeats)
-                self.shape = tuple(t_shape)
-                self.data = self.data.Repeat(list(repeats), axis)
+                t.shape = tuple(t_shape)
+                t.data = self.data.Repeat(list(repeats), axis)
         else:
             raise ValueError('repeats should be int or sequence')
-        
 
-        
+        return t     
 
     def T(self):
         ''' shallow copy, negate the transpose field.
@@ -623,8 +625,8 @@ def reshape(t, s):
     return _call_singa_func(singa.Reshape, t.data, s)
 
 def Reshape(t,s):
-    ret = t.deepcopy()
-    ret.reshape(s)
+
+    ret = t.reshape(s)
     return ret
 
 def transpose(t,axes = None):
@@ -632,8 +634,7 @@ def transpose(t,axes = None):
     Returns:
         the transposed tensor 
     '''
-    ret = t.deepcopy()
-    ret.transpose(axes)
+    ret = t.transpose(axes)
     return ret
 
 
@@ -795,24 +796,63 @@ def tanh(t):
     '''
     return _call_singa_func(singa.Tanh, t.data)
 
-
-def sum(t, axis=None):
-    '''Sum elements of the input tensor long the given axis.
+def sum(t, axis=None, out=None):
+    '''Sum of tensor elements over given axis
 
     Args:
-        t (Tensor): input Tensor
-        axis (int, optional): if None, the summation is done over all elements;
-            if axis is provided, then it is calculated along the given axis,
-            e.g. 0 -- sum each column; 1 -- sum each row.
+        t: Singa.tensor
+            The array_like tensor to be sumed
+        axis: None or int or tuple of ints, optional
+            Axis or axes along which a sum is performed.
+            The default, axis=None, will sum all of the elements of the input array.
+            If axis is negative it counts from the last to the first axis.
+            If axis is a tuple of ints, a sum is performed on all of the axes specified
+            in the tuple instead of a single axis or all the axes as before.
+        out:Singa.tensor optional
+            Alternative output array in which to place the result.
+            It must have the same shape as the expected output,
+            but the type of the output values will be cast if necessary.
 
-    Returns:
-        a float value as the sum of all elements, or a new Tensor
+    Return: sum_along_axis: tensor
+        A tensor with the same shape as t, with the specified axis removed.
+        If a is a 0-d array, or if axis is None, a scalar is returned.
+        If an output array is specified, a reference to out is returned
     '''
 
+    t_shape = t.shape
+    t_ndim = t.ndim()
+
     if axis is None:
-        return singa.SumAsFloat(t.data)
+        one = Tensor(t.shape, t.device)
+        one.set_value(1.0)
+        ret = tensordot(t, one, t_ndim)
+
+    if isinstance(axis,int):
+        if axis < 0:
+            axis += t_ndim
+
+        axis_shape = t_shape[axis]
+        axis_shape = int(axis_shape)
+        one = Tensor(shape = (axis_shape, ), device = t.device)
+        one.set_value(1.0)
+        ret = tensordot(t, one, axes=([axis],[0]))
+
+    if isinstance(axis,tuple):
+        l_axis = list(axis)
+        axis_shape = [t_shape[x] for x in axis]
+        axisshape = tuple(axis_shape)
+        one = Tensor(axisshape, t.device)
+        one.set_value(1.0)
+        one_axis = [x for x in range(one.ndim())]
+        ret = tensordot(t, one, (l_axis,one_axis))
+
+    if out is not None:
+        if out.shape != ret.shape:
+            raise ValueError('dimensions do not match')
+        out[:] = ret
+        return out
     else:
-        return _call_singa_func(singa.Sum, t.data, axis)
+        return ret
 
 
 def pow(t, x, out=None):
@@ -1143,10 +1183,10 @@ def einsum(ops, *args):
     if len(broadcast_b) == 0:
         broadcast_b = [1]  
     mult_A = repeat(A, product(broadcast_a))
-    mult_A.reshape(reshape_A)
+    mult_A = mult_A.reshape(reshape_A)
     mult_A = transpose(mult_A,transpose_A)
     mult_B = repeat(B, product(broadcast_b))
-    mult_B.reshape(reshape_B)
+    mult_B = mult_B.reshape(reshape_B)
     mult_B = transpose(mult_B, transpose_B)
 
     if mult_A.shape != mult_B.shape:
@@ -1154,77 +1194,26 @@ def einsum(ops, *args):
     res = eltwise_mult(mult_A, mult_B)
     sum_R = sorted(sums, reverse=True)
     for i in sum_R:
-        res = sum2(res, axis=i)
+        res = sum(res, axis=i)
     transpose_res = [sorted(list(outputops)).index(x) for x in list(outputops)]
     res = transpose(res, transpose_res)
 
     return res
     
 
-
-
-def sum2(t, axis=None, out=None):
-    '''Sum of tensor elements over given axis
-
+def repeat (t, repeats, axis = None):
+    '''Return the repeated tensor
     Args:
-        t: Singa.tensor
-            The array_like tensor to be sumed
-        axis: None or int or tuple of ints, optional
-            Axis or axes along which a sum is performed.
-            The default, axis=None, will sum all of the elements of the input array.
-            If axis is negative it counts from the last to the first axis.
-            If axis is a tuple of ints, a sum is performed on all of the axes specified
-            in the tuple instead of a single axis or all the axes as before.
-        out:Singa.tensor optional
-            Alternative output array in which to place the result.
-            It must have the same shape as the expected output,
-            but the type of the output values will be cast if necessary.
+        t(tensor): the tensor to be repeated
+        repeats(int or a sequence): the number that the tensor need to repeat for
+        axis (int):the axis to do repeat
+                    If it is None, then the repeated tensor will be flattened.If it isn't None,
+                    the repeats could be sequence, but it's size should match the axis's shape
 
-    Return: sum_along_axis: tensor
-        A tensor with the same shape as t, with the specified axis removed.
-        If a is a 0-d array, or if axis is None, a scalar is returned.
-        If an output array is specified, a reference to out is returned
+    Return:
+        the tensor which has been repeated
     '''
-
-    t_shape = t.shape
-    t_ndim = t.ndim()
-
-    if axis is None:
-        one = Tensor(t.shape, t.device)
-        one.set_value(1.0)
-        ret = tensordot(t, one, t_ndim)
-
-    if isinstance(axis,int):
-        if axis < 0:
-            axis += t_ndim
-
-        axis_shape = t_shape[axis]
-        axis_shape = int(axis_shape)
-        one = Tensor(shape = (axis_shape, ), device = t.device)
-        one.set_value(1.0)
-        ret = tensordot(t, one, axes=([axis],[0]))
-
-    if isinstance(axis,tuple):
-        l_axis = list(axis)
-        axis_shape = [t_shape[x] for x in axis]
-        axisshape = tuple(axis_shape)
-        one = Tensor(axisshape, t.device)
-        one.set_value(1.0)
-        one_axis = [x for x in range(one.ndim())]
-        ret = tensordot(t, one, (l_axis,one_axis))
-
-    if out is not None:
-        if out.shape != ret.shape:
-            raise ValueError('dimensions do not match')
-        out[:] = ret
-        return out
-    else:
-        return ret
-
-def repeat (t, repeats, axis = None):
-
-    ret = t.deepcopy()
-    ret.repeat(repeats,axis)
+    ret = t.repeat(repeats,axis)
     return ret
 
         
@@ -1325,18 +1314,9 @@ def tensordot (A,B,axes=2):
         N1 *= b_shape[bx]
     newshape_b = (N2, N1)
     oldb = [b_shape[axis] for axis in notin]
-    # do transpose and reshape to get the 2D matrix to do multiplication
-    # A_ = to_numpy(A)
-    # B_ = to_numpy(B)
-    # at_ = np.transpose(A_,newaxes_a).reshape(newshape_a)
-    # bt_ = np.transpose(B_,newaxes_b).reshape(newshape_b)
-    # at = from_numpy(at_)
-    # bt = from_numpy(bt_)
 
     A = transpose(A, newaxes_a)
     B = transpose(B, newaxes_b)
-    A = add(A, 0)
-    B = add(B, 0)
     at = Reshape(A, newshape_a)
     bt = Reshape(B, newshape_b)
 
@@ -1344,9 +1324,9 @@ def tensordot (A,B,axes=2):
     if len(olda + oldb) == 0:
         olda = [1]
         oldb = [1]
-        res.reshape(tuple(olda + oldb))
+        res = res.reshape(tuple(olda + oldb))
     else:
-        res.reshape(tuple(olda + oldb))
+        res = res.reshape(tuple(olda + oldb))
 
     return res
 

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/4940fefb/src/api/core_tensor.i
----------------------------------------------------------------------
diff --git a/src/api/core_tensor.i b/src/api/core_tensor.i
index 587dddd..d94e506 100644
--- a/src/api/core_tensor.i
+++ b/src/api/core_tensor.i
@@ -106,7 +106,7 @@ namespace singa{
     Tensor Transpose(const std::vector<size_t> &axes) const;
     size_t Size() const;
     size_t MemSize() const;
-    void Reshape(const std::vector<size_t> &shape);
+    Tensor Reshape(const std::vector<size_t> &shape);
     void ResetLike(const Tensor &t);
     void AsType(DataType type);
     void ToDevice(std::shared_ptr<singa::Device> dev);
@@ -163,8 +163,7 @@ namespace singa{
                       size_t src_offset = 0, size_t dst_offset = 0);
 
   void RepeatDataToFrom(bool broadcast_flag, std::vector<size_t> repeats, int axis, 
-                        Tensor *dst, const Tensor &src, const size_t num, 
-                        const size_t dst_offset, const size_t src_offset);
+                        Tensor *dst, const Tensor &src, const size_t num);
 
   Tensor Reshape(const Tensor &in, const std::vector<size_t> &s);
 

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/4940fefb/src/core/device/device.cc
----------------------------------------------------------------------
diff --git a/src/core/device/device.cc b/src/core/device/device.cc
index 0c9c6a2..cda1b9f 100644
--- a/src/core/device/device.cc
+++ b/src/core/device/device.cc
@@ -64,29 +64,6 @@ void Device::CopyDataToFrom(Block* dst, Block* src, size_t nBytes,
       {src}, {dst});
 }
 
-void Device::RepeatDataToFrom(Block* dst, Block* src, size_t nBytes,
-                              CopyDirection direct, bool broadcast_flag, 
-                              int axis_shape, int shape_outer, int chunk, 
-                              vector<size_t> repeats, int dst_offset, int src_offset) {
-  const char *src_data = reinterpret_cast<const char*>(src->data()) + src_offset;
-  char *dst_data = reinterpret_cast<char*>(dst->mutable_data()) + dst_offset;
-
-  for (int i = 0; i < shape_outer; i++) {
-    for (int j = 0; j < axis_shape; j++) {
-      int temp = broadcast_flag ? repeats[0] : repeats[j];
-      for (int k = 0; k < temp; k++) {
-        this->Exec(
-            [this, dst_data, src_data, direct, chunk, repeats](Context* ctx) {
-              this->CopyToFrom(dst_data, src_data, chunk, direct, ctx);
-            },
-            {src}, {dst});
-        dst_data += chunk;
-      }
-      src_data += chunk;
-    }
-  }
-}
-
 void Device::CopyDataFromHostPtr(Block* dst, const void* src, size_t nBytes,
                                  size_t dst_offset) {
   auto direct = lang_ == kCpp ? kHostToHost : kHostToDevice;

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/4940fefb/src/core/tensor/tensor.cc
----------------------------------------------------------------------
diff --git a/src/core/tensor/tensor.cc b/src/core/tensor/tensor.cc
index de2ea8a..3bf0a77 100644
--- a/src/core/tensor/tensor.cc
+++ b/src/core/tensor/tensor.cc
@@ -124,12 +124,7 @@ void Tensor::ResetLike(const Tensor &in) {
   strides_ = in.strides_;
 }
 
-// if tensor is not transposed yet i.e strides == 1,
-// then we simply change the shape and generate new default strides
-// if tensor is already transposed i.e strides != 1,
-// it should be copied to a new tensor with newly generated default strides
-// TODO(wangwei) raise error if the shape not match
-void Tensor::Reshape(const Shape &shape) {
+Tensor Tensor::Reshape(const Shape &shape) {
   if (strides_.size() == 0)
     strides_.push_back(1);
 
@@ -137,14 +132,27 @@ void Tensor::Reshape(const Shape &shape) {
     if (block_ != nullptr && block_->DecRefCount() == 0)
       device_->FreeBlock(block_);
     block_ = device_->NewBlock((int)(Product(shape) * SizeOf(data_type_)));
+    shape_ = shape;
+    generate_strides();
+    return *this;
+
   } else if (transpose()) {
-    LOG(FATAL) << "Reshape Error: Reshape called on tranposed tensor. Not implemented yet." ;
-  }
+    Tensor t(shape_, device_, data_type_);
+    t.block_ = t.device()->NewBlock((int)(Product(shape) * SizeOf(data_type_)));
+    singa::Transform(*this, &t);
+    t.shape_ = shape;
+    return t;
+ }
+
   shape_ = shape;
   generate_strides();
+  Tensor t(shape, device_, data_type_);
+  t.block_ = block_;
+  t.block_->IncRefCount();
+  return t;
 }
 
-void Tensor::Reshape(Shape &&shape) {
+Tensor Tensor::Reshape(Shape &&shape) {
   if (strides_.size() == 0)
     strides_.push_back(1);
 
@@ -152,11 +160,24 @@ void Tensor::Reshape(Shape &&shape) {
     if (block_ != nullptr && block_->DecRefCount() == 0)
       device_->FreeBlock(block_);
     block_ = device_->NewBlock((int)(Product(shape) * SizeOf(data_type_)));
+    shape_ = std::move(shape);
+    generate_strides();
+    return *this;
+
   } else if (transpose()) {
-    LOG(FATAL) << "Reshape Error: Reshape called on tranposed tensor. Not implemented yet." ;
-  }
-  shape_ = std::move(shape);
+    Tensor t(shape_, device_, data_type_);
+    t.block_ = t.device()->NewBlock((int)(Product(shape) * SizeOf(data_type_)));
+    singa::Transform(*this, &t);
+    t.shape_ = shape;
+    return t;
+ }
+
+  shape_ = shape;
   generate_strides();
+  Tensor t(shape, device_, data_type_);
+  t.block_ = block_;
+  t.block_->IncRefCount();
+  return t;
 }
 
 void Tensor::AsType(const DataType type) {
@@ -226,7 +247,7 @@ void Tensor::RepeatData(vector<size_t> repeats, int axis, int total_repeats, con
   CHECK(block_ != nullptr);
   // Do repeat only if the src's block is already initialized.
   if (src.block_ != nullptr) {
-    singa::RepeatDataToFrom(false, repeats, axis, this, src, Size(), 0, 0);
+    singa::RepeatDataToFrom(false, repeats, axis, this, src, Size());
   }
 }
 
@@ -234,10 +255,9 @@ void Tensor::FromProto(const singa::TensorProto &proto) {
   if (block_ != nullptr && block_->DecRefCount() == 0)
     device_->FreeBlock(block_);
   block_ = nullptr;
-  Shape shape;
-  for (uint32_t s : proto.shape()) shape.push_back(s);
+  for (uint32_t s : proto.shape()) shape_.push_back(s);
   data_type_ = proto.data_type();
-  Reshape(shape);
+  block_ = device_->NewBlock((int)(Product(shape()) * SizeOf(data_type_)));
   //transpose_ = proto.transpose();
   strides_.clear();
   for (int32_t s : proto.strides()) strides_.push_back(s);
@@ -477,13 +497,13 @@ Tensor &Tensor::operator=(Tensor &&in) {
 //yisen todo
 Tensor Reshape(const Tensor &in, const Shape &s) {
   Tensor out(in);
-  out.Reshape(s);
+  out = out.Reshape(s);
   return out;
 }
 
 Tensor Reshape(const Tensor &in, Shape &&s) {
   Tensor out(in);
-  out.Reshape(std::move(s));
+  out = out.Reshape(std::move(s));
   return out;
 }
 
@@ -542,12 +562,10 @@ void CopyDataToFrom(Tensor *dst, const Tensor &src, const size_t num,
 }
 
 void RepeatDataToFrom(bool broadcast_flag, vector<size_t> repeats, int axis, 
-                      Tensor *dst, const Tensor &src, const size_t num, 
-                      const size_t dst_offset, const size_t src_offset) {
+                      Tensor *dst, const Tensor &src, const size_t num) {
   if (repeats.size() == 1) {
     broadcast_flag = true;
-  }
-  if (repeats.size() > 1) {
+  } else if (repeats.size() > 1) {
     if (axis == Noaxis) {
       LOG(FATAL) << "When repeats parameter is sequence, axis cannot be None";
     }
@@ -557,9 +575,7 @@ void RepeatDataToFrom(bool broadcast_flag, vector<size_t> repeats, int axis,
   }
   auto width = SizeOf(src.data_type());
   CHECK_EQ(width, SizeOf(dst->data_type()));
-  size_t nBytes = num * width;
-  auto d_offset = dst_offset * width;
-  auto s_offset = src_offset * width;
+  // size_t nBytes = num * width;
   int chunk = width;
   int axis_shape = 1;
   int shape_outer = 1;
@@ -575,26 +591,34 @@ void RepeatDataToFrom(bool broadcast_flag, vector<size_t> repeats, int axis,
       chunk *= src.shape()[i];
     }
   }
-  
+  int dst_offset = 0;
+  int src_offset = 0;
   std::shared_ptr<Device> src_dev = src.device(), dst_dev = dst->device();
   Block *from = src.block(), *to = dst->block();
-  if (dst_dev->lang() != src_dev->lang()) {
-    // let the none cpp device conduct copy op
-    if (dst_dev->lang() == kCpp) {
-      src_dev->RepeatDataToFrom(to, from, nBytes, kDeviceToHost, broadcast_flag, axis_shape, 
-                                shape_outer, chunk, repeats, (int)d_offset, (int)s_offset);
-    } else if (src_dev->lang() == kCpp) {
-      dst_dev->RepeatDataToFrom(to, from, nBytes, kHostToDevice, broadcast_flag, axis_shape, 
-                                shape_outer, chunk, repeats, (int)d_offset, (int)s_offset);
-    } else {
-      LOG(FATAL) << "Not support mem repeat copy betwee Cuda and OpenCL device";
+  for (int i = 0; i < shape_outer; i++) {
+    for (int j = 0; j < axis_shape; j++) {
+      int temp = broadcast_flag ? repeats[0] : repeats[j];
+      for (int k = 0; k < temp; k++) {
+        if (dst_dev->lang() != src_dev->lang()) {
+          // let the none cpp device conduct copy op
+          if (dst_dev->lang() == kCpp) {
+            src_dev->CopyDataToFrom(to, from, chunk, kDeviceToHost, dst_offset, src_offset);
+          } else if (src_dev->lang() == kCpp) {
+            dst_dev->CopyDataToFrom(to, from, chunk, kHostToDevice, dst_offset, src_offset);
+          } else {
+            LOG(FATAL) << "Not support mem repeat copy betwee Cuda and OpenCL device";
+          }
+        } else {
+          auto direct = src_dev->lang() == kCpp ? kHostToHost : kDeviceToDevice;
+          src_dev->CopyDataToFrom(to, from, chunk, direct, dst_offset, src_offset);
+        }
+        dst_offset += chunk;
+      }
+      src_offset += chunk;
     }
-  } else {
-    auto direct = src_dev->lang() == kCpp ? kHostToHost : kDeviceToDevice;
-    src_dev->RepeatDataToFrom(to, from, nBytes, direct, broadcast_flag, axis_shape, 
-                              shape_outer, chunk, repeats, (int)d_offset, (int)s_offset);
   }
 }
+
 //============================================================================
 /// typedef DType accroding to type value.
 /// DType would be used in the code block __VA_ARGS__.
@@ -729,6 +753,7 @@ GenUnaryTensorFn(Sign);
 GenUnaryTensorFn(Sqrt);
 GenUnaryTensorFn(Square);
 GenUnaryTensorFn(Tanh);
+GenUnaryTensorFn(Transform);
 
 #define EltwiseBinaryTensorFn(fn, lhs, rhs, ret)                            \
   do {                                                                      \
@@ -977,7 +1002,7 @@ Tensor ConcatOn(const vector<Tensor> &in, int axis) {
       tmp.push_back(Reshape(t, {t.shape(0), t.Size() / t.shape(0)}));
     }
     auto ret = ConcatenateRows(tmp);
-    ret.Reshape(out_shape);
+    ret = ret.Reshape(out_shape);
     return ret;
   } else {
     for (const auto& t : in) {
@@ -987,7 +1012,7 @@ Tensor ConcatOn(const vector<Tensor> &in, int axis) {
       tmp.push_back(Reshape(t, {nrow, t.Size() / nrow}));
     }
     auto ret = ConcatenateColumns(tmp);
-    ret.Reshape(out_shape);
+    ret = ret.Reshape(out_shape);
     return ret;
   }
 }
@@ -1071,7 +1096,7 @@ Tensor SliceOn(const Tensor&in, const size_t start, const size_t end, int axis)
     auto suffix = in.Size() / nrow / in.shape(axis);
     auto ret = SliceColumns(Reshape(in, {nrow, in.Size() / nrow}),
                             start * suffix, end * suffix);
-    ret.Reshape(out_shape);
+    ret = ret.Reshape(out_shape);
     return ret;
   }
 }

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/4940fefb/src/core/tensor/tensor_math.h
----------------------------------------------------------------------
diff --git a/src/core/tensor/tensor_math.h b/src/core/tensor/tensor_math.h
index c7fdfe5..388c010 100644
--- a/src/core/tensor/tensor_math.h
+++ b/src/core/tensor/tensor_math.h
@@ -251,6 +251,13 @@ void Tanh(const Tensor &in, Tensor *out, Context *ctx) {
   LOG(FATAL) << "Tanh Not Implemented";
 }
 
+/// similar to cudnnTransformTensor
+/// copies the data from one tensor to another tensor with a different layout
+/// the tensors must have the same dimensions but not necessarily the same strides 
+template <typename DType, typename Lang>
+void Transform(const Tensor &in, Tensor *out, Context *ctx) {
+  LOG(FATAL) << "Transform Not Implemented";
+}
 // **************************************
 // Random functions
 // **************************************

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/4940fefb/src/core/tensor/tensor_math_cpp.h
----------------------------------------------------------------------
diff --git a/src/core/tensor/tensor_math_cpp.h b/src/core/tensor/tensor_math_cpp.h
index bfdd026..e302b04 100644
--- a/src/core/tensor/tensor_math_cpp.h
+++ b/src/core/tensor/tensor_math_cpp.h
@@ -427,6 +427,20 @@ void Tanh<float, lang::Cpp>(const Tensor& in, Tensor* out,
 }
 
 template <>
+void Transform<float, lang::Cpp>(const Tensor& in, Tensor* out,
+                            Context *ctx) {
+  float *outPtr = static_cast<float *>(out->block()->mutable_data());
+  const float *inPtr = static_cast<const float *>(in.block()->data());
+  vector<int> traversal_info = generate_traversal_info(in);
+  vector<int> shape_multipliers = generate_shape_multipliers(in);
+
+  for (size_t i = 0; i < in.Size(); i++) {
+    outPtr[i] = inPtr[traversal_info[in.shape().size()]];
+    traverse_next(in, shape_multipliers, traversal_info, i + 1);
+  }
+}
+
+template <>
 void Bernoulli<float, lang::Cpp>(const float p, Tensor* out,
                                  Context *ctx) {
   std::bernoulli_distribution distribution(p);

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/4940fefb/src/io/image_transformer.cc
----------------------------------------------------------------------
diff --git a/src/io/image_transformer.cc b/src/io/image_transformer.cc
index 6e5567d..204ad08 100644
--- a/src/io/image_transformer.cc
+++ b/src/io/image_transformer.cc
@@ -229,7 +229,7 @@ namespace singa {
             }
           }
         }
-        output.Reshape(Shape{channel, crop_height, crop_width});
+        output = Reshape(output, Shape{channel, crop_height, crop_width});
         output.CopyDataFromHostPtr<float>(out, crop_height * crop_width * channel);
         delete[] out;
       } else if (image_dim_order == "HWC") {
@@ -247,7 +247,7 @@ namespace singa {
             }
           }
         }
-        output.Reshape(Shape{crop_height, crop_width, channel});
+        output = Reshape(output, Shape{crop_height, crop_width, channel});
         output.CopyDataFromHostPtr<float>(out, crop_height * crop_width * channel);
         delete[] out;
       } else {
@@ -266,7 +266,7 @@ namespace singa {
           out[out_idx] = in[in_idx];
         }
       }
-      output.Reshape(Shape{crop_height, crop_width});
+      output = Reshape(output, Shape{crop_height, crop_width});
       output.CopyDataFromHostPtr<float>(out, crop_height * crop_width);
       delete[] out;
     }
@@ -304,7 +304,7 @@ namespace singa {
             }
           }
         }
-        output.Reshape(Shape{channel, height, width});
+        output = Reshape(output, Shape{channel, height, width});
         output.CopyDataFromHostPtr<float>(out, height * width * channel);
         delete[] out;
       } else if (image_dim_order == "HWC") {
@@ -325,7 +325,7 @@ namespace singa {
             }
           }
         }
-        output.Reshape(Shape{height, width, channel});
+        output = Reshape(output, Shape{height, width, channel});
         output.CopyDataFromHostPtr<float>(out, height * width * channel);
         delete[] out;
       } else {
@@ -347,7 +347,7 @@ namespace singa {
           out[out_idx] = in[in_idx];
         }
       }
-      output.Reshape(Shape{height, width});
+      output = Reshape(output, Shape{height, width});
       output.CopyDataFromHostPtr<float>(out, height * width);
       delete[] out;
     }

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/4940fefb/test/python/test_tensor.py
----------------------------------------------------------------------
diff --git a/test/python/test_tensor.py b/test/python/test_tensor.py
index 098994b..080dd1f 100644
--- a/test/python/test_tensor.py
+++ b/test/python/test_tensor.py
@@ -212,19 +212,19 @@ class TestTensorMethods(unittest.TestCase):
         self.assertAlmostEqual(np.sum(Ta_repeat1 - a_repeat1), 0., places=3)
         self.assertAlmostEqual(np.sum(Ta_repeat2 - a_repeat2), 0., places=3)
 
-    def test_sum2(self):
+    def test_sum(self):
         a = np.array([1.1,1.1,1.1,1.1,1.4,1.3,1.1,1.6,1.1,1.1,1.1,1.2])
         a = np.reshape(a,(2,3,2))
         ta = tensor.from_numpy(a)
 
         a_sum0 = np.sum(a)
-        ta_sum0 = tensor.sum2(ta)
+        ta_sum0 = tensor.sum(ta)
         Ta_sum0 = tensor.to_numpy(ta_sum0)
         a_sum1 = np.sum(a, axis = 1)
-        ta_sum1 = tensor.sum2(ta, axis = 1)
+        ta_sum1 = tensor.sum(ta, axis = 1)
         Ta_sum1 = tensor.to_numpy(ta_sum1)
         a_sum2 = np.sum(a, axis = 2)
-        ta_sum2 = tensor.sum2(ta, axis = 2)
+        ta_sum2 = tensor.sum(ta, axis = 2)
         Ta_sum2 = tensor.to_numpy(ta_sum2)
 
         self.assertAlmostEqual(np.sum(a_sum0 - Ta_sum0), 0., places=3)