You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@singa.apache.org by wa...@apache.org on 2016/06/29 07:20:28 UTC

incubator-singa git commit: SINGA-210 Enable checkpoint and resume for v1.0

Repository: incubator-singa
Updated Branches:
  refs/heads/dev d3c1bae61 -> 62c6603ff


SINGA-210 Enable checkpoint and resume for v1.0

This ticket is going to add code for dumping the model parameters as
checkpoint files, which could be used for fine-tuning and deployment.

Serialize Tensor into TensorProto and save it in BinFile, which is
stored as <prefix>.model, and generate description about parameters
in <prefix>.desc.

Unit test cases passed for kFloat, kInt and kDouble data type.


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/62c6603f
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/62c6603f
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/62c6603f

Branch: refs/heads/dev
Commit: 62c6603ff7a3fe9f9749021e84ad9ec35f3fef7d
Parents: d3c1bae
Author: WANG Ji <ij...@gmail.com>
Authored: Tue Jun 28 23:30:36 2016 +0800
Committer: WANG Ji <ij...@gmail.com>
Committed: Wed Jun 29 13:52:30 2016 +0800

----------------------------------------------------------------------
 include/singa/core/tensor.h |  38 +++++++------
 include/singa/io/snapshot.h |  79 ++++++++++++++++++++++++++
 src/core/tensor/tensor.cc   | 106 ++++++++++++++++++++++++++++++++++-
 src/io/snapshot.cc          | 104 +++++++++++++++++++++++++++++++++++
 src/proto/core.proto        |  11 ++++
 test/singa/test_snapshot.cc | 116 +++++++++++++++++++++++++++++++++++++++
 6 files changed, 437 insertions(+), 17 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/62c6603f/include/singa/core/tensor.h
