You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@singa.apache.org by wa...@apache.org on 2018/07/11 08:29:35 UTC

[3/4] incubator-singa git commit: SINGA-380) Fix bugs from Reshape

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b30d7ea5/src/io/image_transformer.cc
----------------------------------------------------------------------
diff --git a/src/io/image_transformer.cc b/src/io/image_transformer.cc
index 204ad08..0f49321 100644
--- a/src/io/image_transformer.cc
+++ b/src/io/image_transformer.cc
@@ -26,331 +26,328 @@
 
 namespace singa {
 
-  Tensor ImageTransformer::Apply(int flag, Tensor& input) {
-    CHECK_LE(input.nDim(), 4u);
-    CHECK_GE(input.nDim(), 2u);
-    CHECK_EQ(input.data_type(), kFloat32) << "Data type " << input.data_type()
-      << " is invalid for an raw image";
-    srand((unsigned int)time(NULL));
-    /// TODO
-    /// currently only consider one sample each time
+Tensor ImageTransformer::Apply(int flag, Tensor& input) {
+  CHECK_LE(input.nDim(), 4u);
+  CHECK_GE(input.nDim(), 2u);
+  CHECK_EQ(input.data_type(), kFloat32) << "Data type " << input.data_type()
+                                        << " is invalid for an raw image";
+  srand((unsigned int)time(NULL));
+  /// TODO
+  /// currently only consider one sample each time
 
-    /// resize image using opencv resize
-    Tensor temp1;
+  /// resize image using opencv resize
+  Tensor temp1;
 #ifdef USE_OPENCV
-    temp1 = resize(input, resize_height_, resize_width_, image_dim_order_);
+  temp1 = resize(input, resize_height_, resize_width_, image_dim_order_);
 #else
-    temp1 = input;
+  temp1 = input;
 #endif
 
-    /// crop
-    Tensor temp2;
-    size_t height = 0, width = 0;
-    if (input.nDim() >= 3u) {
-      if (image_dim_order_ == "CHW")
-        height = temp1.shape(input.nDim() - 2), width = temp1.shape(input.nDim() - 1);
-      else if (image_dim_order_ == "HWC")
-        height = temp1.shape(input.nDim() - 3), width = temp1.shape(input.nDim() - 2);
-      else
-        LOG(FATAL) << "Unknow dimension order for images " << image_dim_order_
-               << " Only support 'HWC' and 'CHW'";
-    } else /// input is 2D gray image
-      height = temp1.shape(0), width = temp1.shape(1);
+  /// crop
+  Tensor temp2;
+  size_t height = 0, width = 0;
+  if (input.nDim() >= 3u) {
+    if (image_dim_order_ == "CHW")
+      height = temp1.shape(input.nDim() - 2), width = temp1.shape(input.nDim() - 1);
+    else if (image_dim_order_ == "HWC")
+      height = temp1.shape(input.nDim() - 3), width = temp1.shape(input.nDim() - 2);
+    else
+      LOG(FATAL) << "Unknow dimension order for images " << image_dim_order_
+                 << " Only support 'HWC' and 'CHW'";
+  } else /// input is 2D gray image
+    height = temp1.shape(0), width = temp1.shape(1);
 
-    if (crop_shape_.size() == 2) {
-      if (flag == kTrain) { 
-        /// random crop
-        if (crop_shape_[0] > height || crop_shape_[0] > width)
-          LOG(FATAL) << "Crop size larger than the size of raw image";
-        size_t crop_h_offset = rand() % ((height - crop_shape_[0]) / 2), 
-               crop_w_offset = rand() % ((width - crop_shape_[1]) / 2);
-        temp2 = crop(temp1, crop_shape_[0], crop_shape_[1], 
-                  crop_h_offset, crop_w_offset, image_dim_order_);
-      } else if (flag == kEval) {
-        /// central crop
-        size_t crop_h_offset = (height - crop_shape_[0]) / 2,
-               crop_w_offset = (width - crop_shape_[1]) / 2;
-        temp2 = crop(temp1, crop_shape_[0], crop_shape_[1], 
-                  crop_h_offset, crop_w_offset, image_dim_order_); 
-      }
+  if (crop_shape_.size() == 2) {
+    if (flag == kTrain) {
+      /// random crop
+      if (crop_shape_[0] > height || crop_shape_[0] > width)
+        LOG(FATAL) << "Crop size larger than the size of raw image";
+      size_t crop_h_offset = rand() % ((height - crop_shape_[0]) / 2),
+             crop_w_offset = rand() % ((width - crop_shape_[1]) / 2);
+      temp2 = crop(temp1, crop_shape_[0], crop_shape_[1],
+                   crop_h_offset, crop_w_offset, image_dim_order_);
+    } else if (flag == kEval) {
+      /// central crop
+      size_t crop_h_offset = (height - crop_shape_[0]) / 2,
+             crop_w_offset = (width - crop_shape_[1]) / 2;
+      temp2 = crop(temp1, crop_shape_[0], crop_shape_[1],
+                   crop_h_offset, crop_w_offset, image_dim_order_);
     }
-    else temp2 = temp1;
+  } else temp2 = temp1;
 
-    /// mirror
-    Tensor output;
-    if ((flag == kTrain) && (rand() % 2))
-        output = mirror(temp2, true, false, image_dim_order_);
-    else output = temp2;
-    return output;
-  }
+  /// mirror
+  Tensor output;
+  if ((flag == kTrain) && (rand() % 2))
+    output = mirror(temp2, true, false, image_dim_order_);
+  else output = temp2;
+  return output;
+}
 
 #ifdef USE_OPENCV
-  Tensor resize(Tensor& input, const size_t resize_height, 
-               const size_t resize_width, const string& image_dim_order) {
-    CHECK_LE(input.nDim(), 4u);
-    CHECK_GE(input.nDim(), 2u);
-    if (!resize_height || !resize_width) return input;
-    Tensor output;
-    cv::Mat mat;
-    const auto* in = input.data<float>();
-    if (input.nDim() == 4u) {
-      /// TODO
-      /// batch based resize
-      LOG(FATAL) << "Not implemented";
-    } else if (input.nDim() == 3u) {
-      if (image_dim_order == "CHW") {
-        size_t height = input.shape(1), width = input.shape(2),
-               channel = input.shape(0);
-        if (channel == 3u) {
-          mat = cv::Mat(height, width, CV_32FC3, cv::Scalar(0, 0, 0));
-          for (size_t i = 0; i < height; i++)
-            for (size_t j = 0; j < width; j++)
-              for (size_t k = 0; k < channel; k++)
-                mat.at<cv::Vec3f>(i, j)[k] = in[k * height * width + i * width + j];
-        }
-        else if (channel == 1u) {
-          mat = cv::Mat(height, width, CV_32FC1);
-          for (size_t i = 0; i < height; i++)
-            for (size_t j = 0; j < width; j++)
-                mat.at<cv::Vec<float, 1>>(i, j)[0] = in[i * width + j];
-        }
-        else LOG(FATAL) << "Invalid channel size: " << channel;
-      } else if (image_dim_order == "HWC") {
-        size_t height = input.shape(0), width = input.shape(1),
-               channel = input.shape(2);
-        if (channel == 3u) {
-          mat = cv::Mat(height, width, CV_32FC3, cv::Scalar(0, 0, 0));
-          for (size_t i = 0; i < height; i++)
-            for (size_t j = 0; j < width; j++)
-              for (size_t k = 0; k < channel; k++)
-                mat.at<cv::Vec3f>(i, j)[k] =
-                  in[i * width * channel + j * channel + k];
-        } else if (channel == 1u) { /// 2D gray image
-          mat = cv::Mat(height, width, CV_32FC1);
-          for (size_t i = 0; i < height; i++)
-            for (size_t j = 0; j < width; j++)
-              mat.at<cv::Vec<float, 1>>(i, j)[0] = in[i * width + j];
-        } else LOG(FATAL) << "Invalid channel size: " << channel;
-      } else {
-        LOG(FATAL) << "Unknow dimension order for images " << image_dim_order
-                   << " Only support 'HWC' and 'CHW'";
-      }
-    } else { /// 2D gray image
-      size_t height = input.shape(0), width = input.shape(1);
-      mat = cv::Mat(height, width, CV_32FC1);
-      for (size_t i = 0; i < height; i++)
-        for (size_t j = 0; j < width; j++)
-          mat.at<cv::Vec<float, 1>>(i, j)[0] = in[i * width + j];
+Tensor resize(Tensor& input, const size_t resize_height,
+              const size_t resize_width, const string& image_dim_order) {
+  CHECK_LE(input.nDim(), 4u);
+  CHECK_GE(input.nDim(), 2u);
+  if (!resize_height || !resize_width) return input;
+  Tensor output;
+  cv::Mat mat;
+  const auto* in = input.data<float>();
+  if (input.nDim() == 4u) {
+    /// TODO
+    /// batch based resize
+    LOG(FATAL) << "Not implemented";
+  } else if (input.nDim() == 3u) {
+    if (image_dim_order == "CHW") {
+      size_t height = input.shape(1), width = input.shape(2),
+             channel = input.shape(0);
+      if (channel == 3u) {
+        mat = cv::Mat(height, width, CV_32FC3, cv::Scalar(0, 0, 0));
+        for (size_t i = 0; i < height; i++)
+          for (size_t j = 0; j < width; j++)
+            for (size_t k = 0; k < channel; k++)
+              mat.at<cv::Vec3f>(i, j)[k] = in[k * height * width + i * width + j];
+      } else if (channel == 1u) {
+        mat = cv::Mat(height, width, CV_32FC1);
+        for (size_t i = 0; i < height; i++)
+          for (size_t j = 0; j < width; j++)
+            mat.at<cv::Vec<float, 1>>(i, j)[0] = in[i * width + j];
+      } else LOG(FATAL) << "Invalid channel size: " << channel;
+    } else if (image_dim_order == "HWC") {
+      size_t height = input.shape(0), width = input.shape(1),
+             channel = input.shape(2);
+      if (channel == 3u) {
+        mat = cv::Mat(height, width, CV_32FC3, cv::Scalar(0, 0, 0));
+        for (size_t i = 0; i < height; i++)
+          for (size_t j = 0; j < width; j++)
+            for (size_t k = 0; k < channel; k++)
+              mat.at<cv::Vec3f>(i, j)[k] =
+                in[i * width * channel + j * channel + k];
+      } else if (channel == 1u) { /// 2D gray image
+        mat = cv::Mat(height, width, CV_32FC1);
+        for (size_t i = 0; i < height; i++)
+          for (size_t j = 0; j < width; j++)
+            mat.at<cv::Vec<float, 1>>(i, j)[0] = in[i * width + j];
+      } else LOG(FATAL) << "Invalid channel size: " << channel;
+    } else {
+      LOG(FATAL) << "Unknow dimension order for images " << image_dim_order
+                 << " Only support 'HWC' and 'CHW'";
     }
-    cv::Size size(resize_width, resize_height);
-    cv::Mat resized;
-    cv::resize(mat, resized, size);
-    CHECK_EQ(resized.size().height, resize_height);
-    CHECK_EQ(resized.size().width, resize_width);
-    size_t new_size = resize_height * resize_width * resized.channels();
-    float* out = new float[new_size];
-    if (input.nDim() == 4u) {
-      /// TODO
-      /// batch based resize
-      LOG(FATAL) << "Not implemented";
-    } else if (input.nDim() == 3u) {
-      if (image_dim_order == "CHW") {
-        size_t height = resize_height, width = resize_width,
-           channel = input.shape(0);
-        if (channel == 3u) {
-          for (size_t i = 0; i < height; i++)
-            for (size_t j = 0; j < width; j++)
-              for (size_t k = 0; k < channel; k++)
-                out[k * height * width + i * width + j] = resized.at<cv::Vec3f>(i, j)[k];
-        } else { /// 2D gray image
-          for (size_t i = 0; i < height; i++)
-            for (size_t j = 0; j < width; j++)
-              out[i * width + j] = resized.at<cv::Vec<float, 1>>(i, j)[0];
-        }
-        Tensor temp(Shape{channel, height, width});
-        temp.CopyDataFromHostPtr<float>(out, new_size);
-        output = temp;
-      } else {
-        size_t height = resize_height, width = resize_width,
-           channel = input.shape(2);
-        if (channel == 3u) {
-          for (size_t i = 0; i < height; i++)
-            for (size_t j = 0; j < width; j++)
-              for (size_t k = 0; k < channel; k++)
-                out[i * width * channel + j * channel + k] = resized.at<cv::Vec3f>(i, j)[k];
-        } else { /// 1 channel
-          for (size_t i = 0; i < height; i++)
-            for (size_t j = 0; j < width; j++)
-              out[i * width + j] = resized.at<cv::Vec<float, 1>>(i, j)[0];
-        }
-        Tensor temp(Shape{height, width, channel}); 
-        temp.CopyDataFromHostPtr<float>(out, new_size);
-        output = temp;
+  } else { /// 2D gray image
+    size_t height = input.shape(0), width = input.shape(1);
+    mat = cv::Mat(height, width, CV_32FC1);
+    for (size_t i = 0; i < height; i++)
+      for (size_t j = 0; j < width; j++)
+        mat.at<cv::Vec<float, 1>>(i, j)[0] = in[i * width + j];
+  }
+  cv::Size size(resize_width, resize_height);
+  cv::Mat resized;
+  cv::resize(mat, resized, size);
+  CHECK_EQ(resized.size().height, resize_height);
+  CHECK_EQ(resized.size().width, resize_width);
+  size_t new_size = resize_height * resize_width * resized.channels();
+  float* out = new float[new_size];
+  if (input.nDim() == 4u) {
+    /// TODO
+    /// batch based resize
+    LOG(FATAL) << "Not implemented";
+  } else if (input.nDim() == 3u) {
+    if (image_dim_order == "CHW") {
+      size_t height = resize_height, width = resize_width,
+             channel = input.shape(0);
+      if (channel == 3u) {
+        for (size_t i = 0; i < height; i++)
+          for (size_t j = 0; j < width; j++)
+            for (size_t k = 0; k < channel; k++)
+              out[k * height * width + i * width + j] = resized.at<cv::Vec3f>(i, j)[k];
+      } else { /// 2D gray image
+        for (size_t i = 0; i < height; i++)
+          for (size_t j = 0; j < width; j++)
+            out[i * width + j] = resized.at<cv::Vec<float, 1>>(i, j)[0];
+      }
+      Tensor temp(Shape{channel, height, width});
+      temp.CopyDataFromHostPtr<float>(out, new_size);
+      output = temp;
+    } else {
+      size_t height = resize_height, width = resize_width,
+             channel = input.shape(2);
+      if (channel == 3u) {
+        for (size_t i = 0; i < height; i++)
+          for (size_t j = 0; j < width; j++)
+            for (size_t k = 0; k < channel; k++)
+              out[i * width * channel + j * channel + k] = resized.at<cv::Vec3f>(i, j)[k];
+      } else { /// 1 channel
+        for (size_t i = 0; i < height; i++)
+          for (size_t j = 0; j < width; j++)
+            out[i * width + j] = resized.at<cv::Vec<float, 1>>(i, j)[0];
       }
-    } else { /// 2D gray image
-      size_t height = resize_height, width = resize_width;
-      for (size_t i = 0; i < height; i++)
-        for (size_t j = 0; j < width; j++)
-          out[i * width + j] = resized.at<cv::Vec<float, 1>>(i, j)[0];
-      Tensor temp(Shape{height, width});
+      Tensor temp(Shape{height, width, channel});
       temp.CopyDataFromHostPtr<float>(out, new_size);
       output = temp;
     }
-    delete[] out;
-    return output;
+  } else { /// 2D gray image
+    size_t height = resize_height, width = resize_width;
+    for (size_t i = 0; i < height; i++)
+      for (size_t j = 0; j < width; j++)
+        out[i * width + j] = resized.at<cv::Vec<float, 1>>(i, j)[0];
+    Tensor temp(Shape{height, width});
+    temp.CopyDataFromHostPtr<float>(out, new_size);
+    output = temp;
   }
+  delete[] out;
+  return output;
+}
 #endif
 
-  Tensor crop(Tensor& input, const size_t crop_height, const size_t crop_width, 
-             const size_t crop_h_offset, const size_t crop_w_offset, 
-             const string& image_dim_order) {
-    CHECK_LE(input.nDim(), 4u);
-    CHECK_GE(input.nDim(), 2u);
+Tensor crop(Tensor& input, const size_t crop_height, const size_t crop_width,
+            const size_t crop_h_offset, const size_t crop_w_offset,
+            const string& image_dim_order) {
+  CHECK_LE(input.nDim(), 4u);
+  CHECK_GE(input.nDim(), 2u);
 
-    Tensor output;
-    const float* in = input.data<float>();
-    size_t out_idx = 0, in_idx = 0;
-    if (input.nDim() == 4u) {
-      /// TODO
-      LOG(FATAL) << "Not implemented";
-    } else if (input.nDim() == 3u) {
-      if (image_dim_order == "CHW") {
-        size_t height = input.shape(1), width = input.shape(2),
-            channel = input.shape(0); 
-        CHECK_LE(crop_height + crop_h_offset, height);
-        CHECK_LE(crop_width + crop_w_offset, width);
-        float* out = new float[crop_height * crop_width * channel];
-        for (size_t c = 0; c < channel; c++) {
-          for (size_t h = 0; h < crop_height; h++) {
-            for (size_t w = 0; w < crop_width; w++) {
-              in_idx = (c * height + crop_h_offset + h) * width + crop_w_offset + w;
-              out_idx = (c * crop_height + h) * crop_width + w;
-              out[out_idx] = in[in_idx];
-            }
-          }
-        }
-        output = Reshape(output, Shape{channel, crop_height, crop_width});
-        output.CopyDataFromHostPtr<float>(out, crop_height * crop_width * channel);
-        delete[] out;
-      } else if (image_dim_order == "HWC") {
-        size_t height = input.shape(0), width = input.shape(1), 
-               channel = input.shape(2); 
-        CHECK_LE(crop_height + crop_h_offset, height);
-        CHECK_LE(crop_width + crop_w_offset, width);
-        float* out = new float[crop_height * crop_width * channel];
-        for (size_t c = 0; c < channel; c++) {
-          for (size_t h = 0; h < crop_height; h++) {
-            for (size_t w = 0; w < crop_width; w++) {
-              in_idx = ((crop_h_offset + h) * width + crop_w_offset + w) * channel + c;
-              out_idx = (h * crop_width + w) * channel + c;
-              out[out_idx] = in[in_idx];
-            }
+  Tensor output;
+  const float* in = input.data<float>();
+  size_t out_idx = 0, in_idx = 0;
+  if (input.nDim() == 4u) {
+    /// TODO
+    LOG(FATAL) << "Not implemented";
+  } else if (input.nDim() == 3u) {
+    if (image_dim_order == "CHW") {
+      size_t height = input.shape(1), width = input.shape(2),
+             channel = input.shape(0);
+      CHECK_LE(crop_height + crop_h_offset, height);
+      CHECK_LE(crop_width + crop_w_offset, width);
+      float* out = new float[crop_height * crop_width * channel];
+      for (size_t c = 0; c < channel; c++) {
+        for (size_t h = 0; h < crop_height; h++) {
+          for (size_t w = 0; w < crop_width; w++) {
+            in_idx = (c * height + crop_h_offset + h) * width + crop_w_offset + w;
+            out_idx = (c * crop_height + h) * crop_width + w;
+            out[out_idx] = in[in_idx];
           }
         }
-        output = Reshape(output, Shape{crop_height, crop_width, channel});
-        output.CopyDataFromHostPtr<float>(out, crop_height * crop_width * channel);
-        delete[] out;
-      } else {
-        LOG(FATAL) << "Unknow dimension order for images " << image_dim_order
-                   << " Only support 'HWC' and 'CHW'";
       }
-    } else { /// 2D gray image
-      size_t height = input.shape(0), width = input.shape(1); 
+      output.SetShape(Shape{channel, crop_height, crop_width});
+      output.CopyDataFromHostPtr<float>(out, crop_height * crop_width * channel);
+      delete[] out;
+    } else if (image_dim_order == "HWC") {
+      size_t height = input.shape(0), width = input.shape(1),
+             channel = input.shape(2);
       CHECK_LE(crop_height + crop_h_offset, height);
       CHECK_LE(crop_width + crop_w_offset, width);
-      float* out = new float[crop_height * crop_width];
-      for (size_t h = 0; h < crop_height; h++) {
-        for (size_t w = 0; w < crop_width; w++) {
-          in_idx = (crop_h_offset + h) * width + crop_w_offset + w;
-          out_idx = h * crop_width + w;
-          out[out_idx] = in[in_idx];
+      float* out = new float[crop_height * crop_width * channel];
+      for (size_t c = 0; c < channel; c++) {
+        for (size_t h = 0; h < crop_height; h++) {
+          for (size_t w = 0; w < crop_width; w++) {
+            in_idx = ((crop_h_offset + h) * width + crop_w_offset + w) * channel + c;
+            out_idx = (h * crop_width + w) * channel + c;
+            out[out_idx] = in[in_idx];
+          }
         }
       }
-      output = Reshape(output, Shape{crop_height, crop_width});
-      output.CopyDataFromHostPtr<float>(out, crop_height * crop_width);
+      output.SetShape(Shape{crop_height, crop_width, channel});
+      output.CopyDataFromHostPtr<float>(out, crop_height * crop_width * channel);
       delete[] out;
+    } else {
+      LOG(FATAL) << "Unknow dimension order for images " << image_dim_order
+                 << " Only support 'HWC' and 'CHW'";
+    }
+  } else { /// 2D gray image
+    size_t height = input.shape(0), width = input.shape(1);
+    CHECK_LE(crop_height + crop_h_offset, height);
+    CHECK_LE(crop_width + crop_w_offset, width);
+    float* out = new float[crop_height * crop_width];
+    for (size_t h = 0; h < crop_height; h++) {
+      for (size_t w = 0; w < crop_width; w++) {
+        in_idx = (crop_h_offset + h) * width + crop_w_offset + w;
+        out_idx = h * crop_width + w;
+        out[out_idx] = in[in_idx];
+      }
     }
-    return output;
+    output.SetShape(Shape{crop_height, crop_width});
+    output.CopyDataFromHostPtr<float>(out, crop_height * crop_width);
+    delete[] out;
   }
+  return output;
+}
 
-  Tensor mirror(Tensor& input, const bool horizontal_mirror,
-             const bool vertical_mirror, const string& image_dim_order) {
-    CHECK_LE(input.nDim(), 4u);
-    CHECK_GE(input.nDim(), 2u);
-    if (!horizontal_mirror && !vertical_mirror) return input;
+Tensor mirror(Tensor& input, const bool horizontal_mirror,
+              const bool vertical_mirror, const string& image_dim_order) {
+  CHECK_LE(input.nDim(), 4u);
+  CHECK_GE(input.nDim(), 2u);
+  if (!horizontal_mirror && !vertical_mirror) return input;
 
-    Tensor output;
-    const float* in = input.data<float>();
-    size_t out_idx = 0, in_idx = 0;
-    if (input.nDim() == 4u) {
-      /// TODO
-      LOG(FATAL) << "Not implemented";
-    } else if (input.nDim() == 3u) {
-      if (image_dim_order == "CHW") {
-        size_t height = input.shape(1), width = input.shape(2),
-            channel = input.shape(0);
-        float* out = new float[height * width * channel];
-        for (size_t c = 0; c < channel; c++) {
-          for (size_t h = 0; h < height; h++) {
-            for (size_t w = 0; w < width; w++) {
-              in_idx = (c * height + h) * width + w;
-              if (horizontal_mirror && vertical_mirror)
-                out_idx = (c * height + (height - 1 - h)) * width + (width - 1 - w);
-              else if (horizontal_mirror)
-                out_idx = (c * height + h) * width + (width - 1 - w);
-              else /// only do vertical mirror
-                out_idx = (c * height + (height - 1 - h)) * width + w;
-              out[out_idx] = in[in_idx];
-            }
-          }
-        }
-        output = Reshape(output, Shape{channel, height, width});
-        output.CopyDataFromHostPtr<float>(out, height * width * channel);
-        delete[] out;
-      } else if (image_dim_order == "HWC") {
-        size_t height = input.shape(0), width = input.shape(1),
-            channel = input.shape(2);
-        float* out = new float[height * width * channel];
-        for (size_t c = 0; c < channel; c++) {
-          for (size_t h = 0; h < height; h++) {
-            for (size_t w = 0; w < width; w++) {
-              in_idx = (h * width + w) * channel + c;
-              if (horizontal_mirror && vertical_mirror)
-                out_idx = ((height - 1 - h) * width + (width - 1 - w)) * channel + c;
-              else if (horizontal_mirror)
-                out_idx = (h * width + (width - 1 - w)) * channel + c;
-              else /// only do vertical mirror
-                out_idx = ((height - 1 - h) * width + w) * channel + c;
-              out[out_idx] = in[in_idx];
-            }
+  Tensor output;
+  const float* in = input.data<float>();
+  size_t out_idx = 0, in_idx = 0;
+  if (input.nDim() == 4u) {
+    /// TODO
+    LOG(FATAL) << "Not implemented";
+  } else if (input.nDim() == 3u) {
+    if (image_dim_order == "CHW") {
+      size_t height = input.shape(1), width = input.shape(2),
+             channel = input.shape(0);
+      float* out = new float[height * width * channel];
+      for (size_t c = 0; c < channel; c++) {
+        for (size_t h = 0; h < height; h++) {
+          for (size_t w = 0; w < width; w++) {
+            in_idx = (c * height + h) * width + w;
+            if (horizontal_mirror && vertical_mirror)
+              out_idx = (c * height + (height - 1 - h)) * width + (width - 1 - w);
+            else if (horizontal_mirror)
+              out_idx = (c * height + h) * width + (width - 1 - w);
+            else /// only do vertical mirror
+              out_idx = (c * height + (height - 1 - h)) * width + w;
+            out[out_idx] = in[in_idx];
           }
         }
-        output = Reshape(output, Shape{height, width, channel});
-        output.CopyDataFromHostPtr<float>(out, height * width * channel);
-        delete[] out;
-      } else {
-        LOG(FATAL) << "Unknow dimension order for images " << image_dim_order
-                   << " Only support 'HWC' and 'CHW'";
       }
-    } else { /// 2D gray image
-      size_t height = input.shape(0), width = input.shape(1);
-      float* out = new float[height * width];
-      for (size_t h = 0; h < height; h++) {
-        for (size_t w = 0; w < width; w++) {
-          in_idx = h * width + w;
-          if (horizontal_mirror && vertical_mirror)
-            out_idx = (height - 1 - h) * width + (width - 1 - w);
-          else if (horizontal_mirror)
-            out_idx = h * width + (width - 1 - w);
-          else /// only do vertical mirror
-            out_idx = (height - 1 - h) * width + w;
-          out[out_idx] = in[in_idx];
+      output.SetShape(Shape{channel, height, width});
+      output.CopyDataFromHostPtr<float>(out, height * width * channel);
+      delete[] out;
+    } else if (image_dim_order == "HWC") {
+      size_t height = input.shape(0), width = input.shape(1),
+             channel = input.shape(2);
+      float* out = new float[height * width * channel];
+      for (size_t c = 0; c < channel; c++) {
+        for (size_t h = 0; h < height; h++) {
+          for (size_t w = 0; w < width; w++) {
+            in_idx = (h * width + w) * channel + c;
+            if (horizontal_mirror && vertical_mirror)
+              out_idx = ((height - 1 - h) * width + (width - 1 - w)) * channel + c;
+            else if (horizontal_mirror)
+              out_idx = (h * width + (width - 1 - w)) * channel + c;
+            else /// only do vertical mirror
+              out_idx = ((height - 1 - h) * width + w) * channel + c;
+            out[out_idx] = in[in_idx];
+          }
         }
       }
-      output = Reshape(output, Shape{height, width});
-      output.CopyDataFromHostPtr<float>(out, height * width);
+      output.SetShape(Shape{height, width, channel});
+      output.CopyDataFromHostPtr<float>(out, height * width * channel);
       delete[] out;
+    } else {
+      LOG(FATAL) << "Unknow dimension order for images " << image_dim_order
+                 << " Only support 'HWC' and 'CHW'";
+    }
+  } else { /// 2D gray image
+    size_t height = input.shape(0), width = input.shape(1);
+    float* out = new float[height * width];
+    for (size_t h = 0; h < height; h++) {
+      for (size_t w = 0; w < width; w++) {
+        in_idx = h * width + w;
+        if (horizontal_mirror && vertical_mirror)
+          out_idx = (height - 1 - h) * width + (width - 1 - w);
+        else if (horizontal_mirror)
+          out_idx = h * width + (width - 1 - w);
+        else /// only do vertical mirror
+          out_idx = (height - 1 - h) * width + w;
+        out[out_idx] = in[in_idx];
+      }
     }
-    return output;
+    output.SetShape(Shape{height, width});
+    output.CopyDataFromHostPtr<float>(out, height * width);
+    delete[] out;
   }
+  return output;
+}
 } // namespace singa

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b30d7ea5/src/model/layer/batchnorm.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/batchnorm.cc b/src/model/layer/batchnorm.cc
index 4e74a82..d2b0c3e 100644
--- a/src/model/layer/batchnorm.cc
+++ b/src/model/layer/batchnorm.cc
@@ -44,7 +44,7 @@ void BatchNorm::Setup(const Shape& in_sample, const LayerConf& conf) {
   else
     is_2d_ = false;
 
-  bnScale_.Reshape(Shape{channels_});
+  bnScale_.SetShape(Shape{channels_});
   bnBias_.ResetLike(bnScale_);
   runningMean_.ResetLike(bnScale_);
   runningVariance_.ResetLike(bnScale_);
@@ -68,19 +68,18 @@ void BatchNorm::ToDevice(std::shared_ptr<Device> device) {
 const Tensor BatchNorm::Forward(int flag, const Tensor& input) {
   Tensor x = input.Clone();
   x.Reshape(Shape{input.shape(0), input.Size() / input.shape(0)});
-  Tensor output, mean, var, xnorm;
+  Tensor output;
   output.ResetLike(x);
   // TODO(wangwei) input sample shape check
-
   if ((flag & kTrain) == kTrain) {  // forward for train
     if (is_2d_) {                   // batchnorm_per_activation mode
-      mean = Average(x, 0);
+      auto mean = Average(x, 0);
       runningMean_ *= 1.0f - factor_;
       Axpy(factor_, mean, &runningMean_);
-      xnorm = x.Clone();
+      auto xnorm = x.Clone();
       SubRow(mean, &xnorm);
       xnorm = Square(xnorm);
-      var = Average(xnorm, 0);
+      auto var = Average(xnorm, 0);
       runningVariance_ *= 1.0f - factor_;
       Axpy(factor_, var, &runningVariance_);
       Tensor tmp = var.Clone();
@@ -102,7 +101,7 @@ const Tensor BatchNorm::Forward(int flag, const Tensor& input) {
     }
   } else {         // forward for test
     if (is_2d_) {  // batchnorm_per_activation mode
-      xnorm = x.Clone();
+      auto xnorm = x.Clone();
       SubRow(runningMean_, &xnorm);
       Tensor tmp = runningVariance_.Clone();
       tmp = Sqrt(tmp);
@@ -134,7 +133,7 @@ const Tensor BatchNorm::Forward(int flag, const Tensor& input) {
       scale.Reshape(Shape{channels_ * height_ * width_});
       bias.Reshape(Shape{channels_ * height_ * width_});
 
-      xnorm = x.Clone();
+      auto xnorm = x.Clone();
       SubRow(mean, &xnorm);
       var = Sqrt(var);
       var += 1e-6f;

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b30d7ea5/src/model/layer/convolution.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/convolution.cc b/src/model/layer/convolution.cc
index cc77433..3718d8d 100755
--- a/src/model/layer/convolution.cc
+++ b/src/model/layer/convolution.cc
@@ -96,9 +96,9 @@ void Convolution::Setup(const Shape &in_sample, const LayerConf &conf) {
   col_width_ = conv_height_ * conv_width_;
 
   // Setup shape of weight_ and bias_
-  weight_.Reshape(Shape{num_filters_, col_height_});
+  weight_.SetShape(Shape{num_filters_, col_height_});
   if (bias_term_)
-    bias_.Reshape(Shape{num_filters_});
+    bias_.SetShape(Shape{num_filters_});
   // Assume the order of param is: weight, bias
   for (const auto &spec : conf.param()) param_specs_.push_back(spec);
 }
@@ -174,8 +174,8 @@ const std::pair<Tensor, vector<Tensor>> Convolution::Backward(
     col_data.CopyDataFromHostPtr(data_col, col_height_ * col_width_);
     Tensor grad_b(Shape{num_filters_, conv_height_ * conv_width_});
     CopyDataToFrom(&grad_b, grad, grad_b.Size(), 0, b * grad_b.Size());
-    dw += Mult(grad_b, col_data.T());
-    Tensor dcol_b = Mult(weight_.T(), grad_b);
+    dw += Mult(grad_b, Transpose(col_data));
+    Tensor dcol_b = Mult(Transpose(weight_), grad_b);
     auto dcol_data = dcol_b.data<float>();
     Col2im(dcol_data, channels_, height_, width_, kernel_h_, kernel_w_, pad_h_,
            pad_w_, stride_h_, stride_w_, dx_b);

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b30d7ea5/src/model/layer/cudnn_batchnorm.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/cudnn_batchnorm.cc b/src/model/layer/cudnn_batchnorm.cc
index 5c93a6b..389b41b 100644
--- a/src/model/layer/cudnn_batchnorm.cc
+++ b/src/model/layer/cudnn_batchnorm.cc
@@ -39,8 +39,8 @@ void CudnnBatchNorm::ToDevice(std::shared_ptr<Device> device) {
 
 void CudnnBatchNorm::Setup(const Shape& in_sample, const LayerConf& conf) {
   BatchNorm::Setup(in_sample, conf);
-  resultSaveMean_.Reshape(Shape{channels_});
-  resultSaveVariance_.Reshape(Shape{channels_});
+  resultSaveMean_.SetShape(Shape{channels_});
+  resultSaveVariance_.SetShape(Shape{channels_});
 }
 
 void CudnnBatchNorm::InitCudnn(const Shape& shape, DataType dtype) {

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b30d7ea5/src/model/layer/dense.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/dense.cc b/src/model/layer/dense.cc
index fac9130..36a7a91 100644
--- a/src/model/layer/dense.cc
+++ b/src/model/layer/dense.cc
@@ -40,11 +40,11 @@ void Dense::Setup(const Shape& in_sample, const LayerConf &conf) {
   transpose_ = dense_conf.transpose();
   bias_term_ = dense_conf.bias_term();
   if (transpose_)  // was {vdim_, hdim} by zhaojing?
-    weight_.Reshape(Shape{hdim_, vdim_});
+    weight_.SetShape(Shape{hdim_, vdim_});
   else
-    weight_.Reshape(Shape{vdim_, hdim_});
+    weight_.SetShape(Shape{vdim_, hdim_});
   if (bias_term_)
-    bias_.Reshape(Shape{hdim_});
+    bias_.SetShape(Shape{hdim_});
   for (auto specs: conf.param())
     param_specs_.push_back(specs);
 }
@@ -55,7 +55,7 @@ const Tensor Dense::Forward(int flag, const Tensor &input) {
   Tensor output;
   CHECK_EQ(input.nDim(), 2u);
   if (transpose_)  // use the transposed version of weight_ for computing
-    output = Mult(input, weight_.T());
+    output = Mult(input, Transpose(weight_));
   else
     output = Mult(input, weight_);
   if (bias_term_)
@@ -81,10 +81,10 @@ const std::pair<Tensor, vector<Tensor>> Dense::Backward(int flag,
   }
   if (transpose_) {
     dx = Mult(grad, weight_);
-    dw = Mult(grad.T(), src_data);
+    dw = Mult(Transpose(grad), src_data);
   } else {
-    dx = Mult(grad, weight_.T());
-    dw = Mult(src_data.T(), grad);
+    dx = Mult(grad, Transpose(weight_));
+    dw = Mult(Transpose(src_data), grad);
   }
   param_grad.push_back(dw);
   if (bias_term_)

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b30d7ea5/src/model/layer/flatten.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/flatten.cc b/src/model/layer/flatten.cc
index 561c310..592e892 100644
--- a/src/model/layer/flatten.cc
+++ b/src/model/layer/flatten.cc
@@ -49,8 +49,7 @@ const Tensor Flatten::Forward(int flag, const Tensor &input) {
 const std::pair<Tensor, vector<Tensor> > Flatten::Backward(int flag,
                                                            const Tensor &grad) {
   vector<Tensor> param_grad;
-  Tensor input_grad = grad;
-  input_grad.Reshape(input_shape_);
+  Tensor input_grad = Reshape(grad, input_shape_);
   return std::make_pair(input_grad, param_grad);
 }
 

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b30d7ea5/src/model/layer/lrn.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/lrn.cc b/src/model/layer/lrn.cc
index 4fdb5c9..a1776fa 100644
--- a/src/model/layer/lrn.cc
+++ b/src/model/layer/lrn.cc
@@ -52,8 +52,7 @@ const Tensor LRN::Forward(int flag, const Tensor& input) {
                    std::min(input.shape(1), c + local_size_ / 2 + 1));
       window = Square(window);
 
-      Tensor tmp, ch;
-      tmp.Reshape(Shape{input.shape(2) * input.shape(3)});
+      Tensor ch, tmp(Shape{input.shape(2) * input.shape(3)});
       SumRows(window, &tmp);
 
       tmp *= alpha_;
@@ -97,8 +96,7 @@ const std::pair<Tensor, vector<Tensor>> LRN::Backward(int flag,
         Tensor window =
             CopyRows(image, std::max(0, static_cast<int>(c) - local_size_ / 2),
                      std::min(grad.shape(1), c + local_size_ / 2 + 1));
-        Tensor tmp;
-        tmp.Reshape(Shape{grad.shape(2) * grad.shape(3)});
+        Tensor tmp(Shape{grad.shape(2) * grad.shape(3)});
         window = Square(window);
         SumRows(window, &tmp);
         tmp *= alpha_;
@@ -126,8 +124,7 @@ const std::pair<Tensor, vector<Tensor>> LRN::Backward(int flag,
         Tensor window =
             CopyRows(image, std::max(0, static_cast<int>(c) - local_size_ / 2),
                      std::min(grad.shape(1), c + local_size_ / 2 + 1));
-        Tensor tmpr;
-        tmpr.Reshape(Shape{grad.shape(2) * grad.shape(3)});
+        Tensor tmpr(Shape{grad.shape(2) * grad.shape(3)});
         SumRows(window, &tmpr);
         tmpr.Reshape(Shape{grad.shape(2), grad.shape(3)});
         channels.push_back(tmpr);

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b30d7ea5/src/model/layer/opencl_convolution.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/opencl_convolution.cc b/src/model/layer/opencl_convolution.cc
index 063c4c3..eb25f5e 100644
--- a/src/model/layer/opencl_convolution.cc
+++ b/src/model/layer/opencl_convolution.cc
@@ -37,9 +37,9 @@ const Tensor OpenclConvolution::Forward(int flag, const Tensor &input) {
   auto data_type = input.data_type();
   auto device = input.device();
 
-   // TODO(wangwei) update the layer config if the input sample shape changes
+  // TODO(wangwei) update the layer config if the input sample shape changes
   CHECK(input.shape(1) == channels_ && input.shape(2) == height_ &&
-      input.shape(3) == width_) << "input sample shape should not change";
+        input.shape(3) == width_) << "input sample shape should not change";
 
   Shape shape{batchsize, num_filters_, conv_height_, conv_width_};
   Tensor output(shape, device, data_type);
@@ -48,16 +48,16 @@ const Tensor OpenclConvolution::Forward(int flag, const Tensor &input) {
   for (size_t b = 0; b < batchsize; b++) {
     int offset = b * imagesize;
 
-    col_data.device()->Exec([input, offset, col_data, this](Context* ctx) mutable {
+    col_data.device()->Exec([input, offset, col_data, this](Context * ctx) mutable {
 
       this->Im2Col(input.block(), offset,
-                   height_, width_,
-                   kernel_h_, kernel_w_,
-                   pad_h_, pad_w_,
-                   stride_h_, stride_w_,
-                   conv_height_, conv_width_,
-                   0, channels_,
-                   col_data.block(), ctx);
+      height_, width_,
+      kernel_h_, kernel_w_,
+      pad_h_, pad_w_,
+      stride_h_, stride_w_,
+      conv_height_, conv_width_,
+      0, channels_,
+      col_data.block(), ctx);
     },
     {input.block()},
     {col_data.block()});
@@ -116,16 +116,17 @@ OpenclConvolution::Backward(int flag, const Tensor &grad) {
     int im_offset = b * imagesize;
     int col_offset = 0; // Always keep this to zero.
 
-    col_data.device()->Exec([src_data, col_data, im_offset, col_offset, this](Context* ctx) mutable {
+    col_data.device()->Exec([src_data, col_data, im_offset, col_offset,
+    this](Context * ctx) mutable {
 
       this->Im2Col(src_data.block(), im_offset,
-                   height_, width_,
-                   kernel_h_, kernel_w_,
-                   pad_h_, pad_w_,
-                   stride_h_, stride_w_,
-                   conv_height_, conv_width_,
-                   col_offset, channels_,
-                   col_data.block(), ctx);
+      height_, width_,
+      kernel_h_, kernel_w_,
+      pad_h_, pad_w_,
+      stride_h_, stride_w_,
+      conv_height_, conv_width_,
+      col_offset, channels_,
+      col_data.block(), ctx);
     },
     {src_data.block()},
     {col_data.block()});
@@ -134,19 +135,20 @@ OpenclConvolution::Backward(int flag, const Tensor &grad) {
                   grad.device(), grad.data_type());
     CopyDataToFrom(&grad_b, grad, grad_b.Size(), 0, b * grad_b.Size());
 
-    dw += Mult(grad_b, col_data.T());
-    Tensor dcol_b = Mult(weight_.T(), grad_b);
+    dw += Mult(grad_b, Transpose(col_data));
+    Tensor dcol_b = Mult(Transpose(weight_), grad_b);
 
-    dx.device()->Exec([dcol_b, dx, im_offset, col_offset, this](Context* ctx) mutable {
+    dx.device()->Exec([dcol_b, dx, im_offset, col_offset,
+    this](Context * ctx) mutable {
 
       this->Col2Im(dcol_b.block(), col_offset,
-                   height_, width_,
-                   kernel_h_, kernel_w_,
-                   pad_h_, pad_w_,
-                   stride_h_, stride_w_,
-                   conv_height_, conv_width_,
-                   im_offset, channels_,
-                   dx.block(), ctx);
+      height_, width_,
+      kernel_h_, kernel_w_,
+      pad_h_, pad_w_,
+      stride_h_, stride_w_,
+      conv_height_, conv_width_,
+      im_offset, channels_,
+      dx.block(), ctx);
     },
     {dcol_b.block()},
     {dx.block()});

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b30d7ea5/src/model/layer/rnn.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/rnn.cc b/src/model/layer/rnn.cc
index b811f9d..e565abc 100644
--- a/src/model/layer/rnn.cc
+++ b/src/model/layer/rnn.cc
@@ -79,7 +79,7 @@ void RNN::Setup(const Shape& in_sample, const LayerConf &conf) {
       dim = hidden_size_ * (hidden_size_ +  hidden_size_ + 2);
     weight_size += mult * dim;
   }
-  weight_.Reshape(Shape{weight_size});
+  weight_.SetShape(Shape{weight_size});
 }
 
 const vector<Tensor> RNN::Forward(int flag, const vector<Tensor>& inputs) {

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b30d7ea5/src/model/operation/convolution.cc
----------------------------------------------------------------------
diff --git a/src/model/operation/convolution.cc b/src/model/operation/convolution.cc
index f700203..7c71d7c 100755
--- a/src/model/operation/convolution.cc
+++ b/src/model/operation/convolution.cc
@@ -4,7 +4,8 @@
 
 namespace singa {
 
-ConvHandle::ConvHandle(const Tensor &input, const std::vector<size_t>& kernel_size,
+ConvHandle::ConvHandle(const Tensor &input,
+                       const std::vector<size_t>& kernel_size,
                        const std::vector<size_t>& stride, const std::vector<size_t>& padding,
                        const size_t in_channels, const size_t out_channels,
                        const bool bias) {
@@ -23,7 +24,8 @@ ConvHandle::ConvHandle(const Tensor &input, const std::vector<size_t>& kernel_si
   bias_term = bias;
 
   batchsize = input.shape(0);
-  CHECK(input.shape(1) == in_channels) << "the number of input channels mismatched.";
+  CHECK(input.shape(1) == in_channels) <<
+                                       "the number of input channels mismatched.";
   height = input.shape(2);
   width = input.shape(3);
 
@@ -39,14 +41,16 @@ ConvHandle::ConvHandle(const Tensor &input, const std::vector<size_t>& kernel_si
 
 
 
-Tensor CpuConvForward(const Tensor &x, Tensor &W,  Tensor &b, const ConvHandle &ch) {
+Tensor CpuConvForward(const Tensor &x, Tensor &W,  Tensor &b,
+                      const ConvHandle &ch) {
   CHECK_EQ(x.device()->lang(), kCpp);
 
   CHECK(x.shape(1) == ch.channels && x.shape(2) == ch.height &&
         x.shape(3) == ch.width) << "input sample shape should not change";
 
   CHECK(W.shape(0) == ch.num_filters && W.shape(1) == ch.channels &&
-        W.shape(2) == ch.kernel_h && W.shape(3) == ch.kernel_w) << "weights shape should not change";
+        W.shape(2) == ch.kernel_h
+        && W.shape(3) == ch.kernel_w) << "weights shape should not change";
 
   Shape w_shape = W.shape();
   Shape b_shape;
@@ -67,8 +71,9 @@ Tensor CpuConvForward(const Tensor &x, Tensor &W,  Tensor &b, const ConvHandle &
   float *data_col = new float[ch.col_height * ch.col_width];
   auto in_data = x.data<float>();
   for (size_t num = 0; num < ch.batchsize; num++) {
-    Im2col(in_data + num * ch.imagesize, ch.channels, ch.height, ch.width, ch.kernel_h,
-             ch.kernel_w, ch.pad_h, ch.pad_w, ch.stride_h, ch.stride_w, data_col);
+    Im2col(in_data + num * ch.imagesize, ch.channels, ch.height, ch.width,
+           ch.kernel_h,
+           ch.kernel_w, ch.pad_h, ch.pad_w, ch.stride_h, ch.stride_w, data_col);
 
     col_data.CopyDataFromHostPtr(data_col, ch.col_height * ch.col_width);
     Tensor each = Mult(W, col_data);
@@ -83,14 +88,16 @@ Tensor CpuConvForward(const Tensor &x, Tensor &W,  Tensor &b, const ConvHandle &
   return output;
 }
 
-Tensor CpuConvBackwardx(const Tensor &dy, Tensor &W, const Tensor &x, const ConvHandle &ch) {
+Tensor CpuConvBackwardx(const Tensor &dy, Tensor &W, const Tensor &x,
+                        const ConvHandle &ch) {
   CHECK_EQ(dy.device()->lang(), kCpp);
 
   CHECK(dy.shape(1) == ch.num_filters && dy.shape(2) == ch.conv_height &&
         dy.shape(3) == ch.conv_width) << "input gradients shape should not change";
 
   CHECK(W.shape(0) == ch.num_filters && W.shape(1) == ch.channels &&
-        W.shape(2) == ch.kernel_h && W.shape(3) == ch.kernel_w) << "weights shape should not change";
+        W.shape(2) == ch.kernel_h
+        && W.shape(3) == ch.kernel_w) << "weights shape should not change";
 
   Shape w_shape = W.shape();
   W.Reshape(Shape{ch.num_filters, ch.col_height});
@@ -103,17 +110,19 @@ Tensor CpuConvBackwardx(const Tensor &dy, Tensor &W, const Tensor &x, const Conv
   for (size_t num = 0; num < ch.batchsize; num++) {
     Tensor grad_b(Shape{ch.num_filters, ch.conv_height * ch.conv_width});
     CopyDataToFrom(&grad_b, dy, grad_b.Size(), 0, num * grad_b.Size());
-    Tensor dcol_b = Mult(W.T(), grad_b);
+    Tensor dcol_b = Mult(Transpose(W), grad_b);
     auto dcol_data = dcol_b.data<float>();
-    Col2im(dcol_data, ch.channels, ch.height, ch.width, ch.kernel_h, ch.kernel_w, ch.pad_h,
-             ch.pad_w, ch.stride_h, ch.stride_w, dx_b);
+    Col2im(dcol_data, ch.channels, ch.height, ch.width, ch.kernel_h, ch.kernel_w,
+           ch.pad_h,
+           ch.pad_w, ch.stride_h, ch.stride_w, dx_b);
     dx.CopyDataFromHostPtr(dx_b, ch.imagesize, num * ch.imagesize);
   }
   W.Reshape(w_shape);
   return dx;
 }
 
-Tensor CpuConvBackwardW(const Tensor &dy, const Tensor &x, const Tensor &W, const ConvHandle &ch) {
+Tensor CpuConvBackwardW(const Tensor &dy, const Tensor &x, const Tensor &W,
+                        const ConvHandle &ch) {
   CHECK_EQ(dy.device()->lang(), kCpp);
 
   CHECK(dy.shape(1) == ch.num_filters && dy.shape(2) == ch.conv_height &&
@@ -134,18 +143,20 @@ Tensor CpuConvBackwardW(const Tensor &dy, const Tensor &x, const Tensor &W, cons
   float *data_col = new float[ch.col_height * ch.col_width];
   auto in_data = dy.data<float>();
   for (size_t num = 0; num < ch.batchsize; num++) {
-    Im2col(in_data + num * ch.imagesize, ch.channels, ch.height, ch.width, ch.kernel_h,
-             ch.kernel_w, ch.pad_h, ch.pad_w, ch.stride_h, ch.stride_w, data_col);
+    Im2col(in_data + num * ch.imagesize, ch.channels, ch.height, ch.width,
+           ch.kernel_h,
+           ch.kernel_w, ch.pad_h, ch.pad_w, ch.stride_h, ch.stride_w, data_col);
     col_data.CopyDataFromHostPtr(data_col, ch.col_height * ch.col_width);
     Tensor grad_b(Shape{ch.num_filters, ch.conv_height * ch.conv_width});
     CopyDataToFrom(&grad_b, dy, grad_b.Size(), 0, num * grad_b.Size());
-    dW += Mult(grad_b, col_data.T());
+    dW += Mult(grad_b, Transpose(col_data));
   }
   dW.Reshape(w_shape);
   return dW;
 }
 
-Tensor CpuConvBackwardb(const Tensor &dy, const Tensor &b, const ConvHandle &ch) {
+Tensor CpuConvBackwardb(const Tensor &dy, const Tensor &b,
+                        const ConvHandle &ch) {
   CHECK_EQ(dy.device()->lang(), kCpp);
 
   CHECK(dy.shape(1) == ch.num_filters && dy.shape(2) == ch.conv_height &&
@@ -169,11 +180,13 @@ Tensor CpuConvBackwardb(const Tensor &dy, const Tensor &b, const ConvHandle &ch)
 };
 
 #ifdef USE_CUDNN
-CudnnConvHandle::CudnnConvHandle(const Tensor &input, const std::vector<size_t>& kernel_size,
+CudnnConvHandle::CudnnConvHandle(const Tensor &input,
+                                 const std::vector<size_t>& kernel_size,
                                  const std::vector<size_t>& stride, const std::vector<size_t>& padding,
                                  const size_t in_channels, const size_t out_channels, const bool bias,
                                  const size_t workspace_byte_limit, const std::string& prefer)
-  : ConvHandle(input, kernel_size, stride, padding, in_channels, out_channels, bias) {
+  : ConvHandle(input, kernel_size, stride, padding, in_channels, out_channels,
+               bias) {
 
   DataType dtype = input.data_type();
   auto dev = input.device();
@@ -203,7 +216,7 @@ CudnnConvHandle::CudnnConvHandle(const Tensor &input, const std::vector<size_t>&
 #if CUDNN_MAJOR >= 7
               , GetCudnnDataType(dtype)
 #endif
-              ));
+                                             ));
   CUDNN_CHECK(cudnnSetFilter4dDescriptor(filter_desc, GetCudnnDataType(dtype),
                                          CUDNN_TENSOR_NCHW, num_filters,
                                          channels, kernel_h, kernel_w));
@@ -268,8 +281,8 @@ CudnnConvHandle::CudnnConvHandle(const Tensor &input, const std::vector<size_t>&
                 ctx->cudnn_handle, x_desc, y_desc, conv_desc, filter_desc,
                 bp_filter_alg, &bp_filter_byte));
   workspace_count = std::max(std::max(fp_byte, bp_data_byte), bp_filter_byte) /
-                     sizeof(float) +
-                     1;
+                    sizeof(float) +
+                    1;
   if (workspace_count * sizeof(float) > workspace_byte_limit)
     LOG(WARNING) << "The required memory for workspace ("
                  << workspace_count * sizeof(float)
@@ -289,7 +302,8 @@ CudnnConvHandle::~CudnnConvHandle() {
   if (y_desc != nullptr) CUDNN_CHECK(cudnnDestroyTensorDescriptor(y_desc));
 }
 
-Tensor GpuConvForward(const Tensor &x, const Tensor &W, const Tensor &b, const CudnnConvHandle &cch) {
+Tensor GpuConvForward(const Tensor &x, const Tensor &W, const Tensor &b,
+                      const CudnnConvHandle &cch) {
   CHECK_EQ(x.device()->lang(), kCuda);
 
   DataType dtype = x.data_type();
@@ -323,7 +337,8 @@ Tensor GpuConvForward(const Tensor &x, const Tensor &W, const Tensor &b, const C
   return output;
 }
 
-Tensor GpuConvBackwardx(const Tensor &dy, const Tensor &W, const Tensor &x, const CudnnConvHandle &cch) {
+Tensor GpuConvBackwardx(const Tensor &dy, const Tensor &W, const Tensor &x,
+                        const CudnnConvHandle &cch) {
   CHECK_EQ(dy.device()->lang(), kCuda);
 
   Tensor dx;
@@ -344,7 +359,8 @@ Tensor GpuConvBackwardx(const Tensor &dy, const Tensor &W, const Tensor &x, cons
   return dx;
 }
 
-Tensor GpuConvBackwardW(const Tensor &dy, const Tensor &x, const Tensor &W, const CudnnConvHandle &cch) {
+Tensor GpuConvBackwardW(const Tensor &dy, const Tensor &x, const Tensor &W,
+                        const CudnnConvHandle &cch) {
   CHECK_EQ(dy.device()->lang(), kCuda);
 
   Tensor dW;
@@ -366,7 +382,8 @@ Tensor GpuConvBackwardW(const Tensor &dy, const Tensor &x, const Tensor &W, cons
 }
 
 // input Tensor b for Reset db purpose, can avoid this later.
-Tensor GpuConvBackwardb(const Tensor &dy, const Tensor &b, const CudnnConvHandle &cch) {
+Tensor GpuConvBackwardb(const Tensor &dy, const Tensor &b,
+                        const CudnnConvHandle &cch) {
   CHECK_EQ(dy.device()->lang(), kCuda);
 
   Tensor db;

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b30d7ea5/src/model/updater/local_updater.cc
----------------------------------------------------------------------
diff --git a/src/model/updater/local_updater.cc b/src/model/updater/local_updater.cc
index c3c6793..04593f4 100644
--- a/src/model/updater/local_updater.cc
+++ b/src/model/updater/local_updater.cc
@@ -43,7 +43,7 @@ void LocalUpdater::Apply(int step, const string& name, Tensor& grad,
   int nth = dev_index_[name]++;
   auto key = std::make_pair(nth, name);
   if (grad_buffer_[key].Size() != grad.Size()) {
-    grad_buffer_[key].Reshape(grad.shape());
+    grad_buffer_[key].SetShape(grad.shape());
     grad_buffer_[key].AsType(grad.data_type());
   }
   grad_buffer_[key].CopyData(grad);
@@ -56,7 +56,7 @@ void LocalUpdater::Apply(int step, const string& name, Tensor& grad,
     }
   } else {
     if (param_buffer_[name].Size() != value.Size()) {
-      param_buffer_[name].Reshape(value.shape());
+      param_buffer_[name].SetShape(value.shape());
       param_buffer_[name].AsType(value.data_type());
       param_buffer_[name].CopyData(value);
       sum_[name].ResetLike(param_buffer_[name]);