You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@singa.apache.org by wa...@apache.org on 2018/06/30 13:30:41 UTC

[1/2] incubator-singa git commit: SINGA-370 Improvement to tensor reshape and various misc. changes related to SINGA-341 and 351

Repository: incubator-singa
Updated Branches:
  refs/heads/master 9b06c2971 -> c8df172c7


SINGA-370 Improvement to tensor reshape and various misc. changes related to SINGA-341 and 351

converted some upper-case functions to lower-case

fixed return types for some non-void cuda functions

added check_cudnn to check for CUDNN_STATUS_SUCCESS for all cudnn functions

removal of set_strides usage in cuda file and replacement with additional cudnn transform functions

updated all unary and binary cuda functions for transform

updated reshape function to return tensor

fixed FromProto function by removing original reshape

change to reshape to support in-place operations as well as return new tensor

added constructors for Transform function (GenUnaryTensorFn) similar to cudnn's transform function

updated tensor_math.h for transform

added Transform function for cpp and cuda

added Transform to reshape instead of add


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/f9e7caaf
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/f9e7caaf
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/f9e7caaf

Branch: refs/heads/master
Commit: f9e7caaf5bf91bd3236b8d1532b192d4416999e6
Parents: c343ff9
Author: Vaan Ng <cm...@gmail.com>
Authored: Wed May 30 13:27:44 2018 +0800
Committer: Vaan Ng <cm...@gmail.com>
Committed: Fri Jun 1 19:02:25 2018 +0800

----------------------------------------------------------------------
 include/singa/core/tensor.h        |   6 +-
 src/core/tensor/tensor.cc          | 151 +++++++---
 src/core/tensor/tensor_math.h      |   8 +
 src/core/tensor/tensor_math_cpp.h  |  62 ++--
 src/core/tensor/tensor_math_cuda.h | 502 ++++++++++++++++++++++++--------
 5 files changed, 543 insertions(+), 186 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/f9e7caaf/include/singa/core/tensor.h
