Repository: incubator-singa Updated Branches: refs/heads/dev d3c1bae61 -> 62c6603ff
SINGA-210 Enable checkpoint and resume for v1.0 This ticket is going to add code for dumping the model parameters as checkpoint files, which could be used for fine-tuning and deployment. Serialize Tensor into TensorProto and save it in BinFile, which is stored as <prefix>.model, and generate description about parameters in <prefix>.desc. Unit test cases passed for kFloat, kInt and kDouble data type. Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/62c6603f Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/62c6603f Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/62c6603f Branch: refs/heads/dev Commit: 62c6603ff7a3fe9f9749021e84ad9ec35f3fef7d Parents: d3c1bae Author: WANG Ji <[email protected]> Authored: Tue Jun 28 23:30:36 2016 +0800 Committer: WANG Ji <[email protected]> Committed: Wed Jun 29 13:52:30 2016 +0800 ---------------------------------------------------------------------- include/singa/core/tensor.h | 38 +++++++------ include/singa/io/snapshot.h | 79 ++++++++++++++++++++++++++ src/core/tensor/tensor.cc | 106 ++++++++++++++++++++++++++++++++++- src/io/snapshot.cc | 104 +++++++++++++++++++++++++++++++++++ src/proto/core.proto | 11 ++++ test/singa/test_snapshot.cc | 116 +++++++++++++++++++++++++++++++++++++++ 6 files changed, 437 insertions(+), 17 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/62c6603f/include/singa/core/tensor.h ---------------------------------------------------------------------- diff --git a/include/singa/core/tensor.h b/include/singa/core/tensor.h index 15e7b7f..4ef3286 100644 --- a/include/singa/core/tensor.h +++ b/include/singa/core/tensor.h @@ -34,8 +34,9 @@ namespace singa { typedef vector<size_t> Shape; /// hardcode the width of types defined in DataType -const size_t kDataWidth[] = {sizeof(float), sizeof(float) / 2, sizeof(int), - sizeof(char), sizeof(double), sizeof(unsigned char)}; +const size_t kDataWidth[] = {sizeof(float), sizeof(float) / 2, + sizeof(int), sizeof(char), + sizeof(double), sizeof(unsigned char)}; inline size_t SizeOf(DataType t) { static_assert(kNumDataType == sizeof(kDataWidth) / sizeof(size_t), "Num of data types not match num of data width"); @@ -70,14 +71,14 @@ class Tensor { /// Users should not operate against Block directly. /// block_ is allocated in constructors. Block *block() const { return block_; } - void SetBlock(Block* block); + void SetBlock(Block *block); std::shared_ptr<Device> device() const { return device_; } /// return immutable Tensor values with given type. template <typename SType> - const SType* data() const { - return static_cast<const SType*>(block()->data()); + const SType *data() const { + return static_cast<const SType *>(block()->data()); } /// data type, including kFloat16, kFloat32, kInt @@ -96,8 +97,7 @@ class Tensor { /// return number of total elements size_t Size() const { - if (block_ == nullptr) - return 0u; + if (block_ == nullptr) return 0u; CHECK_EQ(block_->size() % SizeOf(data_type_), 0u); return block_->size() / SizeOf(data_type_); } @@ -110,7 +110,8 @@ class Tensor { void Reshape(Shape &&shape); /// Reset the shape, device, and data type as given tensor. - /// If block size changes, then reallocate a new block. The previous block would + /// If block size changes, then reallocate a new block. The previous block + /// would /// be deleted. void ResetLike(const Tensor &t); @@ -138,6 +139,12 @@ class Tensor { /// Meta data would not be copied! void CopyData(const Tensor &other); + /// Deserialize data, shape and transpose from protobuf object. + void FromProto(const singa::TensorProto &proto); + + /// Serialize data, shape and transpose to protobuf object. + void ToProto(singa::TensorProto *proto) const; + /// return an exactly the same Tensor with data been deep copied to the given /// device. If 'device' is nullptr, then clone it one the current device. Tensor Clone(std::shared_ptr<Device> device = nullptr) const; @@ -248,7 +255,6 @@ void Sqrt(const Tensor &in, Tensor *out); void Square(const Tensor &in, Tensor *out); void Tanh(const Tensor &in, Tensor *out); - /// Element-wise opeartion, out[i]=in[i]^x template <typename SType> Tensor Pow(const Tensor &in, const SType x); @@ -404,27 +410,27 @@ void Mult(const SType alpha, const Tensor &A, const Tensor &B, const SType beta, /// Compute the cross entropy loss given the prediction probability 'p' and /// the target (ground truth) labels 't'. 'p' and 't' are either 1-d vector /// or 2-d matrix. 'loss' is 1-d vector. The loss is computed into p. -void ComputeCrossEntropy(const Tensor& p, const Tensor& t, Tensor* loss); +void ComputeCrossEntropy(const Tensor &p, const Tensor &t, Tensor *loss); /// Compute the dx, given prediction probability 'p' (p=softmax(x)) and /// the target (ground truth) labels 't'. 'p' and 't' are either 1-d vector /// or 2-d matrix. 'grad' has the same shape as 'p'. dx is computed into p. -void SoftmaxCrossEntropyBwd(const Tensor& t, Tensor* p); +void SoftmaxCrossEntropyBwd(const Tensor &t, Tensor *p); /// Return a tensor consisting of rows ([start, end)) from 'in'. It shares the /// memory with 'in'. 'in' is a 1D or 2D Tensor. -Tensor SliceRows(const Tensor& in, const size_t start, const size_t end); +Tensor SliceRows(const Tensor &in, const size_t start, const size_t end); /// Return a tensor consisting of rows ([start, end)) from 'in'. It copies the /// values from 'in'. 'in' ia a 2D Tensor. -Tensor CopyRows(const Tensor& in, const size_t start, const size_t end); +Tensor CopyRows(const Tensor &in, const size_t start, const size_t end); /// Return a tensor consisting of columns ([start, end)) from 'in'. It copies /// the values from 'in'. 'in' is a 2D Tensor. -Tensor CopyColumns(const Tensor& in, const size_t start, const size_t end); +Tensor CopyColumns(const Tensor &in, const size_t start, const size_t end); /// Return a tensor which is vertically stacked from tensors in 'in'. Each /// tensor in 'in' is a 2D tensor. Values are copied, no memory sharing. -Tensor ConcatenateRows(const vector<Tensor>& in); +Tensor ConcatenateRows(const vector<Tensor> &in); /// Return a tensor which is horizontally stacked from tensors in 'in'. Each /// tensor in 'in' is a 2D tensor. Values are copied, no memory sharing. -Tensor ConcatenateColumns(const vector<Tensor>& in); +Tensor ConcatenateColumns(const vector<Tensor> &in); } // namespace singa #endif // SINGA_CORE_TENSOR_H_ http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/62c6603f/include/singa/io/snapshot.h ---------------------------------------------------------------------- diff --git a/include/singa/io/snapshot.h b/include/singa/io/snapshot.h new file mode 100644 index 0000000..7545572 --- /dev/null +++ b/include/singa/io/snapshot.h @@ -0,0 +1,79 @@ +/************************************************************ +* +* Licensed to the Apache Software Foundation (ASF) under one +* or more contributor license agreements. See the NOTICE file +* distributed with this work for additional information +* regarding copyright ownership. The ASF licenses this file +* to you under the Apache License, Version 2.0 (the +* "License"); you may not use this file except in compliance +* with the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an +* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +* KIND, either express or implied. See the License for the +* specific language governing permissions and limitations +* under the License. +* +*************************************************************/ + +#ifndef SINGA_UTILS_SNAPSHOT_H_ +#define SINGA_UTILS_SNAPSHOT_H_ + +#include "singa/io/reader.h" +#include "singa/io/writer.h" +#include "singa/utils/logging.h" +#include "singa/proto/core.pb.h" +#include "singa/core/tensor.h" + +#include <string> +#include <unordered_set> +#include <unordered_map> +#include <memory> + +namespace singa { +/// The snapshot management. +/// It dumps the model parameter snapshot as checkpoint files, which coud be +/// used for fine-tuning and deployment. +/// The model paramters are separated from model definition, i.e., net +/// construction. Users either randomly initialize the layer parameters or using +/// the parameters from checkpoint files using Snapshot after creating the +/// neural network. +class Snapshot { + public: + enum Mode { kRead, kWrite }; + /// <prefix>.model is the binary file for parameter key-value pair. + /// <prefix>.meta is the text file describing information about paramters, + /// i.e. + /// name and shape, one line per parameter. + /// kRead for reading snapshot, whereas kWrite for dumping out snapshot. + Snapshot(const std::string& prefix, Mode mode); + ~Snapshot() {} + /// Read parameters saved as tensors from checkpoint file. + std::vector<std::pair<std::string, Tensor>> Read(); + /// Read parameter shapes from description file. + std::vector<std::pair<std::string, Shape>> ReadShape(); + /// Read parameter returned as a tensor for a given parameter name. + Tensor Read(const std::string& Key); + /// Read parameter shape for a given parameter name. + Shape ReadShape(const std::string& key); + /// Serialize and dump out parameter. This method will write two files, one + /// binary file is for serialized tensors, the other csv file is for parameter + /// names and shapes. + void Write(const std::string& key, const Tensor& param); + + private: + std::string prefix_; + Mode mode_; + std::unique_ptr<io::Writer> bin_writer_ptr_, text_writer_ptr_; + std::unique_ptr<io::Reader> bin_reader_ptr_; + /// Check whether parameter name is unique. + std::unordered_set<std::string> param_names_; + /// Preload key-parameter tensor pairs for seeking a specified key. + std::unordered_map<std::string, Tensor> param_map_; +}; +} // namespace singa + +#endif // SINGA_UTILS_SNAPSHOT_H_ http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/62c6603f/src/core/tensor/tensor.cc ---------------------------------------------------------------------- diff --git a/src/core/tensor/tensor.cc b/src/core/tensor/tensor.cc index b07a23c..3501ecd 100644 --- a/src/core/tensor/tensor.cc +++ b/src/core/tensor/tensor.cc @@ -159,6 +159,106 @@ void Tensor::CopyData(const Tensor &src) { } } +void Tensor::FromProto(const singa::TensorProto &proto) { + if (block_ != nullptr && block_->DecRefCount() == 0) + device_->FreeBlock(block_); + block_ = nullptr; + Shape shape; + for (uint32_t s : proto.shape()) shape.push_back(s); + data_type_ = proto.data_type(); + Reshape(shape); + transpose_ = proto.transpose(); + switch (data_type_) { + case kFloat32: { + std::unique_ptr<float[]> data_ptr(new float[Product(shape_)]); + for (size_t i = 0; i < Product(shape_); ++i) + data_ptr[i] = static_cast<float>(proto.float_data(i)); + CopyDataFromHostPtr<float>(data_ptr.get(), Product(shape_)); + break; + } + case kDouble: { + std::unique_ptr<double[]> data(new double[Product(shape_)]); + for (size_t i = 0; i < Product(shape_); ++i) + data[i] = proto.double_data(i); + CopyDataFromHostPtr<double>(data.get(), Product(shape_)); + break; + } + case kInt: { + std::unique_ptr<int[]> data(new int[Product(shape_)]); + for (size_t i = 0; i < Product(shape_); ++i) data[i] = proto.int_data(i); + CopyDataFromHostPtr<int>(data.get(), Product(shape_)); + break; + } + ///TODO(wangji): Implement to support C++ type char using bytes type in protobuf + /// which is equivalent to string type is different from the other cases. The kchar + /// and kUChar case is to be implemented. + /* + case kChar: { + std::unique_ptr<char[]> data(new char[Product(shape_)]); + for (size_t i = 0; i < Product(shape_); ++i) + data[i] = static_cast<char>(proto.bytes_data(i)); + break; + } + case kUChar: { + std::unique_ptr<unsigned char[]> data(new unsigned char[Product(shape_)]); + for (size_t i = 0; i < Product(shape_); ++i) + data[i] = static_cast<unsigned char>(proto.bytes_data(i)); + break; + } + */ + default: { LOG(FATAL) << "Unsupported Type" << DataType_Name(data_type_); } + } +} + +void Tensor::ToProto(singa::TensorProto *proto) const { + proto->clear_shape(); + for (auto s : shape_) { + proto->add_shape(s); + } + proto->set_data_type(data_type_); + proto->set_transpose(transpose_); + switch (data_type_) { + case kFloat32: { + proto->clear_float_data(); + const float *data_ptr = data<float>(); + for (size_t i = 0; i < Product(shape_); ++i) + proto->add_float_data(data_ptr[i]); + break; + } + case kDouble: { + proto->clear_double_data(); + const double *data_ptr = data<double>(); + for (size_t i = 0; i < Product(shape_); ++i) + proto->add_double_data(data_ptr[i]); + break; + } + case kInt: { + proto->clear_int_data(); + const int *data_ptr = data<int>(); + for (size_t i = 0; i < Product(shape_); ++i) + proto->add_int_data(data_ptr[i]); + break; + } + /* + case kChar: { + proto->clear_bytes_data(); + const char *data = data<char>(); + for (size_t i = 0; i < Product(shape_); ++i) + proto->add_bytes_data(static_cast<unsigned char>(data[i])); + break; + } + case kUChar: { + proto->clear_bytes_data(); + const unsigned char *data = data<unsigned char>(); + for (size_t i = 0; i < Product(shape_); ++i) + proto->add_bytes_data(static_cast<unsigned char>(data[i])); + break; + } + */ + default: { LOG(FATAL) << "Unsupported Type" << DataType_Name(data_type_); } + } +} + Tensor Tensor::Clone(std::shared_ptr<Device> device) const { if (device == nullptr) device = device_; Tensor t(shape_, device_, data_type_); @@ -292,6 +392,11 @@ void CopyDataToFrom(Tensor *dst, const Tensor &src, const size_t num, { __VA_ARGS__ } \ break; \ } \ + case kDouble: { \ + typedef double DType; \ + { __VA_ARGS__ } \ + break; \ + } \ default: \ LOG(FATAL) << "Unknow data type = " << DataType_Name(type); \ } \ @@ -357,7 +462,6 @@ float Tensor::L2() const { return nrm / Size(); } - template <typename SType> void Tensor::SetValue(const SType x) { CHECK_EQ(sizeof(SType), SizeOf(data_type_)); http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/62c6603f/src/io/snapshot.cc ---------------------------------------------------------------------- diff --git a/src/io/snapshot.cc b/src/io/snapshot.cc new file mode 100644 index 0000000..3b9b8ce --- /dev/null +++ b/src/io/snapshot.cc @@ -0,0 +1,104 @@ +/************************************************************ +* +* Licensed to the Apache Software Foundation (ASF) under one +* or more contributor license agreements. See the NOTICE file +* distributed with this work for additional information +* regarding copyright ownership. The ASF licenses this file +* to you under the Apache License, Version 2.0 (the +* "License"); you may not use this file except in compliance +* with the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an +* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +* KIND, either express or implied. See the License for the +* specific language governing permissions and limitations +* under the License. +* +*************************************************************/ + +#include "singa/io/snapshot.h" + +#include <string> +#include <unordered_set> +#include <unordered_map> +#include <memory> +#include <utility> +#include <iostream> + +namespace singa { +Snapshot::Snapshot(const std::string& prefix, Mode mode) + : prefix_(prefix), + mode_(mode), + bin_writer_ptr_(mode_ == kWrite ? (new io::BinFileWriter) : nullptr), + text_writer_ptr_(mode_ == kWrite ? (new io::TextFileWriter) : nullptr), + bin_reader_ptr_(mode_ == kRead ? (new io::BinFileReader) : nullptr) { + if (mode_ == kWrite) { + bin_writer_ptr_->Open(prefix + ".model", io::kCreate); + text_writer_ptr_->Open(prefix + ".desc", io::kCreate); + } else if (mode == kRead) { + bin_reader_ptr_->Open(prefix + ".model"); + std::string key, serialized_str; + singa::TensorProto tp; + while (bin_reader_ptr_->Read(&key, &serialized_str)) { + CHECK(param_names_.count(key) == 0); + param_names_.insert(key); + CHECK(tp.ParseFromString(serialized_str)); + param_map_[key].FromProto(tp); + } + } else { + LOG(FATAL) + << "Mode for snapshot should be Snapshot::kWrite or Snapshot::kRead"; + } +} + +void Snapshot::Write(const std::string& key, const Tensor& param) { + CHECK(mode_ == kWrite); + CHECK(param_names_.count(key) == 0); + param_names_.insert(key); + TensorProto tp; + param.ToProto(&tp); + std::string serialized_str; + CHECK(tp.SerializeToString(&serialized_str)); + bin_writer_ptr_->Write(key, serialized_str); + + std::string desc_str = "parameter name: " + key; + Shape shape = param.shape(); + desc_str += "\tdata type: " + std::to_string(param.data_type()); + desc_str += "\tdim: " + std::to_string(shape.size()); + desc_str += "\tshape:"; + for (size_t s : shape) desc_str += " " + std::to_string(s); + text_writer_ptr_->Write(key, desc_str); +} + +std::vector<std::pair<std::string, Tensor>> Snapshot::Read() { + CHECK(mode_ == kRead); + std::vector<std::pair<std::string, Tensor>> ret; + for (auto it = param_map_.begin(); it != param_map_.end(); ++it) + ret.push_back(*it); + return ret; +} + +std::vector<std::pair<std::string, Shape>> Snapshot::ReadShape() { + CHECK(mode_ == kRead); + std::vector<std::pair<std::string, Shape>> ret; + for (auto it = param_map_.begin(); it != param_map_.end(); ++it) + ret.push_back(std::make_pair(it->first, it->second.shape())); + return ret; +} + +Tensor Snapshot::Read(const std::string& key) { + CHECK(mode_ == kRead); + CHECK(param_map_.count(key) == 1); + return param_map_[key]; +} + +Shape Snapshot::ReadShape(const std::string& key) { + CHECK(mode_ == kRead); + CHECK(param_map_.count(key) == 1); + return param_map_[key].shape(); +} + +} // namespace singa http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/62c6603f/src/proto/core.proto ---------------------------------------------------------------------- diff --git a/src/proto/core.proto b/src/proto/core.proto index b853b30..da32bc9 100644 --- a/src/proto/core.proto +++ b/src/proto/core.proto @@ -58,3 +58,14 @@ message MemPoolConf { // cnmemflag = 2: prevent the manager from stealing memory optional uint32 cnmemflag = 4 [default = 0]; } + +// For tensor serialization +message TensorProto { + repeated uint32 shape = 1; + optional DataType data_type = 2; + optional bool transpose = 3; + repeated float float_data = 4; + repeated double double_data = 5; + repeated int32 int_data = 6; + repeated bytes bytes_data = 7; +} http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/62c6603f/test/singa/test_snapshot.cc ---------------------------------------------------------------------- diff --git a/test/singa/test_snapshot.cc b/test/singa/test_snapshot.cc new file mode 100644 index 0000000..26f1f8c --- /dev/null +++ b/test/singa/test_snapshot.cc @@ -0,0 +1,116 @@ +/************************************************************ +* +* Licensed to the Apache Software Foundation (ASF) under one +* or more contributor license agreements. See the NOTICE file +* distributed with this work for additional information +* regarding copyright ownership. The ASF licenses this file +* to you under the Apache License, Version 2.0 (the +* "License"); you may not use this file except in compliance +* with the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an +* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +* KIND, either express or implied. See the License for the +* specific language governing permissions and limitations +* under the License. +* +*************************************************************/ + +#include "gtest/gtest.h" +#include "singa/io/snapshot.h" +#include "singa/io/reader.h" +#include "singa/core/tensor.h" + +#include <string> +#include <fstream> + +const std::string prefix = "./snapshot_test"; +const float param_1_data[] = {0.1, 0.2, 0.3, 0.4}; +const float param_2_data[] = {0.2, 0.1, 0.4, 0.3}; +const std::string desc_1 = "parameter name: Param_1\tdata type: 0\tdim: 1\tshape: 4"; +const std::string desc_2 = "parameter name: Param_2\tdata type: 0\tdim: 2\tshape: 2 2"; +const int int_data[] = {1, 3, 5, 7}; +const double double_data[] = {0.2, 0.4, 0.6, 0.8}; + +TEST(Snapshot, WriteTest) { + singa::Snapshot snapshot(prefix, singa::Snapshot::kWrite); + singa::Tensor param_1(singa::Shape{4}), param_2(singa::Shape{2, 2}); + param_1.CopyDataFromHostPtr(param_1_data, 4); + param_2.CopyDataFromHostPtr(param_2_data, 4); + snapshot.Write("Param_1", param_1); + snapshot.Write("Param_2", param_2); +} + +TEST(Snapshot, ReadTest) { + singa::Snapshot snapshot(prefix, singa::Snapshot::kRead); + singa::Tensor param_1, param_2; + singa::Shape shape1, shape2; + shape1 = snapshot.ReadShape("Param_1"); + EXPECT_EQ(shape1.size(), 1); + EXPECT_EQ(shape1[0], 4); + shape2 = snapshot.ReadShape("Param_2"); + EXPECT_EQ(shape2.size(), 2); + EXPECT_EQ(shape2[0], 2); + EXPECT_EQ(shape2[1], 2); + param_1 = snapshot.Read("Param_1"); + const float* data_1 = param_1.data<float>(); + for (size_t i = 0; i < singa::Product(shape1); ++i) + EXPECT_FLOAT_EQ(data_1[i], param_1_data[i]); + param_2 = snapshot.Read("Param_2"); + const float* data_2 = param_2.data<float>(); + for (size_t i = 0; i < singa::Product(shape2); ++i) + EXPECT_FLOAT_EQ(data_2[i], param_2_data[i]); + std::ifstream desc_file(prefix+".desc"); + std::string line; + getline(desc_file, line); + EXPECT_EQ(line, desc_1); + getline(desc_file, line); + EXPECT_EQ(line, desc_2); +} + +TEST(Snapshot, ReadIntTest) { + { + singa::Snapshot int_snapshot_write(prefix+".int", singa::Snapshot::kWrite); + singa::Tensor int_param(singa::Shape{4}); + int_param.AsType(singa::kInt); + int_param.CopyDataFromHostPtr(int_data, 4); + int_snapshot_write.Write("IntParam", int_param); + } + + { + singa::Snapshot int_snapshot_read(prefix+".int", singa::Snapshot::kRead); + singa::Shape shape; + shape = int_snapshot_read.ReadShape("IntParam"); + EXPECT_EQ(shape.size(), 1); + EXPECT_EQ(shape[0], 4); + singa::Tensor int_param = int_snapshot_read.Read("IntParam"); + const int* param_data = int_param.data<int>(); + for (size_t i = 0; i < singa::Product(shape); ++i) + EXPECT_EQ(param_data[i], int_data[i]); + } +} + +TEST(Snapshot, ReadDoubleTest) { + { + singa::Snapshot double_snapshot_write(prefix+".double", singa::Snapshot::kWrite); + singa::Tensor double_param(singa::Shape{4}); + double_param.AsType(singa::kDouble); + double_param.CopyDataFromHostPtr(double_data, 4); + double_snapshot_write.Write("DoubleParam", double_param); + } + + { + singa::Snapshot double_snapshot_read(prefix+".double", singa::Snapshot::kRead); + singa::Shape shape; + shape = double_snapshot_read.ReadShape("DoubleParam"); + EXPECT_EQ(shape.size(), 1); + EXPECT_EQ(shape[0], 4); + singa::Tensor double_param = double_snapshot_read.Read("DoubleParam"); + const double* param_data = double_param.data<double>(); + for (size_t i = 0; i < singa::Product(shape); ++i) + EXPECT_EQ(param_data[i], double_data[i]); + } +}
