Repository: incubator-singa Updated Branches: refs/heads/master 856fc1fbe -> 4df2bb5a8
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/806826eb/src/trainer/trainer.cc ---------------------------------------------------------------------- diff --git a/src/trainer/trainer.cc b/src/trainer/trainer.cc index 3f343af..dbf8a48 100644 --- a/src/trainer/trainer.cc +++ b/src/trainer/trainer.cc @@ -180,7 +180,7 @@ void Trainer::Start(const ModelProto& mproto, const ClusterProto& cproto, threads.push_back(std::thread(&Server::Run,server.get())); for(auto worker: workers) threads.push_back(std::thread(&Worker::Run,worker.get())); - Run(servers.size(), workers.size(), shards); + Run(workers.size(), servers.size(), shards); for(auto& thread: threads) thread.join(); } @@ -191,8 +191,6 @@ void Trainer::Run(int nworkers, int nservers, procs_id_=cluster->procs_id(); map<int, shared_ptr<Dealer>> interprocs_dealers; Metric perf; - int perf_step=-1; - string perf_prefix; bool stop=false; while(!stop){ Msg* msg=router_->Receive(); @@ -209,9 +207,9 @@ void Trainer::Run(int nworkers, int nservers, msg =HandleConnect(&msg); }else if(type==kStop){ if(msg->src_flag()==kServer) - nworkers--; - else if (msg->src_flag()==kWorkerParam) nservers--; + else if (msg->src_flag()==kWorkerParam) + nworkers--; delete msg; msg=nullptr; if(nworkers==0&&nservers==0){ @@ -219,26 +217,19 @@ void Trainer::Run(int nworkers, int nservers, break; } }else if(type==kMetric){ - int step=msg->target_first(); - string prefix((char*)msg->frame_data(), msg->frame_size()); - if(step!=perf_step||perf_prefix!=prefix){ - if(perf_step>=0){ - perf.Avg(); - LOG(ERROR)<<perf_prefix<<" step-" - <<perf_step<<", "<<perf.ToString(); - perf.Reset(); - } - perf_step=step; - perf_prefix=prefix; + if(msg->src_first()==0){ + int step=msg->target_first(); + string prefix((char*)msg->frame_data(), msg->frame_size()); + msg->next_frame(); + Metric cur; + cur.ParseString(string((char*)msg->frame_data(), msg->frame_size())); + perf.AddMetrics(cur); + LOG(ERROR)<<prefix<<" step-" <<step<<", "<<perf.ToString(); + perf.Reset(); } - msg->next_frame(); - Metric cur; - cur.ParseString(string((char*)msg->frame_data(), msg->frame_size())); - perf.AddMetrics(cur); - perf.Inc(); delete msg; msg=nullptr; - }else { + }else if(cluster->nserver_groups()>1){ int group_id=msg->src_first(); int paramid=msg->target_first(); auto entry=shards.at(group_id)->at(paramid); @@ -261,6 +252,9 @@ void Trainer::Run(int nworkers, int nservers, default: break; } + }else{ + delete msg; + msg=nullptr; } }else{ int dst_procs_id; @@ -282,9 +276,11 @@ void Trainer::Run(int nworkers, int nservers, } } } + /* perf.Avg(); if(perf_step>=0) LOG(ERROR)<<perf_prefix<<" step-"<<perf_step<<", "<<perf.ToString(); + */ } Msg* Trainer::HandleConnect(Msg** msg){ string ping((char*)(*msg)->frame_data(), (*msg)->frame_size()); http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/806826eb/src/trainer/worker.cc ---------------------------------------------------------------------- diff --git a/src/trainer/worker.cc b/src/trainer/worker.cc index abfcdf0..3d400ee 100644 --- a/src/trainer/worker.cc +++ b/src/trainer/worker.cc @@ -22,6 +22,11 @@ void Worker::Setup(const ModelProto& model, auto cluster=Cluster::Get(); int sgid=group_id_/cluster->nworker_groups_per_server_group(); CHECK(cluster->runtime()->wJoinSGroup(group_id_, worker_id_, sgid)); + if(model.hogwild()){ + updater_=shared_ptr<Updater>(Singleton<Factory<Updater>>::Instance() + ->Create("Updater")); + updater_->Init(model.updater()); + } } void Worker::Run(){ @@ -124,12 +129,17 @@ int Worker::Get(shared_ptr<Param> param, int step){ return 1; } int Worker::Update(shared_ptr<Param> param, int step){ - Msg* msg=new Msg(); - msg->set_src(group_id_, worker_id_, kWorkerParam); - msg->set_dst(-1, -1, kStub); - msg->set_type(kUpdate); - msg->set_target(param->owner(), step); - param_dealer_->Send(&msg); + if(updater_){ + updater_->Update(step, param); + param->set_version(param->version()+1); + }else{ + Msg* msg=new Msg(); + msg->set_src(group_id_, worker_id_, kWorkerParam); + msg->set_dst(-1, -1, kStub); + msg->set_type(kUpdate); + msg->set_target(param->owner(), step); + param_dealer_->Send(&msg); + } return 1; } http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/806826eb/src/utils/param.cc ---------------------------------------------------------------------- diff --git a/src/utils/param.cc b/src/utils/param.cc index ac5566c..e616a1c 100644 --- a/src/utils/param.cc +++ b/src/utils/param.cc @@ -175,6 +175,116 @@ void Param::Init(int v){ set_version(v); } +/********************HogwildParam***************************/ +Msg* HogwildParam::GenPutMsg(void* arg){ + char buf[128]; + sprintf(buf, "%d %f %f %p", size(), + learning_rate_multiplier(), weight_decay_multiplier(), mutable_cpu_data()); + Msg* msg=new Msg(); + msg->set_type(kPut); + int v=version(); + if(arg!=nullptr) + v=*(int*)arg; + msg->set_target(owner(), v); + msg->add_frame(buf, strlen(buf)); + return msg; +} + +Msg* HogwildParam::GenGetMsg(void* arg){ + Msg* msg=new Msg(); + msg->set_type(kGet); + int v=version(); + if(arg!=nullptr) + v=*(int*)arg; + msg->set_target(owner(), v); + return msg; +} + +Msg* HogwildParam::GenUpdateMsg(void* arg){ + Msg* msg=new Msg(); + msg->set_type(kUpdate); + int v=version(); + if(arg!=nullptr) + v=*(int*)arg; + msg->set_target(owner(), v); + void* p=mutable_cpu_grad(); + msg->add_frame(p, sizeof(void*)); + return msg; +} + +Msg* HogwildParam::GenSyncMsg(void* arg){ + return nullptr; +} + +Msg* HogwildParam::HandlePutMsg(Msg** msg){ + int size; + float lr, wc; + sscanf(static_cast<char*>((*msg)->frame_data()), "%d %f %f", + &size, &lr, &wc); + proto_.set_learning_rate_multiplier(lr); + proto_.set_weight_decay_multiplier(wc); + CHECK((*msg)->next_frame()); + vector<int> shape{size}; + // set pointer + data_=std::make_shared<Blob<float>>(shape); + data_->set_version((*msg)->target_second()); + grad_.Reshape(shape); + history_.Reshape(shape); + delete (*msg); + *msg=nullptr; + return nullptr; +} + +Msg* HogwildParam::HandleGetMsg(Msg** msg){ + if((*msg)->target_second()<=version()){ + (*msg)->add_frame(mutable_cpu_data(), sizeof(float)*size()); + (*msg)->SwapAddr(); + (*msg)->set_type(kRGet); + } + return *msg; +} + +int HogwildParam::ParseUpdateMsg(Msg** msg){ + delete (*msg); + *msg=nullptr; + return 1; +} + +Msg* HogwildParam::GenUpdateResponseMsg(void* arg){ + Msg* msg=new Msg(); + msg->set_type(kRUpdate); + int v=version(); + if(arg!=nullptr) + v=*(int*)arg; + msg->set_target(owner(), v); + return msg; +} + +Msg* HogwildParam::HandleSyncMsg(Msg** msg){ + delete *msg; + *msg=nullptr; + return nullptr; +} + +int HogwildParam::ParseSyncResponseMsg(Msg** msg){ + delete *msg; + *msg=nullptr; + return 1; +} +int HogwildParam::ParsePutResponseMsg(Msg **msg){ + return ParseSyncResponseMsg(msg); +} +int HogwildParam::ParseGetResponseMsg(Msg **msg){ + // must be set after all other settings are done! + set_version((*msg)->target_second()); + delete *msg; + *msg=nullptr; + return 1; +} +int HogwildParam::ParseUpdateResponseMsg(Msg **msg){ + return ParseGetResponseMsg(msg); +} + /**************************RandomSyncParam******************************** const vector<int> RandomSyncParam::RandomSample(int seed, int m, int n){ vector<int> samples(m);
