You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@singa.apache.org by wa...@apache.org on 2019/06/22 13:27:05 UTC
[incubator-singa] branch master updated: transpose reshape fix
SINGA-462
This is an automated email from the ASF dual-hosted git repository.
wangwei pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-singa.git
The following commit(s) were added to refs/heads/master by this push:
new e679ff5 transpose reshape fix SINGA-462
new 6e90679 Merge pull request #469 from dcslin/SINGA-462
e679ff5 is described below
commit e679ff560a977844d14d9afc071223d60d6042ef
Author: slin004 <13...@users.noreply.github.com>
AuthorDate: Fri Jun 21 19:06:09 2019 +0800
transpose reshape fix SINGA-462
---
include/singa/core/tensor.h | 6 +-----
src/core/tensor/tensor.cc | 15 ++++++++++++++-
test/singa/test_tensor_math.cc | 37 +++++++++++++++++++++++++++++++++++++
3 files changed, 52 insertions(+), 6 deletions(-)
diff --git a/include/singa/core/tensor.h b/include/singa/core/tensor.h
index ed37cdb..1eeab5e 100755
--- a/include/singa/core/tensor.h
+++ b/include/singa/core/tensor.h
@@ -137,11 +137,7 @@ class Tensor {
/// used for swig code to convert Tensor into numpy array.
/// It gets data into 'value'
template <typename SType>
- void GetValue(SType *value, const size_t num) {
- CHECK(device_ == defaultDevice);
- const SType* ptr = data<SType>();
- for (size_t i = 0; i < num; i++) value[i] = ptr[i];
- }
+ void GetValue(SType *value, const size_t num);
/// Serialize data, shape and transpose to protobuf object.
void ToProto(singa::TensorProto *proto) const;
diff --git a/src/core/tensor/tensor.cc b/src/core/tensor/tensor.cc
index 8c50437..a58bff5 100755
--- a/src/core/tensor/tensor.cc
+++ b/src/core/tensor/tensor.cc
@@ -654,6 +654,18 @@ void Tensor::SetValue(const SType x) {
template void Tensor::SetValue<float>(const float x);
template void Tensor::SetValue<int>(const int x);
+template <typename SType>
+void Tensor::GetValue(SType *value, const size_t num) {
+ CHECK(device_ == defaultDevice);
+ Tensor t(shape_, device_, data_type_);
+ // transform function arrange data in memory considering stride
+ Transform(*this, &t);
+ auto ptr=static_cast<const float*>(t.block()->data());
+ for (size_t i = 0; i < num; i++) value[i] = ptr[i];
+}
+template void Tensor::GetValue<float>(float *value, const size_t num);
+template void Tensor::GetValue<int>(int *value, const size_t num);
+
#define EltwiseUnaryTensorFn(fn, t, ret) \
do { \
TYPE_LANG_SWITCH(t.data_type(), DType, t.device()->lang(), Lang, { \
@@ -1290,8 +1302,9 @@ Tensor& Tensor::Reshape(const Shape &shape) {
CHECK_EQ(Product(shape), Size());
if (transpose()) {
Tensor t(shape, device_, data_type_);
- singa::Transform(*this, &t);
shape_ = shape;
+ // `Transform` after assigning new shape, to keep this and t consistent
+ singa::Transform(*this, &t);
std::swap(t.block_, block_);
} else {
shape_ = shape;
diff --git a/test/singa/test_tensor_math.cc b/test/singa/test_tensor_math.cc
index 6228607..e874c15 100644
--- a/test/singa/test_tensor_math.cc
+++ b/test/singa/test_tensor_math.cc
@@ -368,6 +368,43 @@ TEST_F(TensorMath, ReshapeCpp) {
EXPECT_EQ(p.shape(0), 4u);
EXPECT_EQ(p.shape(1), 1u);
for (int i = 0; i < 4; i++) EXPECT_FLOAT_EQ(ptr[i], 0.3f);
+
+ // test transpose then reshape
+ // {2,3,2} => {2,2,3} => {2,6}
+ Tensor t2(Shape{2,3,2});
+ t2.SetValue(0.2f);
+ t2.Transpose({2,0,1});
+ EXPECT_EQ(t2.shape(0), 2u);
+ EXPECT_EQ(t2.shape(1), 2u);
+ EXPECT_EQ(t2.shape(2), 3u);
+
+ t2.Reshape(Shape{2,6});
+ EXPECT_EQ(t2.shape(0), 2u);
+ EXPECT_EQ(t2.shape(1), 6u);
+}
+
+
+TEST_F(TensorMath, TransposeCpp) {
+ Tensor t(Shape{2,3,2});
+ const float dat1[12] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f };
+ t.CopyDataFromHostPtr(dat1, 12);
+
+ t.Transpose({2,0,1});
+// const float *dptr = t.data<float>();
+ float dptr[12];
+ t.GetValue(dptr,12);
+ EXPECT_FLOAT_EQ(1.0f, dptr[0]);
+ EXPECT_FLOAT_EQ(3.0f, dptr[1]);
+ EXPECT_FLOAT_EQ(5.0f, dptr[2]);
+ EXPECT_FLOAT_EQ(7.0f, dptr[3]);
+ EXPECT_FLOAT_EQ(9.0f, dptr[4]);
+ EXPECT_FLOAT_EQ(11.0f,dptr[5]);
+ EXPECT_FLOAT_EQ(2.0f, dptr[6]);
+ EXPECT_FLOAT_EQ(4.0f, dptr[7]);
+ EXPECT_FLOAT_EQ(6.0f, dptr[8]);
+ EXPECT_FLOAT_EQ(8.0f, dptr[9]);
+ EXPECT_FLOAT_EQ(10.0f,dptr[10]);
+ EXPECT_FLOAT_EQ(12.0f,dptr[11]);
}