You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@singa.apache.org by wa...@apache.org on 2018/07/11 08:29:33 UTC

[1/4] incubator-singa git commit: SINGA-381 - Update the autograd API to yeild the gradients

Repository: incubator-singa
Updated Branches:
  refs/heads/master e16cea129 -> b30d7ea55


SINGA-381 - Update the autograd API to yeild the gradients

yield gradients by backward() in autograd.py; this saves memory by releasing gradients early


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/81908a82
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/81908a82
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/81908a82

Branch: refs/heads/master
Commit: 81908a82f4c9ea01b1359ed3d8fb4118a5bfd147
Parents: e16cea1
Author: Wang Wei <wa...@gmail.com>
Authored: Thu Jul 5 22:09:27 2018 +0800
Committer: wang wei <wa...@comp.nus.edu.sg>
Committed: Wed Jul 11 15:19:27 2018 +0800

----------------------------------------------------------------------
 examples/autograd/mlp.py       |  8 +++-----
 examples/autograd/mnist_cnn.py |  6 ++----
 python/singa/autograd.py       | 20 +++++++++++++-------
 3 files changed, 18 insertions(+), 16 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/81908a82/examples/autograd/mlp.py
----------------------------------------------------------------------
diff --git a/examples/autograd/mlp.py b/examples/autograd/mlp.py
index 0447927..e90ff1d 100755
--- a/examples/autograd/mlp.py
+++ b/examples/autograd/mlp.py
@@ -62,7 +62,7 @@ if __name__ == '__main__':
     label = to_categorical(label, 2).astype(np.float32)
     print('train_data_shape:', data.shape)
     print('train_label_shape:', label.shape)
-    
+
     inputs = Tensor(data=data)
     target = Tensor(data=label)
 
@@ -86,10 +86,8 @@ if __name__ == '__main__':
         x = autograd.add_bias(x, b1)
         x = autograd.soft_max(x)
         loss = autograd.cross_entropy(x, target)
-        in_grads = autograd.backward(loss)
-
-        for param in in_grads:
-            sgd.apply(0, in_grads[param], param, '')
+        for p, gp in autograd.backward(loss):
+            sgd.apply(0, gp, p, '')
 
         if (i % 100 == 0):
             print('training loss = ', tensor.to_numpy(loss)[0])

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/81908a82/examples/autograd/mnist_cnn.py
----------------------------------------------------------------------
diff --git a/examples/autograd/mnist_cnn.py b/examples/autograd/mnist_cnn.py
index 5b4e608..db21485 100755
--- a/examples/autograd/mnist_cnn.py
+++ b/examples/autograd/mnist_cnn.py
@@ -135,7 +135,5 @@ if __name__ == '__main__':
                 print('accuracy is:', accuracy_rate, 'loss is:',
                       tensor.to_numpy(loss)[0])
 
-            in_grads = autograd.backward(loss)
-
-            for param in in_grads:
-                sgd.apply(0, in_grads[param], param, '')
+            for p, gp in autograd.backward(loss):
+                sgd.apply(0, gp, p, '')

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/81908a82/python/singa/autograd.py
----------------------------------------------------------------------
diff --git a/python/singa/autograd.py b/python/singa/autograd.py
index 9fd8b4d..2ba3098 100755
--- a/python/singa/autograd.py
+++ b/python/singa/autograd.py
@@ -538,6 +538,13 @@ def infer_dependency(op):
     return dependency_count
 
 
+def gradients(y, dy=None):
+    grads = {}  # mapping: x->dx if x.stores_grad
+    for p, dp in backward(y, dy):
+        gradients[p] = dp
+    return grads
+
+
 def backward(y, dy=None):
     '''
     Run the backward propagation starting at y.
@@ -566,7 +573,7 @@ def backward(y, dy=None):
     # ready is a queue of (operation, dy list)
     ready = deque([(y.creator, (dy,))])
     not_ready = {}  # mapping: op->[dy]
-    gradients = {}  # mapping: x->dx if x.stores_grad
+
     if y.stores_grad:
         gradients[y] = dy
 
@@ -608,7 +615,8 @@ def backward(y, dy=None):
             if y_stores_grad:
                 # store the gradient for final return, e.g. if x is parameter
                 g = not_ready[src_op][y_idx]
-                gradients[y] = Tensor(device=g.device(), data=g)
+                tg = Tensor(device=g.device(), data=g)
+                yield (y, tg)
             dependency[src_op] -= 1
             if src_op.requires_grad is True:
                 if dependency[src_op] == 0:
@@ -616,10 +624,8 @@ def backward(y, dy=None):
                         ready.append((src_op, not_ready[src_op]))
                     del not_ready[src_op]
 
-    return gradients
-
 
-class NewLayer(object):
+class Layer(object):
 
     def __init__(self):
         pass
@@ -631,7 +637,7 @@ class NewLayer(object):
                 var.to_device(x_device)
 
 
-class Linear(NewLayer):
+class Linear(Layer):
 
     def __init__(self, in_features, out_features, bias=True):
         #self.in_features = in_features
@@ -661,7 +667,7 @@ class Linear(NewLayer):
         return y
 
 
-class Conv2D(NewLayer):
+class Conv2D(Layer):
 
     def __init__(self, in_channels, out_channels, kernel_size, stride=1,
                  padding=0, dilation=1, groups=1, bias=True, **kwargs):


[4/4] incubator-singa git commit: SINGA-380) Fix bugs from Reshape

Posted by wa...@apache.org.
SINGA-380) Fix bugs from Reshape

Update reshape API in C++ and Python.
C++ Tensor method reshape changes original tensor;
All other reshape method returns a new tensor (which shares memory with the original tensor if possible).

APIs for transpose are updated in the same way.


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/b30d7ea5
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/b30d7ea5
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/b30d7ea5

Branch: refs/heads/master
Commit: b30d7ea55cd58bb0858aa354833c1ba9a3242470
Parents: 58e6640
Author: Wang Wei <wa...@gmail.com>
Authored: Mon Jul 9 23:52:10 2018 +0800
Committer: wang wei <wa...@comp.nus.edu.sg>
Committed: Wed Jul 11 15:24:27 2018 +0800

----------------------------------------------------------------------
 examples/autograd/mnist_cnn.py        |  17 +-
 examples/cifar10/cnn-parallel.cc      |   8 +-
 examples/cifar10/vgg-parallel.cc      |   8 +-
 examples/imagenet/alexnet/alexnet.cc  |   2 +-
 examples/imagenet/alexnet/ilsvrc12.h  |  16 +-
 include/singa/core/tensor.h           | 162 ++++----
 python/singa/autograd.py              | 273 +++++++-------
 python/singa/tensor.py                | 109 +++---
 src/api/core_tensor.i                 |  19 +-
 src/core/tensor/tensor.cc             | 297 ++++-----------
 src/core/tensor/tensor_math.h         |   2 +-
 src/core/tensor/tensor_math_cuda.h    | 323 ++++------------
 src/io/image_transformer.cc           | 573 ++++++++++++++---------------
 src/model/layer/batchnorm.cc          |  15 +-
 src/model/layer/convolution.cc        |   8 +-
 src/model/layer/cudnn_batchnorm.cc    |   4 +-
 src/model/layer/dense.cc              |  14 +-
 src/model/layer/flatten.cc            |   3 +-
 src/model/layer/lrn.cc                |   9 +-
 src/model/layer/opencl_convolution.cc |  58 +--
 src/model/layer/rnn.cc                |   2 +-
 src/model/operation/convolution.cc    |  67 ++--
 src/model/updater/local_updater.cc    |   4 +-
 23 files changed, 849 insertions(+), 1144 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b30d7ea5/examples/autograd/mnist_cnn.py
----------------------------------------------------------------------
diff --git a/examples/autograd/mnist_cnn.py b/examples/autograd/mnist_cnn.py
index 43a22ba..f78ccc8 100755
--- a/examples/autograd/mnist_cnn.py
+++ b/examples/autograd/mnist_cnn.py
@@ -84,7 +84,7 @@ if __name__ == '__main__':
         dev = device.get_default_device()
     else:
         print('Using GPU')
-        dev = device.create_cuda_gpu()
+        dev = device.create_cuda_gpu_on(1)
 
     train, test = load_data(args.file_path)
 
@@ -92,7 +92,7 @@ if __name__ == '__main__':
     num_classes = 10
     epochs = 1
 
-    sgd = optimizer.SGD(0.001)
+    sgd = optimizer.SGD(0.01)
 
     x_train = preprocess(train[0])
     y_train = to_categorical(train[1], num_classes)
@@ -111,7 +111,6 @@ if __name__ == '__main__':
 
 
     def forward(x, t):
-        
         y = conv1(x)
         y = autograd.relu(y)
         y = autograd.max_pool_2d(y)
@@ -124,11 +123,11 @@ if __name__ == '__main__':
         return loss, y
 
     autograd.training = True
-    for epoch in range(50):
+    for epoch in range(epochs):
         for i in range(batch_number):
             inputs = tensor.Tensor(device=dev, data=x_train[ i * 100:(1 + i) * 100], stores_grad=False)
             targets = tensor.Tensor(device=dev, data=y_train[i * 100:(1 + i) * 100], requires_grad=False, stores_grad=False)
-            
+
             loss, y = forward(inputs, targets)
 
             accuracy_rate = accuracy(tensor.to_numpy(y),
@@ -136,12 +135,6 @@ if __name__ == '__main__':
             if (i % 5 == 0):
                 print('accuracy is:', accuracy_rate, 'loss is:',
                       tensor.to_numpy(loss)[0])
-            
+
             for p, gp in autograd.backward(loss):
                 sgd.apply(epoch, gp, p, '')
-            
-            
-
-            
-            
-

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b30d7ea5/examples/cifar10/cnn-parallel.cc
----------------------------------------------------------------------
diff --git a/examples/cifar10/cnn-parallel.cc b/examples/cifar10/cnn-parallel.cc
index 8cc3352..4bee575 100644
--- a/examples/cifar10/cnn-parallel.cc
+++ b/examples/cifar10/cnn-parallel.cc
@@ -154,20 +154,20 @@ void Train(float lr, int num_epoch, string data_dir) {
     train_y = train.second;
 
     LOG(INFO) << "Slicing training data...";
-    train_x_1.Reshape(Shape{nsamples / 2, train.first.shape(1),
+    train_x_1 = Tensor(Shape{nsamples / 2, train.first.shape(1),
         train.first.shape(2), train.first.shape(3)});
     LOG(INFO) << "Copying first data slice...";
     CopyDataToFrom(&train_x_1, train_x, train_x.Size() / 2);
-    train_x_2.Reshape(Shape{nsamples / 2, train.first.shape(1),
+    train_x_2 = Tensor(Shape{nsamples / 2, train.first.shape(1),
         train.first.shape(2), train.first.shape(3)});
     LOG(INFO) << "Copying second data slice...";
     CopyDataToFrom(&train_x_2, train_x, train_x.Size() / 2, 0,
                    train_x.Size() / 2);
-    train_y_1.Reshape(Shape{nsamples / 2});
+    train_y_1 = Tensor(Shape{nsamples / 2});
     train_y_1.AsType(kInt);
     LOG(INFO) << "Copying first label slice...";
     CopyDataToFrom(&train_y_1, train_y, train_y.Size() / 2);
-    train_y_2.Reshape(Shape{nsamples / 2});
+    train_y_2 = Tensor(Shape{nsamples / 2});
     train_y_2.AsType(kInt);
     LOG(INFO) << "Copying second label slice...";
     CopyDataToFrom(&train_y_2, train_y, train_y.Size() / 2, 0,

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b30d7ea5/examples/cifar10/vgg-parallel.cc
----------------------------------------------------------------------
diff --git a/examples/cifar10/vgg-parallel.cc b/examples/cifar10/vgg-parallel.cc
index 90e9fce..33c533b 100644
--- a/examples/cifar10/vgg-parallel.cc
+++ b/examples/cifar10/vgg-parallel.cc
@@ -223,20 +223,20 @@ void Train(float lr, int num_epoch, string data_dir) {
     train_y = train.second;
 
     LOG(INFO) << "Slicing training data...";
-    train_x_1.Reshape(Shape{nsamples / 2, train.first.shape(1),
+    train_x_1 = Tensor(Shape{nsamples / 2, train.first.shape(1),
         train.first.shape(2), train.first.shape(3)});
     LOG(INFO) << "Copying first data slice...";
     CopyDataToFrom(&train_x_1, train_x, train_x.Size() / 2);
-    train_x_2.Reshape(Shape{nsamples / 2, train.first.shape(1),
+    train_x_2 = Tensor(Shape{nsamples / 2, train.first.shape(1),
         train.first.shape(2), train.first.shape(3)});
     LOG(INFO) << "Copying second data slice...";
     CopyDataToFrom(&train_x_2, train_x, train_x.Size() / 2, 0,
                    train_x.Size() / 2);
-    train_y_1.Reshape(Shape{nsamples / 2});
+    train_y_1 = Tensor(Shape{nsamples / 2});
     train_y_1.AsType(kInt);
     LOG(INFO) << "Copying first label slice...";
     CopyDataToFrom(&train_y_1, train_y, train_y.Size() / 2);
-    train_y_2.Reshape(Shape{nsamples / 2});
+    train_y_2 = Tensor(Shape{nsamples / 2});
     train_y_2.AsType(kInt);
     LOG(INFO) << "Copying second label slice...";
     CopyDataToFrom(&train_y_2, train_y, train_y.Size() / 2, 0,

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b30d7ea5/examples/imagenet/alexnet/alexnet.cc
----------------------------------------------------------------------
diff --git a/examples/imagenet/alexnet/alexnet.cc b/examples/imagenet/alexnet/alexnet.cc
index 4ac1130..2d8db2d 100644
--- a/examples/imagenet/alexnet/alexnet.cc
+++ b/examples/imagenet/alexnet/alexnet.cc
@@ -174,7 +174,7 @@ void TrainOneEpoch(FeedForwardNet &net, ILSVRC &data,
   size_t b = 0;
   size_t n_read;
   Timer timer, ttr;
-  Tensor prefetch_x, prefetch_y;
+  Tensor prefetch_x(Shape{batchsize, 3, kCropSize, kCropSize}), prefetch_y(Shape{batchsize}, kInt);
   string binfile = bin_folder + "/train1.bin";
   timer.Tick();
   data.LoadData(kTrain, binfile, batchsize, &prefetch_x, &prefetch_y, &n_read,

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b30d7ea5/examples/imagenet/alexnet/ilsvrc12.h
----------------------------------------------------------------------
diff --git a/examples/imagenet/alexnet/ilsvrc12.h b/examples/imagenet/alexnet/ilsvrc12.h
index 74fffbb..05b3451 100644
--- a/examples/imagenet/alexnet/ilsvrc12.h
+++ b/examples/imagenet/alexnet/ilsvrc12.h
@@ -43,6 +43,12 @@
 using std::string;
 using namespace singa::io;
 namespace singa {
+
+ /// size for resizing
+const size_t kImageSize = 256;
+const size_t kImageNBytes = 3 * kImageSize * kImageSize;
+/// size for cropping
+const size_t kCropSize = 227;
 /// For reading ILSVRC2012 image data as tensors.
 class ILSVRC {
  public:
@@ -105,11 +111,7 @@ class ILSVRC {
   void WriteMean(Tensor &mean, string path);
 
  private:
-  /// size for resizing
-  const size_t kImageSize = 256;
-  const size_t kImageNBytes = 3 * kImageSize * kImageSize;
-  /// size for cropping
-  const size_t kCropSize = 227;
+ 
   Tensor mean;
   string last_read_file = "";
 
@@ -299,9 +301,7 @@ std::thread ILSVRC::AsyncLoadData(int flag, string file, size_t read_size,
 
 size_t ILSVRC::LoadData(int flag, string file, size_t read_size, Tensor *x,
                         Tensor *y, size_t *n_read, int nthreads) {
-  x->Reshape(Shape{read_size, 3, kCropSize, kCropSize});
-  y->AsType(kInt);
-  y->Reshape(Shape{read_size});
+  
   if (file != last_read_file) {
     if (reader != nullptr) {
       reader->Close();

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b30d7ea5/include/singa/core/tensor.h
----------------------------------------------------------------------
diff --git a/include/singa/core/tensor.h b/include/singa/core/tensor.h
index 5921762..a73821c 100644
--- a/include/singa/core/tensor.h
+++ b/include/singa/core/tensor.h
@@ -57,47 +57,38 @@ class Tensor {
  public:
   ~Tensor();
   Tensor();
-  explicit Tensor(Shape &&shape, DataType dtype = kFloat32);
+
+  /// Constructor using default device.
   explicit Tensor(const Shape &shape, DataType dtype = kFloat32);
 
-  Tensor(Shape &&shape,
-         std::shared_ptr<Device> dev,
-         DataType dtype = kFloat32);
+  /// Constructor with shape, device and data type
   Tensor(const Shape &shape,
          std::shared_ptr<Device> dev,
          DataType dtype = kFloat32);
 
-  /// Copy Tensor to share the internal data.  No deep copy.
+  /// Copy constructor.  No deep copy.
   Tensor(const Tensor &from);
-  /// Copy Tensor to share the internal data.  No deep copy.
-  /// For 2 tensors sharing same block but different strides.
-  Tensor(const Tensor &from, Shape &new_shape, vector<int> &new_strides);
-  /// Copy Tensor to share the internal data.  No deep copy.
+
+  /// Move constructor.  No deep copy.
   Tensor(Tensor &&from);
 
+  // --------------------------------------------------------------------------
+  // ---Following methods return info of the class without making any changes--
+  // --------------------------------------------------------------------------
+
   /// For functions in xx_math.cc to access the block.
   /// Users should not operate against Block directly.
   /// block_ is allocated in constructors.
   Block *block() const { return block_; }
-  void SetBlock(Block *block);
 
   std::shared_ptr<Device> device() const { return device_; }
 
-  /// return immutable Tensor values with given type.
+  /// Return immutable Tensor values with given type.
   template <typename SType>
   const SType *data() const {
     return static_cast<const SType *>(block()->data());
   }
 
-  /// used for swig code to convert Tensor into numpy array.
-  /// It gets data into 'value'
-  template <typename SType>
-  void GetValue(SType *value, const size_t num) {
-    CHECK(device_ == defaultDevice);
-    const SType* ptr = data<SType>();
-    for (size_t i = 0; i < num; i++) value[i] = ptr[i];
-  }
-
   /// data type, including kFloat16, kFloat32, kInt
   const DataType data_type() const { return data_type_; }
 
@@ -113,28 +104,55 @@ class Tensor {
   bool empty() const { return nDim() == 0; }
 
   /// Check if the tensor's last stride==1
-  bool transpose() const { return (strides_.back() != 1); }
+  bool transpose() const {
+    if (!strides_.empty()) {
+      auto last = strides_.front();
+      for (auto s : strides_) {
+        if (s > last)
+          return true;
+        last = s;
+      }
+    }
+    return false;
+  }
 
   const vector<int>& strides() const { return strides_; }
 
-  /// return true if the content of the tensor is initialized
+  /// Return true if the content of the tensor is initialized
   bool initailized() const {
     return block_ != nullptr && block_->initialized();
   }
 
-  /// return number of total elements
+  /// Return number of total elements
   size_t Size() const {
     if (block_ == nullptr) return 0u;
     CHECK_EQ(block_->size() % SizeOf(data_type_), 0u);
     return block_->size() / SizeOf(data_type_);
   }
 
-  /// return memory size (i.e., Bytes)
+  /// Return memory size (i.e., Bytes)
   size_t MemSize() const { return block_->size(); }
 
-  /// Reset the tensor shape, it may reallocate block, if MemSize() changes.
-  Tensor Reshape(const Shape &shape);
-  Tensor Reshape(Shape &&shape);
+  /// used for swig code to convert Tensor into numpy array.
+  /// It gets data into 'value'
+  template <typename SType>
+  void GetValue(SType *value, const size_t num) {
+    CHECK(device_ == defaultDevice);
+    const SType* ptr = data<SType>();
+    for (size_t i = 0; i < num; i++) value[i] = ptr[i];
+  }
+
+  /// Serialize data, shape and transpose to protobuf object.
+  void ToProto(singa::TensorProto *proto) const;
+
+  /// Return average L1 norm
+  float L1() const;
+
+  /// Return average L2 norm
+  float L2() const;
+  // --------------------------------------------------------------------------
+  // ---Following methods changes the internal members
+  // --------------------------------------------------------------------------
 
   /// Reset the shape, device, and data type as given tensor.
   /// If block size changes, then reallocate a new block.
@@ -155,6 +173,8 @@ class Tensor {
   template <typename SType>
   void SetValue(const SType x);
 
+  void SetShape(const Shape& shape);
+
   /// For init the tensor values, copy 'num' elements from 'src' to the internal
   /// memory with 'offset' (elements).
   template <typename SType>
@@ -165,46 +185,41 @@ class Tensor {
   /// Meta data would not be copied!
   void CopyData(const Tensor &other);
 
-  void RepeatData(vector<size_t> repeats, int axis, int total_repeats, const Tensor &other);
-
   /// Deserialize data, shape and transpose from protobuf object.
   void FromProto(const singa::TensorProto &proto);
 
-  /// Serialize data, shape and transpose to protobuf object.
-  void ToProto(singa::TensorProto *proto) const;
 
-  /// return an exactly the same Tensor with data been deep copied to the given
-  /// device. If 'device' is nullptr, then clone it one the current device.
-  Tensor Clone(std::shared_ptr<Device> device = nullptr) const;
+  /// TODO(wangwei) merge RepeatData into  Repeat?
+  void RepeatData(const vector<size_t>& repeats, int axis, int total_repeats,
+                  const Tensor &other);
 
-  Tensor Repeat(vector<size_t> repeats, int axis, std::shared_ptr<Device> device = nullptr) ;
+  // --------------------------------------------------------------------------
+  // ---Following methods returns a new Tensor without change original tensor
+  // --------------------------------------------------------------------------
 
-  // Tensor operations
-
-  /// Matrix transpose.  Valid only if shape.size() == 2.
-  /// No data copy, just set the transpose_ filed of the returned tensor.
-  Tensor T() const;
-
-  /// Reverse the shape vector
-  Tensor Transpose() const;
+  Tensor Repeat(const vector<size_t>& repeats, int axis,
+                std::shared_ptr<Device> device = nullptr);
 
-  /// Change the axes
-  Tensor Transpose(const vector<size_t> &axes) const;
+  /// return an exactly the same Tensor with data been deep copied to the given
+  /// device. If 'device' is nullptr, then clone it one the current device.
+  Tensor Clone(std::shared_ptr<Device> device = nullptr) const;
 
-  /// Copy the meta info with data block shared.
+  // --------------------------------------------------------------------------
+  // ---Following methods change the tensor and return itself
+  // --------------------------------------------------------------------------
+  /// Copy assignment
   Tensor &operator=(const Tensor &in);
 
-  /// Copy the meta info with data block shared.
+  /// Move assignment
   Tensor &operator=(Tensor &&in);
 
   Tensor &operator+=(const Tensor &in);
-  // void operator+=(Tensor&& in);
+
   Tensor &operator-=(const Tensor &in);
-  // void operator-=(Tensor&& in);
+
   Tensor &operator*=(const Tensor &in);
-  // void operator*=(Tensor&& in);
+
   Tensor &operator/=(const Tensor &in);
-  // void operator/=(Tensor&& in);
 
   // Scalar operations.
 
@@ -224,10 +239,19 @@ class Tensor {
   template <typename SType>
   Tensor &operator/=(const SType x);
 
-  /// Return average L1 norm
-  float L1() const;
-  /// Return average L2 norm
-  float L2() const;
+  /// change the shape (and stride); the block may be reallocated.
+  Tensor &Reshape(const Shape &shape);
+
+  /// Matrix transpose.  Valid only if shape.size() == 2.
+  Tensor& T();
+
+  /// Reverse the shape vector
+  Tensor& Transpose();
+
+  /// Change the axes
+  Tensor& Transpose(const vector<size_t> &axes);
+
+ protected:
 
   //generate strides automatically if stride field is not passed
   void generate_strides() {
@@ -259,10 +283,10 @@ class Tensor {
   vector<int> strides_ = {};
 }; //end of tensor class
 
+
 inline size_t Product(const Shape &shape, int start = 0, size_t len = 0) {
   if (len == 0) len = shape.size();
-  if (len == 0)
-    return 0;
+  if (len == 0) return 0;
   CHECK_LE(len, shape.size());
   size_t v = 1;
   for (unsigned int i = start; i < len; i++) v *= shape[i];
@@ -275,24 +299,31 @@ inline void CheckDataTypeAndLang(const Tensor &in1, const Tensor &in2) {
   CHECK_EQ(in1.device()->lang(), in2.device()->lang());
 }
 
+
 template <typename FromType, typename ToType>
 ToType TypeCast(const FromType &x) {
   // TODO(wangwei) cast fp16; prevent some casts, e.g., float to char
   return static_cast<ToType>(x);
 }
 
+
+/// Reshape the given tensor and generate a new tensor,
+/// which shares the memory with in if possible
 Tensor Reshape(const Tensor &in, const Shape &s);
-Tensor Reshape(const Tensor &in, Shape &&s);
 
-// For tensors with sparse content, e.g., missing columns or rows.
-// class SparseTensor : public Tensor {};
+/// Reverse the shape vector
+Tensor Transpose(const Tensor& in);
+
+/// Change the axes
+Tensor Transpose(const Tensor& in, const vector<size_t> &axes);
 
 /// Copy 'num' elements of src to dst.
 /// The first 'src_offset' ('dst_offset') elements will be skipped.
 void CopyDataToFrom(Tensor *dst, const Tensor &src, const size_t num,
                     const size_t dst_offset = 0, const size_t src_offset = 0);
 
-void RepeatDataToFrom(bool broadcast_flag, vector<size_t> repeats, int axis,
+
+void RepeatDataToFrom(bool broadcast_flag, const vector<size_t>& repeats, int axis,
                       Tensor *dst, const Tensor &in, const size_t num);
 
 // =============Element-wise operations====================================
@@ -411,6 +442,8 @@ void Div(const SType x, const Tensor &in, Tensor *out);
 
 template <typename SType = float>
 SType Sum(const Tensor &in);
+
+
 // ============Matrix (row/column) operations==================================
 /// Average elements in the Tensor, currently only support vector and matrix.
 /// if 'axis' is 0, average all rows into a single row
@@ -510,8 +543,8 @@ void SoftmaxCrossEntropyBwd(const Tensor &t, Tensor *p);
 
 /// To be called by pysinga autograd operations;
 /// swig ignores the const qualifier http://www.swig.org/Doc3.0/SWIGPlus.html#SWIGPlus_const
-const Tensor CrossEntropyFwd(const Tensor& p, const Tensor& t);
-const Tensor SoftmaxCrossEntropyBwd(const Tensor& p, const Tensor& t);
+Tensor CrossEntropyFwd(const Tensor& p, const Tensor& t);
+Tensor SoftmaxCrossEntropyBwd(const Tensor& p, const Tensor& t);
 
 /// Return a tensor consisting of rows ([start, end)) from 'in'. It copies the
 /// values from 'in'. 'in' ia a 2D Tensor.
@@ -519,7 +552,8 @@ Tensor CopyRows(const Tensor &in, const size_t start, const size_t end);
 /// Alias of CopyRows
 Tensor SliceRows(const Tensor &in, const size_t start, const size_t end);
 /// Slice the input tensor along the give axis to generate a new tensor
-Tensor SliceOn(const Tensor &in, const size_t start, const size_t end, int axis);
+Tensor SliceOn(const Tensor &in, const size_t start, const size_t end,
+               int axis);
 /// Return a tensor consisting of columns ([start, end)) from 'in'. It copies
 /// the values from 'in'. 'in' is a  2D Tensor.
 Tensor CopyColumns(const Tensor &in, const size_t start, const size_t end);

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b30d7ea5/python/singa/autograd.py
----------------------------------------------------------------------
diff --git a/python/singa/autograd.py b/python/singa/autograd.py
index 63698c2..aa6b37a 100755
--- a/python/singa/autograd.py
+++ b/python/singa/autograd.py
@@ -33,6 +33,126 @@ CTensor = singa.Tensor
 training = False
 
 
+
+def infer_dependency(op):
+    '''
+    Infer the dependency of all operations with the
+    given op as the last operation.
+
+    Operation A is depending on B is A uses the output(s) of B.
+
+    Args:
+        op: an Operation instance, e.g. the loss operation.
+
+    Return:
+        a Counter instance with the operation as the key,
+        and the number of operations that are depending on it as the value
+    '''
+    # dependency = {}
+    dependency_count = Counter()
+    queue = deque([op])
+    while len(queue) > 0:
+        cur_op = queue.pop()
+        for src_op, _, _, _ in cur_op.src:
+            if src_op not in dependency_count and \
+                    (not isinstance(src_op, Dummy)):
+                # dependency[src_op] = [Counter() for _ in src_op.y_id2idx]
+                dependency_count[src_op] = 0
+                queue.append(src_op)
+            # y_idx = src_op.y_id2idx[x_id]
+            # dependency[src_op][y_idx][cur_op] += 1
+            dependency_count[src_op] += 1
+    return dependency_count
+
+
+def gradients(y, dy=None):
+    grads = {}  # mapping: x->dx if x.stores_grad
+    for p, dp in backward(y, dy):
+        gradients[p] = dp
+    return grads
+
+
+def backward(y, dy=None):
+    '''
+    Run the backward propagation starting at y.
+
+    Args:
+        y: a Tensor instance, usually the loss
+        dy: a number or a Tensor instance, for the gradient of the
+            objective/loss w.r.t y, usually 1.0
+
+    Return:
+        a dictionary storing the gradient tensors of all tensors
+        whose stores_grad is true (e.g. parameter tensors)
+    '''
+    dependency = infer_dependency(y.creator)
+    assert y.size() == 1, 'y must be a Tensor with a single value;'\
+        'size of y is % d' % y.size()
+
+    # by default the dy is a tensor with 1.0 for each sample;
+    if dy is None:
+        dy = float(1.0)
+    elif isinstance(dy, Tensor):
+        dy = dy.data
+    else:
+        dy = float(dy)
+
+    # ready is a queue of (operation, dy list)
+    ready = deque([(y.creator, (dy,))])
+    not_ready = {}  # mapping: op->[dy]
+
+    if y.stores_grad:
+        gradients[y] = dy
+
+    while len(ready) > 0:
+        op, dys = ready.pop()
+        if not op.requires_grad or isinstance(op, Dummy):
+            continue
+        # if not isinstance(op, tensor.Dummy):
+        dxs = op._do_backward(*dys)
+        # TODO src and dx must match
+        assert len(op.src) == len(dxs), \
+            'the number of src ops (=%d) and dx (=%d) not match' \
+            % (len(op.src), len(dxs))
+        for (src_op, x_id, y, y_stores_grad), dx in zip(op.src, dxs):
+            # prefix x is w.r.t op; prefix y is w.r.t src_op.
+            # x_id is the python id of one input arg of src_op, denoted as x.
+            # y_idx (below) is the index of x among the outputs of src_op.
+            # not_ready[src_op][y_idx] records the intermediate gradient
+            # of the y_idx'th output of src_op. 'intermediate gradient'
+            # indicates that if this output is used in multiple children
+            # operations, then we have to add the graident (dx) from all these
+            # children operations. When src_op is ready, it means that
+            # the gradient of all its outputs are available, i.e. all children
+            # operations have been backwarded.
+            # y is None if y.stores_grad is false; otherwise it is a Tensor
+            y_idx = src_op.y_id2idx[x_id]
+            if src_op not in not_ready:
+                # src_op may have mulitple outputs
+                not_ready[src_op] = [None for _ in src_op.y_id2idx]
+                not_ready[src_op][y_idx] = dx
+            else:
+                dxs = not_ready[src_op]
+                if dxs[y_idx] is None:
+                    dxs[y_idx] = dx
+                else:
+                    # add the gradient from another children operation that
+                    # uses y_idx'th output of src_op as input arg
+                    dxs[y_idx] += dx
+            if y_stores_grad:
+                # store the gradient for final return, e.g. if x is parameter
+                g = not_ready[src_op][y_idx]
+                tg = Tensor(device=g.device(), data=g)
+                yield (y, tg)
+            dependency[src_op] -= 1
+            if src_op.requires_grad is True:
+                if dependency[src_op] == 0:
+                    if not isinstance(src_op, Dummy):
+                        ready.append((src_op, not_ready[src_op]))
+                    del not_ready[src_op]
+        del op  # delete the operation to free all tensors from this op
+
+
 class Operation(object):
     '''
     An operation includes the forward and backward function of