----------------------------------------------------------------------
diff --git a/include/singa/core/tensor.h b/include/singa/core/tensor.h
index 15e7b7f..4ef3286 100644
--- a/include/singa/core/tensor.h
+++ b/include/singa/core/tensor.h
@@ -34,8 +34,9 @@ namespace singa {
 
 typedef vector<size_t> Shape;
 /// hardcode the width of types defined in DataType
-const size_t kDataWidth[] = {sizeof(float), sizeof(float) / 2, sizeof(int),
-                             sizeof(char), sizeof(double), sizeof(unsigned char)};
+const size_t kDataWidth[] = {sizeof(float),  sizeof(float) / 2,
+                             sizeof(int),    sizeof(char),
+                             sizeof(double), sizeof(unsigned char)};
 inline size_t SizeOf(DataType t) {
   static_assert(kNumDataType == sizeof(kDataWidth) / sizeof(size_t),
                 "Num of data types not match num of data width");
@@ -70,14 +71,14 @@ class Tensor {
   /// Users should not operate against Block directly.
   /// block_ is allocated in constructors.
   Block *block() const { return block_; }
-  void SetBlock(Block* block);
+  void SetBlock(Block *block);
 
   std::shared_ptr<Device> device() const { return device_; }
 
   /// return immutable Tensor values with given type.
   template <typename SType>
-  const SType* data() const {
-    return static_cast<const SType*>(block()->data());
+  const SType *data() const {
+    return static_cast<const SType *>(block()->data());
   }
 
   /// data type, including kFloat16, kFloat32, kInt
@@ -96,8 +97,7 @@ class Tensor {
 
   /// return number of total elements
   size_t Size() const {
-    if (block_ == nullptr)
-      return 0u;
+    if (block_ == nullptr) return 0u;
     CHECK_EQ(block_->size() % SizeOf(data_type_), 0u);
     return block_->size() / SizeOf(data_type_);
   }
@@ -110,7 +110,8 @@ class Tensor {
   void Reshape(Shape &&shape);
 
   /// Reset the shape, device, and data type as given tensor.
-  /// If block size changes, then reallocate a new block. The previous block would
+  /// If block size changes, then reallocate a new block. The previous block
+  /// would
   /// be deleted.
   void ResetLike(const Tensor &t);
 
@@ -138,6 +139,12 @@ class Tensor {
   /// Meta data would not be copied!
   void CopyData(const Tensor &other);
 
+  /// Deserialize data, shape and transpose from protobuf object.
+  void FromProto(const singa::TensorProto &proto);
+
+  /// Serialize data, shape and transpose to protobuf object.
+  void ToProto(singa::TensorProto *proto) const;
+
   /// return an exactly the same Tensor with data been deep copied to the given
   /// device. If 'device' is nullptr, then clone it one the current device.
   Tensor Clone(std::shared_ptr<Device> device = nullptr) const;
@@ -248,7 +255,6 @@ void Sqrt(const Tensor &in, Tensor *out);
 void Square(const Tensor &in, Tensor *out);
 void Tanh(const Tensor &in, Tensor *out);
 
-
 /// Element-wise opeartion, out[i]=in[i]^x
 template <typename SType>
 Tensor Pow(const Tensor &in, const SType x);
@@ -404,27 +410,27 @@ void Mult(const SType alpha, const Tensor &A, const Tensor &B, const SType beta,
 /// Compute the cross entropy loss given the prediction probability 'p' and
 /// the target (ground truth) labels 't'. 'p' and 't' are either 1-d vector
 /// or 2-d matrix. 'loss' is 1-d vector. The loss is computed into p.
-void ComputeCrossEntropy(const Tensor& p, const Tensor& t, Tensor* loss);
+void ComputeCrossEntropy(const Tensor &p, const Tensor &t, Tensor *loss);
 /// Compute the dx, given prediction probability 'p' (p=softmax(x)) and
 /// the target (ground truth) labels 't'. 'p' and 't' are either 1-d vector
 /// or 2-d matrix. 'grad' has the same shape as 'p'. dx is computed into p.
-void SoftmaxCrossEntropyBwd(const Tensor& t, Tensor* p);
+void SoftmaxCrossEntropyBwd(const Tensor &t, Tensor *p);
 
 /// Return a tensor consisting of rows ([start, end)) from 'in'. It shares the
 /// memory with 'in'. 'in' is a 1D or 2D Tensor.
-Tensor SliceRows(const Tensor& in, const size_t start, const size_t end);
+Tensor SliceRows(const Tensor &in, const size_t start, const size_t end);
 /// Return a tensor consisting of rows ([start, end)) from 'in'. It copies the
 /// values from 'in'. 'in' ia a 2D Tensor.
-Tensor CopyRows(const Tensor& in, const size_t start, const size_t end);
+Tensor CopyRows(const Tensor &in, const size_t start, const size_t end);
 /// Return a tensor consisting of columns ([start, end)) from 'in'. It copies
 /// the values from 'in'. 'in' is a  2D Tensor.
-Tensor CopyColumns(const Tensor& in, const size_t start, const size_t end);
+Tensor CopyColumns(const Tensor &in, const size_t start, const size_t end);
 /// Return a tensor which is vertically stacked from tensors in 'in'. Each
 /// tensor in 'in' is a 2D tensor. Values are copied, no memory sharing.
-Tensor ConcatenateRows(const vector<Tensor>& in);
+Tensor ConcatenateRows(const vector<Tensor> &in);
 /// Return a tensor which is horizontally stacked from tensors in 'in'. Each
 /// tensor in 'in' is a 2D tensor. Values are copied, no memory sharing.
-Tensor ConcatenateColumns(const vector<Tensor>& in);
+Tensor ConcatenateColumns(const vector<Tensor> &in);
 }  // namespace singa
 
 #endif  // SINGA_CORE_TENSOR_H_

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/62c6603f/include/singa/io/snapshot.h
----------------------------------------------------------------------
diff --git a/include/singa/io/snapshot.h b/include/singa/io/snapshot.h
new file mode 100644
index 0000000..7545572
--- /dev/null
+++ b/include/singa/io/snapshot.h
@@ -0,0 +1,79 @@
+/************************************************************
+*
+* Licensed to the Apache Software Foundation (ASF) under one
+* or more contributor license agreements.  See the NOTICE file
+* distributed with this work for additional information
+* regarding copyright ownership.  The ASF licenses this file
+* to you under the Apache License, Version 2.0 (the
+* "License"); you may not use this file except in compliance
+* with the License.  You may obtain a copy of the License at
+*
+*   http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing,
+* software distributed under the License is distributed on an
+* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+* KIND, either express or implied.  See the License for the
+* specific language governing permissions and limitations
+* under the License.
+*
+*************************************************************/
+
+#ifndef SINGA_UTILS_SNAPSHOT_H_
+#define SINGA_UTILS_SNAPSHOT_H_
+
+#include "singa/io/reader.h"
+#include "singa/io/writer.h"
+#include "singa/utils/logging.h"
+#include "singa/proto/core.pb.h"
+#include "singa/core/tensor.h"
+
+#include <string>
+#include <unordered_set>
+#include <unordered_map>
+#include <memory>
+
+namespace singa {
+/// The snapshot management.
+/// It dumps the model parameter snapshot as checkpoint files, which coud be
+/// used for fine-tuning and deployment.
+/// The model paramters are separated from model definition, i.e., net
+/// construction. Users either randomly initialize the layer parameters or using
+/// the parameters from checkpoint files using Snapshot after creating the
+/// neural network.
+class Snapshot {
+ public:
+  enum Mode { kRead, kWrite };
+  /// <prefix>.model is the binary file for parameter key-value pair.
+  /// <prefix>.meta is the text file describing information about paramters,
+  /// i.e.
+  /// name and shape, one line per parameter.
+  /// kRead for reading snapshot, whereas kWrite for dumping out snapshot.
+  Snapshot(const std::string& prefix, Mode mode);
+  ~Snapshot() {}
+  /// Read parameters saved as tensors from checkpoint file.
+  std::vector<std::pair<std::string, Tensor>> Read();
+  /// Read parameter shapes from description file.
+  std::vector<std::pair<std::string, Shape>> ReadShape();
+  /// Read parameter returned as a tensor for a given parameter name.
+  Tensor Read(const std::string& Key);
+  /// Read parameter shape for a given parameter name.
+  Shape ReadShape(const std::string& key);
+  /// Serialize and dump out parameter. This method will write two files, one
+  /// binary file is for serialized tensors, the other csv file is for parameter
+  /// names and shapes.
+  void Write(const std::string& key, const Tensor& param);
+
+ private:
+  std::string prefix_;
+  Mode mode_;
+  std::unique_ptr<io::Writer> bin_writer_ptr_, text_writer_ptr_;
+  std::unique_ptr<io::Reader> bin_reader_ptr_;
+  /// Check whether parameter name is unique.
+  std::unordered_set<std::string> param_names_;
+  /// Preload key-parameter tensor pairs for seeking a specified key.
+  std::unordered_map<std::string, Tensor> param_map_;
+};
+}  //  namespace singa
+
+#endif  //  SINGA_UTILS_SNAPSHOT_H_

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/62c6603f/src/core/tensor/tensor.cc
----------------------------------------------------------------------
diff --git a/src/core/tensor/tensor.cc b/src/core/tensor/tensor.cc
index b07a23c..3501ecd 100644
--- a/src/core/tensor/tensor.cc
+++ b/src/core/tensor/tensor.cc
@@ -159,6 +159,106 @@ void Tensor::CopyData(const Tensor &src) {
   }
 }
 
+void Tensor::FromProto(const singa::TensorProto &proto) {
+  if (block_ != nullptr && block_->DecRefCount() == 0)
+    device_->FreeBlock(block_);
+  block_ = nullptr;
+  Shape shape;
+  for (uint32_t s : proto.shape()) shape.push_back(s);
+  data_type_ = proto.data_type();
+  Reshape(shape);
+  transpose_ = proto.transpose();
+  switch (data_type_) {
+    case kFloat32: {
+      std::unique_ptr<float[]> data_ptr(new float[Product(shape_)]);
+      for (size_t i = 0; i < Product(shape_); ++i)
+        data_ptr[i] = static_cast<float>(proto.float_data(i));
+      CopyDataFromHostPtr<float>(data_ptr.get(), Product(shape_));
+      break;
+    }
+    case kDouble: {
+      std::unique_ptr<double[]> data(new double[Product(shape_)]);
+      for (size_t i = 0; i < Product(shape_); ++i)
+        data[i] = proto.double_data(i);
+      CopyDataFromHostPtr<double>(data.get(), Product(shape_));
+      break;
+    }
+    case kInt: {
+      std::unique_ptr<int[]> data(new int[Product(shape_)]);
+      for (size_t i = 0; i < Product(shape_); ++i) data[i] = proto.int_data(i);
+      CopyDataFromHostPtr<int>(data.get(), Product(shape_));
+      break;
+    }
+    ///TODO(wangji): Implement to support C++ type char using bytes type in protobuf
+    /// which is equivalent to string type is different from the other cases. The kchar
+    /// and kUChar case is to be implemented.
+    /*
+    case kChar: {
+      std::unique_ptr<char[]> data(new char[Product(shape_)]);
+      for (size_t i = 0; i < Product(shape_); ++i)
+        data[i] = static_cast<char>(proto.bytes_data(i));
+      break;
+    }
+    case kUChar: {
+      std::unique_ptr<unsigned char[]> data(new unsigned char[Product(shape_)]);
+      for (size_t i = 0; i < Product(shape_); ++i)
+        data[i] = static_cast<unsigned char>(proto.bytes_data(i));
+      break;
+    }
+    */
+    default: { LOG(FATAL) << "Unsupported Type" << DataType_Name(data_type_); }
+  }
+}
+
+void Tensor::ToProto(singa::TensorProto *proto) const {
+  proto->clear_shape();
+  for (auto s : shape_) {
+    proto->add_shape(s);
+  }
+  proto->set_data_type(data_type_);
+  proto->set_transpose(transpose_);
+  switch (data_type_) {
+    case kFloat32: {
+      proto->clear_float_data();
+      const float *data_ptr = data<float>();
+      for (size_t i = 0; i < Product(shape_); ++i)
+        proto->add_float_data(data_ptr[i]);
+      break;
+    }
+    case kDouble: {
+      proto->clear_double_data();
+      const double *data_ptr = data<double>();
+      for (size_t i = 0; i < Product(shape_); ++i)
+        proto->add_double_data(data_ptr[i]);
+      break;
+    }
+    case kInt: {
+      proto->clear_int_data();
+      const int *data_ptr = data<int>();
+      for (size_t i = 0; i < Product(shape_); ++i)
+        proto->add_int_data(data_ptr[i]);
+      break;
+    }
+    /*
+    case kChar: {
+      proto->clear_bytes_data();
+      const char *data = data<char>();
+      for (size_t i = 0; i < Product(shape_); ++i)
+        proto->add_bytes_data(static_cast<unsigned char>(data[i]));
+      break;
+    }
+    case kUChar: {
+      proto->clear_bytes_data();
+      const unsigned char *data = data<unsigned char>();
+      for (size_t i = 0; i < Product(shape_); ++i)
+        proto->add_bytes_data(static_cast<unsigned char>(data[i]));
+      break;
+    }
+    */
+    default: { LOG(FATAL) << "Unsupported Type" << DataType_Name(data_type_); }
+  }
+}
+
 Tensor Tensor::Clone(std::shared_ptr<Device> device) const {
   if (device == nullptr) device = device_;
   Tensor t(shape_, device_, data_type_);
@@ -292,6 +392,11 @@ void CopyDataToFrom(Tensor *dst, const Tensor &src, const size_t num,
         { __VA_ARGS__ }                                             \
         break;                                                      \
       }                                                             \
+      case kDouble: {                                               \
+        typedef double DType;                                       \
+        { __VA_ARGS__ }                                             \
+        break;                                                      \
+      }                                                             \
       default:                                                      \
         LOG(FATAL) << "Unknow data type = " << DataType_Name(type); \
     }                                                               \
@@ -357,7 +462,6 @@ float Tensor::L2() const {
   return nrm / Size();
 }
 
-
 template <typename SType>
 void Tensor::SetValue(const SType x) {
   CHECK_EQ(sizeof(SType), SizeOf(data_type_));

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/62c6603f/src/io/snapshot.cc
----------------------------------------------------------------------
diff --git a/src/io/snapshot.cc b/src/io/snapshot.cc
new file mode 100644
index 0000000..3b9b8ce
--- /dev/null
+++ b/src/io/snapshot.cc
@@ -0,0 +1,104 @@
+/************************************************************
+*
+* Licensed to the Apache Software Foundation (ASF) under one
+* or more contributor license agreements.  See the NOTICE file
+* distributed with this work for additional information
+* regarding copyright ownership.  The ASF licenses this file
+* to you under the Apache License, Version 2.0 (the
+* "License"); you may not use this file except in compliance
+* with the License.  You may obtain a copy of the License at
+*
+*   http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing,
+* software distributed under the License is distributed on an
+* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+* KIND, either express or implied.  See the License for the
+* specific language governing permissions and limitations
+* under the License.
+*
+*************************************************************/
+
+#include "singa/io/snapshot.h"
+
+#include <string>
+#include <unordered_set>
+#include <unordered_map>
+#include <memory>
+#include <utility>
+#include <iostream>
+
+namespace singa {
+Snapshot::Snapshot(const std::string& prefix, Mode mode)
+    : prefix_(prefix),
+      mode_(mode),
+      bin_writer_ptr_(mode_ == kWrite ? (new io::BinFileWriter) : nullptr),
+      text_writer_ptr_(mode_ == kWrite ? (new io::TextFileWriter) : nullptr),
+      bin_reader_ptr_(mode_ == kRead ? (new io::BinFileReader) : nullptr) {
+  if (mode_ == kWrite) {
+    bin_writer_ptr_->Open(prefix + ".model", io::kCreate);
+    text_writer_ptr_->Open(prefix + ".desc", io::kCreate);
+  } else if (mode == kRead) {
+    bin_reader_ptr_->Open(prefix + ".model");
+    std::string key, serialized_str;
+    singa::TensorProto tp;
+    while (bin_reader_ptr_->Read(&key, &serialized_str)) {
+      CHECK(param_names_.count(key) == 0);
+      param_names_.insert(key);
+      CHECK(tp.ParseFromString(serialized_str));
+      param_map_[key].FromProto(tp);
+    }
+  } else {
+    LOG(FATAL)
+        << "Mode for snapshot should be Snapshot::kWrite or Snapshot::kRead";
+  }
+}
+
+void Snapshot::Write(const std::string& key, const Tensor& param) {
+  CHECK(mode_ == kWrite);
+  CHECK(param_names_.count(key) == 0);
+  param_names_.insert(key);
+  TensorProto tp;
+  param.ToProto(&tp);
+  std::string serialized_str;
+  CHECK(tp.SerializeToString(&serialized_str));
+  bin_writer_ptr_->Write(key, serialized_str);
+
+  std::string desc_str = "parameter name: " + key;
+  Shape shape = param.shape();
+  desc_str += "\tdata type: " + std::to_string(param.data_type());
+  desc_str += "\tdim: " + std::to_string(shape.size());
+  desc_str += "\tshape:";
+  for (size_t s : shape) desc_str += " " + std::to_string(s);
+  text_writer_ptr_->Write(key, desc_str);
+}
+
+std::vector<std::pair<std::string, Tensor>> Snapshot::Read() {
+  CHECK(mode_ == kRead);
+  std::vector<std::pair<std::string, Tensor>> ret;
+  for (auto it = param_map_.begin(); it != param_map_.end(); ++it)
+    ret.push_back(*it);
+  return ret;
+}
+
+std::vector<std::pair<std::string, Shape>> Snapshot::ReadShape() {
+  CHECK(mode_ == kRead);
+  std::vector<std::pair<std::string, Shape>> ret;
+  for (auto it = param_map_.begin(); it != param_map_.end(); ++it)
+    ret.push_back(std::make_pair(it->first, it->second.shape()));
+  return ret;
+}
+
+Tensor Snapshot::Read(const std::string& key) {
+  CHECK(mode_ == kRead);
+  CHECK(param_map_.count(key) == 1);
+  return param_map_[key];
+}
+
+Shape Snapshot::ReadShape(const std::string& key) {
+  CHECK(mode_ == kRead);
+  CHECK(param_map_.count(key) == 1);
+  return param_map_[key].shape();
+}
+
+}  //  namespace singa

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/62c6603f/src/proto/core.proto
----------------------------------------------------------------------
diff --git a/src/proto/core.proto b/src/proto/core.proto
index b853b30..da32bc9 100644
--- a/src/proto/core.proto
+++ b/src/proto/core.proto
@@ -58,3 +58,14 @@ message MemPoolConf {
 	// cnmemflag = 2: prevent the manager from stealing memory
 	optional uint32 cnmemflag = 4 [default = 0];
 }
+
+// For tensor serialization
+message TensorProto {
+  repeated uint32 shape = 1;
+  optional DataType data_type = 2;
+  optional bool transpose = 3;
+  repeated float float_data = 4;
+  repeated double double_data = 5;
+  repeated int32 int_data = 6;
+  repeated bytes bytes_data = 7;
+}

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/62c6603f/test/singa/test_snapshot.cc
----------------------------------------------------------------------
diff --git a/test/singa/test_snapshot.cc b/test/singa/test_snapshot.cc
new file mode 100644
index 0000000..26f1f8c
--- /dev/null
+++ b/test/singa/test_snapshot.cc
@@ -0,0 +1,116 @@
+/************************************************************
+*
+* Licensed to the Apache Software Foundation (ASF) under one
+* or more contributor license agreements.  See the NOTICE file
+* distributed with this work for additional information
+* regarding copyright ownership.  The ASF licenses this file
+* to you under the Apache License, Version 2.0 (the
+* "License"); you may not use this file except in compliance
+* with the License.  You may obtain a copy of the License at
+*
+*   http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing,
+* software distributed under the License is distributed on an
+* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+* KIND, either express or implied.  See the License for the
+* specific language governing permissions and limitations
+* under the License.
+*
+*************************************************************/
+
+#include "gtest/gtest.h"
+#include "singa/io/snapshot.h"
+#include "singa/io/reader.h"
+#include "singa/core/tensor.h"
+
+#include <string>
+#include <fstream>
+
+const std::string prefix = "./snapshot_test";
+const float param_1_data[] = {0.1, 0.2, 0.3, 0.4};
+const float param_2_data[] = {0.2, 0.1, 0.4, 0.3};
+const std::string desc_1 = "parameter name: Param_1\tdata type: 0\tdim: 1\tshape: 4";
+const std::string desc_2 = "parameter name: Param_2\tdata type: 0\tdim: 2\tshape: 2 2";
+const int int_data[] = {1, 3, 5, 7};
+const double double_data[] = {0.2, 0.4, 0.6, 0.8};
+
+TEST(Snapshot, WriteTest) {
+  singa::Snapshot snapshot(prefix, singa::Snapshot::kWrite);
+  singa::Tensor param_1(singa::Shape{4}), param_2(singa::Shape{2, 2});
+  param_1.CopyDataFromHostPtr(param_1_data, 4);
+  param_2.CopyDataFromHostPtr(param_2_data, 4);
+  snapshot.Write("Param_1", param_1);
+  snapshot.Write("Param_2", param_2);
+}
+
+TEST(Snapshot, ReadTest) {
+  singa::Snapshot snapshot(prefix, singa::Snapshot::kRead);
+  singa::Tensor param_1, param_2;
+  singa::Shape shape1, shape2;
+  shape1 = snapshot.ReadShape("Param_1");
+  EXPECT_EQ(shape1.size(), 1);
+  EXPECT_EQ(shape1[0], 4);
+  shape2 = snapshot.ReadShape("Param_2");
+  EXPECT_EQ(shape2.size(), 2);
+  EXPECT_EQ(shape2[0], 2);
+  EXPECT_EQ(shape2[1], 2);
+  param_1 = snapshot.Read("Param_1");
+  const float* data_1 = param_1.data<float>();
+  for (size_t i = 0; i < singa::Product(shape1); ++i)
+    EXPECT_FLOAT_EQ(data_1[i], param_1_data[i]);
+  param_2 = snapshot.Read("Param_2");
+  const float* data_2 = param_2.data<float>();
+  for (size_t i = 0; i < singa::Product(shape2); ++i)
+    EXPECT_FLOAT_EQ(data_2[i], param_2_data[i]);
+  std::ifstream desc_file(prefix+".desc");
+  std::string line;
+  getline(desc_file, line);
+  EXPECT_EQ(line, desc_1);
+  getline(desc_file, line);
+  EXPECT_EQ(line, desc_2);
+}
+
+TEST(Snapshot, ReadIntTest) {
+  {
+    singa::Snapshot int_snapshot_write(prefix+".int", singa::Snapshot::kWrite);
+    singa::Tensor int_param(singa::Shape{4});
+    int_param.AsType(singa::kInt);
+    int_param.CopyDataFromHostPtr(int_data, 4);
+    int_snapshot_write.Write("IntParam", int_param);
+  }
+
+  {
+    singa::Snapshot int_snapshot_read(prefix+".int", singa::Snapshot::kRead);
+    singa::Shape shape;
+    shape = int_snapshot_read.ReadShape("IntParam");
+    EXPECT_EQ(shape.size(), 1);
+    EXPECT_EQ(shape[0], 4);
+    singa::Tensor int_param = int_snapshot_read.Read("IntParam");
+    const int* param_data = int_param.data<int>();
+    for (size_t i = 0; i < singa::Product(shape); ++i)
+      EXPECT_EQ(param_data[i], int_data[i]);
+  }
+}
+
+TEST(Snapshot, ReadDoubleTest) {
+  {
+    singa::Snapshot double_snapshot_write(prefix+".double", singa::Snapshot::kWrite);
+    singa::Tensor double_param(singa::Shape{4});
+    double_param.AsType(singa::kDouble);
+    double_param.CopyDataFromHostPtr(double_data, 4);
+    double_snapshot_write.Write("DoubleParam", double_param);
+  }
+
+  {
+    singa::Snapshot double_snapshot_read(prefix+".double", singa::Snapshot::kRead);
+    singa::Shape shape;
+    shape = double_snapshot_read.ReadShape("DoubleParam");
+    EXPECT_EQ(shape.size(), 1);
+    EXPECT_EQ(shape[0], 4);
+    singa::Tensor double_param = double_snapshot_read.Read("DoubleParam");
+    const double* param_data = double_param.data<double>();
+    for (size_t i = 0; i < singa::Product(shape); ++i)
+      EXPECT_EQ(param_data[i], double_data[i]);
+  }
+}