Repository: incubator-singa Updated Branches: refs/heads/master 2bbed5fc1 -> 56d32e8a0
SINGA-19 Slice large Param objects for load-balance Tested with single worker, two worker group and two worker groups TODO test with multiple servers and server groups for distributed hogwild and allreduce. Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/e0a52a62 Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/e0a52a62 Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/e0a52a62 Branch: refs/heads/master Commit: e0a52a62577cc9845130b9d2c664007ec354804c Parents: 2bbed5f Author: wang wei <[email protected]> Authored: Tue Jun 23 00:06:13 2015 +0800 Committer: wang wei <[email protected]> Committed: Tue Jun 23 15:57:27 2015 +0800 ---------------------------------------------------------------------- examples/cifar10/model.conf | 3 +- include/communication/msg.h | 2 +- include/trainer/server.h | 20 +- include/trainer/trainer.h | 138 +++++----- include/utils/cluster.h | 13 +- include/utils/common.h | 2 + include/utils/param.h | 145 ++++++++--- src/neuralnet/layer.cc | 8 +- src/proto/cluster.proto | 36 ++- src/proto/model.proto | 88 ++----- src/trainer/server.cc | 33 +-- src/trainer/trainer.cc | 544 +++++++++++++++++++++++---------------- src/trainer/worker.cc | 13 +- src/utils/cluster.cc | 49 +++- src/utils/common.cc | 15 ++ src/utils/param.cc | 221 +++++++++------- 16 files changed, 786 insertions(+), 544 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/e0a52a62/examples/cifar10/model.conf ---------------------------------------------------------------------- diff --git a/examples/cifar10/model.conf b/examples/cifar10/model.conf index 0cdf8b0..bfd7683 100644 --- a/examples/cifar10/model.conf +++ b/examples/cifar10/model.conf @@ -1,8 +1,9 @@ name: "cifar10-convnet" train_steps: 1000 -test_steps:100 +test_steps:10 test_frequency:300 display_frequency:30 +alg: kBackPropagation updater{ momentum:0.0 weight_decay:0.004 http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/e0a52a62/include/communication/msg.h ---------------------------------------------------------------------- diff --git a/include/communication/msg.h b/include/communication/msg.h index ba7d064..b83c738 100644 --- a/include/communication/msg.h +++ b/include/communication/msg.h @@ -45,7 +45,7 @@ class Msg { } inline int target_first() const { return target_first_; } inline int target_second() const { return target_second_; } - /** + /** * Copy src and dst address, including first, id, flag */ inline Msg* CopyAddr() { http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/e0a52a62/include/trainer/server.h ---------------------------------------------------------------------- diff --git a/include/trainer/server.h b/include/trainer/server.h index f9fc80b..b07741f 100644 --- a/include/trainer/server.h +++ b/include/trainer/server.h @@ -1,6 +1,7 @@ #ifndef INCLUDE_TRAINER_SERVER_H_ #define INCLUDE_TRAINER_SERVER_H_ #include <memory> +#include <unordered_map> #include <utils/param.h> #include <utils/updater.h> #include "proto/model.pb.h" @@ -8,6 +9,7 @@ using std::shared_ptr; namespace singa { +typedef std::unordered_map<int, shared_ptr<Param>> ServerShard; /* Repsond to worker's get/put/udpate request, and periodically syncing with * other servers. * @@ -20,10 +22,10 @@ namespace singa { */ class Server{ public: - typedef std::map<int, shared_ptr<Param>> ParamShard; Server(int thread_id, int group_id, int server_id); - void Setup(const UpdaterProto& proto, shared_ptr<ParamShard> shard); + void Setup(const UpdaterProto& proto, shared_ptr<ServerShard> shard, + const vector<int>& slice2group); void Run(); protected: @@ -48,7 +50,7 @@ class Server{ * @return the original message or response message. If we don't want need to * acknowledge the put request, then return nullptr. */ - virtual Msg* HandlePut(shared_ptr<Param> param, Msg **msg); + virtual void HandlePut(shared_ptr<Param> param, Msg **msg); /** * TODO Process SYNC request. @@ -57,21 +59,15 @@ class Server{ /** * TODO Process SYNC response. - */ virtual int HandleSyncResponse(shared_ptr<Param> param, Msg** msg); - - /** - * Scheduler for synchronizing server groups. - * - * TODO implement the Caffe's synchronization scheduler for data parallelism - */ - virtual bool SyncNow(); + */ protected: int thread_id_,group_id_, server_id_; shared_ptr<Dealer> dealer_; shared_ptr<Updater> updater_; - shared_ptr<ParamShard> shard_; + shared_ptr<ServerShard> shard_; + vector<int> slice2group_; }; } /* Server */ #endif //INCLUDE_TRAINER_SERVER_H_ http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/e0a52a62/include/trainer/trainer.h ---------------------------------------------------------------------- diff --git a/include/trainer/trainer.h b/include/trainer/trainer.h index 18250af..6e93f80 100644 --- a/include/trainer/trainer.h +++ b/include/trainer/trainer.h @@ -1,5 +1,6 @@ #ifndef INCLUDE_TRAINER_TRAINER_H_ #define INCLUDE_TRAINER_TRAINER_H_ +#include <unordered_map> #include "proto/cluster.pb.h" #include "proto/model.pb.h" #include "utils/updater.h" @@ -13,63 +14,73 @@ namespace singa { /** - * Every running process has a training object which launches one or more - * worker (and server) threads. - * - * The main thread runs a loop to forward messages between workers and servers. + * Callback function for zookeeper */ - -class Trainer{ +void HandleWorkerFinish(void * ctx); /** - * ParamInfo is used to construct a parameter shard. - * - * For each worker group: - * Every unique Param object is associated with a ParamCounter object whose - * param field points the to Param object itself. - * - * Param objects sharing the same values (due to data parallelism) are - * associated with the same ParamCounter whose param field also shares the - * same values. - * - * Usage: we need to aggregate gradients from all workers for the shared - * parameters before sending the update request. The nUpdate counter counts - * the number. - * - * TODO test with different physical architectures. + * Zookeeper handler context used by HandleWorkerFinish(void*)function. */ - public: - class ParamInfo{ +typedef struct HandleContext_{ + shared_ptr<Dealer> dealer; + int group_id, id; +} HandleContext; +/** + * ParamInfo is used to construct a parameter shard. + * + * For each worker group: + * Every unique Param object is associated with a ParamCounter object whose + * param field points the to Param object itself. + * + * Param objects sharing the same values (due to data parallelism) are + * associated with the same ParamCounter whose param field also shares the + * same values. + * + * Usage: we need to aggregate gradients from all workers for the shared + * parameters before sending the update request. The nUpdate counter counts + * the number. + * + * TODO test with different physical architectures. + */ +class ParamInfo{ public: - ParamInfo(shared_ptr<Param> p,int local, int owner): - num_update(0), next_version(0),num_local(local), num_total(1), - owner_procs(owner){ - shares.push_back(p); - } - - /** - * Associate the counter to a Param object. - * - * @param p - * @param local 1 if this Param object is used by workers in this procs, 0 - * otherwise - * @param owner the procs id of the worker who ownes this Param object - */ - void AddParam(shared_ptr<Param> p, bool local){ - num_local+=local; - num_total+=1; - if(local) - shares.push_back(p); + ParamInfo(shared_ptr<Param> p,int local, int owner): + num_update(0), next_version(0),num_local(local), num_total(1), + owner_procs(owner){ + shares.push_back(p); } - int num_update, next_version; //!< all counters are atomic - int num_local; //!< # local workers uses the shared parameter - int num_total; //!< # total workers uses the shared parameter - int owner_procs; //!< the procs id of the worker that owns the parameter - vector<shared_ptr<Param>> shares; - }; + /** + * Associate the counter to a Param object. + * + * @param p + * @param local 1 if this Param object is used by workers in this procs, 0 + * otherwise + * @param owner the procs id of the worker who ownes this Param object + */ + void AddParam(shared_ptr<Param> p, bool local){ + num_local+=local; + num_total+=1; + if(local) + shares.push_back(p); + } + int num_update, next_version; //!< all counters are atomic + + int num_local; //!< # local workers uses the shared parameter + int num_total; //!< # total workers uses the shared parameter + int owner_procs; //!< the procs id of the worker that owns the parameter + vector<shared_ptr<Param>> shares; +}; + +typedef std::map<int, shared_ptr<ParamInfo>> WorkerShard; - typedef std::map<int, shared_ptr<ParamInfo>> ParamShard; +/** + * Every running process has a training object which launches one or more + * worker (and server) threads. + * + * The main thread runs a loop to forward messages between workers and servers. + */ +class Trainer{ public: /** * Start the training in one process @@ -84,8 +95,13 @@ class Trainer{ // point. protected: - void Run(int nworkers, int nservers, - const std::map<int, shared_ptr<ParamShard>>& shards); + + vector<shared_ptr<Server>> CreateServers(int nthread, const ModelProto& mproto, + const vector<int> slices, vector<HandleContext>* ctx); + vector<shared_ptr<Worker>> CreateWorkers(int nthread, + const ModelProto& mproto, vector<int> *slice_size); + + void Run(int nworkers, int nservers); /** * Register default implementations for all base classes used in the system, * e.g., the Updater, BaseMsg, etc. @@ -99,37 +115,35 @@ class Trainer{ /** * Workers from the same group resident in the same process share the same - * ParamShard which contains ParamCounters for Param objects used/updated by + * WorkerShard which contains ParamCounters for Param objects used/updated by * these worekrs. Shared Param objects are associated with the same * ParamCounter. */ - /** - * @return server id where the parameter is maintained. - */ - virtual int Sharding(int param_id); - /** * Generate a request message to Get the parameter object. */ - virtual Msg* HandleGet(shared_ptr<ParamInfo>counter, Msg** msg); - virtual Msg* HandleGetResponse(shared_ptr<ParamInfo>counter, Msg** msg); + virtual const vector<Msg*> HandleGet(shared_ptr<ParamInfo>counter, Msg** msg); + virtual void HandleGetResponse(shared_ptr<ParamInfo>counter, Msg** msg); /** * Generate a request message to Update the parameter object. */ - virtual Msg* HandleUpdate(shared_ptr<ParamInfo>counter, Msg** msg); - virtual int HandleUpdateResponse(shared_ptr<ParamInfo>counter, Msg** msg); + virtual const vector<Msg*> HandleUpdate(shared_ptr<ParamInfo>counter, Msg** msg); + virtual void HandleUpdateResponse(shared_ptr<ParamInfo>counter, Msg** msg); /** * Generate a request message to Put the parameter object. */ - virtual Msg* HandlePut(shared_ptr<ParamInfo>counter, Msg** msg); + virtual const vector<Msg*> HandlePut(shared_ptr<ParamInfo>counter, Msg** msg); virtual Msg* HandleConnect(Msg** msg); protected: int procs_id_; shared_ptr<Router> router_; + std::unordered_map<int, shared_ptr<WorkerShard>> worker_shards_; + shared_ptr<ServerShard> server_shard_; + vector<int> slice2server_; }; } /* singa */ #endif // INCLUDE_TRAINER_TRAINER_H_ http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/e0a52a62/include/utils/cluster.h ---------------------------------------------------------------------- diff --git a/include/utils/cluster.h b/include/utils/cluster.h index fcdb241..9648bfe 100644 --- a/include/utils/cluster.h +++ b/include/utils/cluster.h @@ -5,6 +5,7 @@ #include <utility> #include <memory> #include <vector> +#include <unordered_map> #include "proto/cluster.pb.h" #include "utils/cluster_rt.h" @@ -39,7 +40,7 @@ class Cluster { */ bool has_server()const { if(server_worker_separate()){ - CHECK_LT(procs_id_, nprocs()); + CHECK_LT(procs_id_, nprocs_); return procs_id_>=nworker_procs(); }else return procs_id_<nserver_procs(); @@ -51,7 +52,7 @@ class Cluster { if(server_worker_separate()){ return procs_id_<nworker_procs(); }else - return procs_id_<nprocs(); + return procs_id_<nprocs_; } /** * @return global procs id, which starts from 0. @@ -67,7 +68,7 @@ class Cluster { return nserver_groups()*nservers_per_group()/nservers_per_procs(); } int nprocs() const { - return cluster_.nprocs(); + return nprocs_; } const string endpoint() const { @@ -77,7 +78,7 @@ class Cluster { * @return endpoint of the router of a procs with the specified id */ const string endpoint(int procs_id) const { - CHECK_LT(procs_id, nprocs()); + CHECK_LT(procs_id, nprocs_); CHECK_GE(procs_id, 0); return endpoints_.at(procs_id); } @@ -121,18 +122,22 @@ class Cluster { return cluster_rt_; } + int ProcsIDOf(int group_id, int id, int flag); private: Cluster(const ClusterProto &cluster, int procs_id) ; void SetupFolders(const ClusterProto &cluster); + int Hash(int gid, int id, int flag); private: int procs_id_; + int nprocs_; std::vector<std::string> endpoints_; // cluster config proto ClusterProto cluster_; shared_ptr<ClusterRuntime> cluster_rt_; // make this class a singlton static shared_ptr<Cluster> instance_; + std::unordered_map<int, int> procs_ids_; }; } // namespace singa http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/e0a52a62/include/utils/common.h ---------------------------------------------------------------------- diff --git a/include/utils/common.h b/include/utils/common.h index 1644598..98a1cd7 100644 --- a/include/utils/common.h +++ b/include/utils/common.h @@ -42,6 +42,8 @@ inline void Sleep(int millisec=1){ } */ +int gcd(int a, int b); +int LeastCommonMultiple(int a, int b); inline float rand_real(){ return static_cast<float>(rand())/(RAND_MAX+1.0f); } http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/e0a52a62/include/utils/param.h ---------------------------------------------------------------------- diff --git a/include/utils/param.h b/include/utils/param.h index e55480b..897c97a 100644 --- a/include/utils/param.h +++ b/include/utils/param.h @@ -12,33 +12,110 @@ namespace singa { class Param { public: Param(); - virtual ~Param(){}; - - virtual Msg* GenGetMsg(bool copy, int v=-1); - virtual Msg* GenPutMsg(bool copy, int v=-1); - virtual Msg* GenUpdateMsg(bool copy, int v=-1); - virtual Msg* GenSyncMsg(bool copy, int v=-1); + virtual ~Param(){ } + /** + * Generate the message for a get request, i.e., get parameters from a server + * + * This function is called at worker/stub side. + * @param copy decides whether to copy the parameter values from the server. + * @param slice_idx index of the slice from which the message is generated. + * @return generated message without setting src, dst, target fields. + */ + virtual Msg* GenGetMsg(bool copy, int slice_idx); + /** + * Generate the message for a put request, i.e., put parameters to a server. + * \copydetails GenGetMsg(bool, int); + */ + virtual Msg* GenPutMsg(bool copy, int slice_idx); + /** + * Generate the message for a update request, i.e., pass info to server for + * parameter update. + * \copydetails GenGetMsg(bool, int); + */ + virtual Msg* GenUpdateMsg(bool copy, int slice_idx); + /** + * Generate the message for a synchronization request between server groups. + * + * This function is called at server side where the Param is actually a slice + * of an original Param object. + * */ + virtual Msg* GenSyncMsg(); + /** + * Generate the message to response the update request. + * + * This function is called at the server side, where the Param is actually a slice + * of an original Param object. + * @param copy if true copy the parameter value into the message, otherwise + * only transfer the pointer of the parameter values. + * @return response message pointer + */ + virtual Msg* GenUpdateResponseMsg(bool copy); + /** + * Server handling function for get request. + * + * @param msg request message + * @return resposne message + */ virtual Msg* HandleGetMsg(Msg** msg); + /** + * Server handling function for put request. + * + * \copydetails HandleGetMsg(Msg**) + */ virtual Msg* HandlePutMsg(Msg** msg); + /** + * Server handling function for synchronization message + * + * \copydetails HandleGetMsg(Msg**) + */ virtual Msg* HandleSyncMsg(Msg** msg); - virtual const std::pair<bool, int> ParseUpdateMsg(Msg** msg); - virtual Msg* GenUpdateResponseMsg(bool copy, int v=-1); - - virtual int ParseGetResponseMsg(Msg** msg); - virtual int ParsePutResponseMsg(Msg** msg); - virtual int ParseUpdateResponseMsg(Msg** msg); - virtual int ParseSyncResponseMsg(Msg** msg); + /** + * Server parses update request message. + * + * @param msg + * @return 1 for copy, 0 for no copy + */ + virtual int ParseUpdateMsg(Msg** msg); + /** + * Worker/Stub parsing function for get response. + * + * @param msg + * @param slice_idx index for the slice + */ + virtual int ParseGetResponseMsg(Msg** msg, int slice_idx); + /** + * Worker/Server parsing function for update response + * + * \copydetails ParseGetResponseMsg(Msg**, int); + */ + virtual int ParseUpdateResponseMsg(Msg** msg, int slice_idx); + /** + * Server parsing function for synchronization response. + * + * \copydetails ParseGetResponseMsg(Msg** , int); + */ + virtual int ParseSyncResponseMsg(Msg** msg, int slice_idx); /** - * setup param shape + * Setup param object + * + * @param proto includes learning rate/weight decay multipliers + * @param shape */ - virtual void Setup(const ParamProto& proto, const std::vector<int>& shape, int fan_in); + virtual void Setup(const ParamProto& proto, const std::vector<int>& shape); /* - * fill the data according to initmethod, i.e., random/gaussian/fixed value + * Fill the values according to initmethod, e.g., gaussian distribution + * + * @param version initial version + */ + virtual void InitValues(int version=0); + /** + * Share the data blob from other Param objects. + * + * @param other the Param object whose owner owns the data blob */ - virtual void Init(int v=0); void ShareData(shared_ptr<Param> other){ proto_.set_owner(other->owner()); if(data_!=nullptr) @@ -52,11 +129,6 @@ class Param { float weight_decay_multiplier() { return proto_.weight_decay_multiplier(); } - /* - const int split_threshold(){ - return proto_.split_threshold(); - } - */ const std::string& name() { return proto_.name(); } @@ -137,28 +209,35 @@ class Param { float* mutable_cpu_history(){ return history_.mutable_cpu_data(); } + int slice_start() const { + return slice_start_; + } + + int num_slices() const { + return num_slices_; + } + + void AddSlice(int slice_id, int size); + protected: + void ParseResponseMsg(Msg** msg, int slice_idx); + + protected: + /** * name of the parameter used to share wights between neuralnets */ std::string name_; shared_ptr<Blob<float>> data_; + int slice_start_, num_slices_; + vector<int> slice_offset_, slice_size_; + vector<bool> pending_put_,pending_get_, pending_update_; + int num_pending_requests_; //! gradient, history gradient of this parameter Blob<float> grad_, history_; ParamProto proto_; - int fan_in_; int local_version_; }; -/** - * To support the shared memory and distributed Hogwild algorithm. - * Each worker group has one worker. Workers from the same process share the - * memory space for parameter values. Each process has one server group which - * also shares the same memory space. Messages except synchronization messages - * only transfer pointers to parameter value or gradient space. Hence memory - * copy is avoided for intra-process communication. - */ -class HogwildParam: public Param{ -}; } // namespace singa http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/e0a52a62/src/neuralnet/layer.cc ---------------------------------------------------------------------- diff --git a/src/neuralnet/layer.cc b/src/neuralnet/layer.cc index de13ba7..04ce72a 100644 --- a/src/neuralnet/layer.cc +++ b/src/neuralnet/layer.cc @@ -46,9 +46,9 @@ void ConvolutionLayer::Setup(const LayerProto& proto, Factory<Param>* factory=Singleton<Factory<Param>>::Instance(); weight_=shared_ptr<Param>(factory->Create("Param")); - weight_->Setup(proto.param(0), vector<int>{num_filters_, col_height_}, col_height_); + weight_->Setup(proto.param(0), vector<int>{num_filters_, col_height_}); bias_=shared_ptr<Param>(factory->Create("Param")); - bias_->Setup(proto.param(1), vector<int>{num_filters_},0); + bias_->Setup(proto.param(1), vector<int>{num_filters_}); } void ConvolutionLayer::SetupAfterPartition(const LayerProto& proto, @@ -173,8 +173,8 @@ void InnerProductLayer::Setup(const LayerProto& proto, Factory<Param>* factory=Singleton<Factory<Param>>::Instance(); weight_=shared_ptr<Param>(factory->Create("Param")); bias_=shared_ptr<Param>(factory->Create("Param")); - weight_->Setup(proto.param(0), vector<int>{vdim_, hdim_}, vdim_*hdim_); - bias_->Setup(proto.param(1), vector<int>{hdim_},0); + weight_->Setup(proto.param(0), vector<int>{vdim_, hdim_}); + bias_->Setup(proto.param(1), vector<int>{hdim_}); } void InnerProductLayer::SetupAfterPartition(const LayerProto& proto, const vector<int> &shape, http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/e0a52a62/src/proto/cluster.proto ---------------------------------------------------------------------- diff --git a/src/proto/cluster.proto b/src/proto/cluster.proto index f79d7d4..52cfd51 100644 --- a/src/proto/cluster.proto +++ b/src/proto/cluster.proto @@ -1,8 +1,8 @@ package singa; message ClusterProto{ - optional int32 nworker_groups=1; - optional int32 nserver_groups=2; + optional int32 nworker_groups=1 [default=1]; + optional int32 nserver_groups=2 [default=1]; optional int32 nworkers_per_group=3 [default=1]; optional int32 nservers_per_group=4 [default=1]; optional int32 nworkers_per_procs=5 [default=1]; @@ -11,27 +11,24 @@ message ClusterProto{ // Used in standalone mode, one ip or hostname per line // For YARN or Mesos version, the processes are allocted dynamically, // hence no need to specify the hosts statically - optional string hostfile=10; + optional string hostfile=10 [default=""]; // servers and workers in different processes? optional bool server_worker_separate=11 [default=false]; - // if configured, must be consistent with the one computed from 1-6 - optional int32 nprocs=12; - // port number is used by ZeroMQ optional int32 start_port=13 [default=6723]; // local workspace, train/val/test shards, checkpoint files required string workspace=14; // relative path to workspace. if not set, use the default dir of glog - optional string log_dir=15; + optional string log_dir=15 [default="/tmp"]; // ip/hostname : port [, ip/hostname : port] optional string zookeeper_host=16 [default="localhost:2181"]; // message size limit, default 1MB // optional int32 largest_message=20 [default=1048576]; // optional float bandwidth=21 [default=100];//MB/s - repeated ServerTopology server_group = 20; + //repeated ServerTopology server_group = 20; optional int32 stub_timeout=30 [default=5000]; optional int32 worker_timeout=31 [default=5000]; @@ -50,3 +47,26 @@ message ServerTopology{ // neighbor group id repeated int32 neighbor = 3; } +enum MsgType{ + kGet=0; + kPut=1; + kSync=2; + kUpdate=3; + kSyncRequest=4; + kSyncResponse=5; + kStop=6; + kData=7; + kRGet=8; + kRUpdate=9; + kConnect=10; + kMetric=11; +}; + +enum EntityType{ + kWorkerParam=0; + kWorkerLayer=1; + kServer=2; + kStub=3; + kRuntime=4; +}; + http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/e0a52a62/src/proto/model.proto ---------------------------------------------------------------------- diff --git a/src/proto/model.proto b/src/proto/model.proto index c6e3495..8cb45a3 100644 --- a/src/proto/model.proto +++ b/src/proto/model.proto @@ -1,26 +1,4 @@ package singa; -enum MsgType{ - kGet=0; - kPut=1; - kSync=2; - kUpdate=3; - kSyncRequest=4; - kSyncResponse=5; - kStop=6; - kData=7; - kRGet=8; - kRUpdate=9; - kConnect=10; - kMetric=11; -}; - -enum EntityType{ - kWorkerParam=0; - kWorkerLayer=1; - kServer=2; - kStub=3; - kRuntime=4; -}; enum Phase { kTrain = 0; kValidation=1; @@ -33,25 +11,17 @@ enum ShareOption{ kWhole=1; }; message ModelProto{ - optional string name = 1; - // relative path to system folder - optional string train_folder=2 [default="train"]; - optional string test_folder=3 [default="test"]; - optional string validation_folder=4 [default="validation"]; + required string name = 1; // start display after this num steps optional int32 display_after_steps = 6 [default = 0]; // frequency of display optional int32 display_frequency = 7 [default = 0]; - // the time of validation - //optional int32 validation_step = 9 [default = 0]; // start validation after this num steps optional int32 validation_after_steps = 10 [default = 0]; // frequency of validation optional int32 validation_frequency = 11 [default = 0]; - // the time of test - //optional int32 test_step = 12 [default = 0]; // start test after this num steps optional int32 test_after_steps = 13 [default = 0]; // frequency of test @@ -63,24 +33,23 @@ message ModelProto{ // total num of steps for training - optional int32 train_steps = 20; + required int32 train_steps = 20; // total num of steps for validation - optional int32 validation_steps=21; + optional int32 validation_steps=21 [default=0]; // total num of steps for test - optional int32 test_steps=22; + optional int32 test_steps=22 [default=0]; // last snapshot step - optional int32 step=29 [default=0]; + optional int32 step=29; - optional UpdaterProto updater=31; + required UpdaterProto updater=31; // There are two basic algorithms for calculating gradients. // Different deep learning models use different algorithms. enum GradCalcAlg{ kBackPropagation = 1; kContrastiveDivergence = 2; } - optional GradCalcAlg alg= 32 [default = kBackPropagation]; - optional bool hogwild=33 [default=false]; - optional NetProto neuralnet = 40; + required GradCalcAlg alg= 32 [default = kBackPropagation]; + required NetProto neuralnet = 40; optional bool debug=41 [default=false]; optional int32 warmup_steps=50 [default=0]; } @@ -93,7 +62,7 @@ message NetProto{ message ParamProto { // for the program to identify it and share among layers. // e.g., "conv1_weight","fc_bias" - optional string name = 1; + required string name = 1; optional int32 id=2; // in most situations, user do not need to config this, // the program will calculate it @@ -149,25 +118,24 @@ message BlobProtos{ repeated string names=3; } - - enum PartitionType{ kDataPartition=0; kLayerPartition=1; kNone=2; } + enum ConnectionType{ kOneToOne=0; kOneToAll=1; } message LayerProto { - optional string name = 1; // the layer name - optional string type = 2; // the layer type from the enum above + required string name = 1; // the layer name + required string type = 2; // the layer type from the enum above repeated string srclayers=3; optional int32 locationid=4 [default=0]; // todo make locationID an array optional int32 partitionid=5 [default=0]; - optional PartitionType partition_type=6; + optional PartitionType partition_type=6 [default=kNone]; optional string datablob=7; // can be pos/neg neuron value for CD, neuron value/grad for BP //repeated DAryProto ary = 10; @@ -204,10 +172,10 @@ message RGBImage { optional float scale=1 [default=1.0]; optional int32 cropsize=2 [default=0]; optional bool mirror=3 [default=false]; - optional string meanfile=4; + optional string meanfile=4 [default=""]; } message SplitProto{ - optional int32 num_splits=1; + required int32 num_splits=1; } // scaled tan: A*tan(B*x) message TanhProto{ @@ -225,7 +193,7 @@ message SoftmaxLossProto { } // Message that stores parameters used by ConvolutionLayer message ConvolutionProto { - optional uint32 num_filters = 1; // The number of outputs for the layer + required uint32 num_filters = 1; // The number of outputs for the layer optional bool bias_term = 2 [default = true]; // whether to have bias terms // Pad, kernel size, and stride are all given as a single value for equal // dimensions in height and width or as Y, X pairs. @@ -235,19 +203,17 @@ message ConvolutionProto { } message ConcateProto{ - optional int32 concate_dimension=1; - optional int32 concate_num=2; + required int32 concate_dimension=1; + required int32 concate_num=2; } // Message that stores parameters used by DataLayer message DataProto { - // Specify the data source. - optional string source = 1; // path to the data file/folder, absolute or relative to the // ClusterProto::workspace - optional string path=2; + required string path=2; // Specify the batch size. - optional uint32 batchsize = 4; + required uint32 batchsize = 4; // skip [0,random_skip] records optional uint32 random_skip=5 [default=0]; } @@ -273,13 +239,13 @@ message DropoutProto { } // Message that stores parameters used by InnerProductLayer message InnerProductProto { - optional uint32 num_output = 1; // The number of outputs for the layer + required uint32 num_output = 1; // The number of outputs for the layer optional bool bias_term = 2 [default = true]; // whether to have bias terms } // Message that stores parameters used by LRNLayer message LRNProto { - optional uint32 local_size = 1 [default = 5]; + optional int32 local_size = 1 [default = 5]; optional float alpha = 2 [default = 1.]; optional float beta = 3 [default = 0.75]; enum NormRegion { @@ -305,8 +271,8 @@ message PoolingProto { } message SliceProto{ - optional int32 slice_dimension=1; - optional int32 slice_num=2; + required int32 slice_dimension=1; + required int32 slice_num=2; } // Message that stores parameters used by ReLULayer message ReLUProto { @@ -356,9 +322,9 @@ message UpdaterProto { optional float pow=7 [default=0]; optional float delta=8 [default=0.0000001]; optional float rho=9 [default=0.9]; - optional float base_learning_rate=12; - optional float final_learning_rate=13; - optional int32 learning_rate_change_frequency = 14; + optional float base_learning_rate=12 [default=0]; + optional float final_learning_rate=13 [default=0]; + optional int32 learning_rate_change_frequency = 14 [default=0]; enum ChangeProto { kFixed = 0; kInverse_t= 1; http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/e0a52a62/src/trainer/server.cc ---------------------------------------------------------------------- diff --git a/src/trainer/server.cc b/src/trainer/server.cc index e0fcb48..36b04b6 100644 --- a/src/trainer/server.cc +++ b/src/trainer/server.cc @@ -13,15 +13,15 @@ Server::Server(int thread_id,int group_id, int server_id): thread_id_(thread_id),group_id_(group_id), server_id_(server_id){} void Server::Setup(const UpdaterProto& proto, - shared_ptr<Server::ParamShard> shard){ + shared_ptr<ServerShard> shard, const vector<int>& slice2group){ //VLOG(3) << "Parsing config file for host "<<hosts[id_] << " server id = " <<id_; updater_=shared_ptr<Updater>(Singleton<Factory<Updater>>::Instance() ->Create("Updater")); updater_->Init(proto); shard_=shard; + slice2group_=slice2group; } - void Server::Run(){ dealer_=std::make_shared<Dealer>(2*thread_id_); dealer_->Connect(kInprocRouterEndpoint); @@ -61,7 +61,7 @@ void Server::Run(){ param->set_id(pid); (*shard_)[pid]=param; } - response = HandlePut(param, &msg); + param->HandlePutMsg(&msg); }else{ int pid=msg->target_first(); if(shard_->find(pid)==shard_->end()){ @@ -83,10 +83,6 @@ void Server::Run(){ VLOG(3)<<"Handle SYNC-REQUEST"; response = HandleSyncRequest(param, &msg); break; - case kSyncResponse: - VLOG(3) << "Handle SYNC response"; - HandleSyncResponse(param, &msg); - break; } if (response!=nullptr){ dealer_->Send(&response); @@ -96,11 +92,8 @@ void Server::Run(){ } } -bool Server::SyncNow(){ - return false; -} -Msg* Server::HandlePut(shared_ptr<Param> param, Msg **msg){ - return param->HandlePutMsg(msg); +void Server::HandlePut(shared_ptr<Param> param, Msg **msg){ + param->HandlePutMsg(msg); } Msg* Server::HandleGet(shared_ptr<Param> param, Msg **msg){ @@ -110,11 +103,15 @@ Msg* Server::HandleGet(shared_ptr<Param> param, Msg **msg){ Msg* Server::HandleUpdate(shared_ptr<Param> param, Msg **msg) { //repsonse of the format: <identity><type: kData><paramId><param content> auto* tmp=static_cast<Msg*>((*msg)->CopyAddr()); - const std::pair<bool, int> copy_step=param->ParseUpdateMsg(msg); - updater_->Update(copy_step.second, param); - param->set_version(param->version()+1); - auto response=param->GenUpdateResponseMsg(copy_step.first, param->version()); tmp->SwapAddr(); + int paramid=(*msg)->target_first(); + int sliceid=(*msg)->target_second(); + int step=(*msg)->target_third(); + bool copy=param->ParseUpdateMsg(msg); + updater_->Update(step, param); + param->set_version(param->version()+1); + auto response=param->GenUpdateResponseMsg(copy); + response->set_target(paramid, sliceid, param->version()); response->SetAddr(tmp); delete tmp; return response; @@ -124,8 +121,4 @@ Msg* Server::HandleSyncRequest(shared_ptr<Param> param, Msg **msg){ return param->HandleSyncMsg(msg); } -int Server::HandleSyncResponse(shared_ptr<Param> param, Msg **msg){ - return param->ParseSyncResponseMsg(msg); -} - } /* singa */ http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/e0a52a62/src/trainer/trainer.cc ---------------------------------------------------------------------- diff --git a/src/trainer/trainer.cc b/src/trainer/trainer.cc index cd0189c..989a020 100644 --- a/src/trainer/trainer.cc +++ b/src/trainer/trainer.cc @@ -1,6 +1,7 @@ #include <thread> #include <vector> #include <map> +#include <queue> #include <glog/logging.h> #include "trainer/trainer.h" #include "mshadow/tensor.h" @@ -8,24 +9,6 @@ using std::vector; using std::map; namespace singa { -int ProcsIDOf(int group_id, int id, int flag){ - int procsid=-1; - auto cluster=Cluster::Get(); - if(flag==kServer){ - procsid=group_id*cluster->nservers_per_group()/ - cluster->nservers_per_procs()+id/cluster->nservers_per_procs(); - if(cluster->server_worker_separate()) - procsid+=cluster->nworker_procs(); - }else if(flag==kWorkerLayer || flag==kWorkerParam){ - procsid=group_id*cluster->nworkers_per_group() - /cluster->nworkers_per_procs(); - if(cluster->nworkers_per_group()>cluster->nworkers_per_procs()) - procsid+=id/cluster->nworkers_per_procs(); - }else{ - LOG(ERROR)<<"Unkown flag ("<<flag<<")"; - } - return procsid; -} void Trainer::RegisterDefaultClasses(const singa::ModelProto& proto){ // register all layers appearing in the neural net @@ -36,11 +19,6 @@ void Trainer::RegisterDefaultClasses(const singa::ModelProto& proto){ "Updater", CreateInstance(singa::SGDUpdater, singa::Updater)); } -typedef struct HandleContext_{ - shared_ptr<Dealer> dealer; - int group_id, id; -} HandleContext; - void HandleWorkerFinish(void * ctx){ HandleContext* hctx=static_cast<HandleContext*> (ctx); Msg* msg=new Msg(); @@ -50,125 +28,225 @@ void HandleWorkerFinish(void * ctx){ hctx->dealer->Send(&msg); } -void Trainer::Start(const ModelProto& mproto, const ClusterProto& cproto, - int procs_id){ - procs_id_=procs_id; - RegisterDefaultClasses(mproto); - - auto cluster=Cluster::Get(cproto, procs_id); - router_=make_shared<Router>(); - router_->Bind(kInprocRouterEndpoint); - if(cluster->nprocs()>1) - router_->Bind(cluster->endpoint()); +const std::unordered_map<int, vector<std::pair<int, int>>> SliceParams(int num, + const vector<shared_ptr<Param>>& params){ + CHECK_GT(num,0); + vector<int> param_size; + int avg=0; + for(const auto& x:params){ + if(x->owner()==x->id()) + avg+=x->size(); + } + avg/=num; + int diff=avg/10; + LOG(INFO)<<"Slicer, param avg="<<avg<<", diff= "<<diff; - // create servers - vector<shared_ptr<Server>> servers; - vector<HandleContext> ctx; - int nthreads=1; // the first socket is the router - if(cluster->has_server()){ // todo move sever creation to a method - int pid=cluster->procs_id(); - if(cluster->server_worker_separate()) - pid-=cluster->nworker_procs(); - int gid=pid*cluster->nservers_per_procs()/cluster->nservers_per_group(); - int start=pid*cluster->nservers_per_procs()%cluster->nservers_per_group(); - int end=start+cluster->nservers_per_group(); - // the ParamShard for servers consists of a dictionary of Param objects - auto shard=make_shared<Server::ParamShard>(); - if(start<end){ - auto dealer=make_shared<Dealer>(); - dealer->Connect(kInprocRouterEndpoint); - for(int sid=start;sid<end;sid++){ - auto server=make_shared<Server>(nthreads++, gid, sid); - server->Setup(mproto.updater(), shard); - servers.push_back(server); - HandleContext hc{dealer, gid, sid}; - ctx.push_back(hc); - CHECK(cluster->runtime()->sWatchSGroup(gid, sid, HandleWorkerFinish, - &ctx.back())); + int capacity=avg, sliceid=0; + std::unordered_map<int, vector<std::pair<int, int>>> paramid2slices; + for(auto& param: params){ + if(param->id()!=param->owner()) + continue; + int x=param->size(), paramid=param->id(); + LOG(INFO)<<"param id="<<paramid<<", total size="<<x; + while(x>0){ + int size=0; + if(capacity>x){ + capacity-=x; + size=x; + x=0; + }else if(capacity+diff>x){ + capacity=avg; + size=x; + x=0; + }else if(capacity>diff){ + x-=capacity; + size=capacity; + capacity=avg; + }else{ + capacity=avg; + } + if(size){ + paramid2slices[paramid].push_back(std::make_pair(sliceid++, size)); + LOG(INFO)<<"param id="<<paramid<<", slice size="<<size; } } } - // create workers - vector<shared_ptr<Worker>> workers; - std::map<int, shared_ptr<Trainer::ParamShard>> shards; - if(cluster->has_worker()){ //move worker creation to a method - auto net=NeuralNet::SetupNeuralNet(mproto.neuralnet(), kTrain, - cluster->nworkers_per_group()); - //LOG(ERROR)<<net->ToString(); - int pid=cluster->procs_id(); - int gstart, gend, wstart, wend; - if(cluster->nworkers_per_group()>=cluster->nworkers_per_procs()){ - // all workers in this procs are from the same group - gstart=pid*cluster->nworkers_per_procs()/cluster->nworkers_per_group(); - gend=gstart+1; - wstart=pid*cluster->nworkers_per_procs()%cluster->nworkers_per_group(); - wend=wstart+cluster->nworkers_per_group(); + return paramid2slices; +} +const vector<int> PartitionSlice(int num, const vector<int>& slices){ + int avg=0; + for(int x: slices) + avg+=x; + avg/=num; + int box=avg, boxid=0, diff=avg/10; + vector<int> slice2box; + for(int x: slices){ + if(box>=x){ + box-=x; + slice2box.push_back(boxid); + }else if(box+diff>=x){ + slice2box.push_back(boxid); + box=avg; + boxid++; }else{ - // there are multiple groups in this procs - CHECK_EQ(cluster->nworkers_per_procs()%cluster->nworkers_per_group(),0); - int groups_per_procs= - cluster->nworkers_per_procs()/cluster->nworkers_per_group(); - gstart=pid*groups_per_procs; - gend=(pid+1)*groups_per_procs; - wstart=0; - wend=cluster->nworkers_per_group(); + box=avg; + boxid++; } - for(int gid=gstart;gid<gend;gid++){ - shared_ptr<NeuralNet> train_net, test_net, validation_net; - if(gid==gstart) - train_net=net; - else{ - train_net=NeuralNet::SetupNeuralNet(mproto.neuralnet(), kTrain, + } + CHECK_LE(boxid, num); + return slice2box; +} +vector<shared_ptr<Server>> Trainer::CreateServers(int nthreads, + const ModelProto & mproto, + const vector<int> slices, + vector<HandleContext>* ctx){ + auto cluster=Cluster::Get(); + vector<shared_ptr<Server>> servers; + if(!cluster->has_server()) + return servers; + + int pid=cluster->procs_id(); + if(cluster->server_worker_separate()) + pid-=cluster->nworker_procs(); + int gid=pid*cluster->nservers_per_procs()/cluster->nservers_per_group(); + int start=pid*cluster->nservers_per_procs()%cluster->nservers_per_group(); + int end=start+cluster->nservers_per_group(); + // the ServerShard for servers consists of a dictionary of Param objects + server_shard_=make_shared<ServerShard>(); + auto slice2group=PartitionSlice(cluster->nserver_groups(), slices); + if(start<end){ + auto dealer=make_shared<Dealer>(); + dealer->Connect(kInprocRouterEndpoint); + for(int sid=start;sid<end;sid++){ + auto server=make_shared<Server>(nthreads++, gid, sid); + server->Setup(mproto.updater(), server_shard_, slice2group); + servers.push_back(server); + HandleContext hc{dealer, gid, sid}; + ctx->push_back(hc); + CHECK(cluster->runtime()->sWatchSGroup(gid, sid, HandleWorkerFinish, + &(ctx->back()))); + } + } + return servers; +} + +vector<shared_ptr<Worker>> Trainer::CreateWorkers(int nthreads, + const ModelProto& mproto, vector<int> *slice_size){ + auto cluster=Cluster::Get(); + vector<shared_ptr<Worker>> workers; + if(!cluster->has_worker()) + return workers; + //LOG(ERROR)<<net->ToString(); + int pid=cluster->procs_id(); + int gstart, gend, wstart, wend; + if(cluster->nworkers_per_group()>=cluster->nworkers_per_procs()){ + // all workers in this procs are from the same group + gstart=pid*cluster->nworkers_per_procs()/cluster->nworkers_per_group(); + gend=gstart+1; + wstart=pid*cluster->nworkers_per_procs()%cluster->nworkers_per_group(); + wend=wstart+cluster->nworkers_per_group(); + }else{ + // there are multiple groups in this procs + CHECK_EQ(cluster->nworkers_per_procs()%cluster->nworkers_per_group(),0); + int groups_per_procs= + cluster->nworkers_per_procs()/cluster->nworkers_per_group(); + gstart=pid*groups_per_procs; + gend=(pid+1)*groups_per_procs; + wstart=0; + wend=cluster->nworkers_per_group(); + } + auto net=NeuralNet::SetupNeuralNet(mproto.neuralnet(), kTrain, + cluster->nworkers_per_group()); + int lcm=LeastCommonMultiple(cluster->nserver_groups(), cluster->nservers_per_group()); + auto paramid2slices=SliceParams(lcm, net->params()); // sliceid, size + for(auto param: net->params()){ + if(param->id()==param->owner()) + for(auto entry: paramid2slices[param->id()]) + slice_size->push_back(entry.second); + } + + for(int gid=gstart;gid<gend;gid++){ + shared_ptr<NeuralNet> train_net, test_net, validation_net; + if(gid==gstart) + train_net=net; + else{ + train_net=NeuralNet::SetupNeuralNet(mproto.neuralnet(), kTrain, + cluster->nworkers_per_group()); + // the train net for other groups may share parameter values from the + // first group + if(cluster->share_memory()) + train_net->ShareParams(net, kValueOnly); + } + if(gid==0){ + // validation and test are performed only by the first group + if(mproto.test_steps()){ + test_net=NeuralNet::SetupNeuralNet(mproto.neuralnet(), kTest, cluster->nworkers_per_group()); - // the train net for other groups may share parameter values from the - // first group - if(cluster->share_memory()) - train_net->ShareParams(net, kValueOnly); + if(test_net!=nullptr) + test_net->ShareParams(train_net, kValueOnly); } - if(gid==0){ - // validation and test are performed only by the first group - if(mproto.test_steps()){ - test_net=NeuralNet::SetupNeuralNet(mproto.neuralnet(), kTest, - cluster->nworkers_per_group()); - if(test_net!=nullptr) - test_net->ShareParams(train_net, kValueOnly); - } - if(mproto.validation_steps()){ - validation_net=NeuralNet::SetupNeuralNet(mproto.neuralnet(), kValidation, - cluster->nworkers_per_group()); - if(validation_net!=nullptr) - validation_net->ShareParams(train_net, kValueOnly); - } + if(mproto.validation_steps()){ + validation_net=NeuralNet::SetupNeuralNet(mproto.neuralnet(), kValidation, + cluster->nworkers_per_group()); + if(validation_net!=nullptr) + validation_net->ShareParams(train_net, kValueOnly); } - // create ParamShard for the workers - auto shard=make_shared<Trainer::ParamShard>(); - shards[gid]=shard; - for(auto layer: train_net->layers()){ - int procsid=ProcsIDOf(gid, layer->partitionid(),kWorkerParam); - bool local=procsid==cluster->procs_id(); - for(auto param: layer->GetParams()){ - int owner_procs=param->owner()==param->id()?procsid:procs_id_; - if(shard->find(param->owner())==shard->end()) - (*shard)[param->owner()]= - make_shared<ParamInfo>(param, local, owner_procs); - else - shard->at(param->owner())->AddParam(param, local); + } + // create ServerShard for the workers + auto shard=make_shared<WorkerShard>(); + worker_shards_[gid]=shard; + for(auto layer: train_net->layers()){ + int procsid=cluster->ProcsIDOf(gid, layer->partitionid(), kWorkerLayer); + bool local=procsid==cluster->procs_id(); + for(auto param: layer->GetParams()){ + for(auto entry :paramid2slices[param->owner()]){ + param->AddSlice(entry.first, entry.second); } + int owner_procs=param->owner()==param->id()?procsid:procs_id_; + if(shard->find(param->owner())==shard->end()) + (*shard)[param->owner()]= + make_shared<ParamInfo>(param, local, owner_procs); + else + shard->at(param->owner())->AddParam(param, local); } - for(int wid=wstart;wid<wend;wid++){ - shared_ptr<Worker> worker=nullptr; - if(mproto.alg()==ModelProto_GradCalcAlg_kBackPropagation) - worker=make_shared<BPWorker>(nthreads++,gid, wid); - else{ + } + for(int wid=wstart;wid<wend;wid++){ + shared_ptr<Worker> worker=nullptr; + if(mproto.alg()==ModelProto_GradCalcAlg_kBackPropagation) + worker=make_shared<BPWorker>(nthreads++,gid, wid); + else{ // TODO add CDWorker - } - worker->Setup(mproto, train_net); - worker->set_test_net(test_net); - worker->set_validation_net(validation_net); - workers.push_back(worker); } + worker->Setup(mproto, train_net); + worker->set_test_net(test_net); + worker->set_validation_net(validation_net); + workers.push_back(worker); } } + return workers; +} + +void Trainer::Start(const ModelProto& mproto, const ClusterProto& cproto, + int procs_id){ + procs_id_=procs_id; + RegisterDefaultClasses(mproto); + + auto cluster=Cluster::Get(cproto, procs_id); + router_=make_shared<Router>(); + router_->Bind(kInprocRouterEndpoint); + if(cluster->nprocs()>1) + router_->Bind(cluster->endpoint()); + + int nthreads=1; + // create workers + vector<int> slices; + vector<shared_ptr<Worker>> workers=CreateWorkers(nthreads, mproto, &slices); + slice2server_=PartitionSlice(cluster->nservers_per_group(), slices); + nthreads+=workers.size(); + // create servers + vector<HandleContext> ctx; + vector<shared_ptr<Server>> servers=CreateServers(nthreads, mproto, slices, + &ctx); #ifdef USE_MPI for(int i=0;i<nSocket;i++){ @@ -180,17 +258,16 @@ void Trainer::Start(const ModelProto& mproto, const ClusterProto& cproto, threads.push_back(std::thread(&Server::Run,server.get())); for(auto worker: workers) threads.push_back(std::thread(&Worker::Run,worker.get())); - Run(workers.size(), servers.size(), shards); + Run(workers.size(), servers.size()); for(auto& thread: threads) thread.join(); } -void Trainer::Run(int nworkers, int nservers, - const std::map<int, shared_ptr<Trainer::ParamShard>>& shards){ +void Trainer::Run(int nworkers, int nservers){ auto cluster=Cluster::Get(); procs_id_=cluster->procs_id(); map<int, shared_ptr<Dealer>> interprocs_dealers; - Metric perf; + std::queue<Msg*> msg_queue; bool stop=false; while(!stop){ Msg* msg=router_->Receive(); @@ -198,13 +275,16 @@ void Trainer::Run(int nworkers, int nservers, LOG(ERROR)<<"Connection broken!"; exit(0); } - while(msg!=nullptr){ + msg_queue.push(msg); + while(!msg_queue.empty()){ + msg=msg_queue.front(); + msg_queue.pop(); int dst_flag=msg->dst_flag(); int type=msg->type(); int dst_procs=msg->dst_first(); if(dst_flag == kStub&&(dst_procs==procs_id_||dst_procs==-1)){ if(type==kConnect){ - msg =HandleConnect(&msg); + msg_queue.push(HandleConnect(&msg)); }else if(type==kStop){ if(msg->src_flag()==kServer) nservers--; @@ -223,63 +303,63 @@ void Trainer::Run(int nworkers, int nservers, msg->next_frame(); Metric cur; cur.ParseString(string((char*)msg->frame_data(), msg->frame_size())); - perf.AddMetrics(cur); - LOG(ERROR)<<prefix<<" step-" <<step<<", "<<perf.ToString(); - perf.Reset(); + LOG(ERROR)<<prefix<<" step-" <<step<<", "<<cur.ToString(); } DeleteMsg(&msg); }else if(cluster->nserver_groups()>0){ - int group_id=msg->src_first(); + int group_id; int paramid=msg->target_first(); - auto entry=shards.at(group_id)->at(paramid); + shared_ptr<ParamInfo> entry; switch (type){ // TODO process other requests, e.g. RESTful case kUpdate: - msg=HandleUpdate(entry, &msg); + group_id=msg->src_first(); + entry=worker_shards_.at(group_id)->at(paramid); + for(auto x:HandleUpdate(entry, &msg)) + msg_queue.push(x); break; case kRUpdate: + group_id=msg->dst_second(); + entry=worker_shards_.at(group_id)->at(paramid); HandleUpdateResponse(entry, &msg); break; case kGet: - msg=HandleGet(entry, &msg); + group_id=msg->src_first(); + entry=worker_shards_.at(group_id)->at(paramid); + for(auto x:HandleGet(entry, &msg)) + msg_queue.push(x); break; case kRGet: - msg=HandleGetResponse(entry, &msg); + group_id=msg->dst_second(); + entry=worker_shards_.at(group_id)->at(paramid); + HandleGetResponse(entry, &msg); break; case kPut: - msg=HandlePut(entry, &msg); + group_id=msg->src_first(); + entry=worker_shards_.at(group_id)->at(paramid); + for(auto x:HandlePut(entry, &msg)) + msg_queue.push(x); break; default: break; } }else{ - delete msg; - msg=nullptr; + DeleteMsg(&msg); } }else{ int dst_procs_id; if(dst_flag==kStub){ dst_procs_id=msg->dst_first(); }else{ - dst_procs_id=ProcsIDOf(msg->dst_first(), msg->dst_second(), msg->dst_flag()); + dst_procs_id=cluster->ProcsIDOf(msg->dst_first(), + msg->dst_second(), msg->dst_flag()); } if(dst_procs_id!=procs_id_){ - /* - // forward to other procs - if (interprocs_dealers.find(procs_id)==interprocs_dealers.end()) - interprocs_dealers[procs_id]=make_shared<Dealer>(procs_id); - interprocs_dealers[procs_id]->Send(&msg); - */ }else{ router_->Send(&msg); } } } } - /* - perf.Avg(); - if(perf_step>=0) - LOG(ERROR)<<perf_prefix<<" step-"<<perf_step<<", "<<perf.ToString(); - */ } Msg* Trainer::HandleConnect(Msg** msg){ string ping((char*)(*msg)->frame_data(), (*msg)->frame_size()); @@ -294,64 +374,45 @@ Msg* Trainer::HandleConnect(Msg** msg){ *msg=NULL; return reply; } -int Trainer::Sharding(int param_id){ - return param_id%Cluster::Get()->nservers_per_group(); -} -/* -int Worker::Sharding(int param_id){ - static map<int, int> id2procs; - if(id2procs.find(param_id)==id2procs.end()){ - auto cluster=Cluster::Get(); - int server_group=group_id_%cluster->nserver_groups(); - int nprocs_per_server_group= - cluster->nservers_per_group()/cluster->nservers_per_procs(); - int procsid=server_group*nprocs_per_server_group+ - param_id%nprocs_per_server_group; - procsid= cluster->server_worker_separate()? - cluster->nworker_procs()+procsid:procsid; - id2procs[param_id]=procsid; - } - return id2procs[param_id]; -} -*/ - -Msg* Trainer::HandleGet(shared_ptr<ParamInfo> pi, Msg** msg){ - Msg* msgg=*msg, *reply=nullptr; +const vector<Msg*> Trainer::HandleGet(shared_ptr<ParamInfo> pi, Msg** msg){ + Msg* msgg=*msg; + vector<Msg*> replies; int version=msgg->target_second(); if(msgg->src_flag()==kStub){ if(version<=pi->shares.at(0)->version()){ - reply=pi->shares.at(0)->HandleGetMsg(msg); + pi->shares.at(0)->HandleGetMsg(msg); }else if(version>pi->next_version){ // reinsert into a msg queue. } }else if(version>pi->next_version){ pi->next_version=version; - int gid=msgg->src_first(), pid=msgg->target_first(); - int dstgroup=gid/Cluster::Get()->nworker_groups_per_server_group(); - int dstid=Sharding(pid); - int dstprocs=ProcsIDOf(dstgroup, dstid, kServer); - reply=pi->shares.at(0)->GenGetMsg(dstprocs!=procs_id_); - reply->set_src(procs_id_, gid, kStub); - reply->set_dst(dstgroup, dstid, kServer); + int gid=msgg->src_first(); + int group=gid/Cluster::Get()->nworker_groups_per_server_group(); + auto param=pi->shares.at(0); + for(int idx=0, id=param->slice_start();idx<param->num_slices();idx++){ + int server=slice2server_[id+idx]; + int procs=Cluster::Get()->ProcsIDOf(group, server, kServer); + auto x=param->GenGetMsg(procs!=procs_id_, idx); + x->set_target(param->owner(), id+idx, param->local_version()+1); + x->set_src(procs_id_, gid, kStub); + x->set_dst(group, server, kServer); + replies.push_back(x); + } } - return reply; -} - -Msg* Trainer::HandleGetResponse(shared_ptr<ParamInfo>pi, Msg** msg){ - pi->shares.at(0)->ParseGetResponseMsg(msg); - return nullptr; - // process get requests in waiting queue + return replies; } -Msg* Trainer::HandleUpdate(shared_ptr<ParamInfo>pi, Msg** msg){ - Msg* msgg=*msg, *update=nullptr; +const vector<Msg*> Trainer::HandleUpdate(shared_ptr<ParamInfo>pi, Msg** msg){ + Msg* msgg=*msg ; + vector<Msg*> ret; int step= msgg->target_second(); if(msgg->src_flag()==kStub){ - if(pi->num_update<pi->num_local) - return *msg; //wait unitl local updates are ready - int n; - sscanf((char*)(*msg)->frame_data(), "%d", &n); + if(pi->num_update<pi->num_local){ + ret.push_back(*msg); + return ret; //wait unitl local updates are ready + } + int n; sscanf((char*)(*msg)->frame_data(), "%d", &n); pi->num_update+=n; auto it=pi->shares.begin(); auto shape=mshadow::Shape1((*it)->size()); @@ -368,45 +429,70 @@ Msg* Trainer::HandleUpdate(shared_ptr<ParamInfo>pi, Msg** msg){ } agg/=pi->num_total; if(pi->num_local<pi->num_total){ - update=pi->shares.at(0)->GenUpdateMsg(pi->owner_procs!=procs_id_, step); + /* int gid=msgg->src_first(); - update->set_src(procs_id_, gid,kStub); - update->set_dst(pi->owner_procs, gid, kStub); + for(auto update: pi->shares.at(0)->GenUpdateMsg(step)){ + update->set_src(procs_id_, gid,kStub); + update->set_dst(pi->owner_procs, gid, kStub); + ret.push_back(update); + } pi->num_update=0; + */ } } if(pi->num_update==pi->num_total){ - int gid=msgg->src_first(); - int dstgroup=gid/Cluster::Get()->nworker_groups_per_server_group(); - int dstid=Sharding(msgg->target_first()); - int dstprocs=ProcsIDOf(dstgroup, dstid, kServer); - update=pi->shares.at(0)->GenUpdateMsg(dstprocs!=procs_id_, step); - update->set_src(procs_id_, gid, kStub); - update->set_dst(dstgroup, dstid, kServer); + auto param=pi->shares.at(0); + int group=msgg->src_first()/Cluster::Get()->nworker_groups_per_server_group(); + int srcgid=msgg->src_first(); + for(int idx=0, id=param->slice_start(); idx<param->num_slices();idx++){ + int server=slice2server_[idx+id]; + int procs=Cluster::Get()->ProcsIDOf(group, server, kServer); + auto x=param->GenUpdateMsg(procs!=procs_id_, idx); + x->set_target(param->owner(), id+idx, step); + x->set_src(procs_id_, srcgid, kStub); + x->set_dst(group, server, kServer); + ret.push_back(x); + } pi->num_update=0; } - delete *msg; - *msg=NULL; - return update; + DeleteMsg(msg); + return ret; } -int Trainer::HandleUpdateResponse(shared_ptr<Trainer::ParamInfo> pi, Msg** msg){ - HandleGetResponse(pi, msg); - return 1; -} - -Msg* Trainer::HandlePut(shared_ptr<Trainer::ParamInfo>pi, Msg** msg){ +const vector<Msg*> Trainer::HandlePut(shared_ptr<ParamInfo>pi, Msg** msg){ + vector<Msg*> ret; CHECK_NE((*msg)->src_flag(), kStub); int gid=(*msg)->src_first(); - int id=(*msg)->target_first(); - int dstgroup=gid/Cluster::Get()->nworker_groups_per_server_group(); - int dstid=Sharding(id); - int dstprocs=ProcsIDOf(dstgroup, dstid, kServer); - Msg* put=pi->shares.at(0)->GenPutMsg(dstprocs!=procs_id_); - put->set_src(procs_id_, gid , kStub); - put->set_dst(dstgroup, dstid, kServer); - delete *msg; - *msg=NULL; - return put; + auto param=pi->shares.at(0); + int group=gid/Cluster::Get()->nworker_groups_per_server_group(); + for(int idx=0, start=param->slice_start();idx<param->num_slices(); idx++){ + int server=slice2server_[start+idx]; + int procs=Cluster::Get()->ProcsIDOf(group, server, kServer); + auto x=param->GenPutMsg(procs!=procs_id_, idx); + x->set_target(param->owner(), start+idx, param->version()); + x->set_src(procs_id_, gid, kStub); + x->set_dst(group, server, kServer); + ret.push_back(x); + } + DeleteMsg(msg); + return ret; +} + +void Trainer::HandleGetResponse(shared_ptr<ParamInfo>pi, Msg** msg){ + int version=(*msg)->target_third(); + int sliceid=(*msg)->target_second(); + auto param=pi->shares.at(0); + if(param->ParseGetResponseMsg(msg,sliceid-param->slice_start())) + param->set_version(version); + // process get requests in waiting queue +} + + +void Trainer::HandleUpdateResponse(shared_ptr<ParamInfo> pi, Msg** msg){ + int version=(*msg)->target_third(); + int sliceid=(*msg)->target_second(); + auto param=pi->shares.at(0); + if(param->ParseUpdateResponseMsg(msg,sliceid-param->slice_start())) + param->set_version(version); } } /* singa */ http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/e0a52a62/src/trainer/worker.cc ---------------------------------------------------------------------- diff --git a/src/trainer/worker.cc b/src/trainer/worker.cc index f047f0f..0835bbb 100644 --- a/src/trainer/worker.cc +++ b/src/trainer/worker.cc @@ -19,10 +19,7 @@ void Worker::Setup(const ModelProto& model, train_net_=train_net; modelproto_=model; auto cluster=Cluster::Get(); - if(cluster->nserver_groups()&&cluster->server_update()){ - int sgid=group_id_/cluster->nworker_groups_per_server_group(); - CHECK(cluster->runtime()->wJoinSGroup(group_id_, worker_id_, sgid)); - }else{ + if(!(cluster->nserver_groups()&&cluster->server_update())){ updater_=shared_ptr<Updater>(Singleton<Factory<Updater>>::Instance() ->Create("Updater")); updater_->Init(model.updater()); @@ -30,6 +27,12 @@ void Worker::Setup(const ModelProto& model, } void Worker::ConnectStub(shared_ptr<Dealer> dealer, EntityType type){ + if(updater_==nullptr){ + auto cluster=Cluster::Get(); + int sgid=group_id_/cluster->nworker_groups_per_server_group(); + CHECK(cluster->runtime()->wJoinSGroup(group_id_, worker_id_, sgid)); + } + dealer->Connect(kInprocRouterEndpoint); Msg* ping=new Msg(); ping->set_src(group_id_, worker_id_, type); @@ -60,7 +63,7 @@ void Worker::Run(){ for(auto param: layer->GetParams()){ if(param->owner() == param->id()){ if(group_id_==0) - param->Init(0); + param->InitValues(0); else Get(param, modelproto_.warmup_steps()); } http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/e0a52a62/src/utils/cluster.cc ---------------------------------------------------------------------- diff --git a/src/utils/cluster.cc b/src/utils/cluster.cc index b00a3cd..c344627 100644 --- a/src/utils/cluster.cc +++ b/src/utils/cluster.cc @@ -12,24 +12,38 @@ Cluster::Cluster(const ClusterProto &cluster, int procs_id) { procs_id_=procs_id; cluster_ = cluster; SetupFolders(cluster); - size_t nprocs; if(server_worker_separate()) - nprocs=nworker_procs()+nserver_procs(); + nprocs_=nworker_procs()+nserver_procs(); else - nprocs=std::max(nworker_procs(), nserver_procs()); - CHECK_LT(procs_id, nprocs); - if (cluster_.has_nprocs()) - CHECK_EQ(cluster.nprocs(), nprocs); - else - cluster_.set_nprocs(nprocs); - if(nprocs>1){ + nprocs_=std::max(nworker_procs(), nserver_procs()); + CHECK_LT(procs_id, nprocs_); + if(nprocs_>1){ std::ifstream ifs(cluster.hostfile(), std::ifstream::in); std::string line; - while(std::getline(ifs, line)&&endpoints_.size()<nprocs){ + while(std::getline(ifs, line)&&endpoints_.size()<nprocs_){ endpoints_.push_back(line); } - CHECK_EQ(endpoints_.size(), nprocs); + CHECK_EQ(endpoints_.size(), nprocs_); + } + + // locate the process id of every worker/server + int ngrps=cluster_.nworker_groups(), grp_size=cluster_.nworkers_per_group(); + int procs; + for(int i=0;i<ngrps;i++){ + for(int j=0;j<grp_size;j++){ + procs=(i*grp_size+j) / cluster_.nworkers_per_procs(); + procs_ids_[Hash(i,j,kWorkerLayer)]=procs; + procs_ids_[Hash(i,j,kWorkerParam)]=procs; + } + } + ngrps=cluster_.nserver_groups(), grp_size=cluster_.nservers_per_group(); + int offset=cluster_.server_worker_separate()? procs:0; + for(int i=0;i<ngrps;i++){ + for(int j=0;j<grp_size;j++){ + procs_ids_[Hash(i,j,kServer)]=(i*grp_size+j) / cluster_.nservers_per_procs()+offset; + } } + auto rt=new ZKClusterRT(cluster_.zookeeper_host()); rt->Init(); cluster_rt_=shared_ptr<ClusterRuntime>(static_cast<ClusterRuntime*>(rt)); @@ -52,4 +66,17 @@ shared_ptr<Cluster> Cluster::Get() { } return instance_; } +int Cluster::Hash(int gid, int id, int flag){ + int ret=-1; + if(flag==kServer){ + ret=(flag*cluster_.nserver_groups()+gid)*cluster_.nservers_per_group() + id; + }else{ + ret=(flag*cluster_.nworker_groups()+gid)*cluster_.nworkers_per_group() + id; + } + return ret; +} +int Cluster::ProcsIDOf(int group_id, int id, int flag){ + return procs_ids_.at(Hash(group_id, id, flag)); +} + } // namespace singa http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/e0a52a62/src/utils/common.cc ---------------------------------------------------------------------- diff --git a/src/utils/common.cc b/src/utils/common.cc index 0697060..783b1f9 100644 --- a/src/utils/common.cc +++ b/src/utils/common.cc @@ -85,5 +85,20 @@ void WriteProtoToBinaryFile(const Message& proto, const char* filename) { int fd= open(filename, O_CREAT|O_WRONLY|O_TRUNC, 0644); CHECK(proto.SerializeToFileDescriptor(fd)); } +int gcd(int a, int b) +{ + for (;;) + { + if (a == 0) return b; + b %= a; + if (b == 0) return a; + a %= b; + } +} +int LeastCommonMultiple(int a, int b) +{ + int temp = gcd(a, b); + return temp ? (a / temp * b) : 0; +} } // namespace singa http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/e0a52a62/src/utils/param.cc ---------------------------------------------------------------------- diff --git a/src/utils/param.cc b/src/utils/param.cc index a7e5230..75cc4cc 100644 --- a/src/utils/param.cc +++ b/src/utils/param.cc @@ -3,6 +3,7 @@ #include <chrono> #include <random> #include "utils/param.h" +#include "proto/cluster.pb.h" #include "mshadow/tensor.h" #include "utils/singleton.h" using namespace mshadow; @@ -10,59 +11,136 @@ using std::vector; using std::string; namespace singa { -Param::Param():data_(nullptr), local_version_(-1){} +Param::Param():data_(nullptr), slice_start_(0), num_slices_(0), + num_pending_requests_(0),local_version_(-1){ +} +void Param::Setup(const ParamProto& proto, const vector<int>& shape){ + data_=std::make_shared<Blob<float>>(shape); + grad_.Reshape(shape); + history_.Reshape(shape); + proto_=proto; +} -Msg* Param::GenPutMsg(bool copy, int v){ +void Param::AddSlice(int slice_id, int size){ + int offset=0; + if(slice_size_.size()>0){ + //must be added in order + CHECK_EQ(slice_start_+num_slices_, slice_id); + offset=slice_offset_.back()+slice_size_.back(); + } + else{ + slice_start_=slice_id; + offset=0; + } + slice_offset_.push_back(offset); + slice_size_.push_back(size); + pending_get_.push_back(false); + pending_update_.push_back(false); + pending_put_.push_back(false); + num_slices_++; +} + +void Param::InitValues(int version){ + Tensor<cpu, 1> data(mutable_cpu_data(), Shape1(size())); + auto random=TSingleton<Random<cpu>>::Instance(); + switch (proto_.init_method()) { + case ParamProto::kConstant: + data=proto_.value(); + break; + case ParamProto::kUniform: + random->SampleUniform(data, proto_.low(), proto_.high()); + if(proto_.value()) + data*= proto_.value(); + break; + /* + case ParamProto::kUniformSqrtFanIn: + CHECK_GT(fan_in_,0); + random->SampleUniform(data, proto_.low(), proto_.high()); + if(proto_.value()) + data*= proto_.value()/ sqrt(fan_in_ / 3.0f); + break; + */ + case ParamProto::kUniformSqrtFanInOut: + random->SampleUniform(data, proto_.low(), proto_.high()); + if(proto_.value()) + data*= proto_.value()/ sqrt(data_->shape()[0] +data_->shape()[1]); + break; + case ParamProto::kGaussian: + random->SampleGaussian(data, proto_.mean(), proto_.std()); + if(proto_.value()) + data*= proto_.value(); + break; + case ParamProto::kGaussainSqrtFanIn: + random->SampleGaussian(data, proto_.mean(), proto_.std()); + if(proto_.value()) + data*= proto_.value()/ sqrt(data_->shape()[0]); + break; + default: + LOG(ERROR) << "Illegal parameter init method "; + break; + } + set_version(version); +} + +/**************Message related functions********/ +Msg* Param::GenPutMsg(bool copy, int idx){ + CHECK_LT(idx, num_slices_); Msg* msg=new Msg(); msg->set_type(kPut); - msg->set_target(owner(), version()); char buf[128]; - sprintf(buf, "%d %f %f", size(), + sprintf(buf, "%d %f %f", slice_size_[idx], learning_rate_multiplier(), weight_decay_multiplier()); + void *ptr=mutable_cpu_data()+slice_offset_[idx]; if(copy){ sprintf(buf+strlen(buf), " %p ", nullptr); msg->add_frame(buf, strlen(buf)); - msg->add_frame(mutable_cpu_data(), size()*sizeof(float)); + msg->add_frame(ptr, slice_size_[idx]*sizeof(float)); }else{ - //share the data blob which includes the blob version - sprintf(buf+strlen(buf), " %p ", data_.get()); + sprintf(buf+strlen(buf), " %p ", ptr); msg->add_frame(buf, strlen(buf)); } + pending_put_[idx]=true; + num_pending_requests_++; return msg; } -Msg* Param::GenGetMsg(bool copy, int v){ +Msg* Param::GenGetMsg(bool copy, int idx){ + CHECK_LT(idx, num_slices_); Msg* msg=new Msg(); msg->set_type(kGet); - msg->set_target(owner(), local_version()+1); msg->add_frame(©, sizeof(bool)); + pending_get_[idx]=true; + num_pending_requests_++; return msg; } -Msg* Param::GenUpdateMsg(bool copy, int v){ +Msg* Param::GenUpdateMsg(bool copy, int idx){ + CHECK_LT(idx, num_slices_); Msg* msg=new Msg(); msg->set_type(kUpdate); - msg->set_target(owner(), v); msg->add_frame(©, sizeof(bool)); + void* ptr=grad_.mutable_cpu_data()+slice_offset_[idx]; if(copy) - msg->add_frame(mutable_cpu_grad(), size()*sizeof(float)); + msg->add_frame(ptr, slice_size_[idx]*sizeof(float)); else{ // to share values of grad blob - char buf[32]; sprintf(buf, " %p ", &grad_); + char buf[32]; sprintf(buf, " %p ", ptr); msg->add_frame(buf, strlen(buf)); } + pending_update_[idx]=true; + num_pending_requests_++; return msg; } -Msg* Param::GenSyncMsg(bool copy, int v){ +Msg* Param::GenSyncMsg(){ return nullptr; } Msg* Param::HandlePutMsg(Msg** msg){ int size; float lr, wc; - void* ptr; - sscanf(static_cast<char*>((*msg)->frame_data()), "%d %f %f %p ", - &size, &lr, &wc, &ptr); + float* ptr; + sscanf(static_cast<char*>((*msg)->frame_data()), + "%d %f %f %p ", &size, &lr, &wc, &ptr); proto_.set_learning_rate_multiplier(lr); proto_.set_weight_decay_multiplier(wc); vector<int> shape{size}; @@ -70,49 +148,45 @@ Msg* Param::HandlePutMsg(Msg** msg){ history_.Reshape(shape); data_=std::make_shared<Blob<float>>(shape); if(ptr==nullptr){ - data_->set_version((*msg)->target_second()); CHECK((*msg)->next_frame()); CHECK_EQ(size* sizeof(float), (*msg)->frame_size()); memcpy(mutable_cpu_data(), (*msg)->frame_data(), size*sizeof(float)); } else{ - data_->ShareData(*static_cast<Blob<float>*>(ptr)); + data_->set_cpu_data(ptr); } DeleteMsg(msg); return nullptr; } Msg* Param::HandleGetMsg(Msg** msg){ - if((*msg)->target_second()<=version()){ - bool* copy=static_cast<bool*>((*msg)->frame_data()); - (*msg)->next_frame(); - if(*copy) - (*msg)->add_frame(mutable_cpu_data(), sizeof(float)*size()); - (*msg)->SwapAddr(); - (*msg)->set_type(kRGet); - } + bool* copy=static_cast<bool*>((*msg)->frame_data()); + (*msg)->next_frame(); + if(*copy) + (*msg)->add_frame(mutable_cpu_data(), sizeof(float)*size()); + // else the mem space is shared among all worker and servers + (*msg)->SwapAddr(); + (*msg)->set_type(kRGet); return *msg; } -const std::pair<bool, int> Param::ParseUpdateMsg(Msg** msg){ - int step=(*msg)->target_second(); +int Param::ParseUpdateMsg(Msg** msg){ bool* copy=static_cast<bool*>((*msg)->frame_data()); (*msg)->next_frame(); if(*copy){ CHECK((*msg)->frame_size()); memcpy(mutable_cpu_grad(), (*msg)->frame_data(),(*msg)->frame_size()); }else {// use the same data field of the grad blob - Blob<float>* ptr=nullptr; + float* ptr=nullptr; sscanf(static_cast<char*>((*msg)->frame_data()), " %p ", &ptr); - grad_.ShareData(*ptr); + grad_.set_cpu_data(ptr); } DeleteMsg(msg); - return std::make_pair(*copy, step); + return *copy; } -Msg* Param::GenUpdateResponseMsg(bool copy, int v){ +Msg* Param::GenUpdateResponseMsg(bool copy){ Msg* msg=new Msg(); msg->set_type(kRUpdate); - msg->set_target(owner(), v); msg->add_frame(©, sizeof(bool)); if(copy) msg->add_frame(mutable_cpu_data(), size()*sizeof(float)); @@ -124,74 +198,35 @@ Msg* Param::HandleSyncMsg(Msg** msg){ return nullptr; } -int Param::ParseSyncResponseMsg(Msg** msg){ +int Param::ParseSyncResponseMsg(Msg** msg, int slice_idx){ DeleteMsg(msg); return 1; } -int Param::ParsePutResponseMsg(Msg **msg){ - return ParseSyncResponseMsg(msg); + +int Param::ParseGetResponseMsg(Msg **msg, int slice_idx){ + CHECK_EQ(pending_get_[slice_idx], true); + pending_get_[slice_idx]=false; + ParseResponseMsg(msg, slice_idx); + return (--num_pending_requests_)%num_slices_==0; +} + +int Param::ParseUpdateResponseMsg(Msg **msg, int slice_idx){ + CHECK_EQ(pending_update_[slice_idx], true); + pending_update_[slice_idx]=false; + ParseResponseMsg(msg, slice_idx); + return (--num_pending_requests_)%num_slices_==0; } -int Param::ParseGetResponseMsg(Msg **msg){ + +void Param::ParseResponseMsg(Msg** msg, int slice_idx){ bool *copy=static_cast<bool*>((*msg)->frame_data()); (*msg)->next_frame(); if(*copy){ CHECK((*msg)->frame_size()); - memcpy(mutable_cpu_data(), (*msg)->frame_data(), (*msg)->frame_size()); - } // must be set after all other settings are done! - set_version((*msg)->target_second()); + memcpy(mutable_cpu_data()+slice_offset_[slice_idx-slice_start_], + (*msg)->frame_data(), (*msg)->frame_size()); + } DeleteMsg(msg); - return 1; } -int Param::ParseUpdateResponseMsg(Msg **msg){ - return ParseGetResponseMsg(msg); -} - -void Param::Setup(const ParamProto& proto, const vector<int>& shape, - int fan_in){ - data_=std::make_shared<Blob<float>>(shape); - grad_.Reshape(shape); - history_.Reshape(shape); - proto_=proto; - fan_in_=fan_in; } -void Param::Init(int v){ - Tensor<cpu, 1> data(mutable_cpu_data(), Shape1(size())); - auto random=TSingleton<Random<cpu>>::Instance(); - switch (proto_.init_method()) { - case ParamProto::kConstant: - data=proto_.value(); - break; - case ParamProto::kUniform: - random->SampleUniform(data, proto_.low(), proto_.high()); - if(proto_.value()) - data*= proto_.value(); - break; - case ParamProto::kUniformSqrtFanIn: - CHECK_GT(fan_in_,0); - random->SampleUniform(data, proto_.low(), proto_.high()); - if(proto_.value()) - data*= proto_.value()/ sqrt(fan_in_ / 3.0f); - break; - case ParamProto::kUniformSqrtFanInOut: - random->SampleUniform(data, proto_.low(), proto_.high()); - if(proto_.value()) - data*= proto_.value()/ sqrt(data_->shape()[0] +data_->shape()[1]); - break; - case ParamProto::kGaussian: - random->SampleGaussian(data, proto_.mean(), proto_.std()); - if(proto_.value()) - data*= proto_.value(); - break; - case ParamProto::kGaussainSqrtFanIn: - random->SampleGaussian(data, proto_.mean(), proto_.std()); - if(proto_.value()) - data*= proto_.value()/ sqrt(data_->shape()[0]); - break; - default: - LOG(ERROR) << "Illegal parameter init method "; - break; - } - set_version(v); -} -} // namespace singa +// namespace singa