@@ -194,8 +314,8 @@ class Matmul(Operation):
         Returns:
             a tuple for (dx, dw)
         '''
-        return singa.Mult(dy, self.input[1].T()), \
-            singa.Mult(self.input[0].T(), dy)
+        return singa.Mult(dy, singa.DefaultTranspose(self.input[1])), \
+            singa.Mult(singa.DefaultTranspose(self.input[0]), dy)
 
 
 def matmul(x, w):
@@ -268,12 +388,12 @@ class SoftMax(Operation):
             the result Tensor
         '''
         if self.axis == 1:
-            x = x.T()
+            x = singa.DefaultTranspose(x)
         self.output = singa.SoftMax(x)
         if self.axis == 0:
             return self.output
         elif self.axis == 1:
-            return self.output.T()
+            return singa.DefaultTranspose(self.output)
 
     def backward(self, dy):
         '''
@@ -286,7 +406,7 @@ class SoftMax(Operation):
         '''
         # calculations are made on numpy array
         if self.axis == 1:
-            dy = dy.T()
+            dy = singa.DefaultTranspose(dy)
         grad = ctensor2numpy(dy)
         output = ctensor2numpy(self.output)
         out_1 = np.einsum('ki,ki->ki', grad, output)
@@ -298,14 +418,14 @@ class SoftMax(Operation):
         if self.axis == 0:
             return dx
         elif self.axis == 1:
-            return dx.T()
+            return singa.DefaultTranspose(dx)
 
 
 def soft_max(x, axis=0):
     return SoftMax(axis)(x)[0]
 
 
-class NLL(Operation):
+class CrossEntropy(Operation):
     '''
     Calculte negative log likelihood loss for a batch of training data.
 
@@ -350,12 +470,11 @@ class NLL(Operation):
             pass  # TODO, broadcast elementwise multiply seems not support
 
 
-def nll(y, t):
-    return NLL()(y, t)[0]
+def cross_entropy(y, t):
+    return CrossEntropy()(y, t)[0]
 
 
 class SoftMaxCrossEntropy(Operation):
-
     def forward(self, x, t):
         self.p = singa.SoftMax(x)
         self.t = t
@@ -365,7 +484,8 @@ class SoftMaxCrossEntropy(Operation):
         return loss
 
     def backward(self, dy=1.0):
-        return singa.SoftmaxCrossEntropyBwd(self.p, self.t), None
+        dx = singa.SoftmaxCrossEntropyBwd(self.p, self.t)
+        return singa.DivFloat(dx, float(self.p.shape()[0])), None
 
 
 def softmax_cross_entropy(x, t):
@@ -448,11 +568,11 @@ class Flatten(Operation):
     def forward(self, x):
         # TODO Do flatten start from axis != 1
         self.shape = list(x.shape())
-        y = x.Reshape((x.shape()[0], x.Size() // x.shape()[0]))
+        y = singa.Reshape(x, (x.shape()[0], x.Size() // x.shape()[0]))
         return y
 
     def backward(self, dy):
-        dx = dy.Reshape(self.shape)
+        dx = singa.Reshape(dy, self.shape)
         return dx
 
 
@@ -466,11 +586,7 @@ class _Conv2D(Operation):
         self.handle = handle
 
     def forward(self, x, W, b):
-        #assert x.nDim() == 4, 'The dimensions of input should be 4D.'
-        #assert x.shape()[1] == self.in_channels, 'in_channels dismatched.'
-        #assert (xs[0].shape()[2]+2*self.padding[0]-self.kernel_size[0])%self.stride[0] == 0, 'invalid padding.'
-        #assert (xs[0].shape()[3]+2*self.padding[1]-self.kernel_size[1])%self.stride[1] == 0, 'invalid padding'
-        #assert 0 == 0, 'invalid padding'
+        assert x.nDim() == 4, 'The dimensions of input should be 4D.'
 
         if training:
             if self.handle.bias_term:
@@ -517,125 +633,6 @@ def conv2d(x, W, b, handle):
     return _Conv2D(handle)(x, W, b)[0]
 
 
-def infer_dependency(op):
-    '''
-    Infer the dependency of all operations with the
-    given op as the last operation.
-
-    Operation A is depending on B is A uses the output(s) of B.
-
-    Args:
-        op: an Operation instance, e.g. the loss operation.
-
-    Return:
-        a Counter instance with the operation as the key,
-        and the number of operations that are depending on it as the value
-    '''
-    # dependency = {}
-    dependency_count = Counter()
-    queue = deque([op])
-    while len(queue) > 0:
-        cur_op = queue.pop()
-        for src_op, _, _, _ in cur_op.src:
-            if src_op not in dependency_count and \
-                    (not isinstance(src_op, Dummy)):
-                # dependency[src_op] = [Counter() for _ in src_op.y_id2idx]
-                dependency_count[src_op] = 0
-                queue.append(src_op)
-            # y_idx = src_op.y_id2idx[x_id]
-            # dependency[src_op][y_idx][cur_op] += 1
-            dependency_count[src_op] += 1
-    return dependency_count
-
-
-def gradients(y, dy=None):
-    grads = {}  # mapping: x->dx if x.stores_grad
-    for p, dp in backward(y, dy):
-        gradients[p] = dp
-    return grads
-
-
-def backward(y, dy=None):
-    '''
-    Run the backward propagation starting at y.
-
-    Args:
-        y: a Tensor instance, usually the loss
-        dy: a number or a Tensor instance, for the gradient of the
-            objective/loss w.r.t y, usually 1.0
-
-    Return:
-        a dictionary storing the gradient tensors of all tensors
-        whose stores_grad is true (e.g. parameter tensors)
-    '''
-    dependency = infer_dependency(y.creator)
-    assert y.size() == 1, 'y must be a Tensor with a single value;'\
-        'size of y is % d' % y.size()
-
-    # by default the dy is a tensor with 1.0 for each sample;
-    if dy is None:
-        dy = float(1.0)
-    elif isinstance(dy, Tensor):
-        dy = dy.data
-    else:
-        dy = float(dy)
-
-    # ready is a queue of (operation, dy list)
-    ready = deque([(y.creator, (dy,))])
-    not_ready = {}  # mapping: op->[dy]
-
-    if y.stores_grad:
-        gradients[y] = dy
-
-    while len(ready) > 0:
-        op, dys = ready.pop()
-        if not op.requires_grad or isinstance(op, Dummy):
-            continue
-        # if not isinstance(op, tensor.Dummy):
-        dxs = op._do_backward(*dys)
-        # TODO src and dx must match
-        assert len(op.src) == len(dxs), \
-            'the number of src ops (=%d) and dx (=%d) not match' \
-            % (len(op.src), len(dxs))
-        for (src_op, x_id, y, y_stores_grad), dx in zip(op.src, dxs):
-            # prefix x is w.r.t op; prefix y is w.r.t src_op.
-            # x_id is the python id of one input arg of src_op, denoted as x.
-            # y_idx (below) is the index of x among the outputs of src_op.
-            # not_ready[src_op][y_idx] records the intermediate gradient
-            # of the y_idx'th output of src_op. 'intermediate gradient'
-            # indicates that if this output is used in multiple children
-            # operations, then we have to add the graident (dx) from all these
-            # children operations. When src_op is ready, it means that
-            # the gradient of all its outputs are available, i.e. all children
-            # operations have been backwarded.
-            # y is None if y.stores_grad is false; otherwise it is a Tensor
-            y_idx = src_op.y_id2idx[x_id]
-            if src_op not in not_ready:
-                # src_op may have mulitple outputs
-                not_ready[src_op] = [None for _ in src_op.y_id2idx]
-                not_ready[src_op][y_idx] = dx
-            else:
-                dxs = not_ready[src_op]
-                if dxs[y_idx] is None:
-                    dxs[y_idx] = dx
-                else:
-                    # add the gradient from another children operation that
-                    # uses y_idx'th output of src_op as input arg
-                    dxs[y_idx] += dx
-            if y_stores_grad:
-                # store the gradient for final return, e.g. if x is parameter
-                g = not_ready[src_op][y_idx]
-                tg = Tensor(device=g.device(), data=g)
-                yield (y, tg)
-            dependency[src_op] -= 1
-            if src_op.requires_grad is True:
-                if dependency[src_op] == 0:
-                    if not isinstance(src_op, Dummy):
-                        ready.append((src_op, not_ready[src_op]))
-                    del not_ready[src_op]
-        del op  # delete the operation to free all tensors from this op
-
-
 class Layer(object):
 
     def __init__(self):
@@ -651,8 +648,6 @@ class Layer(object):
 class Linear(Layer):
 
     def __init__(self, in_features, out_features, bias=True):
-        #self.in_features = in_features
-        #self.out_features = out_features
         w_shape = (in_features, out_features)
         b_shape = (1, out_features)
         self.bias = bias

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b30d7ea5/python/singa/tensor.py
----------------------------------------------------------------------
diff --git a/python/singa/tensor.py b/python/singa/tensor.py
index 0860d9d..46a47b7 100644
--- a/python/singa/tensor.py
+++ b/python/singa/tensor.py
@@ -134,7 +134,7 @@ class Tensor(object):
         '''
         return self.data.transpose()
 
-    def transpose(self,axes = None):
+    def transpose(self, axes=None):
         '''
         To transpose the tensor
         '''
@@ -142,13 +142,13 @@ class Tensor(object):
         if axes == None:
             tshape = [self.shape[x] for x in range(len(t.shape))]
             t.shape = tuple(tshape)
-            t.data = self.data.Transpose()
+            t.data = singa.DefaultTranspose(self.data)
         else:
             if(len(axes) != len(self.shape)):
                 raise ValueError('dimensions do not match')
             tshape = [self.shape[x] for x in axes]
             t.shape = tuple(tshape)
-            t.data = self.data.Transpose(list(axes))
+            t.data = singa.Transpose(self.data, list(axes))
         return t
 
     def size(self):  # TODO(wangwei) compute size
@@ -166,17 +166,18 @@ class Tensor(object):
         return self.data.MemSize()
 
     def reshape(self, shape):
-        '''Change the tensor shape.
+        '''Return a new tensor with the given shape, and the original 
+        tensor is not changed.
 
         Args:
-            shape (list<int>): new shape, which should have the same volumn as
-                the original shape.
+            shape (list<int>): new shape, which should have the same 
+            volumn as the original shape.
         '''
         t = Tensor(self.shape, self.device, self.dtype)
         assert product(self.shape) == product(shape), \
             'product of shape should be equal'
         t.shape = shape
-        t.data = self.data.Reshape(list(shape))
+        t.data = singa.Reshape(self.data, shape)
         return t
 
     def reset_like(self, t):
@@ -283,38 +284,41 @@ class Tensor(object):
 
         Return:
             the tensor which has been repeated
-        
+
         '''
         t = Tensor()
         t_ndim = self.ndim()
         if isinstance(repeats, int) or isinstance(repeats, long):
             if repeats < 0:
-                raise ValueError("'repeats' should not be negative: {}".format(repeats))
+                raise ValueError(
+                    "'repeats' should not be negative: {}".format(repeats))
             if axis != None and axis < 0:
                 axis += t_ndim
             # broadcast = True
             if axis == None:
                 axis = 9999
-                t.shape = (product(self.shape)*repeats,)
-                Repeats = [repeats,]
+                t.shape = (product(self.shape) * repeats,)
+                Repeats = [repeats, ]
                 t.data = self.data.Repeat(Repeats, axis)
             elif axis >= 0:
                 t_shape = list(self.shape)
-                t_shape[axis] = self.shape[axis]*repeats
+                t_shape[axis] = self.shape[axis] * repeats
                 t.shape = tuple(t_shape)
-                Repeats = [repeats,]
+                Repeats = [repeats, ]
                 t.data = self.data.Repeat(Repeats, axis)
 
         elif isinstance(repeats, tuple) or isinstance(repeats, list):
             for rep in repeats:
                 if rep < 0:
-                    raise ValueError("'repeats' should be int or sequence: {}".format(repeats))
+                    raise ValueError(
+                        "'repeats' should be int or sequence: {}".format(repeats))
 
             if axis != None and axis < 0:
                 axis += t_ndim
             if axis == None:
                 axis = 9999
-                raise ValueError("when axis us None, 'repeats' should be int: {}".format(repeats))
+                raise ValueError(
+                    "when axis us None, 'repeats' should be int: {}".format(repeats))
             elif axis >= 0:
                 t_shape = list(self.shape)
                 t_shape[axis] = sum(repeats)
@@ -323,16 +327,15 @@ class Tensor(object):
         else:
             raise ValueError('repeats should be int or sequence')
 
-        return t     
+        return t
 
     def T(self):
-        ''' shallow copy, negate the transpose field.
+        ''' shallow copy.
 
         Returns:
-            a new Tensor which shares the underlying data memory (shallow copy)
-            but is marked as a transposed version of this tensor.
+            a new Tensor which shares the underlying data memory (shallow copy).
         '''
-        return _call_singa_func(self.data.T)
+        return _call_singa_func(singa.DefaultTranspose, self.data)
 
     def copy(self):
         '''shallow copy calls copy constructor of singa::Tensor
@@ -611,8 +614,9 @@ def sizeof(dtype):
     return singa.SizeOf(dtype)
 
 
-def reshape(t, s):
-    '''Reshape the input tensor with the given shape.
+def reshape(tensor, shape):
+    '''Reshape the input tensor with the given shape and 
+    the original tensor is not changed
 
     Args:
         t (Tensor): the tensor to be changed
@@ -624,12 +628,8 @@ def reshape(t, s):
     '''
     return _call_singa_func(singa.Reshape, t.data, s)
 
-def Reshape(t,s):
-
-    ret = t.reshape(s)
-    return ret
 
-def transpose(t,axes = None):
+def transpose(t, axes=None):
     '''
     Returns:
         the transposed tensor 
@@ -796,6 +796,7 @@ def tanh(t):
     '''
     return _call_singa_func(singa.Tanh, t.data)
 
+
 def sum(t, axis=None, out=None):
     '''Sum of tensor elements over given axis
 
@@ -827,24 +828,24 @@ def sum(t, axis=None, out=None):
         one.set_value(1.0)
         ret = tensordot(t, one, t_ndim)
 
-    if isinstance(axis,int):
+    if isinstance(axis, int):
         if axis < 0:
             axis += t_ndim
 
         axis_shape = t_shape[axis]
         axis_shape = int(axis_shape)
-        one = Tensor(shape = (axis_shape, ), device = t.device)
+        one = Tensor(shape=(axis_shape, ), device=t.device)
         one.set_value(1.0)
-        ret = tensordot(t, one, axes=([axis],[0]))
+        ret = tensordot(t, one, axes=([axis], [0]))
 
-    if isinstance(axis,tuple):
+    if isinstance(axis, tuple):
         l_axis = list(axis)
         axis_shape = [t_shape[x] for x in axis]
         axisshape = tuple(axis_shape)
         one = Tensor(axisshape, t.device)
         one.set_value(1.0)
         one_axis = [x for x in range(one.ndim())]
-        ret = tensordot(t, one, (l_axis,one_axis))
+        ret = tensordot(t, one, (l_axis, one_axis))
 
     if out is not None:
         if out.shape != ret.shape:
@@ -1181,10 +1182,10 @@ def einsum(ops, *args):
     if len(broadcast_a) == 0:
         broadcast_a = [1]
     if len(broadcast_b) == 0:
-        broadcast_b = [1]  
+        broadcast_b = [1]
     mult_A = repeat(A, product(broadcast_a))
     mult_A = mult_A.reshape(reshape_A)
-    mult_A = transpose(mult_A,transpose_A)
+    mult_A = transpose(mult_A, transpose_A)
     mult_B = repeat(B, product(broadcast_b))
     mult_B = mult_B.reshape(reshape_B)
     mult_B = transpose(mult_B, transpose_B)
@@ -1199,9 +1200,9 @@ def einsum(ops, *args):
     res = transpose(res, transpose_res)
 
     return res
-    
 
-def repeat (t, repeats, axis = None):
+
+def repeat(t, repeats, axis=None):
     '''Return the repeated tensor
     Args:
         t(tensor): the tensor to be repeated
@@ -1213,12 +1214,11 @@ def repeat (t, repeats, axis = None):
     Return:
         the tensor which has been repeated
     '''
-    ret = t.repeat(repeats,axis)
+    ret = t.repeat(repeats, axis)
     return ret
 
-        
-def tensordot (A,B,axes=2):
 
