SINGA-50 Improve the code of ParseUpdateMsgs function refactored the code by using the memory of one message or one local worker's grad blob to store the aggregated parameter gradients.
Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/b33e50d6 Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/b33e50d6 Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/b33e50d6 Branch: refs/heads/master Commit: b33e50d65fb67ee9d3cff2d3242ff36005204fca Parents: d269b67 Author: Wei Wang <[email protected]> Authored: Fri Aug 14 22:02:46 2015 +0800 Committer: Wei Wang <[email protected]> Committed: Sat Aug 15 14:59:11 2015 +0800 ---------------------------------------------------------------------- include/utils/param.h | 18 +++--- src/utils/param.cc | 142 +++++++++++++++++++++------------------------ 2 files changed, 74 insertions(+), 86 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b33e50d6/include/utils/param.h ---------------------------------------------------------------------- diff --git a/include/utils/param.h b/include/utils/param.h index 18293b6..8fabe71 100644 --- a/include/utils/param.h +++ b/include/utils/param.h @@ -1,5 +1,5 @@ -#ifndef INCLUDE_UTILS_PARAM_H_ -#define INCLUDE_UTILS_PARAM_H_ +#ifndef SINGA_UTILS_PARAM_H_ +#define SINGA_UTILS_PARAM_H_ #include <vector> #include <string> #include "proto/job.pb.h" @@ -28,7 +28,7 @@ namespace singa { class Param { public: Param(); - virtual ~Param(){ } + virtual ~Param() {} /** * Setup param object * @@ -41,7 +41,7 @@ class Param { * * @param version initial version */ - virtual void InitValues(int version=0); + virtual void InitValues(int version = 0); /** * Share the data blob from other Param objects. * @@ -113,7 +113,7 @@ class Param { } void set_local_version(int v) { - local_version_=v; + local_version_ = v; } const std::string& share_from() const { return proto_.share_from(); @@ -280,7 +280,7 @@ class Param { vector<int> slice_offset_, slice_size_; //!< for debug checking - vector<bool> pending_put_,pending_get_, pending_update_; + vector<bool> pending_put_, pending_get_, pending_update_; int num_pending_requests_; shared_ptr<Blob<float>> data_; @@ -309,8 +309,8 @@ class ParamEntry{ */ void AddParam(bool local, Param* p); int num_update, next_version; - int num_local; //!< # local workers using the shared parameter - int num_total; //!< # total workers using the shared parameter + int num_local; //!< # local workers using the shared parameter + int num_total; //!< # total workers using the shared parameter //!< Shares are deleted by neuralnet's destructor vector<Param*> shares; }; @@ -329,4 +329,4 @@ inline int SliceID(int param_trgt) { } } // namespace singa -#endif // INCLUDE_UTILS_PARAM_H_ +#endif // SINGA_UTILS_PARAM_H_ http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b33e50d6/src/utils/param.cc ---------------------------------------------------------------------- diff --git a/src/utils/param.cc b/src/utils/param.cc index b470ea2..61173cb 100644 --- a/src/utils/param.cc +++ b/src/utils/param.cc @@ -6,30 +6,30 @@ #include "proto/job.pb.h" #include "mshadow/tensor.h" #include "utils/singleton.h" +namespace singa { using namespace mshadow; using std::vector; using std::string; -namespace singa { Param::Param():local_version_(-1), slice_start_(0), num_slices_(0), num_pending_requests_(0), data_(nullptr) { } -void Param::Setup(const ParamProto& proto, const vector<int>& shape){ - data_=std::make_shared<Blob<float>>(shape); +void Param::Setup(const ParamProto& proto, const vector<int>& shape) { + data_ = std::make_shared<Blob<float>>(shape); grad_.Reshape(shape); history_.Reshape(shape); proto_.CopyFrom(proto); } -void Param::AddSlice(int slice_id, int size){ - int offset=0; - if(slice_size_.size()>0){ - //must be added in order - CHECK_EQ(slice_start_+num_slices_, slice_id); - offset=slice_offset_.back()+slice_size_.back(); +void Param::AddSlice(int slice_id, int size) { + int offset = 0; + if (slice_size_.size() > 0) { + // must be added in order + CHECK_EQ(slice_start_ + num_slices_, slice_id); + offset = slice_offset_.back() + slice_size_.back(); } else { - slice_start_=slice_id; - offset=0; + slice_start_ = slice_id; + offset = 0; } slice_offset_.push_back(offset); slice_size_.push_back(size); @@ -39,11 +39,11 @@ void Param::AddSlice(int slice_id, int size){ num_slices_++; } -void Param::InitValues(int version){ +void Param::InitValues(int version) { Tensor<cpu, 1> data(mutable_cpu_data(), Shape1(size())); - auto random=TSingleton<Random<cpu>>::Instance(); + auto random = TSingleton<Random<cpu>>::Instance(); switch (proto_.init_method()) { - case InitMethod::kConstant: + case ParamProto::kConstant: data = proto_.value(); break; case InitMethod::kUniform: @@ -61,8 +61,8 @@ void Param::InitValues(int version){ break; case InitMethod::kUniformSqrtFanInOut: random->SampleUniform(data, proto_.low(), proto_.high()); - if(proto_.value()) - data *= proto_.value()/ sqrt(data_->shape()[0] +data_->shape()[1]); + if (proto_.value()) + data *= proto_.value() / sqrt(data_->shape()[0] + data_->shape()[1]); break; case InitMethod::kGaussian: random->SampleGaussian(data, proto_.mean(), proto_.std()); @@ -71,8 +71,8 @@ void Param::InitValues(int version){ break; case InitMethod::kGaussainSqrtFanIn: random->SampleGaussian(data, proto_.mean(), proto_.std()); - if(proto_.value()) - data *= proto_.value()/ sqrt(data_->shape()[0]); + if (proto_.value()) + data *= proto_.value() / sqrt(data_->shape()[0]); break; default: LOG(ERROR) << "Illegal parameter init method "; @@ -90,50 +90,49 @@ void Param::ToProto(BlobProto* blob) { /**************Message related functions********/ Msg* Param::GenPutMsg(bool copy, int idx) { CHECK_LT(idx, num_slices_); - Msg* msg=new Msg(); + Msg* msg = new Msg(); msg->set_type(kPut); - void *ptr=mutable_cpu_data()+slice_offset_[idx]; + void *ptr = mutable_cpu_data() + slice_offset_[idx]; void *p = ptr; if (copy) p = nullptr; msg->AddFormatFrame("iffp", slice_size_[idx], learning_rate_multiplier(), weight_decay_multiplier(), p); if (copy) { - msg->AddFrame(ptr, slice_size_[idx]*sizeof(float)); + msg->AddFrame(ptr, slice_size_[idx] * sizeof(float)); } - //pending_put_[idx]=true; - //num_pending_requests_++; - return msg; + // pending_put_[idx]=true; + // num_pending_requests_++; + return msg; } Msg* Param::GenGetMsg(bool copy, int idx) { CHECK_LT(idx, num_slices_); - Msg* msg=new Msg(); + Msg* msg = new Msg(); msg->set_type(kGet); - msg->AddFormatFrame("ip", copy, data_->cpu_data()+slice_offset_[idx]); - pending_get_[idx]=true; + msg->AddFormatFrame("ip", copy, data_->cpu_data() + slice_offset_[idx]); + pending_get_[idx] = true; num_pending_requests_++; return msg; } Msg* Param::GenUpdateMsg(bool copy, int idx) { CHECK_LT(idx, num_slices_); - Msg* msg=new Msg(); + Msg* msg = new Msg(); msg->set_type(kUpdate); msg->AddFormatFrame("i", copy); - void* ptr=grad_.mutable_cpu_data()+slice_offset_[idx]; - if(copy){ - //LOG(ERROR)<<"Copy in gen update"; + void* ptr = grad_.mutable_cpu_data() + slice_offset_[idx]; + if (copy) { msg->AddFrame(ptr, slice_size_[idx]*sizeof(float)); - } else { // to share values of grad blob - msg->AddFormatFrame("p", ptr); + } else { + msg->AddFormatFrame("p", ptr); // to share values of grad blob } - pending_update_[idx]=true; + pending_update_[idx] = true; num_pending_requests_++; return msg; } Msg* Param::GenSyncMsg(int offset, int size) { - Msg* msg=new Msg(); + Msg* msg = new Msg(); msg->set_type(kSyncRequest); msg->set_trgt(ParamTrgt(-1, id()), local_version()); // always copy data because syn is between server groups in diff procs @@ -155,7 +154,7 @@ Msg* Param::HandlePutMsg(Msg** msg, bool reserve) { CHECK((*msg)->NextFrame()); CHECK_EQ(size* sizeof(float), (*msg)->FrameSize()); memcpy(mutable_cpu_data(), (*msg)->FrameData(), size*sizeof(float)); - }else{ + } else { data_->set_cpu_data(ptr); } if (!reserve) @@ -167,9 +166,9 @@ Msg* Param::HandleGetMsg(Msg** msg, bool reserve) { int copy; float* ptr; (*msg)->ParseFormatFrame("ip", ©, &ptr); - if(copy) + if (copy) { (*msg)->AddFrame(mutable_cpu_data(), sizeof(float)*size()); - else if(ptr!=data_->cpu_data()){ + } else if (ptr != data_->cpu_data()) { memcpy(ptr, data_->cpu_data(), sizeof(float)*size()); data_->set_cpu_data(ptr); } @@ -180,42 +179,33 @@ Msg* Param::HandleGetMsg(Msg** msg, bool reserve) { } void Param::ParseUpdateMsgs(const vector<Msg*>& msgs) { - bool reset = true; - vector<int> copies; + CHECK_GT(msgs.size(), 0); + float* server_grad = nullptr; + vector<float*> worker_grad; for (auto *msg : msgs) { int copy; msg->ParseFormatFrame("i", ©); - reset = reset && copy; - copies.push_back(copy); - } - int idx = 0; - for (auto *msg : msgs) { - CHECK(msg->NextFrame()); - if (copies.at(idx++)) { - float* server_grad = mutable_cpu_grad(); - float* worker_grad = static_cast<float*> (msg->FrameData()); - if (reset) { - memcpy(server_grad, worker_grad, sizeof(float) * size()); - reset = false; - } else { - for (int i =0; i < size(); i++) - server_grad[i] += worker_grad[i]; - } + msg->NextFrame(); + float* ptr = nullptr; + if (copy) { + ptr = static_cast<float*> (msg->FrameData()); + CHECK_EQ(size() * sizeof(float), msg->FrameSize()); } else { - float* ptr = nullptr; msg->ParseFormatFrame("p", &ptr); - if (grad_.cpu_data() != ptr) { - memcpy(ptr, grad_.cpu_data(), msg->FrameSize()); - grad_.set_cpu_data(ptr); - } + server_grad = ptr; } + worker_grad.push_back(ptr); } - - if (msgs.size() > 1) { - float* server_grad = mutable_cpu_grad(); - for (int i = 0; i < size(); i++) - server_grad[i] /= msgs.size(); + if (server_grad == nullptr) + server_grad = worker_grad.at(0); + for (float* grad : worker_grad) { + if (grad != server_grad) { + for (int i =0; i < size(); i++) { + server_grad[i] += grad[i]; + } + } } + grad_.set_cpu_data(server_grad); } const vector<Msg*> Param::GenUpdateResponseMsgs(const vector<Msg*>& msgs) { @@ -248,33 +238,33 @@ int Param::ParseSyncResponseMsg(Msg* msg, int slice_idx) { int Param::ParseGetResponseMsg(Msg *msg, int slice_idx) { CHECK_EQ(pending_get_[slice_idx], true); - pending_get_[slice_idx]=false; + pending_get_[slice_idx] = false; ParseResponseMsg(msg, slice_idx); - return (--num_pending_requests_)%num_slices_==0; + return (--num_pending_requests_)%num_slices_ == 0; } int Param::ParseUpdateResponseMsg(Msg *msg, int slice_idx) { CHECK_EQ(pending_update_[slice_idx], true); - pending_update_[slice_idx]=false; + pending_update_[slice_idx] = false; ParseResponseMsg(msg, slice_idx); - return (--num_pending_requests_) % num_slices_==0; + return (--num_pending_requests_) % num_slices_ == 0; } void Param::ParseResponseMsg(Msg* msg, int slice_idx) { int copy; msg->ParseFormatFrame("i", ©); msg->NextFrame(); - if(copy) { + if (copy) { CHECK_EQ(msg->FrameSize(), slice_size_[slice_idx]*sizeof(float)); memcpy(mutable_cpu_data()+slice_offset_[slice_idx], msg->FrameData(), msg->FrameSize()); } - //LOG(ERROR)<<"parse response norm "<<data_->asum_data()<<" of "<<id(); + // LOG(ERROR)<<"parse response norm "<<data_->asum_data()<<" of "<<id(); } void Param::ShareFrom(const Param& other) { proto_.set_owner(other.owner()); - if(data_!=nullptr) { + if (data_ != nullptr) { CHECK(std::equal(data_->shape().begin(), data_->shape().end(), other.data_->shape().begin())); } @@ -299,9 +289,7 @@ ParamEntry::ParamEntry(int total, Param* p) : num_update(0), num_total(total) { void ParamEntry::AddParam(bool local, Param* p) { num_local += local; num_total += 1; - if(local) + if (local) shares.push_back(p); } -} - -// namespace singa +} // namespace singa
