Repository: incubator-singa
Updated Branches:
refs/heads/master fbbcaafdb -> ae2030362
SINGA-21 Code review 4
review param.h, param.cc
- ShareFrom(): init vectors by resizing instead of assignning
- add a summary for put/get/update/sync request workflows
- implement reserve flags correctly for all HandleXX functions
TODO: remove all reserve-flag checks later
- format the code
format update.h, updater.cc
Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/f99246e7
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/f99246e7
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/f99246e7
Branch: refs/heads/master
Commit: f99246e7840eb5fcec2918b37736cba0344db19e
Parents: fbbcaaf
Author: wang sheng <[email protected]>
Authored: Fri Aug 21 14:50:32 2015 +0800
Committer: wangwei <[email protected]>
Committed: Fri Aug 28 17:58:40 2015 +0800
----------------------------------------------------------------------
include/utils/param.h | 269 ++++++++++++++++++-----------------------
include/utils/singleton.h | 2 +-
include/utils/updater.h | 42 ++++---
src/trainer/server.cc | 7 +-
src/utils/param.cc | 211 ++++++++++++++++++--------------
src/utils/updater.cc | 40 +++---
6 files changed, 287 insertions(+), 284 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/f99246e7/include/utils/param.h
----------------------------------------------------------------------
diff --git a/include/utils/param.h b/include/utils/param.h
index 0d24e95..4909026 100644
--- a/include/utils/param.h
+++ b/include/utils/param.h
@@ -1,56 +1,56 @@
#ifndef SINGA_UTILS_PARAM_H_
#define SINGA_UTILS_PARAM_H_
-#include <vector>
+
+#include <memory>
#include <string>
+#include <vector>
+#include "communication/msg.h"
#include "proto/job.pb.h"
#include "utils/blob.h"
-#include "communication/msg.h"
namespace singa {
/**
* Base parameter generator which intializes parameter values.
*/
-
class ParamGenerator {
public:
static ParamGenerator* Create(const ParamGenProto& proto);
- virtual ~ParamGenerator() {}
- virtual void Init(const ParamGenProto& proto) {
- proto_ = proto;
- }
+ virtual ~ParamGenerator() {}
+ virtual void Init(const ParamGenProto& proto) { proto_ = proto; }
virtual void Fill(Blob<float>* data);
protected:
ParamGenProto proto_;
};
-class GaussianGen: public ParamGenerator {
+class GaussianGen : public ParamGenerator {
public:
void Fill(Blob<float>* data) override;
};
-class UniformGen: public ParamGenerator {
+class GaussianSqrtFanInGen : public GaussianGen {
public:
void Fill(Blob<float>* data) override;
};
-class GaussianSqrtFanInGen: public GaussianGen {
+class UniformGen : public ParamGenerator {
public:
void Fill(Blob<float>* data) override;
};
-class UniformSqrtFanInGen: public UniformGen {
+class UniformSqrtFanInGen : public UniformGen {
public:
void Fill(Blob<float>* data) override;
};
-class UniformSqrtFanInOutGen: public UniformGen {
+class UniformSqrtFanInOutGen : public UniformGen {
public:
void Fill(Blob<float>* data) override;
};
+
/**
* Base paramter class.
*
@@ -72,7 +72,8 @@ class UniformSqrtFanInOutGen: public UniformGen {
class Param {
public:
static Param* Create(const ParamProto& proto);
- Param();
+
+ Param() {}
virtual ~Param() {}
void Init(const ParamProto& proto) { proto_ = proto; }
/**
@@ -87,160 +88,125 @@ class Param {
*
* @param version initial version
*/
- virtual void InitValues(int version = 0);
+ virtual void InitValues();
+ virtual void InitValues(int version);
/**
* Share the data blob from other Param objects.
*
* @param other the Param object whose owner owns the data blob
*/
void ShareFrom(const Param& other);
-
+ /**
+ * Init param values from checkpoint blob.
+ */
+ void FromProto(const BlobProto& blob);
+ /**
+ * Dump param values to blob.
+ */
+ void ToProto(BlobProto* blob);
+ /**
+ * Add a slice
+ *
+ * @param slice_id
+ * @param size num of floats for this slice
+ */
+ void AddSlice(int slice_id, int size);
/**
* Scale the learning rate when updating parameters in the Param object
*/
- float lr_scale() {
- return proto_.lr_scale();
- }
+ inline float lr_scale() const { return proto_.lr_scale(); }
/**
* Scale the weight decay when updating parameters in the Param object
*/
- float wd_scale() {
- return proto_.wd_scale();
- }
+ inline float wd_scale() const { return proto_.wd_scale(); }
/**
* Parameter name used for Param re-use in other model or sharing between
* layers
*/
- const std::string& name() {
- return proto_.name();
- }
- void set_name(const std::string& name) {
- proto_.set_name(name);
- }
+ inline const std::string& name() const { return proto_.name(); }
+ inline void set_name(const std::string& name) { proto_.set_name(name); }
/**
* If it shares data from others, then owner is the id of that Param,
* otherwise it is itself's id.
*/
- const int owner() const {
- return proto_.owner();
- }
+ inline int owner() const { return proto_.owner(); }
/**
* ID start from 0 and ordered for all Param from the same neuralnet
*/
- int id() const {
- return proto_.id();
- }
+ inline int id() const { return proto_.id(); }
/**
* Set ID
*/
- void set_id(int id) {
+ inline void set_id(int id) {
proto_.set_id(id);
proto_.set_owner(id);
}
-
/**
* Param version is stored inside the data blob to enable all Param objs
* sharing the same values have the same version.
* @return the param version
*/
- int version() const {
- return data_->version();
- }
-
- void set_version(int v) {
- data_->set_version(v);
- }
-
+ inline int version() const { return data_->version(); }
+ inline void set_version(int v) { data_->set_version(v); }
/**
* @return the version of the parameter value local to a worker
*/
- int local_version() const {
- return local_version_;
- }
-
- void set_local_version(int v) {
- local_version_ = v;
- }
- const std::string& share_from() const {
- return proto_.share_from();
- }
+ inline int local_version() const { return local_version_; }
+ inline void set_local_version(int v) { local_version_ = v; }
+ inline const std::string& share_from() const { return proto_.share_from(); }
/**
* @return num of floats.
*/
- int size() const {
- return data_->count();
- }
- const Blob<float> &data() {
- return *data_;
- }
- Blob<float> *mutable_data() {
- return data_.get();
- }
- /**
- * Return gradient of this parameter
- */
- const Blob<float> &grad() {
- return grad_;
- }
- Blob<float> *mutable_grad() {
- return &grad_;
- }
- float* mutable_cpu_data() {
- return data_->mutable_cpu_data();
- }
- float* mutable_cpu_grad() {
- return grad_.mutable_cpu_data();
- }
- float* mutable_cpu_history() {
- return history_.mutable_cpu_data();
- }
-
+ inline int size() const { return data_->count(); }
+ inline const Blob<float>& data() const { return *data_; }
+ inline Blob<float>* mutable_data() { return data_.get(); }
+ inline const Blob<float> &grad() const { return grad_; }
+ inline Blob<float> *mutable_grad() { return &grad_; }
+ inline float* mutable_cpu_data() { return data_->mutable_cpu_data(); }
+ inline float* mutable_cpu_grad() { return grad_.mutable_cpu_data(); }
+ inline float* mutable_cpu_history() { return history_.mutable_cpu_data(); }
/**
* @return slice start ID
*/
- int slice_start() const {
- return slice_start_;
- }
+ inline int slice_start() const { return slice_start_; }
+ inline int num_slices() const { return num_slices_; }
- int num_slices() const {
- return num_slices_;
- }
-
- /**
- * Add a slice
- *
- * @param slice_id
- * @param size num of floats for this slice
- */
- void AddSlice(int slice_id, int size);
/**
- * Init param values from checkpoint blob.
+ * Below are message/request related functions.
+ * The basic communication workflows are as follow:
+ *------------------------------------------------------------------------
+ * |Put |Get |Update |Sync
+ *------------------------------------------------------------------------
+ * Generate|(stub) |(stub) |(stub) |(server)
+ * Message |GenPutMsg |GenGetMsg |GenUpdateMsg |GenSyncMsg
+ *------------------------------------------------------------------------
+ * Handle |(server) |(server) |(server) |(server)
+ * Message |HandlePutMsg|HandleGetMsg |ParseUpdateMsg |HandleSyncMsg
+ * | | |GenUpdateResMsg |
+ *------------------------------------------------------------------------
+ * Handle | |(stub) |(stub) |(server)
+ * Response| |ParseGetResMsg|ParseUpdateResMsg|ParseSyncResMsg
+ *------------------------------------------------------------------------
*/
- void FromProto(const BlobProto& blob);
- /**
- * Dump param values to blob.
- */
- void ToProto(BlobProto* blob);
- /**********************Msg related functions***************************/
/**
- * Generate the message for a get request, i.e., get parameters from a server
+ * Generate the message for a put request, i.e., put parameters to a server
*
* This function is called at worker/stub side.
* @param copy decides whether to copy the parameter values from the server.
* @param slice_idx index of the slice from which the message is generated.
* @return generated message without setting src, dst, target fields.
*/
- virtual Msg* GenGetMsg(bool copy, int slice_idx);
+ virtual Msg* GenPutMsg(bool copy, int slice_idx);
/**
- * Generate the message for a put request, i.e., put parameters to a server.
- * \copydetails GenGetMsg(bool, int);
+ * Generate the message for a get request, i.e., get parameters from a server
+ * \copydetails GenPutMsg(bool, int);
*/
- virtual Msg* GenPutMsg(bool copy, int slice_idx);
+ virtual Msg* GenGetMsg(bool copy, int slice_idx);
/**
* Generate the message for a update request, i.e., pass info to server for
* parameter update.
- * \copydetails GenGetMsg(bool, int);
+ * \copydetails GenPutMsg(bool, int);
*/
virtual Msg* GenUpdateMsg(bool copy, int slice_idx);
/**
@@ -251,6 +217,26 @@ class Param {
* */
virtual Msg* GenSyncMsg(int offset, int size);
/**
+ * Server handling function for put request.
+ *
+ * @param msg request
+ * @param reserve if true reserve the msg space for the calling function;
+ * otherwise the msg should be freed inside the function.
+ * @return resposne message
+ */
+ virtual Msg* HandlePutMsg(Msg** msg, bool reserve);
+ /**
+ * Server handling function for put request.
+ *
+ * \copydetails HandleGetMsg(Msg**, bool reserve)
+ */
+ virtual Msg* HandleGetMsg(Msg** msg, bool reserve);
+ /**
+ * Server parse update requests.
+ * \copydetails GenUpdateResponseMsgs(const std::vector<Msg*>& msgs);
+ */
+ virtual void ParseUpdateMsgs(const std::vector<Msg*>& msgs);
+ /**
* Generate the messages to response the update requests.
*
* This function is called at the server side, where the Param is actually a
@@ -263,29 +249,13 @@ class Param {
* @return response messages
*/
virtual const std::vector<Msg*>
- GenUpdateResponseMsgs(const std::vector<Msg*>& msgs);
-
- /**
- * Server handling function for get request.
- *
- * @param msg request
- * @param reserve if true reserve the msg space for the calling function;
- * otherwise the msg should be freed inside the function.
- * @return resposne message
- */
- virtual Msg* HandleGetMsg(Msg** msg, bool reserve = false);
- /**
- * Server handling function for put request.
- *
- * \copydetails HandleGetMsg(Msg**, bool reserve)
- */
- virtual Msg* HandlePutMsg(Msg** msg, bool reserve = false);
+ GenUpdateResponseMsgs(std::vector<Msg*>* msgs, bool reserve);
/**
* Server handling function for synchronization message
*
* \copydetails HandleGetMsg(Msg**, bool reserve)
*/
- virtual Msg* HandleSyncMsg(Msg** msg, bool reserve = false);
+ virtual Msg* HandleSyncMsg(Msg** msg, bool reserve);
/**
* Worker/Stub parsing function for get response.
*
@@ -300,11 +270,6 @@ class Param {
*/
virtual int ParseUpdateResponseMsg(Msg* msg, int slice_idx);
/**
- * Server parse update requests.
- * \copydetails GenUpdateResponseMsgs(const std::vector<Msg*>& msgs);
- */
- virtual void ParseUpdateMsgs(const std::vector<Msg*>& msgs);
- /**
* Server parsing function for synchronization response.
*
* \copydetails ParseGetResponseMsg(Msg** , int);
@@ -319,19 +284,21 @@ class Param {
void ParseResponseMsg(Msg* msg, int slice_idx);
protected:
- int local_version_;
- //!< the ID of the first slice
- int slice_start_;
- int num_slices_;
- //!< offset and size of each slice
- std::vector<int> slice_offset_, slice_size_;
-
- //!< for debug checking
- std::vector<bool> pending_put_, pending_get_, pending_update_;
- int num_pending_requests_;
-
- std::shared_ptr<Blob<float>> data_;
- //! gradient, history gradient of this parameter
+ int local_version_ = -1;
+ // the ID of the first slice
+ int slice_start_ = 0;
+ int num_slices_ = 0;
+ // offset and size of each slice
+ std::vector<int> slice_offset_;
+ std::vector<int> slice_size_;
+ // for debug checking
+ // since put request has no feedback, we do not track its pending status
+ std::vector<bool> pending_get_;
+ std::vector<bool> pending_update_;
+ int num_pending_requests_ = 0;
+ // data field
+ std::shared_ptr<Blob<float>> data_ = nullptr;
+ // gradient, history gradient of this parameter
Blob<float> grad_, history_;
ParamProto proto_;
};
@@ -344,9 +311,9 @@ class Param {
* Param objects sharing the same values are associated with the same
* ParamEntry.
*/
-class ParamEntry{
+class ParamEntry {
public:
- ParamEntry();
+ ParamEntry() {}
ParamEntry(int total, Param* p);
/**
* Associate the counter to a Param object.
@@ -355,9 +322,10 @@ class ParamEntry{
* @param local 1 if it is used by workers in this procs, 0 otherwise
*/
void AddParam(bool local, Param* p);
- int num_update, next_version;
- int num_local; //!< # local workers using the shared parameter
- int num_total; //!< # total workers using the shared parameter
+ int next_version = -1; // next_version & num_update are directly used by
stub
+ int num_update = 0;
+ int num_local = 0; //!< # local workers using the shared parameter
+ int num_total = 0; //!< # total workers using the shared parameter
//!< Shares are deleted by neuralnet's destructor
std::vector<Param*> shares;
};
@@ -371,9 +339,10 @@ inline int ParamID(int param_trgt) {
}
inline int SliceID(int param_trgt) {
- static int mask = (1 << 16) -1;
+ static const int mask = (1 << 16) -1;
return param_trgt & mask;
}
+
} // namespace singa
#endif // SINGA_UTILS_PARAM_H_
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/f99246e7/include/utils/singleton.h
----------------------------------------------------------------------
diff --git a/include/utils/singleton.h b/include/utils/singleton.h
index 5048266..f02c595 100644
--- a/include/utils/singleton.h
+++ b/include/utils/singleton.h
@@ -22,7 +22,7 @@ class Singleton {
template<typename T>
class TSingleton {
public:
- static T* Instance(){
+ static T* Instance() {
static thread_local T data_;
return &data_;
}
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/f99246e7/include/utils/updater.h
----------------------------------------------------------------------
diff --git a/include/utils/updater.h b/include/utils/updater.h
index 46d2c53..25fc31f 100644
--- a/include/utils/updater.h
+++ b/include/utils/updater.h
@@ -10,24 +10,20 @@ namespace singa {
*
* Generate learning rate for a give training step/iteration.
* There are many different ways to change the learning rate through time/step.
- * Users can inherint this class to implment their own change method.
+ * Users can inherint this class to implement their own change method.
*/
class LRGenerator {
public:
static LRGenerator* Create(const LRGenProto& proto);
- virtual ~LRGenerator() {}
- virtual void Init(const LRGenProto& proto) {
- proto_ = proto;
- }
+ virtual ~LRGenerator() {}
+ virtual void Init(const LRGenProto& proto) { proto_ = proto; }
/**
* @param step training step/iteration.
* @return base learning rate regardless of step
*/
- virtual float Get(int step) {
- return proto_.base_lr();
- }
+ virtual float Get(int step) { return proto_.base_lr(); }
protected:
LRGenProto proto_;
@@ -39,35 +35,43 @@ class FixedStepLRGen : public LRGenerator {
private:
int last_idx_ = 0;
};
+
class StepLRGen : public LRGenerator {
public:
float Get(int step) override;
};
+
class LinearLRGen : public LRGenerator {
public:
float Get(int step) override;
};
+
class ExpLRGen : public LRGenerator {
public:
float Get(int step) override;
};
+
class InvLRGen : public LRGenerator {
public:
float Get(int step) override;
};
+
class InvTLRGen : public LRGenerator {
public:
float Get(int step) override;
};
+
/**
* Updater for Param.
*/
-class Updater{
+class Updater {
public:
static Updater* Create(const UpdaterProto& proto);
+
virtual ~Updater() {}
+
virtual void Init(const UpdaterProto &proto);
- virtual void Update(int step, Param* param, float grad_scale = 1.0f) = 0;
+ virtual void Update(int step, Param* param, float grad_scale) = 0;
protected:
UpdaterProto proto_;
@@ -78,23 +82,24 @@ class Updater{
class SGDUpdater : public Updater {
public:
- void Update(int step, Param* param, float grad_scale = 1.0f);
+ void Update(int step, Param* param, float grad_scale) override;
};
-class AdaGradUpdater : public Updater{
+class AdaGradUpdater : public Updater {
public:
- void Update(int step, Param* param, float grad_scale = 1.0f) override;
+ void Update(int step, Param* param, float grad_scale) override;
};
class NesterovUpdater : public Updater {
public:
- void Update(int step, Param* param, float grad_scale = 1.0f) override;
+ void Update(int step, Param* param, float grad_scale) override;
};
+
/*
-class RMSPropUpdater : public Updater{
+class RMSPropUpdater : public Updater {
public:
- virtual void Update(int step, Param* param, float grad_scale=1.0f);
+ virtual void Update(int step, Param* param, float grad_scale);
protected:
float base_lr_;
@@ -103,9 +108,9 @@ class RMSPropUpdater : public Updater{
float weight_decay_;
};
-class AdaDeltaUpdater : public Updater{
+class AdaDeltaUpdater : public Updater {
public:
- virtual void Update(int step, Param* param, float grad_scale=1.0f);
+ virtual void Update(int step, Param* param, float grad_scale);
protected:
float rho_;
@@ -113,6 +118,7 @@ class AdaDeltaUpdater : public Updater{
float weight_decay_;
};
*/
+
} // namespace singa
#endif // SINGA_UTILS_UPDATER_H_
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/f99246e7/src/trainer/server.cc
----------------------------------------------------------------------
diff --git a/src/trainer/server.cc b/src/trainer/server.cc
index 1fda336..09bc75c 100644
--- a/src/trainer/server.cc
+++ b/src/trainer/server.cc
@@ -177,7 +177,7 @@ Msg* Server::HandleGet(Msg **msg) {
return *msg;
else {
// LOG(ERROR) << "get " << slice << " from "<<(*msg)->src_first();
- auto reply = param->HandleGetMsg(msg);
+ auto reply = param->HandleGetMsg(msg, false);
reply->set_trgt(val, param->version());
return reply;
}
@@ -203,14 +203,13 @@ const vector<Msg*> Server::HandleUpdate(Msg **msg) {
auto param = entry->shares.at(0);
// extract and aggregate gradients
param->ParseUpdateMsgs(request);
- updater_->Update(step, param);
+ updater_->Update(step, param, 1.0f);
param->set_local_version(param->local_version() + 1);
// response to all shares of this param
- for (auto response : param->GenUpdateResponseMsgs(request)) {
+ for (auto response : param->GenUpdateResponseMsgs(&request, false)) {
response->set_trgt((*msg)->trgt_val(), param->local_version());
ret.push_back(response);
}
- request.clear();
entry->num_update = 0;
}
*msg = nullptr;
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/f99246e7/src/utils/param.cc
----------------------------------------------------------------------
diff --git a/src/utils/param.cc b/src/utils/param.cc
index 69f697b..a7c1897 100644
--- a/src/utils/param.cc
+++ b/src/utils/param.cc
@@ -1,14 +1,18 @@
+#include "utils/param.h"
+
#include <glog/logging.h>
#include <cmath>
-#include <chrono>
#include <random>
-#include "utils/param.h"
-#include "proto/job.pb.h"
#include "mshadow/tensor.h"
-#include "utils/singleton.h"
#include "utils/factory.h"
+#include "utils/singleton.h"
+
namespace singa {
-using namespace mshadow;
+
+using mshadow::cpu;
+using mshadow::Random;
+using mshadow::Shape1;
+using mshadow::Tensor;
using std::vector;
using std::string;
@@ -23,18 +27,20 @@ ParamGenerator* ParamGenerator::Create(const ParamGenProto&
proto) {
return gen;
}
-void ParamGenerator::Fill (Blob<float>* blob) {
+void ParamGenerator::Fill(Blob<float>* blob) {
Tensor<cpu, 1> data(blob->mutable_cpu_data(), Shape1(blob->count()));
data = proto_.value();
}
-void GaussianGen::Fill (Blob<float>* blob) {
+
+void GaussianGen::Fill(Blob<float>* blob) {
Tensor<cpu, 1> data(blob->mutable_cpu_data(), Shape1(blob->count()));
auto random = TSingleton<Random<cpu>>::Instance();
random->SampleGaussian(data, proto_.mean(), proto_.std());
- if(proto_.value() != 1)
+ if (proto_.value() != 1)
data *= proto_.value();
}
-void GaussianSqrtFanInGen::Fill (Blob<float>* blob) {
+
+void GaussianSqrtFanInGen::Fill(Blob<float>* blob) {
// only valid for param matrix with num of cols as fan in
CHECK_EQ(blob->shape().size(), 2);
Tensor<cpu, 1> data(blob->mutable_cpu_data(), Shape1(blob->count()));
@@ -42,15 +48,15 @@ void GaussianSqrtFanInGen::Fill (Blob<float>* blob) {
data /= sqrt(blob->shape().at(1));
}
-void UniformGen::Fill (Blob<float>* blob) {
+void UniformGen::Fill(Blob<float>* blob) {
Tensor<cpu, 1> data(blob->mutable_cpu_data(), Shape1(blob->count()));
auto random = TSingleton<Random<cpu>>::Instance();
random->SampleUniform(data, proto_.low(), proto_.high());
- if(proto_.value() != 1)
+ if (proto_.value() != 1)
data *= proto_.value();
}
-void UniformSqrtFanInGen::Fill (Blob<float>* blob) {
+void UniformSqrtFanInGen::Fill(Blob<float>* blob) {
// only valid for param matrix with num of cols as fan in
CHECK_EQ(blob->shape().size(), 2);
Tensor<cpu, 1> data(blob->mutable_cpu_data(), Shape1(blob->count()));
@@ -58,16 +64,16 @@ void UniformSqrtFanInGen::Fill (Blob<float>* blob) {
data /= sqrt(blob->shape().at(1) / 3.0f);
}
-void UniformSqrtFanInOutGen::Fill (Blob<float>* blob) {
+void UniformSqrtFanInOutGen::Fill(Blob<float>* blob) {
// only valid for param matrix with num of cols as fan in
CHECK_EQ(blob->shape().size(), 2);
Tensor<cpu, 1> data(blob->mutable_cpu_data(), Shape1(blob->count()));
UniformGen::Fill(blob);
data /= sqrt(blob->shape()[0] + blob->shape()[1]);
}
-/*****************Param***********************************/
+
Param* Param::Create(const ParamProto& proto) {
- Factory<Param>* factory=Singleton<Factory<Param>>::Instance();
+ Factory<Param>* factory = Singleton<Factory<Param>>::Instance();
Param* p = nullptr;
if (proto.has_user_type())
p = factory->Create(proto.user_type());
@@ -77,16 +83,46 @@ Param* Param::Create(const ParamProto& proto) {
return p;
}
-Param::Param():local_version_(-1), slice_start_(0), num_slices_(0),
- num_pending_requests_(0), data_(nullptr) {
-}
-
void Param::Setup(const vector<int>& shape) {
data_ = std::make_shared<Blob<float>>(shape);
grad_.Reshape(shape);
history_.Reshape(shape);
}
+void Param::InitValues() {
+ InitValues(0);
+}
+
+void Param::InitValues(int version) {
+ ParamGenerator* gen = ParamGenerator::Create(proto_.init());
+ gen->Fill(data_.get());
+ set_version(version);
+}
+
+void Param::ShareFrom(const Param& other) {
+ proto_.set_owner(other.owner());
+ if (data_ != nullptr)
+ CHECK(data_->shape() == other.data_->shape());
+ data_ = other.data_;
+ if (grad_.count() == 0)
+ grad_.Reshape(data_->shape());
+ slice_start_ = other.slice_start_;
+ num_slices_ = other.num_slices_;
+ slice_offset_ = other.slice_offset_;
+ slice_size_ = other.slice_size_;
+ // change pending list size equal to slice size
+ pending_get_.resize(other.pending_get_.size());
+ pending_update_.resize(other.pending_update_.size());
+}
+
+void Param::FromProto(const BlobProto& blob) {
+ data_->FromProto(blob);
+}
+
+void Param::ToProto(BlobProto* blob) {
+ data_->ToProto(blob);
+}
+
void Param::AddSlice(int slice_id, int size) {
int offset = 0;
if (slice_size_.size() > 0) {
@@ -101,37 +137,20 @@ void Param::AddSlice(int slice_id, int size) {
slice_size_.push_back(size);
pending_get_.push_back(false);
pending_update_.push_back(false);
- pending_put_.push_back(false);
num_slices_++;
}
-void Param::InitValues(int version) {
- ParamGenerator* gen = ParamGenerator::Create(proto_.init());
- gen->Fill(data_.get());
- set_version(version);
-}
-void Param::FromProto(const BlobProto& blob) {
- data_->FromProto(blob);
-}
-void Param::ToProto(BlobProto* blob) {
- data_->ToProto(blob);
-}
-
-/**************Message related functions********/
Msg* Param::GenPutMsg(bool copy, int idx) {
CHECK_LT(idx, num_slices_);
Msg* msg = new Msg();
msg->set_type(kPut);
- void *ptr = mutable_cpu_data() + slice_offset_[idx];
- void *p = ptr;
+ void* ptr = mutable_cpu_data() + slice_offset_[idx];
+ void* p = ptr;
if (copy) p = nullptr;
- msg->AddFormatFrame("iffp", slice_size_[idx],
- lr_scale(), wd_scale(), p);
+ msg->AddFormatFrame("iffp", slice_size_[idx], lr_scale(), wd_scale(), p);
if (copy) {
msg->AddFrame(ptr, slice_size_[idx] * sizeof(float));
}
- // pending_put_[idx]=true;
- // num_pending_requests_++;
return msg;
}
@@ -171,6 +190,8 @@ Msg* Param::GenSyncMsg(int offset, int size) {
}
Msg* Param::HandlePutMsg(Msg** msg, bool reserve) {
+ // TODO(wangsheng) remove the check later
+ CHECK(reserve);
int size;
float lr, wc;
float* ptr;
@@ -183,43 +204,59 @@ Msg* Param::HandlePutMsg(Msg** msg, bool reserve) {
Setup(shape);
if (ptr == nullptr) {
CHECK((*msg)->NextFrame());
- CHECK_EQ(size* sizeof(float), (*msg)->FrameSize());
- memcpy(mutable_cpu_data(), (*msg)->FrameData(), size*sizeof(float));
+ CHECK_EQ(size * sizeof(float), (*msg)->FrameSize());
+ memcpy(mutable_cpu_data(), (*msg)->FrameData(), size * sizeof(float));
} else {
data_->set_cpu_data(ptr);
}
- if (!reserve)
- DeleteMsg(msg);
+ if (!reserve) DeleteMsg(msg);
return nullptr;
}
Msg* Param::HandleGetMsg(Msg** msg, bool reserve) {
+ // TODO(wangsheng) remove the check later
+ CHECK(!reserve);
int copy;
float* ptr;
(*msg)->ParseFormatFrame("ip", ©, &ptr);
if (copy) {
- (*msg)->AddFrame(mutable_cpu_data(), sizeof(float)*size());
+ (*msg)->AddFrame(mutable_cpu_data(), sizeof(float) * size());
} else if (ptr != data_->cpu_data()) {
- memcpy(ptr, data_->cpu_data(), sizeof(float)*size());
+ // this case reflects following situation:
+ // worker 0 and server are in the same process, while worker 1 is not.
+ // worker 1 "put" data into server, so server need to allocate memory.
+ // then worker 0 "get" data from server, so server need:
+ // 1. copy the data to the worker0 provided space
+ // 2. change its own pointer to that space in order to share memory
+ // in this case, the server always points to last worker's space
+ memcpy(ptr, data_->cpu_data(), sizeof(float) * size());
data_->set_cpu_data(ptr);
}
// else the mem space is shared among all worker and servers
- (*msg)->SwapAddr();
- (*msg)->set_type(kRGet);
- return *msg;
+ Msg* ret = nullptr;
+ if (reserve) {
+ ret = new Msg(**msg);
+ } else {
+ // if not reserve the msg, we reuse it as return value
+ ret = *msg;
+ *msg = nullptr;
+ }
+ ret->SwapAddr();
+ ret->set_type(kRGet);
+ return ret;
}
void Param::ParseUpdateMsgs(const vector<Msg*>& msgs) {
CHECK_GT(msgs.size(), 0);
float* server_grad = nullptr;
vector<float*> worker_grad;
- for (auto *msg : msgs) {
+ for (auto* msg : msgs) {
int copy;
msg->ParseFormatFrame("i", ©);
msg->NextFrame();
float* ptr = nullptr;
if (copy) {
- ptr = static_cast<float*> (msg->FrameData());
+ ptr = static_cast<float*>(msg->FrameData());
CHECK_EQ(size() * sizeof(float), msg->FrameSize());
} else {
msg->ParseFormatFrame("p", &ptr);
@@ -231,7 +268,8 @@ void Param::ParseUpdateMsgs(const vector<Msg*>& msgs) {
server_grad = worker_grad.at(0);
for (float* grad : worker_grad) {
if (grad != server_grad) {
- for (int i =0; i < size(); i++) {
+ // TODO(wangsh) think about optimize it later?
+ for (int i = 0; i < size(); i++) {
server_grad[i] += grad[i];
}
}
@@ -239,39 +277,41 @@ void Param::ParseUpdateMsgs(const vector<Msg*>& msgs) {
grad_.set_cpu_data(server_grad);
}
-const vector<Msg*> Param::GenUpdateResponseMsgs(const vector<Msg*>& msgs) {
+const vector<Msg*> Param::GenUpdateResponseMsgs(vector<Msg*>* msgs,
+ bool reserve) {
+ // TODO(wangsheng) remove the check later
+ CHECK(!reserve);
vector<Msg*> ret;
- for (auto msg : msgs) {
- msg->FirstFrame();
- msg->SwapAddr();
- msg->set_type(kRUpdate);
+ for (Msg* msg : *msgs) {
+ Msg* ptr = reserve ? new Msg(*msg) : msg;
+ ptr->FirstFrame();
+ ptr->SwapAddr();
+ ptr->set_type(kRUpdate);
int copy;
- msg->ParseFormatFrame("i", ©);
+ ptr->ParseFormatFrame("i", ©);
if (copy) {
- msg->NextFrame();
- CHECK_EQ(msg->FrameSize(), sizeof(float) * size());
- memcpy(msg->FrameData(), mutable_cpu_data(), msg->FrameSize());
+ ptr->NextFrame();
+ CHECK_EQ(ptr->FrameSize(), sizeof(float) * size());
+ memcpy(ptr->FrameData(), mutable_cpu_data(), ptr->FrameSize());
}
- ret.push_back(msg);
+ ret.push_back(ptr);
}
+ // if not reserved, we remove all pointers
+ if (!reserve) msgs->clear();
return ret;
}
Msg* Param::HandleSyncMsg(Msg** msg, bool reserve) {
- if (!reserve)
- DeleteMsg(msg);
+ // TODO(wangwei) handle it later
+ if (!reserve) DeleteMsg(msg);
return nullptr;
}
-int Param::ParseSyncResponseMsg(Msg* msg, int slice_idx) {
- return 1;
-}
-
int Param::ParseGetResponseMsg(Msg *msg, int slice_idx) {
CHECK_EQ(pending_get_[slice_idx], true);
pending_get_[slice_idx] = false;
ParseResponseMsg(msg, slice_idx);
- return (--num_pending_requests_)%num_slices_ == 0;
+ return (--num_pending_requests_) % num_slices_ == 0;
}
int Param::ParseUpdateResponseMsg(Msg *msg, int slice_idx) {
@@ -281,48 +321,37 @@ int Param::ParseUpdateResponseMsg(Msg *msg, int
slice_idx) {
return (--num_pending_requests_) % num_slices_ == 0;
}
+int Param::ParseSyncResponseMsg(Msg* msg, int slice_idx) {
+ // TODO(wangwei) handle it later
+ return 1;
+}
+
void Param::ParseResponseMsg(Msg* msg, int slice_idx) {
int copy;
msg->ParseFormatFrame("i", ©);
msg->NextFrame();
if (copy) {
- CHECK_EQ(msg->FrameSize(), slice_size_[slice_idx]*sizeof(float));
- memcpy(mutable_cpu_data()+slice_offset_[slice_idx],
+ CHECK_EQ(msg->FrameSize(), slice_size_[slice_idx] * sizeof(float));
+ memcpy(mutable_cpu_data() + slice_offset_[slice_idx],
msg->FrameData(), msg->FrameSize());
}
// LOG(ERROR)<<"parse response norm "<<data_->asum_data()<<" of "<<id();
}
-void Param::ShareFrom(const Param& other) {
- proto_.set_owner(other.owner());
- if (data_ != nullptr) {
- CHECK(std::equal(data_->shape().begin(), data_->shape().end(),
- other.data_->shape().begin()));
- }
- data_ = other.data_;
- if (grad_.count() == 0)
- grad_.Reshape(data_->shape());
- slice_offset_ = other.slice_offset_;
- slice_size_ = other.slice_size_;
- slice_start_ = other.slice_start_;
- num_slices_ = other.num_slices_;
- pending_get_ = other.pending_get_;
- pending_put_ = other.pending_put_;
- pending_update_ = other.pending_update_;
-}
-
/************************ParamEntry***************************/
ParamEntry::ParamEntry():
num_update(0), next_version(-1), num_local(0), num_total(0) {
}
-ParamEntry::ParamEntry(int total, Param* p) : num_update(0), num_total(total) {
+ParamEntry::ParamEntry(int total, Param* p) {
+ num_total = total;
shares.push_back(p);
}
+
void ParamEntry::AddParam(bool local, Param* p) {
num_local += local;
num_total += 1;
- if (local)
- shares.push_back(p);
+ if (local) shares.push_back(p);
}
+
} // namespace singa
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/f99246e7/src/utils/updater.cc
----------------------------------------------------------------------
diff --git a/src/utils/updater.cc b/src/utils/updater.cc
index 24487d3..2c4b6c3 100644
--- a/src/utils/updater.cc
+++ b/src/utils/updater.cc
@@ -1,16 +1,21 @@
-
#include "utils/updater.h"
-#include "mshadow/tensor.h"
+
#include "mshadow/cxxnet_op.h"
+#include "mshadow/tensor.h"
#include "utils/singleton.h"
#include "utils/factory.h"
-#include "proto/job.pb.h"
-namespace singa {
-using namespace mshadow;
-using namespace mshadow::expr;
+namespace singa {
+
+using mshadow::cpu;
+using mshadow::expr::F;
+using mshadow::op::sqrtop;
+using mshadow::op::square;
+using mshadow::Shape;
+using mshadow::Shape1;
+using mshadow::Tensor;
+using mshadow::TensorContainer;
-/**********************Learning rate generator******************************/
LRGenerator* LRGenerator::Create(const LRGenProto& proto) {
auto factory = Singleton<Factory<LRGenerator>>::Instance();
LRGenerator* gen = nullptr;
@@ -23,9 +28,9 @@ LRGenerator* LRGenerator::Create(const LRGenProto& proto) {
}
float FixedStepLRGen::Get(int step) {
- if (last_idx_ < proto_.fixedstep_conf().step_size() -1
+ if (last_idx_ < proto_.fixedstep_conf().step_size() - 1
&& step >= proto_.fixedstep_conf().step(last_idx_ + 1)) {
- last_idx_ ++;
+ last_idx_++;
}
return proto_.fixedstep_conf().step_lr(last_idx_);
}
@@ -38,7 +43,7 @@ float StepLRGen::Get(int step) {
float LinearLRGen::Get(int step) {
int freq = proto_.linear_conf().change_freq();
- float r = step * 1.0 / freq;
+ float r = step * 1.0 / freq;
return (1.0 - r) * proto_.base_lr() + r * proto_.linear_conf().final_lr();
}
@@ -56,8 +61,6 @@ float InvTLRGen::Get(int step) {
return proto_.base_lr() / (1 + step * 1. /
proto_.inverset_conf().final_lr());
}
-/***********************Updater********************************/
-
Updater* Updater::Create(const UpdaterProto& proto) {
auto factory = Singleton<Factory<Updater>>::Instance();
Updater* updater = nullptr;
@@ -84,9 +87,8 @@ void SGDUpdater::Update(int step, Param* param, float
grad_scale) {
float wd = weight_decay_ * param->wd_scale();
if (grad_scale != 1.f)
grad *= grad_scale;
- if (wd > 0) { // L2 regularization, should be done after timing grad_scale
+ if (wd > 0) // L2 regularization, should be done after timing grad_scale
grad += data * wd;
- }
if (momentum_ > 0) {
Tensor<cpu, 1> history(param->mutable_cpu_history(), s);
history = history * momentum_ - lr * grad;
@@ -108,9 +110,8 @@ void NesterovUpdater::Update(int step, Param* param, float
grad_scale) {
float wd = weight_decay_*param->wd_scale();
if (grad_scale != 1.f)
grad *= grad_scale;
- if (wd > 0) { // L2 regularization, should be done after timing grad_scale
+ if (wd > 0) // L2 regularization, should be done after timing grad_scale
grad += data * wd;
- }
Copy(tmp, history);
history = history * momentum_ + lr * grad;
tmp = history * (1 + momentum_) - tmp * momentum_;
@@ -126,11 +127,10 @@ void AdaGradUpdater::Update(int step, Param* param, float
grad_scale) {
float wd = weight_decay_*param->wd_scale();
if (grad_scale != 1.f)
grad *= grad_scale;
- if (wd > 0) { // L2 regularization, should be done after timing grad_scale
+ if (wd > 0) // L2 regularization, should be done after timing grad_scale
grad += data * wd;
- }
- history += F<op::square>(grad);
- data -= lr * grad / (F<op::sqrtop>(history, proto_.delta()));
+ history += F<square>(grad);
+ data -= lr * grad / (F<sqrtop>(history, proto_.delta()));
}
/***********************RMSProp******************************