+def tensordot(A, B, axes=2):
     """Returns the tensor multiplication of two tensors along specified axes.
 
     This is equivalent to compute dot product along the specified axes which
@@ -1244,30 +1244,33 @@ def tensordot (A,B,axes=2):
     # when axes is an integer, axes_A and axes_B represent axes at the last of ''A'' and
     # the first of ''B''. For example, when axes is 1, we do the normal multiplication :
     # if A is in shape(3,2,4), B is in shape(4,2,5), it will return a matrix in shape(3,2,2,5)
-    #when axes is 2 and A,B are shape (3,2,4) and (2,4,5), it will return a matrix in shape(3,5)
+    # when axes is 2 and A,B are shape (3,2,4) and (2,4,5), it will return a
+    # matrix in shape(3,5)
 
     if type(axes) == int or type(axes) == long:
         axes_A = list(range(-axes, 0))
         axes_B = list(range(0, axes))
         axes_B = axes_B
     else:
-        axes_A,axes_B =axes
+        axes_A, axes_B = axes
     # when axes is a pair of sequences of integers.For example, A is in shape(3,2,4),
-    #B is in shape(4,2,5), we set axes as ([1,2],[1,0]), it will return a matrix in shape(3,5)
-    if isinstance(axes_A,list):
+    # B is in shape(4,2,5), we set axes as ([1,2],[1,0]), it will return a
+    # matrix in shape(3,5)
+    if isinstance(axes_A, list):
         na = len(axes_A)
         axes_A = list(axes_A)
     else:
         axes_A = [axes_A]
         na = 1
-    if isinstance(axes_B,list):
+    if isinstance(axes_B, list):
         nb = len(axes_B)
         axes_B = list(axes_B)
     else:
         axes_B = [axes_B]
         nb = 1
 
-    # a_shape and b_shape are the shape of tensor A and B, while nda and ndb are the dim of A and B
+    # a_shape and b_shape are the shape of tensor A and B, while nda and ndb
+    # are the dim of A and B
     a_shape = A.shape
     nda = A.ndim()
     b_shape = B.shape
@@ -1277,7 +1280,7 @@ def tensordot (A,B,axes=2):
     if na != nb:
         equal = False
     else:
-    # to make the shape match
+        # to make the shape match
         for k in range(na):
             if a_shape[axes_A[k]] != b_shape[axes_B[k]]:
                 equal = False
@@ -1291,18 +1294,19 @@ def tensordot (A,B,axes=2):
     '''start to do the calculation according to the axes'''
 
     notin = [k for k in range(nda) if k not in axes_A]
-    # nda is the dim of A, and axes_a is the axis for A, notin is the axis which is not in axes_A
+    # nda is the dim of A, and axes_a is the axis for A, notin is the axis
+    # which is not in axes_A
     newaxes_a = notin + axes_A
     N2 = 1
     for axis in axes_A:
         N2 *= a_shape[axis]
     N1 = 1
     for ax in notin:
-        N1 *=a_shape[ax]
+        N1 *= a_shape[ax]
     # newshape_a is the shape to do multiplication.For example, A is in shape(3,2,4),
-    #B is in shape(4,2,5), we set axes as ([1,2],[1,0]), then newshape_a should be (3,5)
-    #olda is the shape that will be shown in the result.
-    newshape_a = (N1,N2)
+    # B is in shape(4,2,5), we set axes as ([1,2],[1,0]), then newshape_a should be (3,5)
+    # olda is the shape that will be shown in the result.
+    newshape_a = (N1, N2)
     olda = [a_shape[axis] for axis in notin]
     notin = [k for k in range(ndb) if k not in axes_B]
     newaxes_b = axes_B + notin
@@ -1320,7 +1324,7 @@ def tensordot (A,B,axes=2):
     at = Reshape(A, newshape_a)
     bt = Reshape(B, newshape_b)
 
-    res = mult(at,bt)
+    res = mult(at, bt)
     if len(olda + oldb) == 0:
         olda = [1]
         oldb = [1]
@@ -1330,6 +1334,7 @@ def tensordot (A,B,axes=2):
 
     return res
 
+
 def div(lhs, rhs, ret=None):
     '''Elementi-wise division.
 

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b30d7ea5/src/api/core_tensor.i
----------------------------------------------------------------------
diff --git a/src/api/core_tensor.i b/src/api/core_tensor.i
index cc72d21..9427b11 100644
--- a/src/api/core_tensor.i
+++ b/src/api/core_tensor.i
@@ -101,12 +101,11 @@ namespace singa{
     const std::vector<size_t> &shape() const;
     const size_t shape(size_t idx) const;
     bool transpose() const;
-    size_t nDim() const;
-    Tensor Transpose() const;
-    Tensor Transpose(const std::vector<size_t> &axes) const;
+    size_t nDim() const;    
+
     size_t Size() const;
     size_t MemSize() const;
-    Tensor Reshape(const std::vector<size_t> &shape);
+    
     void ResetLike(const Tensor &t);
     void AsType(DataType type);
     void ToDevice(std::shared_ptr<singa::Device> dev);
@@ -122,10 +121,10 @@ namespace singa{
 
     void CopyData(const Tensor &other);
     void RepeatData(std::vector<size_t> repeats, int axis, int total_repeats, const Tensor &src);
+    
     Tensor Clone() const;
     Tensor Repeat(std::vector<size_t> repeats, int axis);
-    Tensor T() const;
-
+    
 
 #if USE_JAVA
     %rename(iAdd) operator+=(const Tensor &t);
@@ -166,6 +165,10 @@ namespace singa{
                         Tensor *dst, const Tensor &src, const size_t num);
 
   Tensor Reshape(const Tensor &in, const std::vector<size_t> &s);
+  Tensor Transpose(const Tensor &in, const std::vector<size_t> &axes);
+
+  %rename(DefaultTranspose) Transpose(const Tensor &in);
+  Tensor Transpose(const Tensor &in);
 
   Tensor Abs(const Tensor &t);
   Tensor Exp(const Tensor &t);
@@ -326,6 +329,6 @@ namespace singa{
   Tensor SoftMax(const Tensor &in);
   void SoftMax(const Tensor &in, Tensor *out);
 
-  const Tensor CrossEntropyFwd(const Tensor& p, const Tensor& t);
-  const Tensor SoftmaxCrossEntropyBwd(const Tensor& p, const Tensor& t);
+  Tensor CrossEntropyFwd(const Tensor& p, const Tensor& t);
+  Tensor SoftmaxCrossEntropyBwd(const Tensor& p, const Tensor& t);
 }

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b30d7ea5/src/core/tensor/tensor.cc
----------------------------------------------------------------------
diff --git a/src/core/tensor/tensor.cc b/src/core/tensor/tensor.cc
index e5e8017..1ac1b42 100755
--- a/src/core/tensor/tensor.cc
+++ b/src/core/tensor/tensor.cc
@@ -21,6 +21,7 @@
 #include "./tensor_math_cuda.h"
 #include "./tensor_math_opencl.h"
 #include <utility>
+#include <algorithm>
 
 #define Noaxis 9999
 
@@ -45,13 +46,7 @@ Tensor::Tensor(const Shape &shape, DataType dtype)
     block_ = device_->NewBlock((int)size);
   generate_strides();
 }
-Tensor::Tensor(Shape &&shape, DataType dtype)
-  : data_type_(dtype), device_(defaultDevice), shape_(shape) {
-  size_t size = Product(shape_) * SizeOf(data_type_);
-  if (size)
-    block_ = device_->NewBlock((int)size);
-  generate_strides();
-}
+
 
 //non-strided constructors with device
 Tensor::Tensor(const Shape &shape, std::shared_ptr<Device> device,
@@ -62,56 +57,24 @@ Tensor::Tensor(const Shape &shape, std::shared_ptr<Device> device,
     block_ = device_->NewBlock((int)size);
   generate_strides();
 }
-Tensor::Tensor(Shape &&shape, std::shared_ptr<Device> device, DataType dtype)
-  : data_type_(dtype), device_(device), shape_(shape) {
-  size_t size = Product(shape_) * SizeOf(data_type_);
-  if (size)
-    block_ = device_->NewBlock((int)size);
-  generate_strides();
-}
 
 
-Tensor::Tensor(const Tensor &in)
-  : //transpose_(in.transpose_),
-    data_type_(in.data_type_),
-    device_(in.device_),
-    block_(in.block()),
-    shape_(in.shape_),
-    strides_(in.strides_) {
+Tensor::Tensor(const Tensor &in) : data_type_(in.data_type_),
+  device_(in.device_),  block_(in.block()),  shape_(in.shape_),
+  strides_(in.strides_) {
   if (block_ != nullptr)
     block_->IncRefCount();
 }
 
-//strided constructor taking in a tensor, shape and strides
-Tensor::Tensor(const Tensor &in, Shape &new_shape, vector<int> &new_strides)
-  : //transpose_(in.transpose_),
-    data_type_(in.data_type_),
-    device_(in.device_),
-    block_(in.block()),
-    shape_(new_shape),
-    strides_(new_strides) {
-  if (block_ != nullptr)
-    block_->IncRefCount();
-}
 
-Tensor::Tensor(Tensor &&in)
-  : //transpose_(in.transpose_),
-    data_type_(in.data_type_),
-    device_(in.device_),
-    shape_(std::move(in.shape_)),
-    strides_(in.strides_) {
+Tensor::Tensor(Tensor &&in) : data_type_(in.data_type_),
+  device_(in.device_), shape_(std::move(in.shape_)),
+  strides_(std::move(in.strides_)) {
   block_ = in.block_;
   in.block_ = nullptr;
 }
 
 
-void Tensor::SetBlock(Block *block) {
-  LOG(WARNING) << "Pls avoid using this function, which may have side-effect.";
-  if (block_ != nullptr)
-    if (block_->DecRefCount()) device_->FreeBlock(block_);
-  block_ = block;
-}
-
 void Tensor::ResetLike(const Tensor &in) {
   if (block_ == nullptr || device_ != in.device_ || MemSize() != in.MemSize()) {
     if (block_ != nullptr && block_->DecRefCount() == 0)
@@ -124,41 +87,16 @@ void Tensor::ResetLike(const Tensor &in) {
   strides_ = in.strides_;
 }
 
-// if tensor is not transposed yet i.e strides == 1,
-// then we simply change the shape and generate new default strides
-// if tensor is already transposed i.e strides != 1,
-// it should be copied to a new tensor with newly generated default strides
-// TODO(wangwei) raise error if the shape not match
-
-// void Tensor::Reshape(const Shape &shape) {
-//   if (strides_.size() == 0)
-//     strides_.push_back(1);
-
-//   if (Product(shape_) != Product(shape)) {
-//     if (block_ != nullptr && block_->DecRefCount() == 0)
-//       device_->FreeBlock(block_);
-//     block_ = device_->NewBlock((int)(Product(shape) * SizeOf(data_type_)));
-//   } else if (transpose()) {
-//     LOG(FATAL) << "Reshape Error: Reshape called on tranposed tensor. Not implemented yet." ;
-//   }
-//   shape_ = shape;
-//   generate_strides();
-// }
-
-// void Tensor::Reshape(Shape &&shape) {
-//   if (strides_.size() == 0)
-//     strides_.push_back(1);
-
-//   if (Product(shape_) != Product(shape)) {
-//     if (block_ != nullptr && block_->DecRefCount() == 0)
-//       device_->FreeBlock(block_);
-//     block_ = device_->NewBlock((int)(Product(shape) * SizeOf(data_type_)));
-//   } else if (transpose()) {
-//     LOG(FATAL) << "Reshape Error: Reshape called on tranposed tensor. Not implemented yet." ;
-//   }
-//   shape_ = std::move(shape);
-//   generate_strides();
-// }
+void Tensor::SetShape(const Shape& shape) {
+  if (Product(shape_) != Product(shape)) {
+    if (block_ != nullptr && block_->DecRefCount() == 0)
+      device_->FreeBlock(block_);
+    block_ = device_->NewBlock((int)(Product(shape) * SizeOf(data_type_)));
+  }
+  shape_ = shape;
+  generate_strides();
+}
+
 
 void Tensor::AsType(const DataType type) {
   if (data_type_ != type) {
@@ -217,7 +155,8 @@ void Tensor::CopyData(const Tensor &src) {
   }
 }
 
-void Tensor::RepeatData(vector<size_t> repeats, int axis, int total_repeats, const Tensor &src) {
+void Tensor::RepeatData(const vector<size_t>& repeats, int axis, int total_repeats,
+                        const Tensor &src) {
   if (repeats.size() == 1) {
     CHECK_EQ(Size(), src.Size()*total_repeats);
   } else {
@@ -336,7 +275,8 @@ void Tensor::ToProto(singa::TensorProto *proto) const {
   }
 }
 
-Tensor Tensor::Repeat(vector<size_t> repeats, int axis, std::shared_ptr<Device> device) {
+Tensor Tensor::Repeat(const vector<size_t>& repeats, int axis,
+                      std::shared_ptr<Device> device) {
   if (device == nullptr) device = device_;
   vector<size_t> tshape;
   int total_repeats = 0;
@@ -346,7 +286,7 @@ Tensor Tensor::Repeat(vector<size_t> repeats, int axis, std::shared_ptr<Device>
   } else {
     if (repeats.size() == 1) {
       total_repeats = repeats[0];
-      for (size_t i = 0; i < shape_.size(); i++) {
+      for (int i = 0; i < static_cast<int>(shape_.size()); i++) {
         if (i == axis) {
           tshape.push_back(shape_[i] * total_repeats);
         } else {
@@ -363,7 +303,7 @@ Tensor Tensor::Repeat(vector<size_t> repeats, int axis, std::shared_ptr<Device>
         }
         total_repeats += repeats[i];
       }
-      for (size_t i = 0; i < shape_.size(); i++) {
+      for (int i = 0; i < static_cast<int>(shape_.size()); i++) {
         if (i == axis) {
           tshape.push_back(total_repeats);
         } else {
@@ -387,68 +327,53 @@ Tensor Tensor::Clone(std::shared_ptr<Device> device) const {
   return t;
 }
 
-Tensor Tensor::T() const {
+Tensor& Tensor::T() {
   // this function only works for 2d tensors
   CHECK_EQ(shape_.size(), 2u);
-  Tensor t;
-  t.device_ = device_;
-  t.data_type_ = data_type_;
-  t.shape_.push_back(shape_[1]);
-  t.shape_.push_back(shape_[0]);
-  t.strides_.clear();
-  t.strides_.push_back(strides_[1]);
-  t.strides_.push_back(strides_[0]);
-  t.block_ = block_;
-  block_->IncRefCount();
-  return t;
+  Transpose();
+  return *this;
 }
 
 //normal transpose without axes
-Tensor Tensor::Transpose() const {
-  // if(shape_.size() != strides_.size())
-  //   generate_strides();
-
-  Tensor t;
-  t.device_ = device_;
-  t.data_type_ = data_type_;
-  t.strides_.clear();
-  for (size_t n = 0; n < shape_.size(); ++n) {
-    t.shape_.push_back(shape_[shape_.size() - n - 1]);
-    t.strides_.push_back(strides_[shape_.size() - n - 1]);
-  }
-  t.block_ = block_;
-  block_->IncRefCount();
-  return t;
+Tensor& Tensor::Transpose() {
+  std::reverse(shape_.begin(), shape_.end());
+  std::reverse(strides_.begin(), strides_.end());
+  return *this;
 }
 
 //transpose with axes
-// TODO(wangwei) the shape and axes should match
-Tensor Tensor::Transpose(const vector<size_t> &axes) const {
-  // if(axes.size() != shape_.size()){
-  //   std::cout << "Warning: Size of input axes doesn't match size of shape" << std::endl;
-  //   return void();
-  // }
-  // if(shape_.size() != strides_.size())
-  //   generate_strides();
+Tensor& Tensor::Transpose(const vector<size_t> &axes) {
+  CHECK_EQ(axes.size(), shape_.size()) <<
+                                       "Tranpose axes's length should be equal to shape";
 
-  Tensor t;
-  t.device_ = device_;
-  t.data_type_ = data_type_;
-  t.strides_.clear();
+  auto shape = shape_;
+  auto strides = strides_;
+  shape_.clear();
+  strides_.clear();
   for (size_t n = 0; n < axes.size(); ++n) {
-    t.shape_.push_back(shape_[axes[n]]);
-    t.strides_.push_back(strides_[axes[n]]);
+    shape_.push_back(shape[axes[n]]);
+    strides_.push_back(strides[axes[n]]);
   }
-  t.block_ = block_;
-  block_->IncRefCount();
-  return t;
+  return *this;
+}
+
+//normal transpose without axes
+Tensor Transpose(const Tensor& in) {
+  Tensor out(in);
+  out.Transpose();
+  return out;
+}
+
+//transpose with axes
+Tensor Transpose(const Tensor& in, const vector<size_t> &axes) {
+  Tensor out(in);
+  out.Transpose(axes);
+  return out;
 }
 
 Tensor &Tensor::operator=(const Tensor &in) {
-  // LOG(ERROR) << "= const &";
   if (block_ != nullptr && block_->DecRefCount() == 0)
     device_->FreeBlock(block_);
-  //transpose_ = in.transpose_;
   strides_ = in.strides_;
   data_type_ = in.data_type_;
   shape_ = in.shape_;
@@ -460,11 +385,9 @@ Tensor &Tensor::operator=(const Tensor &in) {
 }
 
 Tensor &Tensor::operator=(Tensor &&in) {
-  // LOG(ERROR) << "= &&";
   if (block_ != nullptr && block_->DecRefCount() == 0)
     device_->FreeBlock(block_);
-  //transpose_ = in.transpose_;
-  strides_ = std::move(in.strides_);
+    strides_ = std::move(in.strides_);
   data_type_ = in.data_type_;
   shape_ = std::move(in.shape_);
   device_ = in.device_;
@@ -473,17 +396,6 @@ Tensor &Tensor::operator=(Tensor &&in) {
   return *this;
 }
 
-// Tensor Reshape(const Tensor &in, const Shape &s) {
-//   // Tensor out(in);
-//   // out.Reshape(s);
-//   return out;
-// }
-
-// Tensor Reshape(const Tensor &in, Shape &&s) {
-//   // Tensor out(in);
-//   // out.Reshape(std::move(s));
-//   return out;
-// }
 
 #define GenUnaryTensorArgMemberFn(op, fn) \
   Tensor &Tensor::op(const Tensor &in) {  \
@@ -539,7 +451,7 @@ void CopyDataToFrom(Tensor *dst, const Tensor &src, const size_t num,
   }
 }
 
-void RepeatDataToFrom(bool broadcast_flag, vector<size_t> repeats, int axis,
+void RepeatDataToFrom(bool broadcast_flag, const vector<size_t>& repeats, int axis,
                       Tensor *dst, const Tensor &src, const size_t num) {
   if (repeats.size() == 1) {
     broadcast_flag = true;
@@ -561,11 +473,11 @@ void RepeatDataToFrom(bool broadcast_flag, vector<size_t> repeats, int axis,
     axis_shape = 1;
     shape_outer = Product(src.shape());
   } else {
-    for (size_t i = 0; i < axis; i++) {
+    for (int i = 0; i < axis; i++) {
       shape_outer *= src.shape()[i];
     }
     axis_shape = src.shape()[axis];
-    for (size_t i = axis + 1; i < src.nDim(); i++) {
+    for (int i = axis + 1; i < static_cast<int>(src.nDim()); i++) {
       chunk *= src.shape()[i];
     }
   }
@@ -912,7 +824,7 @@ template <typename SType>
 void AddColumn(const SType alpha, const SType beta, const Tensor &v,
                Tensor *M) {
   if (M->transpose()) {
-    Tensor X = M->T();
+    Tensor X = Transpose(*M);
     AddRow(v, &X);
   } else {
     CHECK_EQ(M->nDim(), 2u);
@@ -935,7 +847,7 @@ void AddRow(const Tensor &v, Tensor *M) { AddRow(1, 1, v, M); }
 template <typename SType>
 void AddRow(const SType alpha, const SType beta, const Tensor &v, Tensor *M) {
   if (M->transpose()) {
-    Tensor X = M->T();
+    Tensor X = Transpose(*M);
     AddColumn(v, &X);
   } else {
     CHECK_EQ(M->nDim(), 2u);
@@ -980,7 +892,7 @@ Tensor ConcatOn(const vector<Tensor> &in, int axis) {
       tmp.push_back(Reshape(t, {t.shape(0), t.Size() / t.shape(0)}));
     }
     auto ret = ConcatenateRows(tmp);
-    ret = ret.Reshape(out_shape);
+    ret.Reshape(out_shape);
     return ret;
   } else {
     for (const auto& t : in) {
@@ -990,7 +902,7 @@ Tensor ConcatOn(const vector<Tensor> &in, int axis) {
       tmp.push_back(Reshape(t, {nrow, t.Size() / nrow}));
     }
     auto ret = ConcatenateColumns(tmp);
-    ret = ret.Reshape(out_shape);
+    ret.Reshape(out_shape);
     return ret;
   }
 }
@@ -1059,7 +971,8 @@ Tensor CopyRows(const Tensor &in, const size_t start, const size_t end) {
 }
 
 
-Tensor SliceOn(const Tensor&in, const size_t start, const size_t end, int axis) {
+Tensor SliceOn(const Tensor&in, const size_t start, const size_t end,
+               int axis) {
   Shape out_shape = in.shape();
   out_shape[axis] = end - start;
   if (axis == 0) {
@@ -1074,7 +987,7 @@ Tensor SliceOn(const Tensor&in, const size_t start, const size_t end, int axis)
     auto suffix = in.Size() / nrow / in.shape(axis);
     auto ret = SliceColumns(Reshape(in, {nrow, in.Size() / nrow}),
                             start * suffix, end * suffix);
-    ret = ret.Reshape(out_shape);
+    ret.Reshape(out_shape);
     return ret;
   }
 }
@@ -1145,7 +1058,7 @@ void SubRow(const Tensor &v, Tensor *M) { AddRow(-1, 1, v, M); }
 
 void SumColumns(const Tensor &M, Tensor *v) {
   if (M.transpose()) {
-    Tensor X = M.T();
+    Tensor X = Transpose(M);
     SumRows(X, v);
   } else {
     CHECK_EQ(M.nDim(), 2u);
@@ -1160,7 +1073,7 @@ void SumColumns(const Tensor &M, Tensor *v) {
 }
 void SumRows(const Tensor &M, Tensor *v) {
   if (M.transpose()) {
-    Tensor X = M.T();
+    Tensor X = Transpose(M);
     SumColumns(X, v);
   } else {
     CHECK_EQ(M.nDim(), 2u);
@@ -1170,7 +1083,7 @@ void SumRows(const Tensor &M, Tensor *v) {
 
     Tensor one(Shape{nb_row}, M.device(), M.data_type());
     one.SetValue(1.0f);  // TODO(wangwei) cast type
-    Tensor X = M.T();
+    Tensor X = Transpose(M);
     Mult(X, one, v);
   }
 }
@@ -1268,13 +1181,13 @@ void Mult(const SType alpha, const Tensor &A, const Tensor &B, const SType beta,
 // ************************
 // Misc.
 // ************************
-const Tensor CrossEntropyFwd(const Tensor& p, const Tensor& t) {
+Tensor CrossEntropyFwd(const Tensor& p, const Tensor& t) {
   Tensor loss({p.shape(0)}, p.device(), p.data_type());
   ComputeCrossEntropy(p, t, &loss);
   return loss;
 }
 
-const Tensor SoftmaxCrossEntropyBwd(const Tensor& p, const Tensor& t) {
+Tensor SoftmaxCrossEntropyBwd(const Tensor& p, const Tensor& t) {
   auto g = p.Clone();
   SoftmaxCrossEntropyBwd(t, &g);
   return g;
@@ -1310,65 +1223,20 @@ void SoftmaxCrossEntropyBwd(const Tensor &t, Tensor *p) {
   });
 }
 
-Tensor Tensor::Reshape(const Shape &shape) {
-  if (strides_.size() == 0)
-    strides_.push_back(1);
 
-  // TODO(wangwei) remove this condition and report error if size changes.
-  if (Product(shape_) != Product(shape)) {
-    if (block_ != nullptr && block_->DecRefCount() == 0)
-      device_->FreeBlock(block_);
-    block_ = device_->NewBlock((int)(Product(shape) * SizeOf(data_type_)));
-    shape_ = shape;
-    generate_strides();
-    return *this;
-
-  } else if (transpose()) {
-    Tensor t(shape_, device_, data_type_);
-    t.block_ = t.device()->NewBlock((int)(Product(shape) * SizeOf(data_type_)));
+// if tensor is not transposed yet, we change the shape and generate new strides
+// if tensor is already transposed, we reallocate the memory and generate strides
+Tensor& Tensor::Reshape(const Shape &shape) {
+  if (transpose()) {
+    Tensor t(shape, device_, data_type_);
     singa::Transform(*this, &t);
-    t.shape_ = shape;
-    return t;
+    shape_ = shape;
+    std::swap(t.block_, block_);
   } else {
-    Tensor t;
-    t.shape_ = shape;
-    t.device_ = device_;
-    t.data_type_ = data_type_;
-    t.block_ = block_;  // be careful about the block inference (mem leaking)
-    t.block_->IncRefCount();
-    t.generate_strides();
-    return t;
-  }
-}
-
-Tensor Tensor::Reshape(Shape &&shape) {
-  if (strides_.size() == 0)
-    strides_.push_back(1);
-
-  if (Product(shape_) != Product(shape)) {
-    if (block_ != nullptr && block_->DecRefCount() == 0)
-      device_->FreeBlock(block_);
-    block_ = device_->NewBlock((int)(Product(shape) * SizeOf(data_type_)));
-    shape_ = std::move(shape);
+    shape_ = shape;
     generate_strides();
-    return *this;
-
-  } else if (transpose()) {
-    Tensor t(shape_, device_, data_type_);
-    t.block_ = t.device()->NewBlock((int)(Product(shape) * SizeOf(data_type_)));
-    singa::Transform(*this, &t);
-    t.shape_ = shape;
-    return t;
-  } else {
-    Tensor t;
-    t.shape_ = shape;
-    t.device_ = device_;
-    t.data_type_ = data_type_;
-    t.block_ = block_;  // be careful about the block inference (mem leaking)
-    t.block_->IncRefCount();
-    t.generate_strides();
-    return t;
   }
+  return *this;
 }
 
 Tensor Reshape(const Tensor &in, const Shape &s) {
@@ -1376,9 +1244,4 @@ Tensor Reshape(const Tensor &in, const Shape &s) {
   return out.Reshape(s);
 }
 
-Tensor Reshape(const Tensor &in, Shape &&s) {
-  Tensor out(in);
-  return out.Reshape(std::move(s));
-}
-
 }  // namespace singa

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b30d7ea5/src/core/tensor/tensor_math.h
----------------------------------------------------------------------
diff --git a/src/core/tensor/tensor_math.h b/src/core/tensor/tensor_math.h
index f438fc6..f5fbc84 100644
--- a/src/core/tensor/tensor_math.h
+++ b/src/core/tensor/tensor_math.h
@@ -253,7 +253,7 @@ void Tanh(const Tensor &in, Tensor *out, Context *ctx) {
 
 /// similar to cudnnTransformTensor
 /// copies the data from one tensor to another tensor with a different layout
-/// the tensors must have the same dimensions but not necessarily the same strides 
+/// the tensors must have the same dimensions but not necessarily the same strides
 template <typename DType, typename Lang>
 void Transform(const Tensor &in, Tensor *out, Context *ctx) {
   LOG(FATAL) << "Transform Not Implemented";

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b30d7ea5/src/core/tensor/tensor_math_cuda.h
----------------------------------------------------------------------
diff --git a/src/core/tensor/tensor_math_cuda.h b/src/core/tensor/tensor_math_cuda.h
index 2a43468..dfe5724 100644
--- a/src/core/tensor/tensor_math_cuda.h
+++ b/src/core/tensor/tensor_math_cuda.h
@@ -54,34 +54,23 @@ cudnn requires tensor dimensions to fulfill 1 requirement:
            Tensor B has shape (2,3,4), cudnn requires shape of {1,2,3,4} to be the input
 */
 vector<int> generate_shape_cuda(const Tensor& x) {
-  Shape shape_ = x.shape();
+  Shape shape = x.shape();
+  CHECK_LE(shape.size(), 5) << "Dimensions (shape) beyond 5 are currently not supported" ;
   vector<int> shape_arr;
-  if (shape_.size() <= 4) {
-    for (size_t n = 0; n < 4 - shape_.size(); ++n) {
+  if (shape.size() <= 4) {
+    for (int n = 0; n < 4 - shape.size(); ++n) {
       shape_arr.push_back(1);
     }
-    for (size_t n = 0; n < shape_.size(); ++n) {
-      shape_arr.push_back(shape_.at(n));
-    }
-    return shape_arr;
-  } else if (shape_.size() == 5) {
-    for (size_t n = 0; n < shape_.size(); ++n) {
-      shape_arr.push_back(shape_.at(n));
-    }
-    return shape_arr;
-  } else {
-    LOG(FATAL) << "Dimensions (shape) beyond 5 are currently not supported" ;
   }
+  for(auto x: shape)
+    shape_arr.push_back(static_cast<int>(x));
   return shape_arr;
 }
 
 int generate_dim_cuda(const Tensor& x) {
+  CHECK_LE(x.nDim(), 5) << "Dimensions (shape) beyond 5 are currently not supported" ;
   if (x.shape().size() <= 4) {return 4;}
-  else if (x.shape().size() == 5) {return 5;}
-  else {
-    LOG(FATAL) << "Dimensions (shape) beyond 5 are currently not supported" ;
-  }
-  return 0;
+  else {return 5;}
 }
 
 /*
@@ -94,29 +83,17 @@ int generate_dim_cuda(const Tensor& x) {
     and stride {9, 9, 3, 1} or {9, 9, 1, 3} to be the inputs
   */
 vector<int> generate_strides_cuda(const Tensor& x) {
-  Shape shape_ = x.shape();
-  vector<int> strides_ = x.strides();
+  Shape shape = x.shape();
+  auto& strides = x.strides();
   vector<int> strides_arr;
-  int product = 1;
-  for (size_t n = 0; n < (shape_.size()); ++n) {
-    product *= shape_[n];
-  }
-  if (shape_.size() <= 4) {
-    for (size_t n = 0; n < 4 - shape_.size(); ++n) {
+  int product = Product(shape);
+  if (shape.size() <= 4) {
+    for (int n = 0; n < 4 - shape.size(); ++n) {
       strides_arr.push_back(product);
     }
-    for (size_t n = 0; n < strides_.size(); ++n) {
-      strides_arr.push_back(strides_[n]);
-    }
-    return strides_arr;
-  } else if (shape_.size() == 5) {
-    for (size_t n = 0; n < strides_.size(); ++n) {
-      strides_arr.push_back(strides_[n]);
-    }
-    return strides_arr;
-  } else {
-    LOG(FATAL) << "Dimensions (strides) beyond 5 are currently not supported" ;
   }
+  for(auto x : strides)
+    strides_arr.push_back(static_cast<int>(x));
   return strides_arr;
 }
 
@@ -241,6 +218,22 @@ void Sub<float, lang::Cuda>(const Tensor& in1,
   }
 }
 
+template <>
+void Transform<float, lang::Cuda>(const Tensor& in, Tensor* out,
+                             Context* ctx) {
+  const float* inPtr = static_cast<const float*>(in.block()->data());
+  float* outPtr = static_cast<float*>(out->block()->mutable_data());
+
+  float alpha = 1.0;
+  float beta = 0.0;
+
+  check_cudnn(cudnnTransformTensor(ctx->cudnn_handle,
+                         (void*)(&alpha), generate_tensor_nd_desc(in), inPtr,
+                         (void*)(&beta), generate_tensor_nd_desc(*out), outPtr
+                        ));
+
+}
+
 /// Element-wise operation, clamp every element into [low, high]
 /// if x>high, then x=high; if x<low, then x=low.
 template <>
@@ -254,14 +247,7 @@ void Clamp<float, lang::Cuda>(const float low,
   if (in.strides() == out->strides()) {
     cuda::clamp(num, low, high, inPtr, outPtr, ctx->stream);
   } else { //else we transform in to out to store first
-    float alpha = 1.0;
-    float beta = 0.0;
-
-    check_cudnn(cudnnTransformTensor(ctx->cudnn_handle,
-                         (void*)(&alpha), generate_tensor_nd_desc(in), inPtr,
-                         (void*)(&beta), generate_tensor_nd_desc(*out), outPtr
-                        ));
-
+    Transform<float, lang::Cuda>(in, out, ctx);
     cuda::clamp(num, low, high, outPtr, outPtr, ctx->stream);
   }
 }
@@ -280,36 +266,18 @@ void Div<float, lang::Cuda>(const Tensor& in1,
   if (!in1.transpose() && !in2.transpose() && (in1.strides() == in2.strides())) {
     cuda::div(num, inPtr1, inPtr2, outPtr, ctx->stream);
   } else { //else we check whether in1 or in2 or both are transposed
-    float alpha = 1.0;
-    float beta = 0.0;
-
     if (in1.transpose() && in2.transpose()) {
       Tensor t(in1.shape(), in1.device(), in1.data_type());
-      float* tPtr = static_cast<float*>(t.block()->mutable_data());
-
-      check_cudnn(cudnnTransformTensor(ctx->cudnn_handle,
-                           (void*)(&alpha), generate_tensor_nd_desc(in1), inPtr1,
-                           (void*)(&beta), generate_tensor_nd_desc(t), tPtr
-                          ));
+      Transform<float, lang::Cuda>(in1, &t, ctx);
+      Transform<float, lang::Cuda>(in2, out, ctx);
 
-      check_cudnn(cudnnTransformTensor(ctx->cudnn_handle,
-                           (void*)(&alpha), generate_tensor_nd_desc(in2), inPtr2,
-                           (void*)(&beta), generate_tensor_nd_desc(*out), outPtr
-                          ));
+      float* tPtr = static_cast<float*>(t.block()->mutable_data());
       cuda::div(num, tPtr, outPtr, outPtr, ctx->stream);
-
     } else if (in1.transpose()) {
-      check_cudnn(cudnnTransformTensor(ctx->cudnn_handle,
-                           (void*)(&alpha), generate_tensor_nd_desc(in1), inPtr1,
-                           (void*)(&beta), generate_tensor_nd_desc(*out), outPtr
-                          ));
+      Transform<float, lang::Cuda>(in1, out, ctx);
       cuda::div(num, outPtr, inPtr2, outPtr, ctx->stream);
-
     } else if (in2.transpose()) {
-      check_cudnn(cudnnTransformTensor(ctx->cudnn_handle,
-                           (void*)(&alpha), generate_tensor_nd_desc(in2), inPtr2,
-                           (void*)(&beta), generate_tensor_nd_desc(*out), outPtr
-                          ));
+      Transform<float, lang::Cuda>(in2, out, ctx);
       cuda::div(num, inPtr1, outPtr, outPtr, ctx->stream);
     }
   }
@@ -325,14 +293,7 @@ void Div<float, lang::Cuda>(const float x, const Tensor& in,
   if (in.strides() == out->strides()) {
     cuda::div(num, x, inPtr, outPtr, ctx->stream);
   } else { //else we transform in to out to store first
-    float alpha = 1.0;
-    float beta = 0.0;
-
-    check_cudnn(cudnnTransformTensor(ctx->cudnn_handle,
-                         (void*)(&alpha), generate_tensor_nd_desc(in), inPtr,
-                         (void*)(&beta), generate_tensor_nd_desc(*out), outPtr
-                        ));
-
+    Transform<float, lang::Cuda>(in, out, ctx);
     cuda::div(num, x, outPtr, outPtr, ctx->stream);
   }
 }
@@ -366,36 +327,17 @@ void EltwiseMult<float, lang::Cuda>(const Tensor& in1,
   if (!in1.transpose() && !in2.transpose() && (in1.strides() == in2.strides())) {
     cuda::mult(num, inPtr1, inPtr2, outPtr, ctx->stream);
   } else { //else we check whether in1 or in2 or both are transposed
-    float alpha = 1.0;
-    float beta = 0.0;
-
     if (in1.transpose() && in2.transpose()) {
       Tensor t(in1.shape(), in1.device(), in1.data_type());
+      Transform<float, lang::Cuda>(in1, &t, ctx);
+      Transform<float, lang::Cuda>(in2, out, ctx);
       float* tPtr = static_cast<float*>(t.block()->mutable_data());
-
-      check_cudnn(cudnnTransformTensor(ctx->cudnn_handle,
-                           (void*)(&alpha), generate_tensor_nd_desc(in1), inPtr1,
-                           (void*)(&beta), generate_tensor_nd_desc(t), tPtr
-                          ));
-
-      check_cudnn(cudnnTransformTensor(ctx->cudnn_handle,
-                           (void*)(&alpha), generate_tensor_nd_desc(in2), inPtr2,
-                           (void*)(&beta), generate_tensor_nd_desc(*out), outPtr
-                          ));
       cuda::mult(num, tPtr, outPtr, outPtr, ctx->stream);
-
     } else if (in1.transpose()) {
-      check_cudnn(cudnnTransformTensor(ctx->cudnn_handle,
-                           (void*)(&alpha), generate_tensor_nd_desc(in1), inPtr1,
-                           (void*)(&beta), generate_tensor_nd_desc(*out), outPtr
-                          ));
+      Transform<float, lang::Cuda>(in1, out, ctx);
       cuda::mult(num, outPtr, inPtr2, outPtr, ctx->stream);
-
     } else if (in2.transpose()) {
-      check_cudnn(cudnnTransformTensor(ctx->cudnn_handle,
-                           (void*)(&alpha), generate_tensor_nd_desc(in2), inPtr2,
-                           (void*)(&beta), generate_tensor_nd_desc(*out), outPtr
-                          ));
+      Transform<float, lang::Cuda>(in2, out, ctx);
       cuda::mult(num, inPtr1, outPtr, outPtr, ctx->stream);
     }
   }
@@ -413,14 +355,7 @@ void Exp<float, lang::Cuda>(const Tensor& in, Tensor* out,
   if (in.strides() == out->strides()) {
     cuda::exp(num, inPtr, outPtr, ctx->stream);
   } else { //else we transform in to out to store first
-    float alpha = 1.0;
-    float beta = 0.0;
-
-    check_cudnn(cudnnTransformTensor(ctx->cudnn_handle,
-                         (void*)(&alpha), generate_tensor_nd_desc(in), inPtr,
-                         (void*)(&beta), generate_tensor_nd_desc(*out), outPtr
-                        ));
-
+    Transform<float, lang::Cuda>(in, out, ctx);
     cuda::exp(num, outPtr, outPtr, ctx->stream);
   }
 }
@@ -435,14 +370,7 @@ void GE<float, lang::Cuda>(const Tensor& in, const float x,
   if (in.strides() == out->strides()) {
     cuda::ge(num, inPtr, x, outPtr, ctx->stream);
   } else { //else we transform in to out to store first
-    float alpha = 1.0;
-    float beta = 0.0;
-
-    check_cudnn(cudnnTransformTensor(ctx->cudnn_handle,
-                         (void*)(&alpha), generate_tensor_nd_desc(in), inPtr,
-                         (void*)(&beta), generate_tensor_nd_desc(*out), outPtr
-                        ));
-
+    Transform<float, lang::Cuda>(in, out, ctx);
     cuda::ge(num, outPtr, x, outPtr, ctx->stream);
   }
 }
@@ -451,10 +379,7 @@ void GE<float, lang::Cuda>(const Tensor& in1, const Tensor& in2,
                            Tensor* out, Context* ctx) {
   Sub<float, lang::Cuda>(in1, in2, out, ctx);
   float* outPtr = static_cast<float*>(out->block()->mutable_data());
-  // const float* inPtr1 = static_cast<const float*>(in1.block()->data());
-  // const float* inPtr2 = static_cast<const float*>(in2.block()->data());
   const size_t num = in1.Size();
-  //cuda::ge(num, inPtr1, inPtr2, outPtr, ctx->stream);
   cuda::ge(num, outPtr, 0.0, outPtr, ctx->stream);
 }
 
@@ -469,14 +394,7 @@ void GT<float, lang::Cuda>(const Tensor& in, const float x,
   if (in.strides() == out->strides()) {
     cuda::gt(num, inPtr, x, outPtr, ctx->stream);
   } else { //else we transform in to out to store first
-    float alpha = 1.0;
-    float beta = 0.0;
-
-    check_cudnn(cudnnTransformTensor(ctx->cudnn_handle,
-                         (void*)(&alpha), generate_tensor_nd_desc(in), inPtr,
-                         (void*)(&beta), generate_tensor_nd_desc(*out), outPtr
-                        ));
-
+    Transform<float, lang::Cuda>(in, out, ctx);
     cuda::gt(num, outPtr, x, outPtr, ctx->stream);
   }
 }
@@ -485,10 +403,7 @@ void GT<float, lang::Cuda>(const Tensor& in1, const Tensor& in2,
                            Tensor* out, Context* ctx) {
   Sub<float, lang::Cuda>(in1, in2, out, ctx);
   float* outPtr = static_cast<float*>(out->block()->mutable_data());
-  // const float* inPtr1 = static_cast<const float*>(in1.block()->data());
-  // const float* inPtr2 = static_cast<const float*>(in2.block()->data());
   const size_t num = in1.Size();
-  //cuda::gt(num, inPtr1, inPtr2, outPtr, ctx->stream);
   cuda::gt(num, outPtr, 0.0, outPtr, ctx->stream);
 }
 
@@ -502,14 +417,7 @@ void LE<float, lang::Cuda>(const Tensor& in, const float x,
   if (in.strides() == out->strides()) {
     cuda::le(num, inPtr, x, outPtr, ctx->stream);
   } else { //else we transform in to out to store first
-    float alpha = 1.0;
-    float beta = 0.0;
-
-    check_cudnn(cudnnTransformTensor(ctx->cudnn_handle,
-                         (void*)(&alpha), generate_tensor_nd_desc(in), inPtr,
-                         (void*)(&beta), generate_tensor_nd_desc(*out), outPtr
-                        ));
-
+    Transform<float, lang::Cuda>(in, out, ctx);
     cuda::le(num, outPtr, x, outPtr, ctx->stream);
   }
 }
@@ -518,10 +426,7 @@ void LE<float, lang::Cuda>(const Tensor& in1, const Tensor& in2,
                            Tensor* out, Context* ctx) {
   Sub<float, lang::Cuda>(in1, in2, out, ctx);
   float* outPtr = static_cast<float*>(out->block()->mutable_data());
-  // const float* inPtr1 = static_cast<const float*>(in1.block()->data());
-  // const float* inPtr2 = static_cast<const float*>(in2.block()->data());
   const size_t num = in1.Size();
-  //cuda::le(num, inPtr1, inPtr2, outPtr, ctx->stream);
   cuda::le(num, outPtr, 0.0, outPtr, ctx->stream);
 }
 
@@ -536,14 +441,7 @@ void Log<float, lang::Cuda>(const Tensor& in, Tensor* out,
   if (in.strides() == out->strides()) {
     cuda::log(num, inPtr, outPtr, ctx->stream);
   } else { //else we transform in to out to store first
-    float alpha = 1.0;
-    float beta = 0.0;
-
-    check_cudnn(cudnnTransformTensor(ctx->cudnn_handle,
-                         (void*)(&alpha), generate_tensor_nd_desc(in), inPtr,
-                         (void*)(&beta), generate_tensor_nd_desc(*out), outPtr
-                        ));
-
+    Transform<float, lang::Cuda>(in, out, ctx);
     cuda::log(num, outPtr, outPtr, ctx->stream);
   }
 }
@@ -558,14 +456,7 @@ void LT<float, lang::Cuda>(const Tensor& in, const float x,
   if (in.strides() == out->strides()) {
     cuda::lt(num, inPtr, x, outPtr, ctx->stream);
   } else { //else we transform in to out to store first
-    float alpha = 1.0;
-    float beta = 0.0;
-
-    check_cudnn(cudnnTransformTensor(ctx->cudnn_handle,
-                         (void*)(&alpha), generate_tensor_nd_desc(in), inPtr,
-                         (void*)(&beta), generate_tensor_nd_desc(*out), outPtr
-                        ));
-
+    Transform<float, lang::Cuda>(in, out, ctx);
     cuda::lt(num, outPtr, x, outPtr, ctx->stream);
   }
 }
@@ -574,10 +465,7 @@ void LT<float, lang::Cuda>(const Tensor& in1, const Tensor& in2,
                            Tensor* out, Context* ctx) {
   Sub<float, lang::Cuda>(in1, in2, out, ctx);
   float* outPtr = static_cast<float*>(out->block()->mutable_data());
-  // const float* inPtr1 = static_cast<const float*>(in1.block()->data());
-  // const float* inPtr2 = static_cast<const float*>(in2.block()->data());
   const size_t num = in1.Size();
-  //cuda::lt(num, inPtr1, inPtr2, outPtr, ctx->stream);
   cuda::lt(num, outPtr, 0.0, outPtr, ctx->stream);
 }
 
@@ -592,14 +480,7 @@ void Pow<float, lang::Cuda>(const Tensor& in, const float x,
   if (in.strides() == out->strides()) {
     cuda::pow(num, inPtr, x, outPtr, ctx->stream);
   } else { //else we transform in to out to store first
-    float alpha = 1.0;
-    float beta = 0.0;
-
-    check_cudnn(cudnnTransformTensor(ctx->cudnn_handle,
-                         (void*)(&alpha), generate_tensor_nd_desc(in), inPtr,
-                         (void*)(&beta), generate_tensor_nd_desc(*out), outPtr
-                        ));
-
+    Transform<float, lang::Cuda>(in, out, ctx);
     cuda::pow(num, outPtr, x, outPtr, ctx->stream);
   }
 }
@@ -617,36 +498,17 @@ void Pow<float, lang::Cuda>(const Tensor& in1,
   if (!in1.transpose() && !in2.transpose() && (in1.strides() == in2.strides())) {
     cuda::pow(num, inPtr1, inPtr2, outPtr, ctx->stream);
   } else { //else we check whether in1 or in2 or both are transposed
-    float alpha = 1.0;
-    float beta = 0.0;
-
     if (in1.transpose() && in2.transpose()) {
       Tensor t(in1.shape(), in1.device(), in1.data_type());
       float* tPtr = static_cast<float*>(t.block()->mutable_data());
-
-      check_cudnn(cudnnTransformTensor(ctx->cudnn_handle,
-                           (void*)(&alpha), generate_tensor_nd_desc(in1), inPtr1,
-                           (void*)(&beta), generate_tensor_nd_desc(t), tPtr
-                          ));
-
-      check_cudnn(cudnnTransformTensor(ctx->cudnn_handle,
-                           (void*)(&alpha), generate_tensor_nd_desc(in2), inPtr2,
-                           (void*)(&beta), generate_tensor_nd_desc(*out), outPtr
-                          ));
+      Transform<float, lang::Cuda>(in1, &t, ctx);
+      Transform<float, lang::Cuda>(in2, out, ctx);
       cuda::pow(num, tPtr, outPtr, outPtr, ctx->stream);
-
     } else if (in1.transpose()) {
-      check_cudnn(cudnnTransformTensor(ctx->cudnn_handle,
-                           (void*)(&alpha), generate_tensor_nd_desc(in1), inPtr1,
-                           (void*)(&beta), generate_tensor_nd_desc(*out), outPtr
-                          ));
+      Transform<float, lang::Cuda>(in1, out, ctx);
       cuda::pow(num, outPtr, inPtr2, outPtr, ctx->stream);
-
     } else if (in2.transpose()) {
-      check_cudnn(cudnnTransformTensor(ctx->cudnn_handle,
-                           (void*)(&alpha), generate_tensor_nd_desc(in2), inPtr2,
-                           (void*)(&beta), generate_tensor_nd_desc(*out), outPtr
-                          ));
+      Transform<float, lang::Cuda>(in2, out, ctx);
       cuda::pow(num, inPtr1, outPtr, outPtr, ctx->stream);
     }
   }
@@ -694,14 +556,7 @@ void ReLU<float, lang::Cuda>(const Tensor& in, Tensor* out,
   if (in.strides() == out->strides()) {
     cuda::relu(num, inPtr, outPtr, ctx->stream);
   } else { //else we transform in to out to store first
-    float alpha = 1.0;
-    float beta = 0.0;
-
-    check_cudnn(cudnnTransformTensor(ctx->cudnn_handle,
-                         (void*)(&alpha), generate_tensor_nd_desc(in), inPtr,
-                         (void*)(&beta), generate_tensor_nd_desc(*out), outPtr
-                        ));
-
+    Transform<float, lang::Cuda>(in, out, ctx);
     cuda::relu(num, outPtr, outPtr, ctx->stream);
   }
 }
@@ -749,14 +604,7 @@ void Sigmoid<float, lang::Cuda>(const Tensor& in, Tensor* out,
   if (in.strides() == out->strides()) {
     cuda::sigmoid(num, inPtr, outPtr, ctx->stream);
   } else { //else we transform in to out to store first
-    float alpha = 1.0;
-    float beta = 0.0;
-
-    check_cudnn(cudnnTransformTensor(ctx->cudnn_handle,
-                         (void*)(&alpha), generate_tensor_nd_desc(in), inPtr,
-                         (void*)(&beta), generate_tensor_nd_desc(*out), outPtr
-                        ));
-
+    Transform<float, lang::Cuda>(in, out, ctx);
     cuda::sigmoid(num, outPtr, outPtr, ctx->stream);
   }
 }
@@ -772,14 +620,7 @@ void Sign<float, lang::Cuda>(const Tensor& in, Tensor* out,
   if (in.strides() == out->strides()) {
     cuda::sign(num, inPtr, outPtr, ctx->stream);
   } else { //else we transform in to out to store first
-    float alpha = 1.0;
-    float beta = 0.0;
-
-    check_cudnn(cudnnTransformTensor(ctx->cudnn_handle,
-                         (void*)(&alpha), generate_tensor_nd_desc(in), inPtr,
-                         (void*)(&beta), generate_tensor_nd_desc(*out), outPtr
-                        ));
-
+    Transform<float, lang::Cuda>(in, out, ctx);
     cuda::sign(num, outPtr, outPtr, ctx->stream);
   }
 }
@@ -788,15 +629,14 @@ void Sign<float, lang::Cuda>(const Tensor& in, Tensor* out,
 template <>
 void Sqrt<float, lang::Cuda>(const Tensor& in, Tensor* out,
                              Context* ctx) {
-  const float* inPtr = static_cast<const float*>(in.block()->data());
   float* outPtr = static_cast<float*>(out->block()->mutable_data());
 
 #if CUDNN_MAJOR < 7
+  Transform<float, lang::Cuda>(in, out, ctx);
   size_t num = in.Size();
-  cuda::sqrt(num, inPtr, outPtr, ctx->stream);
-
+  cuda::sqrt(num, outPtr, outPtr, ctx->stream);
 #else
-
+  const float* inPtr = static_cast<const float*>(in.block()->data());
   float alpha1 = 1.0;
   float alpha2 = 0.0;
   float beta = 0.0;
@@ -820,14 +660,7 @@ void Square<float, lang::Cuda>(const Tensor& in, Tensor* out,
   if (in.strides() == out->strides()) {
     cuda::square(num, inPtr, outPtr, ctx->stream);
   } else { //else we transform in to out to store first
-    float alpha = 1.0;
-    float beta = 0.0;
-
-    check_cudnn(cudnnTransformTensor(ctx->cudnn_handle,
-                         (void*)(&alpha), generate_tensor_nd_desc(in), inPtr,
-                         (void*)(&beta), generate_tensor_nd_desc(*out), outPtr
-                        ));
-
+    Transform<float, lang::Cuda>(in, out, ctx);
     cuda::square(num, outPtr, outPtr, ctx->stream);
   }
 }
@@ -883,34 +716,11 @@ void Tanh<float, lang::Cuda>(const Tensor& in, Tensor* out,
   if (in.strides() == out->strides()) {
     cuda::tanh(num, inPtr, outPtr, ctx->stream);
   } else { //else we transform in to out to store first
-    float alpha = 1.0;
-    float beta = 0.0;
-
-    check_cudnn(cudnnTransformTensor(ctx->cudnn_handle,
-                         (void*)(&alpha), generate_tensor_nd_desc(in), inPtr,
-                         (void*)(&beta), generate_tensor_nd_desc(*out), outPtr
-                        ));
-
+    Transform<float, lang::Cuda>(in, out, ctx);
     cuda::tanh(num, outPtr, outPtr, ctx->stream);
   }
 }
 
-template <>
-void Transform<float, lang::Cuda>(const Tensor& in, Tensor* out,
-                             Context* ctx) {
-  const float* inPtr = static_cast<const float*>(in.block()->data());
-  float* outPtr = static_cast<float*>(out->block()->mutable_data());
-
-  float alpha = 1.0;
-  float beta = 0.0;
-
-  check_cudnn(cudnnTransformTensor(ctx->cudnn_handle,
-                         (void*)(&alpha), generate_tensor_nd_desc(in), inPtr,
-                         (void*)(&beta), generate_tensor_nd_desc(*out), outPtr
-                        ));
-
-}
-
 // ================Random functions===========================================
 /// Each element of out would be 1 with prob p and 0 with 1-p. 0<= p <= 1
 // Get the random generator from 'ctx'
@@ -1175,16 +985,7 @@ void RowMax<float, lang::Cuda>(const Tensor& in, Tensor* out,
 
   if (in.transpose()) {
     Tensor t(in.shape(), in.device(), in.data_type());
-    float* tPtr = static_cast<float*>(t.block()->mutable_data());
-
-    float alpha = 1.0;
-    float beta = 0.0;
-
-    check_cudnn(cudnnTransformTensor(ctx->cudnn_handle,
-                         (void*)(&alpha), generate_tensor_nd_desc(in), inPtr,
-                         (void*)(&beta), generate_tensor_nd_desc(t), tPtr
-                        ));
-
+    Transform<float, lang::Cuda>(in, &t, ctx);
     const float* tPtr_const = static_cast<const float*>(t.block()->data());
     cuda::RowMax(nrow, ncol, tPtr_const, outPtr, ctx->stream);
   } else {



[3/4] incubator-singa git commit: SINGA-380) Fix bugs from Reshape

Posted by wa...@apache.org.
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b30d7ea5/src/io/image_transformer.cc
----------------------------------------------------------------------
diff --git a/src/io/image_transformer.cc b/src/io/image_transformer.cc
index 204ad08..0f49321 100644
--- a/src/io/image_transformer.cc
+++ b/src/io/image_transformer.cc
@@ -26,331 +26,328 @@
 
 namespace singa {
 
-  Tensor ImageTransformer::Apply(int flag, Tensor& input) {
-    CHECK_LE(input.nDim(), 4u);
-    CHECK_GE(input.nDim(), 2u);
-    CHECK_EQ(input.data_type(), kFloat32) << "Data type " << input.data_type()
-      << " is invalid for an raw image";
-    srand((unsigned int)time(NULL));
-    /// TODO
-    /// currently only consider one sample each time
+Tensor ImageTransformer::Apply(int flag, Tensor& input) {
+  CHECK_LE(input.nDim(), 4u);
+  CHECK_GE(input.nDim(), 2u);
+  CHECK_EQ(input.data_type(), kFloat32) << "Data type " << input.data_type()
+                                        << " is invalid for an raw image";
+  srand((unsigned int)time(NULL));
+  /// TODO
+  /// currently only consider one sample each time
 
-    /// resize image using opencv resize
-    Tensor temp1;
+  /// resize image using opencv resize
+  Tensor temp1;
 #ifdef USE_OPENCV
-    temp1 = resize(input, resize_height_, resize_width_, image_dim_order_);
+  temp1 = resize(input, resize_height_, resize_width_, image_dim_order_);
 #else
-    temp1 = input;
+  temp1 = input;
 #endif
 
-    /// crop
-    Tensor temp2;
-    size_t height = 0, width = 0;
-    if (input.nDim() >= 3u) {
-      if (image_dim_order_ == "CHW")
-        height = temp1.shape(input.nDim() - 2), width = temp1.shape(input.nDim() - 1);
-      else if (image_dim_order_ == "HWC")
-        height = temp1.shape(input.nDim() - 3), width = temp1.shape(input.nDim() - 2);
-      else
-        LOG(FATAL) << "Unknow dimension order for images " << image_dim_order_
-               << " Only support 'HWC' and 'CHW'";
-    } else /// input is 2D gray image
-      height = temp1.shape(0), width = temp1.shape(1);
+  /// crop
+  Tensor temp2;
+  size_t height = 0, width = 0;
+  if (input.nDim() >= 3u) {
+    if (image_dim_order_ == "CHW")
+      height = temp1.shape(input.nDim() - 2), width = temp1.shape(input.nDim() - 1);
+    else if (image_dim_order_ == "HWC")
+      height = temp1.shape(input.nDim() - 3), width = temp1.shape(input.nDim() - 2);
+    else
+      LOG(FATAL) << "Unknow dimension order for images " << image_dim_order_
+                 << " Only support 'HWC' and 'CHW'";
+  } else /// input is 2D gray image
+    height = temp1.shape(0), width = temp1.shape(1);
 
-    if (crop_shape_.size() == 2) {
-      if (flag == kTrain) { 
-        /// random crop
-        if (crop_shape_[0] > height || crop_shape_[0] > width)
-          LOG(FATAL) << "Crop size larger than the size of raw image";
-        size_t crop_h_offset = rand() % ((height - crop_shape_[0]) / 2), 
-               crop_w_offset = rand() % ((width - crop_shape_[1]) / 2);
-        temp2 = crop(temp1, crop_shape_[0], crop_shape_[1], 
-                  crop_h_offset, crop_w_offset, image_dim_order_);
-      } else if (flag == kEval) {
-        /// central crop
-        size_t crop_h_offset = (height - crop_shape_[0]) / 2,
-               crop_w_offset = (width - crop_shape_[1]) / 2;
-        temp2 = crop(temp1, crop_shape_[0], crop_shape_[1], 
-                  crop_h_offset, crop_w_offset, image_dim_order_); 
-      }
+  if (crop_shape_.size() == 2) {
+    if (flag == kTrain) {
+      /// random crop
+      if (crop_shape_[0] > height || crop_shape_[0] > width)
+        LOG(FATAL) << "Crop size larger than the size of raw image";
+      size_t crop_h_offset = rand() % ((height - crop_shape_[0]) / 2),
+             crop_w_offset = rand() % ((width - crop_shape_[1]) / 2);
+      temp2 = crop(temp1, crop_shape_[0], crop_shape_[1],
+                   crop_h_offset, crop_w_offset, image_dim_order_);
+    } else if (flag == kEval) {
+      /// central crop
+      size_t crop_h_offset = (height - crop_shape_[0]) / 2,
+             crop_w_offset = (width - crop_shape_[1]) / 2;
+      temp2 = crop(temp1, crop_shape_[0], crop_shape_[1],
+                   crop_h_offset, crop_w_offset, image_dim_order_);
     }
-    else temp2 = temp1;
+  } else temp2 = temp1;
 
-    /// mirror
-    Tensor output;
-    if ((flag == kTrain) && (rand() % 2))
-        output = mirror(temp2, true, false, image_dim_order_);
-    else output = temp2;
-    return output;
-  }
+  /// mirror
+  Tensor output;
+  if ((flag == kTrain) && (rand() % 2))
+    output = mirror(temp2, true, false, image_dim_order_);
+  else output = temp2;
+  return output;
+}
 
 #ifdef USE_OPENCV
-  Tensor resize(Tensor& input, const size_t resize_height, 
-               const size_t resize_width, const string& image_dim_order) {
-    CHECK_LE(input.nDim(), 4u);
-    CHECK_GE(input.nDim(), 2u);
-    if (!resize_height || !resize_width) return input;
-    Tensor output;
-    cv::Mat mat;
-    const auto* in = input.data<float>();
-    if (input.nDim() == 4u) {
-      /// TODO
-      /// batch based resize
-      LOG(FATAL) << "Not implemented";
-    } else if (input.nDim() == 3u) {
-      if (image_dim_order == "CHW") {
-        size_t height = input.shape(1), width = input.shape(2),
-               channel = input.shape(0);
-        if (channel == 3u) {
-          mat = cv::Mat(height, width, CV_32FC3, cv::Scalar(0, 0, 0));
-          for (size_t i = 0; i < height; i++)
-            for (size_t j = 0; j < width; j++)
-              for (size_t k = 0; k < channel; k++)
-                mat.at<cv::Vec3f>(i, j)[k] = in[k * height * width + i * width + j];
-        }
-        else if (channel == 1u) {
-          mat = cv::Mat(height, width, CV_32FC1);
-          for (size_t i = 0; i < height; i++)
-            for (size_t j = 0; j < width; j++)
-                mat.at<cv::Vec<float, 1>>(i, j)[0] = in[i * width + j];
-        }
-        else LOG(FATAL) << "Invalid channel size: " << channel;
-      } else if (image_dim_order == "HWC") {
-        size_t height = input.shape(0), width = input.shape(1),
-               channel = input.shape(2);
-        if (channel == 3u) {
-          mat = cv::Mat(height, width, CV_32FC3, cv::Scalar(0, 0, 0));
-          for (size_t i = 0; i < height; i++)
-            for (size_t j = 0; j < width; j++)
-              for (size_t k = 0; k < channel; k++)
-                mat.at<cv::Vec3f>(i, j)[k] =
-                  in[i * width * channel + j * channel + k];
-        } else if (channel == 1u) { /// 2D gray image
-          mat = cv::Mat(height, width, CV_32FC1);
-          for (size_t i = 0; i < height; i++)
-            for (size_t j = 0; j < width; j++)
-              mat.at<cv::Vec<float, 1>>(i, j)[0] = in[i * width + j];
-        } else LOG(FATAL) << "Invalid channel size: " << channel;
-      } else {
-        LOG(FATAL) << "Unknow dimension order for images " << image_dim_order
-                   << " Only support 'HWC' and 'CHW'";
-      }
-    } else { /// 2D gray image
-      size_t height = input.shape(0), width = input.shape(1);
-      mat = cv::Mat(height, width, CV_32FC1);
-      for (size_t i = 0; i < height; i++)
-        for (size_t j = 0; j < width; j++)
-          mat.at<cv::Vec<float, 1>>(i, j)[0] = in[i * width + j];
+Tensor resize(Tensor& input, const size_t resize_height,
+              const size_t resize_width, const string& image_dim_order) {
+  CHECK_LE(input.nDim(), 4u);
+  CHECK_GE(input.nDim(), 2u);
+  if (!resize_height || !resize_width) return input;
+  Tensor output;
+  cv::Mat mat;
+  const auto* in = input.data<float>();
+  if (input.nDim() == 4u) {
+    /// TODO
+    /// batch based resize
+    LOG(FATAL) << "Not implemented";
+  } else if (input.nDim() == 3u) {
+    if (image_dim_order == "CHW") {
+      size_t height = input.shape(1), width = input.shape(2),
+             channel = input.shape(0);
+      if (channel == 3u) {
+        mat = cv::Mat(height, width, CV_32FC3, cv::Scalar(0, 0, 0));
+        for (size_t i = 0; i < height; i++)
+          for (size_t j = 0; j < width; j++)
+            for (size_t k = 0; k < channel; k++)
+              mat.at<cv::Vec3f>(i, j)[k] = in[k * height * width + i * width + j];
+      } else if (channel == 1u) {
+        mat = cv::Mat(height, width, CV_32FC1);
+        for (size_t i = 0; i < height; i++)
+          for (size_t j = 0; j < width; j++)
+            mat.at<cv::Vec<float, 1>>(i, j)[0] = in[i * width + j];
+      } else LOG(FATAL) << "Invalid channel size: " << channel;
+    } else if (image_dim_order == "HWC") {
+      size_t height = input.shape(0), width = input.shape(1),
+             channel = input.shape(2);
+      if (channel == 3u) {
+        mat = cv::Mat(height, width, CV_32FC3, cv::Scalar(0, 0, 0));
+        for (size_t i = 0; i < height; i++)
+          for (size_t j = 0; j < width; j++)
+            for (size_t k = 0; k < channel; k++)
+              mat.at<cv::Vec3f>(i, j)[k] =
+                in[i * width * channel + j * channel + k];
+      } else if (channel == 1u) { /// 2D gray image
+        mat = cv::Mat(height, width, CV_32FC1);
+        for (size_t i = 0; i < height; i++)
+          for (size_t j = 0; j < width; j++)
+            mat.at<cv::Vec<float, 1>>(i, j)[0] = in[i * width + j];
+      } else LOG(FATAL) << "Invalid channel size: " << channel;
+    } else {
+      LOG(FATAL) << "Unknow dimension order for images " << image_dim_order
+                 << " Only support 'HWC' and 'CHW'";
     }
-    cv::Size size(resize_width, resize_height);
-    cv::Mat resized;
-    cv::resize(mat, resized, size);
-    CHECK_EQ(resized.size().height, resize_height);
-    CHECK_EQ(resized.size().width, resize_width);
-    size_t new_size = resize_height * resize_width * resized.channels();
-    float* out = new float[new_size];
-    if (input.nDim() == 4u) {
-      /// TODO
-      /// batch based resize
-      LOG(FATAL) << "Not implemented";
-    } else if (input.nDim() == 3u) {
-      if (image_dim_order == "CHW") {
-        size_t height = resize_height, width = resize_width,
-           channel = input.shape(0);
-        if (channel == 3u) {
-          for (size_t i = 0; i < height; i++)
-            for (size_t j = 0; j < width; j++)
-              for (size_t k = 0; k < channel; k++)
-                out[k * height * width + i * width + j] = resized.at<cv::Vec3f>(i, j)[k];
-        } else { /// 2D gray image
-          for (size_t i = 0; i < height; i++)
-            for (size_t j = 0; j < width; j++)
-              out[i * width + j] = resized.at<cv::Vec<float, 1>>(i, j)[0];
-        }
-        Tensor temp(Shape{channel, height, width});
-        temp.CopyDataFromHostPtr<float>(out, new_size);
-        output = temp;
-      } else {
-        size_t height = resize_height, width = resize_width,
-           channel = input.shape(2);
-        if (channel == 3u) {
-          for (size_t i = 0; i < height; i++)
-            for (size_t j = 0; j < width; j++)
-              for (size_t k = 0; k < channel; k++)
-                out[i * width * channel + j * channel + k] = resized.at<cv::Vec3f>(i, j)[k];
-        } else { /// 1 channel
-          for (size_t i = 0; i < height; i++)
-            for (size_t j = 0; j < width; j++)
-              out[i * width + j] = resized.at<cv::Vec<float, 1>>(i, j)[0];
-        }
-        Tensor temp(Shape{height, width, channel}); 
-        temp.CopyDataFromHostPtr<float>(out, new_size);
-        output = temp;
+  } else { /// 2D gray image
+    size_t height = input.shape(0), width = input.shape(1);
+    mat = cv::Mat(height, width, CV_32FC1);
+    for (size_t i = 0; i < height; i++)
+      for (size_t j = 0; j < width; j++)
+        mat.at<cv::Vec<float, 1>>(i, j)[0] = in[i * width + j];
+  }
+  cv::Size size(resize_width, resize_height);
+  cv::Mat resized;
+  cv::resize(mat, resized, size);
+  CHECK_EQ(resized.size().height, resize_height);
+  CHECK_EQ(resized.size().width, resize_width);
+  size_t new_size = resize_height * resize_width * resized.channels();
+  float* out = new float[new_size];
+  if (input.nDim() == 4u) {
+    /// TODO
+    /// batch based resize
+    LOG(FATAL) << "Not implemented";
+  } else if (input.nDim() == 3u) {
+    if (image_dim_order == "CHW") {
+      size_t height = resize_height, width = resize_width,
+             channel = input.shape(0);
+      if (channel == 3u) {
+        for (size_t i = 0; i < height; i++)
+          for (size_t j = 0; j < width; j++)
+            for (size_t k = 0; k < channel; k++)
+              out[k * height * width + i * width + j] = resized.at<cv::Vec3f>(i, j)[k];
+      } else { /// 2D gray image
+        for (size_t i = 0; i < height; i++)
+          for (size_t j = 0; j < width; j++)
+            out[i * width + j] = resized.at<cv::Vec<float, 1>>(i, j)[0];
+      }
+      Tensor temp(Shape{channel, height, width});
+      temp.CopyDataFromHostPtr<float>(out, new_size);
+      output = temp;
+    } else {
+      size_t height = resize_height, width = resize_width,
+             channel = input.shape(2);
+      if (channel == 3u) {
+        for (size_t i = 0; i < height; i++)
+          for (size_t j = 0; j < width; j++)
+            for (size_t k = 0; k < channel; k++)
+              out[i * width * channel + j * channel + k] = resized.at<cv::Vec3f>(i, j)[k];
+      } else { /// 1 channel
+        for (size_t i = 0; i < height; i++)
+          for (size_t j = 0; j < width; j++)
+            out[i * width + j] = resized.at<cv::Vec<float, 1>>(i, j)[0];
       }
-    } else { /// 2D gray image
-      size_t height = resize_height, width = resize_width;
-      for (size_t i = 0; i < height; i++)
-        for (size_t j = 0; j < width; j++)
-          out[i * width + j] = resized.at<cv::Vec<float, 1>>(i, j)[0];
-      Tensor temp(Shape{height, width});
+      Tensor temp(Shape{height, width, channel});
       temp.CopyDataFromHostPtr<float>(out, new_size);
       output = temp;
     }
-    delete[] out;
-    return output;
+  } else { /// 2D gray image
+    size_t height = resize_height, width = resize_width;
+    for (size_t i = 0; i < height; i++)
+      for (size_t j = 0; j < width; j++)
+        out[i * width + j] = resized.at<cv::Vec<float, 1>>(i, j)[0];
+    Tensor temp(Shape{height, width});
+    temp.CopyDataFromHostPtr<float>(out, new_size);
+    output = temp;
   }
+  delete[] out;
+  return output;
+}
 #endif
 
-  Tensor crop(Tensor& input, const size_t crop_height, const size_t crop_width, 
-             const size_t crop_h_offset, const size_t crop_w_offset, 
-             const string& image_dim_order) {
-    CHECK_LE(input.nDim(), 4u);
-    CHECK_GE(input.nDim(), 2u);
+Tensor crop(Tensor& input, const size_t crop_height, const size_t crop_width,
+            const size_t crop_h_offset, const size_t crop_w_offset,
+            const string& image_dim_order) {
+  CHECK_LE(input.nDim(), 4u);
+  CHECK_GE(input.nDim(), 2u);
 
-    Tensor output;
-    const float* in = input.data<float>();
-    size_t out_idx = 0, in_idx = 0;
-    if (input.nDim() == 4u) {
-      /// TODO
-      LOG(FATAL) << "Not implemented";
-    } else if (input.nDim() == 3u) {
-      if (image_dim_order == "CHW") {
-        size_t height = input.shape(1), width = input.shape(2),
-            channel = input.shape(0); 
-        CHECK_LE(crop_height + crop_h_offset, height);
-        CHECK_LE(crop_width + crop_w_offset, width);
-        float* out = new float[crop_height * crop_width * channel];
-        for (size_t c = 0; c < channel; c++) {
-          for (size_t h = 0; h < crop_height; h++) {
-            for (size_t w = 0; w < crop_width; w++) {
-              in_idx = (c * height + crop_h_offset + h) * width + crop_w_offset + w;
-              out_idx = (c * crop_height + h) * crop_width + w;
-              out[out_idx] = in[in_idx];
-            }
-          }
-        }
-        output = Reshape(output, Shape{channel, crop_height, crop_width});
-        output.CopyDataFromHostPtr<float>(out, crop_height * crop_width * channel);
-        delete[] out;
-      } else if (image_dim_order == "HWC") {
-        size_t height = input.shape(0), width = input.shape(1), 
-               channel = input.shape(2); 
-        CHECK_LE(crop_height + crop_h_offset, height);
-        CHECK_LE(crop_width + crop_w_offset, width);
-        float* out = new float[crop_height * crop_width * channel];
-        for (size_t c = 0; c < channel; c++) {
-          for (size_t h = 0; h < crop_height; h++) {
-            for (size_t w = 0; w < crop_width; w++) {
-              in_idx = ((crop_h_offset + h) * width + crop_w_offset + w) * channel + c;
-              out_idx = (h * crop_width + w) * channel + c;
-              out[out_idx] = in[in_idx];
-            }
+  Tensor output;
+  const float* in = input.data<float>();
+  size_t out_idx = 0, in_idx = 0;
+  if (input.nDim() == 4u) {
+    /// TODO
+    LOG(FATAL) << "Not implemented";
+  } else if (input.nDim() == 3u) {
+    if (image_dim_order == "CHW") {
+      size_t height = input.shape(1), width = input.shape(2),
+             channel = input.shape(0);
+      CHECK_LE(crop_height + crop_h_offset, height);
+      CHECK_LE(crop_width + crop_w_offset, width);
+      float* out = new float[crop_height * crop_width * channel];
+      for (size_t c = 0; c < channel; c++) {
+        for (size_t h = 0; h < crop_height; h++) {
+          for (size_t w = 0; w < crop_width; w++) {
+            in_idx = (c * height + crop_h_offset + h) * width + crop_w_offset + w;
+            out_idx = (c * crop_height + h) * crop_width + w;
+            out[out_idx] = in[in_idx];
           }
         }
-        output = Reshape(output, Shape{crop_height, crop_width, channel});
-        output.CopyDataFromHostPtr<float>(out, crop_height * crop_width * channel);
-        delete[] out;
-      } else {
-        LOG(FATAL) << "Unknow dimension order for images " << image_dim_order
-                   << " Only support 'HWC' and 'CHW'";
       }
-    } else { /// 2D gray image
-      size_t height = input.shape(0), width = input.shape(1); 
+      output.SetShape(Shape{channel, crop_height, crop_width});
+      output.CopyDataFromHostPtr<float>(out, crop_height * crop_width * channel);
+      delete[] out;
+    } else if (image_dim_order == "HWC") {
+      size_t height = input.shape(0), width = input.shape(1),
+             channel = input.shape(2);
       CHECK_LE(crop_height + crop_h_offset, height);
       CHECK_LE(crop_width + crop_w_offset, width);
-      float* out = new float[crop_height * crop_width];
-      for (size_t h = 0; h < crop_height; h++) {
-        for (size_t w = 0; w < crop_width; w++) {
-          in_idx = (crop_h_offset + h) * width + crop_w_offset + w;
-          out_idx = h * crop_width + w;
-          out[out_idx] = in[in_idx];
+      float* out = new float[crop_height * crop_width * channel];
+      for (size_t c = 0; c < channel; c++) {
+        for (size_t h = 0; h < crop_height; h++) {
+          for (size_t w = 0; w < crop_width; w++) {
+            in_idx = ((crop_h_offset + h) * width + crop_w_offset + w) * channel + c;
+            out_idx = (h * crop_width + w) * channel + c;
+            out[out_idx] = in[in_idx];
+          }
         }
       }
-      output = Reshape(output, Shape{crop_height, crop_width});
-      output.CopyDataFromHostPtr<float>(out, crop_height * crop_width);
+      output.SetShape(Shape{crop_height, crop_width, channel});
+      output.CopyDataFromHostPtr<float>(out, crop_height * crop_width * channel);
       delete[] out;
+    } else {
+      LOG(FATAL) << "Unknow dimension order for images " << image_dim_order
+                 << " Only support 'HWC' and 'CHW'";
+    }
+  } else { /// 2D gray image
+    size_t height = input.shape(0), width = input.shape(1);
+    CHECK_LE(crop_height + crop_h_offset, height);
+    CHECK_LE(crop_width + crop_w_offset, width);
+    float* out = new float[crop_height * crop_width];
+    for (size_t h = 0; h < crop_height; h++) {
+      for (size_t w = 0; w < crop_width; w++) {
+        in_idx = (crop_h_offset + h) * width + crop_w_offset + w;
+        out_idx = h * crop_width + w;
+        out[out_idx] = in[in_idx];
+      }
     }
-    return output;
+    output.SetShape(Shape{crop_height, crop_width});
+    output.CopyDataFromHostPtr<float>(out, crop_height * crop_width);
+    delete[] out;
   }
+  return output;
+}
 
-  Tensor mirror(Tensor& input, const bool horizontal_mirror,
-             const bool vertical_mirror, const string& image_dim_order) {
-    CHECK_LE(input.nDim(), 4u);
-    CHECK_GE(input.nDim(), 2u);
-    if (!horizontal_mirror && !vertical_mirror) return input;
+Tensor mirror(Tensor& input, const bool horizontal_mirror,
+              const bool vertical_mirror, const string& image_dim_order) {
+  CHECK_LE(input.nDim(), 4u);
+  CHECK_GE(input.nDim(), 2u);
+  if (!horizontal_mirror && !vertical_mirror) return input;
 
-    Tensor output;
-    const float* in = input.data<float>();
-    size_t out_idx = 0, in_idx = 0;
-    if (input.nDim() == 4u) {
-      /// TODO
-      LOG(FATAL) << "Not implemented";
-    } else if (input.nDim() == 3u) {
-      if (image_dim_order == "CHW") {
-        size_t height = input.shape(1), width = input.shape(2),
-            channel = input.shape(0);
-        float* out = new float[height * width * channel];
-        for (size_t c = 0; c < channel; c++) {
-          for (size_t h = 0; h < height; h++) {
-            for (size_t w = 0; w < width; w++) {
-              in_idx = (c * height + h) * width + w;
-              if (horizontal_mirror && vertical_mirror)
-                out_idx = (c * height + (height - 1 - h)) * width + (width - 1 - w);
-              else if (horizontal_mirror)
-                out_idx = (c * height + h) * width + (width - 1 - w);
-              else /// only do vertical mirror
-                out_idx = (c * height + (height - 1 - h)) * width + w;
-              out[out_idx] = in[in_idx];
-            }
-          }
-        }
-        output = Reshape(output, Shape{channel, height, width});
-        output.CopyDataFromHostPtr<float>(out, height * width * channel);
-        delete[] out;
-      } else if (image_dim_order == "HWC") {
-        size_t height = input.shape(0), width = input.shape(1),
-            channel = input.shape(2);
-        float* out = new float[height * width * channel];
-        for (size_t c = 0; c < channel; c++) {
-          for (size_t h = 0; h < height; h++) {
-            for (size_t w = 0; w < width; w++) {
-              in_idx = (h * width + w) * channel + c;
-              if (horizontal_mirror && vertical_mirror)
-                out_idx = ((height - 1 - h) * width + (width - 1 - w)) * channel + c;
-              else if (horizontal_mirror)
-                out_idx = (h * width + (width - 1 - w)) * channel + c;
-              else /// only do vertical mirror
-                out_idx = ((height - 1 - h) * width + w) * channel + c;
-              out[out_idx] = in[in_idx];
-            }
+  Tensor output;
+  const float* in = input.data<float>();
+  size_t out_idx = 0, in_idx = 0;
+  if (input.nDim() == 4u) {
+    /// TODO
+    LOG(FATAL) << "Not implemented";
+  } else if (input.nDim() == 3u) {
+    if (image_dim_order == "CHW") {
+      size_t height = input.shape(1), width = input.shape(2),
+             channel = input.shape(0);
+      float* out = new float[height * width * channel];
+      for (size_t c = 0; c < channel; c++) {
+        for (size_t h = 0; h < height; h++) {
+          for (size_t w = 0; w < width; w++) {
+            in_idx = (c * height + h) * width + w;
+            if (horizontal_mirror && vertical_mirror)
+              out_idx = (c * height + (height - 1 - h)) * width + (width - 1 - w);
+            else if (horizontal_mirror)
+              out_idx = (c * height + h) * width + (width - 1 - w);
+            else /// only do vertical mirror
+              out_idx = (c * height + (height - 1 - h)) * width + w;
+            out[out_idx] = in[in_idx];
           }
         }
-        output = Reshape(output, Shape{height, width, channel});
-        output.CopyDataFromHostPtr<float>(out, height * width * channel);
-        delete[] out;
-      } else {
-        LOG(FATAL) << "Unknow dimension order for images " << image_dim_order
-                   << " Only support 'HWC' and 'CHW'";
       }
-    } else { /// 2D gray image
-      size_t height = input.shape(0), width = input.shape(1);
-      float* out = new float[height * width];
-      for (size_t h = 0; h < height; h++) {
-        for (size_t w = 0; w < width; w++) {
-          in_idx = h * width + w;
-          if (horizontal_mirror && vertical_mirror)
-            out_idx = (height - 1 - h) * width + (width - 1 - w);
-          else if (horizontal_mirror)
-            out_idx = h * width + (width - 1 - w);
-          else /// only do vertical mirror
-            out_idx = (height - 1 - h) * width + w;
-          out[out_idx] = in[in_idx];
+      output.SetShape(Shape{channel, height, width});
+      output.CopyDataFromHostPtr<float>(out, height * width * channel);
+      delete[] out;
+    } else if (image_dim_order == "HWC") {
+      size_t height = input.shape(0), width = input.shape(1),
+             channel = input.shape(2);
+      float* out = new float[height * width * channel];
+      for (size_t c = 0; c < channel; c++) {
+        for (size_t h = 0; h < height; h++) {
+          for (size_t w = 0; w < width; w++) {
+            in_idx = (h * width + w) * channel + c;
+            if (horizontal_mirror && vertical_mirror)
+              out_idx = ((height - 1 - h) * width + (width - 1 - w)) * channel + c;
+            else if (horizontal_mirror)
+              out_idx = (h * width + (width - 1 - w)) * channel + c;
+            else /// only do vertical mirror
+              out_idx = ((height - 1 - h) * width + w) * channel + c;
+            out[out_idx] = in[in_idx];
+          }
         }
       }
-      output = Reshape(output, Shape{height, width});
-      output.CopyDataFromHostPtr<float>(out, height * width);
+      output.SetShape(Shape{height, width, channel});
+      output.CopyDataFromHostPtr<float>(out, height * width * channel);
       delete[] out;
+    } else {
+      LOG(FATAL) << "Unknow dimension order for images " << image_dim_order
+                 << " Only support 'HWC' and 'CHW'";
+    }
+  } else { /// 2D gray image
+    size_t height = input.shape(0), width = input.shape(1);
+    float* out = new float[height * width];
+    for (size_t h = 0; h < height; h++) {
+      for (size_t w = 0; w < width; w++) {
+        in_idx = h * width + w;
+        if (horizontal_mirror && vertical_mirror)
+          out_idx = (height - 1 - h) * width + (width - 1 - w);
+        else if (horizontal_mirror)
+          out_idx = h * width + (width - 1 - w);
+        else /// only do vertical mirror
+          out_idx = (height - 1 - h) * width + w;
+        out[out_idx] = in[in_idx];
+      }
     }
-    return output;
+    output.SetShape(Shape{height, width});
+    output.CopyDataFromHostPtr<float>(out, height * width);
+    delete[] out;
   }
+  return output;
+}
 } // namespace singa

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b30d7ea5/src/model/layer/batchnorm.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/batchnorm.cc b/src/model/layer/batchnorm.cc
index 4e74a82..d2b0c3e 100644
--- a/src/model/layer/batchnorm.cc
+++ b/src/model/layer/batchnorm.cc
@@ -44,7 +44,7 @@ void BatchNorm::Setup(const Shape& in_sample, const LayerConf& conf) {
   else
     is_2d_ = false;
 
-  bnScale_.Reshape(Shape{channels_});
+  bnScale_.SetShape(Shape{channels_});
   bnBias_.ResetLike(bnScale_);
   runningMean_.ResetLike(bnScale_);
   runningVariance_.ResetLike(bnScale_);
@@ -68,19 +68,18 @@ void BatchNorm::ToDevice(std::shared_ptr<Device> device) {
 const Tensor BatchNorm::Forward(int flag, const Tensor& input) {
   Tensor x = input.Clone();
   x.Reshape(Shape{input.shape(0), input.Size() / input.shape(0)});
-  Tensor output, mean, var, xnorm;
+  Tensor output;
   output.ResetLike(x);
   // TODO(wangwei) input sample shape check
-
   if ((flag & kTrain) == kTrain) {  // forward for train
     if (is_2d_) {                   // batchnorm_per_activation mode
-      mean = Average(x, 0);
+      auto mean = Average(x, 0);
       runningMean_ *= 1.0f - factor_;
       Axpy(factor_, mean, &runningMean_);
-      xnorm = x.Clone();
+      auto xnorm = x.Clone();
       SubRow(mean, &xnorm);
       xnorm = Square(xnorm);
-      var = Average(xnorm, 0);
+      auto var = Average(xnorm, 0);
       runningVariance_ *= 1.0f - factor_;
       Axpy(factor_, var, &runningVariance_);
       Tensor tmp = var.Clone();
@@ -102,7 +101,7 @@ const Tensor BatchNorm::Forward(int flag, const Tensor& input) {
     }
   } else {         // forward for test
     if (is_2d_) {  // batchnorm_per_activation mode
-      xnorm = x.Clone();
+      auto xnorm = x.Clone();
       SubRow(runningMean_, &xnorm);
       Tensor tmp = runningVariance_.Clone();
       tmp = Sqrt(tmp);
@@ -134,7 +133,7 @@ const Tensor BatchNorm::Forward(int flag, const Tensor& input) {
       scale.Reshape(Shape{channels_ * height_ * width_});
       bias.Reshape(Shape{channels_ * height_ * width_});
 
-      xnorm = x.Clone();
+      auto xnorm = x.Clone();
       SubRow(mean, &xnorm);
       var = Sqrt(var);
       var += 1e-6f;

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b30d7ea5/src/model/layer/convolution.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/convolution.cc b/src/model/layer/convolution.cc
index cc77433..3718d8d 100755
--- a/src/model/layer/convolution.cc
+++ b/src/model/layer/convolution.cc
@@ -96,9 +96,9 @@ void Convolution::Setup(const Shape &in_sample, const LayerConf &conf) {
   col_width_ = conv_height_ * conv_width_;
 
   // Setup shape of weight_ and bias_
-  weight_.Reshape(Shape{num_filters_, col_height_});
+  weight_.SetShape(Shape{num_filters_, col_height_});
   if (bias_term_)
-    bias_.Reshape(Shape{num_filters_});
+    bias_.SetShape(Shape{num_filters_});
   // Assume the order of param is: weight, bias
   for (const auto &spec : conf.param()) param_specs_.push_back(spec);
 }
@@ -174,8 +174,8 @@ const std::pair<Tensor, vector<Tensor>> Convolution::Backward(
     col_data.CopyDataFromHostPtr(data_col, col_height_ * col_width_);
     Tensor grad_b(Shape{num_filters_, conv_height_ * conv_width_});
     CopyDataToFrom(&grad_b, grad, grad_b.Size(), 0, b * grad_b.Size());
-    dw += Mult(grad_b, col_data.T());
-    Tensor dcol_b = Mult(weight_.T(), grad_b);
+    dw += Mult(grad_b, Transpose(col_data));
+    Tensor dcol_b = Mult(Transpose(weight_), grad_b);
     auto dcol_data = dcol_b.data<float>();
     Col2im(dcol_data, channels_, height_, width_, kernel_h_, kernel_w_, pad_h_,
            pad_w_, stride_h_, stride_w_, dx_b);

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b30d7ea5/src/model/layer/cudnn_batchnorm.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/cudnn_batchnorm.cc b/src/model/layer/cudnn_batchnorm.cc
index 5c93a6b..389b41b 100644
--- a/src/model/layer/cudnn_batchnorm.cc
+++ b/src/model/layer/cudnn_batchnorm.cc
@@ -39,8 +39,8 @@ void CudnnBatchNorm::ToDevice(std::shared_ptr<Device> device) {
 
 void CudnnBatchNorm::Setup(const Shape& in_sample, const LayerConf& conf) {
   BatchNorm::Setup(in_sample, conf);
-  resultSaveMean_.Reshape(Shape{channels_});
-  resultSaveVariance_.Reshape(Shape{channels_});
+  resultSaveMean_.SetShape(Shape{channels_});
+  resultSaveVariance_.SetShape(Shape{channels_});
 }
 
 void CudnnBatchNorm::InitCudnn(const Shape& shape, DataType dtype) {

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b30d7ea5/src/model/layer/dense.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/dense.cc b/src/model/layer/dense.cc
index fac9130..36a7a91 100644
--- a/src/model/layer/dense.cc
+++ b/src/model/layer/dense.cc
@@ -40,11 +40,11 @@ void Dense::Setup(const Shape& in_sample, const LayerConf &conf) {
   transpose_ = dense_conf.transpose();
   bias_term_ = dense_conf.bias_term();
   if (transpose_)  // was {vdim_, hdim} by zhaojing?
-    weight_.Reshape(Shape{hdim_, vdim_});
+    weight_.SetShape(Shape{hdim_, vdim_});
   else
-    weight_.Reshape(Shape{vdim_, hdim_});
+    weight_.SetShape(Shape{vdim_, hdim_});
   if (bias_term_)
-    bias_.Reshape(Shape{hdim_});
+    bias_.SetShape(Shape{hdim_});
   for (auto specs: conf.param())
     param_specs_.push_back(specs);
 }
@@ -55,7 +55,7 @@ const Tensor Dense::Forward(int flag, const Tensor &input) {
   Tensor output;
   CHECK_EQ(input.nDim(), 2u);
   if (transpose_)  // use the transposed version of weight_ for computing
-    output = Mult(input, weight_.T());
+    output = Mult(input, Transpose(weight_));
   else
     output = Mult(input, weight_);
   if (bias_term_)
@@ -81,10 +81,10 @@ const std::pair<Tensor, vector<Tensor>> Dense::Backward(int flag,
   }
   if (transpose_) {
     dx = Mult(grad, weight_);
-    dw = Mult(grad.T(), src_data);
+    dw = Mult(Transpose(grad), src_data);
   } else {
-    dx = Mult(grad, weight_.T());
-    dw = Mult(src_data.T(), grad);
+    dx = Mult(grad, Transpose(weight_));
+    dw = Mult(Transpose(src_data), grad);
   }
   param_grad.push_back(dw);
   if (bias_term_)

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b30d7ea5/src/model/layer/flatten.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/flatten.cc b/src/model/layer/flatten.cc
index 561c310..592e892 100644
--- a/src/model/layer/flatten.cc
+++ b/src/model/layer/flatten.cc
@@ -49,8 +49,7 @@ const Tensor Flatten::Forward(int flag, const Tensor &input) {
 const std::pair<Tensor, vector<Tensor> > Flatten::Backward(int flag,
                                                            const Tensor &grad) {
   vector<Tensor> param_grad;
-  Tensor input_grad = grad;
-  input_grad.Reshape(input_shape_);
+  Tensor input_grad = Reshape(grad, input_shape_);
   return std::make_pair(input_grad, param_grad);
 }
 

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b30d7ea5/src/model/layer/lrn.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/lrn.cc b/src/model/layer/lrn.cc
index 4fdb5c9..a1776fa 100644
--- a/src/model/layer/lrn.cc
+++ b/src/model/layer/lrn.cc
@@ -52,8 +52,7 @@ const Tensor LRN::Forward(int flag, const Tensor& input) {
                    std::min(input.shape(1), c + local_size_ / 2 + 1));
       window = Square(window);
 
-      Tensor tmp, ch;
-      tmp.Reshape(Shape{input.shape(2) * input.shape(3)});
+      Tensor ch, tmp(Shape{input.shape(2) * input.shape(3)});
       SumRows(window, &tmp);
 
       tmp *= alpha_;
@@ -97,8 +96,7 @@ const std::pair<Tensor, vector<Tensor>> LRN::Backward(int flag,
         Tensor window =
             CopyRows(image, std::max(0, static_cast<int>(c) - local_size_ / 2),
                      std::min(grad.shape(1), c + local_size_ / 2 + 1));
-        Tensor tmp;
-        tmp.Reshape(Shape{grad.shape(2) * grad.shape(3)});
+        Tensor tmp(Shape{grad.shape(2) * grad.shape(3)});
         window = Square(window);
         SumRows(window, &tmp);
         tmp *= alpha_;
@@ -126,8 +124,7 @@ const std::pair<Tensor, vector<Tensor>> LRN::Backward(int flag,
         Tensor window =
             CopyRows(image, std::max(0, static_cast<int>(c) - local_size_ / 2),
                      std::min(grad.shape(1), c + local_size_ / 2 + 1));
-        Tensor tmpr;
-        tmpr.Reshape(Shape{grad.shape(2) * grad.shape(3)});
+        Tensor tmpr(Shape{grad.shape(2) * grad.shape(3)});
         SumRows(window, &tmpr);
         tmpr.Reshape(Shape{grad.shape(2), grad.shape(3)});
         channels.push_back(tmpr);

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b30d7ea5/src/model/layer/opencl_convolution.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/opencl_convolution.cc b/src/model/layer/opencl_convolution.cc
index 063c4c3..eb25f5e 100644
--- a/src/model/layer/opencl_convolution.cc
+++ b/src/model/layer/opencl_convolution.cc
@@ -37,9 +37,9 @@ const Tensor OpenclConvolution::Forward(int flag, const Tensor &input) {
   auto data_type = input.data_type();
   auto device = input.device();
 
-   // TODO(wangwei) update the layer config if the input sample shape changes
+  // TODO(wangwei) update the layer config if the input sample shape changes
   CHECK(input.shape(1) == channels_ && input.shape(2) == height_ &&
-      input.shape(3) == width_) << "input sample shape should not change";
+        input.shape(3) == width_) << "input sample shape should not change";
 
   Shape shape{batchsize, num_filters_, conv_height_, conv_width_};
   Tensor output(shape, device, data_type);
@@ -48,16 +48,16 @@ const Tensor OpenclConvolution::Forward(int flag, const Tensor &input) {
   for (size_t b = 0; b < batchsize; b++) {
     int offset = b * imagesize;
 
-    col_data.device()->Exec([input, offset, col_data, this](Context* ctx) mutable {
+    col_data.device()->Exec([input, offset, col_data, this](Context * ctx) mutable {
 
       this->Im2Col(input.block(), offset,
-                   height_, width_,
-                   kernel_h_, kernel_w_,
-                   pad_h_, pad_w_,
-                   stride_h_, stride_w_,
-                   conv_height_, conv_width_,
-                   0, channels_,
-                   col_data.block(), ctx);
+      height_, width_,
+      kernel_h_, kernel_w_,
+      pad_h_, pad_w_,
+      stride_h_, stride_w_,
+      conv_height_, conv_width_,
+      0, channels_,
+      col_data.block(), ctx);
     },
     {input.block()},
     {col_data.block()});
@@ -116,16 +116,17 @@ OpenclConvolution::Backward(int flag, const Tensor &grad) {
     int im_offset = b * imagesize;
     int col_offset = 0; // Always keep this to zero.
 
-    col_data.device()->Exec([src_data, col_data, im_offset, col_offset, this](Context* ctx) mutable {
+    col_data.device()->Exec([src_data, col_data, im_offset, col_offset,
+    this](Context * ctx) mutable {
 
       this->Im2Col(src_data.block(), im_offset,
-                   height_, width_,
-                   kernel_h_, kernel_w_,
-                   pad_h_, pad_w_,
-                   stride_h_, stride_w_,
-                   conv_height_, conv_width_,
-                   col_offset, channels_,
-                   col_data.block(), ctx);
+      height_, width_,
+      kernel_h_, kernel_w_,
+      pad_h_, pad_w_,
+      stride_h_, stride_w_,
+      conv_height_, conv_width_,
+      col_offset, channels_,
+      col_data.block(), ctx);
     },
     {src_data.block()},
     {col_data.block()});
@@ -134,19 +135,20 @@ OpenclConvolution::Backward(int flag, const Tensor &grad) {
                   grad.device(), grad.data_type());
     CopyDataToFrom(&grad_b, grad, grad_b.Size(), 0, b * grad_b.Size());
 
-    dw += Mult(grad_b, col_data.T());
-    Tensor dcol_b = Mult(weight_.T(), grad_b);
+    dw += Mult(grad_b, Transpose(col_data));
+    Tensor dcol_b = Mult(Transpose(weight_), grad_b);
 
-    dx.device()->Exec([dcol_b, dx, im_offset, col_offset, this](Context* ctx) mutable {
+    dx.device()->Exec([dcol_b, dx, im_offset, col_offset,
+    this](Context * ctx) mutable {
 
       this->Col2Im(dcol_b.block(), col_offset,
-                   height_, width_,
-                   kernel_h_, kernel_w_,
-                   pad_h_, pad_w_,
-                   stride_h_, stride_w_,
-                   conv_height_, conv_width_,
-                   im_offset, channels_,
-                   dx.block(), ctx);
+      height_, width_,
+      kernel_h_, kernel_w_,
+      pad_h_, pad_w_,
+      stride_h_, stride_w_,
+      conv_height_, conv_width_,
+      im_offset, channels_,
+      dx.block(), ctx);
     },
     {dcol_b.block()},
     {dx.block()});

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b30d7ea5/src/model/layer/rnn.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/rnn.cc b/src/model/layer/rnn.cc
index b811f9d..e565abc 100644
--- a/src/model/layer/rnn.cc
+++ b/src/model/layer/rnn.cc
@@ -79,7 +79,7 @@ void RNN::Setup(const Shape& in_sample, const LayerConf &conf) {
       dim = hidden_size_ * (hidden_size_ +  hidden_size_ + 2);
     weight_size += mult * dim;
   }
-  weight_.Reshape(Shape{weight_size});
+  weight_.SetShape(Shape{weight_size});
 }
 
 const vector<Tensor> RNN::Forward(int flag, const vector<Tensor>& inputs) {

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b30d7ea5/src/model/operation/convolution.cc
----------------------------------------------------------------------
diff --git a/src/model/operation/convolution.cc b/src/model/operation/convolution.cc
index f700203..7c71d7c 100755
--- a/src/model/operation/convolution.cc
+++ b/src/model/operation/convolution.cc
@@ -4,7 +4,8 @@
 
 namespace singa {
 
-ConvHandle::ConvHandle(const Tensor &input, const std::vector<size_t>& kernel_size,
+ConvHandle::ConvHandle(const Tensor &input,
+                       const std::vector<size_t>& kernel_size,
                        const std::vector<size_t>& stride, const std::vector<size_t>& padding,
                        const size_t in_channels, const size_t out_channels,
                        const bool bias) {
@@ -23,7 +24,8 @@ ConvHandle::ConvHandle(const Tensor &input, const std::vector<size_t>& kernel_si
   bias_term = bias;
 
   batchsize = input.shape(0);
-  CHECK(input.shape(1) == in_channels) << "the number of input channels mismatched.";
+  CHECK(input.shape(1) == in_channels) <<
+                                       "the number of input channels mismatched.";
   height = input.shape(2);
   width = input.shape(3);
 
@@ -39,14 +41,16 @@ ConvHandle::ConvHandle(const Tensor &input, const std::vector<size_t>& kernel_si
 
 
 
-Tensor CpuConvForward(const Tensor &x, Tensor &W,  Tensor &b, const ConvHandle &ch) {
+Tensor CpuConvForward(const Tensor &x, Tensor &W,  Tensor &b,
+                      const ConvHandle &ch) {
   CHECK_EQ(x.device()->lang(), kCpp);
 
   CHECK(x.shape(1) == ch.channels && x.shape(2) == ch.height &&
         x.shape(3) == ch.width) << "input sample shape should not change";
 
   CHECK(W.shape(0) == ch.num_filters && W.shape(1) == ch.channels &&
-        W.shape(2) == ch.kernel_h && W.shape(3) == ch.kernel_w) << "weights shape should not change";
+        W.shape(2) == ch.kernel_h
+        && W.shape(3) == ch.kernel_w) << "weights shape should not change";
 
   Shape w_shape = W.shape();
   Shape b_shape;
@@ -67,8 +71,9 @@ Tensor CpuConvForward(const Tensor &x, Tensor &W,  Tensor &b, const ConvHandle &
   float *data_col = new float[ch.col_height * ch.col_width];
   auto in_data = x.data<float>();
   for (size_t num = 0; num < ch.batchsize; num++) {
-    Im2col(in_data + num * ch.imagesize, ch.channels, ch.height, ch.width, ch.kernel_h,
-             ch.kernel_w, ch.pad_h, ch.pad_w, ch.stride_h, ch.stride_w, data_col);
+    Im2col(in_data + num * ch.imagesize, ch.channels, ch.height, ch.width,
+           ch.kernel_h,
+           ch.kernel_w, ch.pad_h, ch.pad_w, ch.stride_h, ch.stride_w, data_col);
 
     col_data.CopyDataFromHostPtr(data_col, ch.col_height * ch.col_width);
     Tensor each = Mult(W, col_data);
@@ -83,14 +88,16 @@ Tensor CpuConvForward(const Tensor &x, Tensor &W,  Tensor &b, const ConvHandle &
   return output;
 }
 
-Tensor CpuConvBackwardx(const Tensor &dy, Tensor &W, const Tensor &x, const ConvHandle &ch) {
+Tensor CpuConvBackwardx(const Tensor &dy, Tensor &W, const Tensor &x,
+                        const ConvHandle &ch) {
   CHECK_EQ(dy.device()->lang(), kCpp);
 
   CHECK(dy.shape(1) == ch.num_filters && dy.shape(2) == ch.conv_height &&
         dy.shape(3) == ch.conv_width) << "input gradients shape should not change";
 
   CHECK(W.shape(0) == ch.num_filters && W.shape(1) == ch.channels &&
-        W.shape(2) == ch.kernel_h && W.shape(3) == ch.kernel_w) << "weights shape should not change";
+        W.shape(2) == ch.kernel_h
+        && W.shape(3) == ch.kernel_w) << "weights shape should not change";
 
   Shape w_shape = W.shape();
   W.Reshape(Shape{ch.num_filters, ch.col_height});
@@ -103,17 +110,19 @@ Tensor CpuConvBackwardx(const Tensor &dy, Tensor &W, const Tensor &x, const Conv
   for (size_t num = 0; num < ch.batchsize; num++) {
     Tensor grad_b(Shape{ch.num_filters, ch.conv_height * ch.conv_width});
     CopyDataToFrom(&grad_b, dy, grad_b.Size(), 0, num * grad_b.Size());
-    Tensor dcol_b = Mult(W.T(), grad_b);
+    Tensor dcol_b = Mult(Transpose(W), grad_b);
     auto dcol_data = dcol_b.data<float>();
-    Col2im(dcol_data, ch.channels, ch.height, ch.width, ch.kernel_h, ch.kernel_w, ch.pad_h,
-             ch.pad_w, ch.stride_h, ch.stride_w, dx_b);
+    Col2im(dcol_data, ch.channels, ch.height, ch.width, ch.kernel_h, ch.kernel_w,
+           ch.pad_h,
+           ch.pad_w, ch.stride_h, ch.stride_w, dx_b);
     dx.CopyDataFromHostPtr(dx_b, ch.imagesize, num * ch.imagesize);
   }
   W.Reshape(w_shape);
   return dx;
 }
 
-Tensor CpuConvBackwardW(const Tensor &dy, const Tensor &x, const Tensor &W, const ConvHandle &ch) {
+Tensor CpuConvBackwardW(const Tensor &dy, const Tensor &x, const Tensor &W,
+                        const ConvHandle &ch) {
   CHECK_EQ(dy.device()->lang(), kCpp);
 
   CHECK(dy.shape(1) == ch.num_filters && dy.shape(2) == ch.conv_height &&
@@ -134,18 +143,20 @@ Tensor CpuConvBackwardW(const Tensor &dy, const Tensor &x, const Tensor &W, cons
   float *data_col = new float[ch.col_height * ch.col_width];
   auto in_data = dy.data<float>();
   for (size_t num = 0; num < ch.batchsize; num++) {
-    Im2col(in_data + num * ch.imagesize, ch.channels, ch.height, ch.width, ch.kernel_h,
-             ch.kernel_w, ch.pad_h, ch.pad_w, ch.stride_h, ch.stride_w, data_col);
+    Im2col(in_data + num * ch.imagesize, ch.channels, ch.height, ch.width,
+           ch.kernel_h,
+           ch.kernel_w, ch.pad_h, ch.pad_w, ch.stride_h, ch.stride_w, data_col);
     col_data.CopyDataFromHostPtr(data_col, ch.col_height * ch.col_width);
     Tensor grad_b(Shape{ch.num_filters, ch.conv_height * ch.conv_width});
     CopyDataToFrom(&grad_b, dy, grad_b.Size(), 0, num * grad_b.Size());
-    dW += Mult(grad_b, col_data.T());
+    dW += Mult(grad_b, Transpose(col_data));
   }
   dW.Reshape(w_shape);
   return dW;
 }
 
-Tensor CpuConvBackwardb(const Tensor &dy, const Tensor &b, const ConvHandle &ch) {
+Tensor CpuConvBackwardb(const Tensor &dy, const Tensor &b,
+                        const ConvHandle &ch) {
   CHECK_EQ(dy.device()->lang(), kCpp);
 
   CHECK(dy.shape(1) == ch.num_filters && dy.shape(2) == ch.conv_height &&
@@ -169,11 +180,13 @@ Tensor CpuConvBackwardb(const Tensor &dy, const Tensor &b, const ConvHandle &ch)
 };
 
 #ifdef USE_CUDNN
-CudnnConvHandle::CudnnConvHandle(const Tensor &input, const std::vector<size_t>& kernel_size,
+CudnnConvHandle::CudnnConvHandle(const Tensor &input,
+                                 const std::vector<size_t>& kernel_size,
                                  const std::vector<size_t>& stride, const std::vector<size_t>& padding,
                                  const size_t in_channels, const size_t out_channels, const bool bias,
                                  const size_t workspace_byte_limit, const std::string& prefer)
-  : ConvHandle(input, kernel_size, stride, padding, in_channels, out_channels, bias) {
+  : ConvHandle(input, kernel_size, stride, padding, in_channels, out_channels,
+               bias) {
 
   DataType dtype = input.data_type();
   auto dev = input.device();
@@ -203,7 +216,7 @@ CudnnConvHandle::CudnnConvHandle(const Tensor &input, const std::vector<size_t>&
 #if CUDNN_MAJOR >= 7
               , GetCudnnDataType(dtype)
 #endif
-              ));
+                                             ));
   CUDNN_CHECK(cudnnSetFilter4dDescriptor(filter_desc, GetCudnnDataType(dtype),
                                          CUDNN_TENSOR_NCHW, num_filters,
                                          channels, kernel_h, kernel_w));
@@ -268,8 +281,8 @@ CudnnConvHandle::CudnnConvHandle(const Tensor &input, const std::vector<size_t>&
                 ctx->cudnn_handle, x_desc, y_desc, conv_desc, filter_desc,
                 bp_filter_alg, &bp_filter_byte));
   workspace_count = std::max(std::max(fp_byte, bp_data_byte), bp_filter_byte) /
-                     sizeof(float) +
-                     1;
+                    sizeof(float) +
+                    1;
   if (workspace_count * sizeof(float) > workspace_byte_limit)
     LOG(WARNING) << "The required memory for workspace ("
                  << workspace_count * sizeof(float)
@@ -289,7 +302,8 @@ CudnnConvHandle::~CudnnConvHandle() {
   if (y_desc != nullptr) CUDNN_CHECK(cudnnDestroyTensorDescriptor(y_desc));
 }
 
-Tensor GpuConvForward(const Tensor &x, const Tensor &W, const Tensor &b, const CudnnConvHandle &cch) {
+Tensor GpuConvForward(const Tensor &x, const Tensor &W, const Tensor &b,
+                      const CudnnConvHandle &cch) {
   CHECK_EQ(x.device()->lang(), kCuda);
 
   DataType dtype = x.data_type();
@@ -323,7 +337,8 @@ Tensor GpuConvForward(const Tensor &x, const Tensor &W, const Tensor &b, const C
   return output;
 }
 
-Tensor GpuConvBackwardx(const Tensor &dy, const Tensor &W, const Tensor &x, const CudnnConvHandle &cch) {
+Tensor GpuConvBackwardx(const Tensor &dy, const Tensor &W, const Tensor &x,
+                        const CudnnConvHandle &cch) {
   CHECK_EQ(dy.device()->lang(), kCuda);
 
   Tensor dx;
@@ -344,7 +359,8 @@ Tensor GpuConvBackwardx(const Tensor &dy, const Tensor &W, const Tensor &x, cons
   return dx;
 }
 
-Tensor GpuConvBackwardW(const Tensor &dy, const Tensor &x, const Tensor &W, const CudnnConvHandle &cch) {
+Tensor GpuConvBackwardW(const Tensor &dy, const Tensor &x, const Tensor &W,
+                        const CudnnConvHandle &cch) {
   CHECK_EQ(dy.device()->lang(), kCuda);
 
   Tensor dW;
@@ -366,7 +382,8 @@ Tensor GpuConvBackwardW(const Tensor &dy, const Tensor &x, const Tensor &W, cons
 }
 
 // input Tensor b for Reset db purpose, can avoid this later.
-Tensor GpuConvBackwardb(const Tensor &dy, const Tensor &b, const CudnnConvHandle &cch) {
+Tensor GpuConvBackwardb(const Tensor &dy, const Tensor &b,
+                        const CudnnConvHandle &cch) {
   CHECK_EQ(dy.device()->lang(), kCuda);
 
   Tensor db;

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b30d7ea5/src/model/updater/local_updater.cc
----------------------------------------------------------------------
diff --git a/src/model/updater/local_updater.cc b/src/model/updater/local_updater.cc
index c3c6793..04593f4 100644
--- a/src/model/updater/local_updater.cc
+++ b/src/model/updater/local_updater.cc
@@ -43,7 +43,7 @@ void LocalUpdater::Apply(int step, const string& name, Tensor& grad,
   int nth = dev_index_[name]++;
   auto key = std::make_pair(nth, name);
   if (grad_buffer_[key].Size() != grad.Size()) {
-    grad_buffer_[key].Reshape(grad.shape());
+    grad_buffer_[key].SetShape(grad.shape());
     grad_buffer_[key].AsType(grad.data_type());
   }
   grad_buffer_[key].CopyData(grad);
@@ -56,7 +56,7 @@ void LocalUpdater::Apply(int step, const string& name, Tensor& grad,
     }
   } else {
     if (param_buffer_[name].Size() != value.Size()) {
-      param_buffer_[name].Reshape(value.shape());
+      param_buffer_[name].SetShape(value.shape());
       param_buffer_[name].AsType(value.data_type());
       param_buffer_[name].CopyData(value);
       sum_[name].ResetLike(param_buffer_[name]);


[2/4] incubator-singa git commit: SINGA-380 Fix bugs from Reshape

Posted by wa...@apache.org.
SINGA-380 Fix bugs from Reshape

Add SoftmaxCrossEntropy Operation which accepts logits as input (Softmax is applied in the operation)
There is another operation for CrossEntropy which accepts probabilities as input.

Fix the memory leaking bug from Reshape in C++.

fix mem leak bug from reshape


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/58e6640d
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/58e6640d
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/58e6640d

Branch: refs/heads/master
Commit: 58e6640df8d8f3572e15e818e59eb552ea314eaa
Parents: 81908a8
Author: Wang Wei <wa...@gmail.com>
Authored: Sun Jul 8 21:42:26 2018 +0800
Committer: wang wei <wa...@comp.nus.edu.sg>
Committed: Wed Jul 11 15:20:43 2018 +0800

----------------------------------------------------------------------
 examples/autograd/mnist_cnn.py | 30 ++++++++++------
 include/singa/core/tensor.h    |  6 +++-
 python/singa/autograd.py       | 61 ++++++++++++++++++-------------
 src/api/core_tensor.i          |  3 ++
 src/core/tensor/tensor.cc      | 71 +++++++++++++++++++++++--------------
 src/model/layer/cudnn_rnn.cc   |  2 +-
 6 files changed, 108 insertions(+), 65 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/58e6640d/examples/autograd/mnist_cnn.py
----------------------------------------------------------------------
diff --git a/examples/autograd/mnist_cnn.py b/examples/autograd/mnist_cnn.py
index db21485..43a22ba 100755
--- a/examples/autograd/mnist_cnn.py
+++ b/examples/autograd/mnist_cnn.py
@@ -77,7 +77,7 @@ if __name__ == '__main__':
     args = parser.parse_args()
 
     assert os.path.exists(args.file_path), \
-        'Pls download the MNIST dataset from '
+        'Pls download the MNIST dataset from https://s3.amazonaws.com/img-datasets/mnist.npz'
 
     if args.use_cpu:
         print('Using CPU')
@@ -92,7 +92,7 @@ if __name__ == '__main__':
     num_classes = 10
     epochs = 1
 
-    sgd = optimizer.SGD(0.05)
+    sgd = optimizer.SGD(0.001)
 
     x_train = preprocess(train[0])
     y_train = to_categorical(train[1], num_classes)
@@ -109,24 +109,26 @@ if __name__ == '__main__':
     conv2 = autograd.Conv2D(32, 32, 3, padding=1)
     linear = autograd.Linear(32 * 28 * 28, 10)
 
+
     def forward(x, t):
+        
         y = conv1(x)
         y = autograd.relu(y)
+        y = autograd.max_pool_2d(y)
         y = conv2(y)
         y = autograd.relu(y)
         y = autograd.max_pool_2d(y)
-        y = autograd.flatten(y)
+        y=autograd.flatten(y)
         y = linear(y)
-        y = autograd.soft_max(y)
-        loss = autograd.cross_entropy(y, t)
+        loss = autograd.softmax_cross_entropy(y, t)
         return loss, y
 
     autograd.training = True
-    for epoch in range(epochs):
+    for epoch in range(50):
         for i in range(batch_number):
-            inputs = tensor.Tensor(device=dev, data=x_train[i * 100:(1 + i) * 100])
-            targets = tensor.Tensor(device=dev, data=y_train[i * 100:(1 + i) * 100])
-
+            inputs = tensor.Tensor(device=dev, data=x_train[ i * 100:(1 + i) * 100], stores_grad=False)
+            targets = tensor.Tensor(device=dev, data=y_train[i * 100:(1 + i) * 100], requires_grad=False, stores_grad=False)
+            
             loss, y = forward(inputs, targets)
 
             accuracy_rate = accuracy(tensor.to_numpy(y),
@@ -134,6 +136,12 @@ if __name__ == '__main__':
             if (i % 5 == 0):
                 print('accuracy is:', accuracy_rate, 'loss is:',
                       tensor.to_numpy(loss)[0])
-
+            
             for p, gp in autograd.backward(loss):
-                sgd.apply(0, gp, p, '')
+                sgd.apply(epoch, gp, p, '')
+            
+            
+
+            
+            
+

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/58e6640d/include/singa/core/tensor.h
----------------------------------------------------------------------
diff --git a/include/singa/core/tensor.h b/include/singa/core/tensor.h
index a1badd4..5921762 100644
--- a/include/singa/core/tensor.h
+++ b/include/singa/core/tensor.h
@@ -292,7 +292,7 @@ Tensor Reshape(const Tensor &in, Shape &&s);
 void CopyDataToFrom(Tensor *dst, const Tensor &src, const size_t num,
                     const size_t dst_offset = 0, const size_t src_offset = 0);
 
-void RepeatDataToFrom(bool broadcast_flag, vector<size_t> repeats, int axis, 
+void RepeatDataToFrom(bool broadcast_flag, vector<size_t> repeats, int axis,
                       Tensor *dst, const Tensor &in, const size_t num);
 
 // =============Element-wise operations====================================
@@ -508,6 +508,10 @@ void ComputeCrossEntropy(const Tensor &p, const Tensor &t, Tensor *loss);
 
 void SoftmaxCrossEntropyBwd(const Tensor &t, Tensor *p);
 
+/// To be called by pysinga autograd operations;
+/// swig ignores the const qualifier http://www.swig.org/Doc3.0/SWIGPlus.html#SWIGPlus_const
+const Tensor CrossEntropyFwd(const Tensor& p, const Tensor& t);
+const Tensor SoftmaxCrossEntropyBwd(const Tensor& p, const Tensor& t);
 
 /// Return a tensor consisting of rows ([start, end)) from 'in'. It copies the
 /// values from 'in'. 'in' ia a 2D Tensor.

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/58e6640d/python/singa/autograd.py
----------------------------------------------------------------------
diff --git a/python/singa/autograd.py b/python/singa/autograd.py
index 2ba3098..63698c2 100755
--- a/python/singa/autograd.py
+++ b/python/singa/autograd.py
@@ -40,14 +40,8 @@ class Operation(object):
 
     Steps to add a specific operation Xxxx:
     1. create a subclass of Operation, name it as Xxxx
-    2. if Xxxx is implemented using other Operations, then override
-       _do_forward() function;
-       if Xxxx is implemented using CTensor operations,
-       then override the forward() and backward(); The arguments of forward()
+    2. override the forward() and backward(); The arguments of forward()
        and backward() should only include CTensor;
-       if Xxxx is implemented by calling functions in layer.py, then override
-       __call__(), forward() and backward(). TODO(wangwei) avoid this complex
-       case.
     '''
 
     def __call__(self, *xs):
@@ -311,9 +305,9 @@ def soft_max(x, axis=0):
     return SoftMax(axis)(x)[0]
 
 
-class CrossEntropy(Operation):
+class NLL(Operation):
     '''
-    Calculte CrossEntropy loss for a batch of training data.
+    Calculte negative log likelihood loss for a batch of training data.
 
     '''
 
@@ -356,8 +350,27 @@ class CrossEntropy(Operation):
             pass  # TODO, broadcast elementwise multiply seems not support
 
 
-def cross_entropy(y, t):
-    return CrossEntropy()(y, t)[0]
+def nll(y, t):
+    return NLL()(y, t)[0]
+
+
+class SoftMaxCrossEntropy(Operation):
+
+    def forward(self, x, t):
+        self.p = singa.SoftMax(x)
+        self.t = t
+        loss = CTensor((1,), self.p.device())
+        ret = singa.CrossEntropyFwd(self.p, t)
+        loss.SetFloatValue(singa.SumAsFloat(ret) / x.shape()[0])
+        return loss
+
+    def backward(self, dy=1.0):
+        return singa.SoftmaxCrossEntropyBwd(self.p, self.t), None
+
+
+def softmax_cross_entropy(x, t):
+    # x is the logits and t is the ground truth; both are 2D.
+    return SoftMaxCrossEntropy()(x, t)[0]
 
 
 def ctensor2numpy(x):
@@ -427,23 +440,20 @@ def max_pool_2d(x, kernel_size=3, stride=1, padding=0, dilation=1,
 
 class Flatten(Operation):
 
-    def __init__(self):
-        self.PyLayer = layer.Flatten('flatten', 1)
+    def __init(self, start_axis=1):
+        # flatten all axis after (inclusive) start_axis
+        self.start_axis = start_axis
+        assert start_axis == 1, 'must flatten into 2d array not'
 
-    def __call__(self, x):
-        if training:
-            self.flag = model_pb2.kTrain
-        else:
-            self.flag = model_pb2.kEval
-        if not self.PyLayer.has_setup:
-            self.PyLayer.setup(x.shape[1:])
-        return self._do_forward(x)
-
-    def forward(self, *xs):
-        return self.PyLayer.layer.Forward(self.flag, xs[0])
+    def forward(self, x):
+        # TODO Do flatten start from axis != 1
+        self.shape = list(x.shape())
+        y = x.Reshape((x.shape()[0], x.Size() // x.shape()[0]))
+        return y
 
     def backward(self, dy):
-        return self.PyLayer.layer.Backward(0, dy)[0]
+        dx = dy.Reshape(self.shape)
+        return dx
 
 
 def flatten(x):
@@ -623,6 +633,7 @@ def backward(y, dy=None):
                     if not isinstance(src_op, Dummy):
                         ready.append((src_op, not_ready[src_op]))
                     del not_ready[src_op]
+        del op  # delete the operation to free all tensors from this op
 
 
 class Layer(object):

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/58e6640d/src/api/core_tensor.i
----------------------------------------------------------------------
diff --git a/src/api/core_tensor.i b/src/api/core_tensor.i
index d94e506..cc72d21 100644
--- a/src/api/core_tensor.i
+++ b/src/api/core_tensor.i
@@ -325,4 +325,7 @@ namespace singa{
 
   Tensor SoftMax(const Tensor &in);
   void SoftMax(const Tensor &in, Tensor *out);
+
+  const Tensor CrossEntropyFwd(const Tensor& p, const Tensor& t);
+  const Tensor SoftmaxCrossEntropyBwd(const Tensor& p, const Tensor& t);
 }

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/58e6640d/src/core/tensor/tensor.cc
----------------------------------------------------------------------
diff --git a/src/core/tensor/tensor.cc b/src/core/tensor/tensor.cc
old mode 100644
new mode 100755
index 05db7cf..e5e8017
--- a/src/core/tensor/tensor.cc
+++ b/src/core/tensor/tensor.cc
@@ -218,10 +218,10 @@ void Tensor::CopyData(const Tensor &src) {
 }
 
 void Tensor::RepeatData(vector<size_t> repeats, int axis, int total_repeats, const Tensor &src) {
-  if(repeats.size() == 1) {
+  if (repeats.size() == 1) {
     CHECK_EQ(Size(), src.Size()*total_repeats);
   } else {
-    CHECK_EQ(Size(), src.Size()*total_repeats/src.shape()[axis]);
+    CHECK_EQ(Size(), src.Size()*total_repeats / src.shape()[axis]);
   }
 
   CHECK(block_ != nullptr);
@@ -344,7 +344,7 @@ Tensor Tensor::Repeat(vector<size_t> repeats, int axis, std::shared_ptr<Device>
     total_repeats = repeats[0];
     tshape.push_back(Product(shape_)*total_repeats);
   } else {
-    if (repeats.size() == 1){
+    if (repeats.size() == 1) {
       total_repeats = repeats[0];
       for (size_t i = 0; i < shape_.size(); i++) {
         if (i == axis) {
@@ -358,15 +358,15 @@ Tensor Tensor::Repeat(vector<size_t> repeats, int axis, std::shared_ptr<Device>
         LOG(FATAL) << "the repeats number doesn't match the axis";
       }
       for (size_t i = 0; i < shape_[axis]; i++) {
-        if(repeats[i] < 0) {
+        if (repeats[i] < 0) {
           LOG(FATAL) << "the repeats number is less than zero";
         }
         total_repeats += repeats[i];
       }
-      for (size_t i = 0; i < shape_.size(); i++){
+      for (size_t i = 0; i < shape_.size(); i++) {
         if (i == axis) {
           tshape.push_back(total_repeats);
-        } else{
+        } else {
           tshape.push_back(shape_[i]);
         }
       }
@@ -539,7 +539,7 @@ void CopyDataToFrom(Tensor *dst, const Tensor &src, const size_t num,
   }
 }
 
-void RepeatDataToFrom(bool broadcast_flag, vector<size_t> repeats, int axis, 
+void RepeatDataToFrom(bool broadcast_flag, vector<size_t> repeats, int axis,
                       Tensor *dst, const Tensor &src, const size_t num) {
   if (repeats.size() == 1) {
     broadcast_flag = true;
@@ -548,7 +548,7 @@ void RepeatDataToFrom(bool broadcast_flag, vector<size_t> repeats, int axis,
       LOG(FATAL) << "When repeats parameter is sequence, axis cannot be None";
     }
   }
-  for (size_t i = 0; i < repeats.size(); i++){
+  for (size_t i = 0; i < repeats.size(); i++) {
     CHECK_GE(repeats[i], 0);
   }
   auto width = SizeOf(src.data_type());
@@ -557,7 +557,7 @@ void RepeatDataToFrom(bool broadcast_flag, vector<size_t> repeats, int axis,
   int chunk = width;
   int axis_shape = 1;
   int shape_outer = 1;
-  if (axis == Noaxis){
+  if (axis == Noaxis) {
     axis_shape = 1;
     shape_outer = Product(src.shape());
   } else {
@@ -565,7 +565,7 @@ void RepeatDataToFrom(bool broadcast_flag, vector<size_t> repeats, int axis,
       shape_outer *= src.shape()[i];
     }
     axis_shape = src.shape()[axis];
-    for(size_t i = axis + 1; i < src.nDim(); i++) {
+    for (size_t i = axis + 1; i < src.nDim(); i++) {
       chunk *= src.shape()[i];
     }
   }
@@ -693,7 +693,7 @@ void Tensor::SetValue(const SType x) {
   CHECK_EQ(sizeof(SType), SizeOf(data_type_));
   //auto size = Size();
   auto ptr = block_;
-  
+
   TYPE_LANG_SWITCH(data_type_, DType, device_->lang(), Lang, {
     // TODO(wangwei) cast x to DType
     device_->Exec([this, x, ptr](Context * ctx) {
@@ -1268,6 +1268,18 @@ void Mult(const SType alpha, const Tensor &A, const Tensor &B, const SType beta,
 // ************************
 // Misc.
 // ************************
+const Tensor CrossEntropyFwd(const Tensor& p, const Tensor& t) {
+  Tensor loss({p.shape(0)}, p.device(), p.data_type());
+  ComputeCrossEntropy(p, t, &loss);
+  return loss;
+}
+
+const Tensor SoftmaxCrossEntropyBwd(const Tensor& p, const Tensor& t) {
+  auto g = p.Clone();
+  SoftmaxCrossEntropyBwd(t, &g);
+  return g;
+}
+
 void ComputeCrossEntropy(const Tensor &p, const Tensor &t, Tensor *loss) {
   CHECK_LE(p.nDim(), 2u);
   CHECK_LE(t.nDim(), 2u);
@@ -1302,6 +1314,7 @@ Tensor Tensor::Reshape(const Shape &shape) {
   if (strides_.size() == 0)
     strides_.push_back(1);
 
+  // TODO(wangwei) remove this condition and report error if size changes.
   if (Product(shape_) != Product(shape)) {
     if (block_ != nullptr && block_->DecRefCount() == 0)
       device_->FreeBlock(block_);
@@ -1316,14 +1329,16 @@ Tensor Tensor::Reshape(const Shape &shape) {
     singa::Transform(*this, &t);
     t.shape_ = shape;
     return t;
- }
-
-  shape_ = shape;
-  generate_strides();
-  Tensor t(shape, device_, data_type_);
-  t.block_ = block_;
-  t.block_->IncRefCount();
-  return t;
+  } else {
+    Tensor t;
+    t.shape_ = shape;
+    t.device_ = device_;
+    t.data_type_ = data_type_;
+    t.block_ = block_;  // be careful about the block inference (mem leaking)
+    t.block_->IncRefCount();
+    t.generate_strides();
+    return t;
+  }
 }
 
 Tensor Tensor::Reshape(Shape &&shape) {
@@ -1344,14 +1359,16 @@ Tensor Tensor::Reshape(Shape &&shape) {
     singa::Transform(*this, &t);
     t.shape_ = shape;
     return t;
- }
-
-  shape_ = shape;
-  generate_strides();
-  Tensor t(shape, device_, data_type_);
-  t.block_ = block_;
-  t.block_->IncRefCount();
-  return t;
+  } else {
+    Tensor t;
+    t.shape_ = shape;
+    t.device_ = device_;
+    t.data_type_ = data_type_;
+    t.block_ = block_;  // be careful about the block inference (mem leaking)
+    t.block_->IncRefCount();
+    t.generate_strides();
+    return t;
+  }
 }
 
 Tensor Reshape(const Tensor &in, const Shape &s) {

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/58e6640d/src/model/layer/cudnn_rnn.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/cudnn_rnn.cc b/src/model/layer/cudnn_rnn.cc
old mode 100644
new mode 100755
index 28a52c5..eb2bfd3
--- a/src/model/layer/cudnn_rnn.cc
+++ b/src/model/layer/cudnn_rnn.cc
@@ -144,7 +144,7 @@ void CudnnRNN::SetRNNDescriptor(shared_ptr<Device> dev) {
     rnn_mode = CUDNN_RNN_TANH;
   else if (rnn_mode_ == "gru")
     rnn_mode = CUDNN_GRU;
-#ifdef CUDNN_MAJOR == 5
+#if CUDNN_MAJOR <= 5
   CUDNN_CHECK(cudnnSetRNNDescriptor(rnn_desc_, hidden_size_, num_stacks_,
                                     dropout_desc_, input_mode, direction,
                                     rnn_mode, dtype_));