SINGA-21 Code review 5 review trainer.cc/h, driver.cc/.h, singa.h, main.cc - rewrite headers in driver.h - move template impl from driver.h to driver.cc - format code
Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/366e6a82 Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/366e6a82 Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/366e6a82 Branch: refs/heads/master Commit: 366e6a82684aff9c0b31e904e3c45dcca2163490 Parents: f50d293 Author: wang sheng <[email protected]> Authored: Wed Sep 23 15:20:20 2015 +0800 Committer: wang sheng <[email protected]> Committed: Wed Sep 23 15:28:43 2015 +0800 ---------------------------------------------------------------------- include/driver.h | 45 +------- include/trainer/trainer.h | 80 ++++++------- src/driver.cc | 52 ++++++++- src/main.cc | 10 +- src/neuralnet/neuralnet.cc | 4 +- src/trainer/trainer.cc | 250 +++++++++++++++++++--------------------- 6 files changed, 211 insertions(+), 230 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/366e6a82/include/driver.h ---------------------------------------------------------------------- diff --git a/include/driver.h b/include/driver.h index 7d15c98..563be77 100644 --- a/include/driver.h +++ b/include/driver.h @@ -22,7 +22,8 @@ #ifndef SINGA_DRIVER_H_ #define SINGA_DRIVER_H_ -#include "singa.h" +#include "proto/job.pb.h" +#include "proto/singa.pb.h" namespace singa { @@ -119,48 +120,6 @@ class Driver { SingaProto singa_conf_; }; -template<typename Subclass, typename Type> -int Driver::RegisterLayer(const Type& type) { - auto factory = Singleton<Factory<singa::Layer>>::Instance(); - factory->Register(type, CreateInstance(Subclass, Layer)); - return 1; -} - -template<typename Subclass, typename Type> -int Driver::RegisterParam(const Type& type) { - auto factory = Singleton<Factory<singa::Param>>::Instance(); - factory->Register(type, CreateInstance(Subclass, Param)); - return 1; -} - -template<typename Subclass, typename Type> -int Driver::RegisterParamGenerator(const Type& type) { - auto factory = Singleton<Factory<singa::ParamGenerator>>::Instance(); - factory->Register(type, CreateInstance(Subclass, ParamGenerator)); - return 1; -} - -template<typename Subclass, typename Type> -int Driver::RegisterUpdater(const Type& type) { - auto factory = Singleton<Factory<singa::Updater>>::Instance(); - factory->Register(type, CreateInstance(Subclass, Updater)); - return 1; -} - -template<typename Subclass, typename Type> -int Driver::RegisterLRGenerator(const Type& type) { - auto factory = Singleton<Factory<singa::LRGenerator>>::Instance(); - factory->Register(type, CreateInstance(Subclass, LRGenerator)); - return 1; -} - -template<typename Subclass, typename Type> -int Driver::RegisterWorker(const Type& type) { - auto factory = Singleton<Factory<singa::Worker>>::Instance(); - factory->Register(type, CreateInstance(Subclass, Worker)); - return 1; -} - } // namespace singa #endif // SINGA_DRIVER_H_ http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/366e6a82/include/trainer/trainer.h ---------------------------------------------------------------------- diff --git a/include/trainer/trainer.h b/include/trainer/trainer.h index d3d332f..1c0e039 100644 --- a/include/trainer/trainer.h +++ b/include/trainer/trainer.h @@ -19,26 +19,24 @@ * *************************************************************/ -#ifndef INCLUDE_TRAINER_TRAINER_H_ -#define INCLUDE_TRAINER_TRAINER_H_ +#ifndef SINGA_TRAINER_TRAINER_H_ +#define SINGA_TRAINER_TRAINER_H_ #include <queue> -#include <vector> #include <unordered_map> +#include <vector> +#include "communication/socket.h" +#include "neuralnet/neuralnet.h" #include "proto/job.pb.h" #include "proto/singa.pb.h" +#include "trainer/server.h" +#include "trainer/worker.h" +#include "utils/factory.h" #include "utils/param.h" #include "utils/singleton.h" -#include "utils/factory.h" -#include "neuralnet/neuralnet.h" -#include "trainer/worker.h" -#include "trainer/server.h" -#include "communication/socket.h" namespace singa { -using std::vector; - /** * Every running process has a training object which launches one or more * worker (and server) threads. @@ -77,7 +75,7 @@ class Trainer{ * @param jobConf * @return server instances */ - vector<Server*> CreateServers(const JobProto& jobConf); + std::vector<Server*> CreateServers(const JobProto& jobConf); /** * Create workers instances. * @param nthread total num of threads in current procs which is used to @@ -86,8 +84,7 @@ class Trainer{ * @param jobConf * @return worker instances */ - vector<Worker*> CreateWorkers(const JobProto& jobConf); - + std::vector<Worker*> CreateWorkers(const JobProto& jobConf); /** * Setup workers and servers. * @@ -98,12 +95,11 @@ class Trainer{ * @param workers * @param servers */ - void SetupWorkerServer( - const JobProto& jobConf, - const vector<Worker*>& workers, - const vector<Server*>& servers); - - void Run(const vector<Worker*>& workers, const vector<Server*>& servers); + void SetupWorkerServer(const JobProto& jobConf, + const std::vector<Worker*>& workers, + const std::vector<Server*>& servers); + void Run(const std::vector<Worker*>& workers, + const std::vector<Server*>& servers); /** * Display metrics to log (standard output) */ @@ -118,24 +114,20 @@ class Trainer{ * Handle messages to local servers and local stub */ void HandleLocalMsg(std::queue<Msg*>* msg_queue, Msg** msg); - - /** - * Generate a request message to Get the parameter object. - */ - const vector<Msg*> HandleGet(ParamEntry* entry, Msg** msg); - void HandleGetResponse(ParamEntry* entry, Msg** msg); - - /** - * Generate a request message to Update the parameter object. - */ - const vector<Msg*> HandleUpdate(ParamEntry* entry, Msg** msg); + /** + * Generate a request message to Get the parameter object. + */ + const std::vector<Msg*> HandleGet(ParamEntry* entry, Msg** msg); + void HandleGetResponse(ParamEntry* entry, Msg** msg); + /** + * Generate a request message to Update the parameter object. + */ + const std::vector<Msg*> HandleUpdate(ParamEntry* entry, Msg** msg); void HandleUpdateResponse(ParamEntry* entry, Msg** msg); - /** - * Generate a request message to Put the parameter object. - */ - const vector<Msg*> HandlePut(ParamEntry* entry, Msg** msg); - + * Generate a request message to Put the parameter object. + */ + const std::vector<Msg*> HandlePut(ParamEntry* entry, Msg** msg); /** * Called by HandlePut, HandleUpdate and HandleGet functions * @param type message type @@ -145,7 +137,7 @@ class Trainer{ * @param ret generated messages */ void GenMsgs(int type, int version, ParamEntry* entry, - Msg* msg, vector<Msg*> *ret); + Msg* msg, std::vector<Msg*> *ret); /** * Get a hash id for a Param object from a group. * @@ -157,13 +149,15 @@ class Trainer{ } protected: - int procs_id_; - Router *router_; + int procs_id_ = -1; + Router *router_ = nullptr; std::unordered_map<int, ParamEntry*> worker_shard_; //!< map from slice to the server that updates it - vector<int> slice2server_; - //stub will destroy all neuralnets in the end - vector<NeuralNet*> nets_; + std::vector<int> slice2server_; + // a buffer of created nets, will destroy them all in destructor + std::vector<NeuralNet*> nets_; }; -} /* singa */ -#endif // INCLUDE_TRAINER_TRAINER_H_ + +} // namespace singa + +#endif // SINGA_TRAINER_TRAINER_H_ http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/366e6a82/src/driver.cc ---------------------------------------------------------------------- diff --git a/src/driver.cc b/src/driver.cc index 28d21c2..42a1330 100644 --- a/src/driver.cc +++ b/src/driver.cc @@ -24,24 +24,27 @@ #include <cblas.h> #include <glog/logging.h> #include <string> +#include "neuralnet/neuralnet.h" +#include "neuralnet/layer.h" +#include "trainer/trainer.h" +#include "utils/common.h" +#include "utils/factory.h" +#include "utils/singleton.h" #include "utils/tinydir.h" namespace singa { void Driver::Init(int argc, char **argv) { google::InitGoogleLogging(argv[0]); - // unique job ID generated from singa-run.sh, passed in as "-singa_job <id>" int arg_pos = ArgPos(argc, argv, "-singa_job"); job_id_ = (arg_pos != -1) ? atoi(argv[arg_pos+1]) : -1; - // global signa conf passed by singa-run.sh as "-singa_conf <path>" arg_pos = ArgPos(argc, argv, "-singa_conf"); if (arg_pos != -1) ReadProtoFromTextFile(argv[arg_pos+1], &singa_conf_); else ReadProtoFromTextFile("conf/singa.conf", &singa_conf_); - // job conf passed by users as "-conf <path>" arg_pos = ArgPos(argc, argv, "-conf"); CHECK_NE(arg_pos, -1); @@ -107,7 +110,47 @@ void Driver::Init(int argc, char **argv) { RegisterParamGenerator<UniformSqrtFanInOutGen>(kUniformSqrtFanInOut); } +template<typename Subclass, typename Type> +int Driver::RegisterLayer(const Type& type) { + auto factory = Singleton<Factory<singa::Layer>>::Instance(); + factory->Register(type, CreateInstance(Subclass, Layer)); + return 1; +} + +template<typename Subclass, typename Type> +int Driver::RegisterParam(const Type& type) { + auto factory = Singleton<Factory<singa::Param>>::Instance(); + factory->Register(type, CreateInstance(Subclass, Param)); + return 1; +} + +template<typename Subclass, typename Type> +int Driver::RegisterParamGenerator(const Type& type) { + auto factory = Singleton<Factory<singa::ParamGenerator>>::Instance(); + factory->Register(type, CreateInstance(Subclass, ParamGenerator)); + return 1; +} + +template<typename Subclass, typename Type> +int Driver::RegisterUpdater(const Type& type) { + auto factory = Singleton<Factory<singa::Updater>>::Instance(); + factory->Register(type, CreateInstance(Subclass, Updater)); + return 1; +} +template<typename Subclass, typename Type> +int Driver::RegisterLRGenerator(const Type& type) { + auto factory = Singleton<Factory<singa::LRGenerator>>::Instance(); + factory->Register(type, CreateInstance(Subclass, LRGenerator)); + return 1; +} + +template<typename Subclass, typename Type> +int Driver::RegisterWorker(const Type& type) { + auto factory = Singleton<Factory<singa::Worker>>::Instance(); + factory->Register(type, CreateInstance(Subclass, Worker)); + return 1; +} void Driver::Submit(bool resume, const JobProto& jobConf) { if (singa_conf_.has_log_dir()) @@ -118,9 +161,8 @@ void Driver::Submit(bool resume, const JobProto& jobConf) { LOG(FATAL) << "workspace does not exist: " << jobConf.cluster().workspace(); if (jobConf.num_openblas_threads() != 1) LOG(WARNING) << "openblas with " - << jobConf.num_openblas_threads() << " threads"; + << jobConf.num_openblas_threads() << " threads"; openblas_set_num_threads(jobConf.num_openblas_threads()); - JobProto job; job.CopyFrom(jobConf); job.set_id(job_id_); http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/366e6a82/src/main.cc ---------------------------------------------------------------------- diff --git a/src/main.cc b/src/main.cc index 5e94de4..5d2ab2f 100644 --- a/src/main.cc +++ b/src/main.cc @@ -45,20 +45,20 @@ */ int main(int argc, char **argv) { - // must create driver at the beginning and call its Init method. + // must create driver at the beginning and call its Init method. singa::Driver driver; driver.Init(argc, argv); - // if -resume in argument list, set resume to true; otherwise false + // if -resume in argument list, set resume to true; otherwise false int resume_pos = singa::ArgPos(argc, argv, "-resume"); bool resume = (resume_pos != -1); - // users can register new subclasses of layer, updater, etc. + // users can register new subclasses of layer, updater, etc. - // get the job conf, and custmize it if need + // get the job conf, and custmize it if need singa::JobProto jobConf = driver.job_conf(); - // submit the job + // submit the job driver.Submit(resume, jobConf); return 0; } http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/366e6a82/src/neuralnet/neuralnet.cc ---------------------------------------------------------------------- diff --git a/src/neuralnet/neuralnet.cc b/src/neuralnet/neuralnet.cc index 200824a..775a5a7 100644 --- a/src/neuralnet/neuralnet.cc +++ b/src/neuralnet/neuralnet.cc @@ -19,10 +19,10 @@ * *************************************************************/ +#include "neuralnet/neuralnet.h" + #include <algorithm> #include <queue> - -#include "neuralnet/neuralnet.h" #include "utils/singleton.h" namespace singa { http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/366e6a82/src/trainer/trainer.cc ---------------------------------------------------------------------- diff --git a/src/trainer/trainer.cc b/src/trainer/trainer.cc index 8a0589e..ecfc94a 100644 --- a/src/trainer/trainer.cc +++ b/src/trainer/trainer.cc @@ -19,25 +19,21 @@ * *************************************************************/ -#include <thread> -#include <vector> -#include <map> -#include <chrono> +#include "trainer/trainer.h" + #include <glog/logging.h> -#include "utils/tinydir.h" #include <unistd.h> +#include <map> +#include <thread> +#include "mshadow/tensor.h" +#include "proto/common.pb.h" #include "utils/cluster.h" #include "utils/common.h" -#include "proto/common.pb.h" -#include "trainer/trainer.h" -#include "mshadow/tensor.h" - +#include "utils/tinydir.h" namespace singa { + using std::vector; -using std::map; -using std::queue; -using namespace std::chrono; using std::string; /***********************Trainer****************************/ @@ -47,12 +43,82 @@ Trainer::~Trainer() { delete p; } +void Trainer::Start(bool resume, const SingaProto& singaConf, JobProto* job) { + // register job to zookeeper at the beginning + auto cluster = Cluster::Setup(job->id(), singaConf, job->cluster()); + if (resume) Resume(job); + router_ = new Router(); + router_->Bind(kInprocRouterEndpoint); + const string hostip = cluster->hostip(); + int port = router_->Bind("tcp://" + hostip + ":*"); + // register endpoint to zookeeper + cluster->Register(getpid(), hostip + ":" + std::to_string(port)); + const vector<Worker*> workers = CreateWorkers(*job); + const vector<Server*> servers = CreateServers(*job); + SetupWorkerServer(*job, workers, servers); +#ifdef USE_MPI + int nthreads = workers.size() + servers.size(); + for (int i = 0; i < nthreads; i++) + MPIQueues.push_back(make_shared<SafeQueue>()); +#endif + vector<std::thread> threads; + for (auto server : servers) + threads.push_back(std::thread(&Server::Run, server)); + for (auto worker : workers) + threads.push_back(std::thread(&Worker::Run, worker)); + Run(workers, servers); + for (auto& thread : threads) + thread.join(); + for (auto server : servers) + delete server; + for (auto worker : workers) + delete worker; +} + +void Trainer::Resume(JobProto* jobConf) { + tinydir_dir dir; + string folder = Cluster::Get()->checkpoint_folder(); + tinydir_open(&dir, folder.c_str()); + int latest_step = 0; + // there would be multi checkpoint files (from diff workers) for one step + vector<string> ck_files; + // iterate all files to get the files for the last checkpoint + while (dir.has_next) { + tinydir_file file; + tinydir_readfile(&dir, &file); + tinydir_next(&dir); + char* ch = strstr(file.name, "step"); + if (ch == nullptr) { + if (file.name[0] != '.') + LOG(INFO) << "Irregular file in checkpoint folder: " << file.name; + continue; + } + LOG(INFO) << "Add checkpoint file for resume: " << ch; + int step = atoi(ch+4); + if (step == latest_step) { + ck_files.push_back(file.name); + } else if (step > latest_step) { + latest_step = step; + ck_files.clear(); + ck_files.push_back(string(file.name)); + } + } + if (latest_step > 0) { + jobConf->set_step(latest_step); + if (!jobConf->has_reset_param_version()) + jobConf->set_reset_param_version(false); + jobConf->clear_checkpoint_path(); + for (auto ck_file : ck_files) + jobConf->add_checkpoint_path(folder + "/" + ck_file); + } + tinydir_close(&dir); +} + const vector<int> SliceParams(const vector<Param*>& params) { // for load-balance among servers in a group and among server groups int nserver_grps = Cluster::Get()->nserver_groups(); int nservers_per_grp = Cluster::Get()->nservers_per_group(); int lcm = LeastCommonMultiple(nserver_grps, nservers_per_grp); - // collect sizes of unique Params std::vector<int> paramsize; for (auto param : params) @@ -86,10 +152,9 @@ const vector<int> SliceParams(const vector<Param*>& params) { return slices; } -void Trainer::SetupWorkerServer( - const JobProto& job_conf, - const vector<Worker*>& workers, - const vector<Server*>& servers) { +void Trainer::SetupWorkerServer(const JobProto& job_conf, + const vector<Worker*>& workers, + const vector<Server*>& servers) { auto cluster = Cluster::Get(); int grp_size = cluster->nworkers_per_group(); const auto& net_conf = job_conf.neuralnet(); @@ -97,7 +162,6 @@ void Trainer::SetupWorkerServer( nets_.push_back(net); // MUST do SliceParam before share param/net with others auto slices = SliceParams(net->params()); - std::unordered_map<int, NeuralNet*> grp_net; int first_grp = workers.size() ? workers.at(0)->grp_id() : -1; for (auto worker : workers) { @@ -107,13 +171,17 @@ void Trainer::SetupWorkerServer( NeuralNet* valid_net = nullptr; if (grp_net.find(grp_id) == grp_net.end()) { if (grp_id == first_grp) { - // test are performed only by the first group now. TODO update. + // test are performed only by the first group now. + // TODO(wangwei) update. if (first_grp == 0 && job_conf.test_steps() && worker_id == 0) { - test_net = NeuralNet::Create(net_conf, kTest, 1); // hard code for exp + // hard code for exp + // TODO(wangwei) move test unit out as an independent module + test_net = NeuralNet::Create(net_conf, kTest, 1); test_net->ShareParamsFrom(net); nets_.push_back(test_net); } - // validation are performed only by the first group. TODO update. + // validation are performed only by the first group. + // TODO(wangwei) update. if (first_grp == 0 && job_conf.valid_steps() && worker_id == 0) { valid_net = NeuralNet::Create(net_conf, kValidation, 1); valid_net->ShareParamsFrom(net); @@ -123,7 +191,7 @@ void Trainer::SetupWorkerServer( } else { grp_net[grp_id] = NeuralNet::Create(net_conf, kTrain, grp_size); nets_.push_back(grp_net[grp_id]); - if(cluster->share_memory()) + if (cluster->share_memory()) grp_net[grp_id]->ShareParamsFrom(net); } for (auto layer : grp_net[grp_id]->layers()) { @@ -141,12 +209,10 @@ void Trainer::SetupWorkerServer( << worker->id() << " net " << grp_net[grp_id]; worker->Setup(job_conf, grp_net[grp_id], valid_net, test_net); } - // partition among server groups, each group maintains one sub-set for sync auto slice2group = PartitionSlices(cluster->nserver_groups(), slices); // partition within one server group, each server updates for one sub-set slice2server_ = PartitionSlices(cluster->nservers_per_group(), slices); - for (auto server : servers) server->Setup(job_conf.updater(), slice2group, slice2server_); } @@ -156,14 +222,13 @@ vector<Server*> Trainer::CreateServers(const JobProto& job) { vector<Server*> servers; if (!cluster->has_server()) return servers; - int server_procs = cluster->procs_id(); // if true, server procs (logical) id starts after worker procs if (cluster->server_worker_separate()) server_procs -= cluster->nworker_procs(); const vector<int> rng = cluster->ExecutorRng(server_procs, - cluster->nservers_per_group(), - cluster->nservers_per_procs()); + cluster->nservers_per_group(), + cluster->nservers_per_procs()); int gstart = rng[0], gend = rng[1], start = rng[2], end = rng[3]; for (int gid = gstart; gid < gend; gid++) { for (int sid = start; sid < end; sid++) { @@ -174,15 +239,14 @@ vector<Server*> Trainer::CreateServers(const JobProto& job) { return servers; } - vector<Worker*> Trainer::CreateWorkers(const JobProto& job) { - auto cluster=Cluster::Get(); + auto cluster = Cluster::Get(); vector<Worker*> workers; - if(!cluster->has_worker()) + if (!cluster->has_worker()) return workers; const vector<int> rng = cluster->ExecutorRng(cluster->procs_id(), - cluster->nworkers_per_group(), - cluster->nworkers_per_procs()); + cluster->nworkers_per_group(), + cluster->nworkers_per_procs()); int gstart = rng[0], gend = rng[1], wstart = rng[2], wend = rng[3]; for (int gid = gstart; gid < gend; gid++) { for (int wid = wstart; wid < wend; wid++) { @@ -194,93 +258,13 @@ vector<Worker*> Trainer::CreateWorkers(const JobProto& job) { return workers; } -void Trainer::Resume(JobProto* jobConf) { - tinydir_dir dir; - string folder = Cluster::Get()->checkpoint_folder(); - tinydir_open(&dir, folder.c_str()); - int latest_step = 0; - // there would be multi checkpoint files (from diff workers) for one step - vector<string> ck_files; - // iterate all files to get the files for the last checkpoint - while (dir.has_next) { - tinydir_file file; - tinydir_readfile(&dir, &file); - tinydir_next(&dir); - char* ch = strstr(file.name, "step"); - if (ch == nullptr) { - if (file.name[0] != '.') - LOG(INFO) << "Irregular file in checkpoint folder: " << file.name; - continue; - } - - LOG(INFO) << "Add checkpoint file for resume: " << ch; - int step = atoi(ch+4); - if (step == latest_step) { - ck_files.push_back(file.name); - } else if(step > latest_step) { - latest_step = step; - ck_files.clear(); - ck_files.push_back(string(file.name)); - } - } - - if (latest_step > 0) { - jobConf->set_step(latest_step); - if (!jobConf->has_reset_param_version()) - jobConf->set_reset_param_version(false); - jobConf->clear_checkpoint_path(); - for (auto ck_file : ck_files) - jobConf->add_checkpoint_path(folder + "/" + ck_file); - } - tinydir_close(&dir); -} - -void Trainer::Start(bool resume, const SingaProto& singaConf, JobProto* job) { - // register job to zookeeper at the beginning - auto cluster = Cluster::Setup(job->id(), singaConf, job->cluster()); - if (resume) - Resume(job); - - router_ = new Router(); - router_->Bind(kInprocRouterEndpoint); - const string hostip = cluster->hostip(); - int port = router_->Bind("tcp://" + hostip + ":*"); - // register endpoint to zookeeper - cluster->Register(getpid(), hostip + ":" + std::to_string(port)); - - const vector<Worker*> workers = CreateWorkers(*job); - const vector<Server*> servers = CreateServers(*job); - SetupWorkerServer(*job, workers, servers); - -#ifdef USE_MPI - int nthreads = workers.size() + servers.size(); - for (int i = 0; i < nthreads; i++) - MPIQueues.push_back(make_shared<SafeQueue>()); -#endif - vector<std::thread> threads; - for(auto server : servers) - threads.push_back(std::thread(&Server::Run, server)); - for(auto worker : workers) - threads.push_back(std::thread(&Worker::Run, worker)); - Run(workers, servers); - for(auto& thread : threads) - thread.join(); - for(auto server : servers) - delete server; - for(auto worker : workers) - delete worker; -} - -void Trainer::Run( - const vector<Worker*>& workers, - const vector<Server*>& servers) { +void Trainer::Run(const vector<Worker*>& workers, + const vector<Server*>& servers) { int nworkers = workers.size(), nservers = servers.size(); auto cluster = Cluster::Get(); procs_id_ = cluster->procs_id(); LOG(INFO) << "Stub in process " << procs_id_ << " starts"; - - map<int, Dealer*> inter_dealers; // for sending msg to other procs - + std::map<int, Dealer*> inter_dealers; // for sending msg to other procs std::queue<Msg*> msg_queue; while (true) { Msg* msg = nullptr; @@ -343,26 +327,27 @@ Dealer* Trainer::CreateInterProcsDealer(int dst_procs) { // forward to other procs auto cluster = Cluster::Get(); auto dealer = new Dealer(); - while(cluster->endpoint(dst_procs)=="") { - //kCollectSleepTime)); + while (cluster->endpoint(dst_procs) == "") { + // kCollectSleepTime)); std::this_thread::sleep_for(std::chrono::milliseconds(3000)); - LOG(ERROR)<<"waiting for procs "<< dst_procs<<" to register"; + LOG(ERROR) << "waiting for procs " << dst_procs << " to register"; } dealer->Connect("tcp://"+cluster->endpoint(dst_procs)); return dealer; } -void Trainer::HandleLocalMsg(queue<Msg*>* msg_queue, Msg** msg) { +void Trainer::HandleLocalMsg(std::queue<Msg*>* msg_queue, Msg** msg) { Msg* msgg = *msg; int paramid = ParamID(msgg->trgt_val()); int type = msgg->type(); int grp; ParamEntry *entry = nullptr; - switch (type) { // TODO process other requests, e.g. RESTful + // TODO(wangwei) process other requests, e.g. RESTful + switch (type) { case kUpdate: grp = AddrGrp(msgg->src()); entry = worker_shard_.at(Hash(grp, paramid)); - for(auto update_msg : HandleUpdate(entry, msg)) + for (auto update_msg : HandleUpdate(entry, msg)) msg_queue->push(update_msg); break; case kRUpdate: @@ -373,7 +358,7 @@ void Trainer::HandleLocalMsg(queue<Msg*>* msg_queue, Msg** msg) { case kGet: grp = AddrGrp(msgg->src()); entry = worker_shard_.at(Hash(grp, paramid)); - for(auto get_msg : HandleGet(entry, msg)) + for (auto get_msg : HandleGet(entry, msg)) msg_queue->push(get_msg); break; case kRGet: @@ -384,22 +369,22 @@ void Trainer::HandleLocalMsg(queue<Msg*>* msg_queue, Msg** msg) { case kPut: grp = AddrGrp(msgg->src()); entry = worker_shard_.at(Hash(grp, paramid)); - for(auto put_msg : HandlePut(entry, msg)) + for (auto put_msg : HandlePut(entry, msg)) msg_queue->push(put_msg); break; default: - LOG(ERROR)<<"Unknow message type:"<<type; + LOG(ERROR) << "Unknow message type:" << type; break; } } -void Trainer::GenMsgs(int type, int version, ParamEntry* entry, - Msg* msg, vector<Msg*> *ret) { +void Trainer::GenMsgs(int type, int version, ParamEntry* entry, Msg* msg, + vector<Msg*> *ret) { int src_grp = AddrGrp(msg->src()); int dst_grp = src_grp / Cluster::Get()->nworker_groups_per_server_group(); - auto param=entry->shares.at(0); + auto param = entry->shares.at(0); for (int idx = 0 ; idx < param->num_slices(); idx++) { - int slice_id =param->slice_start() + idx; + int slice_id = param->slice_start() + idx; int server = slice2server_[slice_id]; int dst_procs = Cluster::Get()->ProcsIDOf(dst_grp, server, kServer); Msg* new_msg = nullptr; @@ -440,10 +425,10 @@ const vector<Msg*> Trainer::HandleUpdate(ParamEntry *entry, Msg** msg) { // average local gradient if (entry->num_local > 1) { auto it = entry->shares.begin(); - auto shape=mshadow::Shape1((*it)->size()); - mshadow::Tensor<mshadow::cpu,1> sum((*it)->mutable_cpu_grad(), shape); + auto shape = mshadow::Shape1((*it)->size()); + mshadow::Tensor<mshadow::cpu, 1> sum((*it)->mutable_cpu_grad(), shape); for (++it; it != entry->shares.end(); it++) { - mshadow::Tensor<mshadow::cpu,1> grad((*it)->mutable_cpu_grad(), shape); + mshadow::Tensor<mshadow::cpu, 1> grad((*it)->mutable_cpu_grad(), shape); sum += grad; } } @@ -480,4 +465,5 @@ void Trainer::HandleUpdateResponse(ParamEntry* entry, Msg** msg) { param->set_version(version); DeleteMsg(msg); } -} /* singa */ + +} // namespace singa
