Repository: incubator-singa Updated Branches: refs/heads/master 97141e2e0 -> 6afa895b8
SINGA-60 Make learning rate and param init modular Created a base class for getting learning rate, which is changed during training. Created a base class for initializing parameter values. SINGA comes with a couple of built-in implementations for the two base classes. Users can also implement their own learning rate changing methods and parameter initializing methods by extending the correponding base classes. Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/6afa895b Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/6afa895b Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/6afa895b Branch: refs/heads/master Commit: 6afa895b8ea060a532ea01f1f4484c9db11a2496 Parents: 97141e2 Author: Wei Wang <[email protected]> Authored: Wed Aug 19 17:36:39 2015 +0800 Committer: Wei Wang <[email protected]> Committed: Wed Aug 19 19:19:38 2015 +0800 ---------------------------------------------------------------------- examples/cifar10/job.conf | 71 +++++++++++++---------- examples/mnist/conv.conf | 71 +++++++++++++---------- examples/mnist/job.conf | 108 +++++++++++++++++++++-------------- include/driver.h | 33 +++++++++++ include/trainer/worker.h | 2 - include/utils/param.h | 46 ++++++++++++++- include/utils/updater.h | 89 +++++++++++++++++++---------- src/driver.cc | 23 +++++++- src/proto/job.proto | 67 ++++++++++++---------- src/trainer/server.cc | 1 - src/trainer/worker.cc | 25 +++------ src/utils/param.cc | 93 +++++++++++++++++++------------ src/utils/updater.cc | 124 ++++++++++++++++++----------------------- 13 files changed, 467 insertions(+), 286 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/6afa895b/examples/cifar10/job.conf ---------------------------------------------------------------------- diff --git a/examples/cifar10/job.conf b/examples/cifar10/job.conf index b294f03..f44ca50 100644 --- a/examples/cifar10/job.conf +++ b/examples/cifar10/job.conf @@ -5,16 +5,18 @@ test_freq:300 disp_freq:30 alg: kBP updater{ - weight_decay:0.004 - lr_change: kFixedStep type: kSGD - fixedstep_conf:{ - step:0 - step:60000 - step:65000 - step_lr:0.001 - step_lr:0.0001 - step_lr:0.00001 + weight_decay:0.004 + learning_rate { + type: kFixedStep + fixedstep_conf:{ + step:0 + step:60000 + step:65000 + step_lr:0.001 + step_lr:0.0001 + step_lr:0.00001 + } } } neuralnet { @@ -63,15 +65,18 @@ neuralnet { } param { name: "w1" - init_method:kGaussian - std:0.0001 - lr_scale:1.0 + init { + type:kGaussian + std:0.0001 + } } param { name: "b1" - init_method: kConstant lr_scale:2.0 - value:0 + init { + type: kConstant + value:0 + } } } @@ -112,15 +117,18 @@ neuralnet { } param { name: "w2" - init_method:kGaussian - std:0.01 - lr_scale:1.0 + init { + type:kGaussian + std:0.01 + } } param { name: "b2" - init_method: kConstant lr_scale:2.0 - value:0 + init { + type: kConstant + value:0 + } } } layer { @@ -160,13 +168,17 @@ neuralnet { } param { name: "w3" - init_method:kGaussian - std:0.01 + init { + type:kGaussian + std:0.01 + } } param { name: "b3" - init_method: kConstant - value:0 + init { + type: kConstant + value:0 + } } } layer { @@ -193,17 +205,20 @@ neuralnet { } param { name: "w4" - init_method:kGaussian - std:0.01 - lr_scale:1.0 wd_scale:250 + init { + type:kGaussian + std:0.01 + } } param { name: "b4" - init_method: kConstant lr_scale:2.0 wd_scale:0 - value:0 + init { + type: kConstant + value:0 + } } } http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/6afa895b/examples/mnist/conv.conf ---------------------------------------------------------------------- diff --git a/examples/mnist/conv.conf b/examples/mnist/conv.conf index 3509a36..1d4d740 100644 --- a/examples/mnist/conv.conf +++ b/examples/mnist/conv.conf @@ -4,15 +4,17 @@ test_steps:100 test_freq:500 disp_freq:50 alg: kBP -updater{ - base_lr:0.01 +updater { momentum:0.9 weight_decay:0.0005 - lr_change: kInverse type: kSGD - inverse_conf { - gamma:0.0001 - pow:0.75 + learning_rate { + type : kInverse + base_lr:0.01 + inverse_conf { + gamma:0.0001 + pow:0.75 + } } } neuralnet { @@ -61,16 +63,19 @@ neuralnet { stride: 1 } param{ - name: "w1" - init_method:kUniformSqrtFanIn - lr_scale:1.0 + name: "w1" + init { + type : kUniformSqrtFanIn } + } param{ - name: "b1" - init_method: kConstant - lr_scale:2.0 + name: "b1" + init { + type : kConstant value:0 } + lr_scale:2.0 + } } layer { name: "pool1" @@ -92,16 +97,19 @@ neuralnet { stride: 1 } param{ - name: "w2" - init_method:kUniformSqrtFanIn - lr_scale:1.0 + name: "w2" + init { + type :kUniformSqrtFanIn } + } param{ - name: "b2" - init_method: kConstant - lr_scale:2.0 + name: "b2" + init { + type : kConstant value:0 } + lr_scale:2.0 + } } layer { name: "pool2" @@ -121,17 +129,19 @@ neuralnet { num_output: 500 } param{ - name: "w3" - init_method:kUniformSqrtFanIn - lr_scale:1.0 + name: "w3" + init { + type :kUniformSqrtFanIn } + } param{ - name: "b3" - init_method: kConstant - lr_scale:2.0 + name: "b3" + init { + type : kConstant value:0 + } + lr_scale:2.0 } - } layer { @@ -149,14 +159,17 @@ neuralnet { } param { name: "w4" - init_method:kUniformSqrtFanIn - lr_scale:1 + init { + type :kUniformSqrtFanIn + } } param { name: "b4" - init_method: kConstant + init { + type : kConstant + value:0 + } lr_scale:2 - value:0 } } layer{ http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/6afa895b/examples/mnist/job.conf ---------------------------------------------------------------------- diff --git a/examples/mnist/job.conf b/examples/mnist/job.conf index 34fbca2..360e1ec 100644 --- a/examples/mnist/job.conf +++ b/examples/mnist/job.conf @@ -5,12 +5,14 @@ test_freq:60 disp_freq:10 alg: kBP updater{ - base_lr: 0.001 - lr_change: kStep type: kSGD - step_conf{ - change_freq: 60 - gamma: 0.997 + learning_rate{ + type : kStep + base_lr: 0.001 + step_conf{ + change_freq: 60 + gamma: 0.997 + } } } @@ -61,15 +63,19 @@ neuralnet { } param{ name: "w1" - init_method: kUniform - low:-0.05 - high:0.05 + init { + type: kUniform + low:-0.05 + high:0.05 + } } param{ name: "b1" - init_method: kUniform - low: -0.05 - high:0.05 + init { + type : kUniform + low: -0.05 + high:0.05 + } } } @@ -87,15 +93,19 @@ neuralnet { } param{ name: "w2" - init_method: kUniform - low:-0.05 - high:0.05 + init { + type: kUniform + low:-0.05 + high:0.05 + } } param{ name: "b2" - init_method: kUniform - low: -0.05 - high:0.05 + init { + type: kUniform + low: -0.05 + high:0.05 + } } } @@ -113,15 +123,19 @@ neuralnet { } param{ name: "w3" - init_method: kUniform - low:-0.05 - high:0.05 + init{ + type: kUniform + low:-0.05 + high:0.05 + } } param{ name: "b3" - init_method: kUniform - low: -0.05 - high:0.05 + init { + type : kUniform + low: -0.05 + high:0.05 + } } } @@ -140,15 +154,19 @@ neuralnet { } param{ name: "w4" - init_method: kUniform - low:-0.05 - high:0.05 + init { + type : kUniform + low:-0.05 + high:0.05 + } } param{ name: "b4" - init_method: kUniform - low: -0.05 - high:0.05 + init { + type : kUniform + low: -0.05 + high:0.05 + } } } @@ -167,15 +185,19 @@ neuralnet { } param{ name: "w5" - init_method: kUniform - low:-0.05 - high:0.05 + init { + type : kUniform + low:-0.05 + high:0.05 + } } param{ name: "b5" - init_method: kUniform - low: -0.05 - high:0.05 + init { + type : kUniform + low: -0.05 + high:0.05 + } } } @@ -194,15 +216,19 @@ neuralnet { } param{ name: "w6" - init_method: kUniform - low:-0.05 - high:0.05 + init { + type : kUniform + low:-0.05 + high:0.05 + } } param{ name: "b6" - init_method: kUniform - low: -0.05 - high:0.05 + init { + type : kUniform + low: -0.05 + high:0.05 + } } } layer{ http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/6afa895b/include/driver.h ---------------------------------------------------------------------- diff --git a/include/driver.h b/include/driver.h index fcaab12..5a9ddfc 100644 --- a/include/driver.h +++ b/include/driver.h @@ -34,6 +34,16 @@ class Driver { template<typename Subclass, typename Type> int RegisterUpdater(const Type& type); /** + * Register a learning rate generator subclasses. + * + * @param type ID of the subclass. If called to register built-in subclasses, + * it is from ChangeMethod; if called to register user-defined + * subclass, it is a string; + * @return 0 if success; otherwise -1. + */ + template<typename Subclass, typename Type> + int RegisterLRGenerator(const Type& type); + /** * Register a Worker subclass. * * @param type ID of the subclass. If called to register built-in subclasses, @@ -54,6 +64,17 @@ class Driver { template<typename Subclass, typename Type> int RegisterParam(const Type& type); /** + * Register ParamGenerator subclasses for initalizing Param objects. + * + * @param type ID of the subclass. If called to register built-in subclasses, + * it is from InitMethod; if called to register user-defined + * subclass, it is a string; + * @return 0 if success; otherwise -1. + */ + template<typename Subclass, typename Type> + int RegisterParamGenerator(const Type& type); + + /** * Submit the job configuration for starting the job. * @param resume resume from last checkpoint if true. * @param job job configuration @@ -90,12 +111,24 @@ int Driver::RegisterParam(const Type& type) { return 1; } template<typename Subclass, typename Type> +int Driver::RegisterParamGenerator(const Type& type) { + auto factory = Singleton<Factory<singa::ParamGenerator>>::Instance(); + factory->Register(type, CreateInstance(Subclass, ParamGenerator)); + return 1; +} +template<typename Subclass, typename Type> int Driver::RegisterUpdater(const Type& type) { auto factory = Singleton<Factory<singa::Updater>>::Instance(); factory->Register(type, CreateInstance(Subclass, Updater)); return 1; } template<typename Subclass, typename Type> +int Driver::RegisterLRGenerator(const Type& type) { + auto factory = Singleton<Factory<singa::LRGenerator>>::Instance(); + factory->Register(type, CreateInstance(Subclass, LRGenerator)); + return 1; +} +template<typename Subclass, typename Type> int Driver::RegisterWorker(const Type& type) { auto factory = Singleton<Factory<singa::Worker>>::Instance(); factory->Register(type, CreateInstance(Subclass, Worker)); http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/6afa895b/include/trainer/worker.h ---------------------------------------------------------------------- diff --git a/include/trainer/worker.h b/include/trainer/worker.h index c50b54f..cc5a745 100644 --- a/include/trainer/worker.h +++ b/include/trainer/worker.h @@ -2,7 +2,6 @@ #define SINGA_TRAINER_WORKER_H_ #include "neuralnet/neuralnet.h" #include "proto/job.pb.h" -#include "utils/updater.h" #include "communication/socket.h" namespace singa { @@ -177,7 +176,6 @@ class Worker { JobProto job_conf_; shared_ptr<NeuralNet> train_net_, test_net_, validation_net_; Dealer* layer_dealer_, *dealer_; - Updater* updater_; }; class BPWorker: public Worker{ http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/6afa895b/include/utils/param.h ---------------------------------------------------------------------- diff --git a/include/utils/param.h b/include/utils/param.h index 83f64ed..f7a0982 100644 --- a/include/utils/param.h +++ b/include/utils/param.h @@ -6,6 +6,51 @@ #include "utils/blob.h" #include "communication/msg.h" +namespace singa { + +/** + * Base parameter generator which intializes parameter values. + */ + +class ParamGenerator { + public: + static ParamGenerator* Create(const ParamGenProto& proto); + virtual ~ParamGenerator() {} + + virtual void Init(const ParamGenProto& proto) { + proto_ = proto; + } + + virtual void Fill(Blob<float>* data); + + protected: + ParamGenProto proto_; +}; + +class GaussianGen: public ParamGenerator { + public: + void Fill(Blob<float>* data) override; +}; + +class UniformGen: public ParamGenerator { + public: + void Fill(Blob<float>* data) override; +}; + +class GaussianSqrtFanInGen: public GaussianGen { + public: + void Fill(Blob<float>* data) override; +}; + +class UniformSqrtFanInGen: public UniformGen { + public: + void Fill(Blob<float>* data) override; +}; + +class UniformSqrtFanInOutGen: public UniformGen { + public: + void Fill(Blob<float>* data) override; +}; /** * Base paramter class. * @@ -24,7 +69,6 @@ * load-balance among servers. Hence, we slice large Param objects into small * pieces. At the server side, one slice is a Param object. */ -namespace singa { class Param { public: static Param* Create(const ParamProto& proto); http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/6afa895b/include/utils/updater.h ---------------------------------------------------------------------- diff --git a/include/utils/updater.h b/include/utils/updater.h index 92ddf6c..46d2c53 100644 --- a/include/utils/updater.h +++ b/include/utils/updater.h @@ -6,60 +6,94 @@ namespace singa { /** + * Base learning rate generator. + * + * Generate learning rate for a give training step/iteration. + * There are many different ways to change the learning rate through time/step. + * Users can inherint this class to implment their own change method. + */ +class LRGenerator { + public: + static LRGenerator* Create(const LRGenProto& proto); + virtual ~LRGenerator() {} + + virtual void Init(const LRGenProto& proto) { + proto_ = proto; + } + + /** + * @param step training step/iteration. + * @return base learning rate regardless of step + */ + virtual float Get(int step) { + return proto_.base_lr(); + } + + protected: + LRGenProto proto_; +}; + +class FixedStepLRGen : public LRGenerator { + public: + float Get(int step) override; + private: + int last_idx_ = 0; +}; +class StepLRGen : public LRGenerator { + public: + float Get(int step) override; +}; +class LinearLRGen : public LRGenerator { + public: + float Get(int step) override; +}; +class ExpLRGen : public LRGenerator { + public: + float Get(int step) override; +}; +class InvLRGen : public LRGenerator { + public: + float Get(int step) override; +}; +class InvTLRGen : public LRGenerator { + public: + float Get(int step) override; +}; +/** * Updater for Param. */ class Updater{ public: static Updater* Create(const UpdaterProto& proto); virtual ~Updater() {} - virtual void Init(const UpdaterProto &proto) { - proto_ = proto; - } + virtual void Init(const UpdaterProto &proto); virtual void Update(int step, Param* param, float grad_scale = 1.0f) = 0; - float GetLearningRate(int step); - protected: UpdaterProto proto_; + LRGenerator* lr_gen_; + float weight_decay_; + float momentum_; }; class SGDUpdater : public Updater { public: - virtual void Init(const UpdaterProto& proto); - virtual void Update(int step, Param* param, float grad_scale = 1.0f); - - protected: - float base_lr_; - float momentum_; - float weight_decay_; + void Update(int step, Param* param, float grad_scale = 1.0f); }; class AdaGradUpdater : public Updater{ public: - virtual void Init(const UpdaterProto& proto); - virtual void Update(int step, Param* param, float grad_scale = 1.0f); - - protected: - float base_lr_; - float delta_; - float weight_decay_; + void Update(int step, Param* param, float grad_scale = 1.0f) override; }; class NesterovUpdater : public Updater { public: - virtual void Init(const UpdaterProto& proto); - virtual void Update(int step, Param* param, float grad_scale = 1.0f); - - protected: - float base_lr_; - float momentum_; - float weight_decay_; + void Update(int step, Param* param, float grad_scale = 1.0f) override; }; /* class RMSPropUpdater : public Updater{ public: - virtual void Init(const UpdaterProto& proto); virtual void Update(int step, Param* param, float grad_scale=1.0f); protected: @@ -71,7 +105,6 @@ class RMSPropUpdater : public Updater{ class AdaDeltaUpdater : public Updater{ public: - virtual void Init(const UpdaterProto& proto); virtual void Update(int step, Param* param, float grad_scale=1.0f); protected: http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/6afa895b/src/driver.cc ---------------------------------------------------------------------- diff --git a/src/driver.cc b/src/driver.cc index b79b609..1bc712d 100644 --- a/src/driver.cc +++ b/src/driver.cc @@ -50,18 +50,35 @@ void Driver::Init(int argc, char **argv) { RegisterLayer<LMDBDataLayer, int>(kLMDBData); #endif - // register updater + // register updaters RegisterUpdater<AdaGradUpdater>(kAdaGrad); RegisterUpdater<NesterovUpdater>(kNesterov); // TODO(wangwei) RegisterUpdater<kRMSPropUpdater>(kRMSProp); RegisterUpdater<SGDUpdater>(kSGD); - // register worker + // register learning rate change methods + RegisterLRGenerator<LRGenerator>(kFixed); + RegisterLRGenerator<FixedStepLRGen>(kFixedStep); + RegisterLRGenerator<StepLRGen>(kStep); + RegisterLRGenerator<LinearLRGen>(kLinear); + RegisterLRGenerator<ExpLRGen>(kExponential); + RegisterLRGenerator<InvLRGen>(kInverse); + RegisterLRGenerator<InvTLRGen>(kInverseT); + + // register workers RegisterWorker<BPWorker>(kBP); RegisterWorker<CDWorker>(kCD); - // register param + // register params RegisterParam<Param>(0); + + // register param init methods + RegisterParamGenerator<ParamGenerator>(kConstant); + RegisterParamGenerator<GaussianGen>(kGaussian); + RegisterParamGenerator<UniformGen>(kUniform); + RegisterParamGenerator<GaussianSqrtFanInGen>(kGaussianSqrtFanIn); + RegisterParamGenerator<UniformSqrtFanInGen>(kUniformSqrtFanIn); + RegisterParamGenerator<UniformSqrtFanInOutGen>(kUniformSqrtFanInOut); } http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/6afa895b/src/proto/job.proto ---------------------------------------------------------------------- diff --git a/src/proto/job.proto b/src/proto/job.proto index 80c8b65..b4abe68 100644 --- a/src/proto/job.proto +++ b/src/proto/job.proto @@ -101,21 +101,11 @@ message UpdaterProto { // configuration for RMSProp algorithm optional RMSPropProto rmsprop_conf = 3; - // built-in change method for learning rate - optional ChangeMethod lr_change = 10 [default = kUserChange]; - // user-defined change method - optional string user_lr_change = 11; - - optional FixedStepProto fixedstep_conf = 40; - optional StepProto step_conf = 41; - optional LinearProto linear_conf = 42; - optional ExponentialProto exponential_conf = 43; - optional InverseProto inverse_conf = 44; - optional InverseTProto inverset_conf = 45; + // learning rate generator + optional LRGenProto learning_rate = 11; optional float momentum = 31 [default = 0]; optional float weight_decay = 32 [default = 0]; - // base learning rate - optional float base_lr = 34 [default = 0]; + // used to avoid divide by 0, i.e. x/(y+delta) optional float delta = 35 [default = 0.00000001]; @@ -220,24 +210,13 @@ message LayerProto { message ParamProto { // used for identifying the same params from diff models and display deug info optional string name = 1 [default = ""]; - optional InitMethod init_method = 2 [default = kGaussian]; // for built-in Param optional ParamType type = 3 [default = kParam]; // for user-defined Param optional string user_type = 4; - // constant init - optional float value = 5 [default = 1]; - // for uniform sampling - optional UniformProto uniform_conf = 6; - optional float low = 7 [default = -1]; - optional float high = 8 [default = 1]; - - // for gaussian sampling - optional GaussianProto gaussian_conf = 9; - optional float mean = 10 [default = 0]; - optional float std = 11 [default = 1]; - // multiplied on the global learning rate. + optional ParamGenProto init =5; + // multiplied on the global learning rate. optional float lr_scale = 15 [default = 1]; // multiplied on the global weight decay. optional float wd_scale = 16 [default = 1]; @@ -260,6 +239,38 @@ message ParamProto { // --------------------------- // protos for different layers // --------------------------- +// learning rate generator proto +message LRGenProto { + // user-defined change method + optional ChangeMethod type = 1 [default = kUserChange]; + optional string user_type = 2; + + optional float base_lr = 3 [default = 0.01]; + + optional FixedStepProto fixedstep_conf = 40; + optional StepProto step_conf = 41; + optional LinearProto linear_conf = 42; + optional ExponentialProto exponential_conf = 43; + optional InverseProto inverse_conf = 44; + optional InverseTProto inverset_conf = 45; + + extensions 101 to 200; +} + +message ParamGenProto { + optional InitMethod type = 1 [default = kUserInit]; + optional string user_type =2; + // constant init + optional float value = 3 [default = 1]; + // for gaussian sampling + optional float mean = 4 [default = 0]; + optional float std = 5 [default = 1]; + // for uniform sampling + optional float low = 8 [default = -1]; + optional float high = 9 [default = 1]; + + extensions 101 to 200; +} message RGBImageProto { // scale factor for each pixel @@ -476,11 +487,9 @@ enum InitMethod { kGaussian = 1; // uniform sampling between low and high kUniform = 2; - // copy the content and history which are from previous training - kPretrained = 3; // from Toronto Convnet, let a=1/sqrt(fan_in), w*=a after generating from // Gaussian distribution - kGaussainSqrtFanIn = 4; + kGaussianSqrtFanIn = 4; // from Toronto Convnet, rectified linear activation, let // a=sqrt(3)/sqrt(fan_in), range is [-a, +a]; no need to set value=sqrt(3), // the program will multiply it. http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/6afa895b/src/trainer/server.cc ---------------------------------------------------------------------- diff --git a/src/trainer/server.cc b/src/trainer/server.cc index a8483de..1fda336 100644 --- a/src/trainer/server.cc +++ b/src/trainer/server.cc @@ -21,7 +21,6 @@ void Server::Setup(const UpdaterProto& proto, std::unordered_map<int, ParamEntry*>* shard, const vector<int>& slice2group) { updater_ = Updater::Create(proto); - updater_->Init(proto); shard_ = shard; slice2group_ = slice2group; } http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/6afa895b/src/trainer/worker.cc ---------------------------------------------------------------------- diff --git a/src/trainer/worker.cc b/src/trainer/worker.cc index 25fea7c..e047367 100644 --- a/src/trainer/worker.cc +++ b/src/trainer/worker.cc @@ -24,7 +24,6 @@ void Worker::Init(int thread_id, int grp_id, int id) { grp_id_ = grp_id; id_ = id; layer_dealer_ = dealer_ = nullptr; - updater_ = nullptr; } void Worker::Setup( @@ -141,10 +140,8 @@ void ConnectStub(int grp, int id, Dealer* dealer, EntityType entity) { void Worker::Run() { LOG(ERROR) << "Worker (group = " << grp_id_ <<", id = " << id_ << ") start"; auto cluster = Cluster::Get(); - if (updater_==nullptr) { - int svr_grp = grp_id_ / cluster->nworker_groups_per_server_group(); - CHECK(cluster->runtime()->JoinSGroup(grp_id_, id_, svr_grp)); - } + int svr_grp = grp_id_ / cluster->nworker_groups_per_server_group(); + CHECK(cluster->runtime()->JoinSGroup(grp_id_, id_, svr_grp)); dealer_ = new Dealer(2*thread_id_); ConnectStub(grp_id_, id_, dealer_, kWorkerParam); for (auto layer : train_net_->layers()) { @@ -190,10 +187,7 @@ void Worker::Run() { Checkpoint(step_, train_net_); // clean up - if(updater_ == nullptr) { - int svr_grp = grp_id_ / cluster->nworker_groups_per_server_group(); - cluster->runtime()->LeaveSGroup(grp_id_, id_, svr_grp); - } + cluster->runtime()->LeaveSGroup(grp_id_, id_, svr_grp); // notify the stub on worker stop Msg* msg=new Msg(Addr(grp_id_, id_, kWorkerParam), Addr(-1,-1, kStub)); msg->set_type(kStop); @@ -224,15 +218,10 @@ int Worker::Get(Param* param, int step) { int Worker::Update(Param* param, int step) { param->set_local_version(param->version()); - if (updater_) { - updater_->Update(step, param); - param->set_version(param->version() + 1); - } else { - Msg* msg=new Msg(Addr(grp_id_, id_, kWorkerParam), Addr(-1, -1, kStub)); - msg->set_trgt(ParamTrgt(param->owner(), 0), step); - msg->set_type(kUpdate); - dealer_->Send(&msg); - } + Msg* msg=new Msg(Addr(grp_id_, id_, kWorkerParam), Addr(-1, -1, kStub)); + msg->set_trgt(ParamTrgt(param->owner(), 0), step); + msg->set_type(kUpdate); + dealer_->Send(&msg); return 1; } http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/6afa895b/src/utils/param.cc ---------------------------------------------------------------------- diff --git a/src/utils/param.cc b/src/utils/param.cc index e658631..67f14ab 100644 --- a/src/utils/param.cc +++ b/src/utils/param.cc @@ -12,6 +12,60 @@ using namespace mshadow; using std::vector; using std::string; +ParamGenerator* ParamGenerator::Create(const ParamGenProto& proto) { + auto factory = Singleton<Factory<ParamGenerator>>::Instance(); + ParamGenerator * gen = nullptr; + if (proto.has_user_type()) + gen = factory->Create(proto.user_type()); + else + gen = factory->Create(proto.type()); + gen->Init(proto); + return gen; +} + +void ParamGenerator::Fill (Blob<float>* blob) { + Tensor<cpu, 1> data(blob->mutable_cpu_data(), Shape1(blob->count())); + data = proto_.value(); +} +void GaussianGen::Fill (Blob<float>* blob) { + Tensor<cpu, 1> data(blob->mutable_cpu_data(), Shape1(blob->count())); + auto random = TSingleton<Random<cpu>>::Instance(); + random->SampleGaussian(data, proto_.mean(), proto_.std()); + if(proto_.value() != 1) + data *= proto_.value(); +} +void GaussianSqrtFanInGen::Fill (Blob<float>* blob) { + // only valid for param matrix with num of cols as fan in + CHECK_EQ(blob->shape().size(), 2); + Tensor<cpu, 1> data(blob->mutable_cpu_data(), Shape1(blob->count())); + GaussianGen::Fill(blob); + data /= sqrt(blob->shape().at(1)); +} + +void UniformGen::Fill (Blob<float>* blob) { + Tensor<cpu, 1> data(blob->mutable_cpu_data(), Shape1(blob->count())); + auto random = TSingleton<Random<cpu>>::Instance(); + random->SampleUniform(data, proto_.low(), proto_.high()); + if(proto_.value() != 1) + data *= proto_.value(); +} + +void UniformSqrtFanInGen::Fill (Blob<float>* blob) { + // only valid for param matrix with num of cols as fan in + CHECK_EQ(blob->shape().size(), 2); + Tensor<cpu, 1> data(blob->mutable_cpu_data(), Shape1(blob->count())); + UniformGen::Fill(blob); + data /= sqrt(blob->shape().at(1) / 3.0f); +} + +void UniformSqrtFanInOutGen::Fill (Blob<float>* blob) { + // only valid for param matrix with num of cols as fan in + CHECK_EQ(blob->shape().size(), 2); + Tensor<cpu, 1> data(blob->mutable_cpu_data(), Shape1(blob->count())); + UniformGen::Fill(blob); + data /= sqrt(blob->shape()[0] + blob->shape()[1]); +} +/*****************Param***********************************/ Param* Param::Create(const ParamProto& proto) { Factory<Param>* factory=Singleton<Factory<Param>>::Instance(); Param* p = nullptr; @@ -51,43 +105,8 @@ void Param::AddSlice(int slice_id, int size) { } void Param::InitValues(int version) { - Tensor<cpu, 1> data(mutable_cpu_data(), Shape1(size())); - auto random = TSingleton<Random<cpu>>::Instance(); - switch (proto_.init_method()) { - case InitMethod::kConstant: - data = proto_.value(); - break; - case InitMethod::kUniform: - random->SampleUniform(data, proto_.low(), proto_.high()); - if(proto_.value() != 1) - data *= proto_.value(); - break; - case InitMethod::kUniformSqrtFanIn: - // only valid for param matrix with num of cols as fan in - CHECK_EQ(data_->shape().size(), 2); - random->SampleUniform(data, proto_.low(), proto_.high()); - data *= proto_.value() / sqrt(data_->shape().at(1) / 3.0f); - break; - case InitMethod::kUniformSqrtFanInOut: - random->SampleUniform(data, proto_.low(), proto_.high()); - if (proto_.value()) - data *= proto_.value() / sqrt(data_->shape()[0] + data_->shape()[1]); - break; - case InitMethod::kGaussian: - random->SampleGaussian(data, proto_.mean(), proto_.std()); - if(proto_.value() != 1) - data *= proto_.value(); - break; - case InitMethod::kGaussainSqrtFanIn: - // only valid for param matrix with num of cols as fan in - CHECK_EQ(data_->shape().size(), 2); - random->SampleGaussian(data, proto_.mean(), proto_.std()); - data *= proto_.value() / sqrt(data_->shape().at(1)); - break; - default: - LOG(ERROR) << "Illegal parameter init method "; - break; - } + ParamGenerator* gen = ParamGenerator::Create(proto_.init()); + gen->Fill(data_.get()); set_version(version); } void Param::FromProto(const BlobProto& blob) { http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/6afa895b/src/utils/updater.cc ---------------------------------------------------------------------- diff --git a/src/utils/updater.cc b/src/utils/updater.cc index 7d80844..24487d3 100644 --- a/src/utils/updater.cc +++ b/src/utils/updater.cc @@ -10,6 +10,53 @@ namespace singa { using namespace mshadow; using namespace mshadow::expr; +/**********************Learning rate generator******************************/ +LRGenerator* LRGenerator::Create(const LRGenProto& proto) { + auto factory = Singleton<Factory<LRGenerator>>::Instance(); + LRGenerator* gen = nullptr; + if (proto.has_user_type()) + gen = factory->Create(proto.user_type()); + else + gen = factory->Create(proto.type()); + gen->Init(proto); + return gen; +} + +float FixedStepLRGen::Get(int step) { + if (last_idx_ < proto_.fixedstep_conf().step_size() -1 + && step >= proto_.fixedstep_conf().step(last_idx_ + 1)) { + last_idx_ ++; + } + return proto_.fixedstep_conf().step_lr(last_idx_); +} + +float StepLRGen::Get(int step) { + // do not cast int to float + int freq = proto_.step_conf().change_freq(); + return proto_.base_lr() * pow(proto_.step_conf().gamma(), step / freq); +} + +float LinearLRGen::Get(int step) { + int freq = proto_.linear_conf().change_freq(); + float r = step * 1.0 / freq; + return (1.0 - r) * proto_.base_lr() + r * proto_.linear_conf().final_lr(); +} + +float ExpLRGen::Get(int step) { + int freq = proto_.exponential_conf().change_freq(); + return proto_.base_lr() / pow(2, step * 1. / freq); +} + +float InvLRGen::Get(int step) { + return proto_.base_lr() * pow(1.f + proto_.inverse_conf().gamma() * step, + - proto_.inverse_conf().pow()); +} + +float InvTLRGen::Get(int step) { + return proto_.base_lr() / (1 + step * 1. / proto_.inverset_conf().final_lr()); +} + +/***********************Updater********************************/ Updater* Updater::Create(const UpdaterProto& proto) { auto factory = Singleton<Factory<Updater>>::Instance(); @@ -18,69 +65,23 @@ Updater* Updater::Create(const UpdaterProto& proto) { updater = factory->Create(proto.user_type()); else updater = factory->Create(proto.type()); + updater->Init(proto); return updater; } -float Updater::GetLearningRate(int step) { - float ret = 0., r = 0., base = proto_.base_lr(); - int freq = 0; - switch (proto_.lr_change()) { - case ChangeMethod::kFixed: - ret = base; - break; - case ChangeMethod::kLinear: - // a is init, b is the final - freq = proto_.linear_conf().change_freq(); - r = step * 1.0 / freq; - ret = (1.0 - r) * base + r * proto_.linear_conf().final_lr(); - break; - case ChangeMethod::kExponential: - // a is init, b is the final, from convnet - freq = proto_.exponential_conf().change_freq(); - ret = base / pow(2, step * 1. / freq); - break; - case ChangeMethod::kInverseT: - // a is init, b is the final, from convnet - CHECK_EQ(base, 2 * proto_.inverset_conf().final_lr()) - << "final value should be the half"; - ret = base / (1. + step * 1. / proto_.inverset_conf().final_lr()); - break; - case ChangeMethod::kInverse: - // a is init, b is gamma, c is pow - ret = base * pow(1.f + proto_.inverse_conf().gamma() * step, - - proto_.inverse_conf().pow()); - break; - case ChangeMethod::kStep: - // a is the base learning rate, b is gamma, from caffe - // notice it is step/change_steps, not step*1.0/change_steps - freq = proto_.step_conf().change_freq(); - ret = base * pow(proto_.step_conf().gamma(), step / freq); - break; - case ChangeMethod::kFixedStep: - for (int i = 0; i < proto_.fixedstep_conf().step_size(); i++) { - if (step > proto_.fixedstep_conf().step(i)) - ret = proto_.fixedstep_conf().step_lr(i); - } - break; - default: - LOG(ERROR) << "Wrong hyper-parameter update method"; - } - return ret; -} /***********************SGD with momentum******************************/ -void SGDUpdater::Init(const UpdaterProto& proto) { - Updater::Init(proto); - base_lr_ = proto.base_lr(); +void Updater::Init(const UpdaterProto& proto) { momentum_ = proto.momentum(); weight_decay_ = proto.weight_decay(); + lr_gen_ = LRGenerator::Create(proto.learning_rate()); } void SGDUpdater::Update(int step, Param* param, float grad_scale) { Shape<1> s = Shape1(param->size()); Tensor<cpu, 1> data(param->mutable_cpu_data(), s); Tensor<cpu, 1> grad(param->mutable_cpu_grad(), s); - float lr = GetLearningRate(step)*param->lr_scale(); - float wd = weight_decay_*param->wd_scale(); + float lr = lr_gen_->Get(step) * param->lr_scale(); + float wd = weight_decay_ * param->wd_scale(); if (grad_scale != 1.f) grad *= grad_scale; if (wd > 0) { // L2 regularization, should be done after timing grad_scale @@ -97,20 +98,13 @@ void SGDUpdater::Update(int step, Param* param, float grad_scale) { } /***********************Nesterov******************************/ -void NesterovUpdater::Init(const UpdaterProto& proto) { - Updater::Init(proto); - base_lr_ = proto.base_lr(); - CHECK_GT(base_lr_, 0); - weight_decay_ = proto.weight_decay(); -} - void NesterovUpdater::Update(int step, Param* param, float grad_scale) { Shape<1> s = Shape1(param->size()); Tensor<cpu, 1> data(param->mutable_cpu_data(), s); Tensor<cpu, 1> grad(param->mutable_cpu_grad(), s); Tensor<cpu, 1> history(param->mutable_cpu_history(), s); TensorContainer<cpu, 1> tmp(s); - float lr = GetLearningRate(step)*param->lr_scale(); + float lr = lr_gen_->Get(step)*param->lr_scale(); float wd = weight_decay_*param->wd_scale(); if (grad_scale != 1.f) grad *= grad_scale; @@ -123,20 +117,12 @@ void NesterovUpdater::Update(int step, Param* param, float grad_scale) { data -= tmp; } /***********************AdaGrad******************************/ -void AdaGradUpdater::Init(const UpdaterProto& proto) { - Updater::Init(proto); - base_lr_ = proto.base_lr(); - CHECK_GT(base_lr_, 0); - delta_ = proto.delta(); - weight_decay_ = proto.weight_decay(); -} - void AdaGradUpdater::Update(int step, Param* param, float grad_scale) { Shape<1> s = Shape1(param->size()); Tensor<cpu, 1> data(param->mutable_cpu_data(), s); Tensor<cpu, 1> grad(param->mutable_cpu_grad(), s); Tensor<cpu, 1> history(param->mutable_cpu_history(), s); - float lr = GetLearningRate(step)*param->lr_scale(); + float lr = lr_gen_->Get(step)*param->lr_scale(); float wd = weight_decay_*param->wd_scale(); if (grad_scale != 1.f) grad *= grad_scale; @@ -144,7 +130,7 @@ void AdaGradUpdater::Update(int step, Param* param, float grad_scale) { grad += data * wd; } history += F<op::square>(grad); - data -= lr * grad / (F<op::sqrtop>(history, delta_)); + data -= lr * grad / (F<op::sqrtop>(history, proto_.delta())); } /***********************RMSProp******************************
