Repository: incubator-singa Updated Branches: refs/heads/master f2b0aef12 -> d0438b42c
SINGA-21 Code review 3 review blob.h, blob.cc -- wrap all classes/functions into singa namespace -- remove Blob.data() function for more secure access -- make some funtions inlined -- move Blob.data_ to protected domain -- format the code Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/e28b0394 Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/e28b0394 Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/e28b0394 Branch: refs/heads/master Commit: e28b0394c0831b4c6b0933cedc070d4cd000fb47 Parents: f2b0aef Author: wang sheng <[email protected]> Authored: Tue Aug 18 12:23:42 2015 +0800 Committer: wang sheng <[email protected]> Committed: Tue Aug 18 12:23:42 2015 +0800 ---------------------------------------------------------------------- include/trainer/server.h | 12 +-- include/utils/blob.h | 120 +++++++++++---------- include/utils/param.h | 15 +-- src/utils/blob.cc | 246 +++++++++++++++++++----------------------- 4 files changed, 185 insertions(+), 208 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/e28b0394/include/trainer/server.h ---------------------------------------------------------------------- diff --git a/include/trainer/server.h b/include/trainer/server.h index ef6e599..8cc37c5 100644 --- a/include/trainer/server.h +++ b/include/trainer/server.h @@ -24,7 +24,7 @@ class Server{ virtual ~Server(); void Setup(const UpdaterProto& proto, std::unordered_map<int, ParamEntry*>* shard, - const vector<int>& slice2group); + const std::vector<int>& slice2group); void Run(); const int grp_id() const { return grp_id_; @@ -47,7 +47,7 @@ class Server{ * * @return the orignal message or response message */ - const vector<Msg*> HandleUpdate(Msg **msg); + const std::vector<Msg*> HandleUpdate(Msg **msg); /** * Process PUT request. @@ -68,15 +68,15 @@ class Server{ * @param param slice to be sync with others * @return sync messages */ - const vector<Msg*> GenSyncMsgs(Param* param); + const std::vector<Msg*> GenSyncMsgs(Param* param); protected: int thread_id_,grp_id_, id_; Updater* updater_; std::unordered_map<int, ParamEntry*> *shard_; - vector<int> slice2group_; - std::unordered_map<int, shared_ptr<Blob<float>>> last_data_; - std::unordered_map<int, vector<Msg*>> buffer_requests_; + std::vector<int> slice2group_; + std::unordered_map<int, std::shared_ptr<Blob<float>>> last_data_; + std::unordered_map<int, std::vector<Msg*>> buffer_requests_; }; } /* Server */ #endif //INCLUDE_TRAINER_SERVER_H_ http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/e28b0394/include/utils/blob.h ---------------------------------------------------------------------- diff --git a/include/utils/blob.h b/include/utils/blob.h index 9ff1d47..8769e34 100644 --- a/include/utils/blob.h +++ b/include/utils/blob.h @@ -38,16 +38,16 @@ * license and copyright terms herein. * */ -#ifndef INCLUDE_UTILS_BLOB_ -#define INCLUDE_UTILS_BLOB_ +#ifndef SINGA_UTILS_BLOB_H_ +#define SINGA_UTILS_BLOB_H_ + +#include <glog/logging.h> #include <memory> #include <vector> -#include <glog/logging.h> #include "proto/common.pb.h" -using std::shared_ptr; -using std::vector; -#define NOT_IMPLEMENTED LOG(FATAL) << "Not implemented function" +namespace singa { + inline void MallocHost(void** ptr, size_t size) { *ptr = malloc(size); } @@ -64,39 +64,40 @@ inline void FreeHost(void* ptr) { */ class SyncedMemory { public: - SyncedMemory() - : cpu_ptr_(NULL), gpu_ptr_(NULL), size_(0), head_(UNINITIALIZED), - own_cpu_data_(false) {} - explicit SyncedMemory(size_t size) - : cpu_ptr_(NULL), gpu_ptr_(NULL), size_(size), head_(UNINITIALIZED), - own_cpu_data_(false) {} + enum SyncedHead { UNINITIALIZED, + HEAD_AT_CPU, + HEAD_AT_GPU, + SYNCED }; + + SyncedMemory() {} + explicit SyncedMemory(size_t size) : size_(size) {} ~SyncedMemory(); + const void* cpu_data(); - void set_cpu_data(void* data); const void* gpu_data(); void* mutable_cpu_data(); void* mutable_gpu_data(); - enum SyncedHead { UNINITIALIZED, HEAD_AT_CPU, HEAD_AT_GPU, SYNCED }; - SyncedHead head() { return head_; } - size_t size() { return size_; } + void set_cpu_data(void* data); + inline SyncedHead head() { return head_; } + inline size_t size() { return size_; } private: void to_cpu(); void to_gpu(); - void* cpu_ptr_; - void* gpu_ptr_; - size_t size_; - SyncedHead head_; - bool own_cpu_data_; + void* cpu_ptr_ = nullptr; + void* gpu_ptr_ = nullptr; + size_t size_ = 0; + SyncedHead head_ = UNINITIALIZED; + bool own_cpu_data_ = false; }; // class SyncedMemory template <typename Dtype> class Blob { public: - Blob(): count_(0), capacity_(0) , version_(-1){} - Blob(const vector<int>&shape); + Blob() {} + explicit Blob(const std::vector<int>& shape) { Reshape(shape); } /** * @brief Change the dimensions of the blob, allocating new memory if * necessary. @@ -111,18 +112,8 @@ class Blob { * an error; either Net::Forward or Net::Reshape need to be called to * propagate the new input shape to higher layers. */ - void Reshape(const vector<int>& shape); + void Reshape(const std::vector<int>& shape); void ReshapeLike(const Blob& other); - const vector<int>& shape() const{ - return shape_; - } - inline int count() const { return count_; } - void set_version(int v){ - version_=v; - } - const int version() const { - return version_; - } /** * @brief Copy from a source Blob. * @@ -131,25 +122,10 @@ class Blob { * of other (and die otherwise); if true, Reshape this Blob to other's * shape if necessary */ - void CopyFrom(const Blob<Dtype>& source, bool reshape = false); - - inline const shared_ptr<SyncedMemory>& data() const { - CHECK(data_); - return data_; - } - - const Dtype* cpu_data() const; - void set_cpu_data(Dtype* data); - const Dtype* gpu_data() const; - Dtype* mutable_cpu_data(); - Dtype* mutable_gpu_data(); + void CopyFrom(const Blob<Dtype>& source); + void CopyFrom(const Blob<Dtype>& source, bool reshape); void FromProto(const singa::BlobProto& proto); void ToProto(singa::BlobProto* proto) const; - - /// @brief Compute the sum of absolute values (L1 norm) of the data. - Dtype asum_data() const; - Dtype sum_data() const; - /** * @brief Set the data_ shared_ptr to point to the SyncedMemory holding the * data_ of Blob other -- useful in Layer&s which simply perform a copy @@ -160,12 +136,42 @@ class Blob { */ void ShareData(const Blob& other); void Swap(Blob& other); - shared_ptr<SyncedMemory> data_; + inline const std::vector<int>& shape() const { return shape_; } + inline int count() const { return count_; } + inline const int version() const { return version_; } + inline void set_version(int v) { version_ = v; } + inline const Dtype* cpu_data() const { + CHECK(data_); + return static_cast<const Dtype*>(data_->cpu_data()); + } + inline void set_cpu_data(Dtype* data) { + CHECK(data); + data_->set_cpu_data(data); + } + inline const Dtype* gpu_data() const { + CHECK(data_); + return static_cast<const Dtype*>(data_->gpu_data()); + } + inline Dtype* mutable_cpu_data() { + CHECK(data_); + return static_cast<Dtype*>(data_->mutable_cpu_data()); + } + inline Dtype* mutable_gpu_data() { + CHECK(data_); + return static_cast<Dtype*>(data_->mutable_gpu_data()); + } + /// @brief Compute the sum of absolute values (L1 norm) of the data. + Dtype asum_data() const; + Dtype sum_data() const; + protected: - vector<int> shape_; - int count_; - int capacity_; - int version_; + std::shared_ptr<SyncedMemory> data_ = nullptr; + std::vector<int> shape_; + int count_ = 0; + int capacity_ = 0; + int version_ = -1; }; // class Blob -#endif // INCLUDE_UTILS_BLOB_ +} // namespace singa + +#endif // SINGA_UTILS_BLOB_H_ http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/e28b0394/include/utils/param.h ---------------------------------------------------------------------- diff --git a/include/utils/param.h b/include/utils/param.h index 2eb66db..be465f4 100644 --- a/include/utils/param.h +++ b/include/utils/param.h @@ -216,7 +216,8 @@ class Param { * request in msgs. * @return response messages */ - virtual const vector<Msg*> GenUpdateResponseMsgs(const vector<Msg*>& msgs); + virtual const std::vector<Msg*> + GenUpdateResponseMsgs(const std::vector<Msg*>& msgs); /** * Server handling function for get request. @@ -254,9 +255,9 @@ class Param { virtual int ParseUpdateResponseMsg(Msg* msg, int slice_idx); /** * Server parse update requests. - * \copydetails GenUpdateResponseMsgs(const vector<Msg*>& msgs); + * \copydetails GenUpdateResponseMsgs(const std::vector<Msg*>& msgs); */ - virtual void ParseUpdateMsgs(const vector<Msg*>& msgs); + virtual void ParseUpdateMsgs(const std::vector<Msg*>& msgs); /** * Server parsing function for synchronization response. * @@ -277,13 +278,13 @@ class Param { int slice_start_; int num_slices_; //!< offset and size of each slice - vector<int> slice_offset_, slice_size_; + std::vector<int> slice_offset_, slice_size_; //!< for debug checking - vector<bool> pending_put_, pending_get_, pending_update_; + std::vector<bool> pending_put_, pending_get_, pending_update_; int num_pending_requests_; - shared_ptr<Blob<float>> data_; + std::shared_ptr<Blob<float>> data_; //! gradient, history gradient of this parameter Blob<float> grad_, history_; ParamProto proto_; @@ -312,7 +313,7 @@ class ParamEntry{ int num_local; //!< # local workers using the shared parameter int num_total; //!< # total workers using the shared parameter //!< Shares are deleted by neuralnet's destructor - vector<Param*> shares; + std::vector<Param*> shares; }; inline int ParamTrgt(int param_id, int slice_id) { http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/e28b0394/src/utils/blob.cc ---------------------------------------------------------------------- diff --git a/src/utils/blob.cc b/src/utils/blob.cc index 14c772f..fd402a8 100644 --- a/src/utils/blob.cc +++ b/src/utils/blob.cc @@ -37,12 +37,13 @@ * or otherwise, the contributor releases their content to the * license and copyright terms herein. */ -#include <utility> -#include <math.h> -#include <cblas.h> #include "utils/blob.h" -/*********************SyncedMemory implementation************************/ +#include <cblas.h> +#include <math.h> +#include <utility> + +#define NOT_IMPLEMENTED LOG(FATAL) << "Not implemented function" #define NO_GPU LOG(FATAL) << "CPU-only Mode: cannot make GPU call." // Instantiate a class with float and double specifications. #define INSTANTIATE_CLASS(classname) \ @@ -77,14 +78,14 @@ private:\ << caffe::curandGetErrorString(status); \ } while (0) -#endif // CPU_ONLY +#endif // CPU_ONLY +namespace singa { SyncedMemory::~SyncedMemory() { if (cpu_ptr_ && own_cpu_data_) { FreeHost(cpu_ptr_); } - #ifndef CPU_ONLY if (gpu_ptr_) { CUDA_CHECK(cudaFree(gpu_ptr_)); @@ -92,11 +93,53 @@ SyncedMemory::~SyncedMemory() { #endif // CPU_ONLY } -inline void SyncedMemory::to_cpu() { +const void* SyncedMemory::cpu_data() { + to_cpu(); + return cpu_ptr_; +} + +const void* SyncedMemory::gpu_data() { +#ifndef CPU_ONLY + to_gpu(); + return gpu_ptr_; +#else + NO_GPU; +#endif + return nullptr; +} + +void* SyncedMemory::mutable_cpu_data() { + to_cpu(); + head_ = HEAD_AT_CPU; + return cpu_ptr_; +} + +void* SyncedMemory::mutable_gpu_data() { +#ifndef CPU_ONLY + to_gpu(); + head_ = HEAD_AT_GPU; + return gpu_ptr_; +#else + NO_GPU; +#endif + return nullptr; +} + +void SyncedMemory::set_cpu_data(void* data) { + CHECK(data); + if (own_cpu_data_) { + FreeHost(cpu_ptr_); + } + cpu_ptr_ = data; + head_ = HEAD_AT_CPU; + own_cpu_data_ = false; +} + +void SyncedMemory::to_cpu() { switch (head_) { case UNINITIALIZED: MallocHost(&cpu_ptr_, size_); - memset(cpu_ptr_,0, size_); + memset(cpu_ptr_, 0, size_); head_ = HEAD_AT_CPU; own_cpu_data_ = true; break; @@ -118,19 +161,19 @@ inline void SyncedMemory::to_cpu() { } } -inline void SyncedMemory::to_gpu() { +void SyncedMemory::to_gpu() { #ifndef CPU_ONLY switch (head_) { case UNINITIALIZED: CUDA_CHECK(cudaMalloc(&gpu_ptr_, size_)); - CUDA_CHECK(cudaMemset(gpu_ptr_, 0, N)); // NOLINT(caffe/alt_fn) + CUDA_CHECK(cudaMemset(gpu_ptr_, 0, N)); head_ = HEAD_AT_GPU; break; case HEAD_AT_CPU: if (gpu_ptr_ == NULL) { CUDA_CHECK(cudaMalloc(&gpu_ptr_, size_)); } - CUDA_CHECK(cudaMemcpy( gpu_ptr_,cpu_ptr_, size_, cudaMemcpyDefault)); + CUDA_CHECK(cudaMemcpy(gpu_ptr_, cpu_ptr_, size_, cudaMemcpyDefault)); head_ = SYNCED; break; case HEAD_AT_GPU: @@ -142,64 +185,13 @@ inline void SyncedMemory::to_gpu() { #endif } -const void* SyncedMemory::cpu_data() { - to_cpu(); - return (const void*)cpu_ptr_; -} - -void SyncedMemory::set_cpu_data(void* data) { - CHECK(data); - if (own_cpu_data_) { - FreeHost(cpu_ptr_); - } - cpu_ptr_ = data; - head_ = HEAD_AT_CPU; - own_cpu_data_ = false; -} - -const void* SyncedMemory::gpu_data() { -#ifndef CPU_ONLY - to_gpu(); - return (const void*)gpu_ptr_; -#else - NO_GPU; -#endif - return nullptr; -} - -void* SyncedMemory::mutable_cpu_data() { - to_cpu(); - head_ = HEAD_AT_CPU; - return cpu_ptr_; -} - -void* SyncedMemory::mutable_gpu_data() { -#ifndef CPU_ONLY - to_gpu(); - head_ = HEAD_AT_GPU; - return gpu_ptr_; -#else - NO_GPU; -#endif - return nullptr; -} - -/*********************Blob implementation************************/ - template <typename Dtype> -Blob<Dtype>::Blob(const vector<int>& shape) - // capacity_ must be initialized before calling Reshape - : capacity_(0), version_(-1) { - Reshape(shape); -} - -template <typename Dtype> -void Blob<Dtype>::Reshape(const vector<int>& shape) { - count_=1; +void Blob<Dtype>::Reshape(const std::vector<int>& shape) { + count_ = 1; shape_ = shape; - for(size_t i=0;i<shape.size();i++){ + for (size_t i = 0; i < shape.size(); ++i) { CHECK(shape[i]); - count_*=shape[i]; + count_ *= shape[i]; } if (count_ > capacity_) { capacity_ = count_; @@ -213,76 +205,13 @@ void Blob<Dtype>::ReshapeLike(const Blob<Dtype>& other) { } template <typename Dtype> -const Dtype* Blob<Dtype>::cpu_data() const { - CHECK(data_); - return (const Dtype*)data_->cpu_data(); -} - -template <typename Dtype> -void Blob<Dtype>::set_cpu_data(Dtype* data) { - CHECK(data); - data_->set_cpu_data(data); -} - -template <typename Dtype> -const Dtype* Blob<Dtype>::gpu_data() const { - CHECK(data_); - return (const Dtype*)data_->gpu_data(); -} - -template <typename Dtype> -Dtype* Blob<Dtype>::mutable_cpu_data() { - CHECK(data_); - return static_cast<Dtype*>(data_->mutable_cpu_data()); -} - -template <typename Dtype> -Dtype* Blob<Dtype>::mutable_gpu_data() { - CHECK(data_); - return static_cast<Dtype*>(data_->mutable_gpu_data()); -} - -template <typename Dtype> -void Blob<Dtype>::ShareData(const Blob& other) { - CHECK_EQ(count_, other.count()); - data_ = other.data(); -} - -template <> float Blob<float>::asum_data() const { - if(count()==0) - return 0.f; - return cblas_sasum(count(), cpu_data(), 1)/count(); -} -template <> float Blob<float>::sum_data() const { - if(count()==0) - return 0.f; - float sum=0.f; - const float *dptr=cpu_data(); - for(int i=0;i<count();i++) - sum+=dptr[i]; - return sum/count(); -} -template <> unsigned int Blob<unsigned int>::asum_data() const { - NOT_IMPLEMENTED; - return 0; -} - -template <> int Blob<int>::asum_data() const { - NOT_IMPLEMENTED; - return 0; -} - -template <typename Dtype> -void Blob<Dtype>::Swap(Blob& other){ - CHECK_EQ(other.count(), count()); - CHECK(std::equal(shape_.begin(), shape_.end(), other.shape_.begin())); - std::swap(data_, other.data_); - std::swap(capacity_, other.capacity_); +void Blob<Dtype>::CopyFrom(const Blob& source) { + CopyFrom(source, false); } template <typename Dtype> void Blob<Dtype>::CopyFrom(const Blob& source, bool reshape) { - if (!std::equal(shape_.begin(),shape_.end(),source.shape_.begin())) { + if (!std::equal(shape_.begin(), shape_.end(), source.shape_.begin())) { if (reshape) { Reshape(source.shape_); } else { @@ -291,17 +220,18 @@ void Blob<Dtype>::CopyFrom(const Blob& source, bool reshape) { } #ifndef CPU_ONLY CUDA_CHECK(cudaMemcpy(static_cast<Dtype*>(data_->mutable_gpu_data()), - source.gpu_data(), sizeof(Dtype) * count_, cudaMemcpyDefault)); + source.gpu_data(), sizeof(Dtype) * count_, cudaMemcpyDefault)); #endif - memcpy(static_cast<Dtype*>(data_->mutable_cpu_data()),source.cpu_data(), - sizeof(Dtype)*count_); + memcpy(static_cast<Dtype*>(data_->mutable_cpu_data()), source.cpu_data(), + sizeof(Dtype)*count_); } template <typename Dtype> void Blob<Dtype>::FromProto(const singa::BlobProto& proto) { - vector<int> shape; - for (int s : proto.shape()) + std::vector<int> shape; + for (int s : proto.shape()) { shape.push_back(s); + } int count = count_; Reshape(shape); if (count != count_) @@ -315,8 +245,9 @@ void Blob<Dtype>::FromProto(const singa::BlobProto& proto) { template <typename Dtype> void Blob<Dtype>::ToProto(singa::BlobProto* proto) const { - for (int s : shape_) + for (int s : shape_) { proto->add_shape(s); + } proto->clear_data(); const Dtype* data_vec = cpu_data(); for (int i = 0; i < count_; ++i) { @@ -324,6 +255,45 @@ void Blob<Dtype>::ToProto(singa::BlobProto* proto) const { } } +template <typename Dtype> +void Blob<Dtype>::ShareData(const Blob& other) { + CHECK_EQ(count_, other.count()); + data_ = other.data_; +} + +template <typename Dtype> +void Blob<Dtype>::Swap(Blob& other) { + CHECK_EQ(other.count(), count()); + CHECK(std::equal(shape_.begin(), shape_.end(), other.shape_.begin())); + std::swap(data_, other.data_); + std::swap(capacity_, other.capacity_); +} + +template <> float Blob<float>::asum_data() const { + if (count() == 0) return 0.f; + return cblas_sasum(count(), cpu_data(), 1) / count(); +} +template <> float Blob<float>::sum_data() const { + if (count() == 0) return 0.f; + float sum = 0.f; + const float* dptr = cpu_data(); + for (int i = 0; i < count(); ++i) + sum += dptr[i]; + return sum / count(); +} + +template <> unsigned int Blob<unsigned int>::asum_data() const { + NOT_IMPLEMENTED; + return 0; +} + +template <> int Blob<int>::asum_data() const { + NOT_IMPLEMENTED; + return 0; +} + INSTANTIATE_CLASS(Blob); template class Blob<int>; template class Blob<unsigned int>; + +} // namespace singa
