SINGA-54 Refactor job configuration to move fields in ModelProto out format job.proto, move important enum types as global
Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/4dee7b9c Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/4dee7b9c Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/4dee7b9c Branch: refs/heads/master Commit: 4dee7b9cd0f07eff4906e2398b7ad7f23691a508 Parents: 1b574f3 Author: wang sheng <[email protected]> Authored: Fri Aug 14 21:56:35 2015 +0800 Committer: wang sheng <[email protected]> Committed: Fri Aug 14 22:17:16 2015 +0800 ---------------------------------------------------------------------- include/neuralnet/base_layer.h | 1 - src/neuralnet/neuralnet.cc | 14 +- src/proto/job.proto | 450 ++++++++++++++++++++---------------- src/utils/param.cc | 12 +- src/utils/updater.cc | 14 +- 5 files changed, 272 insertions(+), 219 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/4dee7b9c/include/neuralnet/base_layer.h ---------------------------------------------------------------------- diff --git a/include/neuralnet/base_layer.h b/include/neuralnet/base_layer.h index 25df95f..508fe18 100644 --- a/include/neuralnet/base_layer.h +++ b/include/neuralnet/base_layer.h @@ -20,7 +20,6 @@ using std::vector; using std::string; using std::map; - class Layer; /** * Base layer class. http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/4dee7b9c/src/neuralnet/neuralnet.cc ---------------------------------------------------------------------- diff --git a/src/neuralnet/neuralnet.cc b/src/neuralnet/neuralnet.cc index 4732a36..e2565e3 100644 --- a/src/neuralnet/neuralnet.cc +++ b/src/neuralnet/neuralnet.cc @@ -7,10 +7,10 @@ namespace singa { // macros to shorten the code -#define LayerT(x) LayerProto_LayerType_k##x +#define LayerT(x) LayerType::k##x #define RegisterLayer(factory, id) \ - factory->Register(LayerProto_LayerType_k##id, \ + factory->Register(LayerType::k##id, \ CreateInstance(id##Layer, Layer)) void NeuralNet::RegisterLayers() { @@ -195,7 +195,7 @@ Node* SliceNode(Graph* graph, Node* srcnode, string name = srcnode->name + "<"; LayerProto *proto = new LayerProto(); proto->set_name(name); - proto->set_type(LayerProto_LayerType_kSlice); + proto->set_type(LayerType::kSlice); proto->set_partition_id( static_cast<LayerProto*>(srcnode->proto)->partition_id()); auto conf = proto->mutable_slice_conf(); @@ -215,7 +215,7 @@ Node* ConcateNodes(Graph* graph, const vector<Node*>& srcnodes, Node* dstnode) { string name = ">" + dstnode->name; LayerProto *proto = new LayerProto(); proto->set_name(name); - proto->set_type(LayerProto_LayerType_kConcate); + proto->set_type(LayerType::kConcate); proto->set_partition_id( static_cast<LayerProto*>(dstnode->proto)->partition_id()); auto conf = proto->mutable_concate_conf(); @@ -234,7 +234,7 @@ Node* SplitNode(Graph* graph, Node* srcnode, const vector<Node*>& dstnodes) { string name = srcnode->name + "+"; LayerProto *proto = new LayerProto(); proto->set_name(name); - proto->set_type(LayerProto_LayerType_kSplit); + proto->set_type(LayerType::kSplit); proto->set_partition_id( static_cast<LayerProto*>(srcnode->proto)->partition_id()); Node* node = new Node(name, "##" + name, proto->partition_id(), proto); @@ -251,14 +251,14 @@ void BridgeNodes(Graph* graph, Node* srcnode, Node* dstnode) { string sname = srcnode->name + ":-"; LayerProto *sproto = new LayerProto(); sproto->set_name(sname); - sproto->set_type(LayerProto_LayerType_kBridgeSrc); + sproto->set_type(LayerType::kBridgeSrc); sproto->set_partition_id( static_cast<LayerProto*>(srcnode->proto)->partition_id()); auto sbridge = new Node(sname, "##" + sname, sproto->partition_id(), sproto); string dname = "-:" + dstnode->name; LayerProto *dproto = new LayerProto(); dproto->set_name(dname); - dproto->set_type(LayerProto_LayerType_kBridgeDst); + dproto->set_type(LayerType::kBridgeDst); dproto->set_partition_id( static_cast<LayerProto*>(dstnode->proto)->partition_id()); auto dbridge = new Node(dname, "##" + dname, dproto->partition_id(), dproto); http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/4dee7b9c/src/proto/job.proto ---------------------------------------------------------------------- diff --git a/src/proto/job.proto b/src/proto/job.proto index a67d330..7c734bf 100644 --- a/src/proto/job.proto +++ b/src/proto/job.proto @@ -1,11 +1,21 @@ package singa; -enum TrainOneBatchAlg { - // Back-propagation algorithm for feed-forward models, e.g., CNN, and RNN - kBP = 1; - // Contrastive Divergence algorithm for RBM, DBM etc. - kCD = 2; -} +// To start a training job, all we need is a JobProto object. +// It should contain following fields +// - Job Name (name) +// the name to identify the job +// - NeuralNet (neuralnet) +// the neural network structure contains a set of layers +// - Train One Batch (alg) +// the training algorithm +// - Updater (updater) +// the protocol for updating parameters at server side +// - Cluster Topology (cluster) +// the distributed topology of workers/servers +// - Training Steps (train_steps) +// the number of training iteration +// All other fields/functions are optional, e.g., test, checkpoint +// message JobProto { // job name, e.g., "cifar10-dcnn", "mnist-mlp" required string name = 1; @@ -28,7 +38,8 @@ message JobProto { // frequency of test, e.g., do test every 100 training steps optional int32 test_freq = 20 [default = 0]; - // total num of steps for testing all test data; todo set -1 for test forever + // total num of steps for testing all test data; + // TODO(wangwei): set -1 for test forever optional int32 test_steps = 21 [default = 0]; // frequency of validation, e.g., do validation every 100 training steps optional int32 valid_freq = 25 [default = 0]; @@ -57,7 +68,10 @@ message JobProto { // start validation after this num steps optional int32 valid_after = 83 [default = 0]; - // used by SINGA; uses typically do not touch these fields + // for internal use + // users typically do not touch following fields + + // resume flag optional bool resume = 90 [default = false]; // last snapshot step optional int32 step = 91 [default = 0]; @@ -65,9 +79,41 @@ message JobProto { optional int32 id = 92 [default = -1]; } -message CDProto { - //number of steps for gibbs sampling - optional int32 pcd_k = 1 [default = 1]; +// ----------------------- +// Protos used by JobProto +// ----------------------- + +message NetProto { + repeated LayerProto layer = 1; + // partitioning type for parallelism + optional int32 partition_dim = 20 [default = 0]; +} + +message UpdaterProto { + // updater type + required UpdaterType type = 1 [default = kSGD]; + // configuration for RMSProp algorithm + optional RMSPropProto rmsprop_conf = 50; + + // change method for learning rate + required ChangeMethod lr_change = 2 [default = kFixed]; + + // proto of change method + oneof change_conf { + FixedStepProto fixedstep_conf = 40; + StepProto step_conf = 41; + LinearProto linear_conf = 42; + ExponentialProto exponential_conf = 43; + InverseProto inverse_conf = 44; + InverseTProto inverset_conf = 45; + } + + optional float momentum = 31 [default = 0]; + optional float weight_decay = 32 [default = 0]; + // base learning rate + optional float base_lr = 34 [default = 0]; + // used to avoid divide by 0, i.e. x/(y+delta) + optional float delta = 35 [default = 0.00000001]; } message ClusterProto { @@ -83,64 +129,86 @@ message ClusterProto { // servers and workers in different processes? optional bool server_worker_separate = 20 [default = false]; - // port number is used by ZeroMQ + // port number used by ZeroMQ optional int32 start_port = 60 [default = 6723]; - // conduct updates at server side; otherwise do it at worker side + // conduct updates at server side; otherwise do it at worker side optional bool server_update = 61 [default = true]; // share memory space between worker groups in one procs optional bool share_memory = 62 [default = true]; // bandwidth of ethernet, Bytes per second, default is 1 Gbps - optional int32 bandwidth=80 [default=134217728]; + optional int32 bandwidth = 80 [default = 134217728]; // poll time in milliseconds - optional int32 poll_time=81 [default =100]; + optional int32 poll_time = 81 [default = 100]; } - -enum Phase { - kTrain = 0; - kValidation = 1; - kTest= 2; - // postivie phase for contrastive divergence algorithm - kPositive = 3; - // negative phase for contrastive divergence algorithm - kNegative = 4; - kForward = 5; - kBackward = 6; - kLoss = 7; +message CDProto { + //number of steps for gibbs sampling + optional int32 pcd_k = 1 [default = 1]; } -message NetProto { - repeated LayerProto layer = 1; - // partitioning type for parallelism - optional int32 partition_dim = 20 [default = 0]; +message LayerProto { + // the layer name used for identification + required string name = 1; + // source layer names + repeated string srclayers = 3; + // parameters, e.g., weight matrix or bias vector + repeated ParamProto param = 12; + // all layers are included in the net structure for training phase by default. + // some layers like data layer for loading test data are not used by training + // phase should be removed by setting the exclude field. + repeated Phase exclude = 15; + // the layer type + required LayerType type = 20; + // proto for the specific layer + oneof layer_conf { + // configuration for convolution layer + ConvolutionProto convolution_conf = 30; + // configuration for concatenation layer + ConcateProto concate_conf = 31; + // configuration for dropout layer + DropoutProto dropout_conf = 33; + // configuration for inner product layer + InnerProductProto innerproduct_conf = 34; + // configuration for local response normalization layer + DataProto lmdbdata_conf = 35; + // configuration for local response normalization layer + LRNProto lrn_conf = 45; + // configuration for mnist parser layer + MnistProto mnist_conf = 36; + // configuration for pooling layer + PoolingProto pooling_conf = 37; + // configuration for prefetch layer + PrefetchProto prefetch_conf = 44; + // configuration for rectified linear unit layer + ReLUProto relu_conf = 38; + // configuration for rgb image parser layer + RGBImageProto rgbimage_conf = 39; + // configuration for data layer + DataProto sharddata_conf = 32; + // configuration for slice layer + SliceProto slice_conf = 41; + // configuration for softmax loss layer + SoftmaxLossProto softmaxloss_conf = 40; + // configuration for split layer + SplitProto split_conf = 42; + // configuration for tanh layer + TanhProto tanh_conf = 43; + // configuration for rbmvis layer + RBMVisProto rbmvis_conf = 48; + // configuration for rbmhid layer + RBMHidProto rbmhid_conf = 49; + } + + // overrides the partition dimension for neural net + optional int32 partition_dim = 60 [default = -1]; + // names of parameters shared from other layers + optional int32 partition_id = 90 [default = 0]; } -// weight matrix should be defined before bias vector; -// todo separate conf for diff init method +// weight matrix should be defined before bias vector +// TODO(wangwei): separate conf for diff init method message ParamProto { - enum InitMethod { - // fix the values of all parameters a constant in the value field - kConstant = 0; - // sample gaussian with std and mean - kGaussian = 1; - // uniform sampling between low and high - kUniform = 2; - // copy the content and history which are from previous training - kPretrained = 3; - // from Toronto Convnet, let a=1/sqrt(fan_in), w*=a after generating from - // Gaussian distribution - kGaussainSqrtFanIn = 4; - // from Toronto Convnet, rectified linear activation, let - // a=sqrt(3)/sqrt(fan_in), range is [-a, +a]; no need to set value=sqrt(3), - // the program will multiply it. - kUniformSqrtFanIn = 5; - // from Theano MLP tutorial, let a=sqrt(6/(fan_in+fan_out)). for tanh - // activation, range is [-a, +a], for sigmoid activation, range is - // [-4a, +4a], put the scale factor to value field. - // <a href="http://deeplearning.net/tutorial/mlp.html"> Theano MLP</a> - kUniformSqrtFanInOut = 6; - } // used for identifying the same params from diff models and display deug info optional string name = 1 [default = ""]; optional InitMethod init_method = 2 [default = kGaussian]; @@ -157,7 +225,7 @@ message ParamProto { // multiplied on the global weight decay. optional float weight_decay_multiplier = 16 [default = 1]; - // name of the owner param from which this param shares the values + // name of the owner param from which this param shares the values optional string share_from = 60; // used interally @@ -170,91 +238,9 @@ message ParamProto { repeated int32 shape = 93; } -enum PartitionType{ - kDataPartition=0; - kLayerPartition=1; - kNone=2; -} - -message LayerProto { - // the layer name used for identification - required string name = 1; - enum LayerType{ - kBridgeSrc = 15; - kBridgeDst = 16; - kConvolution = 1; - kConcate = 2; - kShardData = 3; - kDropout = 4; - kInnerProduct = 5; - kLabel = 18; - kLMDBData = 17; - kLRN = 6; - kMnist = 7; - kPooling = 8; - kPrefetch = 19; - kReLU = 9; - kRGBImage = 10; - kSoftmaxLoss = 11; - kSlice = 12; - kSplit = 13; - kTanh = 14; - kRBMVis = 23; - kRBMHid = 24; - } - // source layer names - repeated string srclayers = 3; - // parameters, e.g., weight matrix or bias vector - repeated ParamProto param = 12; - // all layers are included in the net structure for training phase by default. - // some layers like data layer for loading test data are not used by training - // phase should be removed by setting the exclude field. - repeated Phase exclude = 15; - // the layer type from the enum above - required LayerType type = 20; - // configuration for convolution layer - optional ConvolutionProto convolution_conf = 30; - // configuration for concatenation layer - optional ConcateProto concate_conf = 31; - // configuration for dropout layer - optional DropoutProto dropout_conf = 33; - // configuration for inner product layer - optional InnerProductProto innerproduct_conf = 34; - // configuration for local response normalization layer - optional DataProto lmdbdata_conf = 35; - // configuration for local response normalization layer - optional LRNProto lrn_conf = 45; - // configuration for mnist parser layer - optional MnistProto mnist_conf= 36; - // configuration for pooling layer - optional PoolingProto pooling_conf = 37; - // configuration for prefetch layer - optional PrefetchProto prefetch_conf = 44; - // configuration for rectified linear unit layer - optional ReLUProto relu_conf = 38; - // configuration for rgb image parser layer - optional RGBImageProto rgbimage_conf = 39; - // configuration for data layer - optional DataProto sharddata_conf = 32; - // configuration for slice layer - optional SliceProto slice_conf = 41; - // configuration for softmax loss layer - optional SoftmaxLossProto softmaxloss_conf = 40; - // configuration for split layer - optional SplitProto split_conf = 42; - // configuration for tanh layer - optional TanhProto tanh_conf = 43; - // configuration for rbmvis layer - optional RBMVisProto rbmvis_conf = 48; - // configuration for rbmhid layer - optional RBMHidProto rbmhid_conf = 49; - - - // overrides the partition dimension for neural net - optional int32 partition_dim =60 [default = -1]; - // names of parameters shared from other layers - optional int32 partition_id = 90 [default = 0]; -} +// --------------------------- +// protos for different layers +// --------------------------- message RGBImageProto { // scale factor for each pixel @@ -272,7 +258,7 @@ message PrefetchProto { } message SplitProto { - optional int32 num_splits = 1 [default =1]; + optional int32 num_splits = 1 [default = 1]; } // scaled tan: A*tan(B*x) @@ -287,14 +273,14 @@ message SoftmaxLossProto { // computing accuracy against topk results optional int32 topk = 1 [default = 1]; // loss scale factor - optional float scale= 30 [default = 1]; + optional float scale = 30 [default = 1]; } message ConvolutionProto { // The number of outputs for the layer required int32 num_filters = 1; // the kernel height/width - required int32 kernel= 2; + required int32 kernel = 2; // The padding height/width optional int32 pad = 30 [default = 0]; @@ -377,7 +363,7 @@ message LRNProto { // normalization objective optional NormRegion norm_region = 33 [default = ACROSS_CHANNELS]; // offset - optional float knorm =34 [default = 1.0]; + optional float knorm = 34 [default = 1.0]; } message PoolingProto { @@ -395,7 +381,7 @@ message PoolingProto { optional uint32 stride = 32 [default = 1]; } -message SliceProto{ +message SliceProto { required int32 slice_dim = 1; } @@ -406,83 +392,151 @@ message ReLUProto { optional float negative_slope = 1 [default = 0]; } -message UpdaterProto { - enum UpdaterType{ - // noraml SGD with momentum and weight decay - kSGD = 1; - // adaptive subgradient, http://www.magicbroom.info/Papers/DuchiHaSi10.pdf - kAdaGrad = 2; - // http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf - kRMSProp = 3; - // Nesterov first optimal gradient method - kNesterov = 4; - } - // updater type - required UpdaterType type = 1 [default=kSGD]; - // configuration for RMSProp algorithm - optional RMSPropProto rmsprop_conf = 50; - - enum ChangeMethod { - kFixed = 0; - kInverseT = 1; - kInverse = 2; - kExponential = 3; - kLinear = 4; - kStep = 5; - kFixedStep = 6; - } - // change method for learning rate - required ChangeMethod lr_change= 2 [default = kFixed]; - - optional FixedStepProto fixedstep_conf=40; - optional StepProto step_conf=41; - optional LinearProto linear_conf=42; - optional ExponentialProto exponential_conf=43; - optional InverseProto inverse_conf=44; - optional InverseTProto inverset_conf=45; - - optional float momentum = 31 [default = 0]; - optional float weight_decay = 32 [default = 0]; - // base learning rate - optional float base_lr = 34 [default = 0]; - // used to avoid divide by 0, i.e. x/(y+delta) - optional float delta = 35 [default = 0.00000001]; -} - -message RMSPropProto{ +message RMSPropProto { // history=history*rho_+(1-rho_)*(grad*grad_scale); required float rho = 1; } -message FixedStepProto{ +message FixedStepProto { repeated int32 step = 28; // lr = step_lr[i] if current step >= step[i] repeated float step_lr = 29; } -message StepProto{ +message StepProto { // lr = base_lr * gamma^(step/change_freq) required float gamma = 35 [default = 1]; // lr = base_lr * gamma^(step/change_freq) - required int32 change_freq= 40; + required int32 change_freq = 40; } -message LinearProto{ + +message LinearProto { // lr = (1 - step / freq) * base_lr + (step / freq) * final_lr required int32 change_freq= 40; // lr = (1 - step / freq) * base_lr + (step / freq) * final_lr required float final_lr = 39; } -message ExponentialProto{ + +message ExponentialProto { // lr = base / 2^(step/change_freq) - required int32 change_freq= 40; + required int32 change_freq = 40; } -message InverseTProto{ + +message InverseTProto { // lr = base_lr / (1+step/final_lr) required float final_lr = 39; } -message InverseProto{ +message InverseProto { // lr = base_lr*(1+gamma*step)^(-pow) required float gamma = 1 [default = 1]; // lr = base_lr*(1+gamma*step)^(-pow) required float pow = 2 [default = 0]; } + +// -------------- +// All Enum Types +// -------------- + +enum ChangeMethod { + kFixed = 0; + kInverseT = 1; + kInverse = 2; + kExponential = 3; + kLinear = 4; + kStep = 5; + kFixedStep = 6; +} + +enum InitMethod { + // fix the values of all parameters a constant in the value field + kConstant = 0; + // sample gaussian with std and mean + kGaussian = 1; + // uniform sampling between low and high + kUniform = 2; + // copy the content and history which are from previous training + kPretrained = 3; + // from Toronto Convnet, let a=1/sqrt(fan_in), w*=a after generating from + // Gaussian distribution + kGaussainSqrtFanIn = 4; + // from Toronto Convnet, rectified linear activation, let + // a=sqrt(3)/sqrt(fan_in), range is [-a, +a]; no need to set value=sqrt(3), + // the program will multiply it. + kUniformSqrtFanIn = 5; + // from Theano MLP tutorial, let a=sqrt(6/(fan_in+fan_out)). for tanh + // activation, range is [-a, +a], for sigmoid activation, range is + // [-4a, +4a], put the scale factor to value field. + // <a href="http://deeplearning.net/tutorial/mlp.html"> Theano MLP</a> + kUniformSqrtFanInOut = 6; +} + +enum LayerType { + // Data layers + // - Load records from file, database + kLMDBData = 17; + kPrefetch = 19; + kShardData = 3; + // Parser layers + // - Parse features from records, e.g., pixels + kLabel = 18; + kMnist = 7; + kRGBImage = 10; + // Neuron layers + // - Feature transformation + kConcate = 2; + kConvolution = 1; + kDropout = 4; + kInnerProduct = 5; + kLRN = 6; + kPooling = 8; + kReLU = 9; + kRBMHid = 24; + kRBMVis = 23; + kTanh = 14; + // Loss layers + // - Compute objective loss + kSoftmaxLoss = 11; + // Other layers + // - Connect layers when neural net is partitioned + kBridgeDst = 16; + kBridgeSrc = 15; + kSlice = 12; + kSplit = 13; +} + +enum PartitionType { + kDataPartition = 0; + kLayerPartition = 1; + kNone = 2; +} + +enum Phase { + kTrain = 0; + kValidation = 1; + kTest= 2; + // postivie phase for contrastive divergence algorithm + kPositive = 3; + // negative phase for contrastive divergence algorithm + kNegative = 4; + kForward = 5; + kBackward = 6; + kLoss = 7; +} + +enum TrainOneBatchAlg { + // Back-propagation algorithm for feed-forward models, e.g., CNN and RNN + kBP = 1; + // Contrastive Divergence algorithm for RBM, DBM, etc. + kCD = 2; +} + +enum UpdaterType { + // noraml SGD with momentum and weight decay + kSGD = 1; + // adaptive subgradient, http://www.magicbroom.info/Papers/DuchiHaSi10.pdf + kAdaGrad = 2; + // http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf + kRMSProp = 3; + // Nesterov first optimal gradient method + kNesterov = 4; +} http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/4dee7b9c/src/utils/param.cc ---------------------------------------------------------------------- diff --git a/src/utils/param.cc b/src/utils/param.cc index 2f43f66..b470ea2 100644 --- a/src/utils/param.cc +++ b/src/utils/param.cc @@ -43,15 +43,15 @@ void Param::InitValues(int version){ Tensor<cpu, 1> data(mutable_cpu_data(), Shape1(size())); auto random=TSingleton<Random<cpu>>::Instance(); switch (proto_.init_method()) { - case ParamProto::kConstant: + case InitMethod::kConstant: data = proto_.value(); break; - case ParamProto::kUniform: + case InitMethod::kUniform: random->SampleUniform(data, proto_.low(), proto_.high()); if(proto_.value() != 1) data *= proto_.value(); break; - case ParamProto::kUniformSqrtFanIn: + case InitMethod::kUniformSqrtFanIn: random->SampleUniform(data, proto_.low(), proto_.high()); // only valid for param matrix with dim 1 for fan in LOG(ERROR) << "init fan in"; @@ -59,17 +59,17 @@ void Param::InitValues(int version){ data *= proto_.value() / sqrt(data_->shape().at(1) / 3.0f); LOG(ERROR) << "end fan in"; break; - case ParamProto::kUniformSqrtFanInOut: + case InitMethod::kUniformSqrtFanInOut: random->SampleUniform(data, proto_.low(), proto_.high()); if(proto_.value()) data *= proto_.value()/ sqrt(data_->shape()[0] +data_->shape()[1]); break; - case ParamProto::kGaussian: + case InitMethod::kGaussian: random->SampleGaussian(data, proto_.mean(), proto_.std()); if(proto_.value() != 1) data *= proto_.value(); break; - case ParamProto::kGaussainSqrtFanIn: + case InitMethod::kGaussainSqrtFanIn: random->SampleGaussian(data, proto_.mean(), proto_.std()); if(proto_.value()) data *= proto_.value()/ sqrt(data_->shape()[0]); http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/4dee7b9c/src/utils/updater.cc ---------------------------------------------------------------------- diff --git a/src/utils/updater.cc b/src/utils/updater.cc index b85982e..c038ca7 100644 --- a/src/utils/updater.cc +++ b/src/utils/updater.cc @@ -13,38 +13,38 @@ float Updater::GetLearningRate(int step) { float ret = 0., r = 0., base = proto_.base_lr(); int freq = 0; switch (proto_.lr_change()) { - case UpdaterProto_ChangeMethod_kFixed: + case ChangeMethod::kFixed: ret = base; break; - case UpdaterProto_ChangeMethod_kLinear: + case ChangeMethod::kLinear: // a is init, b is the final freq = proto_.linear_conf().change_freq(); r = step * 1.0 / freq; ret = (1.0 - r) * base + r * proto_.linear_conf().final_lr(); break; - case UpdaterProto_ChangeMethod_kExponential: + case ChangeMethod::kExponential: // a is init, b is the final, from convnet freq = proto_.exponential_conf().change_freq(); ret = base / pow(2, step * 1. / freq); break; - case UpdaterProto_ChangeMethod_kInverseT: + case ChangeMethod::kInverseT: // a is init, b is the final, from convnet CHECK_EQ(base, 2 * proto_.inverset_conf().final_lr()) << "final value should be the half"; ret = base / (1. + step * 1. / proto_.inverset_conf().final_lr()); break; - case UpdaterProto_ChangeMethod_kInverse: + case ChangeMethod::kInverse: // a is init, b is gamma, c is pow ret = base * pow(1.f + proto_.inverse_conf().gamma() * step, - proto_.inverse_conf().pow()); break; - case UpdaterProto_ChangeMethod_kStep: + case ChangeMethod::kStep: // a is the base learning rate, b is gamma, from caffe // notice it is step/change_steps, not step*1.0/change_steps freq = proto_.step_conf().change_freq(); ret = base * pow(proto_.step_conf().gamma(), step / freq); break; - case UpdaterProto_ChangeMethod_kFixedStep: + case ChangeMethod::kFixedStep: for (int i = 0; i < proto_.fixedstep_conf().step_size(); i++) { if (step > proto_.fixedstep_conf().step(i)) ret = proto_.fixedstep_conf().step_lr(i);
