http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/321ef96a/include/neuralnet/neuron_layer.h ---------------------------------------------------------------------- diff --git a/include/neuralnet/neuron_layer.h b/include/neuralnet/neuron_layer.h index 6c4647d..51ba304 100644 --- a/include/neuralnet/neuron_layer.h +++ b/include/neuralnet/neuron_layer.h @@ -7,9 +7,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at -* +* * http://www.apache.org/licenses/LICENSE-2.0 -* +* * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -38,9 +38,9 @@ class ConvolutionLayer : public NeuronLayer { public: ~ConvolutionLayer(); - void Setup(const LayerProto& proto, int npartitions) override; - void ComputeFeature(int flag, Metric* perf) override; - void ComputeGradient(int flag, Metric* perf) override; + void Setup(const LayerProto& proto, const vector<Layer*>& srclayers) override; + void ComputeFeature(int flag, const vector<Layer*>& srclayers) override; + void ComputeGradient(int flag, const vector<Layer*>& srclayers) override; const std::vector<Param*> GetParams() const override { std::vector<Param*> params{weight_, bias_}; return params; @@ -63,15 +63,15 @@ class ConvolutionLayer : public NeuronLayer { */ class CConvolutionLayer : public ConvolutionLayer { public: - void ComputeFeature(int flag, Metric* perf) override; - void ComputeGradient(int flag, Metric* perf) override; + void ComputeFeature(int flag, const vector<Layer*>& srclayers) override; + void ComputeGradient(int flag, const vector<Layer*>& srclayers) override; }; class DropoutLayer : public NeuronLayer { public: - void Setup(const LayerProto& proto, int npartitions) override; - void ComputeFeature(int flag, Metric* perf) override; - void ComputeGradient(int flag, Metric* perf) override; + void Setup(const LayerProto& proto, const vector<Layer*>& srclayers) override; + void ComputeFeature(int flag, const vector<Layer*>& srclayers) override; + void ComputeGradient(int flag, const vector<Layer*>& srclayers) override; protected: // drop probability float pdrop_; @@ -90,9 +90,9 @@ class DropoutLayer : public NeuronLayer { * b_i, the neuron after normalization, N is the total num of kernels */ class LRNLayer : public NeuronLayer { - void Setup(const LayerProto& proto, int npartitions) override; - void ComputeFeature(int flag, Metric *perf) override; - void ComputeGradient(int flag, Metric* perf) override; + void Setup(const LayerProto& proto, const vector<Layer*>& srclayers) override; + void ComputeFeature(int flag, const vector<Layer*>& srclayers) override; + void ComputeGradient(int flag, const vector<Layer*>& srclayers) override; protected: //! shape of the bottom layer feature @@ -106,9 +106,9 @@ class LRNLayer : public NeuronLayer { class PoolingLayer : public NeuronLayer { public: - void Setup(const LayerProto& proto, int npartitions) override; - void ComputeFeature(int flag, Metric *perf) override; - void ComputeGradient(int flag, Metric* perf) override; + void Setup(const LayerProto& proto, const vector<Layer*>& srclayers) override; + void ComputeFeature(int flag, const vector<Layer*>& srclayers) override; + void ComputeGradient(int flag, const vector<Layer*>& srclayers) override; protected: int kernel_, pad_, stride_; @@ -121,26 +121,26 @@ class PoolingLayer : public NeuronLayer { */ class CPoolingLayer : public PoolingLayer { public: - void Setup(const LayerProto& proto, int npartitions); - void ComputeFeature(int flag, Metric *perf) override; - void ComputeGradient(int flag, Metric* perf) override; + void Setup(const LayerProto& proto, const vector<Layer*>& srclayers); + void ComputeFeature(int flag, const vector<Layer*>& srclayers) override; + void ComputeGradient(int flag, const vector<Layer*>& srclayers) override; private: Blob<float> mask_; }; class ReLULayer : public NeuronLayer { public: - void Setup(const LayerProto& proto, int npartitions) override; - void ComputeFeature(int flag, Metric *perf) override; - void ComputeGradient(int flag, Metric* perf) override; + void Setup(const LayerProto& proto, const vector<Layer*>& srclayers) override; + void ComputeFeature(int flag, const vector<Layer*>& srclayers) override; + void ComputeGradient(int flag, const vector<Layer*>& srclayers) override; }; class InnerProductLayer : public NeuronLayer { public: ~InnerProductLayer(); - void Setup(const LayerProto& proto, int npartitions) override; - void ComputeFeature(int flag, Metric* perf) override; - void ComputeGradient(int flag, Metric* perf) override; + void Setup(const LayerProto& proto, const vector<Layer*>& srclayers) override; + void ComputeFeature(int flag, const vector<Layer*>& srclayers) override; + void ComputeGradient(int flag, const vector<Layer*>& srclayers) override; const std::vector<Param*> GetParams() const override { std::vector<Param*> params{weight_, bias_}; return params; @@ -159,9 +159,9 @@ class InnerProductLayer : public NeuronLayer { */ class STanhLayer : public NeuronLayer { public: - void Setup(const LayerProto& proto, int npartitions) override; - void ComputeFeature(int flag, Metric *perf) override; - void ComputeGradient(int flag, Metric* perf) override; + void Setup(const LayerProto& proto, const vector<Layer*>& srclayers) override; + void ComputeFeature(int flag, const vector<Layer*>& srclayers) override; + void ComputeGradient(int flag, const vector<Layer*>& srclayers) override; }; /** @@ -174,19 +174,19 @@ class SigmoidLayer: public Layer { using Layer::ComputeFeature; using Layer::ComputeGradient; - void Setup(const LayerProto& proto, int npartitions) override; - void ComputeFeature(int flag, Metric* perf) override; - void ComputeGradient(int flag, Metric* perf) override; + void Setup(const LayerProto& proto, const vector<Layer*>& srclayers) override; + void ComputeFeature(int flag, const vector<Layer*>& srclayers) override; + void ComputeGradient(int flag, const vector<Layer*>& srclayers) override; }; /** * Base layer for RBM models. */ -class RBMLayer: public Layer { +class RBMLayer: virtual public Layer { public: virtual ~RBMLayer() {} - void Setup(const LayerProto& proto, int npartitions) override; + void Setup(const LayerProto& proto, const vector<Layer*>& srclayers) override; const Blob<float>& neg_data(const Layer* layer) { return neg_data_; } @@ -218,12 +218,12 @@ class RBMLayer: public Layer { /** * RBM visible layer */ -class RBMVisLayer: public RBMLayer { +class RBMVisLayer: public RBMLayer, public LossLayer { public: ~RBMVisLayer(); - void Setup(const LayerProto& proto, int npartitions) override; - void ComputeFeature(int flag, Metric* perf) override; - void ComputeGradient(int flag, Metric* perf) override; + void Setup(const LayerProto& proto, const vector<Layer*>& srclayers) override; + void ComputeFeature(int flag, const vector<Layer*>& srclayers) override; + void ComputeGradient(int flag, const vector<Layer*>& srclayers) override; private: RBMLayer* hid_layer_; @@ -235,9 +235,9 @@ class RBMVisLayer: public RBMLayer { class RBMHidLayer: public RBMLayer { public: ~RBMHidLayer(); - void Setup(const LayerProto& proto, int npartitions) override; - void ComputeFeature(int flag, Metric* perf) override; - void ComputeGradient(int flag, Metric* perf) override; + void Setup(const LayerProto& proto, const vector<Layer*>& srclayers) override; + void ComputeFeature(int flag, const vector<Layer*>& srclayers) override; + void ComputeGradient(int flag, const vector<Layer*>& srclayers) override; private: RBMLayer *vis_layer_;
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/321ef96a/include/server.h ---------------------------------------------------------------------- diff --git a/include/server.h b/include/server.h new file mode 100644 index 0000000..4b75430 --- /dev/null +++ b/include/server.h @@ -0,0 +1,133 @@ +/************************************************************ +* +* Licensed to the Apache Software Foundation (ASF) under one +* or more contributor license agreements. See the NOTICE file +* distributed with this work for additional information +* regarding copyright ownership. The ASF licenses this file +* to you under the Apache License, Version 2.0 (the +* "License"); you may not use this file except in compliance +* with the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an +* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +* KIND, either express or implied. See the License for the +* specific language governing permissions and limitations +* under the License. +* +*************************************************************/ + +#ifndef SINGA_SERVER_H_ +#define SINGA_SERVER_H_ + +#include <unordered_map> +#include <vector> +#include "comm/socket.h" +#include "proto/job.pb.h" +#include "utils/param.h" +#include "utils/updater.h" + +namespace singa { + + /* Repsond to worker's get/put/udpate request, and periodically syncing with + * other servers. + * + * Normally, the Server creates a response message for each request which + * will be sent back to the one who issued the request. However, if the request + * are not processed successfully, the original message will be returned. The + * sever does not know the returned message is a response or the original + * message. It just sends it to the router. The router will decided to + * re-send the request to the server or send it to the worker. + */ +class Server { + public: + ~Server(); + Server(int group_id, int server_id, + const JobProto& job_conf, + const std::vector<int>& slice2group, + const std::vector<int>& slice2server); + void Run(); + inline int grp_id() const { return grp_id_; } + inline int id() const { return id_; } + + protected: + /** + * Process GET request. + * + * @return the orignal message or a response message which contains the values + * of the Param with the request version. + */ + Msg* HandleGet(Msg** msg); + /** + * Process Update request. + * + * It waits until received the gradients from all workers from the same worker + * group. After updating, it responses to each sender with the new Param + * values. It may generate a sync message to the server group that maintains + * the global version of the updated Param (slice). + * + * Note: there is no counter for each worker group on the number of received + * update requests. Hence it is possible that the server would conduct the + * update when it receives x requests from group a and y requests from group + * b where x + y = group size. To avoid this problem, we can + * -# maintain request list for each group for each Param at the server side + * -# do not span a worker group among multiple nodes. then the updates from + * the same group would be locally aggregated on the worker node. And the + * server would conduct the update immediately after receiving the aggregated + * request. + * -# launch only one worker group. + * + * @return the orignal message or response message + */ + const std::vector<Msg*> HandleUpdate(Msg **msg); + /** + * Process PUT request. + * + * @return the original message or response message. If we don't want to + * acknowledge the put request, then return nullptr. + */ + Msg* HandlePut(Msg **msg); + /** + * Handle sync request from other server groups. + * + * It adds updates of Param (slice) from other server groups directly to + * local Param (slice). Currently, each Param (slice) has a master group, + * i.e., slice2group_[sliceid], which would receive such requests from all + * other server groups for the Param object. + * + * @param msg request msg containing the parameter updates + * @return response msg that contains the fresh parameter values. + */ + Msg* HandleSyncRequest(Msg** msg); + /** + * Handle sync response. + * + * The response msg includes the latest values of a Param object from the + * server group that maintainers this Param object. + * The local Param values are replaced with the addition result of local + * udpates since the sync request was sent and the received Param values. + * + * @param response message + */ + void HandleSyncResponse(Msg** msg); + + protected: + int grp_id_ = -1; + int id_ = -1; + Updater* updater_ = nullptr; + //!< map from slice ID to slice and deleted in the destructor + std::unordered_map<int, ParamEntry*> shard_; + std::vector<int> slice2group_, slice2server_; + //!< num of updates from last sync with master server group for a param/slice + std::vector<int> n_updates_; + //!< num of sync requests that have not been responded + std::vector<int> n_pending_sync_; + std::vector<Blob<float>> last_sync_; + std::unordered_map<int, std::vector<Msg*>> buffer_requests_; +}; + +} // namespace singa + +#endif // SINGA_SERVER_H_ http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/321ef96a/include/singa.h ---------------------------------------------------------------------- diff --git a/include/singa.h b/include/singa.h index d4ee557..6c801ab 100644 --- a/include/singa.h +++ b/include/singa.h @@ -7,9 +7,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at -* +* * http://www.apache.org/licenses/LICENSE-2.0 -* +* * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -22,16 +22,15 @@ #ifndef SINGA_SINGA_H_ #define SINGA_SINGA_H_ -#include "communication/socket.h" +#include "comm/socket.h" #include "neuralnet/neuralnet.h" #include "neuralnet/layer.h" #include "proto/job.pb.h" #include "proto/singa.pb.h" -#include "trainer/trainer.h" #include "utils/common.h" #include "utils/param.h" #include "utils/singleton.h" #include "utils/factory.h" -#include "driver.h" +#include "./driver.h" #endif // SINGA_SINGA_H_ http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/321ef96a/include/stub.h ---------------------------------------------------------------------- diff --git a/include/stub.h b/include/stub.h new file mode 100644 index 0000000..719f033 --- /dev/null +++ b/include/stub.h @@ -0,0 +1,109 @@ +/************************************************************ +* +* Licensed to the Apache Software Foundation (ASF) under one +* or more contributor license agreements. See the NOTICE file +* distributed with this work for additional information +* regarding copyright ownership. The ASF licenses this file +* to you under the Apache License, Version 2.0 (the +* "License"); you may not use this file except in compliance +* with the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an +* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +* KIND, either express or implied. See the License for the +* specific language governing permissions and limitations +* under the License. +* +*************************************************************/ + +#ifndef SINGA_STUB_H_ +#define SINGA_STUB_H_ + +#include <queue> +#include <unordered_map> +#include <vector> +#include <string> +#include "comm/socket.h" +#include "neuralnet/neuralnet.h" +#include "proto/job.pb.h" +#include "proto/singa.pb.h" +#include "utils/factory.h" +#include "utils/param.h" +#include "utils/singleton.h" +#include "./server.h" +#include "./worker.h" + +namespace singa { + +class Stub { + public: + ~Stub(); + /** + * Find an endpoint to bind. + */ + void Setup(); + /** + * The Stub instance runs this function in the main thread to handle (e.g., + * forward) messages from workers and servers. + * + * @param[in] slice2server the k-th value is the ID of the server that is in + * charge of updating the Param slice with ID k. Large Param objects are + * sliced into subsets for load-balance. Different subsets are updated by + * different servers. + */ + void Run(const vector<int>& slice2server, + const std::vector<Worker*>& workers, + const std::vector<Server*>& servers); + + const std::string& endpoint() const { + return endpoint_; + } + + protected: + /** + * Create a socket to send msg to the specified process + * @param dst_procs the dst process (logical) ID + * @return the newly created socket + */ + Dealer* CreateInterProcsDealer(int dst_procs); + /** + * Generate a request message to Get the parameter object. + */ + const std::vector<Msg*> HandleGetRequest(ParamEntry* entry, Msg** msg); + void HandleGetResponse(ParamEntry* entry, Msg** msg); + /** + * Generate a request message to Update the parameter object. + */ + const std::vector<Msg*> HandleUpdateRequest(ParamEntry* entry, Msg** msg); + /** + * Handle response msg from servers for the update requests. + */ + void HandleUpdateResponse(ParamEntry* entry, Msg** msg); + /** + * Generate a request message to Put the parameter object. + */ + const std::vector<Msg*> HandlePutRequest(ParamEntry* entry, Msg** msg); + /** + * Called by HandlePut, HandleUpdate and HandleGet functions + * @param type message type + * @param version param version + * @param entry + * @param msg + * @param ret generated messages + */ + void GenMsgs(int type, int version, ParamEntry* entry, + Msg* msg, std::vector<Msg*> *ret); + + + protected: + Router *router_ = nullptr; + std::string endpoint_; + std::vector<int> slice2server_; +}; + +} // namespace singa + +#endif // SINGA_STUB_H_ http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/321ef96a/include/trainer/server.h ---------------------------------------------------------------------- diff --git a/include/trainer/server.h b/include/trainer/server.h deleted file mode 100644 index 84b3a41..0000000 --- a/include/trainer/server.h +++ /dev/null @@ -1,132 +0,0 @@ -/************************************************************ -* -* Licensed to the Apache Software Foundation (ASF) under one -* or more contributor license agreements. See the NOTICE file -* distributed with this work for additional information -* regarding copyright ownership. The ASF licenses this file -* to you under the Apache License, Version 2.0 (the -* "License"); you may not use this file except in compliance -* with the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, -* software distributed under the License is distributed on an -* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -* KIND, either express or implied. See the License for the -* specific language governing permissions and limitations -* under the License. -* -*************************************************************/ - -#ifndef SINGA_TRAINER_SERVER_H_ -#define SINGA_TRAINER_SERVER_H_ - -#include <unordered_map> -#include <vector> -#include "communication/socket.h" -#include "proto/job.pb.h" -#include "utils/param.h" -#include "utils/updater.h" - -namespace singa { - - /* Repsond to worker's get/put/udpate request, and periodically syncing with - * other servers. - * - * Normally, the Server creates a response message for each request which - * will be sent back to the one who issued the request. However, if the request - * are not processed successfully, the original message will be returned. The - * sever does not know the returned message (response or the original message), - * it just sends it to the router. The router will decide to re-send the - * request to the server or send it to the worker. - */ -class Server { - public: - Server(int group_id, int server_id); - ~Server(); - void Setup(const UpdaterProto& proto, const std::vector<int>& slice2group, - const std::vector<int>& slice2server); - void Run(); - inline int grp_id() const { return grp_id_; } - inline int id() const { return id_; } - - protected: - /** - * Process GET request. - * - * @return the orignal message or a response message which contains the values - * of the Param with the request version. - */ - Msg* HandleGet(Msg** msg); - /** - * Process Update request. - * - * It waits until received the gradients from all workers from the same worker - * group. After updating, it responses to each sender with the new Param - * values. It may generate a sync message to the server group that maintains - * the global version of the updated Param (slice). - * - * Note: there is no counter for each worker group on the number of received - * update requests. Hence it is possible that the server would conduct the - * update when it receives x requests from group a and y requests from group - * b where x + y = group size. To avoid this problem, we can - * 1. maintain request list for each group for each Param at the server side - * 2. do not span a worker group among multiple nodes. then the updates from - * the same group would be locally aggregated on the worker node. And the - * server would conduct the update immediately after receiving the aggregated - * request. - * 3. launch only one worker group. - * - * @return the orignal message or response message - */ - const std::vector<Msg*> HandleUpdate(Msg **msg); - /** - * Process PUT request. - * - * @return the original message or response message. If we don't want to - * acknowledge the put request, then return nullptr. - */ - Msg* HandlePut(Msg **msg); - /** - * Handle sync request from other server groups. - * - * It adds updates of Param (slice) from other server groups directly to - * local Param (slice). Currently, each Param (slice) has a master group, - * i.e., slice2group_[sliceid], which would receive such requests from all - * other server groups for the Param object. - * - * @param msg request msg containing the parameter updates - * @return response msg that contains the fresh parameter values. - */ - Msg* HandleSyncRequest(Msg** msg); - /** - * Handle sync response. - * - * The response msg includes the latest values of a Param object, for which - * this server sent the sync request to the master/maintainer group. - * The local Param values are replaced with the addition result of local - * udpates since the sync request was sent and the received Param values. - * - * @param response message - */ - void HandleSyncResponse(Msg** msg); - - protected: - int grp_id_ = -1; - int id_ = -1; - Updater* updater_ = nullptr; - //!< map from slice ID to slice and deleted in the destructor - std::unordered_map<int, ParamEntry*> shard_; - std::vector<int> slice2group_, slice2server_; - //!< num of updates from last sync with master server group for a param/slice - std::vector<int> n_updates_; - //!< num of sync requests that have not been responded - std::vector<int> n_pending_sync_; - std::vector<Blob<float>> last_sync_; - std::unordered_map<int, std::vector<Msg*>> buffer_requests_; -}; - -} // namespace singa - -#endif // SINGA_TRAINER_SERVER_H_ http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/321ef96a/include/trainer/trainer.h ---------------------------------------------------------------------- diff --git a/include/trainer/trainer.h b/include/trainer/trainer.h deleted file mode 100644 index 1c0e039..0000000 --- a/include/trainer/trainer.h +++ /dev/null @@ -1,163 +0,0 @@ -/************************************************************ -* -* Licensed to the Apache Software Foundation (ASF) under one -* or more contributor license agreements. See the NOTICE file -* distributed with this work for additional information -* regarding copyright ownership. The ASF licenses this file -* to you under the Apache License, Version 2.0 (the -* "License"); you may not use this file except in compliance -* with the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, -* software distributed under the License is distributed on an -* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -* KIND, either express or implied. See the License for the -* specific language governing permissions and limitations -* under the License. -* -*************************************************************/ - -#ifndef SINGA_TRAINER_TRAINER_H_ -#define SINGA_TRAINER_TRAINER_H_ - -#include <queue> -#include <unordered_map> -#include <vector> -#include "communication/socket.h" -#include "neuralnet/neuralnet.h" -#include "proto/job.pb.h" -#include "proto/singa.pb.h" -#include "trainer/server.h" -#include "trainer/worker.h" -#include "utils/factory.h" -#include "utils/param.h" -#include "utils/singleton.h" - -namespace singa { - -/** - * Every running process has a training object which launches one or more - * worker (and server) threads. - * - * The main thread runs a loop to forward messages between workers and servers. - */ -class Trainer{ - public: - ~Trainer(); - /** - * Entrance function which construct the workers and servers, and luanch - * one thread per worker/server. - * - * @param resume if true resume the training from the latest checkpoint files - * @param singaConf global singa configuration including zookeeper and - * @param jobConf job configuration, including cluster and model configuration - */ - void Start(bool resume, const SingaProto& singaConf, JobProto* jobConf); - - protected: - /** - * Setting the checkpoint field of model configuration to resume training. - * - * The checkpoint folder will be searched to get the files for the latest - * checkpoint, which will be added into the checkpoint field. The workers - * would then load the values of params from the checkpoint files. - * - * @param jobConf job configuration - */ - void Resume(JobProto* jobConf); - /** - * Create server instances. - * @param nthread total num of threads in current procs which is used to - * assign each thread a local thread ID. The number of workers is extracted - * from Cluster - * @param jobConf - * @return server instances - */ - std::vector<Server*> CreateServers(const JobProto& jobConf); - /** - * Create workers instances. - * @param nthread total num of threads in current procs which is used to - * assign each thread a local thread ID. The number of workers is extracted - * from Cluster - * @param jobConf - * @return worker instances - */ - std::vector<Worker*> CreateWorkers(const JobProto& jobConf); - /** - * Setup workers and servers. - * - * For each worker, create and assign a neuralnet to it. - * For each server, create and assign the param shard to it. - * Create the partition map from slice ID to server - * @param modelConf - * @param workers - * @param servers - */ - void SetupWorkerServer(const JobProto& jobConf, - const std::vector<Worker*>& workers, - const std::vector<Server*>& servers); - void Run(const std::vector<Worker*>& workers, - const std::vector<Server*>& servers); - /** - * Display metrics to log (standard output) - */ - void DisplayMetric(Msg** msg); - /** - * Create a socket to send msg to the specified process - * @param dst_procs the dst process (logical) ID - * @return the newly created socket - */ - Dealer* CreateInterProcsDealer(int dst_procs); - /** - * Handle messages to local servers and local stub - */ - void HandleLocalMsg(std::queue<Msg*>* msg_queue, Msg** msg); - /** - * Generate a request message to Get the parameter object. - */ - const std::vector<Msg*> HandleGet(ParamEntry* entry, Msg** msg); - void HandleGetResponse(ParamEntry* entry, Msg** msg); - /** - * Generate a request message to Update the parameter object. - */ - const std::vector<Msg*> HandleUpdate(ParamEntry* entry, Msg** msg); - void HandleUpdateResponse(ParamEntry* entry, Msg** msg); - /** - * Generate a request message to Put the parameter object. - */ - const std::vector<Msg*> HandlePut(ParamEntry* entry, Msg** msg); - /** - * Called by HandlePut, HandleUpdate and HandleGet functions - * @param type message type - * @param version param version - * @param entry - * @param msg - * @param ret generated messages - */ - void GenMsgs(int type, int version, ParamEntry* entry, - Msg* msg, std::vector<Msg*> *ret); - /** - * Get a hash id for a Param object from a group. - * - * Simple multiple group_id with a large prime number 997 (assuming there are - * no more than 997 worker groups) and plus owner param id. - */ - inline int Hash(int grp_id, int param_id) { - return grp_id * 997 + param_id; - } - - protected: - int procs_id_ = -1; - Router *router_ = nullptr; - std::unordered_map<int, ParamEntry*> worker_shard_; - //!< map from slice to the server that updates it - std::vector<int> slice2server_; - // a buffer of created nets, will destroy them all in destructor - std::vector<NeuralNet*> nets_; -}; - -} // namespace singa - -#endif // SINGA_TRAINER_TRAINER_H_ http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/321ef96a/include/trainer/worker.h ---------------------------------------------------------------------- diff --git a/include/trainer/worker.h b/include/trainer/worker.h deleted file mode 100644 index 66439ec..0000000 --- a/include/trainer/worker.h +++ /dev/null @@ -1,258 +0,0 @@ -/************************************************************ -* -* Licensed to the Apache Software Foundation (ASF) under one -* or more contributor license agreements. See the NOTICE file -* distributed with this work for additional information -* regarding copyright ownership. The ASF licenses this file -* to you under the Apache License, Version 2.0 (the -* "License"); you may not use this file except in compliance -* with the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, -* software distributed under the License is distributed on an -* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -* KIND, either express or implied. See the License for the -* specific language governing permissions and limitations -* under the License. -* -*************************************************************/ - -#ifndef SINGA_TRAINER_WORKER_H_ -#define SINGA_TRAINER_WORKER_H_ - -#include <string> -#include "communication/socket.h" -#include "neuralnet/neuralnet.h" -#include "proto/job.pb.h" - -namespace singa { - -//!< sleep 5 milliseconds if the Param is not updated to the expected version -const int kCollectSleepTime = 5; -/** - * The Worker class which runs the training algorithm. - * The first worker group will initialize parameters of the Net, - * and put them into the distributed memory/table. - * The virtual function TrainOneBatch and TestOneBatch implement the - * training and test algorithm for one mini-batch data. - * - * Child workers override the two functions to implement their training - * algorithms, e.g., the BPWorker/CDWorker/BPTTWorker implements the BP/CD/BPTT - * algorithm respectively. - */ -class Worker { - public: - static Worker* Create(const JobProto& proto); - /** - * @param thread_id local thread index within the procs - * @param grp_id global worker group ID - * @param id worker ID within the group - */ - virtual void Init(int grp_id, int id); - virtual ~Worker(); - /** - * Setup members - */ - void Setup(const JobProto& job, NeuralNet* train_net, NeuralNet* valid_net, - NeuralNet* test_net); - /** - * Init all local params (i.e., params from layers resident in this worker). - * - * If the param is owned by the worker, then init it and put it to servers. - * Otherwise call Get() to get the param. The Get may not send get request. - * Because the param's own is in the same procs. Once the owner initializes - * the param, its version is visiable to all shares. - * If the training starts from scrath, the params are initialzed using random - * distributions, e.g., Gaussian distribution. After that, the worker may - * train for a couple of steps to warmup the params before put - * them to servers (warmup of JobProto controls this). - * - * If the owner param is available from checkpoint file, then its - * values are parsed from the checkpoint file instead of randomly initialized. - * For params who do not have checkpoints, randomly init them. - */ - void InitLocalParams(); - /** - * Main function of Worker. - * - * Train the neuralnet step by step, test/validation is done periodically. - */ - void Run(); - /** - * Checkpoint all params owned by the worker from the first group onto disk. - * The serialization is done using BlobProtos which includes the name, version - * and values of each Param. - * Different worker would generate different checkpoint files. The file path - * is <workspace>/checkpoint-<jobname>-step<step>-worker<worker_id>.bin - * @param step training step of this worker - * @param net the training net whose params will be dumped. - */ - void Checkpoint(int step, NeuralNet* net); - /** - * Test the perforance of the learned model on validation or test dataset. - * Test is done by the first group. - * @param net, neural network - */ - void Test(int nsteps, Phase phase, NeuralNet* net); - /** - * Train one mini-batch. - * Test/Validation is done before training. - */ - virtual void TrainOneBatch(int step, Metric* perf) = 0; - /** - * Test/validate one mini-batch. - */ - virtual void TestOneBatch(int step, Phase phase, NeuralNet* net, - Metric* perf) = 0; - /** - * Report performance to the stub. - * - * @param prefix display prefix, e.g., 'Train', 'Test' - * @param perf - */ - void Report(const std::string& prefix, const Metric & perf); - /** - * Put Param to server. - * @param param - * @param step used as current param version for the put request - */ - int Put(Param* param, int step); - /** - * Get Param with specific version from server - * If the current version >= the requested version, then return. - * Otherwise send a get request to stub who would forwards it to servers. - * @param param - * @param step requested param version - */ - int Get(Param* param, int step); - /** - * Update Param - * @param param - * @param step training step used for updating (e.g., deciding learning rate) - */ - int Update(Param* param, int step); - /** - * Block until the param is updated since sending the update request - * - * @param param - * @param step not used - */ - int Collect(Param* param, int step); - /** - * Call Collect for every param of net - */ - int CollectAll(NeuralNet* net, int step); - /** - * Receive blobs from other workers due to model partitions. - */ - void ReceiveBlobs(bool data, bool grad, BridgeLayer* layer, NeuralNet* net); - /** - * Send blobs to other workers due to model partitions. - */ - void SendBlobs(bool data, bool grad, BridgeLayer* layer, NeuralNet* net); - /** - * Check is it time to display training info, e.g., loss and precison. - */ - inline bool DisplayNow(int step) const { - return job_conf_.disp_freq() > 0 - && step >= job_conf_.disp_after() - && ((step - job_conf_.disp_after()) % job_conf_.disp_freq() == 0); - } - /** - * Check is it time to display training info, e.g., loss and precison. - */ - inline bool DisplayDebugInfo(int step) const { - return DisplayNow(step) && job_conf_.debug() && grp_id_ == 0; - } - /** - * Check is it time to stop - */ - inline bool StopNow(int step) const { - return step >= job_conf_.train_steps(); - } - /** - * Check is it time to do checkpoint. - */ - inline bool CheckpointNow(int step) const { - return grp_id_ == 0 - && job_conf_.checkpoint_freq() > 0 - && step >= job_conf_.checkpoint_after() - && ((step - job_conf_.checkpoint_after()) - % job_conf_.checkpoint_freq() == 0); - } - /** - * Check is it time to do test. - * @param step the ::Train() has been called this num times. - */ - inline bool TestNow(int step) const { - return grp_id_ == 0 - && job_conf_.test_freq() > 0 - && job_conf_.test_steps() > 0 - && step >= job_conf_.test_after() - && ((step - job_conf_.test_after()) % job_conf_.test_freq() == 0); - } - /** - * Check is it time to do validation. - * @param step the ::Train() has been called step times. - */ - inline bool ValidateNow(int step) const { - return grp_id_ == 0 - && job_conf_.valid_freq() > 0 - && job_conf_.valid_steps() > 0 - && step >= job_conf_.valid_after() - && ((step - job_conf_.valid_after()) % job_conf_.valid_freq() == 0); - } - /** - * @return group ID - */ - int grp_id() const { return grp_id_; } - /** - * @reutrn worker ID within the worker group. - */ - int id() const { return id_; } - - protected: - int grp_id_ = -1, id_ = -1; - int step_ = 0; - JobProto job_conf_; - NeuralNet* train_net_ = nullptr; - NeuralNet* test_net_ = nullptr; - NeuralNet* validation_net_ = nullptr; - Dealer* layer_dealer_ = nullptr; - Dealer* dealer_ = nullptr; -}; - -class BPWorker: public Worker { - public: - void TrainOneBatch(int step, Metric* perf) override; - void TestOneBatch(int step, Phase phase, NeuralNet* net, Metric* perf) - override; - void Forward(int step, Phase phase, NeuralNet* net, Metric* perf); - void Backward(int step, NeuralNet* net); -}; - -class CDWorker: public Worker { - public: - void TrainOneBatch(int step, Metric* perf) override; - void TestOneBatch(int step, Phase phase, NeuralNet* net, Metric* perf) - override; -}; - -inline int BlobTrgt(int grp, int layer) { - return (grp << 16) | layer; -} - -inline int BlobGrp(int blob_trgt) { - return blob_trgt >> 16; -} - -inline int BlobLayer(int blob_trgt) { - static int mask = (1 << 16) -1; - return blob_trgt & mask; -} - -} // namespace singa - -#endif // SINGA_TRAINER_WORKER_H_ http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/321ef96a/include/utils/param.h ---------------------------------------------------------------------- diff --git a/include/utils/param.h b/include/utils/param.h index e6c8c7c..f690438 100644 --- a/include/utils/param.h +++ b/include/utils/param.h @@ -7,9 +7,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at -* +* * http://www.apache.org/licenses/LICENSE-2.0 -* +* * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -25,12 +25,13 @@ #include <memory> #include <string> #include <vector> -#include "communication/msg.h" + +#include "comm/msg.h" #include "proto/job.pb.h" #include "utils/blob.h" namespace singa { - +using std::vector; /** * Base parameter generator which intializes parameter values. */ @@ -92,7 +93,34 @@ class UniformSqrtFanInOutGen : public UniformGen { */ class Param { public: - static Param* Create(const ParamProto& proto); + /** + * Create an instance of (sub) Param class based on the type from the + * configuration. + * + * @param[in] conf configuration + * @param a pointer to an instance + */ + static Param* Create(const ParamProto& conf); + + /** + * Try to slice the Param objects (from a neural net) into a given number of + * servers (groups) evenly. This is to achieve load-balance among servers. + * + * It does not change the Param objects, but just computes the length of each + * slice. + * + * @param num number of servers (groups) for maintaining the Param objects. + * @param params all Param objects from a neural net. + * @return the length of each slice. + */ + static const vector<int> ComputeSlices(int num, const vector<Param*>& params); + /** + * It computes the length of each slice and slices the Param objects by adding + * the slicing information into every Param object. + * + * @copydetails ComputeSlices() + */ + static void SliceParams(int num, const vector<Param*>& params); Param() {} virtual ~Param() {} http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/321ef96a/include/worker.h ---------------------------------------------------------------------- diff --git a/include/worker.h b/include/worker.h new file mode 100644 index 0000000..58f02c4 --- /dev/null +++ b/include/worker.h @@ -0,0 +1,311 @@ +/************************************************************ +* +* Licensed to the Apache Software Foundation (ASF) under one +* or more contributor license agreements. See the NOTICE file +* distributed with this work for additional information +* regarding copyright ownership. The ASF licenses this file +* to you under the Apache License, Version 2.0 (the +* "License"); you may not use this file except in compliance +* with the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an +* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +* KIND, either express or implied. See the License for the +* specific language governing permissions and limitations +* under the License. +* +*************************************************************/ + +#ifndef SINGA_WORKER_H_ +#define SINGA_WORKER_H_ + +#include <string> +#include <vector> +#include "comm/socket.h" +#include "neuralnet/neuralnet.h" +#include "proto/job.pb.h" + +namespace singa { + +//!< sleep 5 milliseconds if the Param is not updated to the expected version +const int kCollectSleepTime = 5; +/** + * The Worker class which runs the training algorithm. + * The first worker group will initialize parameters of the Net, + * and put them into the distributed memory/table. + * The virtual function TrainOneBatch and TestOneBatch implement the + * training and test algorithm for one mini-batch data. + * + * Child workers override the two functions to implement their training + * algorithms, e.g., the BPWorker/CDWorker/BPTTWorker implements the BP/CD/BPTT + * algorithm respectively. + */ +class Worker { + public: + /** + * Create an instance of the subclass of Worker. + * + * @param[in] conf configuration of the TrainOneBatch algorithm. Different + * Worker subclasses implement different algorithms. Hence the creation is + * based on the TrainOneBatch algorithm type. Currently SINGA + * provides two algorithms: + * -# Back-propagation for the feed-forward models, e.g., CNN and MLP, and the + * recurrent neural networks. + * -# Contrastive divergence for the energy models, e.g., RBM. + * + * @return a pointer to the instance of the Worker subclass. + */ + static Worker* Create(const AlgProto& conf); + virtual ~Worker(); + /** + * @param[in] grp_id global worker group ID + * @param[in] id worker ID within the group + * @param[in] conf job configuration + * @param[in] train_net pointer to the training neural net, which could be + * shared with other workers from the same group. Different workers run over + * differnt subset of layers. + * @param[in] val_net pointer to the validation neural net. Currently only the + * first worker from the first group would have validation neural net. All + * other workers receive nullptr for this argument. + * @param[in] test_net pointer to the test neural net. Currently only the + * first worker from the first group would have test neural net. All other + * workers receive nullptr for this argument. + */ + virtual void Setup(int grp_id, int id, const JobProto& conf, + NeuralNet* train_net, NeuralNet* val_net, NeuralNet* test_net); + + /** + * Main function of Worker. + * + * Train the neuralnet step by step, test/validation is done periodically. + */ + void Run(); + + /** + * Init values of Param instances assocaited with local layers (i.e., layers + * dispatched to this worker). + * + * If one Param is owned by the worker, then it should be initialized and put + * to servers. Otherwise Get() should be called to get the Param. The Get() + * may not send get requests if the Param owner is in the same procs, for + * which case the memory space of the Param objects are shared. But if this + * worker and the Param owner worker run on different devices (e.g., GPUs), + * then the get request would be sent. + * + * If the training starts from scrath, every Param object is initialzed using + * ParamGenerator. After that, the worker may + * train for a couple of steps to warmup the params before put + * them to servers (warmup of JobProto controls this). + * + * If one Param object's name matches that of one Param object from the + * checkpoint files, its Param values would be loaded from checkpoint files. + * + * @param[in] job_conf job configuration which provides settings for + * checkpoint file paths, warmup steps and Param versions. + * @param[out] net pointer to a neural net whose Param values will be + * initialized. + */ + void InitNetParams(const JobProto& job_conf, NeuralNet* net); + + /** + * Checkpoint all Param objects owned by the worker onto disk. + * The serialization is done using BlobProtos which includes the name, version + * and values of each Param object. + * Different workers would generate different checkpoint files. The file path + * is <workspace>/checkpoint-<jobname>-step<step>-worker<worker_id>.bin + * @param[in] step training step + * @param[in] folder directory to put the checkpoint file + * @param net the training net whose Param objects will be dumped. + */ + void Checkpoint(int step, const std::string& folder, NeuralNet* net); + + /** + * Train one mini-batch. + * Test/Validation is done before training. + * + * @param[in] step training step. + * @param[in] net neural net to be trained. + */ + virtual void TrainOneBatch(int step, NeuralNet* net) = 0; + + /** + * Test/validate one mini-batch data. + * + * @param[in] step test step. + * @param[in] phase test could be done for validation or test phase. + * @param[in] net neural net for test + */ + virtual void TestOneBatch(int step, Phase phase, NeuralNet* net) = 0; + + /** + * Display infomation from layers. + * + * @param flag could be a combination of multiple phases, e.g, kTest|kForward, + * it is passed to the Layer::ToString() function for each layer to decide + * what to display . + * @param prefix display prefix, e.g., 'Train step 100', 'Test step 90'. + * @param net display layers from this neural net. + */ + void Display(int flag, const std::string& prefix, NeuralNet* net); + + /** + * Put Param values to server. + * + * @param param + * @param step used as current param version for the put request + */ + int Put(int step, Param* param); + + /** + * Get Param with specific version from server + * If the current version >= the requested version, then return. + * Otherwise send a get request to stub who would forwards it to servers. + * @param param + * @param step requested param version + */ + int Get(int step, Param* param); + + /** + * Update Param. + * + * @param param + * @param step training step used for updating (e.g., deciding learning rate). + */ + int Update(int step, Param* param); + + /** + * Wait for the response of the update/get requests. + * + * @param param + * @param step not used now. + */ + int Collect(int step, Param* param); + + /** + * Call Collect() for every param of net + */ + int CollectAll(int step, NeuralNet* net); + + /** + * Receive blobs from other workers due to model partitions. + */ + void ReceiveBlobs(bool data, bool grad, BridgeLayer* layer, NeuralNet* net); + + /** + * Send blobs to other workers due to model partitions. + */ + void SendBlobs(bool data, bool grad, BridgeLayer* layer, NeuralNet* net); + + + /** + * @param[in] step + * @return true if it is time to display training info, e.g., loss; otherwise + * false. + */ + inline bool DisplayNow(int step) const { + return job_conf_.disp_freq() > 0 + && step >= job_conf_.disp_after() + && ((step - job_conf_.disp_after()) % job_conf_.disp_freq() == 0); + } + /** + * @param[in] step + * @return true if it is time to finish the training; otherwise false. + */ + inline bool StopNow(int step) const { + return step >= job_conf_.train_steps(); + } + /** + * @param[in] step + * @return true if it is time to do checkpoint Param objects; otherwise false. + */ + inline bool CheckpointNow(int step) const { + return job_conf_.checkpoint_freq() > 0 + && step >= job_conf_.checkpoint_after() + && ((step - job_conf_.checkpoint_after()) + % job_conf_.checkpoint_freq() == 0); + } + /** + * @param[in] step + * @return true if it is time to do test over the test dataset. + */ + inline bool TestNow(int step) const { + return job_conf_.test_freq() > 0 + && job_conf_.test_steps() > 0 + && step >= job_conf_.test_after() + && ((step - job_conf_.test_after()) % job_conf_.test_freq() == 0); + } + /** + * @param[in] step + * @return true if it is time to do test over the validation dataset. + */ + inline bool ValidateNow(int step) const { + return job_conf_.validate_freq() > 0 + && job_conf_.validate_steps() > 0 + && step >= job_conf_.validate_after() + && ((step - job_conf_.validate_after()) % job_conf_.validate_freq() == 0); + } + /** + * @return a vector with pointers to all neural nets. + */ + const std::vector<NeuralNet*> GetNets() const { + return std::vector<NeuralNet*> {train_net_, val_net_, test_net_}; + } + /** + * @return training net. + */ + inline NeuralNet* train_net() const { + return train_net_; + } + /** + * @return group ID + */ + inline int grp_id() const { return grp_id_; } + /** + * @reutrn worker ID within the worker group. + */ + inline int id() const { return id_; } + + protected: + int grp_id_ = -1, id_ = -1; + int step_ = 0; + JobProto job_conf_; + NeuralNet* train_net_ = nullptr; + NeuralNet* test_net_ = nullptr; + NeuralNet* val_net_ = nullptr; + Dealer* layer_dealer_ = nullptr; + Dealer* dealer_ = nullptr; +}; + +class BPWorker: public Worker { + public: + void TrainOneBatch(int step, NeuralNet* net) override; + void TestOneBatch(int step, Phase phase, NeuralNet* net) override; + void Forward(int step, Phase phase, NeuralNet* net); + void Backward(int step, NeuralNet* net); +}; + +class CDWorker: public Worker { + public: + void TrainOneBatch(int step, NeuralNet* net) override; + void TestOneBatch(int step, Phase phase, NeuralNet* net) override; +}; + +inline int BlobTrgt(int grp, int layer) { + return (grp << 16) | layer; +} + +inline int BlobGrp(int blob_trgt) { + return blob_trgt >> 16; +} + +inline int BlobLayer(int blob_trgt) { + static int mask = (1 << 16) -1; + return blob_trgt & mask; +} + +} // namespace singa + +#endif // SINGA_WORKER_H_ http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/321ef96a/src/comm/msg.cc ---------------------------------------------------------------------- diff --git a/src/comm/msg.cc b/src/comm/msg.cc new file mode 100644 index 0000000..2521c28 --- /dev/null +++ b/src/comm/msg.cc @@ -0,0 +1,215 @@ +/************************************************************ +* +* Licensed to the Apache Software Foundation (ASF) under one +* or more contributor license agreements. See the NOTICE file +* distributed with this work for additional information +* regarding copyright ownership. The ASF licenses this file +* to you under the Apache License, Version 2.0 (the +* "License"); you may not use this file except in compliance +* with the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an +* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +* KIND, either express or implied. See the License for the +* specific language governing permissions and limitations +* under the License. +* +*************************************************************/ + +#include "comm/msg.h" + +#include <glog/logging.h> + +namespace singa { + +#ifdef USE_ZMQ +Msg::~Msg() { + if (msg_ != nullptr) + zmsg_destroy(&msg_); + frame_ = nullptr; +} + +Msg::Msg() { + msg_ = zmsg_new(); +} + +Msg::Msg(const Msg& msg) { + src_ = msg.src_; + dst_ = msg.dst_; + type_ = msg.type_; + trgt_val_ = msg.trgt_val_; + trgt_version_ = msg.trgt_version_; + msg_ = zmsg_dup(msg.msg_); +} + +Msg::Msg(int src, int dst) { + src_ = src; + dst_ = dst; + msg_ = zmsg_new(); +} + +void Msg::SwapAddr() { + std::swap(src_, dst_); +} + +int Msg::size() const { + return zmsg_content_size(msg_); +} + +void Msg::AddFrame(const void* addr, int nBytes) { + zmsg_addmem(msg_, addr, nBytes); +} + +int Msg::FrameSize() { + return zframe_size(frame_); +} + +void* Msg::FrameData() { + return zframe_data(frame_); +} + +char* Msg::FrameStr() { + return zframe_strdup(frame_); +} +bool Msg::NextFrame() { + frame_ = zmsg_next(msg_); + return frame_ != nullptr; +} + +void Msg::FirstFrame() { + frame_ = zmsg_first(msg_); +} + +void Msg::LastFrame() { + frame_ = zmsg_last(msg_); +} + +void Msg::ParseFromZmsg(zmsg_t* msg) { + char* tmp = zmsg_popstr(msg); + sscanf(tmp, "%d %d %d %d %d", + &src_, &dst_, &type_, &trgt_val_, &trgt_version_); + frame_ = zmsg_first(msg); + msg_ = msg; +} + +zmsg_t* Msg::DumpToZmsg() { + zmsg_pushstrf(msg_, "%d %d %d %d %d", + src_, dst_, type_, trgt_val_, trgt_version_); + zmsg_t *tmp = msg_; + msg_ = nullptr; + return tmp; +} + +// frame marker indicating this frame is serialize like printf +#define FMARKER "*singa*" + +#define kMaxFrameLen 2048 + +int Msg::AddFormatFrame(const char *format, ...) { + va_list argptr; + va_start(argptr, format); + int size = strlen(FMARKER); + char dst[kMaxFrameLen]; + memcpy(dst, FMARKER, size); + dst[size++] = 0; + while (*format) { + if (*format == 'i') { + int x = va_arg(argptr, int); + dst[size++] = 'i'; + memcpy(dst + size, &x, sizeof(x)); + size += sizeof(x); + } else if (*format == 'f') { + float x = static_cast<float> (va_arg(argptr, double)); + dst[size++] = 'f'; + memcpy(dst + size, &x, sizeof(x)); + size += sizeof(x); + } else if (*format == '1') { + uint8_t x = va_arg(argptr, int); + memcpy(dst + size, &x, sizeof(x)); + size += sizeof(x); + } else if (*format == '2') { + uint16_t x = va_arg(argptr, int); + memcpy(dst + size, &x, sizeof(x)); + size += sizeof(x); + } else if (*format == '4') { + uint32_t x = va_arg(argptr, uint32_t); + memcpy(dst + size, &x, sizeof(x)); + size += sizeof(x); + } else if (*format == 's') { + char* x = va_arg(argptr, char *); + dst[size++] = 's'; + memcpy(dst + size, x, strlen(x)); + size += strlen(x); + dst[size++] = 0; + } else if (*format == 'p') { + void* x = va_arg(argptr, void *); + dst[size++] = 'p'; + memcpy(dst + size, &x, sizeof(x)); + size += sizeof(x); + } else { + LOG(ERROR) << "Unknown format " << *format; + } + format++; + CHECK_LE(size, kMaxFrameLen); + } + va_end(argptr); + zmsg_addmem(msg_, dst, size); + return size; +} + +int Msg::ParseFormatFrame(const char *format, ...) { + va_list argptr; + va_start(argptr, format); + char* src = zframe_strdup(frame_); + CHECK_STREQ(FMARKER, src); + int size = strlen(FMARKER) + 1; + while (*format) { + if (*format == 'i') { + int *x = va_arg(argptr, int *); + CHECK_EQ(src[size++], 'i'); + memcpy(x, src + size, sizeof(*x)); + size += sizeof(*x); + } else if (*format == 'f') { + float *x = va_arg(argptr, float *); + CHECK_EQ(src[size++], 'f'); + memcpy(x, src + size, sizeof(*x)); + size += sizeof(*x); + } else if (*format == '1') { + uint8_t *x = va_arg(argptr, uint8_t *); + memcpy(x, src + size, sizeof(*x)); + size += sizeof(*x); + } else if (*format == '2') { + uint16_t *x = va_arg(argptr, uint16_t *); + memcpy(x, src + size, sizeof(*x)); + size += sizeof(*x); + } else if (*format == '4') { + uint32_t *x = va_arg(argptr, uint32_t *); + memcpy(x, src + size, sizeof(*x)); + size += sizeof(*x); + } else if (*format == 's') { + char* x = va_arg(argptr, char *); + CHECK_EQ(src[size++], 's'); + int len = strlen(src + size); + memcpy(x, src + size, len); + x[len] = 0; + size += len + 1; + } else if (*format == 'p') { + void** x = va_arg(argptr, void **); + CHECK_EQ(src[size++], 'p'); + memcpy(x, src + size, sizeof(*x)); + size += sizeof(*x); + } else { + LOG(ERROR) << "Unknown format type " << *format; + } + format++; + } + va_end(argptr); + delete src; + return size; +} +#endif + +} // namespace singa http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/321ef96a/src/comm/socket.cc ---------------------------------------------------------------------- diff --git a/src/comm/socket.cc b/src/comm/socket.cc new file mode 100644 index 0000000..b9c7810 --- /dev/null +++ b/src/comm/socket.cc @@ -0,0 +1,180 @@ +/************************************************************ +* +* Licensed to the Apache Software Foundation (ASF) under one +* or more contributor license agreements. See the NOTICE file +* distributed with this work for additional information +* regarding copyright ownership. The ASF licenses this file +* to you under the Apache License, Version 2.0 (the +* "License"); you may not use this file except in compliance +* with the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an +* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +* KIND, either express or implied. See the License for the +* specific language governing permissions and limitations +* under the License. +* +*************************************************************/ +#include "comm/socket.h" + +#include <glog/logging.h> + +namespace singa { + +#ifdef USE_ZMQ +Poller::Poller() { + poller_ = zpoller_new(nullptr); +} + +Poller::Poller(SocketInterface* socket) { + poller_ = zpoller_new(nullptr); + Add(socket); +} + +void Poller::Add(SocketInterface* socket) { + zsock_t* zsock = static_cast<zsock_t*>(socket->InternalID()); + zpoller_add(poller_, zsock); + zsock2Socket_[zsock] = socket; +} + +SocketInterface* Poller::Wait(int timeout) { + zsock_t* sock = static_cast<zsock_t*>(zpoller_wait(poller_, timeout)); + if (sock != nullptr) + return zsock2Socket_[sock]; + else + return nullptr; +} + +bool Poller::Terminated() { + return zpoller_terminated(poller_); +} + + +Dealer::Dealer() : Dealer(-1) {} + +Dealer::Dealer(int id) : id_(id) { + dealer_ = zsock_new(ZMQ_DEALER); + CHECK_NOTNULL(dealer_); +} + +Dealer::~Dealer() { + zsock_destroy(&dealer_); +} + +int Dealer::Connect(const std::string& endpoint) { + CHECK_GT(endpoint.length(), 0); + if (endpoint.length()) { + CHECK_EQ(zsock_connect(dealer_, "%s", endpoint.c_str()), 0); + return 1; + } + return 0; +} + +int Dealer::Send(Msg** msg) { + zmsg_t* zmsg = (*msg)->DumpToZmsg(); + zmsg_send(&zmsg, dealer_); + delete *msg; + *msg = nullptr; + return 1; +} + +Msg* Dealer::Receive() { + zmsg_t* zmsg = zmsg_recv(dealer_); + if (zmsg == nullptr) + return nullptr; + Msg* msg = new Msg(); + msg->ParseFromZmsg(zmsg); + return msg; +} + +void* Dealer::InternalID() const { + return dealer_; +} + +Router::Router() : Router(100) {} + +Router::Router(int bufsize) { + nBufmsg_ = 0; + bufsize_ = bufsize; + router_ = zsock_new(ZMQ_ROUTER); + CHECK_NOTNULL(router_); + poller_ = zpoller_new(router_); + CHECK_NOTNULL(poller_); +} + +Router::~Router() { + zsock_destroy(&router_); + for (auto it : id2addr_) + zframe_destroy(&it.second); + for (auto it : bufmsg_) { + for (auto *msg : it.second) + zmsg_destroy(&msg); + } +} +int Router::Bind(const std::string& endpoint) { + int port = -1; + if (endpoint.length()) { + port = zsock_bind(router_, "%s", endpoint.c_str()); + } + CHECK_NE(port, -1) << endpoint; + LOG(INFO) << "bind successfully to " << endpoint + ":" + std::to_string(port); + return port; +} + +int Router::Send(Msg **msg) { + zmsg_t* zmsg = (*msg)->DumpToZmsg(); + int dstid = (*msg)->dst(); + if (id2addr_.find(dstid) != id2addr_.end()) { + // the connection has already been set up + zframe_t* addr = zframe_dup(id2addr_[dstid]); + zmsg_prepend(zmsg, &addr); + zmsg_send(&zmsg, router_); + } else { + // the connection is not ready, buffer the message + if (bufmsg_.size() == 0) + nBufmsg_ = 0; + bufmsg_[dstid].push_back(zmsg); + ++nBufmsg_; + CHECK_LE(nBufmsg_, bufsize_); + } + delete *msg; + *msg = nullptr; + return 1; +} + +Msg* Router::Receive() { + zmsg_t* zmsg = zmsg_recv(router_); + if (zmsg == nullptr) { + LOG(ERROR) << "Connection broken!"; + exit(0); + } + zframe_t* dealer = zmsg_pop(zmsg); + Msg* msg = new Msg(); + msg->ParseFromZmsg(zmsg); + if (id2addr_.find(msg->src()) == id2addr_.end()) { + // new connection, store the sender's identfier and send buffered messages + // for it + id2addr_[msg->src()] = dealer; + if (bufmsg_.find(msg->src()) != bufmsg_.end()) { + for (auto& it : bufmsg_.at(msg->src())) { + zframe_t* addr = zframe_dup(dealer); + zmsg_prepend(it, &addr); + zmsg_send(&it, router_); + } + bufmsg_.erase(msg->src()); + } + } else { + zframe_destroy(&dealer); + } + return msg; +} + +void* Router::InternalID() const { + return router_; +} +#endif + +} // namespace singa http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/321ef96a/src/communication/msg.cc ---------------------------------------------------------------------- diff --git a/src/communication/msg.cc b/src/communication/msg.cc deleted file mode 100644 index 6042057..0000000 --- a/src/communication/msg.cc +++ /dev/null @@ -1,215 +0,0 @@ -/************************************************************ -* -* Licensed to the Apache Software Foundation (ASF) under one -* or more contributor license agreements. See the NOTICE file -* distributed with this work for additional information -* regarding copyright ownership. The ASF licenses this file -* to you under the Apache License, Version 2.0 (the -* "License"); you may not use this file except in compliance -* with the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, -* software distributed under the License is distributed on an -* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -* KIND, either express or implied. See the License for the -* specific language governing permissions and limitations -* under the License. -* -*************************************************************/ - -#include "communication/msg.h" - -#include <glog/logging.h> - -namespace singa { - -#ifdef USE_ZMQ -Msg::~Msg() { - if (msg_ != nullptr) - zmsg_destroy(&msg_); - frame_ = nullptr; -} - -Msg::Msg() { - msg_ = zmsg_new(); -} - -Msg::Msg(const Msg& msg) { - src_ = msg.src_; - dst_ = msg.dst_; - type_ = msg.type_; - trgt_val_ = msg.trgt_val_; - trgt_version_ = msg.trgt_version_; - msg_ = zmsg_dup(msg.msg_); -} - -Msg::Msg(int src, int dst) { - src_ = src; - dst_ = dst; - msg_ = zmsg_new(); -} - -void Msg::SwapAddr() { - std::swap(src_, dst_); -} - -int Msg::size() const { - return zmsg_content_size(msg_); -} - -void Msg::AddFrame(const void* addr, int nBytes) { - zmsg_addmem(msg_, addr, nBytes); -} - -int Msg::FrameSize() { - return zframe_size(frame_); -} - -void* Msg::FrameData() { - return zframe_data(frame_); -} - -char* Msg::FrameStr() { - return zframe_strdup(frame_); -} -bool Msg::NextFrame() { - frame_ = zmsg_next(msg_); - return frame_ != nullptr; -} - -void Msg::FirstFrame() { - frame_ = zmsg_first(msg_); -} - -void Msg::LastFrame() { - frame_ = zmsg_last(msg_); -} - -void Msg::ParseFromZmsg(zmsg_t* msg) { - char* tmp = zmsg_popstr(msg); - sscanf(tmp, "%d %d %d %d %d", - &src_, &dst_, &type_, &trgt_val_, &trgt_version_); - frame_ = zmsg_first(msg); - msg_ = msg; -} - -zmsg_t* Msg::DumpToZmsg() { - zmsg_pushstrf(msg_, "%d %d %d %d %d", - src_, dst_, type_, trgt_val_, trgt_version_); - zmsg_t *tmp = msg_; - msg_ = nullptr; - return tmp; -} - -// frame marker indicating this frame is serialize like printf -#define FMARKER "*singa*" - -#define kMaxFrameLen 2048 - -int Msg::AddFormatFrame(const char *format, ...) { - va_list argptr; - va_start(argptr, format); - int size = strlen(FMARKER); - char dst[kMaxFrameLen]; - memcpy(dst, FMARKER, size); - dst[size++] = 0; - while (*format) { - if (*format == 'i') { - int x = va_arg(argptr, int); - dst[size++] = 'i'; - memcpy(dst + size, &x, sizeof(x)); - size += sizeof(x); - } else if (*format == 'f') { - float x = static_cast<float> (va_arg(argptr, double)); - dst[size++] = 'f'; - memcpy(dst + size, &x, sizeof(x)); - size += sizeof(x); - } else if (*format == '1') { - uint8_t x = va_arg(argptr, int); - memcpy(dst + size, &x, sizeof(x)); - size += sizeof(x); - } else if (*format == '2') { - uint16_t x = va_arg(argptr, int); - memcpy(dst + size, &x, sizeof(x)); - size += sizeof(x); - } else if (*format == '4') { - uint32_t x = va_arg(argptr, uint32_t); - memcpy(dst + size, &x, sizeof(x)); - size += sizeof(x); - } else if (*format == 's') { - char* x = va_arg(argptr, char *); - dst[size++] = 's'; - memcpy(dst + size, x, strlen(x)); - size += strlen(x); - dst[size++] = 0; - } else if (*format == 'p') { - void* x = va_arg(argptr, void *); - dst[size++] = 'p'; - memcpy(dst + size, &x, sizeof(x)); - size += sizeof(x); - } else { - LOG(ERROR) << "Unknown format " << *format; - } - format++; - CHECK_LE(size, kMaxFrameLen); - } - va_end(argptr); - zmsg_addmem(msg_, dst, size); - return size; -} - -int Msg::ParseFormatFrame(const char *format, ...) { - va_list argptr; - va_start(argptr, format); - char* src = zframe_strdup(frame_); - CHECK_STREQ(FMARKER, src); - int size = strlen(FMARKER) + 1; - while (*format) { - if (*format == 'i') { - int *x = va_arg(argptr, int *); - CHECK_EQ(src[size++], 'i'); - memcpy(x, src + size, sizeof(*x)); - size += sizeof(*x); - } else if (*format == 'f') { - float *x = va_arg(argptr, float *); - CHECK_EQ(src[size++], 'f'); - memcpy(x, src + size, sizeof(*x)); - size += sizeof(*x); - } else if (*format == '1') { - uint8_t *x = va_arg(argptr, uint8_t *); - memcpy(x, src + size, sizeof(*x)); - size += sizeof(*x); - } else if (*format == '2') { - uint16_t *x = va_arg(argptr, uint16_t *); - memcpy(x, src + size, sizeof(*x)); - size += sizeof(*x); - } else if (*format == '4') { - uint32_t *x = va_arg(argptr, uint32_t *); - memcpy(x, src + size, sizeof(*x)); - size += sizeof(*x); - } else if (*format == 's') { - char* x = va_arg(argptr, char *); - CHECK_EQ(src[size++], 's'); - int len = strlen(src + size); - memcpy(x, src + size, len); - x[len] = 0; - size += len + 1; - } else if (*format == 'p') { - void** x = va_arg(argptr, void **); - CHECK_EQ(src[size++], 'p'); - memcpy(x, src + size, sizeof(*x)); - size += sizeof(*x); - } else { - LOG(ERROR) << "Unknown format type " << *format; - } - format++; - } - va_end(argptr); - delete src; - return size; -} -#endif - -} // namespace singa http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/321ef96a/src/communication/socket.cc ---------------------------------------------------------------------- diff --git a/src/communication/socket.cc b/src/communication/socket.cc deleted file mode 100644 index 60e1cc1..0000000 --- a/src/communication/socket.cc +++ /dev/null @@ -1,180 +0,0 @@ -/************************************************************ -* -* Licensed to the Apache Software Foundation (ASF) under one -* or more contributor license agreements. See the NOTICE file -* distributed with this work for additional information -* regarding copyright ownership. The ASF licenses this file -* to you under the Apache License, Version 2.0 (the -* "License"); you may not use this file except in compliance -* with the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, -* software distributed under the License is distributed on an -* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -* KIND, either express or implied. See the License for the -* specific language governing permissions and limitations -* under the License. -* -*************************************************************/ -#include "communication/socket.h" - -#include <glog/logging.h> - -namespace singa { - -#ifdef USE_ZMQ -Poller::Poller() { - poller_ = zpoller_new(nullptr); -} - -Poller::Poller(SocketInterface* socket) { - poller_ = zpoller_new(nullptr); - Add(socket); -} - -void Poller::Add(SocketInterface* socket) { - zsock_t* zsock = static_cast<zsock_t*>(socket->InternalID()); - zpoller_add(poller_, zsock); - zsock2Socket_[zsock] = socket; -} - -SocketInterface* Poller::Wait(int timeout) { - zsock_t* sock = static_cast<zsock_t*>(zpoller_wait(poller_, timeout)); - if (sock != nullptr) - return zsock2Socket_[sock]; - else - return nullptr; -} - -bool Poller::Terminated() { - return zpoller_terminated(poller_); -} - - -Dealer::Dealer() : Dealer(-1) {} - -Dealer::Dealer(int id) : id_(id) { - dealer_ = zsock_new(ZMQ_DEALER); - CHECK_NOTNULL(dealer_); -} - -Dealer::~Dealer() { - zsock_destroy(&dealer_); -} - -int Dealer::Connect(const std::string& endpoint) { - CHECK_GT(endpoint.length(), 0); - if (endpoint.length()) { - CHECK_EQ(zsock_connect(dealer_, "%s", endpoint.c_str()), 0); - return 1; - } - return 0; -} - -int Dealer::Send(Msg** msg) { - zmsg_t* zmsg = (*msg)->DumpToZmsg(); - zmsg_send(&zmsg, dealer_); - delete *msg; - *msg = nullptr; - return 1; -} - -Msg* Dealer::Receive() { - zmsg_t* zmsg = zmsg_recv(dealer_); - if (zmsg == nullptr) - return nullptr; - Msg* msg = new Msg(); - msg->ParseFromZmsg(zmsg); - return msg; -} - -void* Dealer::InternalID() const { - return dealer_; -} - -Router::Router() : Router(100) {} - -Router::Router(int bufsize) { - nBufmsg_ = 0; - bufsize_ = bufsize; - router_ = zsock_new(ZMQ_ROUTER); - CHECK_NOTNULL(router_); - poller_ = zpoller_new(router_); - CHECK_NOTNULL(poller_); -} - -Router::~Router() { - zsock_destroy(&router_); - for (auto it : id2addr_) - zframe_destroy(&it.second); - for (auto it : bufmsg_) { - for (auto *msg : it.second) - zmsg_destroy(&msg); - } -} -int Router::Bind(const std::string& endpoint) { - int port = -1; - if (endpoint.length()) { - port = zsock_bind(router_, "%s", endpoint.c_str()); - } - CHECK_NE(port, -1) << endpoint; - LOG(INFO) << "bind successfully to " << endpoint + ":" + std::to_string(port); - return port; -} - -int Router::Send(Msg **msg) { - zmsg_t* zmsg = (*msg)->DumpToZmsg(); - int dstid = (*msg)->dst(); - if (id2addr_.find(dstid) != id2addr_.end()) { - // the connection has already been set up - zframe_t* addr = zframe_dup(id2addr_[dstid]); - zmsg_prepend(zmsg, &addr); - zmsg_send(&zmsg, router_); - } else { - // the connection is not ready, buffer the message - if (bufmsg_.size() == 0) - nBufmsg_ = 0; - bufmsg_[dstid].push_back(zmsg); - ++nBufmsg_; - CHECK_LE(nBufmsg_, bufsize_); - } - delete *msg; - *msg = nullptr; - return 1; -} - -Msg* Router::Receive() { - zmsg_t* zmsg = zmsg_recv(router_); - if (zmsg == nullptr) { - LOG(ERROR) << "Connection broken!"; - exit(0); - } - zframe_t* dealer = zmsg_pop(zmsg); - Msg* msg = new Msg(); - msg->ParseFromZmsg(zmsg); - if (id2addr_.find(msg->src()) == id2addr_.end()) { - // new connection, store the sender's identfier and send buffered messages - // for it - id2addr_[msg->src()] = dealer; - if (bufmsg_.find(msg->src()) != bufmsg_.end()) { - for (auto& it : bufmsg_.at(msg->src())) { - zframe_t* addr = zframe_dup(dealer); - zmsg_prepend(it, &addr); - zmsg_send(&it, router_); - } - bufmsg_.erase(msg->src()); - } - } else { - zframe_destroy(&dealer); - } - return msg; -} - -void* Router::InternalID() const { - return router_; -} -#endif - -} // namespace singa http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/321ef96a/src/driver.cc ---------------------------------------------------------------------- diff --git a/src/driver.cc b/src/driver.cc index 6fa70ee..d3f0f3e 100644 --- a/src/driver.cc +++ b/src/driver.cc @@ -19,16 +19,17 @@ * *************************************************************/ -#include "driver.h" - #include <glog/logging.h> +#include <set> #include <string> #include "neuralnet/layer.h" -#include "trainer/trainer.h" #include "utils/common.h" #include "utils/tinydir.h" +#include "utils/cluster.h" +#include "./stub.h" +#include "./driver.h" -extern "C" void openblas_set_num_threads(int); +extern "C" void openblas_set_num_threads(int num); namespace singa { @@ -109,22 +110,192 @@ void Driver::Init(int argc, char **argv) { } -void Driver::Submit(bool resume, const JobProto& jobConf) { +void Driver::Train(bool resume, const JobProto& job_conf) { + Cluster::Setup(job_id_, singa_conf_, job_conf.cluster()); if (singa_conf_.has_log_dir()) - SetupLog(singa_conf_.log_dir(), std::to_string(job_id_) - + "-" + jobConf.name()); + SetupLog(singa_conf_.log_dir(), + std::to_string(job_id_) + "-" + job_conf.name()); tinydir_dir workspace; - if (tinydir_open(&workspace, jobConf.cluster().workspace().c_str()) == -1) - LOG(FATAL) << "workspace does not exist: " << jobConf.cluster().workspace(); - if (jobConf.num_openblas_threads() != 1) - LOG(WARNING) << "openblas with " - << jobConf.num_openblas_threads() << " threads"; - openblas_set_num_threads(jobConf.num_openblas_threads()); + if (tinydir_open(&workspace, job_conf.cluster().workspace().c_str()) == -1) + LOG(FATAL) << "workspace not exist: " << job_conf.cluster().workspace(); + if (job_conf.num_openblas_threads() != 1) + LOG(WARNING) << "openblas luanches " + << job_conf.num_openblas_threads() << " threads"; + openblas_set_num_threads(job_conf.num_openblas_threads()); + JobProto job; - job.CopyFrom(jobConf); + job.CopyFrom(job_conf); + if (resume) + SetupForResume(&job); job.set_id(job_id_); - Trainer trainer; - trainer.Start(resume, singa_conf_, &job); + Train(job); } +void Driver::Train(const JobProto& job_conf) { + auto cluster = Cluster::Get(); + int nserver_grps = cluster->nserver_groups(); + int grp_size = cluster->nworkers_per_group(); + Stub stub; + // no need to create Stub if there is only a single worker without servers, + // i.e., the training will be conducted by the single worker. + if (grp_size > 1 || nserver_grps > 0) { + stub.Setup(); + // TODO(wangwei) register endpoint to zookeeper if > 1 procs; + cluster->Register(getpid(), stub.endpoint()); // getpid() is from unistd.h + } + + NeuralNet* net = NeuralNet::Create(job_conf.neuralnet(), kTrain, grp_size); + const vector<Worker*> workers = CreateWorkers(job_conf, net); + const vector<Server*> servers = CreateServers(job_conf, net); + +#ifdef USE_MPI + int nthreads = workers.size() + servers.size() + 1; + for (int i = 0; i < nthreads; i++) + MPIQueues.push_back(make_shared<SafeQueue>()); +#endif + + vector<std::thread> threads; + for (auto server : servers) + threads.push_back(std::thread(&Server::Run, server)); + for (auto worker : workers) + threads.push_back(std::thread(&Worker::Run, worker)); + if (grp_size > 1 || nserver_grps > 0) { + int nservers_per_grp = cluster->nservers_per_group(); + int lcm = LeastCommonMultiple(nservers_per_grp, nserver_grps); + auto slices = Param::ComputeSlices(lcm, net->params()); + auto slice2server = PartitionSlices(nservers_per_grp, slices); + stub.Run(slice2server, workers, servers); + } + + for (auto& thread : threads) + thread.join(); + for (auto server : servers) + delete server; + delete net; + std::set<NeuralNet*> deleted{net, nullptr}; + for (auto worker : workers) { + for (auto ptr : worker->GetNets()) + if (deleted.find(ptr) == deleted.end()) { + delete ptr; + deleted.insert(ptr); + } + delete worker; + } +} + +void Driver::SetupForResume(JobProto* job_conf) { + tinydir_dir dir; + std::string folder = Cluster::Get()->checkpoint_folder(); + tinydir_open(&dir, folder.c_str()); + int latest_step = 0; + // there would be multi checkpoint files (from diff workers) for one step + vector<std::string> ck_files; + // iterate all files to get the files for the last checkpoint + while (dir.has_next) { + tinydir_file file; + tinydir_readfile(&dir, &file); + tinydir_next(&dir); + char* ch = strstr(file.name, "step"); + if (ch == nullptr) { + if (file.name[0] != '.') + LOG(INFO) << "Irregular file in checkpoint folder: " << file.name; + continue; + } + LOG(INFO) << "Add checkpoint file for resume: " << ch; + int step = atoi(ch+4); + if (step == latest_step) { + ck_files.push_back(file.name); + } else if (step > latest_step) { + latest_step = step; + ck_files.clear(); + ck_files.push_back(std::string(file.name)); + } + } + if (latest_step > 0) { + job_conf->set_step(latest_step); + if (!job_conf->has_reset_param_version()) + job_conf->set_reset_param_version(false); + job_conf->clear_checkpoint_path(); + for (auto ck_file : ck_files) + job_conf->add_checkpoint_path(folder + "/" + ck_file); + } + tinydir_close(&dir); +} + +const vector<Worker*> Driver::CreateWorkers(const JobProto& job_conf, + NeuralNet* net) { + auto cluster = Cluster::Get(); + vector<Worker*> workers; + if (!cluster->has_worker()) return workers; + int wgrp_size = cluster->nworkers_per_group(); + int nservers_per_grp = cluster->nservers_per_group(); + int nserver_grps = cluster->nserver_groups(); + int lcm = LeastCommonMultiple(nserver_grps, nservers_per_grp); + const vector<int> rng = cluster->ExecutorRng(cluster->procs_id(), + cluster->nworkers_per_group(), cluster->nworkers_per_procs()); + int gstart = rng[0], gend = rng[1], wstart = rng[2], wend = rng[3]; + for (int gid = gstart; gid < gend; gid++) { + NeuralNet* train_net = nullptr, *test_net = nullptr, *val_net = nullptr; + if (gid == gstart) { + train_net = net; + Param::SliceParams(lcm, train_net->params()); + // test and validation are performed by the 1st group. + if (gid == 0 && job_conf.test_steps() > 0) { + test_net = NeuralNet::Create(job_conf.neuralnet(), kTest, 1); + test_net->ShareParamsFrom(train_net); + } + if (gid == 0 && job_conf.validate_steps() > 0) { + val_net = NeuralNet::Create(job_conf.neuralnet(), kVal, 1); + val_net->ShareParamsFrom(train_net); + } + } else { + train_net = NeuralNet::Create(job_conf.neuralnet(), kTrain, wgrp_size); + if (cluster->share_memory()) { + train_net->ShareParamsFrom(net); + } else { + Param::SliceParams(lcm, train_net->params()); + } + } + for (int wid = wstart; wid < wend; wid++) { + auto *worker = Worker::Create(job_conf.train_one_batch()); + // TODO(wangwei) extend to test among workers in a grp + if (wid == 0) + worker->Setup(gid, wid, job_conf, train_net, val_net, test_net); + else + worker->Setup(gid, wid, job_conf, train_net, nullptr, nullptr); + workers.push_back(worker); + } + } + return workers; +} + +const vector<Server*> Driver::CreateServers(const JobProto& job_conf, + NeuralNet* net) { + auto cluster = Cluster::Get(); + vector<Server*> servers; + if (!cluster->has_server()) return servers; + int nservers_per_grp = cluster->nservers_per_group(); + int nserver_grps = cluster->nserver_groups(); + int lcm = LeastCommonMultiple(nserver_grps, nservers_per_grp); + auto slices = Param::ComputeSlices(lcm, net->params()); + // partition among server groups, each group maintains one sub-set for sync + auto slice2group = PartitionSlices(nserver_grps, slices); + // partition within one server group, each server updates for one sub-set + auto slice2server = PartitionSlices(nservers_per_grp, slices); + + int server_procs = cluster->procs_id(); + // if true, server procs (logical) id starts after worker procs + if (cluster->server_worker_separate()) + server_procs -= cluster->nworker_procs(); + const vector<int> rng = cluster->ExecutorRng(server_procs, + cluster->nservers_per_group(), cluster->nservers_per_procs()); + int gstart = rng[0], gend = rng[1], start = rng[2], end = rng[3]; + for (int gid = gstart; gid < gend; gid++) { + for (int sid = start; sid < end; sid++) { + auto server = new Server(gid, sid, job_conf, slice2group, slice2server); + servers.push_back(server); + } + } + return servers; +} } // namespace singa http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/321ef96a/src/main.cc ---------------------------------------------------------------------- diff --git a/src/main.cc b/src/main.cc index 5d2ab2f..99c91b8 100644 --- a/src/main.cc +++ b/src/main.cc @@ -7,9 +7,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at -* +* * http://www.apache.org/licenses/LICENSE-2.0 -* +* * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -19,9 +19,9 @@ * *************************************************************/ -#include "singa.h" +#include "./singa.h" /** - * \file main.cc provides an example main func. + * \file main.cc provides an example main function. * * Like the main func of Hadoop, it prepares the job configuration and submit it * to the Driver which starts the training. @@ -31,19 +31,17 @@ * func must call Driver::Init at the beginning, and pass the job configuration * and resume option to the Driver for job submission. * - * Optionally, users can register their own implemented classes, e.g., layer, - * updater, through the registration func provided by the Driver. + * Optionally, users can register their own implemented subclasses of Layer, + * Updater, etc. through the registration function provided by the Driver. * * Users must pass at least one argument to the singa-run.sh, i.e., the job * configuration file which includes the cluster topology setting. Other fields * e.g, neuralnet, updater can be configured in main.cc. * * TODO - * Add helper functions for users to generate their configurations easily. - * e.g., AddLayer(layer_type, source_layers, meta_data), - * or, MLP(layer1_size, layer2_size, tanh, loss); + * Add helper functions for users to generate configurations for popular models + * easily, e.g., MLP(layer1_size, layer2_size, tanh, loss); */ - int main(int argc, char **argv) { // must create driver at the beginning and call its Init method. singa::Driver driver; @@ -58,7 +56,7 @@ int main(int argc, char **argv) { // get the job conf, and custmize it if need singa::JobProto jobConf = driver.job_conf(); - // submit the job - driver.Submit(resume, jobConf); + // submit the job for training + driver.Train(resume, jobConf); return 0; }
