SINGA-120 - Implemented GRU and BPTT Change new memory computation formula following char-rnn (i.e., element-wise multiplication before matrix multiplication)
Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/6a4c9960 Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/6a4c9960 Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/6a4c9960 Branch: refs/heads/master Commit: 6a4c9960e0795aeac6df282d7190b6f93b305c58 Parents: 959ef70 Author: Wei Wang <[email protected]> Authored: Tue Jan 5 18:14:46 2016 +0800 Committer: Wei Wang <[email protected]> Committed: Wed Jan 6 01:55:08 2016 +0800 ---------------------------------------------------------------------- Makefile.am | 2 + include/singa/neuralnet/connection_layer.h | 28 +++++++++++++ include/singa/neuralnet/layer.h | 14 +++++++ include/singa/neuralnet/neuron_layer.h | 2 +- include/singa/neuralnet/output_layer.h | 15 +++++++ include/singa/utils/updater.h | 8 ++-- src/driver.cc | 6 ++- src/neuralnet/neuralnet.cc | 9 ++--- src/neuralnet/neuron_layer/gru.cc | 52 +++++++++++-------------- src/proto/job.proto | 14 +++++++ src/utils/updater.cc | 28 ++++++------- src/worker.cc | 16 ++++++-- 12 files changed, 133 insertions(+), 61 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/6a4c9960/Makefile.am ---------------------------------------------------------------------- diff --git a/Makefile.am b/Makefile.am index d2b2aa8..7ae4537 100644 --- a/Makefile.am +++ b/Makefile.am @@ -75,6 +75,7 @@ SINGA_SRCS := src/driver.cc \ src/neuralnet/connection_layer/concate.cc \ src/neuralnet/connection_layer/slice.cc \ src/neuralnet/connection_layer/split.cc \ + src/neuralnet/connection_layer/rnn_dummy.cc \ src/neuralnet/input_layer/char_rnn.cc \ src/neuralnet/input_layer/onehot.cc \ src/neuralnet/input_layer/csv.cc \ @@ -88,6 +89,7 @@ SINGA_SRCS := src/driver.cc \ src/neuralnet/output_layer/argsort.cc \ src/neuralnet/output_layer/csv.cc \ src/neuralnet/output_layer/record.cc \ + src/neuralnet/output_layer/char_rnn.cc \ src/neuralnet/loss_layer/euclidean.cc \ src/neuralnet/loss_layer/softmax.cc \ src/neuralnet/neuron_layer/activation.cc \ http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/6a4c9960/include/singa/neuralnet/connection_layer.h ---------------------------------------------------------------------- diff --git a/include/singa/neuralnet/connection_layer.h b/include/singa/neuralnet/connection_layer.h index a18f458..481d991 100644 --- a/include/singa/neuralnet/connection_layer.h +++ b/include/singa/neuralnet/connection_layer.h @@ -153,6 +153,34 @@ class SplitLayer : public ConnectionLayer { Layer2Index layer_idx_; }; +/** + * Dummy layer for RNN models, which provides input for other layers. + * + * Particularly, it is used in the test phase of RNN models to connect other + * layers and avoid cycles in the neural net config. + */ +class RNNDummyLayer : public ConnectionLayer { + public: + void Setup(const LayerProto& proto, const vector<Layer*>& srclayers) override; + void ComputeFeature(int flag, const vector<Layer*>& srclayers) override; + void ComputeGradient(int flag, const vector<Layer*>& srclayers) { + LOG(FATAL) << "Not implemented"; + } + + const string srclayer(int step) const { + if (step > 0) + return dynamic_src_; + else + return ""; + } + + private: + string dynamic_src_; + float low_, high_; + bool integer_; + Layer* srclayer_; +}; + } // namespace singa http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/6a4c9960/include/singa/neuralnet/layer.h ---------------------------------------------------------------------- diff --git a/include/singa/neuralnet/layer.h b/include/singa/neuralnet/layer.h index f4738fa..c1612a2 100644 --- a/include/singa/neuralnet/layer.h +++ b/include/singa/neuralnet/layer.h @@ -36,6 +36,20 @@ using std::string; // TODO(wangwei) make AuxType a template argument for Layer. using AuxType = int; + +inline const string AddUnrollingPrefix(int unroll_idx, const string& name) { + return std::to_string(unroll_idx) + "#" + name; +} +inline const string AddPartitionSuffix(int partition_idx, const string& name) { + return name + "@" + std::to_string(partition_idx); +} + + +inline const string AddPrefixSuffix(int unroll_idx, int partition_idx, + const string& name) { + return std::to_string(unroll_idx) + "#" + name + "@" + + std::to_string(partition_idx); +} /** * Base layer class. * http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/6a4c9960/include/singa/neuralnet/neuron_layer.h ---------------------------------------------------------------------- diff --git a/include/singa/neuralnet/neuron_layer.h b/include/singa/neuralnet/neuron_layer.h index e587e38..e1a63a2 100644 --- a/include/singa/neuralnet/neuron_layer.h +++ b/include/singa/neuralnet/neuron_layer.h @@ -203,7 +203,7 @@ class GRULayer : public NeuronLayer { int batchsize_; // batch size int vdim_, hdim_; // dimensions - Blob<float> *update_gate, *reset_gate, *new_memory; + Blob<float> *update_gate, *reset_gate, *new_memory, *reset_context; //!< gru layer connect to two dst layers, hence need to grad blobs. Blob<float> aux_grad_; http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/6a4c9960/include/singa/neuralnet/output_layer.h ---------------------------------------------------------------------- diff --git a/include/singa/neuralnet/output_layer.h b/include/singa/neuralnet/output_layer.h index c7e5d6a..9071f33 100644 --- a/include/singa/neuralnet/output_layer.h +++ b/include/singa/neuralnet/output_layer.h @@ -80,5 +80,20 @@ class RecordOutputLayer : public OutputLayer { int inst_ = 0; //!< instance No. io::Store* store_ = nullptr; }; + +/** + * Output layer for char rnn model, which convert sample id back to char and + * dump to stdout. + */ +class CharRNNOutputLayer : public OutputLayer { + public: + void Setup(const LayerProto& proto, const vector<Layer*>& srclayers) override; + + void ComputeFeature(int flag, const vector<Layer*>& srclayers) override; + + private: + string vocab_; +}; + } // namespace singa #endif // SINGA_NEURALNET_OUTPUT_LAYER_H_ http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/6a4c9960/include/singa/utils/updater.h ---------------------------------------------------------------------- diff --git a/include/singa/utils/updater.h b/include/singa/utils/updater.h index 575ab86..7fec78c 100644 --- a/include/singa/utils/updater.h +++ b/include/singa/utils/updater.h @@ -118,18 +118,16 @@ class NesterovUpdater : public Updater { void Update(int step, Param* param, float grad_scale) override; }; -/* class RMSPropUpdater : public Updater { public: - virtual void Update(int step, Param* param, float grad_scale); + void Init(const UpdaterProto &proto) override; + void Update(int step, Param* param, float grad_scale) override; protected: - float base_lr_; - float delta_; float rho_; - float weight_decay_; }; +/* class AdaDeltaUpdater : public Updater { public: virtual void Update(int step, Param* param, float grad_scale); http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/6a4c9960/src/driver.cc ---------------------------------------------------------------------- diff --git a/src/driver.cc b/src/driver.cc index 21968bb..1e4929f 100644 --- a/src/driver.cc +++ b/src/driver.cc @@ -74,6 +74,7 @@ void Driver::Init(int argc, char **argv) { RegisterLayer<CharRNNInputLayer, int>(kCharRNN); RegisterLayer<RNNLabelLayer, int>(kRNNLabel); RegisterLayer<OneHotLayer, int>(kOneHot); + RegisterLayer<CharRNNOutputLayer, int>(kCharRNNOutput); // connection layers RegisterLayer<BridgeDstLayer, int>(kBridgeDst); @@ -81,6 +82,7 @@ void Driver::Init(int argc, char **argv) { RegisterLayer<ConcateLayer, int>(kConcate); RegisterLayer<SliceLayer, int>(kSlice); RegisterLayer<SplitLayer, int>(kSplit); + RegisterLayer<RNNDummyLayer, int>(kRNNDummy); RegisterLayer<AccuracyLayer, int>(kAccuracy); RegisterLayer<ArgSortLayer, int>(kArgSort); @@ -125,7 +127,7 @@ void Driver::Init(int argc, char **argv) { // register updaters RegisterUpdater<AdaGradUpdater>(kAdaGrad); RegisterUpdater<NesterovUpdater>(kNesterov); - // TODO(wangwei) RegisterUpdater<kRMSPropUpdater>(kRMSProp); + RegisterUpdater<RMSPropUpdater>(kRMSProp); RegisterUpdater<SGDUpdater>(kSGD); // register learning rate change methods @@ -198,6 +200,8 @@ void Driver::Test(const JobProto& job_conf) { auto worker = Worker::Create(job_conf.train_one_batch()); worker->Setup(0, 0, job_conf, nullptr, nullptr, nullptr); auto net = NeuralNet::Create(job_conf.neuralnet(), kTest, 1); + WriteStringToTextFile(Cluster::Get()->vis_folder() + "/test_net.json", + net->ToGraph(true).ToJson()); vector<string> paths; for (const auto& p : job_conf.checkpoint_path()) paths.push_back(p); http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/6a4c9960/src/neuralnet/neuralnet.cc ---------------------------------------------------------------------- diff --git a/src/neuralnet/neuralnet.cc b/src/neuralnet/neuralnet.cc index f9579b1..49978a1 100644 --- a/src/neuralnet/neuralnet.cc +++ b/src/neuralnet/neuralnet.cc @@ -144,7 +144,7 @@ const NetProto NeuralNet::Unrolling(const NetProto& net_conf) { for (int i = 0; i < org_layer.unroll_len(); i ++) { // unroll LayerProto* unroll_layer = conf.add_layer(); unroll_layer->CopyFrom(org_layer); // create a new layer conf - if (org_layer.unroll_len() > 1) { + // if (org_layer.unroll_len() > 1) { // update layer names std::stringstream sstm; sstm << i << '#' << unroll_layer->name(); @@ -160,7 +160,7 @@ const NetProto NeuralNet::Unrolling(const NetProto& net_conf) { sstm1 << i << '#' << param->name(); param->set_name(sstm1.str()); } - } + // } // clear unrolling related fields unroll_layer->clear_unroll_len(); unroll_layer->clear_unroll_conn_type(); @@ -257,6 +257,7 @@ void NeuralNet::Load(const vector<string>& paths, ReadProtoFromBinaryFile(path.c_str(), &bps); for (int i = 0; i < bps.name_size(); i++) { if (params.find(bps.name(i)) != params.end()) { + // LOG(ERROR) << "Loading param = " << bps.name(i); params.at(bps.name(i))->FromProto(bps.blob(i)); params.at(bps.name(i))->set_version(bps.version(i)); } @@ -458,12 +459,10 @@ Graph* NeuralNet::CreateGraph(const NetProto& netproto, int npartitions) { map<string, const LayerProto*> name2proto; for (const LayerProto& layer : net_w_connection.layer()) { vector<Node*> nodes; - char suffix[4]; for (int i = 0; i < npartitions; i++) { LayerProto *proto = new LayerProto(layer); - snprintf(suffix, sizeof(suffix), "%02d", i); // differentiate partitions - string nodename = layer.name() + "@" + string(suffix); + string nodename = layer.name() + "@" + std::to_string(i); proto->set_name(nodename); proto->set_type(layer.type()); proto->set_partition_dim(layer.partition_dim()); http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/6a4c9960/src/neuralnet/neuron_layer/gru.cc ---------------------------------------------------------------------- diff --git a/src/neuralnet/neuron_layer/gru.cc b/src/neuralnet/neuron_layer/gru.cc index 9ba5a50..cf7425b 100644 --- a/src/neuralnet/neuron_layer/gru.cc +++ b/src/neuralnet/neuron_layer/gru.cc @@ -98,6 +98,8 @@ void GRULayer::Setup(const LayerProto& conf, update_gate = new Blob<float>(batchsize_, hdim_); reset_gate = new Blob<float>(batchsize_, hdim_); + // reset gate x context + reset_context = new Blob<float>(batchsize_, hdim_); new_memory = new Blob<float>(batchsize_, hdim_); } @@ -130,24 +132,23 @@ void GRULayer::ComputeFeature(int flag, Map<op::Sigmoid<float>,float>(*update_gate, update_gate); // Compute the reset gate - GEMM(1.0f, 0.0f, src,*w_r_hx_t,reset_gate); + GEMM(1.0f, 0.0f, src, *w_r_hx_t, reset_gate); if (bias_r_ != nullptr) - MVAddRow(1.0f,1.0f,bias_r_->data(),reset_gate); + MVAddRow(1.0f,1.0f, bias_r_->data(),reset_gate); GEMM(1.0f, 1.0f, *context, *w_r_hh_t, reset_gate); Map<op::Sigmoid<float>,float>(*reset_gate, reset_gate); // Compute the new memory - GEMM(1.0f, 0.0f, src, *w_c_hx_t, new_memory); + Mult<float>(*reset_gate, *context, reset_context); + GEMM(1.0f, 0.0f, *reset_context, *w_c_hh_t, new_memory); + GEMM(1.0f, 1.0f, src, *w_c_hx_t, new_memory); if (bias_c_ != nullptr) - MVAddRow(1.0f,1.0f,bias_c_->data(), new_memory); - Mult<float>(*reset_gate, *new_memory, new_memory); - GEMM(1.0f, 1.0f, *context, *w_c_hh_t, new_memory); + MVAddRow(1.0f, 1.0f, bias_c_->data(), new_memory); Map<op::Tanh<float>,float>(*new_memory, new_memory); - - Sub(*context, *new_memory, &data_); + Sub(*new_memory, *context, &data_); Mult(data_, *update_gate, &data_); - Add(data_, *new_memory, &data_); + AXPY(1.0f, *context, &data_); // delete the pointers if (srclayers.size() == 1) @@ -192,24 +193,19 @@ void GRULayer::ComputeGradient(int flag, Map<singa::op::TanhGrad<float>, float>(*new_memory, &dnewmdc); Blob<float> dLdz (batchsize_, hdim_); - Sub<float>(*context, *new_memory, &dLdz); + Sub<float>(*new_memory, *context, &dLdz); Mult<float>(dLdz, grad_, &dLdz); Mult<float>(dLdz, dugatedz, &dLdz); Blob<float> dLdc (batchsize_,hdim_); - Blob<float> z1 (batchsize_,hdim_); - z1.SetValue(1.0f); - AXPY<float>(-1.0f, *update_gate, &z1); - Mult(grad_,z1,&dLdc); - Mult(dLdc,dnewmdc,&dLdc); + Mult(grad_, *update_gate, &dLdc); + Mult(dLdc, dnewmdc, &dLdc); Blob<float> reset_dLdc (batchsize_,hdim_); - Mult(dLdc, *reset_gate, &reset_dLdc); + GEMM(1.0f, 0.0f, dLdc, weight_c_hh_->data(), &reset_dLdc); Blob<float> dLdr (batchsize_, hdim_); - Blob<float> cprev (batchsize_, hdim_); - GEMM(1.0f, 0.0f, *context, weight_c_hh_->data().T(), &cprev); - Mult(dLdc, cprev, &dLdr); + Mult(reset_dLdc, *context, &dLdr); Mult(dLdr, drgatedr, &dLdr); // Compute gradients for parameters of update gate @@ -230,29 +226,25 @@ void GRULayer::ComputeGradient(int flag, // Compute gradients for parameters of new memory Blob<float> *dLdc_t = Transpose(dLdc); - GEMM(1.0f, beta, *dLdc_t, src,weight_c_hx_->mutable_grad()); + GEMM(1.0f, beta, *dLdc_t, src, weight_c_hx_->mutable_grad()); + GEMM(1.0f, beta, *dLdc_t, *reset_context, weight_c_hh_->mutable_grad()); if (bias_c_ != nullptr) MVSumRow(1.0f, beta, dLdc, bias_c_->mutable_grad()); delete dLdc_t; - Blob<float> *reset_dLdc_t = Transpose(reset_dLdc); - GEMM(1.0f, beta, *reset_dLdc_t, *context, weight_c_hh_->mutable_grad()); - delete reset_dLdc_t; - // Compute gradients for data input layer if (srclayers[0]->mutable_grad(this) != nullptr) { - GEMM(1.0f,0.0f,dLdc, weight_c_hx_->data(), ilayer->mutable_grad(this)); - GEMM(1.0f,1.0f,dLdz, weight_z_hx_->data(), ilayer->mutable_grad(this)); - GEMM(1.0f,1.0f,dLdr, weight_r_hx_->data(), ilayer->mutable_grad(this)); + GEMM(1.0f,0.0f, dLdc, weight_c_hx_->data(), ilayer->mutable_grad(this)); + GEMM(1.0f,1.0f, dLdz, weight_z_hx_->data(), ilayer->mutable_grad(this)); + GEMM(1.0f,1.0f, dLdr, weight_r_hx_->data(), ilayer->mutable_grad(this)); } if (clayer != nullptr && clayer->mutable_grad(this) != nullptr) { // Compute gradients for context layer - GEMM(1.0f, 0.0f, reset_dLdc, weight_c_hh_->data(), - clayer->mutable_grad(this)); + Mult(reset_dLdc, *reset_gate, clayer->mutable_grad(this)); GEMM(1.0f, 1.0f, dLdr, weight_r_hh_->data(), clayer->mutable_grad(this)); GEMM(1.0f, 1.0f, dLdz, weight_z_hh_->data(), clayer->mutable_grad(this)); - Add(clayer->grad(this), *update_gate, clayer->mutable_grad(this)); + AXPY(-1.0f, *update_gate, clayer->mutable_grad(this)); // LOG(ERROR) << "grad to prev gru " << Asum(clayer->grad(this)); } http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/6a4c9960/src/proto/job.proto ---------------------------------------------------------------------- diff --git a/src/proto/job.proto b/src/proto/job.proto index e520eba..28a3a68 100644 --- a/src/proto/job.proto +++ b/src/proto/job.proto @@ -253,6 +253,7 @@ message LayerProto { optional ConcateProto concate_conf = 502; optional SliceProto slice_conf = 503; optional SplitProto split_conf = 504; + optional RNNDummyProto rnn_dummy_conf = 505; extensions 1001 to 1100; } @@ -456,6 +457,17 @@ message DummyProto { repeated int32 shape = 3; } +message RNNDummyProto { + optional string dynamic_srclayer = 1; + // if shape set, random generate the data blob + repeated int32 shape = 2; + // if integer is true, generate integer data + optional bool integer = 3 [default = false]; + // range of the random generation + optional float low = 4 [default = 0]; + optional float high = 5 [default = 0]; +} + // Message that stores parameters used by DropoutLayer message DropoutProto { // dropout ratio @@ -667,6 +679,7 @@ enum LayerType { kArgSort = 401; kCSVOutput = 402; kRecordOutput = 403; + kCharRNNOutput = 404; /* * Connection layers @@ -677,6 +690,7 @@ enum LayerType { kConcate = 502; kSlice = 503; kSplit = 504; + kRNNDummy = 505; /* * User defined layer http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/6a4c9960/src/utils/updater.cc ---------------------------------------------------------------------- diff --git a/src/utils/updater.cc b/src/utils/updater.cc index a9f70c0..200670a 100644 --- a/src/utils/updater.cc +++ b/src/utils/updater.cc @@ -174,31 +174,27 @@ void AdaGradUpdater::Update(int step, Param* param, float grad_scale) { data -= lr * grad / (F<sqrtop>(history, proto_.delta())); } -/***********************RMSProp****************************** -void RMSPropUpdater::Init(const UpdaterProto& proto){ +/***********************RMSProp******************************/ +void RMSPropUpdater::Init(const UpdaterProto& proto) { Updater::Init(proto); - base_lr_ = proto.base_lr(); - CHECK_GT(base_lr_, 0); - delta_ = proto.delta(); rho_ = proto.rmsprop_conf().rho(); - weight_decay_ = proto.weight_decay(); } -void RMSPropUpdater::Update(int step, Param* param, float grad_scale){ +void RMSPropUpdater::Update(int step, Param* param, float grad_scale) { Shape<1> s=Shape1(param->size()); Tensor<cpu, 1> data(param->mutable_cpu_data(), s); Tensor<cpu, 1> grad(param->mutable_cpu_grad(), s); Tensor<cpu, 1> history(param->mutable_cpu_history(), s); - history=history*rho_+(1-rho_)*F<op::square>(grad*grad_scale); - float lr=GetLearningRate(step)*param->lr_scale(); - float wd=weight_decay_*param->wd_scale(); - if(wd>0){ // L2 regularization - grad+=data*wd; - } - data-=lr*grad/(F<op::sqrtop>(history,delta_)); + float lr = lr_gen_->Get(step) * param->lr_scale(); + float wd = weight_decay_ * param->wd_scale(); + if (grad_scale != 1.f) + grad *= grad_scale; + if (wd > 0) // L2 regularization, should be done after timing grad_scale + grad += data * wd; + history = history * rho_ + (1 - rho_) * F<square>(grad); + data -= lr * grad / (F<sqrtop>(history, proto_.delta())); } - -***********************AdaDelta****************************** +/***********************AdaDelta****************************** void AdaDeltaUpdater::Init(const UpdaterProto& proto){ Updater::Init(proto); delta_=proto.delta(); http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/6a4c9960/src/worker.cc ---------------------------------------------------------------------- diff --git a/src/worker.cc b/src/worker.cc index 4e1dc75..abe74e7 100644 --- a/src/worker.cc +++ b/src/worker.cc @@ -390,13 +390,24 @@ void BPTTWorker::Forward(int step, Phase phase, NeuralNet* net) { } } vector<Layer*> src = net->srclayers(layer); + if ((phase & kTest) && typeid(*layer) == typeid(RNNDummyLayer)) { + CHECK_LE(src.size(), 1); + auto dummy = dynamic_cast<RNNDummyLayer*>(layer); + Layer* srclayer = net->name2layer(dummy->srclayer(step)); + if (step > 0) + CHECK(srclayer != nullptr); + if (srclayer != nullptr) { + src.clear(); + src.push_back(srclayer); + } + } // if full state rnn and not the starting of a new passing of the dataset, // feed the hidden state of the last unit to the first unit. if (layer->unroll_index() == 0 && full_state_ && !begin_) { Layer* last = net->last_unroll_layer(layer); - if (last != layer) { + CHECK(last != nullptr); + if (last != layer || (phase & kTest)) src.push_back(last); - } } // LOG(ERROR) << layer->name() << " forward"; // int ret = @@ -405,7 +416,6 @@ void BPTTWorker::Forward(int step, Phase phase, NeuralNet* net) { if ((phase & Phase::kTrain) && ret == Status::kEnd) begin_ = true; */ - if (job_conf_.debug() && DisplayNow(step) && grp_id_ == 0) label[layer->name()] = layer->ToString(true, phase | kForward); }
