http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/239ed217/include/singa/utils/common.h ---------------------------------------------------------------------- diff --git a/include/singa/utils/common.h b/include/singa/utils/common.h new file mode 100644 index 0000000..ef47b6c --- /dev/null +++ b/include/singa/utils/common.h @@ -0,0 +1,155 @@ +/************************************************************ +* +* Licensed to the Apache Software Foundation (ASF) under one +* or more contributor license agreements. See the NOTICE file +* distributed with this work for additional information +* regarding copyright ownership. The ASF licenses this file +* to you under the Apache License, Version 2.0 (the +* "License"); you may not use this file except in compliance +* with the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an +* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +* KIND, either express or implied. See the License for the +* specific language governing permissions and limitations +* under the License. +* +*************************************************************/ + +#ifndef SINGA_UTILS_COMMON_H_ +#define SINGA_UTILS_COMMON_H_ + +#include <google/protobuf/message.h> +#include <unordered_map> +#include <sstream> +#include <string> +#include <vector> +#include <utility> +#include "singa/proto/common.pb.h" + +namespace singa { + +std::string IntVecToString(const std::vector<int>& vec); +std::string VStringPrintf(std::string fmt, va_list l); +std::string StringPrintf(std::string fmt, ...); + +/** + * Locate the position of the arg in arglist. + * + * @param argc total num of arguments + * @param arglist all arguments + * @param the searched argument + * @return the position of arg in the arglist; -1 if not found. + */ +int ArgPos(int argc, char** arglist, const char* arg); +void CreateFolder(const std::string name); +/** + * Slice a set of large Params into small pieces such that they can be roughtly + * equally partitioned into a fixed number of boxes. + * + * @param num total number of boxes to store the small pieces + * @param sizes size of all Params + * @return all slices for each Param + */ +const std::vector<std::vector<int>> Slice(int num, + const std::vector<int>& sizes); +/** + * Partition slices into boxes. + * + * @param num number of boxes + * @param slices slice sizes + * @return box id for each slice + */ +const std::vector<int> PartitionSlices(int num, const std::vector<int>& slices); +/* +inline void Sleep(int millisec=1){ + std::this_thread::sleep_for(std::chrono::milliseconds(millisec)); +} +*/ +int gcd(int a, int b); +int LeastCommonMultiple(int a, int b); +/* +inline float rand_real() { + return static_cast<float>(rand_r())/(RAND_MAX+1.0f); +} +*/ +std::string GetHostIP(); +void SetupLog(const std::string& workspace, const std::string& model); + +/** + * Performance mtrics. + */ +class Metric { + public: + Metric() {} + explicit Metric(const std::string& str); + /** + * Add one metric. + * + * If the metric exist, the aggregate. Otherwise create a new entry for it. + * + * @param name metric name, e.g., 'loss' + * @param value metric value + */ + void Add(const std::string& name, float value); + void Add(const std::string& name, float value, int count); + /** + * reset all metric counter and value to 0 + */ + void Reset(); + /** + * Generate a one-line string for logging + */ + std::string ToLogString() const; + /** + * Serialize the object into a string + */ + std::string ToString() const; + /** + * Parse the metric from a string + */ + void ParseFrom(const std::string& msg); + + private: + std::unordered_map<std::string, std::pair<int, float>> entry_; +}; + +using google::protobuf::Message; +void Im2col(const float* data_im, const int channels, + const int height, const int width, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + float* data_col); +void Col2im(const float* data_col, const int channels, + const int height, const int width, const int patch_h, const int patch_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + float* data_im); +void ForwardMaxPooling(const float* bottom, const int num, const int channels, + const int height, const int width, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + float* top, float* mask); +void BackwardMaxPooling(const float* top, const float* mask, const int num, + const int channels, const int height, const int width, + const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + float* bottom); +void ForwardAvgPooling(const float* bottom, const int num, const int channels, + const int height, const int width, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + float* top); +void BackwardAvgPooling(const float* top, const int num, const int channels, + const int height, const int width, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + float* bottom); + +void ReadProtoFromTextFile(const char* filename, Message* proto); +void WriteProtoToTextFile(const Message& proto, const char* filename); +void ReadProtoFromBinaryFile(const char* filename, Message* proto); +void WriteProtoToBinaryFile(const Message& proto, const char* filename); + + +} // namespace singa + +#endif // SINGA_UTILS_COMMON_H_
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/239ed217/include/singa/utils/data_shard.h ---------------------------------------------------------------------- diff --git a/include/singa/utils/data_shard.h b/include/singa/utils/data_shard.h new file mode 100644 index 0000000..7d69ae5 --- /dev/null +++ b/include/singa/utils/data_shard.h @@ -0,0 +1,171 @@ +/************************************************************ +* +* Licensed to the Apache Software Foundation (ASF) under one +* or more contributor license agreements. See the NOTICE file +* distributed with this work for additional information +* regarding copyright ownership. The ASF licenses this file +* to you under the Apache License, Version 2.0 (the +* "License"); you may not use this file except in compliance +* with the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an +* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +* KIND, either express or implied. See the License for the +* specific language governing permissions and limitations +* under the License. +* +*************************************************************/ + +#ifndef SINGA_UTILS_DATA_SHARD_H_ +#define SINGA_UTILS_DATA_SHARD_H_ + +#include <google/protobuf/message.h> +#include <fstream> +#include <string> +#include <unordered_set> + +namespace singa { + +/** + * Data shard stores training/validation/test tuples. + * Every worker node should have a training shard (validation/test shard + * is optional). The shard file for training is + * singa::Cluster::workspace()/train/shard.dat; The shard file for validation + * is singa::Cluster::workspace()/train/shard.dat; Similar path for test. + * + * shard.dat consists of a set of unordered tuples. Each tuple is + * encoded as [key_len key record_len val] (key_len and record_len are of type + * uint32, which indicate the bytes of key and record respectively. + * + * When Shard obj is created, it will remove the last key if the record size + * and key size do not match because the last write of tuple crashed. + * + * TODO + * 1. split one shard into multiple shards. + * 2. add threading to prefetch and parse records + * + */ +class DataShard { + public: + enum { + // read only mode used in training + kRead = 0, + // write mode used in creating shard (will overwrite previous one) + kCreate = 1, + // append mode, e.g. used when previous creating crashes + kAppend = 2 + }; + + /** + * Init the shard obj. + * + * @param folder Shard folder (path excluding shard.dat) on worker node + * @param mode Shard open mode, Shard::kRead, Shard::kWrite or Shard::kAppend + * @param bufsize Batch bufsize bytes data for every disk op (read or write), + * default is 100MB + */ + DataShard(const std::string& folder, int mode); + DataShard(const std::string& folder, int mode, int capacity); + ~DataShard(); + + /** + * read next tuple from the shard. + * + * @param key Tuple key + * @param val Record of type Message + * @return false if read unsuccess, e.g., the tuple was not inserted + * completely. + */ + bool Next(std::string* key, google::protobuf::Message* val); + /** + * read next tuple from the shard. + * + * @param key Tuple key + * @param val Record of type string + * @return false if read unsuccess, e.g., the tuple was not inserted + * completely. + */ + bool Next(std::string* key, std::string* val); + /** + * Append one tuple to the shard. + * + * @param key e.g., image path + * @param val + * @return false if unsucess, e.g., inserted before + */ + bool Insert(const std::string& key, const google::protobuf::Message& tuple); + /** + * Append one tuple to the shard. + * + * @param key e.g., image path + * @param val + * @return false if unsucess, e.g., inserted before + */ + bool Insert(const std::string& key, const std::string& tuple); + /** + * Move the read pointer to the head of the shard file. + * Used for repeated reading. + */ + void SeekToFirst(); + /** + * Flush buffered data to disk. + * Used only for kCreate or kAppend. + */ + void Flush(); + /** + * Iterate through all tuples to get the num of all tuples. + * + * @return num of tuples + */ + int Count(); + /** + * @return path to shard file + */ + inline std::string path() { return path_; } + + protected: + /** + * Read the next key and prepare buffer for reading value. + * + * @param key + * @return length (i.e., bytes) of value field. + */ + int Next(std::string* key); + /** + * Setup the disk pointer to the right position for append in case that + * the pervious write crashes. + * + * @param path shard path. + * @return offset (end pos) of the last success written record. + */ + int PrepareForAppend(const std::string& path); + /** + * Read data from disk if the current data in the buffer is not a full field. + * + * @param size size of the next field. + */ + bool PrepareNextField(int size); + + private: + char mode_ = 0; + std::string path_ = ""; + // either ifstream or ofstream + std::fstream fdat_; + // to avoid replicated record + std::unordered_set<std::string> keys_; + // internal buffer + char* buf_ = nullptr; + // offset inside the buf_ + int offset_ = 0; + // allocated bytes for the buf_ + int capacity_ = 0; + // bytes in buf_, used in reading + int bufsize_ = 0; +}; + +} // namespace singa + +#endif // SINGA_UTILS_DATA_SHARD_H_ http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/239ed217/include/singa/utils/factory.h ---------------------------------------------------------------------- diff --git a/include/singa/utils/factory.h b/include/singa/utils/factory.h new file mode 100644 index 0000000..3af25f0 --- /dev/null +++ b/include/singa/utils/factory.h @@ -0,0 +1,100 @@ +/************************************************************ +* +* Licensed to the Apache Software Foundation (ASF) under one +* or more contributor license agreements. See the NOTICE file +* distributed with this work for additional information +* regarding copyright ownership. The ASF licenses this file +* to you under the Apache License, Version 2.0 (the +* "License"); you may not use this file except in compliance +* with the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an +* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +* KIND, either express or implied. See the License for the +* specific language governing permissions and limitations +* under the License. +* +*************************************************************/ + +#ifndef SINGA_UTILS_FACTORY_H_ +#define SINGA_UTILS_FACTORY_H_ + +#include <glog/logging.h> +#include <functional> +#include <map> +#include <string> + +/** + * Macro that creats a function which instantiate a subclass instance and + * returns pointer to the base class. + */ +#define CreateInstance(SubClass, BaseClass) \ + [](void)->BaseClass* {return new SubClass();} + +/** + * Factory template to generate class (or a sub-class) object based on id. + * 1. register class creation function that generates a class + * object based on id. + * 2. call Create() func to call the creation function and return + * a pointer to the base calss. + */ +template<typename T> +class Factory { + public: + /** + * Register functions to create user defined classes. + * This function is called by the REGISTER_FACTORY macro. + * + * @param id Identifier of the creating function/class + * @param func a function that creates a layer instance + */ + inline void Register(const std::string& id, + const std::function<T*(void)>& func) { + CHECK(str2func_.find(id) == str2func_.end()) + << "The id has been registered by another function"; + str2func_[id] = func; + } + /** + * Register functions to create user defined classes. + * This function is called by the REGISTER_FACTORY macro. + * + * @param id Identifier of the creating function/class + * @param func a function that creates a layer instance + */ + inline void Register(int id, + const std::function<T*(void)>& func) { + CHECK(id2func_.find(id) == id2func_.end()) + << "The id has been registered by another function"; + id2func_[id] = func; + } + /** + * create an instance by providing its id + * + * @param id + */ + inline T* Create(const std::string& id) { + CHECK(str2func_.find(id) != str2func_.end()) + << "The creation function for " << id << " has not been registered"; + return str2func_[id](); + } + /** + * create an instance by providing its id + * + * @param id + */ + inline T* Create(int id) { + CHECK(id2func_.find(id) != id2func_.end()) + << "The creation function for " << id << " has not been registered"; + return id2func_[id](); + } + + private: + // Map that stores the registered creation functions + std::map<std::string, std::function<T*(void)>> str2func_; + std::map<int, std::function<T*(void)>> id2func_; +}; + +#endif // SINGA_UTILS_FACTORY_H_ http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/239ed217/include/singa/utils/graph.h ---------------------------------------------------------------------- diff --git a/include/singa/utils/graph.h b/include/singa/utils/graph.h new file mode 100644 index 0000000..bad7b19 --- /dev/null +++ b/include/singa/utils/graph.h @@ -0,0 +1,118 @@ +/************************************************************ +* +* Licensed to the Apache Software Foundation (ASF) under one +* or more contributor license agreements. See the NOTICE file +* distributed with this work for additional information +* regarding copyright ownership. The ASF licenses this file +* to you under the Apache License, Version 2.0 (the +* "License"); you may not use this file except in compliance +* with the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an +* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +* KIND, either express or implied. See the License for the +* specific language governing permissions and limitations +* under the License. +* +*************************************************************/ + +#ifndef SINGA_UTILS_GRAPH_H_ +#define SINGA_UTILS_GRAPH_H_ + +#include <stack> +#include <string> +#include <map> +#include <vector> + +namespace singa { + +class Node { + public: + /** + * Node constructor. + * + * @param name name of the corresponding layer + */ + explicit Node(std::string name); + /** + * Node constructor. + * + * This node is a partition of some node. + * @param name node name + * @param origin name of the original node + * @param id partition id of this node + * @param proto conf of the corresponding layer + */ + Node(const std::string& name, const std::string& origin, int id, void* proto); + ~Node() {} // the proto field is deleted outside by other functions + void AddDstNode(Node* dstnode); + void AddSrcNode(Node* srcnode); + void RemoveDstNode(Node* dst); + void RemoveSrcNode(Node* src); + + std::string name = ""; + //! name of the origin node/layer from which is node is derived + std::string origin = ""; + //! partition id + int partition_id = 0; + //! proto of the corresponding layer + void* proto = nullptr; + std::vector<Node*> srcnodes; + std::vector<Node*> dstnodes; +}; + +/** + * Neuralnet is constructed by creating a graph with each node representing one + * layer at first. After topology sort for graph nodes, layers are created and + * connected. + */ +class Graph { + public: + Graph() {} + ~Graph(); + /** + * @return all nodes of the graph + */ + inline const std::vector<Node*>& nodes() const { + return nodes_; + } + /** + * @param name node name + * @return return the node of given name + */ + inline Node* node(const std::string& name) const { + return name2node_.at(name); + } + void AddNode(Node* node); + Node* AddNode(const std::string& name); + void AddEdge(Node* srcnode, Node* dstnode); + void AddEdge(const std::string& src, const std::string& dst); + void RemoveEdge(Node* src, Node* dst); + void RemoveEdge(const std::string &src, const std::string& dst); + /** + * Dump the graph into json string which can be used to draw a picture by + * graphviz + */ + std::string ToJson() const; + /** + * \copybreif ToJson() + * + * @param info info associated with each node + */ + std::string ToJson(const std::map<std::string, std::string>& info) const; + /** + * Do topology sort for all nodes of the graph. + */ + void Sort(); + + private: + std::vector<Node*> nodes_; + std::map<std::string, Node*> name2node_; +}; + +} // namespace singa + +#endif // SINGA_UTILS_GRAPH_H_ http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/239ed217/include/singa/utils/image_transform.h ---------------------------------------------------------------------- diff --git a/include/singa/utils/image_transform.h b/include/singa/utils/image_transform.h new file mode 100644 index 0000000..2867ad2 --- /dev/null +++ b/include/singa/utils/image_transform.h @@ -0,0 +1,35 @@ +/************************************************************ +* +* Licensed to the Apache Software Foundation (ASF) under one +* or more contributor license agreements. See the NOTICE file +* distributed with this work for additional information +* regarding copyright ownership. The ASF licenses this file +* to you under the Apache License, Version 2.0 (the +* "License"); you may not use this file except in compliance +* with the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an +* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +* KIND, either express or implied. See the License for the +* specific language governing permissions and limitations +* under the License. +* +*************************************************************/ + +#ifndef SINGA_UTILS_IMAGE_TRANSFORM_H_ +#define SINGA_UTILS_IMAGE_TRANSFORM_H_ + +#include <glog/logging.h> +// TODO(wangwei) provide image transformation API, the implementation can be +// done by opencv, manual transform, or mshadow. +namespace singa { + +void ImageTransform(const float* in, const float* mean, bool mirror, int h_crop, + int w_crop, int h_offset, int w_offset, int channel, int height, int width, + float scale, float* out); +} // namespace singa + +#endif // SINGA_UTILS_IMAGE_TRANSFORM_H_ http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/239ed217/include/singa/utils/param.h ---------------------------------------------------------------------- diff --git a/include/singa/utils/param.h b/include/singa/utils/param.h new file mode 100644 index 0000000..bcfc3f9 --- /dev/null +++ b/include/singa/utils/param.h @@ -0,0 +1,397 @@ +/************************************************************ +* +* Licensed to the Apache Software Foundation (ASF) under one +* or more contributor license agreements. See the NOTICE file +* distributed with this work for additional information +* regarding copyright ownership. The ASF licenses this file +* to you under the Apache License, Version 2.0 (the +* "License"); you may not use this file except in compliance +* with the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an +* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +* KIND, either express or implied. See the License for the +* specific language governing permissions and limitations +* under the License. +* +*************************************************************/ + +#ifndef SINGA_UTILS_PARAM_H_ +#define SINGA_UTILS_PARAM_H_ + +#include <memory> +#include <string> +#include <vector> + +#include "singa/comm/msg.h" +#include "singa/proto/job.pb.h" +#include "singa/utils/blob.h" + +namespace singa { +using std::vector; +/** + * Base parameter generator which intializes parameter values. + */ +class ParamGenerator { + public: + static ParamGenerator* Create(const ParamGenProto& proto); + + virtual ~ParamGenerator() {} + + virtual void Init(const ParamGenProto& proto) { proto_ = proto; } + virtual void Fill(Blob<float>* data); + + protected: + ParamGenProto proto_; +}; + +class GaussianGen : public ParamGenerator { + public: + void Fill(Blob<float>* data) override; +}; + +class GaussianSqrtFanInGen : public GaussianGen { + public: + void Fill(Blob<float>* data) override; +}; + +class UniformGen : public ParamGenerator { + public: + void Fill(Blob<float>* data) override; +}; + +class UniformSqrtFanInGen : public UniformGen { + public: + void Fill(Blob<float>* data) override; +}; + +class UniformSqrtFanInOutGen : public UniformGen { + public: + void Fill(Blob<float>* data) override; +}; + +/** + * Base paramter class. + * + * The Param object is a set of parameters, e.g., the (sub) weight matrix or + * (sub) bias vector. + * + * It has at a gradient Blob and data Blob for gradients and parameter values. + * Since some layers (or neuralnet) share parameter values, the data Blob is a + * shared pointer which can be assigned to many Param objects' data field. + * + * It provides access methods like data(), grad(). It also provides functions + * for generating messages and parsing messages to transferring the Param + * objects among worker-worker, worker-server and server-server. + * + * Param objects are of different sizes, which makes it hard to acheive + * load-balance among servers. Hence, we slice large Param objects into small + * pieces. At the server side, one slice is a Param object. + */ +class Param { + public: + /** + * Create an instance of (sub) Param class based on the type from the + * configuration. + * + * @param[in] conf configuration + * @param a pointer to an instance + */ + static Param* Create(const ParamProto& conf); + + /** + * Try to slice the Param objects (from a neural net) into a given number of + * servers (groups) evenly. This is to achieve load-balance among servers. + * + * It does not change the Param objects, but just computes the length of each + * slice. + * + * @param num number of servers (groups) for maintaining the Param objects. + * @param params all Param objects from a neural net. + * @return the length of each slice. + */ + static const vector<int> ComputeSlices(int num, const vector<Param*>& params); + /** + * It computes the length of each slice and slices the Param objects by adding + * the slicing information into every Param object. + * + * @copydetails ComputeSlices() + */ + static void SliceParams(int num, const vector<Param*>& params); + + Param() {} + virtual ~Param() {} + void Init(const ParamProto& proto) { proto_ = proto; } + /** + * Setup param object + * + * @param conf param configuration, include learning rate multiplier etc. + * @param shape one value per dimension + */ + virtual void Setup(const std::vector<int>& shape); + /* + * Fill the values according to init method, e.g., gaussian distribution. + * + * @param version initial version + */ + virtual void InitValues(); + virtual void InitValues(int version); + /** + * Share the data blob from other Param objects. + * + * @param other the Param object whose owner owns the data blob + */ + void ShareFrom(const Param& other); + /** + * Init param values from checkpoint blob. + */ + void FromProto(const BlobProto& blob); + /** + * Dump param values to blob. + */ + void ToProto(BlobProto* blob); + /** + * Add a slice + * + * @param slice_id + * @param size num of floats for this slice + */ + void AddSlice(int slice_id, int size); + /** + * Scale the learning rate when updating parameters in the Param object + */ + inline float lr_scale() const { return proto_.lr_scale(); } + /** + * Scale the weight decay when updating parameters in the Param object + */ + inline float wd_scale() const { return proto_.wd_scale(); } + /** + * Parameter name used for Param re-use in other model or sharing between + * layers + */ + inline const std::string& name() const { return proto_.name(); } + inline void set_name(const std::string& name) { proto_.set_name(name); } + /** + * If it shares data from others, then owner is the id of that Param, + * otherwise it is itself's id. + */ + inline int owner() const { return proto_.owner(); } + /** + * ID start from 0 and ordered for all Param from the same neuralnet + */ + inline int id() const { return proto_.id(); } + /** + * Set ID + */ + inline void set_id(int id) { + proto_.set_id(id); + proto_.set_owner(id); + } + /** + * Param version is stored inside the data blob to enable all Param objs + * sharing the same values have the same version. + * @return the param version + */ + inline int version() const { return data_->version(); } + inline void set_version(int v) { data_->set_version(v); } + /** + * @return the version of the parameter value local to a worker + */ + inline int local_version() const { return local_version_; } + inline void set_local_version(int v) { local_version_ = v; } + inline const std::string& share_from() const { return proto_.share_from(); } + /** + * @return num of floats. + */ + inline int size() const { return data_->count(); } + inline const Blob<float>& data() const { return *data_; } + inline Blob<float>* mutable_data() { return data_.get(); } + inline const Blob<float> &grad() const { return grad_; } + inline Blob<float> *mutable_grad() { return &grad_; } + inline float* mutable_cpu_data() { return data_->mutable_cpu_data(); } + inline float* mutable_cpu_grad() { return grad_.mutable_cpu_data(); } + inline float* mutable_cpu_history() { return history_.mutable_cpu_data(); } + /** + * @return slice start ID + */ + inline int slice_start() const { return slice_start_; } + inline int num_slices() const { return num_slices_; } + + /** + * Below are message/request related functions. + * The basic communication workflows are as follow: + *------------------------------------------------------------------------ + * |Put |Get |Update |Sync + *------------------------------------------------------------------------ + * Generate|(stub) |(stub) |(stub) |(server) + * Message |GenPutMsg |GenGetMsg |GenUpdateMsg |GenSyncMsg + *------------------------------------------------------------------------ + * Handle |(server) |(server) |(server) |(server) + * Message |HandlePutMsg|HandleGetMsg |ParseUpdateMsg |HandleSyncMsg + * | | |GenUpdateResMsg | + *------------------------------------------------------------------------ + * Handle | |(stub) |(stub) |(server) + * Response| |ParseGetResMsg|ParseUpdateResMsg|ParseSyncResMsg + *------------------------------------------------------------------------ + */ + + /** + * Generate the message for a put request, i.e., put parameters to a server + * + * This function is called at worker/stub side. + * @param copy decides whether to copy the parameter values from the server. + * @param slice_idx index of the slice from which the message is generated. + * @return generated message without setting src, dst, target fields. + */ + virtual Msg* GenPutMsg(bool copy, int slice_idx); + /** + * Generate the message for a get request, i.e., get parameters from a server + * \copydetails GenPutMsg(bool, int); + */ + virtual Msg* GenGetMsg(bool copy, int slice_idx); + /** + * Generate the message for a update request, i.e., pass info to server for + * parameter update. + * \copydetails GenPutMsg(bool, int); + */ + virtual Msg* GenUpdateMsg(bool copy, int slice_idx); + /** + * Generate the message for a synchronization request between server groups. + * + * This function is called at server side where the Param is actually a slice + * of an original Param object. + * */ + virtual Msg* GenSyncMsg(int offset, int size); + /** + * Server handling function for put request. + * + * @param msg request + * @param reserve if true reserve the msg space for the calling function; + * otherwise the msg should be freed inside the function. + * @return resposne message + */ + virtual Msg* HandlePutMsg(Msg** msg, bool reserve); + /** + * Server handling function for put request. + * + * \copydetails HandleGetMsg(Msg**, bool reserve) + */ + virtual Msg* HandleGetMsg(Msg** msg, bool reserve); + /** + * Server parse update requests. + * \copydetails GenUpdateResponseMsgs(const std::vector<Msg*>& msgs); + */ + virtual void ParseUpdateMsgs(const std::vector<Msg*>& msgs); + /** + * Generate the messages to response the update requests. + * + * This function is called at the server side, where the Param is actually a + * slice of an original Param object. + * + * @param msgs for synchronous training, there would be multiple procs in + * which workers sharing the same Param (slice) objects. Their update requests + * is bufferred and handled together. For asynchrnous training, there is only + * request in msgs. + * @return response messages + */ + virtual const std::vector<Msg*> + GenUpdateResponseMsgs(std::vector<Msg*>* msgs, bool reserve); + /** + * Server handling function for synchronization message + * + * \copydetails HandleGetMsg(Msg**, bool reserve) + */ + virtual Msg* HandleSyncMsg(Msg** msg, bool reserve); + /** + * Worker/Stub parsing function for get response. + * + * @param msg + * @param slice_idx index for the slice + */ + virtual int ParseGetResponseMsg(Msg* msg, int slice_idx); + /** + * Worker/Server parsing function for update response + * + * \copydetails ParseGetResponseMsg(Msg**, int); + */ + virtual int ParseUpdateResponseMsg(Msg* msg, int slice_idx); + /** + * Server parsing function for synchronization response. + * + * \copydetails ParseGetResponseMsg(Msg** , int); + */ + virtual int ParseSyncResponseMsg(Msg* msg, int slice_idx); + + protected: + /** + * Implement the common code of ParseGetResponseMsg and ParseUpdateResponseMsg + * \copydetails ParseSyncResponseMsg(Msg* msg, int slice_idx); + */ + void ParseResponseMsg(Msg* msg, int slice_idx); + + protected: + int local_version_ = -1; + // the ID of the first slice + int slice_start_ = 0; + int num_slices_ = 0; + // offset and size of each slice + std::vector<int> slice_offset_; + std::vector<int> slice_size_; + // for debug checking + // since put request has no feedback, we do not track its pending status + std::vector<bool> pending_get_; + std::vector<bool> pending_update_; + int num_pending_requests_ = 0; + // data field + std::shared_ptr<Blob<float>> data_ = nullptr; + // gradient, history gradient of this parameter + Blob<float> grad_, history_; + ParamProto proto_; +}; + +/** + * ParamEntry is used for aggregating gradients of Params shared by workers from + * the same group. + * + * For each worker group, every unique Param object has a ParamEntry object. + * Param objects sharing the same values are associated with the same + * ParamEntry. + */ +class ParamEntry { + public: + ParamEntry() {} + ParamEntry(int total, Param* p); + /** + * Associate the counter to a Param object. + * + * @param p + * @param local 1 if it is used by workers in this procs, 0 otherwise + */ + void AddParam(bool local, Param* p); + int next_version = -1; // next_version & num_update are directly used by stub + int num_update = 0; + int num_local = 0; //!< # local workers using the shared parameter + int num_total = 0; //!< # total workers using the shared parameter + //!< Shares are deleted by neuralnet's destructor + std::vector<Param*> shares; +}; + +inline int ParamTrgt(int param_id, int slice_id) { + return (param_id << 16) | slice_id; +} + +inline int ParamID(int param_trgt) { + return param_trgt >> 16; +} + +inline int SliceID(int param_trgt) { + static const int mask = (1 << 16) -1; + return param_trgt & mask; +} + +} // namespace singa + +#endif // SINGA_UTILS_PARAM_H_ http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/239ed217/include/singa/utils/singleton.h ---------------------------------------------------------------------- diff --git a/include/singa/utils/singleton.h b/include/singa/utils/singleton.h new file mode 100644 index 0000000..4cf487e --- /dev/null +++ b/include/singa/utils/singleton.h @@ -0,0 +1,52 @@ +/************************************************************ +* +* Licensed to the Apache Software Foundation (ASF) under one +* or more contributor license agreements. See the NOTICE file +* distributed with this work for additional information +* regarding copyright ownership. The ASF licenses this file +* to you under the Apache License, Version 2.0 (the +* "License"); you may not use this file except in compliance +* with the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an +* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +* KIND, either express or implied. See the License for the +* specific language governing permissions and limitations +* under the License. +* +*************************************************************/ + +#ifndef SINGA_UTILS_SINGLETON_H_ +#define SINGA_UTILS_SINGLETON_H_ + +/** + * Thread-safe implementation for C++11 according to + * http://stackoverflow.com/questions/2576022/efficient-thread-safe-singleton-in-c + */ +template<typename T> +class Singleton { + public: + static T* Instance() { + static T data_; + return &data_; + } +}; + +/** + * Thread Specific Singleton + * + * Each thread will have its own data_ storage. + */ +template<typename T> +class TSingleton { + public: + static T* Instance() { + static thread_local T data_; + return &data_; + } +}; + +#endif // SINGA_UTILS_SINGLETON_H_ http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/239ed217/include/singa/utils/tinydir.h ---------------------------------------------------------------------- diff --git a/include/singa/utils/tinydir.h b/include/singa/utils/tinydir.h new file mode 100644 index 0000000..abb7000 --- /dev/null +++ b/include/singa/utils/tinydir.h @@ -0,0 +1,562 @@ +/* +Copyright (c) 2013-2014, Cong Xu, Baudouin Feildel +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR +ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ +#ifndef TINYDIR_H +#define TINYDIR_H + +#include <errno.h> +#include <stdlib.h> +#include <string.h> +#ifdef _WIN32 +#define WIN32_LEAN_AND_MEAN +#include <windows.h> +#ifdef _MSC_VER +#pragma warning (disable : 4996) +#endif +#else +#include <dirent.h> +#include <libgen.h> +#include <sys/stat.h> +#endif + + +/* types */ + +#define _TINYDIR_PATH_MAX 4096 +#ifdef _WIN32 +/* extra chars for the "\\*" mask */ +#define _TINYDIR_PATH_EXTRA 2 +#else +#define _TINYDIR_PATH_EXTRA 0 +#endif +#define _TINYDIR_FILENAME_MAX 256 + +#ifdef _MSC_VER +#define _TINYDIR_FUNC static __inline +#else +#define _TINYDIR_FUNC static __inline__ +#endif + +/* Allow user to use a custom allocator by defining _TINYDIR_MALLOC and _TINYDIR_FREE. */ +#if defined(_TINYDIR_MALLOC) && defined(_TINYDIR_FREE) +#elif !defined(_TINYDIR_MALLOC) && !defined(_TINYDIR_FREE) +#else +#error "Either define both alloc and free or none of them!" +#endif + +#if !defined(_TINYDIR_MALLOC) + #define _TINYDIR_MALLOC(_size) malloc(_size) + #define _TINYDIR_FREE(_ptr) free(_ptr) +#endif //!defined(_TINYDIR_MALLOC) + +typedef struct +{ + char path[_TINYDIR_PATH_MAX]; + char name[_TINYDIR_FILENAME_MAX]; + char *extension; + int is_dir; + int is_reg; + +#ifdef _WIN32 +#else + struct stat _s; +#endif +} tinydir_file; + +typedef struct +{ + char path[_TINYDIR_PATH_MAX]; + int has_next; + size_t n_files; + + tinydir_file *_files; +#ifdef _WIN32 + HANDLE _h; + WIN32_FIND_DATAA _f; +#else + DIR *_d; + struct dirent *_e; +#endif +} tinydir_dir; + + +/* declarations */ + +_TINYDIR_FUNC +int tinydir_open(tinydir_dir *dir, const char *path); +_TINYDIR_FUNC +int tinydir_open_sorted(tinydir_dir *dir, const char *path); +_TINYDIR_FUNC +void tinydir_close(tinydir_dir *dir); + +_TINYDIR_FUNC +int tinydir_next(tinydir_dir *dir); +_TINYDIR_FUNC +int tinydir_readfile(const tinydir_dir *dir, tinydir_file *file); +_TINYDIR_FUNC +int tinydir_readfile_n(const tinydir_dir *dir, tinydir_file *file, size_t i); +_TINYDIR_FUNC +int tinydir_open_subdir_n(tinydir_dir *dir, size_t i); + +_TINYDIR_FUNC +void _tinydir_get_ext(tinydir_file *file); +_TINYDIR_FUNC +int _tinydir_file_cmp(const void *a, const void *b); + + +/* definitions*/ + +_TINYDIR_FUNC +int tinydir_open(tinydir_dir *dir, const char *path) +{ + if (dir == NULL || path == NULL || strlen(path) == 0) + { + errno = EINVAL; + return -1; + } + if (strlen(path) + _TINYDIR_PATH_EXTRA >= _TINYDIR_PATH_MAX) + { + errno = ENAMETOOLONG; + return -1; + } + + /* initialise dir */ + dir->_files = NULL; +#ifdef _WIN32 + dir->_h = INVALID_HANDLE_VALUE; +#else + dir->_d = NULL; +#endif + tinydir_close(dir); + + strcpy(dir->path, path); +#ifdef _WIN32 + strcat(dir->path, "\\*"); + dir->_h = FindFirstFileA(dir->path, &dir->_f); + dir->path[strlen(dir->path) - 2] = '\0'; + if (dir->_h == INVALID_HANDLE_VALUE) +#else + dir->_d = opendir(path); + if (dir->_d == NULL) +#endif + { + errno = ENOENT; + goto bail; + } + + /* read first file */ + dir->has_next = 1; +#ifndef _WIN32 + dir->_e = readdir(dir->_d); + if (dir->_e == NULL) + { + dir->has_next = 0; + } +#endif + + return 0; + +bail: + tinydir_close(dir); + return -1; +} + +_TINYDIR_FUNC +int tinydir_open_sorted(tinydir_dir *dir, const char *path) +{ + /* Count the number of files first, to pre-allocate the files array */ + size_t n_files = 0; + if (tinydir_open(dir, path) == -1) + { + return -1; + } + while (dir->has_next) + { + n_files++; + if (tinydir_next(dir) == -1) + { + goto bail; + } + } + tinydir_close(dir); + + if (tinydir_open(dir, path) == -1) + { + return -1; + } + + dir->n_files = 0; + dir->_files = (tinydir_file *)_TINYDIR_MALLOC(sizeof *dir->_files * n_files); + if (dir->_files == NULL) + { + errno = ENOMEM; + goto bail; + } + while (dir->has_next) + { + tinydir_file *p_file; + dir->n_files++; + + p_file = &dir->_files[dir->n_files - 1]; + if (tinydir_readfile(dir, p_file) == -1) + { + goto bail; + } + + if (tinydir_next(dir) == -1) + { + goto bail; + } + + /* Just in case the number of files has changed between the first and + second reads, terminate without writing into unallocated memory */ + if (dir->n_files == n_files) + { + break; + } + } + + qsort(dir->_files, dir->n_files, sizeof(tinydir_file), _tinydir_file_cmp); + + return 0; + +bail: + tinydir_close(dir); + return -1; +} + +_TINYDIR_FUNC +void tinydir_close(tinydir_dir *dir) +{ + if (dir == NULL) + { + return; + } + + memset(dir->path, 0, sizeof(dir->path)); + dir->has_next = 0; + dir->n_files = 0; + if (dir->_files != NULL) + { + _TINYDIR_FREE(dir->_files); + } + dir->_files = NULL; +#ifdef _WIN32 + if (dir->_h != INVALID_HANDLE_VALUE) + { + FindClose(dir->_h); + } + dir->_h = INVALID_HANDLE_VALUE; +#else + if (dir->_d) + { + closedir(dir->_d); + } + dir->_d = NULL; + dir->_e = NULL; +#endif +} + +_TINYDIR_FUNC +int tinydir_next(tinydir_dir *dir) +{ + if (dir == NULL) + { + errno = EINVAL; + return -1; + } + if (!dir->has_next) + { + errno = ENOENT; + return -1; + } + +#ifdef _WIN32 + if (FindNextFileA(dir->_h, &dir->_f) == 0) +#else + dir->_e = readdir(dir->_d); + if (dir->_e == NULL) +#endif + { + dir->has_next = 0; +#ifdef _WIN32 + if (GetLastError() != ERROR_SUCCESS && + GetLastError() != ERROR_NO_MORE_FILES) + { + tinydir_close(dir); + errno = EIO; + return -1; + } +#endif + } + + return 0; +} + +_TINYDIR_FUNC +int tinydir_readfile(const tinydir_dir *dir, tinydir_file *file) +{ + if (dir == NULL || file == NULL) + { + errno = EINVAL; + return -1; + } +#ifdef _WIN32 + if (dir->_h == INVALID_HANDLE_VALUE) +#else + if (dir->_e == NULL) +#endif + { + errno = ENOENT; + return -1; + } + if (strlen(dir->path) + + strlen( +#ifdef _WIN32 + dir->_f.cFileName +#else + dir->_e->d_name +#endif + ) + 1 + _TINYDIR_PATH_EXTRA >= + _TINYDIR_PATH_MAX) + { + /* the path for the file will be too long */ + errno = ENAMETOOLONG; + return -1; + } + if (strlen( +#ifdef _WIN32 + dir->_f.cFileName +#else + dir->_e->d_name +#endif + ) >= _TINYDIR_FILENAME_MAX) + { + errno = ENAMETOOLONG; + return -1; + } + + strcpy(file->path, dir->path); + strcat(file->path, "/"); + strcpy(file->name, +#ifdef _WIN32 + dir->_f.cFileName +#else + dir->_e->d_name +#endif + ); + strcat(file->path, file->name); +#ifndef _WIN32 + if (stat(file->path, &file->_s) == -1) + { + return -1; + } +#endif + _tinydir_get_ext(file); + + file->is_dir = +#ifdef _WIN32 + !!(dir->_f.dwFileAttributes & FILE_ATTRIBUTE_DIRECTORY); +#else + S_ISDIR(file->_s.st_mode); +#endif + file->is_reg = +#ifdef _WIN32 + !!(dir->_f.dwFileAttributes & FILE_ATTRIBUTE_NORMAL) || + ( + !(dir->_f.dwFileAttributes & FILE_ATTRIBUTE_DEVICE) && + !(dir->_f.dwFileAttributes & FILE_ATTRIBUTE_DIRECTORY) && + !(dir->_f.dwFileAttributes & FILE_ATTRIBUTE_ENCRYPTED) && +#ifdef FILE_ATTRIBUTE_INTEGRITY_STREAM + !(dir->_f.dwFileAttributes & FILE_ATTRIBUTE_INTEGRITY_STREAM) && +#endif +#ifdef FILE_ATTRIBUTE_NO_SCRUB_DATA + !(dir->_f.dwFileAttributes & FILE_ATTRIBUTE_NO_SCRUB_DATA) && +#endif + !(dir->_f.dwFileAttributes & FILE_ATTRIBUTE_OFFLINE) && + !(dir->_f.dwFileAttributes & FILE_ATTRIBUTE_TEMPORARY)); +#else + S_ISREG(file->_s.st_mode); +#endif + + return 0; +} + +_TINYDIR_FUNC +int tinydir_readfile_n(const tinydir_dir *dir, tinydir_file *file, size_t i) +{ + if (dir == NULL || file == NULL) + { + errno = EINVAL; + return -1; + } + if (i >= dir->n_files) + { + errno = ENOENT; + return -1; + } + + memcpy(file, &dir->_files[i], sizeof(tinydir_file)); + _tinydir_get_ext(file); + + return 0; +} + +_TINYDIR_FUNC +int tinydir_open_subdir_n(tinydir_dir *dir, size_t i) +{ + char path[_TINYDIR_PATH_MAX]; + if (dir == NULL) + { + errno = EINVAL; + return -1; + } + if (i >= dir->n_files || !dir->_files[i].is_dir) + { + errno = ENOENT; + return -1; + } + + strcpy(path, dir->_files[i].path); + tinydir_close(dir); + if (tinydir_open_sorted(dir, path) == -1) + { + return -1; + } + + return 0; +} + +/* Open a single file given its path */ +_TINYDIR_FUNC +int tinydir_file_open(tinydir_file *file, const char *path) +{ + tinydir_dir dir; + int result = 0; + int found = 0; + char dir_name_buf[_TINYDIR_PATH_MAX]; + char file_name_buf[_TINYDIR_FILENAME_MAX]; + char *dir_name; + char *base_name; +#ifdef _WIN32 + char drive_buf[_TINYDIR_PATH_MAX]; + char ext_buf[_TINYDIR_FILENAME_MAX]; +#endif + + if (file == NULL || path == NULL || strlen(path) == 0) + { + errno = EINVAL; + return -1; + } + if (strlen(path) + _TINYDIR_PATH_EXTRA >= _TINYDIR_PATH_MAX) + { + errno = ENAMETOOLONG; + return -1; + } + + /* Get the parent path */ +#ifdef _WIN32 + if (_splitpath_s( + path, + drive_buf, sizeof drive_buf, + dir_name_buf, sizeof dir_name_buf, + file_name_buf, sizeof file_name_buf, + ext_buf, sizeof ext_buf)) + { + errno = EINVAL; + return -1; + } + /* Concatenate the drive letter and dir name to form full dir name */ + strcat(drive_buf, dir_name_buf); + dir_name = drive_buf; + /* Concatenate the file name and extension to form base name */ + strcat(file_name_buf, ext_buf); + base_name = file_name_buf; +#else + strcpy(dir_name_buf, path); + dir_name = dirname(dir_name_buf); + strcpy(file_name_buf, path); + base_name = basename(file_name_buf); +#endif + + /* Open the parent directory */ + if (tinydir_open(&dir, dir_name) == -1) + { + return -1; + } + + /* Read through the parent directory and look for the file */ + while (dir.has_next) + { + if (tinydir_readfile(&dir, file) == -1) + { + result = -1; + goto bail; + } + if (strcmp(file->name, base_name) == 0) + { + /* File found */ + found = 1; + goto bail; + } + tinydir_next(&dir); + } + if (!found) + { + result = -1; + errno = ENOENT; + } + +bail: + tinydir_close(&dir); + return result; +} + +_TINYDIR_FUNC +void _tinydir_get_ext(tinydir_file *file) +{ + char *period = strrchr(file->name, '.'); + if (period == NULL) + { + file->extension = &(file->name[strlen(file->name)]); + } + else + { + file->extension = period + 1; + } +} + +_TINYDIR_FUNC +int _tinydir_file_cmp(const void *a, const void *b) +{ + const tinydir_file *fa = (const tinydir_file *)a; + const tinydir_file *fb = (const tinydir_file *)b; + if (fa->is_dir != fb->is_dir) + { + return -(fa->is_dir - fb->is_dir); + } + return strncmp(fa->name, fb->name, _TINYDIR_FILENAME_MAX); +} + +#endif http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/239ed217/include/singa/utils/tokenizer.h ---------------------------------------------------------------------- diff --git a/include/singa/utils/tokenizer.h b/include/singa/utils/tokenizer.h new file mode 100644 index 0000000..c66e0af --- /dev/null +++ b/include/singa/utils/tokenizer.h @@ -0,0 +1,64 @@ +/************************************************************ +* +* Licensed to the Apache Software Foundation (ASF) under one +* or more contributor license agreements. See the NOTICE file +* distributed with this work for additional information +* regarding copyright ownership. The ASF licenses this file +* to you under the Apache License, Version 2.0 (the +* "License"); you may not use this file except in compliance +* with the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an +* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +* KIND, either express or implied. See the License for the +* specific language governing permissions and limitations +* under the License. +* +*************************************************************/ + +#ifndef SINGA_UTILS_TOKENIZER_H_ +#define SINGA_UTILS_TOKENIZER_H_ + +#include <glog/logging.h> +#include <string> + +namespace singa { +/** + * Tokenize a string. + * + * example: + * Tokenizer t("assa,asf;wes", ",;"); + * string x; + * t >> x; // x is assa + * t >> x; // x is asf + * t >> x; // x is wes + * cout << (t >> x); // print 0. + */ +class Tokenizer { + public: + Tokenizer(const std::string& str, const std::string& sep): start_(0), + sep_(sep), buf_(str) {} + Tokenizer & operator>>(std::string& out) { + CHECK_LT(start_, buf_.length()); + int start = start_; + auto pos = buf_.find_first_of(sep_, start); + if (pos == std::string::npos) + pos = buf_.length(); + start_ = pos + 1; + out = buf_.substr(start, pos); + return *this; + } + bool Valid() { return start_ < buf_.length(); } + + private: + unsigned start_; + std::string sep_; + const std::string& buf_; +}; + +} // namespace singa + +#endif // SINGA_UTILS_TOKENIZER_H_ http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/239ed217/include/singa/utils/updater.h ---------------------------------------------------------------------- diff --git a/include/singa/utils/updater.h b/include/singa/utils/updater.h new file mode 100644 index 0000000..6413a80 --- /dev/null +++ b/include/singa/utils/updater.h @@ -0,0 +1,145 @@ +/************************************************************ +* +* Licensed to the Apache Software Foundation (ASF) under one +* or more contributor license agreements. See the NOTICE file +* distributed with this work for additional information +* regarding copyright ownership. The ASF licenses this file +* to you under the Apache License, Version 2.0 (the +* "License"); you may not use this file except in compliance +* with the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an +* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +* KIND, either express or implied. See the License for the +* specific language governing permissions and limitations +* under the License. +* +*************************************************************/ + +#ifndef SINGA_UTILS_UPDATER_H_ +#define SINGA_UTILS_UPDATER_H_ + +#include "singa/proto/job.pb.h" +#include "singa/utils/param.h" + +namespace singa { +/** + * Base learning rate generator. + * + * Generate learning rate for a give training step/iteration. + * There are many different ways to change the learning rate through time/step. + * Users can inherint this class to implement their own change method. + */ +class LRGenerator { + public: + static LRGenerator* Create(const LRGenProto& proto); + + virtual ~LRGenerator() {} + + virtual void Init(const LRGenProto& proto) { proto_ = proto; } + /** + * @param step training step/iteration. + * @return base learning rate regardless of step + */ + virtual float Get(int step) { return proto_.base_lr(); } + + protected: + LRGenProto proto_; +}; + +class FixedStepLRGen : public LRGenerator { + public: + float Get(int step) override; + private: + int last_idx_ = 0; +}; + +class StepLRGen : public LRGenerator { + public: + float Get(int step) override; +}; + +class LinearLRGen : public LRGenerator { + public: + float Get(int step) override; +}; + +class ExpLRGen : public LRGenerator { + public: + float Get(int step) override; +}; + +class InvLRGen : public LRGenerator { + public: + float Get(int step) override; +}; + +class InvTLRGen : public LRGenerator { + public: + float Get(int step) override; +}; + +/** + * Updater for Param. + */ +class Updater { + public: + static Updater* Create(const UpdaterProto& proto); + + virtual ~Updater() {} + + virtual void Init(const UpdaterProto &proto); + virtual void Update(int step, Param* param, float grad_scale) = 0; + + protected: + UpdaterProto proto_; + LRGenerator* lr_gen_; + float weight_decay_; + float momentum_; +}; + +class SGDUpdater : public Updater { + public: + void Update(int step, Param* param, float grad_scale) override; +}; + +class AdaGradUpdater : public Updater { + public: + void Update(int step, Param* param, float grad_scale) override; +}; + + +class NesterovUpdater : public Updater { + public: + void Update(int step, Param* param, float grad_scale) override; +}; + +/* +class RMSPropUpdater : public Updater { + public: + virtual void Update(int step, Param* param, float grad_scale); + + protected: + float base_lr_; + float delta_; + float rho_; + float weight_decay_; +}; + +class AdaDeltaUpdater : public Updater { + public: + virtual void Update(int step, Param* param, float grad_scale); + + protected: + float rho_; + float delta_; + float weight_decay_; +}; +*/ + +} // namespace singa + +#endif // SINGA_UTILS_UPDATER_H_ http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/239ed217/include/singa/worker.h ---------------------------------------------------------------------- diff --git a/include/singa/worker.h b/include/singa/worker.h new file mode 100644 index 0000000..d8ab61c --- /dev/null +++ b/include/singa/worker.h @@ -0,0 +1,313 @@ +/************************************************************ +* +* Licensed to the Apache Software Foundation (ASF) under one +* or more contributor license agreements. See the NOTICE file +* distributed with this work for additional information +* regarding copyright ownership. The ASF licenses this file +* to you under the Apache License, Version 2.0 (the +* "License"); you may not use this file except in compliance +* with the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an +* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +* KIND, either express or implied. See the License for the +* specific language governing permissions and limitations +* under the License. +* +*************************************************************/ + +#ifndef SINGA_WORKER_H_ +#define SINGA_WORKER_H_ + +#include <string> +#include <vector> +#include "singa/comm/socket.h" +#include "singa/neuralnet/neuralnet.h" +#include "singa/proto/job.pb.h" +#include "singa/neuralnet/connection_layer/bridge.h" +#include "singa/neuralnet/neuron_layer/rbm.h" + +namespace singa { + +//!< sleep 5 milliseconds if the Param is not updated to the expected version +const int kCollectSleepTime = 5; +/** + * The Worker class which runs the training algorithm. + * The first worker group will initialize parameters of the Net, + * and put them into the distributed memory/table. + * The virtual function TrainOneBatch and TestOneBatch implement the + * training and test algorithm for one mini-batch data. + * + * Child workers override the two functions to implement their training + * algorithms, e.g., the BPWorker/CDWorker/BPTTWorker implements the BP/CD/BPTT + * algorithm respectively. + */ +class Worker { + public: + /** + * Create an instance of the subclass of Worker. + * + * @param[in] conf configuration of the TrainOneBatch algorithm. Different + * Worker subclasses implement different algorithms. Hence the creation is + * based on the TrainOneBatch algorithm type. Currently SINGA + * provides two algorithms: + * -# Back-propagation for the feed-forward models, e.g., CNN and MLP, and the + * recurrent neural networks. + * -# Contrastive divergence for the energy models, e.g., RBM. + * + * @return a pointer to the instance of the Worker subclass. + */ + static Worker* Create(const AlgProto& conf); + virtual ~Worker(); + /** + * @param[in] grp_id global worker group ID + * @param[in] id worker ID within the group + * @param[in] conf job configuration + * @param[in] train_net pointer to the training neural net, which could be + * shared with other workers from the same group. Different workers run over + * differnt subset of layers. + * @param[in] val_net pointer to the validation neural net. Currently only the + * first worker from the first group would have validation neural net. All + * other workers receive nullptr for this argument. + * @param[in] test_net pointer to the test neural net. Currently only the + * first worker from the first group would have test neural net. All other + * workers receive nullptr for this argument. + */ + virtual void Setup(int grp_id, int id, const JobProto& conf, + NeuralNet* train_net, NeuralNet* val_net, NeuralNet* test_net); + + /** + * Main function of Worker. + * + * Train the neuralnet step by step, test/validation is done periodically. + */ + void Run(); + + /** + * Init values of Param instances assocaited with local layers (i.e., layers + * dispatched to this worker). + * + * If one Param is owned by the worker, then it should be initialized and put + * to servers. Otherwise Get() should be called to get the Param. The Get() + * may not send get requests if the Param owner is in the same procs, for + * which case the memory space of the Param objects are shared. But if this + * worker and the Param owner worker run on different devices (e.g., GPUs), + * then the get request would be sent. + * + * If the training starts from scrath, every Param object is initialzed using + * ParamGenerator. After that, the worker may + * train for a couple of steps to warmup the params before put + * them to servers (warmup of JobProto controls this). + * + * If one Param object's name matches that of one Param object from the + * checkpoint files, its Param values would be loaded from checkpoint files. + * + * @param[in] job_conf job configuration which provides settings for + * checkpoint file paths, warmup steps and Param versions. + * @param[out] net pointer to a neural net whose Param values will be + * initialized. + */ + void InitNetParams(const JobProto& job_conf, NeuralNet* net); + + /** + * Checkpoint all Param objects owned by the worker onto disk. + * The serialization is done using BlobProtos which includes the name, version + * and values of each Param object. + * Different workers would generate different checkpoint files. The file path + * is <workspace>/checkpoint-<jobname>-step<step>-worker<worker_id>.bin + * @param[in] step training step + * @param[in] folder directory to put the checkpoint file + * @param net the training net whose Param objects will be dumped. + */ + void Checkpoint(int step, const std::string& folder, NeuralNet* net); + + /** + * Train one mini-batch. + * Test/Validation is done before training. + * + * @param[in] step training step. + * @param[in] net neural net to be trained. + */ + virtual void TrainOneBatch(int step, NeuralNet* net) = 0; + + /** + * Test/validate one mini-batch data. + * + * @param[in] step test step. + * @param[in] phase test could be done for validation or test phase. + * @param[in] net neural net for test + */ + virtual void TestOneBatch(int step, Phase phase, NeuralNet* net) = 0; + + /** + * Display infomation from layers. + * + * @param flag could be a combination of multiple phases, e.g, kTest|kForward, + * it is passed to the Layer::ToString() function for each layer to decide + * what to display . + * @param prefix display prefix, e.g., 'Train step 100', 'Test step 90'. + * @param net display layers from this neural net. + */ + void Display(int flag, const std::string& prefix, NeuralNet* net); + + /** + * Put Param values to server. + * + * @param param + * @param step used as current param version for the put request + */ + int Put(int step, Param* param); + + /** + * Get Param with specific version from server + * If the current version >= the requested version, then return. + * Otherwise send a get request to stub who would forwards it to servers. + * @param param + * @param step requested param version + */ + int Get(int step, Param* param); + + /** + * Update Param. + * + * @param param + * @param step training step used for updating (e.g., deciding learning rate). + */ + int Update(int step, Param* param); + + /** + * Wait for the response of the update/get requests. + * + * @param param + * @param step not used now. + */ + int Collect(int step, Param* param); + + /** + * Call Collect() for every param of net + */ + int CollectAll(int step, NeuralNet* net); + + /** + * Receive blobs from other workers due to model partitions. + */ + void ReceiveBlobs(bool data, bool grad, BridgeLayer* layer, NeuralNet* net); + + /** + * Send blobs to other workers due to model partitions. + */ + void SendBlobs(bool data, bool grad, BridgeLayer* layer, NeuralNet* net); + + + /** + * @param[in] step + * @return true if it is time to display training info, e.g., loss; otherwise + * false. + */ + inline bool DisplayNow(int step) const { + return job_conf_.disp_freq() > 0 + && step >= job_conf_.disp_after() + && ((step - job_conf_.disp_after()) % job_conf_.disp_freq() == 0); + } + /** + * @param[in] step + * @return true if it is time to finish the training; otherwise false. + */ + inline bool StopNow(int step) const { + return step >= job_conf_.train_steps(); + } + /** + * @param[in] step + * @return true if it is time to do checkpoint Param objects; otherwise false. + */ + inline bool CheckpointNow(int step) const { + return job_conf_.checkpoint_freq() > 0 + && step >= job_conf_.checkpoint_after() + && ((step - job_conf_.checkpoint_after()) + % job_conf_.checkpoint_freq() == 0); + } + /** + * @param[in] step + * @return true if it is time to do test over the test dataset. + */ + inline bool TestNow(int step) const { + return job_conf_.test_freq() > 0 + && job_conf_.test_steps() > 0 + && step >= job_conf_.test_after() + && ((step - job_conf_.test_after()) % job_conf_.test_freq() == 0); + } + /** + * @param[in] step + * @return true if it is time to do test over the validation dataset. + */ + inline bool ValidateNow(int step) const { + return job_conf_.validate_freq() > 0 + && job_conf_.validate_steps() > 0 + && step >= job_conf_.validate_after() + && ((step - job_conf_.validate_after()) % job_conf_.validate_freq() == 0); + } + /** + * @return a vector with pointers to all neural nets. + */ + const std::vector<NeuralNet*> GetNets() const { + return std::vector<NeuralNet*> {train_net_, val_net_, test_net_}; + } + /** + * @return training net. + */ + inline NeuralNet* train_net() const { + return train_net_; + } + /** + * @return group ID + */ + inline int grp_id() const { return grp_id_; } + /** + * @reutrn worker ID within the worker group. + */ + inline int id() const { return id_; } + + protected: + int grp_id_ = -1, id_ = -1; + int step_ = 0; + JobProto job_conf_; + NeuralNet* train_net_ = nullptr; + NeuralNet* test_net_ = nullptr; + NeuralNet* val_net_ = nullptr; + Dealer* layer_dealer_ = nullptr; + Dealer* dealer_ = nullptr; +}; + +class BPWorker: public Worker { + public: + void TrainOneBatch(int step, NeuralNet* net) override; + void TestOneBatch(int step, Phase phase, NeuralNet* net) override; + void Forward(int step, Phase phase, NeuralNet* net); + void Backward(int step, NeuralNet* net); +}; + +class CDWorker: public Worker { + public: + void TrainOneBatch(int step, NeuralNet* net) override; + void TestOneBatch(int step, Phase phase, NeuralNet* net) override; +}; + +inline int BlobTrgt(int grp, int layer) { + return (grp << 16) | layer; +} + +inline int BlobGrp(int blob_trgt) { + return blob_trgt >> 16; +} + +inline int BlobLayer(int blob_trgt) { + static int mask = (1 << 16) -1; + return blob_trgt & mask; +} + +} // namespace singa + +#endif // SINGA_WORKER_H_ http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/239ed217/include/stub.h ---------------------------------------------------------------------- diff --git a/include/stub.h b/include/stub.h deleted file mode 100644 index 719f033..0000000 --- a/include/stub.h +++ /dev/null @@ -1,109 +0,0 @@ -/************************************************************ -* -* Licensed to the Apache Software Foundation (ASF) under one -* or more contributor license agreements. See the NOTICE file -* distributed with this work for additional information -* regarding copyright ownership. The ASF licenses this file -* to you under the Apache License, Version 2.0 (the -* "License"); you may not use this file except in compliance -* with the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, -* software distributed under the License is distributed on an -* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -* KIND, either express or implied. See the License for the -* specific language governing permissions and limitations -* under the License. -* -*************************************************************/ - -#ifndef SINGA_STUB_H_ -#define SINGA_STUB_H_ - -#include <queue> -#include <unordered_map> -#include <vector> -#include <string> -#include "comm/socket.h" -#include "neuralnet/neuralnet.h" -#include "proto/job.pb.h" -#include "proto/singa.pb.h" -#include "utils/factory.h" -#include "utils/param.h" -#include "utils/singleton.h" -#include "./server.h" -#include "./worker.h" - -namespace singa { - -class Stub { - public: - ~Stub(); - /** - * Find an endpoint to bind. - */ - void Setup(); - /** - * The Stub instance runs this function in the main thread to handle (e.g., - * forward) messages from workers and servers. - * - * @param[in] slice2server the k-th value is the ID of the server that is in - * charge of updating the Param slice with ID k. Large Param objects are - * sliced into subsets for load-balance. Different subsets are updated by - * different servers. - */ - void Run(const vector<int>& slice2server, - const std::vector<Worker*>& workers, - const std::vector<Server*>& servers); - - const std::string& endpoint() const { - return endpoint_; - } - - protected: - /** - * Create a socket to send msg to the specified process - * @param dst_procs the dst process (logical) ID - * @return the newly created socket - */ - Dealer* CreateInterProcsDealer(int dst_procs); - /** - * Generate a request message to Get the parameter object. - */ - const std::vector<Msg*> HandleGetRequest(ParamEntry* entry, Msg** msg); - void HandleGetResponse(ParamEntry* entry, Msg** msg); - /** - * Generate a request message to Update the parameter object. - */ - const std::vector<Msg*> HandleUpdateRequest(ParamEntry* entry, Msg** msg); - /** - * Handle response msg from servers for the update requests. - */ - void HandleUpdateResponse(ParamEntry* entry, Msg** msg); - /** - * Generate a request message to Put the parameter object. - */ - const std::vector<Msg*> HandlePutRequest(ParamEntry* entry, Msg** msg); - /** - * Called by HandlePut, HandleUpdate and HandleGet functions - * @param type message type - * @param version param version - * @param entry - * @param msg - * @param ret generated messages - */ - void GenMsgs(int type, int version, ParamEntry* entry, - Msg* msg, std::vector<Msg*> *ret); - - - protected: - Router *router_ = nullptr; - std::string endpoint_; - std::vector<int> slice2server_; -}; - -} // namespace singa - -#endif // SINGA_STUB_H_ http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/239ed217/include/utils/blob.h ---------------------------------------------------------------------- diff --git a/include/utils/blob.h b/include/utils/blob.h deleted file mode 100644 index 91db095..0000000 --- a/include/utils/blob.h +++ /dev/null @@ -1,198 +0,0 @@ -/************************************************************ -* -* Licensed to the Apache Software Foundation (ASF) under one -* or more contributor license agreements. See the NOTICE file -* distributed with this work for additional information -* regarding copyright ownership. The ASF licenses this file -* to you under the Apache License, Version 2.0 (the -* "License"); you may not use this file except in compliance -* with the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, -* software distributed under the License is distributed on an -* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -* KIND, either express or implied. See the License for the -* specific language governing permissions and limitations -* under the License. -* -*************************************************************/ - -/** - * The code is adapted from that of Caffe whose license is attached. - * - * COPYRIGHT - * All contributions by the University of California: - * Copyright (c) 2014, The Regents of the University of California (Regents) - * All rights reserved. - * All other contributions: - * Copyright (c) 2014, the respective contributors - * All rights reserved. - * Caffe uses a shared copyright model: each contributor holds copyright over - * their contributions to Caffe. The project versioning records all such - * contribution and copyright details. If a contributor wants to further mark - * their specific copyright on a particular contribution, they should indicate - * their copyright solely in the commit message of the change when it is - * committed. - * LICENSE - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND - * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED - * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR - * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES - * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; - * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND - * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS - * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * CONTRIBUTION AGREEMENT - * By contributing to the BVLC/caffe repository through pull-request, comment, - * or otherwise, the contributor releases their content to the - * license and copyright terms herein. - * - */ -#ifndef SINGA_UTILS_BLOB_H_ -#define SINGA_UTILS_BLOB_H_ - -#include <glog/logging.h> -#include <memory> -#include <vector> -#include "proto/common.pb.h" - -namespace singa { - -inline void MallocHost(void** ptr, size_t size) { - *ptr = malloc(size); -} - -inline void FreeHost(void* ptr) { - free(ptr); -} - -/** - * @brief Manages memory allocation and synchronization between the host (CPU) - * and device (GPU). - * - * TODO(dox): more thorough description. - */ -class SyncedMemory { - public: - enum SyncedHead { UNINITIALIZED, - HEAD_AT_CPU, - HEAD_AT_GPU, - SYNCED }; - - SyncedMemory() {} - explicit SyncedMemory(size_t size) : size_(size) {} - ~SyncedMemory(); - - const void* cpu_data(); - const void* gpu_data(); - void* mutable_cpu_data(); - void* mutable_gpu_data(); - void set_cpu_data(void* data); - inline SyncedHead head() { return head_; } - inline size_t size() { return size_; } - - private: - void to_cpu(); - void to_gpu(); - - void* cpu_ptr_ = nullptr; - void* gpu_ptr_ = nullptr; - size_t size_ = 0; - SyncedHead head_ = UNINITIALIZED; - bool own_cpu_data_ = false; -}; // class SyncedMemory - - -template <typename Dtype> -class Blob { - public: - Blob() {} - explicit Blob(const std::vector<int>& shape) { Reshape(shape); } - /** - * @brief Change the dimensions of the blob, allocating new memory if - * necessary. - * - * This function can be called both to create an initial allocation - * of memory, and to adjust the dimensions of a top blob during Layer::Reshape - * or Layer::Forward. When changing the size of blob, memory will only be - * reallocated if sufficient memory does not already exist, and excess memory - * will never be freed. - * - * Note that reshaping an input blob and immediately calling Net::Backward is - * an error; either Net::Forward or Net::Reshape need to be called to - * propagate the new input shape to higher layers. - */ - void Reshape(const std::vector<int>& shape); - void ReshapeLike(const Blob& other); - /** - * @brief Copy from a source Blob. - * - * @param source the Blob to copy from - * @param reshape if false, require this Blob to be pre-shaped to the shape - * of other (and die otherwise); if true, Reshape this Blob to other's - * shape if necessary - */ - void CopyFrom(const Blob<Dtype>& source); - void CopyFrom(const Blob<Dtype>& source, bool reshape); - void FromProto(const singa::BlobProto& proto); - void ToProto(singa::BlobProto* proto) const; - /** - * @brief Set the data_ shared_ptr to point to the SyncedMemory holding the - * data_ of Blob other -- useful in Layer&s which simply perform a copy - * in their Forward pass. - * - * This deallocates the SyncedMemory holding this Blob's data_, as - * shared_ptr calls its destructor when reset with the "=" operator. - */ - void ShareData(const Blob& other); - void Swap(Blob& other); - inline const std::vector<int>& shape() const { return shape_; } - inline int count() const { return count_; } - inline const int version() const { return version_; } - inline void set_version(int v) { version_ = v; } - inline const Dtype* cpu_data() const { - CHECK(data_); - return static_cast<const Dtype*>(data_->cpu_data()); - } - inline void set_cpu_data(Dtype* data) { - CHECK(data); - data_->set_cpu_data(data); - } - inline const Dtype* gpu_data() const { - CHECK(data_); - return static_cast<const Dtype*>(data_->gpu_data()); - } - inline Dtype* mutable_cpu_data() { - CHECK(data_); - return static_cast<Dtype*>(data_->mutable_cpu_data()); - } - inline Dtype* mutable_gpu_data() { - CHECK(data_); - return static_cast<Dtype*>(data_->mutable_gpu_data()); - } - /// @brief Compute the sum of absolute values (L1 norm) of the data. - Dtype asum_data() const; - Dtype sum_data() const; - - protected: - std::shared_ptr<SyncedMemory> data_ = nullptr; - std::vector<int> shape_; - int count_ = 0; - int capacity_ = 0; - int version_ = -1; -}; // class Blob - -} // namespace singa - -#endif // SINGA_UTILS_BLOB_H_ http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/239ed217/include/utils/cluster.h ---------------------------------------------------------------------- diff --git a/include/utils/cluster.h b/include/utils/cluster.h deleted file mode 100644 index afeb947..0000000 --- a/include/utils/cluster.h +++ /dev/null @@ -1,163 +0,0 @@ -/************************************************************ -* -* Licensed to the Apache Software Foundation (ASF) under one -* or more contributor license agreements. See the NOTICE file -* distributed with this work for additional information -* regarding copyright ownership. The ASF licenses this file -* to you under the Apache License, Version 2.0 (the -* "License"); you may not use this file except in compliance -* with the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, -* software distributed under the License is distributed on an -* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -* KIND, either express or implied. See the License for the -* specific language governing permissions and limitations -* under the License. -* -*************************************************************/ - -#ifndef SINGA_UTILS_CLUSTER_H_ -#define SINGA_UTILS_CLUSTER_H_ - -#include <glog/logging.h> -#include <string> -#include <unordered_map> -#include <memory> -#include <vector> -#include "proto/job.pb.h" -#include "proto/singa.pb.h" -#include "utils/cluster_rt.h" -#include "utils/common.h" -#include "utils/singleton.h" - -namespace singa { - -/** - * Cluster is a singleton object, which provides cluster configuations, - * e.g., the topology of the cluster. - * All IDs start from 0. - */ -class Cluster { - public: - // Cluster is a global singleton in a process - static Cluster* Setup(int job_id, const SingaProto& singaConf, - const ClusterProto& clusterConf); - static Cluster* Get(); - - inline int nserver_groups() const { return cluster_.nserver_groups(); } - inline int nworker_groups() const { return cluster_.nworker_groups(); } - inline int nworkers_per_group() const { return cluster_.nworkers_per_group();} - inline int nservers_per_group() const { return cluster_.nservers_per_group();} - inline int nworkers_per_procs() const { return cluster_.nworkers_per_procs();} - inline int nservers_per_procs() const { return cluster_.nservers_per_procs();} - inline int nworker_groups_per_server_group() const { - if (nserver_groups() == 0 || nservers_per_group() == 0) - return 1; - else - return cluster_.nworker_groups() / cluster_.nserver_groups(); - } - /** - * @return true if the calling procs has server threads, otherwise false - */ - inline bool has_server() const { - if (server_worker_separate()) { - CHECK_LT(procs_id_, nprocs_); - return procs_id_ >= nworker_procs(); - } else { - return procs_id_ < nserver_procs(); - } - } - /** - * @return true if the calling procs has worker threads. - */ - inline bool has_worker() const { - return procs_id_ < nworker_procs(); - } - /** - * @return global procs id, which starts from 0. - */ - inline int procs_id() const { return procs_id_; } - inline void set_procs_id(int procs_id) { procs_id_ = procs_id; } - inline bool server_worker_separate() const { - return cluster_.server_worker_separate(); - } - inline int nworker_procs() const { - return nworker_groups() * nworkers_per_group() / nworkers_per_procs(); - } - inline int nserver_procs() const { - return nserver_groups() * nservers_per_group() / nservers_per_procs(); - } - inline int nprocs() const { return nprocs_; } - /** - * @return endpoint of the router of a procs with the specified id - */ - inline std::string endpoint(int procs_id) const { - CHECK_LT(procs_id, nprocs()); - CHECK_GE(procs_id, 0); - return cluster_rt_->GetProcHost(procs_id); - } - inline std::string workspace() const { return cluster_.workspace(); } - inline std::string vis_folder() const { - return cluster_.workspace() + "/visualization"; - } - inline std::string checkpoint_folder() const { - return cluster_.workspace() + "/checkpoint"; - } - /* - const int stub_timeout() const { return cluster_.stub_timeout(); } - const int worker_timeout() const { return cluster_.worker_timeout(); } - const int server_timeout() const { return cluster_.server_timeout(); } - */ - inline bool share_memory() const { return cluster_.share_memory(); } - inline int sync_freq() const { return cluster_.sync_freq(); } - inline int poll_time() const { return cluster_.poll_time(); } - ClusterRuntime* runtime() const { return cluster_rt_; } - - /** - * @return logical procs ID - */ - inline int ProcsIDOf(int group_id, int id, int flag) { - return procs_ids_.at(Hash(group_id, id, flag)); - } - inline std::string hostip() const { return hostip_; } - - /** - * @param pid, processs ID - * @param group_size, num of executors in a group - * @param procs_size, num of executors in a procs - * - * @return a vector with 4 integers: - * [group start, group end), [start executor, end executor) - */ - const std::vector<int> ExecutorRng(int pid, int group_size, int procs_size); - /** - * Register this process. - * - * @param pid physical process id get from OS, all other procs ID refers to - * logical process ID. - * @param endpoint unique string for other procs to connect - */ - void Register(int pid, const std::string& endpoint); - - private: - void Init(int job, const SingaProto& singaConf, - const ClusterProto& clusterConf); - void SetupFolders(const ClusterProto &cluster); - int Hash(int gid, int id, int flag); - - int procs_id_ = -1; - int nprocs_ = 0; - std::string hostip_ = ""; - // cluster config proto - ClusterProto cluster_; - SingaProto singa_; - ClusterRuntime* cluster_rt_ = nullptr; - std::unordered_map<int, int> procs_ids_; -}; - -} // namespace singa - -#endif // SINGA_UTILS_CLUSTER_H_
