You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@singa.apache.org by wa...@apache.org on 2016/06/29 07:20:28 UTC
incubator-singa git commit: SINGA-210 Enable checkpoint and resume
for v1.0
Repository: incubator-singa
Updated Branches:
refs/heads/dev d3c1bae61 -> 62c6603ff
SINGA-210 Enable checkpoint and resume for v1.0
This ticket is going to add code for dumping the model parameters as
checkpoint files, which could be used for fine-tuning and deployment.
Serialize Tensor into TensorProto and save it in BinFile, which is
stored as <prefix>.model, and generate description about parameters
in <prefix>.desc.
Unit test cases passed for kFloat, kInt and kDouble data type.
Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/62c6603f
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/62c6603f
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/62c6603f
Branch: refs/heads/dev
Commit: 62c6603ff7a3fe9f9749021e84ad9ec35f3fef7d
Parents: d3c1bae
Author: WANG Ji <ij...@gmail.com>
Authored: Tue Jun 28 23:30:36 2016 +0800
Committer: WANG Ji <ij...@gmail.com>
Committed: Wed Jun 29 13:52:30 2016 +0800
----------------------------------------------------------------------
include/singa/core/tensor.h | 38 +++++++------
include/singa/io/snapshot.h | 79 ++++++++++++++++++++++++++
src/core/tensor/tensor.cc | 106 ++++++++++++++++++++++++++++++++++-
src/io/snapshot.cc | 104 +++++++++++++++++++++++++++++++++++
src/proto/core.proto | 11 ++++
test/singa/test_snapshot.cc | 116 +++++++++++++++++++++++++++++++++++++++
6 files changed, 437 insertions(+), 17 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/62c6603f/include/singa/core/tensor.h
----------------------------------------------------------------------
diff --git a/include/singa/core/tensor.h b/include/singa/core/tensor.h
index 15e7b7f..4ef3286 100644
--- a/include/singa/core/tensor.h
+++ b/include/singa/core/tensor.h
@@ -34,8 +34,9 @@ namespace singa {
typedef vector<size_t> Shape;
/// hardcode the width of types defined in DataType
-const size_t kDataWidth[] = {sizeof(float), sizeof(float) / 2, sizeof(int),
- sizeof(char), sizeof(double), sizeof(unsigned char)};
+const size_t kDataWidth[] = {sizeof(float), sizeof(float) / 2,
+ sizeof(int), sizeof(char),
+ sizeof(double), sizeof(unsigned char)};
inline size_t SizeOf(DataType t) {
static_assert(kNumDataType == sizeof(kDataWidth) / sizeof(size_t),
"Num of data types not match num of data width");
@@ -70,14 +71,14 @@ class Tensor {
/// Users should not operate against Block directly.
/// block_ is allocated in constructors.
Block *block() const { return block_; }
- void SetBlock(Block* block);
+ void SetBlock(Block *block);
std::shared_ptr<Device> device() const { return device_; }
/// return immutable Tensor values with given type.
template <typename SType>
- const SType* data() const {
- return static_cast<const SType*>(block()->data());
+ const SType *data() const {
+ return static_cast<const SType *>(block()->data());
}
/// data type, including kFloat16, kFloat32, kInt
@@ -96,8 +97,7 @@ class Tensor {
/// return number of total elements
size_t Size() const {
- if (block_ == nullptr)
- return 0u;
+ if (block_ == nullptr) return 0u;
CHECK_EQ(block_->size() % SizeOf(data_type_), 0u);
return block_->size() / SizeOf(data_type_);
}
@@ -110,7 +110,8 @@ class Tensor {
void Reshape(Shape &&shape);
/// Reset the shape, device, and data type as given tensor.
- /// If block size changes, then reallocate a new block. The previous block would
+ /// If block size changes, then reallocate a new block. The previous block
+ /// would
/// be deleted.
void ResetLike(const Tensor &t);
@@ -138,6 +139,12 @@ class Tensor {
/// Meta data would not be copied!
void CopyData(const Tensor &other);
+ /// Deserialize data, shape and transpose from protobuf object.
+ void FromProto(const singa::TensorProto &proto);
+
+ /// Serialize data, shape and transpose to protobuf object.
+ void ToProto(singa::TensorProto *proto) const;
+
/// return an exactly the same Tensor with data been deep copied to the given
/// device. If 'device' is nullptr, then clone it one the current device.
Tensor Clone(std::shared_ptr<Device> device = nullptr) const;
@@ -248,7 +255,6 @@ void Sqrt(const Tensor &in, Tensor *out);
void Square(const Tensor &in, Tensor *out);
void Tanh(const Tensor &in, Tensor *out);
-
/// Element-wise opeartion, out[i]=in[i]^x
template <typename SType>
Tensor Pow(const Tensor &in, const SType x);
@@ -404,27 +410,27 @@ void Mult(const SType alpha, const Tensor &A, const Tensor &B, const SType beta,
/// Compute the cross entropy loss given the prediction probability 'p' and
/// the target (ground truth) labels 't'. 'p' and 't' are either 1-d vector
/// or 2-d matrix. 'loss' is 1-d vector. The loss is computed into p.
-void ComputeCrossEntropy(const Tensor& p, const Tensor& t, Tensor* loss);
+void ComputeCrossEntropy(const Tensor &p, const Tensor &t, Tensor *loss);
/// Compute the dx, given prediction probability 'p' (p=softmax(x)) and
/// the target (ground truth) labels 't'. 'p' and 't' are either 1-d vector
/// or 2-d matrix. 'grad' has the same shape as 'p'. dx is computed into p.
-void SoftmaxCrossEntropyBwd(const Tensor& t, Tensor* p);
+void SoftmaxCrossEntropyBwd(const Tensor &t, Tensor *p);
/// Return a tensor consisting of rows ([start, end)) from 'in'. It shares the
/// memory with 'in'. 'in' is a 1D or 2D Tensor.
-Tensor SliceRows(const Tensor& in, const size_t start, const size_t end);
+Tensor SliceRows(const Tensor &in, const size_t start, const size_t end);
/// Return a tensor consisting of rows ([start, end)) from 'in'. It copies the
/// values from 'in'. 'in' ia a 2D Tensor.
-Tensor CopyRows(const Tensor& in, const size_t start, const size_t end);
+Tensor CopyRows(const Tensor &in, const size_t start, const size_t end);
/// Return a tensor consisting of columns ([start, end)) from 'in'. It copies
/// the values from 'in'. 'in' is a 2D Tensor.
-Tensor CopyColumns(const Tensor& in, const size_t start, const size_t end);
+Tensor CopyColumns(const Tensor &in, const size_t start, const size_t end);
/// Return a tensor which is vertically stacked from tensors in 'in'. Each
/// tensor in 'in' is a 2D tensor. Values are copied, no memory sharing.
-Tensor ConcatenateRows(const vector<Tensor>& in);
+Tensor ConcatenateRows(const vector<Tensor> &in);
/// Return a tensor which is horizontally stacked from tensors in 'in'. Each
/// tensor in 'in' is a 2D tensor. Values are copied, no memory sharing.
-Tensor ConcatenateColumns(const vector<Tensor>& in);
+Tensor ConcatenateColumns(const vector<Tensor> &in);
} // namespace singa
#endif // SINGA_CORE_TENSOR_H_
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/62c6603f/include/singa/io/snapshot.h
----------------------------------------------------------------------
diff --git a/include/singa/io/snapshot.h b/include/singa/io/snapshot.h
new file mode 100644
index 0000000..7545572
--- /dev/null
+++ b/include/singa/io/snapshot.h
@@ -0,0 +1,79 @@
+/************************************************************
+*
+* Licensed to the Apache Software Foundation (ASF) under one
+* or more contributor license agreements. See the NOTICE file
+* distributed with this work for additional information
+* regarding copyright ownership. The ASF licenses this file
+* to you under the Apache License, Version 2.0 (the
+* "License"); you may not use this file except in compliance
+* with the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing,
+* software distributed under the License is distributed on an
+* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+* KIND, either express or implied. See the License for the
+* specific language governing permissions and limitations
+* under the License.
+*
+*************************************************************/
+
+#ifndef SINGA_UTILS_SNAPSHOT_H_
+#define SINGA_UTILS_SNAPSHOT_H_
+
+#include "singa/io/reader.h"
+#include "singa/io/writer.h"
+#include "singa/utils/logging.h"
+#include "singa/proto/core.pb.h"
+#include "singa/core/tensor.h"
+
+#include <string>
+#include <unordered_set>
+#include <unordered_map>
+#include <memory>
+
+namespace singa {
+/// The snapshot management.
+/// It dumps the model parameter snapshot as checkpoint files, which coud be
+/// used for fine-tuning and deployment.
+/// The model paramters are separated from model definition, i.e., net
+/// construction. Users either randomly initialize the layer parameters or using
+/// the parameters from checkpoint files using Snapshot after creating the
+/// neural network.
+class Snapshot {
+ public:
+ enum Mode { kRead, kWrite };
+ /// <prefix>.model is the binary file for parameter key-value pair.
+ /// <prefix>.meta is the text file describing information about paramters,
+ /// i.e.
+ /// name and shape, one line per parameter.
+ /// kRead for reading snapshot, whereas kWrite for dumping out snapshot.
+ Snapshot(const std::string& prefix, Mode mode);
+ ~Snapshot() {}
+ /// Read parameters saved as tensors from checkpoint file.
+ std::vector<std::pair<std::string, Tensor>> Read();
+ /// Read parameter shapes from description file.
+ std::vector<std::pair<std::string, Shape>> ReadShape();
+ /// Read parameter returned as a tensor for a given parameter name.
+ Tensor Read(const std::string& Key);
+ /// Read parameter shape for a given parameter name.
+ Shape ReadShape(const std::string& key);
+ /// Serialize and dump out parameter. This method will write two files, one
+ /// binary file is for serialized tensors, the other csv file is for parameter
+ /// names and shapes.
+ void Write(const std::string& key, const Tensor& param);
+
+ private:
+ std::string prefix_;
+ Mode mode_;
+ std::unique_ptr<io::Writer> bin_writer_ptr_, text_writer_ptr_;
+ std::unique_ptr<io::Reader> bin_reader_ptr_;
+ /// Check whether parameter name is unique.
+ std::unordered_set<std::string> param_names_;
+ /// Preload key-parameter tensor pairs for seeking a specified key.
+ std::unordered_map<std::string, Tensor> param_map_;
+};
+} // namespace singa
+
+#endif // SINGA_UTILS_SNAPSHOT_H_
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/62c6603f/src/core/tensor/tensor.cc
----------------------------------------------------------------------
diff --git a/src/core/tensor/tensor.cc b/src/core/tensor/tensor.cc
index b07a23c..3501ecd 100644
--- a/src/core/tensor/tensor.cc
+++ b/src/core/tensor/tensor.cc
@@ -159,6 +159,106 @@ void Tensor::CopyData(const Tensor &src) {
}
}
+void Tensor::FromProto(const singa::TensorProto &proto) {
+ if (block_ != nullptr && block_->DecRefCount() == 0)
+ device_->FreeBlock(block_);
+ block_ = nullptr;
+ Shape shape;
+ for (uint32_t s : proto.shape()) shape.push_back(s);
+ data_type_ = proto.data_type();
+ Reshape(shape);
+ transpose_ = proto.transpose();
+ switch (data_type_) {
+ case kFloat32: {
+ std::unique_ptr<float[]> data_ptr(new float[Product(shape_)]);
+ for (size_t i = 0; i < Product(shape_); ++i)
+ data_ptr[i] = static_cast<float>(proto.float_data(i));
+ CopyDataFromHostPtr<float>(data_ptr.get(), Product(shape_));
+ break;
+ }
+ case kDouble: {
+ std::unique_ptr<double[]> data(new double[Product(shape_)]);
+ for (size_t i = 0; i < Product(shape_); ++i)
+ data[i] = proto.double_data(i);
+ CopyDataFromHostPtr<double>(data.get(), Product(shape_));
+ break;
+ }
+ case kInt: {
+ std::unique_ptr<int[]> data(new int[Product(shape_)]);
+ for (size_t i = 0; i < Product(shape_); ++i) data[i] = proto.int_data(i);
+ CopyDataFromHostPtr<int>(data.get(), Product(shape_));
+ break;
+ }
+ ///TODO(wangji): Implement to support C++ type char using bytes type in protobuf
+ /// which is equivalent to string type is different from the other cases. The kchar
+ /// and kUChar case is to be implemented.
+ /*
+ case kChar: {
+ std::unique_ptr<char[]> data(new char[Product(shape_)]);
+ for (size_t i = 0; i < Product(shape_); ++i)
+ data[i] = static_cast<char>(proto.bytes_data(i));
+ break;
+ }
+ case kUChar: {
+ std::unique_ptr<unsigned char[]> data(new unsigned char[Product(shape_)]);
+ for (size_t i = 0; i < Product(shape_); ++i)
+ data[i] = static_cast<unsigned char>(proto.bytes_data(i));
+ break;
+ }
+ */
+ default: { LOG(FATAL) << "Unsupported Type" << DataType_Name(data_type_); }
+ }
+}
+
+void Tensor::ToProto(singa::TensorProto *proto) const {
+ proto->clear_shape();
+ for (auto s : shape_) {
+ proto->add_shape(s);
+ }
+ proto->set_data_type(data_type_);
+ proto->set_transpose(transpose_);
+ switch (data_type_) {
+ case kFloat32: {
+ proto->clear_float_data();
+ const float *data_ptr = data<float>();
+ for (size_t i = 0; i < Product(shape_); ++i)
+ proto->add_float_data(data_ptr[i]);
+ break;
+ }
+ case kDouble: {
+ proto->clear_double_data();
+ const double *data_ptr = data<double>();
+ for (size_t i = 0; i < Product(shape_); ++i)
+ proto->add_double_data(data_ptr[i]);
+ break;
+ }
+ case kInt: {
+ proto->clear_int_data();
+ const int *data_ptr = data<int>();
+ for (size_t i = 0; i < Product(shape_); ++i)
+ proto->add_int_data(data_ptr[i]);
+ break;
+ }
+ /*
+ case kChar: {
+ proto->clear_bytes_data();
+ const char *data = data<char>();
+ for (size_t i = 0; i < Product(shape_); ++i)
+ proto->add_bytes_data(static_cast<unsigned char>(data[i]));
+ break;
+ }
+ case kUChar: {
+ proto->clear_bytes_data();
+ const unsigned char *data = data<unsigned char>();
+ for (size_t i = 0; i < Product(shape_); ++i)
+ proto->add_bytes_data(static_cast<unsigned char>(data[i]));
+ break;
+ }
+ */
+ default: { LOG(FATAL) << "Unsupported Type" << DataType_Name(data_type_); }
+ }
+}
+
Tensor Tensor::Clone(std::shared_ptr<Device> device) const {
if (device == nullptr) device = device_;
Tensor t(shape_, device_, data_type_);
@@ -292,6 +392,11 @@ void CopyDataToFrom(Tensor *dst, const Tensor &src, const size_t num,
{ __VA_ARGS__ } \
break; \
} \
+ case kDouble: { \
+ typedef double DType; \
+ { __VA_ARGS__ } \
+ break; \
+ } \
default: \
LOG(FATAL) << "Unknow data type = " << DataType_Name(type); \
} \
@@ -357,7 +462,6 @@ float Tensor::L2() const {
return nrm / Size();
}
-
template <typename SType>
void Tensor::SetValue(const SType x) {
CHECK_EQ(sizeof(SType), SizeOf(data_type_));
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/62c6603f/src/io/snapshot.cc
----------------------------------------------------------------------
diff --git a/src/io/snapshot.cc b/src/io/snapshot.cc
new file mode 100644
index 0000000..3b9b8ce
--- /dev/null
+++ b/src/io/snapshot.cc
@@ -0,0 +1,104 @@
+/************************************************************
+*
+* Licensed to the Apache Software Foundation (ASF) under one
+* or more contributor license agreements. See the NOTICE file
+* distributed with this work for additional information
+* regarding copyright ownership. The ASF licenses this file
+* to you under the Apache License, Version 2.0 (the
+* "License"); you may not use this file except in compliance
+* with the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing,
+* software distributed under the License is distributed on an
+* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+* KIND, either express or implied. See the License for the
+* specific language governing permissions and limitations
+* under the License.
+*
+*************************************************************/
+
+#include "singa/io/snapshot.h"
+
+#include <string>
+#include <unordered_set>
+#include <unordered_map>
+#include <memory>
+#include <utility>
+#include <iostream>
+
+namespace singa {
+Snapshot::Snapshot(const std::string& prefix, Mode mode)
+ : prefix_(prefix),
+ mode_(mode),
+ bin_writer_ptr_(mode_ == kWrite ? (new io::BinFileWriter) : nullptr),
+ text_writer_ptr_(mode_ == kWrite ? (new io::TextFileWriter) : nullptr),
+ bin_reader_ptr_(mode_ == kRead ? (new io::BinFileReader) : nullptr) {
+ if (mode_ == kWrite) {
+ bin_writer_ptr_->Open(prefix + ".model", io::kCreate);
+ text_writer_ptr_->Open(prefix + ".desc", io::kCreate);
+ } else if (mode == kRead) {
+ bin_reader_ptr_->Open(prefix + ".model");
+ std::string key, serialized_str;
+ singa::TensorProto tp;
+ while (bin_reader_ptr_->Read(&key, &serialized_str)) {
+ CHECK(param_names_.count(key) == 0);
+ param_names_.insert(key);
+ CHECK(tp.ParseFromString(serialized_str));
+ param_map_[key].FromProto(tp);
+ }
+ } else {
+ LOG(FATAL)
+ << "Mode for snapshot should be Snapshot::kWrite or Snapshot::kRead";
+ }
+}
+
+void Snapshot::Write(const std::string& key, const Tensor& param) {
+ CHECK(mode_ == kWrite);
+ CHECK(param_names_.count(key) == 0);
+ param_names_.insert(key);
+ TensorProto tp;
+ param.ToProto(&tp);
+ std::string serialized_str;
+ CHECK(tp.SerializeToString(&serialized_str));
+ bin_writer_ptr_->Write(key, serialized_str);
+
+ std::string desc_str = "parameter name: " + key;
+ Shape shape = param.shape();
+ desc_str += "\tdata type: " + std::to_string(param.data_type());
+ desc_str += "\tdim: " + std::to_string(shape.size());
+ desc_str += "\tshape:";
+ for (size_t s : shape) desc_str += " " + std::to_string(s);
+ text_writer_ptr_->Write(key, desc_str);
+}
+
+std::vector<std::pair<std::string, Tensor>> Snapshot::Read() {
+ CHECK(mode_ == kRead);
+ std::vector<std::pair<std::string, Tensor>> ret;
+ for (auto it = param_map_.begin(); it != param_map_.end(); ++it)
+ ret.push_back(*it);
+ return ret;
+}
+
+std::vector<std::pair<std::string, Shape>> Snapshot::ReadShape() {
+ CHECK(mode_ == kRead);
+ std::vector<std::pair<std::string, Shape>> ret;
+ for (auto it = param_map_.begin(); it != param_map_.end(); ++it)
+ ret.push_back(std::make_pair(it->first, it->second.shape()));
+ return ret;
+}
+
+Tensor Snapshot::Read(const std::string& key) {
+ CHECK(mode_ == kRead);
+ CHECK(param_map_.count(key) == 1);
+ return param_map_[key];
+}
+
+Shape Snapshot::ReadShape(const std::string& key) {
+ CHECK(mode_ == kRead);
+ CHECK(param_map_.count(key) == 1);
+ return param_map_[key].shape();
+}
+
+} // namespace singa
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/62c6603f/src/proto/core.proto
----------------------------------------------------------------------
diff --git a/src/proto/core.proto b/src/proto/core.proto
index b853b30..da32bc9 100644
--- a/src/proto/core.proto
+++ b/src/proto/core.proto
@@ -58,3 +58,14 @@ message MemPoolConf {
// cnmemflag = 2: prevent the manager from stealing memory
optional uint32 cnmemflag = 4 [default = 0];
}
+
+// For tensor serialization
+message TensorProto {
+ repeated uint32 shape = 1;
+ optional DataType data_type = 2;
+ optional bool transpose = 3;
+ repeated float float_data = 4;
+ repeated double double_data = 5;
+ repeated int32 int_data = 6;
+ repeated bytes bytes_data = 7;
+}
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/62c6603f/test/singa/test_snapshot.cc
----------------------------------------------------------------------
diff --git a/test/singa/test_snapshot.cc b/test/singa/test_snapshot.cc
new file mode 100644
index 0000000..26f1f8c
--- /dev/null
+++ b/test/singa/test_snapshot.cc
@@ -0,0 +1,116 @@
+/************************************************************
+*
+* Licensed to the Apache Software Foundation (ASF) under one
+* or more contributor license agreements. See the NOTICE file
+* distributed with this work for additional information
+* regarding copyright ownership. The ASF licenses this file
+* to you under the Apache License, Version 2.0 (the
+* "License"); you may not use this file except in compliance
+* with the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing,
+* software distributed under the License is distributed on an
+* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+* KIND, either express or implied. See the License for the
+* specific language governing permissions and limitations
+* under the License.
+*
+*************************************************************/
+
+#include "gtest/gtest.h"
+#include "singa/io/snapshot.h"
+#include "singa/io/reader.h"
+#include "singa/core/tensor.h"
+
+#include <string>
+#include <fstream>
+
+const std::string prefix = "./snapshot_test";
+const float param_1_data[] = {0.1, 0.2, 0.3, 0.4};
+const float param_2_data[] = {0.2, 0.1, 0.4, 0.3};
+const std::string desc_1 = "parameter name: Param_1\tdata type: 0\tdim: 1\tshape: 4";
+const std::string desc_2 = "parameter name: Param_2\tdata type: 0\tdim: 2\tshape: 2 2";
+const int int_data[] = {1, 3, 5, 7};
+const double double_data[] = {0.2, 0.4, 0.6, 0.8};
+
+TEST(Snapshot, WriteTest) {
+ singa::Snapshot snapshot(prefix, singa::Snapshot::kWrite);
+ singa::Tensor param_1(singa::Shape{4}), param_2(singa::Shape{2, 2});
+ param_1.CopyDataFromHostPtr(param_1_data, 4);
+ param_2.CopyDataFromHostPtr(param_2_data, 4);
+ snapshot.Write("Param_1", param_1);
+ snapshot.Write("Param_2", param_2);
+}
+
+TEST(Snapshot, ReadTest) {
+ singa::Snapshot snapshot(prefix, singa::Snapshot::kRead);
+ singa::Tensor param_1, param_2;
+ singa::Shape shape1, shape2;
+ shape1 = snapshot.ReadShape("Param_1");
+ EXPECT_EQ(shape1.size(), 1);
+ EXPECT_EQ(shape1[0], 4);
+ shape2 = snapshot.ReadShape("Param_2");
+ EXPECT_EQ(shape2.size(), 2);
+ EXPECT_EQ(shape2[0], 2);
+ EXPECT_EQ(shape2[1], 2);
+ param_1 = snapshot.Read("Param_1");
+ const float* data_1 = param_1.data<float>();
+ for (size_t i = 0; i < singa::Product(shape1); ++i)
+ EXPECT_FLOAT_EQ(data_1[i], param_1_data[i]);
+ param_2 = snapshot.Read("Param_2");
+ const float* data_2 = param_2.data<float>();
+ for (size_t i = 0; i < singa::Product(shape2); ++i)
+ EXPECT_FLOAT_EQ(data_2[i], param_2_data[i]);
+ std::ifstream desc_file(prefix+".desc");
+ std::string line;
+ getline(desc_file, line);
+ EXPECT_EQ(line, desc_1);
+ getline(desc_file, line);
+ EXPECT_EQ(line, desc_2);
+}
+
+TEST(Snapshot, ReadIntTest) {
+ {
+ singa::Snapshot int_snapshot_write(prefix+".int", singa::Snapshot::kWrite);
+ singa::Tensor int_param(singa::Shape{4});
+ int_param.AsType(singa::kInt);
+ int_param.CopyDataFromHostPtr(int_data, 4);
+ int_snapshot_write.Write("IntParam", int_param);
+ }
+
+ {
+ singa::Snapshot int_snapshot_read(prefix+".int", singa::Snapshot::kRead);
+ singa::Shape shape;
+ shape = int_snapshot_read.ReadShape("IntParam");
+ EXPECT_EQ(shape.size(), 1);
+ EXPECT_EQ(shape[0], 4);
+ singa::Tensor int_param = int_snapshot_read.Read("IntParam");
+ const int* param_data = int_param.data<int>();
+ for (size_t i = 0; i < singa::Product(shape); ++i)
+ EXPECT_EQ(param_data[i], int_data[i]);
+ }
+}
+
+TEST(Snapshot, ReadDoubleTest) {
+ {
+ singa::Snapshot double_snapshot_write(prefix+".double", singa::Snapshot::kWrite);
+ singa::Tensor double_param(singa::Shape{4});
+ double_param.AsType(singa::kDouble);
+ double_param.CopyDataFromHostPtr(double_data, 4);
+ double_snapshot_write.Write("DoubleParam", double_param);
+ }
+
+ {
+ singa::Snapshot double_snapshot_read(prefix+".double", singa::Snapshot::kRead);
+ singa::Shape shape;
+ shape = double_snapshot_read.ReadShape("DoubleParam");
+ EXPECT_EQ(shape.size(), 1);
+ EXPECT_EQ(shape[0], 4);
+ singa::Tensor double_param = double_snapshot_read.Read("DoubleParam");
+ const double* param_data = double_param.data<double>();
+ for (size_t i = 0; i < singa::Product(shape); ++i)
+ EXPECT_EQ(param_data[i], double_data[i]);
+ }
+}