1. move functions in pm_server (pm_worker) into server (trainer) to simplify the logics. now workers send simple messages to the stub thread which construct the real update/get/put requests. the stub thread also handles the responses from servers. E.g., the get/update response is handled by the stub now. the workers then wait until its param's version is updated in the collect function. avoid deadlocks for param_dealer_ and layer_dealer_ 2. tested data partition in single group in one procs. 3. generate a json file under workspace/visualization representing the neural net structure. users can create an image using the python script (scirpt/graph.py) reading the json file.
Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/b5b943c7 Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/b5b943c7 Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/b5b943c7 Branch: refs/heads/master Commit: b5b943c7deb33c976272471446545f72d2b5a2be Parents: 39969f2 Author: wang wei <[email protected]> Authored: Sat May 16 22:59:49 2015 +0800 Committer: wang wei <[email protected]> Committed: Sat May 16 22:59:49 2015 +0800 ---------------------------------------------------------------------- examples/cifar10/model.conf | 4 +- include/communication/msg.h | 77 +++++---- include/communication/socket.h | 6 +- include/neuralnet/base_layer.h | 14 ++ include/trainer/pm_server.h | 91 ---------- include/trainer/pm_worker.h | 172 ------------------- include/trainer/server.h | 62 ++++++- include/trainer/trainer.h | 91 +++++++++- include/trainer/worker.h | 30 +--- include/utils/common.h | 31 ++++ src/communication/socket.cc | 16 +- src/main.cc | 1 + src/neuralnet/base_layer.cc | 37 ++-- src/neuralnet/neuralnet.cc | 19 ++- src/proto/model.pb.h | 82 ++++----- src/proto/model.proto | 2 +- src/trainer/pm_server.cc | 99 ----------- src/trainer/pm_worker.cc | 324 ------------------------------------ src/trainer/server.cc | 123 ++++++++++---- src/trainer/trainer.cc | 220 +++++++++++++++++++----- src/trainer/worker.cc | 241 ++++++++++----------------- src/utils/graph.cc | 50 ++++-- src/utils/param.cc | 71 ++++---- 23 files changed, 758 insertions(+), 1105 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b5b943c7/examples/cifar10/model.conf ---------------------------------------------------------------------- diff --git a/examples/cifar10/model.conf b/examples/cifar10/model.conf index bace74d..76ce8db 100644 --- a/examples/cifar10/model.conf +++ b/examples/cifar10/model.conf @@ -1,8 +1,8 @@ name: "cifar10-convnet" train_steps: 70000 -test_steps:5 +test_steps:100 test_frequency:1000 -display_frequency:1 +display_frequency:30 updater{ momentum:0.9 weight_decay:0.004 http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b5b943c7/include/communication/msg.h ---------------------------------------------------------------------- diff --git a/include/communication/msg.h b/include/communication/msg.h index 54b4601..21ac78e 100644 --- a/include/communication/msg.h +++ b/include/communication/msg.h @@ -13,27 +13,28 @@ class BaseMsg{ */ virtual ~BaseMsg(){}; /** - * @param group_id worker/server group id + * @param first worker/server group id * @param id worker/server id within the group * @param flag 0 for server, 1 for worker, 2 for stub */ - virtual void set_src(int group_id, int id, int flag)=0; - virtual void set_dst(int group_id, int id, int flag)=0; + virtual void set_src(int first, int second, int flag)=0; + virtual void set_dst(int first, int second, int flag)=0; virtual void set_src(int procs_id, int flag)=0; virtual void set_dst(int procs_id, int flag)=0; - virtual int src_group_id() const=0; - virtual int dst_group_id() const=0; - virtual int src_id() const=0; - virtual int dst_id() const=0; + virtual int src_first() const=0; + virtual int dst_first() const=0; + virtual int src_second() const=0; + virtual int dst_second() const=0; virtual int src_flag() const=0; virtual int dst_flag() const=0; virtual void set_type(int type)=0; virtual int type() const=0; - virtual void set_target(int target)=0; - virtual int target() const=0; + virtual void set_target(int first, int second)=0; + virtual int target_first() const=0; + virtual int target_second() const=0; /** - * Copy src and dst address, including group_id, id, flag + * Copy src and dst address, including first, id, flag */ virtual BaseMsg* CopyAddr()=0; virtual void SetAddr(BaseMsg* msg)=0; @@ -64,11 +65,11 @@ class Msg : public BaseMsg{ if(msg_!=NULL) zmsg_destroy(&msg_); } - virtual void set_src(int group_id, int id, int flag){ - src_=(group_id<<kOff1)|(id<<kOff2)|flag; + virtual void set_src(int first, int second, int flag){ + src_=(first<<kOff1)|(second<<kOff2)|flag; } - virtual void set_dst(int group_id, int id, int flag){ - dst_=(group_id<<kOff1)|(id<<kOff2)|flag; + virtual void set_dst(int first, int second, int flag){ + dst_=(first<<kOff1)|(second<<kOff2)|flag; } virtual void set_src(int procs_id, int flag){ set_src(procs_id, 0, flag); @@ -82,20 +83,20 @@ class Msg : public BaseMsg{ int dst() const { return dst_; } - virtual int src_group_id() const { + virtual int src_first() const { int ret=src_>>kOff1; return ret; } - virtual int dst_group_id() const{ + virtual int dst_first() const{ int ret=dst_>>kOff1; return ret; } - virtual int src_id() const{ + virtual int src_second() const{ int ret=(src_&kMask1)>>kOff2; return ret; } - virtual int dst_id() const{ + virtual int dst_second() const{ int ret=(dst_&kMask1)>>kOff2; return ret; } @@ -113,22 +114,24 @@ class Msg : public BaseMsg{ } virtual void set_type(int type){ - target_=(type<<kOff3)|(target_&kMask3); - } - virtual void set_target(int target){ - target_=(target_>>kOff3)<<kOff3; - target_=target_|target; + type_=type; } virtual int type() const{ - int ret=target_>>kOff3; - return ret; + return type_; } - virtual int target() const{ - int ret=target_&kMask3; - return ret; + + virtual void set_target(int first, int second){ + target_first_=first; + target_second_=second; + } + virtual int target_first() const{ + return target_first_; + } + virtual int target_second() const{ + return target_second_; } - virtual BaseMsg* CopyAddr(){ + virtual BaseMsg* CopyAddr(){ Msg* msg=new Msg(); msg->src_=src_; msg->dst_=dst_; @@ -158,25 +161,27 @@ class Msg : public BaseMsg{ void ParseFromZmsg(zmsg_t* msg){ char* tmp=zmsg_popstr(msg); - sscanf(tmp, "%d %d %d", &src_, &dst_, &target_); + sscanf(tmp, "%d %d %d %d %d", + &src_, &dst_, &type_, &target_first_, &target_second_); //LOG(ERROR)<<"recv "<<src_<<" "<<dst_<<" "<<target_; frame_=zmsg_next(msg); msg_=msg; } zmsg_t* DumpToZmsg(){ - zmsg_pushstrf(msg_, "%d %d %d",src_, dst_,target_); + zmsg_pushstrf(msg_, "%d %d %d %d %d", + src_, dst_, type_, target_first_, target_second_); //LOG(ERROR)<<"send "<<src_<<" "<<dst_<<" "<<target_; - zmsg_t* tmp=msg_; + zmsg_t *tmp=msg_; msg_=NULL; return tmp; } protected: - static const unsigned int kOff1=16, kOff2=4, kOff3=24; - static const unsigned int kMask1=(1<<kOff1)-1, kMask2=(1<<kOff2)-1, - kMask3=(1<<kOff3)-1; - unsigned int src_, dst_, target_; + static const unsigned int kOff1=16, kOff2=4; + static const unsigned int kMask1=(1<<kOff1)-1, kMask2=(1<<kOff2)-1; + int src_, dst_; + int type_, target_first_, target_second_; zmsg_t* msg_; zframe_t *frame_; }; http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b5b943c7/include/communication/socket.h ---------------------------------------------------------------------- diff --git a/include/communication/socket.h b/include/communication/socket.h index 771c3da..4b1f467 100644 --- a/include/communication/socket.h +++ b/include/communication/socket.h @@ -16,7 +16,7 @@ class Socket{ * @param the message to be sent * @return 1 for success queuing the message for sending, 0 for failure */ - virtual int Send(Msg* msg)=0; + virtual int Send(Msg** msg)=0; /** * Receive a message from any connected socket. * @@ -84,7 +84,7 @@ class Dealer : public Socket{ * @return 1 connection sets up successfully; 0 otherwise */ virtual int Connect(string endpoint); - virtual int Send(Msg* msg); + virtual int Send(Msg** msg); virtual Msg* Receive(); virtual void* InternalID() const{ return dealer_; @@ -123,7 +123,7 @@ class Router : public Socket{ /** * If the destination socket has not connected yet, buffer this the message. */ - virtual int Send(Msg* msg); + virtual int Send(Msg** msg); virtual Msg* Receive(); virtual void* InternalID() const{ return router_; http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b5b943c7/include/neuralnet/base_layer.h ---------------------------------------------------------------------- diff --git a/include/neuralnet/base_layer.h b/include/neuralnet/base_layer.h index 8e49059..777c2cb 100644 --- a/include/neuralnet/base_layer.h +++ b/include/neuralnet/base_layer.h @@ -288,6 +288,19 @@ class BridgeSrcLayer: public Layer { virtual void ComputeFeature(bool training, const vector<SLayer>& srclayers); virtual void ComputeGradient(const vector<SLayer>& srclayers); + virtual const Blob<float>& data(const Layer* from) const { + return srclayers_[0]->data(this); + } + virtual Blob<float>* mutable_data(const Layer* from){ + return srclayers_[0]->mutable_data(this); + } + + virtual const Blob<float>& grad(const Layer* from) const { + return srclayers_[0]->grad(this); + } + virtual Blob<float>* mutable_grad(const Layer* from) { + return srclayers_[0]->mutable_grad(this); + } int dst_partition() const; virtual bool is_bridgesrclayer() const { return true; @@ -478,6 +491,7 @@ class SliceLayer: public Layer { protected: int SliceID(const Layer* layer) const; vector<Blob<float>> datavec_, gradvec_; + int slice_dim_, slice_num_; }; /** http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b5b943c7/include/trainer/pm_server.h ---------------------------------------------------------------------- diff --git a/include/trainer/pm_server.h b/include/trainer/pm_server.h deleted file mode 100644 index b759844..0000000 --- a/include/trainer/pm_server.h +++ /dev/null @@ -1,91 +0,0 @@ -#ifndef INCLUDE_TRAINER_PM_SERVER_H_ -#define INCLUDE_TRAINER_PM_SERVER_H_ - -#include <czmq.h> -#include <memory> -#include <vector> -#include <map> -#include <string.h> -#include "proto/model.pb.h" -#include "utils/updater.h" -#include "utils/param.h" -#include "communication/msg.h" -#include "communication/socket.h" -using std::vector; -using std::string; -using std::shared_ptr; - -namespace singa{ - -/** - * Parameter manager at the server side. - * - * Repsond to worker's get/put/udpate request, and periodically syncing with - * other servers. - * - * Normally, the PMServer creates a response message for each request which - * will be sent back to the one who issued the request. However, if the request - * are not processed successfully, the original message will be returned. The - * sever does not know the returned message (response or the original message), - * it just sends it to the router. The router will decide to re-send the - * request to the server or send it to the worker. - * - */ -class PMServer{ -public: - typedef std::map<int, shared_ptr<Param>> ParamShard; - - void Setup(int group_id, int server_id, shared_ptr<ParamShard> shard, - const UpdaterProto& proto); - - ~PMServer(); - - /** - * Process GET request. - * - * @return the orignal message or response message - */ - virtual Msg* HandleGet(Msg** msg); - - /** - * Process Update request. - * - * @return the orignal message or response message - */ - virtual Msg* HandleUpdate(Msg** msg); - - /** - * Process PUT request. - * - * @return the original message or response message. If we don't want need to - * acknowledge the put request, then return nullptr. - */ - virtual Msg* HandlePut(Msg **msg); - - /** - * TODO Process SYNC request. - */ - virtual Msg* HandleSyncRequest(Msg** msg); - - /** - * TODO Process SYNC response. - */ - virtual int HandleSyncResponse(Msg** msg); - - /** - * Scheduler for synchronizing server groups. - * - * TODO implement the Caffe's synchronization scheduler for data parallelism - */ - virtual bool SyncNow(); - - protected: - int group_id_, server_id_; - shared_ptr<ParamShard> shard_; - shared_ptr<Dealer> dealer_; - shared_ptr<Updater> updater_; -}; - -} // namespace singa - -#endif // INCLUDE_TRAINER_PM_SERVER_H_ http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b5b943c7/include/trainer/pm_worker.h ---------------------------------------------------------------------- diff --git a/include/trainer/pm_worker.h b/include/trainer/pm_worker.h deleted file mode 100644 index 9b973d6..0000000 --- a/include/trainer/pm_worker.h +++ /dev/null @@ -1,172 +0,0 @@ -#ifndef INCLUDE_TRAINER_PM_WORKER_H_ -#define INCLUDE_TRAINER_PM_WORKER_H_ - -#include <memory> -#include <vector> -#include <map> -#include <string> -#include <atomic> -#include "utils/param.h" -#include "communication/msg.h" - -using std::string; -using std::vector; -using std::shared_ptr; -using std::map; - -namespace singa { - -/** - * Counters used to construct a parameter shard. - * - * For each worker group: - * Every unique Param object is associated with a ParamCounter object whose - * param field points the to Param object itself. - * - * Param objects sharing the same values (due to data parallelism) are - * associated with the same ParamCounter whose param field also shares the - * same values. - * - * Usage: we need to aggregate gradients from all workers for the shared - * parameters before sending the update request. The nUpdate counter counts - * the number. - * - * TODO test with different physical architectures. - */ -class ParamCounter{ - public: - ParamCounter(shared_ptr<Param> p,int local, int owner): - nUpdate(0), nGet(0), nPut(0), nCollect(0), nLocal(local), nTotal(1), - owner_procs(owner){ - shares.push_back(p); - } - - /** - * Associate the counter to a Param object. - * - * @param p - * @param local 1 if this Param object is used by workers in this procs, 0 - * otherwise - * @param owner the procs id of the worker who ownes this Param object - */ - void AddParam(shared_ptr<Param> p, int local, int owner){ - nLocal+=local; - nTotal+=1; - if(owner>-1) - owner_procs=owner; - if(local>0){ - shares.push_back(p); - } - } - std::atomic<int> nUpdate, nGet, nPut, nCollect; //!< all counters are atomic - - int nLocal; //!< # local workers uses the shared parameter - int nTotal; //!< # total workers uses the shared parameter - int owner_procs; //!< the procs id of the worker that owns the parameter - vector<shared_ptr<Param>> shares; -}; - -/** - * Parameter manager at the worker side. - */ -class PMWorker{ -public: - /** - * Workers from the same group resident in the same process share the same - * ParamShard which contains ParamCounters for Param objects used/updated by - * these worekrs. Shared Param objects are associated with the same - * ParamCounter. - */ - typedef std::map<int, shared_ptr<ParamCounter>> ParamShard; - - - void Setup(int group_id, int worker_id, shared_ptr<ParamShard> shard); - - void set_id(int group_id, int worker_id){ - group_id_=group_id; - worker_id_=worker_id; - } - - /** - * @return server id where the parameter is maintained. - */ - virtual int Sharding(int param_id); - - /** - * Generate a request message to Get the parameter object. - */ - virtual Msg* Get(shared_ptr<Param> param, int step); - virtual Msg* Get(Msg** msg); - - /** - * Generate a request message to Update the parameter object. - */ - virtual Msg* Update(shared_ptr<Param> param, int step); - virtual Msg* Update(Msg** msg); - - /** - * Collect a Param object returned from server. - */ - virtual Msg* Collect(Msg**); - - /** - * Generate a request message to Put the parameter object. - */ - virtual Msg* Put(shared_ptr<Param> param, int step); - virtual Msg* Put(Msg** msg); - - protected: - int group_id_, worker_id_; - shared_ptr<ParamShard> shard_; -}; - -/** - * Testing worker functionality.The main thread reads the config file and set up the socket. - * - * Create the shared ParamShard, then starts worker thread which basically carries out the work. - * Each thread creates a PMClient object. - * - * The main thread then enter the loops to forward messages. - * - * Requests from the worker thread is prepend the paramId, which is stripped by the main thread - * before forwarding to the correct server. - * - * The 1st thread in Client 0 populates the servers with data (PUT request). Wait - * for a while before starting the client thread (which does get/update - * continuously). -class SingaClient { -public: - SingaClient(int worker_id, Topology &topology, vector<string> &hosts); - void StartClient(); - - int id() { - return id_; - } - ParamShard *param_shard() { - return param_shard_; - } - char *backend_endpoint() { - return backend_endpoint_; - } - -private: - int id_, local_id_, group_id_; - char backend_endpoint_[256]; - vector<char*> neighbors_; - ParamShard *param_shard_; - - int param_to_server_id(int paramId);//< mapping paramId to server ID -}; - -//Zthread function for the worker thread, in the global namespace. -//Basically a loop of: compute, get, update, compute, etc. -void ClientThread(void *args, zctx_t *ctx, void *pipe); - -vector<Param*> gen_random_params(); -void test_get(PMClient *client); -void test_update(PMClient *client, vector<Param*> params); -void test_collect(PMClient *client); - */ - -} // namespace singa -#endif // INCLUDE_TRAINER_PM_WORKER_H_ http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b5b943c7/include/trainer/server.h ---------------------------------------------------------------------- diff --git a/include/trainer/server.h b/include/trainer/server.h index 6ae09e4..f9fc80b 100644 --- a/include/trainer/server.h +++ b/include/trainer/server.h @@ -1,21 +1,77 @@ #ifndef INCLUDE_TRAINER_SERVER_H_ #define INCLUDE_TRAINER_SERVER_H_ #include <memory> -#include "trainer/pm_server.h" +#include <utils/param.h> +#include <utils/updater.h> +#include "proto/model.pb.h" #include "communication/socket.h" using std::shared_ptr; namespace singa { +/* Repsond to worker's get/put/udpate request, and periodically syncing with + * other servers. + * + * Normally, the Server creates a response message for each request which + * will be sent back to the one who issued the request. However, if the request + * are not processed successfully, the original message will be returned. The + * sever does not know the returned message (response or the original message), + * it just sends it to the router. The router will decide to re-send the + * request to the server or send it to the worker. + */ class Server{ public: + typedef std::map<int, shared_ptr<Param>> ParamShard; + Server(int thread_id, int group_id, int server_id); - void Setup(const UpdaterProto& proto, shared_ptr<PMServer::ParamShard> shard); + void Setup(const UpdaterProto& proto, shared_ptr<ParamShard> shard); void Run(); protected: + + /** + * Process GET request. + * + * @return the orignal message or response message + */ + virtual Msg* HandleGet(shared_ptr<Param> param, Msg** msg); + + /** + * Process Update request. + * + * @return the orignal message or response message + */ + virtual Msg* HandleUpdate(shared_ptr<Param> param, Msg** msg); + + /** + * Process PUT request. + * + * @return the original message or response message. If we don't want need to + * acknowledge the put request, then return nullptr. + */ + virtual Msg* HandlePut(shared_ptr<Param> param, Msg **msg); + + /** + * TODO Process SYNC request. + */ + virtual Msg* HandleSyncRequest(shared_ptr<Param> param, Msg** msg); + + /** + * TODO Process SYNC response. + */ + virtual int HandleSyncResponse(shared_ptr<Param> param, Msg** msg); + + /** + * Scheduler for synchronizing server groups. + * + * TODO implement the Caffe's synchronization scheduler for data parallelism + */ + virtual bool SyncNow(); + + protected: int thread_id_,group_id_, server_id_; - shared_ptr<PMServer> pmserver_; shared_ptr<Dealer> dealer_; + shared_ptr<Updater> updater_; + shared_ptr<ParamShard> shard_; }; } /* Server */ #endif //INCLUDE_TRAINER_SERVER_H_ http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b5b943c7/include/trainer/trainer.h ---------------------------------------------------------------------- diff --git a/include/trainer/trainer.h b/include/trainer/trainer.h index 34d95f1..f5d2591 100644 --- a/include/trainer/trainer.h +++ b/include/trainer/trainer.h @@ -7,8 +7,6 @@ #include "utils/singleton.h" #include "utils/factory.h" #include "neuralnet/neuralnet.h" -#include "trainer/pm_worker.h" -#include "trainer/pm_server.h" #include "trainer/worker.h" #include "trainer/server.h" @@ -19,7 +17,61 @@ namespace singa { * * The main thread runs a loop to forward messages between workers and servers. */ + class Trainer{ +/** + * ParamInfo is used to construct a parameter shard. + * + * For each worker group: + * Every unique Param object is associated with a ParamCounter object whose + * param field points the to Param object itself. + * + * Param objects sharing the same values (due to data parallelism) are + * associated with the same ParamCounter whose param field also shares the + * same values. + * + * Usage: we need to aggregate gradients from all workers for the shared + * parameters before sending the update request. The nUpdate counter counts + * the number. + * + * TODO test with different physical architectures. + */ + public: + class ParamInfo{ + public: + ParamInfo(shared_ptr<Param> p,int local, int owner): + num_update(0), next_version(0),num_local(local), num_total(1), + owner_procs(owner){ + shares.push_back(p); + } + + /** + * Associate the counter to a Param object. + * + * @param p + * @param local 1 if this Param object is used by workers in this procs, 0 + * otherwise + * @param owner the procs id of the worker who ownes this Param object + */ + void AddParam(shared_ptr<Param> p, int local, int owner){ + num_local+=local; + num_total+=1; + if(owner>-1) + owner_procs=owner; + if(local>0){ + shares.push_back(p); + } + } + int num_update, next_version; //!< all counters are atomic + + int num_local; //!< # local workers uses the shared parameter + int num_total; //!< # total workers uses the shared parameter + int owner_procs; //!< the procs id of the worker that owns the parameter + vector<shared_ptr<Param>> shares; + }; + + typedef std::map<int, shared_ptr<ParamInfo>> ParamShard; + public: /** * Start the training in one process @@ -34,7 +86,7 @@ class Trainer{ // point. protected: - void Run(); + void Run(const std::map<int, shared_ptr<ParamShard>>& shards); /** * Register default implementations for all base classes used in the system, * e.g., the Updater, BaseMsg, etc. @@ -45,6 +97,39 @@ class Trainer{ * implementation class as the value, e.g., <"Updater" SGDUpdater>. */ void RegisterDefaultClasses(const singa::ModelProto& proto); + + /** + * Workers from the same group resident in the same process share the same + * ParamShard which contains ParamCounters for Param objects used/updated by + * these worekrs. Shared Param objects are associated with the same + * ParamCounter. + */ + + /** + * @return server id where the parameter is maintained. + */ + virtual int Sharding(int param_id); + + /** + * Generate a request message to Get the parameter object. + */ + virtual Msg* HandleGet(shared_ptr<ParamInfo>counter, Msg** msg); + virtual Msg* HandleGetResponse(shared_ptr<ParamInfo>counter, Msg** msg); + + /** + * Generate a request message to Update the parameter object. + */ + virtual Msg* HandleUpdate(shared_ptr<ParamInfo>counter, Msg** msg); + virtual int HandleUpdateResponse(shared_ptr<ParamInfo>counter, Msg** msg); + + /** + * Generate a request message to Put the parameter object. + */ + virtual Msg* HandlePut(shared_ptr<ParamInfo>counter, Msg** msg); + virtual Msg* HandleConnect(Msg** msg); + + protected: + int procs_id_; }; } /* singa */ #endif // INCLUDE_TRAINER_TRAINER_H_ http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b5b943c7/include/trainer/worker.h ---------------------------------------------------------------------- diff --git a/include/trainer/worker.h b/include/trainer/worker.h index 13f7798..afa56ae 100644 --- a/include/trainer/worker.h +++ b/include/trainer/worker.h @@ -4,35 +4,12 @@ #include <exception> #include "neuralnet/neuralnet.h" #include "proto/model.pb.h" -#include "trainer/pm_worker.h" #include "utils/cluster.h" #include "communication/socket.h" #include "communication/msg.h" namespace singa { /** - * Collecting metrics, like accuracy, loss, etc. - */ -class Performance{ - public: - /** - * Collect from LossLayer of net. - */ - explicit Performance(shared_ptr<NeuralNet> net); - /** - * aggregate metrics from LossLayerS - */ - void Update(); - void Reset(); - string ToString(); - private: - vector<string> name_; - shared_ptr<NeuralNet> net_; - vector<vector<float>> metric_; - int counter_; //!< inc by 1 for every Update -}; - -/** * The Worker class which runs the training algorithm. * The first worker group will initialize parameters of the Net, * and put them into the distributed memory/table. @@ -41,8 +18,7 @@ class Worker { public: Worker(int thread_id, int group_id, int worker_id); ~Worker(){} - void Setup(const ModelProto& model, shared_ptr<NeuralNet> train_net, - shared_ptr<PMWorker::ParamShard> shard); + void Setup(const ModelProto& model, shared_ptr<NeuralNet> train_net); void set_test_net(shared_ptr<NeuralNet> test_net){ test_net_=test_net; } @@ -61,7 +37,7 @@ class Worker { * Hence, no need to collect performance in every thread. * Only the main thread will pass none null perf. */ - void RunOneBatch(int step, Performance* perf=nullptr); + void RunOneBatch(int step, Metric* perf=nullptr); /** * Train one mini-batch. * Test/Validation is done before training. @@ -105,6 +81,7 @@ class Worker { const bool DisplayDebugInfo(const int step) const { return DisplayNow(step)&&modelproto_.debug()&&group_id_==0; } + const void DisplayPerformance(const Metric & perf, const string& prefix); /** * return true if the stop condition is satisfied, e.g., the maximum number @@ -163,7 +140,6 @@ class Worker { int thread_id_,group_id_, worker_id_; int step_; ModelProto modelproto_; - shared_ptr<PMWorker> pmworker_; shared_ptr<NeuralNet> train_net_, test_net_, validation_net_; shared_ptr<Dealer> layer_dealer_, param_dealer_; Poller layer_poller_, param_poller_; http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b5b943c7/include/utils/common.h ---------------------------------------------------------------------- diff --git a/include/utils/common.h b/include/utils/common.h index 993c153..5a59127 100644 --- a/include/utils/common.h +++ b/include/utils/common.h @@ -47,5 +47,36 @@ inline float rand_real(){ return static_cast<float>(rand())/(RAND_MAX+1.0f); } +class Metric{ + public: + Metric():counter_(0){} + void AddMetric(string name, float value){ + if(data_.find(name)==data_.end()) + data_[name]=value; + else + data_[name]+=value; + } + void Reset(){ + data_.clear(); + counter_=0; + } + void Avg(){ + for(auto& entry: data_) + entry.second/=counter_; + } + void Inc(){ + counter_++; + } + const string ToString() const{ + string disp=""; + for(const auto& entry: data_){ + disp+=entry.first+":"+std::to_string(entry.second)+"\t"; + } + return disp; + } + private: + map<string, float> data_; + int counter_; +}; } /* singa */ #endif // INCLUDE_UTILS_COMMON_H_ http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b5b943c7/src/communication/socket.cc ---------------------------------------------------------------------- diff --git a/src/communication/socket.cc b/src/communication/socket.cc index ef6174a..385950b 100644 --- a/src/communication/socket.cc +++ b/src/communication/socket.cc @@ -29,10 +29,11 @@ int Dealer::Connect(string endpoint){ CHECK_EQ(zsock_connect(dealer_,"%s", endpoint.c_str()),0); return 1; } -int Dealer::Send(Msg *msg){ - zmsg_t* zmsg=(static_cast<Msg*>(msg))->DumpToZmsg(); +int Dealer::Send(Msg** msg){ + zmsg_t* zmsg=(*msg)->DumpToZmsg(); zmsg_send(&zmsg, dealer_); - delete msg; + delete *msg; + *msg=NULL; return 1; } @@ -61,9 +62,9 @@ int Router::Bind(string endpoint){ return 1; } -int Router::Send(Msg *msg){ - zmsg_t* zmsg=static_cast<Msg*>(msg)->DumpToZmsg(); - int dstid=static_cast<Msg*>(msg)->dst(); +int Router::Send(Msg **msg){ + zmsg_t* zmsg=(*msg)->DumpToZmsg(); + int dstid=(*msg)->dst(); if(id2addr_.find(dstid)!=id2addr_.end()){ // the connection has already been set up zframe_t* addr=zframe_dup(id2addr_[dstid]); @@ -77,7 +78,8 @@ int Router::Send(Msg *msg){ nBufmsg_++; CHECK_LE(nBufmsg_, bufsize_); } - delete msg; + delete *msg; + *msg=NULL; return 1; } http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b5b943c7/src/main.cc ---------------------------------------------------------------------- diff --git a/src/main.cc b/src/main.cc index 89306d8..eaf88cc 100644 --- a/src/main.cc +++ b/src/main.cc @@ -20,6 +20,7 @@ DEFINE_int32(procsID, 0, "Global process ID"); DEFINE_string(cluster, "examples/mnist/cluster.conf", "Cluster config file"); DEFINE_string(model, "examples/mnist/conv.conf", "Model config file"); +DEFINE_int32(sleep, 5, "sleep seconds"); /** * Register layers, and other customizable classes. http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b5b943c7/src/neuralnet/base_layer.cc ---------------------------------------------------------------------- diff --git a/src/neuralnet/base_layer.cc b/src/neuralnet/base_layer.cc index e7702e9..d3ff24b 100644 --- a/src/neuralnet/base_layer.cc +++ b/src/neuralnet/base_layer.cc @@ -180,19 +180,20 @@ PrefetchLayer::~PrefetchLayer(){ /************* Implementation for SliceLayer****************/ void SliceLayer::Setup(const LayerProto& proto, const vector<SLayer>& srclayers){ - int slice_dim=proto.slice_param().slice_dimension(); - int slice_num=proto.slice_param().slice_num(); - CHECK_GE(slice_dim,0); - CHECK_EQ(slice_num, dstlayers_.size()); + slice_dim_=proto.slice_param().slice_dimension(); + slice_num_=proto.slice_param().slice_num(); + CHECK_GE(slice_dim_,0); + CHECK_EQ(slice_num_, dstlayers_.size()); data_.Reshape(srclayers[0]->data(this).shape()); grad_.ReshapeLike(data_); - datavec_.resize(slice_num); - gradvec_.resize(slice_num); + datavec_.resize(slice_num_); + gradvec_.resize(slice_num_); + CHECK_EQ(data_.count()%slice_num_, 0); // restrict equal slicing //LOG(ERROR)<<"slice dim "<<slice_dim<<" slice num "<<slice_num; - for(int i=0;i<slice_num;i++){ + for(int i=0;i<slice_num_;i++){ vector<int> newshape(data_.shape()); - newshape[slice_dim]=newshape[slice_dim]/slice_num+ - ((i==slice_num-1)?newshape[slice_dim]%slice_num:0); + newshape[slice_dim_]=newshape[slice_dim_]/slice_num_+ + ((i==slice_num_-1)?newshape[slice_dim_]%slice_num_:0); datavec_[i].Reshape(newshape); gradvec_[i].Reshape(newshape); //LOG(ERROR)<<"slice "<<IntVecToString(newshape); @@ -236,8 +237,22 @@ Blob<float>* SliceLayer::mutable_grad(const Layer* layer){ return &grad_; return &gradvec_[SliceID(layer)]; } -void SliceLayer::ComputeFeature(bool training, const vector<shared_ptr<Layer>>& srclayers){} -void SliceLayer::ComputeGradient(const vector<shared_ptr<Layer>>& srclayers){} +void SliceLayer::ComputeFeature(bool training, + const vector<shared_ptr<Layer>>& srclayers){ + CHECK_EQ(srclayers.size(),1); + if(slice_dim_==0){ + const auto& blob=srclayers.at(0)->data(this); + int size=blob.count()/slice_num_; + for(int i=0;i<slice_num_;i++){ + float* dst=datavec_[i].mutable_cpu_data(); + const float* src=blob.cpu_data()+i*size; + memcpy(dst, src, size*sizeof(float)); + } + } +} +void SliceLayer::ComputeGradient(const vector<shared_ptr<Layer>>& srclayers){ + +} void SplitLayer::Setup(const LayerProto& proto, const vector<SLayer>& srclayers){ http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b5b943c7/src/neuralnet/neuralnet.cc ---------------------------------------------------------------------- diff --git a/src/neuralnet/neuralnet.cc b/src/neuralnet/neuralnet.cc index 4dac512..ca371f6 100644 --- a/src/neuralnet/neuralnet.cc +++ b/src/neuralnet/neuralnet.cc @@ -5,7 +5,7 @@ #include "utils/singleton.h" #include "utils/factory.h" #include "utils/graph.h" - +#include "utils/cluster.h" namespace singa { #define CreateLayer(id) CreateInstance(id, Layer) @@ -61,8 +61,21 @@ NeuralNet::NeuralNet(NetProto net_proto, int group_size) { LOG(INFO)<<"Construct Neural Net..."; ConstructNeuralNet(net_proto); - if(group_size_>1) + { + string vis_folder=Cluster::Get()->vis_folder(); + std::ofstream fout(vis_folder+"/nopartition.json", std::ofstream::out); + fout<<ToString(); + fout.flush(); + fout.close(); + } + if(group_size_>1){ PartitionNeuralNet(); + string vis_folder=Cluster::Get()->vis_folder(); + std::ofstream fout(vis_folder+"/partition.json", std::ofstream::out); + fout<<ToString(); + fout.flush(); + fout.close(); + } for(auto layer: layers_){ DLOG(INFO)<<layer->name(); } @@ -88,7 +101,7 @@ void NeuralNet::ConstructNeuralNet(const NetProto& net_proto){ // topology sort graph_.Sort(); - //DLOG(INFO)<<"pure graph without partition\n"<< graph_.ToString(); + //LOG(ERROR)<<"pure graph without partition\n"<< graph_.ToString(); auto* factory=Singleton<Factory<Layer>>::Instance(); // create Layers according to topology order http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b5b943c7/src/proto/model.pb.h ---------------------------------------------------------------------- diff --git a/src/proto/model.pb.h b/src/proto/model.pb.h index 4f68462..567021a 100644 --- a/src/proto/model.pb.h +++ b/src/proto/model.pb.h @@ -567,6 +567,13 @@ class ModelProto : public ::google::protobuf::Message { inline bool debug() const; inline void set_debug(bool value); + // optional int32 warmup_steps = 50 [default = 0]; + inline bool has_warmup_steps() const; + inline void clear_warmup_steps(); + static const int kWarmupStepsFieldNumber = 50; + inline ::google::protobuf::int32 warmup_steps() const; + inline void set_warmup_steps(::google::protobuf::int32 value); + // @@protoc_insertion_point(class_scope:singa.ModelProto) private: inline void set_has_name(); @@ -613,6 +620,8 @@ class ModelProto : public ::google::protobuf::Message { inline void clear_has_neuralnet(); inline void set_has_debug(); inline void clear_has_debug(); + inline void set_has_warmup_steps(); + inline void clear_has_warmup_steps(); ::google::protobuf::UnknownFieldSet _unknown_fields_; @@ -641,9 +650,10 @@ class ModelProto : public ::google::protobuf::Message { bool debug_; int alg_; ::singa::NetProto* neuralnet_; + ::google::protobuf::int32 warmup_steps_; mutable int _cached_size_; - ::google::protobuf::uint32 _has_bits_[(22 + 31) / 32]; + ::google::protobuf::uint32 _has_bits_[(23 + 31) / 32]; friend void protobuf_AddDesc_model_2eproto(); friend void protobuf_AssignDesc_model_2eproto(); @@ -3577,13 +3587,6 @@ class UpdaterProto : public ::google::protobuf::Message { inline ::google::protobuf::int32 sync_frequency() const; inline void set_sync_frequency(::google::protobuf::int32 value); - // optional int32 warmup_steps = 25 [default = 10]; - inline bool has_warmup_steps() const; - inline void clear_warmup_steps(); - static const int kWarmupStepsFieldNumber = 25; - inline ::google::protobuf::int32 warmup_steps() const; - inline void set_warmup_steps(::google::protobuf::int32 value); - // optional float moving_rate = 26 [default = 0]; inline bool has_moving_rate() const; inline void clear_moving_rate(); @@ -3651,8 +3654,6 @@ class UpdaterProto : public ::google::protobuf::Message { inline void clear_has_learning_rate_change_method(); inline void set_has_sync_frequency(); inline void clear_has_sync_frequency(); - inline void set_has_warmup_steps(); - inline void clear_has_warmup_steps(); inline void set_has_moving_rate(); inline void clear_has_moving_rate(); inline void set_has_param_type(); @@ -3671,15 +3672,14 @@ class UpdaterProto : public ::google::protobuf::Message { ::google::protobuf::int32 learning_rate_change_frequency_; int learning_rate_change_method_; ::google::protobuf::int32 sync_frequency_; - ::google::protobuf::int32 warmup_steps_; + float moving_rate_; ::std::string* param_type_; static ::std::string* _default_param_type_; ::google::protobuf::RepeatedField< ::google::protobuf::int32 > step_; ::google::protobuf::RepeatedField< float > step_lr_; - float moving_rate_; mutable int _cached_size_; - ::google::protobuf::uint32 _has_bits_[(16 + 31) / 32]; + ::google::protobuf::uint32 _has_bits_[(15 + 31) / 32]; friend void protobuf_AddDesc_model_2eproto(); friend void protobuf_AssignDesc_model_2eproto(); @@ -4544,6 +4544,28 @@ inline void ModelProto::set_debug(bool value) { debug_ = value; } +// optional int32 warmup_steps = 50 [default = 0]; +inline bool ModelProto::has_warmup_steps() const { + return (_has_bits_[0] & 0x00400000u) != 0; +} +inline void ModelProto::set_has_warmup_steps() { + _has_bits_[0] |= 0x00400000u; +} +inline void ModelProto::clear_has_warmup_steps() { + _has_bits_[0] &= ~0x00400000u; +} +inline void ModelProto::clear_warmup_steps() { + warmup_steps_ = 0; + clear_has_warmup_steps(); +} +inline ::google::protobuf::int32 ModelProto::warmup_steps() const { + return warmup_steps_; +} +inline void ModelProto::set_warmup_steps(::google::protobuf::int32 value) { + set_has_warmup_steps(); + warmup_steps_ = value; +} + // ------------------------------------------------------------------- // NetProto @@ -7917,37 +7939,15 @@ inline void UpdaterProto::set_sync_frequency(::google::protobuf::int32 value) { sync_frequency_ = value; } -// optional int32 warmup_steps = 25 [default = 10]; -inline bool UpdaterProto::has_warmup_steps() const { - return (_has_bits_[0] & 0x00000800u) != 0; -} -inline void UpdaterProto::set_has_warmup_steps() { - _has_bits_[0] |= 0x00000800u; -} -inline void UpdaterProto::clear_has_warmup_steps() { - _has_bits_[0] &= ~0x00000800u; -} -inline void UpdaterProto::clear_warmup_steps() { - warmup_steps_ = 10; - clear_has_warmup_steps(); -} -inline ::google::protobuf::int32 UpdaterProto::warmup_steps() const { - return warmup_steps_; -} -inline void UpdaterProto::set_warmup_steps(::google::protobuf::int32 value) { - set_has_warmup_steps(); - warmup_steps_ = value; -} - // optional float moving_rate = 26 [default = 0]; inline bool UpdaterProto::has_moving_rate() const { - return (_has_bits_[0] & 0x00001000u) != 0; + return (_has_bits_[0] & 0x00000800u) != 0; } inline void UpdaterProto::set_has_moving_rate() { - _has_bits_[0] |= 0x00001000u; + _has_bits_[0] |= 0x00000800u; } inline void UpdaterProto::clear_has_moving_rate() { - _has_bits_[0] &= ~0x00001000u; + _has_bits_[0] &= ~0x00000800u; } inline void UpdaterProto::clear_moving_rate() { moving_rate_ = 0; @@ -7963,13 +7963,13 @@ inline void UpdaterProto::set_moving_rate(float value) { // optional string param_type = 27 [default = "Param"]; inline bool UpdaterProto::has_param_type() const { - return (_has_bits_[0] & 0x00002000u) != 0; + return (_has_bits_[0] & 0x00001000u) != 0; } inline void UpdaterProto::set_has_param_type() { - _has_bits_[0] |= 0x00002000u; + _has_bits_[0] |= 0x00001000u; } inline void UpdaterProto::clear_has_param_type() { - _has_bits_[0] &= ~0x00002000u; + _has_bits_[0] &= ~0x00001000u; } inline void UpdaterProto::clear_param_type() { if (param_type_ != _default_param_type_) { http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b5b943c7/src/proto/model.proto ---------------------------------------------------------------------- diff --git a/src/proto/model.proto b/src/proto/model.proto index 950bc2e..19727a9 100644 --- a/src/proto/model.proto +++ b/src/proto/model.proto @@ -78,6 +78,7 @@ message ModelProto{ optional bool hogwild=33 [default=false]; optional NetProto neuralnet = 40; optional bool debug=41 [default=false]; + optional int32 warmup_steps=50 [default=0]; } message NetProto{ @@ -366,7 +367,6 @@ message UpdaterProto { optional ChangeProto learning_rate_change_method = 16 [default = kFixed]; optional int32 sync_frequency=17 [default=1]; // warmup the parameters and then send to parameter servers. - optional int32 warmup_steps=25 [default=10]; optional float moving_rate=26 [default=0]; optional string param_type=27[default="Param"]; repeated int32 step=28; http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b5b943c7/src/trainer/pm_server.cc ---------------------------------------------------------------------- diff --git a/src/trainer/pm_server.cc b/src/trainer/pm_server.cc deleted file mode 100644 index 28fa28d..0000000 --- a/src/trainer/pm_server.cc +++ /dev/null @@ -1,99 +0,0 @@ -#include <gflags/gflags.h> -#include <glog/logging.h> -#include "trainer/pm_server.h" -#include "utils/singleton.h" -#include "utils/factory.h" -#include <vector> - -using std::vector; - -namespace singa{ -void PMServer::Setup(int group_id, int server_id, shared_ptr<ParamShard> shard, - const UpdaterProto& proto){ - group_id_=group_id; - server_id_=server_id; - shard_=shard; - updater_=shared_ptr<Updater>(Singleton<Factory<Updater>>::Instance() - ->Create("Updater")); - updater_->Init(proto); -} - -PMServer::~PMServer(){ -} - -bool PMServer::SyncNow(){ - return false; -} -Msg* PMServer::HandlePut(Msg **msg){ - int id=(*msg)->target(); - shared_ptr<Param> param=nullptr; - if(shard_->find(id)!=shard_->end()){ - LOG(ERROR)<<"Param ("<<id<<") is put more than once"; - param=shard_->at(id); - }else{ - param=shared_ptr<Param>(Singleton<Factory<Param>>::Instance() - ->Create("Param")); - param->set_id(id); - (*shard_)[id]=param; - } - return param->HandlePutMsg(msg); -} - -Msg* PMServer::HandleGet(Msg **msg){ - int id=(*msg)->target(); - shared_ptr<Param> param=nullptr; - if(shard_->find(id)!=shard_->end()){ - param=shard_->at(id); - return param->HandleGetMsg(msg); - } else { - //re-construct msg to be re-queued. - //the calling function will send this message off - return *msg; - } -} - -Msg* PMServer::HandleUpdate(Msg **msg) { - int id=(*msg)->target(); - shared_ptr<Param> param=nullptr; - if(shard_->find(id)!=shard_->end()){ - //repsonse of the format: <identity><type: kData><paramId><param content> - param=shard_->at(id); - Msg* tmp=static_cast<Msg*>((*msg)->CopyAddr()); - param->ParseUpdateMsg(msg); - updater_->Update(param->version(), param); - param->set_version(param->version()+1); - auto response=param->GenUpdateResponseMsg(); - tmp->SwapAddr(); - response->SetAddr(tmp); - delete tmp; - return response; - } else { - LOG(ERROR)<<"Param ("<<id<<") is not maintained by server ("<<group_id_ - <<", "<<server_id_<<")"; - //re-construct msg to be re-queued. - return *msg; - } -} - -Msg* PMServer::HandleSyncRequest(Msg **msg){ - int id=(*msg)->target(); - shared_ptr<Param> param=nullptr; - if(shard_->find(id)!=shard_->end()){ - //repsonse of the format: <identity><type: kData><paramId><param content> - param=shard_->at(id); - return param->HandleSyncMsg(msg); - } else { - //re-construct msg to be re-queued. - return *msg; - } -} - -int PMServer::HandleSyncResponse(Msg **msg){ - int id=(*msg)->target(); - CHECK(shard_->find(id)!=shard_->end()); - return shard_->at(id)->ParseSyncResponseMsg(msg); -} - -} // namespace singa - - http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b5b943c7/src/trainer/pm_worker.cc ---------------------------------------------------------------------- diff --git a/src/trainer/pm_worker.cc b/src/trainer/pm_worker.cc deleted file mode 100644 index d2531e0..0000000 --- a/src/trainer/pm_worker.cc +++ /dev/null @@ -1,324 +0,0 @@ -#include <sys/types.h> -#include <sys/stat.h> -#include <fcntl.h> -#include "gflags/gflags.h" -#include <glog/logging.h> -#include "proto/model.pb.h" -#include "trainer/pm_worker.h" -#include "mshadow/tensor.h" -#include "utils/cluster.h" - - -namespace singa{ - -void PMWorker::Setup(int group_id, int worker_id, - shared_ptr<ParamShard> shard){ - group_id_=group_id; - worker_id_=worker_id; - shard_=shard; -} -int PMWorker::Sharding(int param_id){ - return param_id%Cluster::Get()->nservers_per_group(); -} -/* -int PMWorker::Sharding(int param_id){ - static map<int, int> id2procs; - if(id2procs.find(param_id)==id2procs.end()){ - auto cluster=Cluster::Get(); - int server_group=group_id_%cluster->nserver_groups(); - int nprocs_per_server_group= - cluster->nservers_per_group()/cluster->nservers_per_procs(); - int procsid=server_group*nprocs_per_server_group+ - param_id%nprocs_per_server_group; - procsid= cluster->server_worker_separate()? - cluster->nworker_procs()+procsid:procsid; - id2procs[param_id]=procsid; - } - return id2procs[param_id]; -} -*/ - -Msg* PMWorker::Put(Msg** msg){ - return *msg; -} - -Msg* PMWorker::Put(shared_ptr<Param> param, int step){ - int id=param->owner(); - auto entry=shard_->at(id); - Msg* msg= param->GenPutMsg(&step); - msg->set_src(group_id_, worker_id_, kWorkerParam); - msg->set_dst(group_id_/Cluster::Get()->nworker_groups_per_server_group(), - Sharding(id), kServer); - msg->set_type(kPut); - msg->set_target(id); - return msg; -} - -Msg* PMWorker::Get(Msg** msg){ - return *msg; -} - -Msg* PMWorker::Get(shared_ptr<Param> param, int step){ - int id=param->owner(); - shared_ptr<ParamCounter> entry=shard_->at(id); - Msg *msg=nullptr; - if((entry->nGet+1)%entry->nLocal==0&¶m->version()<step){ - msg=param->GenGetMsg(&step); - msg->set_dst(group_id_/Cluster::Get()->nworker_groups_per_server_group(), - Sharding(id), kServer); - msg->set_src(group_id_, worker_id_, kWorkerParam); - msg->set_type(kGet); - msg->set_target(id); - } - entry->nGet++; - return msg; -} - -Msg* PMWorker::Update(Msg** msg){ - return *msg; -} -Msg* PMWorker::Update(shared_ptr<Param> param, int step){ - int id=param->owner(); - shared_ptr<ParamCounter> entry=shard_->at(id); - Msg* msg=nullptr; - if((entry->nUpdate+1)%entry->nLocal==0){ - auto shape=mshadow::Shape1(param->size()); - auto it=entry->shares.begin(); - mshadow::Tensor<mshadow::cpu,1> agg((*it)->mutable_cpu_grad(), shape); - for(++it;it!=entry->shares.end();it++){ - mshadow::Tensor<mshadow::cpu,1> grad((*it)->mutable_cpu_grad(), shape); - agg+=grad/entry->nTotal; - } - msg=entry->shares.at(0)->GenUpdateMsg(&step); - msg->set_dst(group_id_/Cluster::Get()->nworker_groups_per_server_group(), - Sharding(id), kServer); - /* - entry->param->GenUpdateMsg(&step); - msg->set_dst(entry->owner_procs,kStub); - memset(param->mutable_cpu_data(), 0, sizeof(float)*param->size()); - */ - msg->set_type(kUpdate); - msg->set_target(id); - msg->set_src(group_id_, worker_id_, kWorkerParam); - } - entry->nUpdate++; - return msg; -} - -Msg* PMWorker::Collect(Msg** msg){ - int id=(*msg)->target(); - int type=(*msg)->type(); - auto pp=shard_->at(id)->shares.at(0); - if(type==kRGet){ - pp->ParseGetResponseMsg(msg); - }else if(type==kRUpdate){ - pp->ParseUpdateResponseMsg(msg); - } - if(pp->owner()!=pp->id()){ - // forwarding to workers on other procs - } - delete (*msg); - *msg=nullptr; - return nullptr; -} - -/* -//id is the global worker id -SingaClient::SingaClient(int global_id, Topology &topology, vector<string> &hosts) { - //Read the config files and store endpoints - id_ = global_id; - - int n_workers = hosts.size() - topology.nservers(); - int n_worker_groups = topology.nworker_groups(); - int group_size = n_workers/n_worker_groups; - int server_group_size = topology.nservers()/topology.server_group_size(); - FLAGS_client_threads = topology.worker_threads(); - - local_id_ = (id_-topology.nservers())%group_size;//local worker id. - group_id_ = (id_-topology.nservers())/group_size; - - VLOG(3) << "Parsing client config for "<<hosts[id_]; - - //connect to all server in the server group group_id_ - int start_server_idx = group_id_*server_group_size; - int end_server_idx = start_server_idx+server_group_size; - - for (int i = start_server_idx; i < end_server_idx; i++) { - char *neighbor_endpoint = (char*) malloc(256); - sprintf(neighbor_endpoint, "tcp://%s:%d", hosts[i].c_str(), topology.port()); - neighbors_.push_back(neighbor_endpoint); - VLOG(3) << "Worker neighbor (server): "<<neighbor_endpoint; - } - - sprintf(backend_endpoint_, "inproc://singanus%d",id_); - - //Create shared paramshard - param_shard_ = new ParamShard(id_,0); -} - -void SingaClient::StartClient(){ - //Create and connect sockets to the server - vector<void *> server_sockets; - zctx_t *context = zctx_new(); - int nservers = neighbors_.size(); - int rc; - for (int i=0; i<nservers; i++){ - void *socket = zsocket_new(context, ZMQ_DEALER); - rc = zsocket_connect(socket, neighbors_[i]); - VLOG(3) << "Connected to neighbor " <<neighbors_[i]; - assert(rc==0); - server_sockets.push_back(socket); - } - - //Create and bind backend socket - void *backend = zsocket_new(context, ZMQ_ROUTER); - rc = zsocket_bind(backend, backend_endpoint_); - assert(rc==0); - - //Start client threads - for (int i=0; i<FLAGS_client_threads; i++){ - void * socket = zthread_fork(context, ClientThread, this); - zmsg_t *control_msg = zmsg_new(); - if (i==0 && local_id_==0) - zmsg_pushstr(control_msg,POPULATE); - else - zmsg_pushstr(control_msg, WAIT); - zmsg_send(&control_msg, socket); - } - - //Star the message loop - bool is_running = true; - int nsockets= nservers+1; - while (is_running) { - zmq_pollitem_t items[nsockets]; - for (int i = 0; i < nsockets-1; i++) - items[i] = {server_sockets[i], 0, ZMQ_POLLIN, 0}; - items[nsockets-1] = {backend, 0, ZMQ_POLLIN, 0}; - - int rc = zmq_poll(items,nsockets,-1); - if (rc<0) break; - - for (int i=0; i<nsockets-1; i++){ - if (items[i].revents & ZMQ_POLLIN){ - zmsg_t *msg = zmsg_recv(server_sockets[i]); - if (!msg){ - is_running = false; - break; - } - //forward to backend - zmsg_send(&msg, backend); - } - } - if (items[nsockets-1].revents & ZMQ_POLLIN){ - //compute serverId from paramId and forward to the socket - zmsg_t *msg = zmsg_recv(backend); - if (!msg) is_running=false; - zframe_t *identity = zmsg_pop(msg); - zframe_t *type = zmsg_pop(msg); - int paramId; - sscanf(zmsg_popstr(msg), "%d", ¶mId); - zmsg_pushstrf(msg,"%d",paramId); - zmsg_prepend(msg,&type); - zmsg_prepend(msg,&identity); - zmsg_send(&msg, server_sockets[param_to_server_id(paramId)]); - } - } - - zsocket_destroy(context, backend); - for (int i=0; i<nsockets-1; i++) - zsocket_destroy(context, server_sockets[i]); - zctx_destroy(&context); -} - -vector<Param*> gen_random_params() { - int size[] = { 1960000, 2500, 5000000, 2000, 3000000, 1500, 1500000, 1000, 500000, 500, 5000, 10 }; - vector<Param*> params; - for (int i = 0; i < 12; i++) { - ParamProto proto; - proto.set_id(i); - proto.set_init_method(ParamProto::kGaussain); - Param* p = new Param(); - p->Setup(proto, vector<int> { size[i] }, 0); - p->Init(); - params.push_back(p); - } - return params; -} - -//simple mapping -int SingaClient::param_to_server_id(int paramId){ - return paramId % neighbors_.size(); -} - -void ClientThread(void *args, zctx_t *ctx, void *pipe){ - SingaClient *client = static_cast<SingaClient*>(args); - - //Create back-end socket and connect to the main thread - void *backend = zsocket_new(ctx, ZMQ_DEALER); - int rc = zsocket_connect(backend, client->backend_endpoint()); - assert(rc==0); - //Create PMClient object - PMClient *pmclient = new PMClient(client->id(), client->param_shard(), backend); - - //FOR TESTING ONLY. REMOVE THIS! - //wait for control from main thread - vector<Param*> params = gen_random_params(); - zmsg_t *control_msg = zmsg_recv(pipe); - zframe_t *msg = zmsg_pop(control_msg); - if (zframe_streq(msg,WAIT)) - zclock_sleep(2000); //2s - else{ - for (int i=0; i<params.size(); i++){ - pmclient->Put(i, params[i]); - } - VLOG(3)<<"Done PUT requests for populating servers."; - zclock_sleep(2000); - } - zframe_destroy(&msg); - //END TESTING - LOG(ERROR) << "Done putting"; - - //first, get the params - - test_get(pmclient); - test_collect(pmclient); - - - int iterations = 1; - while (iterations<=200){ - VLOG(3) << "Iteration "<<iterations; - test_update(pmclient, params); - test_collect(pmclient); - iterations++; - } - - zsocket_destroy(ctx, backend); -} - -void test_get(PMClient *client){ - for (int i=0; i<12; i++){ - Param pm; - int status = client->Get(i, &pm); - assert(status==NON_LOCAL); - } -} - -void test_collect(PMClient *client){ - for (int i=0; i<12; i++){ - Param pm; - int64_t start_time = zclock_time(); - while (!client->Collect(&pm)) - zclock_sleep(1); - int64_t end_time = zclock_time(); - VLOG(3) << "Collected: " <<(end_time-start_time); - } -} - -void test_update(PMClient *client, vector<Param*> params){ - for (int i=0; i<params.size(); i++) - client->Update(i, params[i]); -} -*/ - - -} //namespace singa http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b5b943c7/src/trainer/server.cc ---------------------------------------------------------------------- diff --git a/src/trainer/server.cc b/src/trainer/server.cc index 36c1302..cd2bc02 100644 --- a/src/trainer/server.cc +++ b/src/trainer/server.cc @@ -13,11 +13,12 @@ Server::Server(int thread_id,int group_id, int server_id): thread_id_(thread_id),group_id_(group_id), server_id_(server_id){} void Server::Setup(const UpdaterProto& proto, - shared_ptr<PMServer::ParamShard> shard){ + shared_ptr<Server::ParamShard> shard){ //VLOG(3) << "Parsing config file for host "<<hosts[id_] << " server id = " <<id_; - pmserver_=shared_ptr<PMServer>(Singleton<Factory<PMServer>>::Instance() - ->Create("PMServer")); - pmserver_->Setup(group_id_, server_id_, shard, proto); + updater_=shared_ptr<Updater>(Singleton<Factory<Updater>>::Instance() + ->Create("Updater")); + updater_->Init(proto); + shard_=shard; } void Server::Run(){ @@ -26,12 +27,10 @@ void Server::Run(){ Msg* ping=new Msg(); ping->set_src(group_id_, server_id_, kServer); - ping->set_dst(0,0,kStub); + ping->set_dst(-1,-1,kStub); ping->add_frame("PING", 4); ping->set_type(kConnect); - dealer_->Send(ping); - Poller poller; - poller.Add(dealer_.get()); + dealer_->Send(&ping); //start recv loop and process requests while (true){ Msg* msg=dealer_->Receive(); @@ -39,39 +38,89 @@ void Server::Run(){ break; Msg* response=nullptr; int type=msg->type(); - switch (type){ - case kConnect:{ - string pong((char*)msg->frame_data(), msg->frame_size()); - CHECK_STREQ("PONG", pong.c_str()); - delete msg; - break; - } - case kPut: - response = pmserver_->HandlePut(&msg); - break; - case kGet: - response = pmserver_->HandleGet(&msg); - break; - case kUpdate: - response = pmserver_->HandleUpdate(&msg); - break; - case kSyncRequest: - VLOG(3)<<"Handle SYNC-REQUEST"; - response = pmserver_->HandleSyncRequest(&msg); - break; - case kSyncResponse: - VLOG(3) << "Handle SYNC response"; - pmserver_->HandleSyncResponse(&msg); - break; - } - - if (response!=nullptr){ - //LOG(ERROR)<<"type: "<<type<<" response to "<<response->dst_id(); - dealer_->Send(response); + if (type==kConnect){ + // TODO remove receiving pong msg + string pong((char*)msg->frame_data(), msg->frame_size()); + CHECK_STREQ("PONG", pong.c_str()); + delete msg; + }else if(type==kPut){ + int pid=msg->target_first(); + shared_ptr<Param> param=nullptr; + if(shard_->find(pid)!=shard_->end()){ + LOG(ERROR)<<"Param ("<<pid<<") is put more than once"; + param=shard_->at(pid); + }else{ + param=shared_ptr<Param>(Singleton<Factory<Param>>::Instance() + ->Create("Param")); + param->set_id(pid); + (*shard_)[pid]=param; + } + response = HandlePut(param, &msg); + }else{ + int pid=msg->target_first(); + if(shard_->find(pid)==shard_->end()){ + // delay the processing by re-queue the msg. + response=msg; + } else{ + CHECK(shard_->find(pid)!=shard_->end()) <<"Param ("<<pid + <<") is not maintained by server (" + <<group_id_ <<", " <<server_id_<<")"; + auto param=shard_->at(pid); + switch (type){ + case kGet: + response=HandleGet(param, &msg); + break; + case kUpdate: + response = HandleUpdate(param, &msg); + break; + case kSyncRequest: + VLOG(3)<<"Handle SYNC-REQUEST"; + response = HandleSyncRequest(param, &msg); + break; + case kSyncResponse: + VLOG(3) << "Handle SYNC response"; + HandleSyncResponse(param, &msg); + break; + } + if (response!=nullptr){ + dealer_->Send(&response); + } + } } } } +bool Server::SyncNow(){ + return false; +} +Msg* Server::HandlePut(shared_ptr<Param> param, Msg **msg){ + return param->HandlePutMsg(msg); +} + +Msg* Server::HandleGet(shared_ptr<Param> param, Msg **msg){ + return param->HandleGetMsg(msg); +} +Msg* Server::HandleUpdate(shared_ptr<Param> param, Msg **msg) { + //repsonse of the format: <identity><type: kData><paramId><param content> + auto* tmp=static_cast<Msg*>((*msg)->CopyAddr()); + int v=(*msg)->target_second()+1; + param->ParseUpdateMsg(msg); + updater_->Update(param->version(), param); + param->set_version(param->version()+1); + auto response=param->GenUpdateResponseMsg(&v); + tmp->SwapAddr(); + response->SetAddr(tmp); + delete tmp; + return response; +} + +Msg* Server::HandleSyncRequest(shared_ptr<Param> param, Msg **msg){ + return param->HandleSyncMsg(msg); +} + +int Server::HandleSyncResponse(shared_ptr<Param> param, Msg **msg){ + return param->ParseSyncResponseMsg(msg); +} } /* singa */ http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b5b943c7/src/trainer/trainer.cc ---------------------------------------------------------------------- diff --git a/src/trainer/trainer.cc b/src/trainer/trainer.cc index 4ac51ce..35b8f6c 100644 --- a/src/trainer/trainer.cc +++ b/src/trainer/trainer.cc @@ -3,6 +3,7 @@ #include <map> #include <glog/logging.h> #include "trainer/trainer.h" +#include "mshadow/tensor.h" using std::vector; using std::map; @@ -33,16 +34,11 @@ void Trainer::RegisterDefaultClasses(const singa::ModelProto& proto){ "Param", CreateInstance(singa::Param, singa::Param)); Singleton<Factory<singa::Updater>>::Instance() ->Register( "Updater", CreateInstance(singa::SGDUpdater, singa::Updater)); - Singleton<Factory<singa::PMWorker>>::Instance() ->Register( - "PMWorker", CreateInstance(singa::PMWorker, singa::PMWorker)); - Singleton<Factory<singa::PMServer>>::Instance() ->Register( - "PMServer", CreateInstance(singa::PMServer, singa::PMServer)); - Singleton<Factory<singa::PMServer>>::Instance() ->Register( - "PMServer", CreateInstance(singa::PMServer, singa::PMServer)); } void Trainer::Start(const ModelProto& mproto, const ClusterProto& cproto, int procs_id){ + procs_id_=procs_id; RegisterDefaultClasses(mproto); auto cluster=Cluster::Get(cproto, procs_id); @@ -57,7 +53,7 @@ void Trainer::Start(const ModelProto& mproto, const ClusterProto& cproto, int start=pid*cluster->nservers_per_procs()%cluster->nservers_per_group(); int end=start+cluster->nservers_per_group(); // the ParamShard for servers consists of a dictionary of Param objects - auto shard=make_shared<PMServer::ParamShard>(); + auto shard=make_shared<Server::ParamShard>(); for(int sid=start;sid<end;sid++){ auto server=make_shared<Server>(nthreads++,gid, sid); server->Setup(mproto.updater(), shard); @@ -67,6 +63,7 @@ void Trainer::Start(const ModelProto& mproto, const ClusterProto& cproto, // create workers vector<shared_ptr<Worker>> workers; + std::map<int, shared_ptr<Trainer::ParamShard>> shards; if(cluster->has_worker()){ auto net=NeuralNet::SetupNeuralNet(mproto.neuralnet(), kTrain, cluster->nworkers_per_group()); @@ -117,14 +114,15 @@ void Trainer::Start(const ModelProto& mproto, const ClusterProto& cproto, } } // create ParamShard for the workers - auto shard=make_shared<PMWorker::ParamShard>(); + auto shard=make_shared<Trainer::ParamShard>(); + shards[gid]=shard; for(auto layer: train_net->layers()){ int procsid=ProcsIDOf(gid, layer->partitionid(),kWorkerParam); int local=procsid==cluster->procs_id(); for(auto param: layer->GetParams()){ int owner=param->owner()<0||param->owner()==param->id()?procsid:-1; if(shard->find(param->owner())==shard->end()) - (*shard)[param->owner()]=make_shared<ParamCounter>(param, local, owner); + (*shard)[param->owner()]=make_shared<ParamInfo>(param, local, owner); else shard->at(param->owner())->AddParam(param, local, owner); } @@ -136,7 +134,7 @@ void Trainer::Start(const ModelProto& mproto, const ClusterProto& cproto, else{ // TODO add CDWorker } - worker->Setup(mproto, train_net, shard); + worker->Setup(mproto, train_net); worker->set_test_net(test_net); worker->set_validation_net(validation_net); workers.push_back(worker); @@ -154,13 +152,14 @@ void Trainer::Start(const ModelProto& mproto, const ClusterProto& cproto, threads.push_back(std::thread(&Server::Run,server.get())); for(auto worker: workers) threads.push_back(std::thread(&Worker::Run,worker.get())); - Run(); + Run(shards); for(auto& thread: threads) thread.join(); } -void Trainer::Run(){ +void Trainer::Run(const std::map<int, shared_ptr<Trainer::ParamShard>>& shards){ auto cluster=Cluster::Get(); + procs_id_=cluster->procs_id(); auto router=make_shared<Router>(); router->Bind(kInprocRouterEndpoint); if(cluster->nprocs()>1) @@ -173,38 +172,179 @@ void Trainer::Run(){ LOG(ERROR)<<"Connection broken!"; exit(0); } - int dst_flag=msg->dst_flag(); - int type=msg->type(); - int group_id, id, procs_id; - switch (dst_flag){ // TODO process other requests, e.g. RESTful - case kStub: + while(msg!=nullptr){ + int dst_flag=msg->dst_flag(); + int type=msg->type(); + int dst_procs=msg->dst_first(); + if(dst_flag == kStub&&(dst_procs==procs_id_||dst_procs==-1)){ if(type==kConnect){ - string ping((char*)msg->frame_data(), msg->frame_size()); - CHECK_STREQ("PING", ping.c_str()); - msg->SwapAddr(); - Msg* reply=new Msg(); - reply->SetAddr(msg); - reply->add_frame("PONG", 4); - reply->set_type(kConnect); - delete msg; - router->Send(reply); + msg =HandleConnect(&msg); }else{ - // TODO processing requests for worker group spanning multiple procs. - LOG(ERROR)<<"Unkown message type ("<<type<<") to stub"; + int group_id=msg->src_first(); + int paramid=msg->target_first(); + auto entry=shards.at(group_id)->at(paramid); + switch (type){ // TODO process other requests, e.g. RESTful + case kUpdate: + msg=HandleUpdate(entry, &msg); + break; + case kRUpdate: + HandleUpdateResponse(entry, &msg); + break; + case kGet: + msg=HandleGet(entry, &msg); + break; + case kRGet: + msg=HandleGetResponse(entry, &msg); + break; + case kPut: + msg=HandlePut(entry, &msg); + break; + default: + break; + } } - break; - default: - group_id=msg->dst_group_id(); - id=msg->dst_id(); - procs_id=ProcsIDOf(group_id, id, dst_flag); - if(procs_id!=cluster->procs_id()){ - if (interprocs_dealers.find(procs_id)==interprocs_dealers.end()) - interprocs_dealers[procs_id]=make_shared<Dealer>(procs_id); - interprocs_dealers[procs_id]->Send(msg); - } else - router->Send(msg); - break; + }else{ + int dst_procs_id; + if(dst_flag==kStub){ + dst_procs_id=msg->dst_first(); + }else{ + dst_procs_id=ProcsIDOf(msg->dst_first(), msg->dst_second(), msg->dst_flag()); + } + if(dst_procs_id!=procs_id_){ + /* + // forward to other procs + if (interprocs_dealers.find(procs_id)==interprocs_dealers.end()) + interprocs_dealers[procs_id]=make_shared<Dealer>(procs_id); + interprocs_dealers[procs_id]->Send(&msg); + */ + }else{ + router->Send(&msg); + } + } } } } +Msg* Trainer::HandleConnect(Msg** msg){ + string ping((char*)(*msg)->frame_data(), (*msg)->frame_size()); + CHECK_STREQ("PING", ping.c_str()); + // ping-pong for debug + (*msg)->SwapAddr(); + Msg* reply=new Msg(); + reply->SetAddr(*msg); + reply->add_frame("PONG", 4); + reply->set_type(kConnect); + delete *msg; + *msg=NULL; + return reply; +} +int Trainer::Sharding(int param_id){ + return param_id%Cluster::Get()->nservers_per_group(); +} +/* +int Worker::Sharding(int param_id){ + static map<int, int> id2procs; + if(id2procs.find(param_id)==id2procs.end()){ + auto cluster=Cluster::Get(); + int server_group=group_id_%cluster->nserver_groups(); + int nprocs_per_server_group= + cluster->nservers_per_group()/cluster->nservers_per_procs(); + int procsid=server_group*nprocs_per_server_group+ + param_id%nprocs_per_server_group; + procsid= cluster->server_worker_separate()? + cluster->nworker_procs()+procsid:procsid; + id2procs[param_id]=procsid; + } + return id2procs[param_id]; +} +*/ + + +Msg* Trainer::HandleGet(shared_ptr<ParamInfo> pi, Msg** msg){ + Msg* msgg=*msg, *reply=nullptr; + int version=msgg->target_second(); + if(msgg->src_flag()==kStub){ + if(version<=pi->shares.at(0)->version()){ + reply=pi->shares.at(0)->HandleGetMsg(msg); + }else if(version>pi->next_version){ + // reinsert into a msg queue. + } + }else if(version>pi->next_version){ + pi->next_version=version; + reply=pi->shares.at(0)->GenGetMsg(&version); + int gid=msgg->src_first(), pid=msgg->target_first(); + reply->set_src(procs_id_, gid, kStub); + reply->set_dst(gid/Cluster::Get()->nworker_groups_per_server_group(), + Sharding(pid), kServer); + } + return reply; +} + +Msg* Trainer::HandleGetResponse(shared_ptr<ParamInfo>pi, Msg** msg){ + pi->shares.at(0)->ParseGetResponseMsg(msg); + return nullptr; + // process get requests in waiting queue +} + +Msg* Trainer::HandleUpdate(shared_ptr<ParamInfo>pi, Msg** msg){ + Msg* msgg=*msg, *update=nullptr; + if(msgg->src_flag()==kStub){ + if(pi->num_update<pi->num_local) + return *msg; //wait unitl local updates are ready + int n; + sscanf((char*)(*msg)->frame_data(), "%d", &n); + pi->num_update+=n; + auto it=pi->shares.begin(); + auto shape=mshadow::Shape1((*it)->size()); + mshadow::Tensor<mshadow::cpu,1> agg((*it)->mutable_cpu_grad(), shape); + mshadow::Tensor<mshadow::cpu,1> grad((*it)->mutable_cpu_grad(), shape); + agg+=grad; + }else if(++pi->num_update>=pi->num_local){ + auto it=pi->shares.begin(); + auto shape=mshadow::Shape1((*it)->size()); + mshadow::Tensor<mshadow::cpu,1> agg((*it)->mutable_cpu_grad(), shape); + for(++it;it!=pi->shares.end();it++){ + mshadow::Tensor<mshadow::cpu,1> grad((*it)->mutable_cpu_grad(), shape); + agg+=grad; + } + agg/=pi->num_total; + if(pi->num_local<pi->num_total){ + int v=msgg->target_second(); + update=pi->shares.at(0)->GenUpdateMsg(&v); + int gid=msgg->src_first(); + update->set_src(procs_id_, gid,kStub); + update->set_dst(pi->owner_procs, gid, kStub); + pi->num_update=0; + } + } + if(pi->num_update==pi->num_total){ + int v=msgg->target_second(); + update=pi->shares.at(0)->GenUpdateMsg(&v); + int gid=msgg->src_first(); + update->set_src(procs_id_, gid, kStub); + update->set_dst(gid/Cluster::Get()->nworker_groups_per_server_group(), + Sharding((*msg)->target_first()), kServer); + pi->num_update=0; + } + delete *msg; + *msg=NULL; + return update; +} + +int Trainer::HandleUpdateResponse(shared_ptr<Trainer::ParamInfo> pi, Msg** msg){ + HandleGetResponse(pi, msg); + return 1; +} + +Msg* Trainer::HandlePut(shared_ptr<Trainer::ParamInfo>pi, Msg** msg){ + CHECK_NE((*msg)->src_flag(), kStub); + Msg* put=pi->shares.at(0)->GenPutMsg(); + int gid=(*msg)->src_first(); + int id=(*msg)->target_first(); + put->set_src(procs_id_, gid , kStub); + put->set_dst(gid/Cluster::Get()->nworker_groups_per_server_group(), + Sharding(id), kServer); + delete *msg; + *msg=NULL; + return put; +} } /* singa */ http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b5b943c7/src/trainer/worker.cc ---------------------------------------------------------------------- diff --git a/src/trainer/worker.cc b/src/trainer/worker.cc index 3f1c83f..7565d49 100644 --- a/src/trainer/worker.cc +++ b/src/trainer/worker.cc @@ -2,24 +2,23 @@ #include <thread> #include <memory> #include <iostream> +#include <chrono> +#include <thread> #include "utils/singleton.h" #include "utils/factory.h" #include "trainer/worker.h" #include "proto/model.pb.h" using std::thread; +DECLARE_int32(sleep); namespace singa { Worker::Worker(int thread_id, int group_id, int worker_id): - thread_id_(thread_id),group_id_(group_id), worker_id_(worker_id){ + thread_id_(thread_id), group_id_(group_id), worker_id_(worker_id){ } void Worker::Setup(const ModelProto& model, - shared_ptr<NeuralNet> train_net, - shared_ptr<PMWorker::ParamShard> shard){ + shared_ptr<NeuralNet> train_net){ train_net_=train_net; modelproto_=model; - pmworker_=shared_ptr<PMWorker>(Singleton<Factory<PMWorker>>::Instance() - ->Create("PMWorker")); - pmworker_->Setup(group_id_, worker_id_, shard); } void Worker::Run(){ @@ -29,13 +28,13 @@ void Worker::Run(){ layer_dealer_=make_shared<Dealer>(2*thread_id_+1); layer_dealer_->Connect(kInprocRouterEndpoint); - { + { // TODO remove waiting pong msg Msg* ping=new Msg(); ping->set_src(group_id_, worker_id_, kWorkerParam); - ping->set_dst(0,0,kStub); + ping->set_dst(-1,-1,kStub); ping->set_type(kConnect); ping->add_frame("PING", 4); - param_dealer_->Send(ping); + param_dealer_->Send(&ping); ping=param_dealer_->Receive(); string pong((char*)ping->frame_data(), ping->frame_size()); CHECK_STREQ("PONG", pong.c_str()); @@ -45,10 +44,10 @@ void Worker::Run(){ { Msg* ping=new Msg(); ping->set_src(group_id_, worker_id_, kWorkerLayer); - ping->set_dst(0,0,kStub); + ping->set_dst(-1,-1,kStub); ping->set_type(kConnect); ping->add_frame("PING", 4); - layer_dealer_->Send(ping); + layer_dealer_->Send(&ping); ping=layer_dealer_->Receive(); string pong((char*)ping->frame_data(), ping->frame_size()); CHECK_STREQ("PONG", pong.c_str()); @@ -60,37 +59,60 @@ void Worker::Run(){ //LOG(ERROR)<<layer->partitionid()<<" : "<<layer->name(); if(layer->partitionid()==worker_id_) for(auto param: layer->GetParams()){ - if(group_id_==0&¶m->owner()==param->id()){ - param->Init(0); - Put(param, step_); + if(group_id_==0){ + if(param->owner()==param->id()){ + param->Init(0); + Put(param, step_); + }else{ + Get(param, 0); + } }else{ - Get(param, step_); + Get(param, modelproto_.warmup_steps()); } } } - step_=modelproto_.step(); - Performance perf(train_net_); + Metric perf; + if(group_id_==0&&step_<modelproto_.warmup_steps()){ + for(step_=0;step_<modelproto_.warmup_steps();step_++) + RunOneBatch(step_, &perf); + for(auto layer: train_net_->layers()){ + //LOG(ERROR)<<layer->partitionid()<<" : "<<layer->name(); + if(layer->partitionid()==worker_id_) + for(auto param: layer->GetParams()) + if(param->owner()==param->id()) + Put(param, step_); + } + } while(!StopNow(step_)){ RunOneBatch(step_, &perf); step_++; } } int Worker::Put(shared_ptr<Param> param, int step){ - auto msg=pmworker_->Put(param, step); - if(msg!=nullptr) - param_dealer_->Send(msg); + Msg* msg=new Msg(); + msg->set_src(group_id_, worker_id_, kWorkerParam); + msg->set_dst(-1, -1, kStub); + msg->set_type(kPut); + msg->set_target(param->owner(), step); + param_dealer_->Send(&msg); return 1; } int Worker::Get(shared_ptr<Param> param, int step){ - auto msg=pmworker_->Get(param, step); - if(msg!=nullptr) - param_dealer_->Send(msg); + Msg* msg=new Msg(); + msg->set_src(group_id_, worker_id_, kWorkerParam); + msg->set_dst(-1, -1, kStub); + msg->set_type(kGet); + msg->set_target(param->owner(), step); + param_dealer_->Send(&msg); return 1; } int Worker::Update(shared_ptr<Param> param, int step){ - auto msg=pmworker_->Update(param, step); - if(msg!=nullptr) - param_dealer_->Send(msg); + Msg* msg=new Msg(); + msg->set_src(group_id_, worker_id_, kWorkerParam); + msg->set_dst(-1, -1, kStub); + msg->set_type(kUpdate); + msg->set_target(param->owner(), step); + param_dealer_->Send(&msg); return 1; } @@ -106,22 +128,24 @@ int Worker::CollectAll(shared_ptr<NeuralNet> net, int step){ } int Worker::Collect(shared_ptr<Param> param, int step){ while(param->version()<step){ - Socket* which=param_poller_.Wait(10); - if(which!=nullptr){ - Msg* msg=param_dealer_->Receive(); - if(msg==nullptr) - return 0; - pmworker_->Collect(&msg); - } + std::this_thread::sleep_for(std::chrono::milliseconds(FLAGS_sleep)); } return 1; } +const void Worker::DisplayPerformance(const Metric & perf, const string& prefix){ + /* TODO send perf to Stub thread for printing + Msg* msg=new Msg(); + msg->set_src(group_id_, worker_id_, kWorkerParam); + msg->set_dst(-1,-1, kStub); + msg->set_type(kMetric); + const string disp=perf.ToString(); + msg->AddFrame(disp.c_str(), disp.length()); + param_dealer_->Send(&msg); + */ + LOG(ERROR)<<prefix<<" "<<perf.ToString(); +} -void Worker::RunOneBatch(int step, Performance* perf){ - //DLOG(ERROR)<<"Step "<<step; - // Test will call Pull which updates the sync time - // Hence we store the sync time, and restore it later - //float tSyncData=tSyncData_, tSyncParam=tSyncParam_; +void Worker::RunOneBatch(int step, Metric* perf){ if(ValidateNow(step)){ LOG(ERROR)<<"Validation at step "<<step; CollectAll(validation_net_, step); @@ -132,20 +156,23 @@ void Worker::RunOneBatch(int step, Performance* perf){ CollectAll(test_net_, step); Test(test_net_, modelproto_.test_steps(), perf!=nullptr); } - //tSyncData_=tSyncData; tSyncParam_=tSyncParam; - - CollectAll(train_net_, step); TrainOneBatch(step); if(perf!=nullptr){ - perf->Update(); + auto losslayers=train_net_->losslayers(); + for(auto layer: losslayers){ + if(layer->partitionid()==worker_id_){ + const float * ptr=layer->metric().cpu_data(); + for(int j=0;j<layer->metric().count();j++) + perf->AddMetric(layer->name()+"-"+std::to_string(j), ptr[j]); + } + } + perf->Inc(); if(DisplayNow(step)){ - LOG(ERROR)<<"Training at step "<<step; - LOG(ERROR)<<"\t"<<perf->ToString(); + perf->Avg(); + DisplayPerformance(*perf, "Train at step "+std::to_string(step)); perf->Reset(); - //LOG(ERROR)<<"\t"<<TimerInfo(); } } - /* if(CheckpointNow(step)){ pm_->Checkpoint(cluster_->workspace()+"/snapshot-"+std::to_string(step)); @@ -154,44 +181,32 @@ void Worker::RunOneBatch(int step, Performance* perf){ } void Worker::ReceiveBlobs(shared_ptr<NeuralNet> net){ - /* - int type; - char *name; - int64_t tick=zclock_mono(); - zframe_t* frame=zframe_new_empty(); - - zsock_recv(pull_, "isf", &type, &name, &frame); - if(type==kDataFrame){ - auto* dst=static_cast<BridgeDstLayer*>( - net->name2layer(string(name)).get()); - memcpy(dst->mutable_data()->mutable_cpu_data(), zframe_data(frame), - zframe_size(frame)); - dst->set_ready(true); - }else if(type==kGradFrame){ - auto* src=static_cast<BridgeSrcLayer*>(net->name2layer(string(name)).get()); - memcpy(src->mutable_grad()->mutable_cpu_data(), zframe_data(frame), - zframe_size(frame)); - src->set_ready(true); - } - zframe_destroy(&frame); - delete name; - tSyncData_+=zclock_mono()-tick; - */ } void Worker::SendBlob(){ - } void Worker::Test(shared_ptr<NeuralNet> net, int nsteps, bool disperf){ - Performance perf(net); + const auto& losslayers=net->losslayers(); + Metric perf; for(int step=0;step<nsteps;step++){ TestOneBatch(net, step, kTest); - if(disperf) - perf.Update(); + if(disperf){ + for(auto layer: losslayers){ + if(layer->partitionid()==worker_id_){ + const float * ptr=layer->metric().cpu_data(); + for(int j=0;j<layer->metric().count();j++) + perf.AddMetric(layer->name()+"-"+std::to_string(j), ptr[j]); + } + } + perf.Inc(); + } + } + if(disperf){ + perf.Avg(); + DisplayPerformance(perf, "Test"); + perf.Reset(); } - if(disperf) - LOG(ERROR)<<"\t"<<perf.ToString(); } /****************************BPWorker**********************************/ @@ -204,7 +219,7 @@ void BPWorker::Forward(shared_ptr<NeuralNet> net, int step, bool training){ auto* dst=static_cast<BridgeDstLayer*>(layer.get()); while(!dst->ready()){ auto msg=layer_dealer_->Receive(); - CHECK_EQ(msg->src_group_id(), group_id_); + CHECK_EQ(msg->src_first(), group_id_); string name((char*)msg->frame_data(), msg->frame_size()); auto tmp=net->name2layer(name); CHECK(tmp->is_bridgedstlayer()); @@ -232,7 +247,7 @@ void BPWorker::Forward(shared_ptr<NeuralNet> net, int step, bool training){ msg->add_frame(dst->name().c_str(), dst->name().length()); auto const & blob=layer->data(nullptr); msg->add_frame(blob.cpu_data(), blob.count()*sizeof(float)); - layer_dealer_->Send(msg); + layer_dealer_->Send(&msg); } if(training&&DisplayDebugInfo(step)&&layer->mutable_data(nullptr)!=nullptr){ LOG(INFO)<<StringPrintf("Forward layer %10s data norm1 %13.9f", @@ -280,76 +295,4 @@ void BPWorker::TestOneBatch(shared_ptr<NeuralNet> net,int step, Phase phase){ Forward(net, step, false); } -/*********************Implementation for Performance class*******************/ -Performance::Performance(shared_ptr<NeuralNet> net):net_(net), counter_(0){ - for(auto& layer: net->losslayers()){ - name_.push_back(layer->name()); - metric_.push_back(vector<float>{}); - metric_.back().resize(layer->metric().count(),0.f); - } -} - -void Performance::Update(){ - const auto& losslayers=net_->losslayers(); - for(size_t i=0;i<losslayers.size();i++){ - const float * ptr=losslayers[i]->metric().cpu_data(); - vector<float>& m=metric_.at(i); - for(int j=0;j<losslayers[i]->metric().count();j++) - m[j]+=ptr[j]; - } - counter_++; -} - -void Performance::Reset(){ - for(auto& m: metric_) - for(auto& x: m) - x=0.f; - counter_=0; -} - -string Performance::ToString(){ - string disp=""; - for(size_t i=0;i<metric_.size();i++){ - disp+="Output from "+name_[i]+" layer "; - vector<float> m=metric_.at(i); - for(size_t j=0;j<m.size();j++) - disp+=std::to_string(j)+" : "+std::to_string(m[j]/counter_)+"\t"; - disp+="\n"; - } - return disp; -} -/* -void Executor::Setup(int local_threadid, const ModelProto& model){ - tForward_=tBackward_=tSyncData_=tSyncParam_=0; - modelproto_=model; - local_threadid_=local_threadid; - if(model.prefetch()){ - for(auto& layer: train_net_->datalayers()){ - if(cluster_->group_threadid(local_threadid_)==layer->locationid()) - localDataLayers_.push_back(layer); - } - if(localDataLayers_.size()) - prefetch_thread_=std::thread(Executor::PrefetchData, - std::ref(localDataLayers_), true,1); - } - int gthreadid=cluster_->group_threadid(local_threadid); -} - -void Executor::PrefetchData(const vector<DataLayer*>& datalayers, bool training, - int steps){ - if(datalayers.size()==0) - return; - for(int i=0;i<steps;i++){ - for(auto& layer: datalayers){ - layer->Prefetching(training); - for(auto& dstlayer: layer->dstlayers()){ - CHECK(dstlayer->is_parserlayer()); - auto parserlayer=static_cast<ParserLayer*>(dstlayer.get()); - parserlayer->Prefetching(training); - } - } - } -} -*/ - } // namespace singa http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b5b943c7/src/utils/graph.cc ---------------------------------------------------------------------- diff --git a/src/utils/graph.cc b/src/utils/graph.cc index 076c971..6fff83b 100644 --- a/src/utils/graph.cc +++ b/src/utils/graph.cc @@ -1,4 +1,5 @@ #include <algorithm> +#include <queue> #include "utils/graph.h" const string Graph::ToString() const { @@ -78,30 +79,43 @@ void Graph::topology_sort_inner(SNode node, // sort to make `bottom' nodes be placed in the front positions void Graph::Sort() { - // adjacent list from upper layers to lower layers - std::map<string, bool> visited; - // prepare adjacent list; input layers will be processed firstly, - // hence no need to sort them (mark them as visited) - for (SNode node: nodes_) { - visited[node->name()] = false; - } - // the `top' layer in the net will be placed at the bottom of the stack - // and then be processed (i.e., forward) at last - std::stack<string > stack; - for (SNode node: nodes_) { - if (visited[node->name()] == false) - topology_sort_inner(node, &visited, &stack); + SNode start=nullptr; + map<string, bool> visited; + for(auto node: nodes_){ + if(node->srcnodes().size()==0){ + CHECK(start==nullptr); + start=node; + } + visited[node->name()]=false; } + int n=nodes_.size(); + std::queue<SNode> tmp; + tmp.push(start); nodes_.clear(); - - while (!stack.empty()) { - nodes_.push_back(name2node_[stack.top()]); - stack.pop(); + while(!tmp.empty()){ + auto node=tmp.front(); + tmp.pop(); + bool visit=true; + for(auto src: node->srcnodes()) + if(visited[src->name()]==false){ + visit=false; + break; + } + if(visit){ + nodes_.push_back(node); + visited[node->name()]=true; + for(auto dst: node->dstnodes()){ + CHECK(visited.find(dst->name())!=visited.end())<<dst->name(); + if(visited[dst->name()]==false){ + tmp.push(dst); + } + } + } } + CHECK_EQ(nodes_.size(), n); } - SNode Graph::InsertSliceNode(SNode srcnode, const vector<SNode>& dstnodes, const V& info, bool connect_dst){ V myinfo=info;