----------------------------------------------------------------------
diff --git a/include/singa/core/tensor.h b/include/singa/core/tensor.h
index 3cc28ff..47c8874 100644
--- a/include/singa/core/tensor.h
+++ b/include/singa/core/tensor.h
@@ -133,8 +133,8 @@ class Tensor {
   size_t MemSize() const { return block_->size(); }
 
   /// Reset the tensor shape, it may reallocate block, if MemSize() changes.
-  void Reshape(const Shape &shape);
-  void Reshape(Shape &&shape);
+  Tensor Reshape(const Shape &shape);
+  Tensor Reshape(Shape &&shape);
 
   /// Reset the shape, device, and data type as given tensor.
   /// If block size changes, then reallocate a new block.
@@ -297,6 +297,7 @@ Tensor Sign(const Tensor &in);
 Tensor Sqrt(const Tensor &in);
 Tensor Square(const Tensor &in);
 Tensor Tanh(const Tensor &in);
+Tensor Transform(const Tensor &in);
 
 void Abs(const Tensor &in, Tensor *out);
 void Exp(const Tensor &in, Tensor *out);
@@ -307,6 +308,7 @@ void Sign(const Tensor &in, Tensor *out);
 void Sqrt(const Tensor &in, Tensor *out);
 void Square(const Tensor &in, Tensor *out);
 void Tanh(const Tensor &in, Tensor *out);
+void Transform(const Tensor &in, Tensor *out);
 
 /// Element-wise opeartion, out[i]=in[i]^x
 template <typename SType>

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/f9e7caaf/src/core/tensor/tensor.cc
----------------------------------------------------------------------
diff --git a/src/core/tensor/tensor.cc b/src/core/tensor/tensor.cc
index de0d7d2..182a9eb 100644
--- a/src/core/tensor/tensor.cc
+++ b/src/core/tensor/tensor.cc
@@ -127,35 +127,36 @@ void Tensor::ResetLike(const Tensor &in) {
 // if tensor is already transposed i.e strides != 1,
 // it should be copied to a new tensor with newly generated default strides
 // TODO(wangwei) raise error if the shape not match
-void Tensor::Reshape(const Shape &shape) {
-  if (strides_.size() == 0)
-    strides_.push_back(1);
-
-  if (Product(shape_) != Product(shape)) {
-    if (block_ != nullptr && block_->DecRefCount() == 0)
-      device_->FreeBlock(block_);
-    block_ = device_->NewBlock((int)(Product(shape) * SizeOf(data_type_)));
-  } else if (transpose()) {
-    LOG(FATAL) << "Reshape Error: Reshape called on tranposed tensor. Not implemented yet." ;
-  }
-  shape_ = shape;
-  generate_strides();
-}
 
-void Tensor::Reshape(Shape &&shape) {
-  if (strides_.size() == 0)
-    strides_.push_back(1);
-
-  if (Product(shape_) != Product(shape)) {
-    if (block_ != nullptr && block_->DecRefCount() == 0)
-      device_->FreeBlock(block_);
-    block_ = device_->NewBlock((int)(Product(shape) * SizeOf(data_type_)));
-  } else if (transpose()) {
-    LOG(FATAL) << "Reshape Error: Reshape called on tranposed tensor. Not implemented yet." ;
-  }
-  shape_ = std::move(shape);
-  generate_strides();
-}
+// void Tensor::Reshape(const Shape &shape) {
+//   if (strides_.size() == 0)
+//     strides_.push_back(1);
+
+//   if (Product(shape_) != Product(shape)) {
+//     if (block_ != nullptr && block_->DecRefCount() == 0)
+//       device_->FreeBlock(block_);
+//     block_ = device_->NewBlock((int)(Product(shape) * SizeOf(data_type_)));
+//   } else if (transpose()) {
+//     LOG(FATAL) << "Reshape Error: Reshape called on tranposed tensor. Not implemented yet." ;
+//   }
+//   shape_ = shape;
+//   generate_strides();
+// }
+
+// void Tensor::Reshape(Shape &&shape) {
+//   if (strides_.size() == 0)
+//     strides_.push_back(1);
+
+//   if (Product(shape_) != Product(shape)) {
+//     if (block_ != nullptr && block_->DecRefCount() == 0)
+//       device_->FreeBlock(block_);
+//     block_ = device_->NewBlock((int)(Product(shape) * SizeOf(data_type_)));
+//   } else if (transpose()) {
+//     LOG(FATAL) << "Reshape Error: Reshape called on tranposed tensor. Not implemented yet." ;
+//   }
+//   shape_ = std::move(shape);
+//   generate_strides();
+// }
 
 void Tensor::AsType(const DataType type) {
   if (data_type_ != type) {
@@ -218,10 +219,9 @@ void Tensor::FromProto(const singa::TensorProto &proto) {
   if (block_ != nullptr && block_->DecRefCount() == 0)
     device_->FreeBlock(block_);
   block_ = nullptr;
-  Shape shape;
-  for (uint32_t s : proto.shape()) shape.push_back(s);
+  for (uint32_t s : proto.shape()) shape_.push_back(s);
   data_type_ = proto.data_type();
-  Reshape(shape);
+  block_ = device_->NewBlock((int)(Product(shape()) * SizeOf(data_type_)));
   //transpose_ = proto.transpose();
   strides_.clear();
   for (int32_t s : proto.strides()) strides_.push_back(s);
@@ -329,7 +329,6 @@ Tensor Tensor::Clone(std::shared_ptr<Device> device) const {
   return t;
 }
 
-//yisen todo
 Tensor Tensor::T() const {
   // this function only works for 2d tensors
   CHECK_EQ(shape_.size(), 2u);
@@ -416,18 +415,17 @@ Tensor &Tensor::operator=(Tensor &&in) {
   return *this;
 }
 
-//yisen todo
-Tensor Reshape(const Tensor &in, const Shape &s) {
-  Tensor out(in);
-  out.Reshape(s);
-  return out;
-}
+// Tensor Reshape(const Tensor &in, const Shape &s) {
+//   // Tensor out(in);
+//   // out.Reshape(s);
+//   return out;
+// }
 
-Tensor Reshape(const Tensor &in, Shape &&s) {
-  Tensor out(in);
-  out.Reshape(std::move(s));
-  return out;
-}
+// Tensor Reshape(const Tensor &in, Shape &&s) {
+//   // Tensor out(in);
+//   // out.Reshape(std::move(s));
+//   return out;
+// }
 
 #define GenUnaryTensorArgMemberFn(op, fn) \
   Tensor &Tensor::op(const Tensor &in) {  \
@@ -615,6 +613,7 @@ GenUnaryTensorFn(Sign);
 GenUnaryTensorFn(Sqrt);
 GenUnaryTensorFn(Square);
 GenUnaryTensorFn(Tanh);
+GenUnaryTensorFn(Transform);
 
 #define EltwiseBinaryTensorFn(fn, lhs, rhs, ret)                            \
   do {                                                                      \
@@ -1181,4 +1180,70 @@ void SoftmaxCrossEntropyBwd(const Tensor &t, Tensor *p) {
   });
 }
 
+Tensor Tensor::Reshape(const Shape &shape) {
+  if (strides_.size() == 0)
+    strides_.push_back(1);
+
+  if (Product(shape_) != Product(shape)) {
+    if (block_ != nullptr && block_->DecRefCount() == 0)
+      device_->FreeBlock(block_);
+    block_ = device_->NewBlock((int)(Product(shape) * SizeOf(data_type_)));
+    shape_ = shape;
+    generate_strides();
+    return *this;
+
+  } else if (transpose()) {
+    Tensor t(shape_, device_, data_type_);
+    t.block_ = t.device()->NewBlock((int)(Product(shape) * SizeOf(data_type_)));
+    singa::Transform(*this, &t);
+    t.shape_ = shape;
+    return t;
+ }
+
+  shape_ = shape;
+  generate_strides();
+  Tensor t(shape, device_, data_type_);
+  t.block_ = block_;
+  t.block_->IncRefCount();
+  return t;
+}
+
+Tensor Tensor::Reshape(Shape &&shape) {
+  if (strides_.size() == 0)
+    strides_.push_back(1);
+
+  if (Product(shape_) != Product(shape)) {
+    if (block_ != nullptr && block_->DecRefCount() == 0)
+      device_->FreeBlock(block_);
+    block_ = device_->NewBlock((int)(Product(shape) * SizeOf(data_type_)));
+    shape_ = std::move(shape);
+    generate_strides();
+    return *this;
+
+  } else if (transpose()) {
+    Tensor t(shape_, device_, data_type_);
+    t.block_ = t.device()->NewBlock((int)(Product(shape) * SizeOf(data_type_)));
+    singa::Transform(*this, &t);
+    t.shape_ = shape;
+    return t;
+ }
+
+  shape_ = shape;
+  generate_strides();
+  Tensor t(shape, device_, data_type_);
+  t.block_ = block_;
+  t.block_->IncRefCount();
+  return t;
+}
+
+Tensor Reshape(const Tensor &in, const Shape &s) {
+  Tensor out(in);
+  return out.Reshape(s);
+}
+
+Tensor Reshape(const Tensor &in, Shape &&s) {
+  Tensor out(in);
+  return out.Reshape(std::move(s));
+}
+
 }  // namespace singa

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/f9e7caaf/src/core/tensor/tensor_math.h
----------------------------------------------------------------------
diff --git a/src/core/tensor/tensor_math.h b/src/core/tensor/tensor_math.h
index c7fdfe5..f438fc6 100644
--- a/src/core/tensor/tensor_math.h
+++ b/src/core/tensor/tensor_math.h
@@ -251,6 +251,14 @@ void Tanh(const Tensor &in, Tensor *out, Context *ctx) {
   LOG(FATAL) << "Tanh Not Implemented";
 }
 
+/// similar to cudnnTransformTensor
+/// copies the data from one tensor to another tensor with a different layout
+/// the tensors must have the same dimensions but not necessarily the same strides 
+template <typename DType, typename Lang>
+void Transform(const Tensor &in, Tensor *out, Context *ctx) {
+  LOG(FATAL) << "Transform Not Implemented";
+}
+
 // **************************************
 // Random functions
 // **************************************

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/f9e7caaf/src/core/tensor/tensor_math_cpp.h
----------------------------------------------------------------------
diff --git a/src/core/tensor/tensor_math_cpp.h b/src/core/tensor/tensor_math_cpp.h
index bfdd026..5ae3d41 100644
--- a/src/core/tensor/tensor_math_cpp.h
+++ b/src/core/tensor/tensor_math_cpp.h
@@ -107,7 +107,7 @@ void traverse_next(const Tensor& x,
 };
 
 template <typename DType>
-void TraverseUnary(const Tensor &in, Tensor* out, std::function<DType(DType)> func) {
+void traverse_unary(const Tensor &in, Tensor* out, std::function<DType(DType)> func) {
   DType *outPtr = static_cast<DType *>(out->block()->mutable_data());
   const DType *inPtr = static_cast<const DType *>(in.block()->data());
   vector<int> traversal_info = generate_traversal_info(in);
@@ -120,7 +120,7 @@ void TraverseUnary(const Tensor &in, Tensor* out, std::function<DType(DType)> fu
 }
 
 template <typename DType>
-void TraverseBinary(const Tensor &in1, const Tensor &in2, Tensor* out,
+void traverse_binary(const Tensor &in1, const Tensor &in2, Tensor* out,
                     std::function<DType(DType, DType)> func) {
   DType *outPtr = static_cast<DType *>(out->block()->mutable_data());
   const DType *in1Ptr = static_cast<const DType *>(in1.block()->data());
@@ -146,7 +146,7 @@ void TraverseBinary(const Tensor &in1, const Tensor &in2, Tensor* out,
 
 template <>
 void Abs<float, lang::Cpp>(const Tensor& in, Tensor* out, Context *ctx) {
-  TraverseUnary<float>(in, out, [](float x) {return fabs(x);});
+  traverse_unary<float>(in, out, [](float x) {return fabs(x);});
 }
 
 template <>
@@ -154,7 +154,7 @@ void Add<float, lang::Cpp>(const Tensor& in, const float x, Tensor* out, Context
   auto add_lambda = [&x](float a) {
     return (a + x);
   };
-  TraverseUnary<float>(in, out, add_lambda);
+  traverse_unary<float>(in, out, add_lambda);
 }
 
 template <>
@@ -163,7 +163,7 @@ void Add<float, lang::Cpp>(const Tensor& in1, const Tensor& in2, Tensor* out, Co
   auto add_lambda_binary = [](float a, float b) {
     return (a + b);
   };
-  TraverseBinary<float>(in1, in2, out, add_lambda_binary);
+  traverse_binary<float>(in1, in2, out, add_lambda_binary);
 
 }
 
@@ -176,7 +176,7 @@ void Clamp<float, lang::Cpp>(const float low, const float high,
     else if (a > high) {return high;}
     else {return a;}
   };
-  TraverseUnary<float>(in, out, clamp_lambda);
+  traverse_unary<float>(in, out, clamp_lambda);
 }
 
 template <>
@@ -219,7 +219,7 @@ void EltwiseMult<float, lang::Cpp>(const Tensor& in, const float x, Tensor* out,
   auto eltwisemult_lambda = [&x](float a) {
     return (a * x);
   };
-  TraverseUnary<float>(in, out, eltwisemult_lambda);
+  traverse_unary<float>(in, out, eltwisemult_lambda);
 }
 
 template <>
@@ -228,12 +228,12 @@ void EltwiseMult<float, lang::Cpp>(const Tensor& in1, const Tensor& in2, Tensor*
   auto eltwisemult_lambda_binary = [](float a, float b) {
     return (a * b);
   };
-  TraverseBinary<float>(in1, in2, out, eltwisemult_lambda_binary);
+  traverse_binary<float>(in1, in2, out, eltwisemult_lambda_binary);
 }
 
 template <>
 void Exp<float, lang::Cpp>(const Tensor& in, Tensor *out, Context *ctx) {
-  TraverseUnary<float>(in, out, [](float x) {return exp(x);});
+  traverse_unary<float>(in, out, [](float x) {return exp(x);});
 }
 
 template <>
@@ -242,7 +242,7 @@ void GE<float, lang::Cpp>(const Tensor& in, const float x, Tensor* out,
   auto ge_lambda = [&x](float a) {
     return (a >= x) ? 1.f : 0.f;
   };
-  TraverseUnary<float>(in, out, ge_lambda);
+  traverse_unary<float>(in, out, ge_lambda);
 }
 
 template <>
@@ -251,7 +251,7 @@ void GE<float, lang::Cpp>(const Tensor& in1, const Tensor& in2, Tensor* out,
   auto ge_lambda_binary = [](float a, float b) {
     return (a >= b) ? 1.f : 0.f;
   };
-  TraverseBinary<float>(in1, in2, out, ge_lambda_binary);
+  traverse_binary<float>(in1, in2, out, ge_lambda_binary);
 }
 
 template <>
@@ -260,7 +260,7 @@ void GT<float, lang::Cpp>(const Tensor& in, const float x, Tensor* out,
   auto gt_lambda = [&x](float a) {
     return (a > x) ? 1.f : 0.f;
   };
-  TraverseUnary<float>(in, out, gt_lambda);
+  traverse_unary<float>(in, out, gt_lambda);
 }
 
 template <>
@@ -269,7 +269,7 @@ void GT<float, lang::Cpp>(const Tensor& in1, const Tensor& in2, Tensor* out,
   auto gt_lambda_binary = [](float a, float b) {
     return (a > b) ? 1.f : 0.f;
   };
-  TraverseBinary<float>(in1, in2, out, gt_lambda_binary);
+  traverse_binary<float>(in1, in2, out, gt_lambda_binary);
 }
 
 template <>
@@ -278,7 +278,7 @@ void LE<float, lang::Cpp>(const Tensor& in, const float x, Tensor* out,
   auto le_lambda = [&x](float a) {
     return (a <= x) ? 1.f : 0.f;
   };
-  TraverseUnary<float>(in, out, le_lambda);
+  traverse_unary<float>(in, out, le_lambda);
 }
 
 template <>
@@ -287,7 +287,7 @@ void LE<float, lang::Cpp>(const Tensor& in1, const Tensor& in2, Tensor* out,
   auto le_lambda_binary = [](float a, float b) {
     return (a <= b) ? 1.f : 0.f;
   };
-  TraverseBinary<float>(in1, in2, out, le_lambda_binary);
+  traverse_binary<float>(in1, in2, out, le_lambda_binary);
 }
 
 template <>
@@ -311,7 +311,7 @@ void LT<float, lang::Cpp>(const Tensor& in, const float x, Tensor* out,
   auto lt_lambda = [&x](float a) {
     return (a < x) ? 1.f : 0.f;
   };
-  TraverseUnary<float>(in, out, lt_lambda);
+  traverse_unary<float>(in, out, lt_lambda);
 }
 
 
@@ -321,12 +321,12 @@ void LT<float, lang::Cpp>(const Tensor& in1, const Tensor& in2, Tensor* out,
   auto lt_lambda_binary = [](float a, float b) {
     return (a < b) ? 1.f : 0.f;
   };
-  TraverseBinary<float>(in1, in2, out, lt_lambda_binary);
+  traverse_binary<float>(in1, in2, out, lt_lambda_binary);
 }
 
 template <>
 void Pow<float, lang::Cpp>(const Tensor& in, const float x, Tensor *out, Context *ctx) {
-  TraverseUnary<float>(in, out, [x](float y) {return pow(y, x);});
+  traverse_unary<float>(in, out, [x](float y) {return pow(y, x);});
 }
 
 template <>
@@ -335,7 +335,7 @@ void Pow<float, lang::Cpp>(const Tensor& in1, const Tensor& in2, Tensor* out,
   auto pow_lambda_binary = [](float a, float b) {
     return pow(a, b);
   };
-  TraverseBinary<float>(in1, in2, out, pow_lambda_binary);
+  traverse_binary<float>(in1, in2, out, pow_lambda_binary);
 }
 
 template <>
@@ -344,7 +344,7 @@ void ReLU<float, lang::Cpp>(const Tensor& in, Tensor* out,
   auto relu_lambda = [](float a) {
     return (a >= 0.f) ? a : 0.f;
   };
-  TraverseUnary<float>(in, out, relu_lambda);
+  traverse_unary<float>(in, out, relu_lambda);
 }
 
 template <>
@@ -367,7 +367,7 @@ void Sigmoid<float, lang::Cpp>(const Tensor& in, Tensor* out,
   auto sigmoid_lambda = [](float a) {
     return 1.f / (1.f + exp(-a));
   };
-  TraverseUnary<float>(in, out, sigmoid_lambda);
+  traverse_unary<float>(in, out, sigmoid_lambda);
 }
 
 template <>
@@ -376,7 +376,7 @@ void Sign<float, lang::Cpp>(const Tensor& in, Tensor* out,
   auto sign_lambda = [](float a) {
     return (a > 0) - (a < 0);
   };
-  TraverseUnary<float>(in, out, sign_lambda);
+  traverse_unary<float>(in, out, sign_lambda);
 }
 
 template <>
@@ -401,7 +401,7 @@ void Sub<float, lang::Cpp>(const Tensor& in1, const Tensor& in2,
   auto sub_lambda_binary = [](float a, float b) {
     return (a - b);
   };
-  TraverseBinary<float>(in1, in2, out, sub_lambda_binary);
+  traverse_binary<float>(in1, in2, out, sub_lambda_binary);
 }
 
 // sum all elements of input into out
@@ -423,7 +423,21 @@ void Tanh<float, lang::Cpp>(const Tensor& in, Tensor* out,
   auto tanh_lambda = [](float a) {
     return tanh(a);
   };
-  TraverseUnary<float>(in, out, tanh_lambda);
+  traverse_unary<float>(in, out, tanh_lambda);
+}
+
+template <>
+void Transform<float, lang::Cpp>(const Tensor& in, Tensor* out,
+                            Context *ctx) {
+  float *outPtr = static_cast<float *>(out->block()->mutable_data());
+  const float *inPtr = static_cast<const float *>(in.block()->data());
+  vector<int> traversal_info = generate_traversal_info(in);
+  vector<int> shape_multipliers = generate_shape_multipliers(in);
+
+  for (size_t i = 0; i < in.Size(); i++) {
+    outPtr[i] = inPtr[traversal_info[in.shape().size()]];
+    traverse_next(in, shape_multipliers, traversal_info, i + 1);
+  }
 }
 
 template <>

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/f9e7caaf/src/core/tensor/tensor_math_cuda.h
----------------------------------------------------------------------
diff --git a/src/core/tensor/tensor_math_cuda.h b/src/core/tensor/tensor_math_cuda.h
index 55d6a1b..a1b9381 100644
--- a/src/core/tensor/tensor_math_cuda.h
+++ b/src/core/tensor/tensor_math_cuda.h
@@ -30,6 +30,16 @@
 #include "singa/utils/cuda_utils.h"
 #include <cudnn.h>
 
+#define check_cudnn(expression)                              \
+  {                                                          \
+    cudnnStatus_t status = (expression);                     \
+    if (status != CUDNN_STATUS_SUCCESS) {                    \
+      LOG(FATAL) << "Error on line " << __LINE__ << ": "     \
+                 << cudnnGetErrorString(status) << " ";      \
+    }                                                        \
+  }
+
+
 namespace singa {
 
 // ===================== Helper Functions =============================
@@ -62,6 +72,7 @@ vector<int> generate_shape_cuda(const Tensor& x) {
   } else {
     LOG(FATAL) << "Dimensions (shape) beyond 5 are currently not supported" ;
   }
+  return shape_arr;
 }
 
 int generate_dim_cuda(const Tensor& x) {
@@ -70,6 +81,7 @@ int generate_dim_cuda(const Tensor& x) {
   else {
     LOG(FATAL) << "Dimensions (shape) beyond 5 are currently not supported" ;
   }
+  return 0;
 }
 
 /*
@@ -105,27 +117,28 @@ vector<int> generate_strides_cuda(const Tensor& x) {
   } else {
     LOG(FATAL) << "Dimensions (strides) beyond 5 are currently not supported" ;
   }
+  return strides_arr;
 }
 
-cudnnTensorDescriptor_t generate_tensorND_desc(const Tensor& x) {
+cudnnTensorDescriptor_t generate_tensor_nd_desc(const Tensor& x) {
   cudnnTensorDescriptor_t x_desc;
-  cudnnCreateTensorDescriptor(&x_desc);
-  cudnnSetTensorNdDescriptor(x_desc, CUDNN_DATA_FLOAT,
+  check_cudnn(cudnnCreateTensorDescriptor(&x_desc));
+  check_cudnn(cudnnSetTensorNdDescriptor(x_desc, CUDNN_DATA_FLOAT,
                              generate_dim_cuda(x),
                              generate_shape_cuda(x).data(),
                              generate_strides_cuda(x).data()
-                            );
+                            ));
 
   return x_desc;
 }
 
-cudnnOpTensorDescriptor_t generate_Op_desc(cudnnOpTensorOp_t op) {
+cudnnOpTensorDescriptor_t generate_op_desc(cudnnOpTensorOp_t op) {
   cudnnOpTensorDescriptor_t op_desc;
-  cudnnCreateOpTensorDescriptor(&op_desc);
-  cudnnSetOpTensorDescriptor(op_desc, op,
+  check_cudnn(cudnnCreateOpTensorDescriptor(&op_desc));
+  check_cudnn(cudnnSetOpTensorDescriptor(op_desc, op,
                              CUDNN_DATA_FLOAT,
                              CUDNN_PROPAGATE_NAN
-                            );
+                            ));
 
   return op_desc;
 }
@@ -142,12 +155,12 @@ void Abs<float, lang::Cuda>(const Tensor& in, Tensor* out,
   float alpha1 = 1.0;
   float alpha2 = -1.0;
   float beta = 0.0;
-  cudnnTensorDescriptor_t in_desc = generate_tensorND_desc(in);
-  cudnnOpTensor(ctx->cudnn_handle, generate_Op_desc(CUDNN_OP_TENSOR_MAX),
+  cudnnTensorDescriptor_t in_desc = generate_tensor_nd_desc(in);
+  check_cudnn(cudnnOpTensor(ctx->cudnn_handle, generate_op_desc(CUDNN_OP_TENSOR_MAX),
                 (void*)(&alpha1), in_desc, inPtr,
                 (void*)(&alpha2), in_desc, inPtr,
-                (void*)(&beta), generate_tensorND_desc(*out), outPtr
-               );
+                (void*)(&beta), generate_tensor_nd_desc(*out), outPtr
+               ));
   cudnnDestroyTensorDescriptor(in_desc);
 }
 
@@ -156,8 +169,8 @@ void Set<float, lang::Cuda>(const float x, Tensor* out,
                             Context* ctx) {
   float* outPtr = static_cast<float*>(out->block()->mutable_data());
 
-  cudnnSetTensor(ctx->cudnn_handle, generate_tensorND_desc(*out),
-                 outPtr, (void*)(&x));
+  check_cudnn(cudnnSetTensor(ctx->cudnn_handle, generate_tensor_nd_desc(*out),
+                 outPtr, (void*)(&x)));
 }
 
 template <>
@@ -168,10 +181,10 @@ void Add<float, lang::Cuda>(const Tensor& in, const float x,
   float* outPtr = static_cast<float*>(out->block()->mutable_data());
 
   float alpha = 1.0, beta = 1.0;
-  cudnnAddTensor(ctx->cudnn_handle,
-                 (void*)(&alpha), generate_tensorND_desc(in), inPtr,
-                 (void*)(&beta), generate_tensorND_desc(*out), outPtr
-                );
+  check_cudnn(cudnnAddTensor(ctx->cudnn_handle,
+                 (void*)(&alpha), generate_tensor_nd_desc(in), inPtr,
+                 (void*)(&beta), generate_tensor_nd_desc(*out), outPtr
+                ));
 }
 
 /// out = in1 + in2
@@ -187,17 +200,17 @@ void Add<float, lang::Cuda>(const Tensor& in1,
   float beta = 0.0;
 
   if ((in1.nDim() == in2.nDim()) || (in2.nDim() == 1)) {
-    cudnnOpTensor(ctx->cudnn_handle, generate_Op_desc(CUDNN_OP_TENSOR_ADD),
-                  (void*)(&alpha1), generate_tensorND_desc(in1), inPtr1,
-                  (void*)(&alpha2), generate_tensorND_desc(in2), inPtr2,
-                  (void*)(&beta), generate_tensorND_desc(*out), outPtr
-                 );
+    check_cudnn(cudnnOpTensor(ctx->cudnn_handle, generate_op_desc(CUDNN_OP_TENSOR_ADD),
+                  (void*)(&alpha1), generate_tensor_nd_desc(in1), inPtr1,
+                  (void*)(&alpha2), generate_tensor_nd_desc(in2), inPtr2,
+                  (void*)(&beta), generate_tensor_nd_desc(*out), outPtr
+                 ));
   } else {
-    cudnnOpTensor(ctx->cudnn_handle, generate_Op_desc(CUDNN_OP_TENSOR_ADD),
-                  (void*)(&alpha1), generate_tensorND_desc(in1), inPtr1,
-                  (void*)(&alpha2), generate_tensorND_desc(in1), inPtr2,
-                  (void*)(&beta), generate_tensorND_desc(*out), outPtr
-                 );
+    check_cudnn(cudnnOpTensor(ctx->cudnn_handle, generate_op_desc(CUDNN_OP_TENSOR_ADD),
+                  (void*)(&alpha1), generate_tensor_nd_desc(in1), inPtr1,
+                  (void*)(&alpha2), generate_tensor_nd_desc(in1), inPtr2,
+                  (void*)(&beta), generate_tensor_nd_desc(*out), outPtr
+                 ));
   }
 }
 
@@ -214,17 +227,17 @@ void Sub<float, lang::Cuda>(const Tensor& in1,
   float beta = 0.0;
 
   if ((in1.nDim() == in2.nDim()) || (in2.nDim() == 1)) {
-    cudnnOpTensor(ctx->cudnn_handle, generate_Op_desc(CUDNN_OP_TENSOR_ADD),
-                  (void*)(&alpha1), generate_tensorND_desc(in1), inPtr1,
-                  (void*)(&alpha2), generate_tensorND_desc(in2), inPtr2,
-                  (void*)(&beta), generate_tensorND_desc(*out), outPtr
-                 );
+    check_cudnn(cudnnOpTensor(ctx->cudnn_handle, generate_op_desc(CUDNN_OP_TENSOR_ADD),
+                  (void*)(&alpha1), generate_tensor_nd_desc(in1), inPtr1,
+                  (void*)(&alpha2), generate_tensor_nd_desc(in2), inPtr2,
+                  (void*)(&beta), generate_tensor_nd_desc(*out), outPtr
+                 ));
   } else {
-    cudnnOpTensor(ctx->cudnn_handle, generate_Op_desc(CUDNN_OP_TENSOR_ADD),
-                  (void*)(&alpha1), generate_tensorND_desc(in1), inPtr1,
-                  (void*)(&alpha2), generate_tensorND_desc(in1), inPtr2,
-                  (void*)(&beta), generate_tensorND_desc(*out), outPtr
-                 );
+    check_cudnn(cudnnOpTensor(ctx->cudnn_handle, generate_op_desc(CUDNN_OP_TENSOR_ADD),
+                  (void*)(&alpha1), generate_tensor_nd_desc(in1), inPtr1,
+                  (void*)(&alpha2), generate_tensor_nd_desc(in1), inPtr2,
+                  (void*)(&beta), generate_tensor_nd_desc(*out), outPtr
+                 ));
   }
 }
 
@@ -237,9 +250,22 @@ void Clamp<float, lang::Cuda>(const float low,
   const float* inPtr = static_cast<const float*>(in.block()->data());
   float* outPtr = static_cast<float*>(out->block()->mutable_data());
   const size_t num = in.Size();
-  cuda::clamp(num, low, high, inPtr, outPtr, ctx->stream);
-  out->set_strides(in.strides());
+  //if both in and out strides are the same, we proceed to normal cuda::clamp
+  if (in.strides() == out->strides()) {
+    cuda::clamp(num, low, high, inPtr, outPtr, ctx->stream);
+  } else { //else we transform in to out to store first
+    float alpha = 1.0;
+    float beta = 0.0;
+
+    check_cudnn(cudnnTransformTensor(ctx->cudnn_handle,
+                         (void*)(&alpha), generate_tensor_nd_desc(in), inPtr,
+                         (void*)(&beta), generate_tensor_nd_desc(*out), outPtr
+                        ));
+
+    cuda::clamp(num, low, high, outPtr, outPtr, ctx->stream);
+  }
 }
+
 /// out = in1 / in2
 template <>
 void Div<float, lang::Cuda>(const Tensor& in1,
@@ -249,21 +275,43 @@ void Div<float, lang::Cuda>(const Tensor& in1,
   float* outPtr = static_cast<float*>(out->block()->mutable_data());
   const size_t num = in1.Size();
 
-  //if both in1 and in2 strides are the same, we proceed to normal cuda::div
-  if (in1.strides() == in2.strides()) {
+  //if both in1 and in2 are not transposed, and have the same strides,
+  //we proceed to normal cuda::div
+  if (!in1.transpose() && !in2.transpose() && (in1.strides() == in2.strides())) {
     cuda::div(num, inPtr1, inPtr2, outPtr, ctx->stream);
-    out->set_strides(in1.strides());
-  } else { //else we transform in1 to out to store first
+  } else { //else we check whether in1 or in2 or both are transposed
     float alpha = 1.0;
     float beta = 0.0;
 
-    out->set_strides(in2.strides());
-    cudnnTransformTensor(ctx->cudnn_handle,
-                         (void*)(&alpha), generate_tensorND_desc(in1), inPtr1,
-                         (void*)(&beta), generate_tensorND_desc(*out), outPtr
-                        );
-
-    cuda::div(num, outPtr, inPtr2, outPtr, ctx->stream);
+    if (in1.transpose() && in2.transpose()) {
+      Tensor t(in1.shape(), in1.device(), in1.data_type());
+      float* tPtr = static_cast<float*>(t.block()->mutable_data());
+
+      check_cudnn(cudnnTransformTensor(ctx->cudnn_handle,
+                           (void*)(&alpha), generate_tensor_nd_desc(in1), inPtr1,
+                           (void*)(&beta), generate_tensor_nd_desc(t), tPtr
+                          ));
+
+      check_cudnn(cudnnTransformTensor(ctx->cudnn_handle,
+                           (void*)(&alpha), generate_tensor_nd_desc(in2), inPtr2,
+                           (void*)(&beta), generate_tensor_nd_desc(*out), outPtr
+                          ));
+      cuda::div(num, tPtr, outPtr, outPtr, ctx->stream);
+
+    } else if (in1.transpose()) {
+      check_cudnn(cudnnTransformTensor(ctx->cudnn_handle,
+                           (void*)(&alpha), generate_tensor_nd_desc(in1), inPtr1,
+                           (void*)(&beta), generate_tensor_nd_desc(*out), outPtr
+                          ));
+      cuda::div(num, outPtr, inPtr2, outPtr, ctx->stream);
+
+    } else if (in2.transpose()) {
+      check_cudnn(cudnnTransformTensor(ctx->cudnn_handle,
+                           (void*)(&alpha), generate_tensor_nd_desc(in2), inPtr2,
+                           (void*)(&beta), generate_tensor_nd_desc(*out), outPtr
+                          ));
+      cuda::div(num, inPtr1, outPtr, outPtr, ctx->stream);
+    }
   }
 }
 
@@ -273,8 +321,20 @@ void Div<float, lang::Cuda>(const float x, const Tensor& in,
   const float* inPtr = static_cast<const float*>(in.block()->data());
   float* outPtr = static_cast<float*>(out->block()->mutable_data());
   const size_t num = in.Size();
-  cuda::div(num, x, inPtr, outPtr, ctx->stream);
-  out->set_strides(in.strides());
+
+  if (in.strides() == out->strides()) {
+    cuda::div(num, x, inPtr, outPtr, ctx->stream);
+  } else { //else we transform in to out to store first
+    float alpha = 1.0;
+    float beta = 0.0;
+
+    check_cudnn(cudnnTransformTensor(ctx->cudnn_handle,
+                         (void*)(&alpha), generate_tensor_nd_desc(in), inPtr,
+                         (void*)(&beta), generate_tensor_nd_desc(*out), outPtr
+                        ));
+
+    cuda::div(num, x, outPtr, outPtr, ctx->stream);
+  }
 }
 
 /// out = in * x
@@ -285,10 +345,10 @@ void EltwiseMult<float, lang::Cuda>(const Tensor& in,
   float* outPtr = static_cast<float*>(out->block()->mutable_data());
 
   float alpha = x, beta = 0.0;
-  cudnnAddTensor(ctx->cudnn_handle,
-                 (void*)(&alpha), generate_tensorND_desc(in), inPtr,
-                 (void*)(&beta), generate_tensorND_desc(*out), outPtr
-                );
+  check_cudnn(cudnnAddTensor(ctx->cudnn_handle,
+                 (void*)(&alpha), generate_tensor_nd_desc(in), inPtr,
+                 (void*)(&beta), generate_tensor_nd_desc(*out), outPtr
+                ));
 }
 
 /// out = in1 * in2
@@ -301,21 +361,43 @@ void EltwiseMult<float, lang::Cuda>(const Tensor& in1,
   float* outPtr = static_cast<float*>(out->block()->mutable_data());
   const size_t num = in1.Size();
 
-  //if both in1 and in2 strides are the same, we proceed to normal cuda::mult
-  if (in1.strides() == in2.strides()) {
+  //if both in1 and in2 are not transposed, and have the same strides,
+  //we proceed to normal cuda::mult
+  if (!in1.transpose() && !in2.transpose() && (in1.strides() == in2.strides())) {
     cuda::mult(num, inPtr1, inPtr2, outPtr, ctx->stream);
-    out->set_strides(in1.strides());
-  } else { //else we transform in1 to out to store first
+  } else { //else we check whether in1 or in2 or both are transposed
     float alpha = 1.0;
     float beta = 0.0;
 
-    out->set_strides(in2.strides());
-    cudnnTransformTensor(ctx->cudnn_handle,
-                         (void*)(&alpha), generate_tensorND_desc(in1), inPtr1,
-                         (void*)(&beta), generate_tensorND_desc(*out), outPtr
-                        );
-
-    cuda::mult(num, outPtr, inPtr2, outPtr, ctx->stream);
+    if (in1.transpose() && in2.transpose()) {
+      Tensor t(in1.shape(), in1.device(), in1.data_type());
+      float* tPtr = static_cast<float*>(t.block()->mutable_data());
+
+      check_cudnn(cudnnTransformTensor(ctx->cudnn_handle,
+                           (void*)(&alpha), generate_tensor_nd_desc(in1), inPtr1,
+                           (void*)(&beta), generate_tensor_nd_desc(t), tPtr
+                          ));
+
+      check_cudnn(cudnnTransformTensor(ctx->cudnn_handle,
+                           (void*)(&alpha), generate_tensor_nd_desc(in2), inPtr2,
+                           (void*)(&beta), generate_tensor_nd_desc(*out), outPtr
+                          ));
+      cuda::mult(num, tPtr, outPtr, outPtr, ctx->stream);
+
+    } else if (in1.transpose()) {
+      check_cudnn(cudnnTransformTensor(ctx->cudnn_handle,
+                           (void*)(&alpha), generate_tensor_nd_desc(in1), inPtr1,
+                           (void*)(&beta), generate_tensor_nd_desc(*out), outPtr
+                          ));
+      cuda::mult(num, outPtr, inPtr2, outPtr, ctx->stream);
+
+    } else if (in2.transpose()) {
+      check_cudnn(cudnnTransformTensor(ctx->cudnn_handle,
+                           (void*)(&alpha), generate_tensor_nd_desc(in2), inPtr2,
+                           (void*)(&beta), generate_tensor_nd_desc(*out), outPtr
+                          ));
+      cuda::mult(num, inPtr1, outPtr, outPtr, ctx->stream);
+    }
   }
 }
 
@@ -327,8 +409,20 @@ void Exp<float, lang::Cuda>(const Tensor& in, Tensor* out,
   const float* inPtr = static_cast<const float*>(in.block()->data());
   float* outPtr = static_cast<float*>(out->block()->mutable_data());
   const size_t num = in.Size();
-  cuda::exp(num, inPtr, outPtr, ctx->stream);
-  out->set_strides(in.strides());
+
+  if (in.strides() == out->strides()) {
+    cuda::exp(num, inPtr, outPtr, ctx->stream);
+  } else { //else we transform in to out to store first
+    float alpha = 1.0;
+    float beta = 0.0;
+
+    check_cudnn(cudnnTransformTensor(ctx->cudnn_handle,
+                         (void*)(&alpha), generate_tensor_nd_desc(in), inPtr,
+                         (void*)(&beta), generate_tensor_nd_desc(*out), outPtr
+                        ));
+
+    cuda::exp(num, outPtr, outPtr, ctx->stream);
+  }
 }
 
 template <>
@@ -337,8 +431,20 @@ void GE<float, lang::Cuda>(const Tensor& in, const float x,
   float* outPtr = static_cast<float*>(out->block()->mutable_data());
   const float* inPtr = static_cast<const float*>(in.block()->data());
   const size_t num = in.Size();
-  cuda::ge(num, inPtr, x, outPtr, ctx->stream);
-  out->set_strides(in.strides());
+
+  if (in.strides() == out->strides()) {
+    cuda::ge(num, inPtr, x, outPtr, ctx->stream);
+  } else { //else we transform in to out to store first
+    float alpha = 1.0;
+    float beta = 0.0;
+
+    check_cudnn(cudnnTransformTensor(ctx->cudnn_handle,
+                         (void*)(&alpha), generate_tensor_nd_desc(in), inPtr,
+                         (void*)(&beta), generate_tensor_nd_desc(*out), outPtr
+                        ));
+
+    cuda::ge(num, outPtr, x, outPtr, ctx->stream);
+  }
 }
 template <>
 void GE<float, lang::Cuda>(const Tensor& in1, const Tensor& in2,
@@ -359,8 +465,20 @@ void GT<float, lang::Cuda>(const Tensor& in, const float x,
   float* outPtr = static_cast<float*>(out->block()->mutable_data());
   const float* inPtr = static_cast<const float*>(in.block()->data());
   const size_t num = in.Size();
-  cuda::gt(num, inPtr, x, outPtr, ctx->stream);
-  out->set_strides(in.strides());
+
+  if (in.strides() == out->strides()) {
+    cuda::gt(num, inPtr, x, outPtr, ctx->stream);
+  } else { //else we transform in to out to store first
+    float alpha = 1.0;
+    float beta = 0.0;
+
+    check_cudnn(cudnnTransformTensor(ctx->cudnn_handle,
+                         (void*)(&alpha), generate_tensor_nd_desc(in), inPtr,
+                         (void*)(&beta), generate_tensor_nd_desc(*out), outPtr
+                        ));
+
+    cuda::gt(num, outPtr, x, outPtr, ctx->stream);
+  }
 }
 template <>
 void GT<float, lang::Cuda>(const Tensor& in1, const Tensor& in2,
@@ -373,14 +491,27 @@ void GT<float, lang::Cuda>(const Tensor& in1, const Tensor& in2,
   //cuda::gt(num, inPtr1, inPtr2, outPtr, ctx->stream);
   cuda::gt(num, outPtr, 0.0, outPtr, ctx->stream);
 }
+
 template <>
 void LE<float, lang::Cuda>(const Tensor& in, const float x,
                            Tensor* out, Context* ctx) {
   float* outPtr = static_cast<float*>(out->block()->mutable_data());
   const float* inPtr = static_cast<const float*>(in.block()->data());
   const size_t num = in.Size();
-  cuda::le(num, inPtr, x, outPtr, ctx->stream);
-  out->set_strides(in.strides());
+
+  if (in.strides() == out->strides()) {
+    cuda::le(num, inPtr, x, outPtr, ctx->stream);
+  } else { //else we transform in to out to store first
+    float alpha = 1.0;
+    float beta = 0.0;
+
+    check_cudnn(cudnnTransformTensor(ctx->cudnn_handle,
+                         (void*)(&alpha), generate_tensor_nd_desc(in), inPtr,
+                         (void*)(&beta), generate_tensor_nd_desc(*out), outPtr
+                        ));
+
+    cuda::le(num, outPtr, x, outPtr, ctx->stream);
+  }
 }
 template <>
 void LE<float, lang::Cuda>(const Tensor& in1, const Tensor& in2,
@@ -401,17 +532,42 @@ void Log<float, lang::Cuda>(const Tensor& in, Tensor* out,
   const float* inPtr = static_cast<const float*>(in.block()->data());
   float* outPtr = static_cast<float*>(out->block()->mutable_data());
   const size_t num = in.Size();
-  cuda::log(num, inPtr, outPtr, ctx->stream);
-  out->set_strides(in.strides());
+
+  if (in.strides() == out->strides()) {
+    cuda::log(num, inPtr, outPtr, ctx->stream);
+  } else { //else we transform in to out to store first
+    float alpha = 1.0;
+    float beta = 0.0;
+
+    check_cudnn(cudnnTransformTensor(ctx->cudnn_handle,
+                         (void*)(&alpha), generate_tensor_nd_desc(in), inPtr,
+                         (void*)(&beta), generate_tensor_nd_desc(*out), outPtr
+                        ));
+
+    cuda::log(num, outPtr, outPtr, ctx->stream);
+  }
 }
+
 template <>
 void LT<float, lang::Cuda>(const Tensor& in, const float x,
                            Tensor* out, Context* ctx) {
   float* outPtr = static_cast<float*>(out->block()->mutable_data());
   const float* inPtr = static_cast<const float*>(in.block()->data());
   const size_t num = in.Size();
-  cuda::lt(num, inPtr, x, outPtr, ctx->stream);
-  out->set_strides(in.strides());
+
+  if (in.strides() == out->strides()) {
+    cuda::lt(num, inPtr, x, outPtr, ctx->stream);
+  } else { //else we transform in to out to store first
+    float alpha = 1.0;
+    float beta = 0.0;
+
+    check_cudnn(cudnnTransformTensor(ctx->cudnn_handle,
+                         (void*)(&alpha), generate_tensor_nd_desc(in), inPtr,
+                         (void*)(&beta), generate_tensor_nd_desc(*out), outPtr
+                        ));
+
+    cuda::lt(num, outPtr, x, outPtr, ctx->stream);
+  }
 }
 template <>
 void LT<float, lang::Cuda>(const Tensor& in1, const Tensor& in2,
@@ -424,6 +580,7 @@ void LT<float, lang::Cuda>(const Tensor& in1, const Tensor& in2,
   //cuda::lt(num, inPtr1, inPtr2, outPtr, ctx->stream);
   cuda::lt(num, outPtr, 0.0, outPtr, ctx->stream);
 }
+
 /// Element-wise operation, out[i] = in[i]^x
 template <>
 void Pow<float, lang::Cuda>(const Tensor& in, const float x,
@@ -431,8 +588,20 @@ void Pow<float, lang::Cuda>(const Tensor& in, const float x,
   const float* inPtr = static_cast<const float*>(in.block()->data());
   float* outPtr = static_cast<float*>(out->block()->mutable_data());
   const size_t num = in.Size();
-  cuda::pow(num, inPtr, x, outPtr, ctx->stream);
-  out->set_strides(in.strides());
+
+  if (in.strides() == out->strides()) {
+    cuda::pow(num, inPtr, x, outPtr, ctx->stream);
+  } else { //else we transform in to out to store first
+    float alpha = 1.0;
+    float beta = 0.0;
+
+    check_cudnn(cudnnTransformTensor(ctx->cudnn_handle,
+                         (void*)(&alpha), generate_tensor_nd_desc(in), inPtr,
+                         (void*)(&beta), generate_tensor_nd_desc(*out), outPtr
+                        ));
+
+    cuda::pow(num, outPtr, x, outPtr, ctx->stream);
+  }
 }
 /// Element-wise operation, out[i] = in1[i]^in2[i]
 template <>
@@ -443,20 +612,43 @@ void Pow<float, lang::Cuda>(const Tensor& in1,
   float* outPtr = static_cast<float*>(out->block()->mutable_data());
   const size_t num = in1.Size();
 
-  if (in1.strides() == in2.strides()) {
+  //if both in1 and in2 are not transposed, and have the same strides,
+  //we proceed to normal cuda::pow
+  if (!in1.transpose() && !in2.transpose() && (in1.strides() == in2.strides())) {
     cuda::pow(num, inPtr1, inPtr2, outPtr, ctx->stream);
-    out->set_strides(in1.strides());
-  } else { //else we transform in1 to out to store first
+  } else { //else we check whether in1 or in2 or both are transposed
     float alpha = 1.0;
     float beta = 0.0;
 
-    out->set_strides(in2.strides());
-    cudnnTransformTensor(ctx->cudnn_handle,
-                         (void*)(&alpha), generate_tensorND_desc(in1), inPtr1,
-                         (void*)(&beta), generate_tensorND_desc(*out), outPtr
-                        );
-
-    cuda::pow(num, outPtr, inPtr2, outPtr, ctx->stream);
+    if (in1.transpose() && in2.transpose()) {
+      Tensor t(in1.shape(), in1.device(), in1.data_type());
+      float* tPtr = static_cast<float*>(t.block()->mutable_data());
+
+      check_cudnn(cudnnTransformTensor(ctx->cudnn_handle,
+                           (void*)(&alpha), generate_tensor_nd_desc(in1), inPtr1,
+                           (void*)(&beta), generate_tensor_nd_desc(t), tPtr
+                          ));
+
+      check_cudnn(cudnnTransformTensor(ctx->cudnn_handle,
+                           (void*)(&alpha), generate_tensor_nd_desc(in2), inPtr2,
+                           (void*)(&beta), generate_tensor_nd_desc(*out), outPtr
+                          ));
+      cuda::pow(num, tPtr, outPtr, outPtr, ctx->stream);
+
+    } else if (in1.transpose()) {
+      check_cudnn(cudnnTransformTensor(ctx->cudnn_handle,
+                           (void*)(&alpha), generate_tensor_nd_desc(in1), inPtr1,
+                           (void*)(&beta), generate_tensor_nd_desc(*out), outPtr
+                          ));
+      cuda::pow(num, outPtr, inPtr2, outPtr, ctx->stream);
+
+    } else if (in2.transpose()) {
+      check_cudnn(cudnnTransformTensor(ctx->cudnn_handle,
+                           (void*)(&alpha), generate_tensor_nd_desc(in2), inPtr2,
+                           (void*)(&beta), generate_tensor_nd_desc(*out), outPtr
+                          ));
+      cuda::pow(num, inPtr1, outPtr, outPtr, ctx->stream);
+    }
   }
 }
 
@@ -498,8 +690,20 @@ void ReLU<float, lang::Cuda>(const Tensor& in, Tensor* out,
   const float* inPtr = static_cast<const float*>(in.block()->data());
   float* outPtr = static_cast<float*>(out->block()->mutable_data());
   const size_t num = in.Size();
-  cuda::relu(num, inPtr, outPtr, ctx->stream);
-  out->set_strides(in.strides());
+
+  if (in.strides() == out->strides()) {
+    cuda::relu(num, inPtr, outPtr, ctx->stream);
+  } else { //else we transform in to out to store first
+    float alpha = 1.0;
+    float beta = 0.0;
+
+    check_cudnn(cudnnTransformTensor(ctx->cudnn_handle,
+                         (void*)(&alpha), generate_tensor_nd_desc(in), inPtr,
+                         (void*)(&beta), generate_tensor_nd_desc(*out), outPtr
+                        ));
+
+    cuda::relu(num, outPtr, outPtr, ctx->stream);
+  }
 }
 
 // /// Element-wise operation, out[i]=sigmoid([in[i])
@@ -541,8 +745,20 @@ void Sigmoid<float, lang::Cuda>(const Tensor& in, Tensor* out,
   const float* inPtr = static_cast<const float*>(in.block()->data());
   float* outPtr = static_cast<float*>(out->block()->mutable_data());
   const size_t num = in.Size();
-  cuda::sigmoid(num, inPtr, outPtr, ctx->stream);
-  out->set_strides(in.strides());
+
+  if (in.strides() == out->strides()) {
+    cuda::sigmoid(num, inPtr, outPtr, ctx->stream);
+  } else { //else we transform in to out to store first
+    float alpha = 1.0;
+    float beta = 0.0;
+
+    check_cudnn(cudnnTransformTensor(ctx->cudnn_handle,
+                         (void*)(&alpha), generate_tensor_nd_desc(in), inPtr,
+                         (void*)(&beta), generate_tensor_nd_desc(*out), outPtr
+                        ));
+
+    cuda::sigmoid(num, outPtr, outPtr, ctx->stream);
+  }
 }
 
 // out[i] = sign(in[i])
@@ -552,8 +768,20 @@ void Sign<float, lang::Cuda>(const Tensor& in, Tensor* out,
   const float* inPtr = static_cast<const float*>(in.block()->data());
   float* outPtr = static_cast<float*>(out->block()->mutable_data());
   const size_t num = in.Size();
-  cuda::sign(num, inPtr, outPtr, ctx->stream);
-  out->set_strides(in.strides());
+
+  if (in.strides() == out->strides()) {
+    cuda::sign(num, inPtr, outPtr, ctx->stream);
+  } else { //else we transform in to out to store first
+    float alpha = 1.0;
+    float beta = 0.0;
+
+    check_cudnn(cudnnTransformTensor(ctx->cudnn_handle,
+                         (void*)(&alpha), generate_tensor_nd_desc(in), inPtr,
+                         (void*)(&beta), generate_tensor_nd_desc(*out), outPtr
+                        ));
+
+    cuda::sign(num, outPtr, outPtr, ctx->stream);
+  }
 }
 
 // Element-wise operation, out[i]=sqrt([in[i])
@@ -566,12 +794,12 @@ void Sqrt<float, lang::Cuda>(const Tensor& in, Tensor* out,
   float alpha1 = 1.0;
   float alpha2 = 0.0;
   float beta = 0.0;
-  cudnnTensorDescriptor_t in_desc = generate_tensorND_desc(in);
-  cudnnOpTensor(ctx->cudnn_handle, generate_Op_desc(CUDNN_OP_TENSOR_SQRT),
+  cudnnTensorDescriptor_t in_desc = generate_tensor_nd_desc(in);
+  check_cudnn(cudnnOpTensor(ctx->cudnn_handle, generate_op_desc(CUDNN_OP_TENSOR_SQRT),
                 (void*)(&alpha1), in_desc, inPtr,
                 (void*)(&alpha2), in_desc, inPtr,
-                (void*)(&beta), generate_tensorND_desc(*out), outPtr
-               );
+                (void*)(&beta), generate_tensor_nd_desc(*out), outPtr
+               ));
 }
 
 /// Element-wise operation, out[i]=in[i]^2
@@ -581,8 +809,20 @@ void Square<float, lang::Cuda>(const Tensor& in, Tensor* out,
   const float* inPtr = static_cast<const float*>(in.block()->data());
   float* outPtr = static_cast<float*>(out->block()->mutable_data());
   const size_t num = in.Size();
-  cuda::square(num, inPtr, outPtr, ctx->stream);
-  out->set_strides(in.strides());
+
+  if (in.strides() == out->strides()) {
+    cuda::square(num, inPtr, outPtr, ctx->stream);
+  } else { //else we transform in to out to store first
+    float alpha = 1.0;
+    float beta = 0.0;
+
+    check_cudnn(cudnnTransformTensor(ctx->cudnn_handle,
+                         (void*)(&alpha), generate_tensor_nd_desc(in), inPtr,
+                         (void*)(&beta), generate_tensor_nd_desc(*out), outPtr
+                        ));
+
+    cuda::square(num, outPtr, outPtr, ctx->stream);
+  }
 }
 
 // template <>
@@ -614,9 +854,9 @@ void Sum<float, lang::Cuda>(const Tensor& in, float* out,
   cudnnNanPropagation_t cudnn_propagation = CUDNN_PROPAGATE_NAN;
   cudnnReduceTensorIndices_t cudnn_indices = CUDNN_REDUCE_TENSOR_NO_INDICES;
   cudnnIndicesType_t cudnn_indices_type = CUDNN_32BIT_INDICES;
-  cudnnCreateReduceTensorDescriptor(&reduce_desc);
-  cudnnSetReduceTensorDescriptor(reduce_desc, reduce_op, cudnn_dtype,
-                                 cudnn_propagation, cudnn_indices, cudnn_indices_type);
+  check_cudnn(cudnnCreateReduceTensorDescriptor(&reduce_desc));
+  check_cudnn(cudnnSetReduceTensorDescriptor(reduce_desc, reduce_op, cudnn_dtype,
+                                 cudnn_propagation, cudnn_indices, cudnn_indices_type));
 
   //instantiate 2 new tensors to use new blocks as memory instead of cudaMalloc
   size_t reduction_size_int = Product(in.shape());
@@ -632,11 +872,11 @@ void Sum<float, lang::Cuda>(const Tensor& in, float* out,
 
   float alpha = 1.0;
   float beta = 0.0;
-  cudnnReduceTensor(ctx->cudnn_handle, reduce_desc,
+  check_cudnn(cudnnReduceTensor(ctx->cudnn_handle, reduce_desc,
                     indicesPtr, indices_bytes, workspacePtr, workspace_bytes,
-                    (void*)(&alpha), generate_tensorND_desc(in), inPtr,
-                    (void*)(&beta), generate_tensorND_desc(t), tPtr
-                   );
+                    (void*)(&alpha), generate_tensor_nd_desc(in), inPtr,
+                    (void*)(&beta), generate_tensor_nd_desc(t), tPtr
+                   ));
 
   *out = tPtr[0];
 }
@@ -680,8 +920,36 @@ void Tanh<float, lang::Cuda>(const Tensor& in, Tensor* out,
   const float* inPtr = static_cast<const float*>(in.block()->data());
   float* outPtr = static_cast<float*>(out->block()->mutable_data());
   const size_t num = in.Size();
-  cuda::tanh(num, inPtr, outPtr, ctx->stream);
-  out->set_strides(in.strides());
+
+  if (in.strides() == out->strides()) {
+    cuda::tanh(num, inPtr, outPtr, ctx->stream);
+  } else { //else we transform in to out to store first
+    float alpha = 1.0;
+    float beta = 0.0;
+
+    check_cudnn(cudnnTransformTensor(ctx->cudnn_handle,
+                         (void*)(&alpha), generate_tensor_nd_desc(in), inPtr,
+                         (void*)(&beta), generate_tensor_nd_desc(*out), outPtr
+                        ));
+
+    cuda::tanh(num, outPtr, outPtr, ctx->stream);
+  }
+}
+
+template <>
+void Transform<float, lang::Cuda>(const Tensor& in, Tensor* out,
+                             Context* ctx) {
+  const float* inPtr = static_cast<const float*>(in.block()->data());
+  float* outPtr = static_cast<float*>(out->block()->mutable_data());
+
+  float alpha = 1.0;
+  float beta = 0.0;
+
+  check_cudnn(cudnnTransformTensor(ctx->cudnn_handle,
+                         (void*)(&alpha), generate_tensor_nd_desc(in), inPtr,
+                         (void*)(&beta), generate_tensor_nd_desc(*out), outPtr
+                        ));
+  
 }
 
 // ================Random functions===========================================
@@ -953,10 +1221,10 @@ void RowMax<float, lang::Cuda>(const Tensor& in, Tensor* out,
     float alpha = 1.0;
     float beta = 0.0;
 
-    cudnnTransformTensor(ctx->cudnn_handle,
-                         (void*)(&alpha), generate_tensorND_desc(in), inPtr,
-                         (void*)(&beta), generate_tensorND_desc(t), tPtr
-                        );
+    check_cudnn(cudnnTransformTensor(ctx->cudnn_handle,
+                         (void*)(&alpha), generate_tensor_nd_desc(in), inPtr,
+                         (void*)(&beta), generate_tensor_nd_desc(t), tPtr
+                        ));
 
     const float* tPtr_const = static_cast<const float*>(t.block()->data());
     cuda::RowMax(nrow, ncol, tPtr_const, outPtr, ctx->stream);


[2/2] incubator-singa git commit: Merge branch 'pr388'

Posted by wa...@apache.org.
Merge branch 'pr388'


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/c8df172c
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/c8df172c
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/c8df172c

Branch: refs/heads/master
Commit: c8df172c71941913d0aaea0d70782a4d649ae0be
Parents: 9b06c29 f9e7caa
Author: Wang Wei <wa...@gmail.com>
Authored: Sat Jun 30 21:30:20 2018 +0800
Committer: Wang Wei <wa...@gmail.com>
Committed: Sat Jun 30 21:30:20 2018 +0800

----------------------------------------------------------------------
 include/singa/core/tensor.h        |   6 +-
 src/core/tensor/tensor.cc          | 151 +++++++---
 src/core/tensor/tensor_math.h      |   8 +
 src/core/tensor/tensor_math_cpp.h  |  62 ++--
 src/core/tensor/tensor_math_cuda.h | 502 ++++++++++++++++++++++++--------
 5 files changed, 543 insertions(+), 186 deletions(-)
----------------------------------------------------------------------