SINGA-10 Add Support for Recurrent Neural Networks (RNN) Draft upper layers for rnnlm; Compile using Makefile.example;
Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/13b1c08a Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/13b1c08a Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/13b1c08a Branch: refs/heads/master Commit: 13b1c08ad2d65fa41d09d3b9fff05f2b58f925a2 Parents: 1791442 Author: Wei Wang <[email protected]> Authored: Sun Sep 13 14:07:41 2015 +0800 Committer: Wei Wang <[email protected]> Committed: Fri Sep 18 16:46:40 2015 +0800 ---------------------------------------------------------------------- examples/rnnlm/Makefile.example | 7 ++ examples/rnnlm/main.cc | 23 ++++ examples/rnnlm/rnnlm.cc | 209 +++++++++++++++++++++++++++++++++++ examples/rnnlm/rnnlm.h | 89 +++++++++++++++ examples/rnnlm/rnnlm.proto | 18 +++ include/utils/common.h | 1 + src/utils/common.cc | 5 +- 7 files changed, 351 insertions(+), 1 deletion(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/13b1c08a/examples/rnnlm/Makefile.example ---------------------------------------------------------------------- diff --git a/examples/rnnlm/Makefile.example b/examples/rnnlm/Makefile.example new file mode 100644 index 0000000..5eeca78 --- /dev/null +++ b/examples/rnnlm/Makefile.example @@ -0,0 +1,7 @@ +MSHADOW_FLAGS :=-DMSHADOW_USE_CUDA=0 -DMSHADOW_USE_CBLAS=1 -DMSHADOW_USE_MKL=0 + +all: + protoc --proto_path=../../src/proto --proto_path=. --cpp_out=. rnnlm.proto + $(CXX) main.cc rnnlm.cc rnnlm.pb.cc $(MSHADOW_FLAGS) -std=c++11 -lsinga -lglog -lprotobuf -lopenblas -I../../include\ + -I../../include/proto/ -L../../.libs/ -L/usr/local -Wl,-unresolved-symbols=ignore-in-shared-libs -Wl,-rpath=../../.libs/\ + -o rnnlm.bin http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/13b1c08a/examples/rnnlm/main.cc ---------------------------------------------------------------------- diff --git a/examples/rnnlm/main.cc b/examples/rnnlm/main.cc new file mode 100644 index 0000000..690c158 --- /dev/null +++ b/examples/rnnlm/main.cc @@ -0,0 +1,23 @@ +#include <string> +#include "singa.h" +#include "rnnlm.h" +#include "rnnlm.pb.h" + +int main(int argc, char **argv) { + singa::Driver driver; + driver.Init(argc, argv); + + //if -resume in argument list, set resume to true; otherwise false + int resume_pos = singa::ArgPos(argc, argv, "-resume"); + bool resume = (resume_pos != -1); + + // register all layers for rnnlm + driver.RegisterLayer<singa::EmbeddingLayer, std::string>("kEmbedding"); + driver.RegisterLayer<singa::HiddenLayer, std::string>("kHidden"); + driver.RegisterLayer<singa::OutputLayer, std::string>("kOutput"); + + singa::JobProto jobConf = driver.job_conf(); + + driver.Submit(resume, jobConf); + return 0; +} http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/13b1c08a/examples/rnnlm/rnnlm.cc ---------------------------------------------------------------------- diff --git a/examples/rnnlm/rnnlm.cc b/examples/rnnlm/rnnlm.cc new file mode 100644 index 0000000..ddb0f63 --- /dev/null +++ b/examples/rnnlm/rnnlm.cc @@ -0,0 +1,209 @@ +#include "rnnlm.h" +#include "rnnlm.pb.h" +#include "mshadow/tensor.h" +#include "mshadow/cxxnet_op.h" + +namespace singa { +using namespace mshadow; +using mshadow::cpu; + +using mshadow::Shape; +using mshadow::Shape1; +using mshadow::Shape2; +using mshadow::Tensor; + + +inline Tensor<cpu, 2> RTensor2(Blob<float>* blob) { + const vector<int>& shape = blob->shape(); + Tensor<cpu, 2> tensor(blob->mutable_cpu_data(), + Shape2(shape[0], blob->count() / shape[0])); + return tensor; +} + +inline Tensor<cpu, 1> RTensor1(Blob<float>* blob) { + Tensor<cpu, 1> tensor(blob->mutable_cpu_data(), Shape1(blob->count())); + return tensor; +} + +/*******EmbeddingLayer**************/ +EmbeddingLayer::~EmbeddingLayer() { + delete embed_; +} + +void EmbeddingLayer::Setup(const LayerProto& proto, int npartitions) { + Layer::Setup(proto, npartitions); + CHECK_EQ(srclayers_.size(), 1); + int max_window = srclayers_[0]->data(this).shape()[0]; + word_dim_ = proto.GetExtension(embedding_conf).word_dim(); + data_.Reshape(vector<int>{max_window, word_dim_}); + grad_.ReshapeLike(data_); + vocab_size_ = proto.GetExtension(embedding_conf).vocab_size(); + embed_ = Param::Create(proto.param(0)); + embed_->Setup(vector<int>{vocab_size_, word_dim_}); +} + +void EmbeddingLayer::ComputeFeature(int flag, Metric* perf) { + window_ = static_cast<RNNLayer*>(srclayers_[0])->window(); + auto words = RTensor2(&data_); + auto embed = RTensor2(embed_->mutable_data()); + auto word_idx = RTensor1(srclayers_[0]->mutable_data(this)); + + for (int t = 0; t < window_; t++) { + int idx = static_cast<int>(word_idx[t]); + CHECK_GE(idx, 0); + CHECK_LT(idx, vocab_size_); + Copy(words[t], embed[idx]); + } +} + +void EmbeddingLayer::ComputeGradient(int flag, Metric* perf) { + auto grad = RTensor2(&grad_); + auto gembed = RTensor2(embed_->mutable_grad()); + auto word_idx = RTensor1(srclayers_[0]->mutable_data(this)); + gembed = 0; + for (int t = 0; t < window_; t++) { + int idx = static_cast<int>(word_idx[t]); + Copy(gembed[idx], grad[t]); + } +} +/***********HiddenLayer**********/ +HiddenLayer::~HiddenLayer() { + delete weight_; +} + +void HiddenLayer::Setup(const LayerProto& proto, int npartitions) { + Layer::Setup(proto, npartitions); + CHECK_EQ(srclayers_.size(), 1); + const auto& innerproductData = srclayers_[0]->data(this); + data_.ReshapeLike(srclayers_[0]->data(this)); + grad_.ReshapeLike(srclayers_[0]->grad(this)); + int word_dim = data_.shape()[1]; + weight_ = Param::Create(proto.param(0)); + weight_->Setup(std::vector<int>{word_dim, word_dim}); +} + +// hid[t] = sigmoid(hid[t-1] * W + src[t]) +void HiddenLayer::ComputeFeature(int flag, Metric* perf) { + window_ = static_cast<RNNLayer*>(srclayers_[0])->window(); + auto data = RTensor2(&data_); + auto src = RTensor2(srclayers_[0]->mutable_data(this)); + auto weight = RTensor2(weight_->mutable_data()); + for (int t = 0; t < window_; t++) { // Skip the 1st component + if (t == 0) { + data[t] = expr::F<op::sigmoid>(src[t]); + } else { + data[t] = dot(data[t - 1], weight); + data[t] += src[t]; + data[t] = expr::F<op::sigmoid>(data[t]); + } + } +} + +void HiddenLayer::ComputeGradient(int flag, Metric* perf) { + auto data = RTensor2(&data_); + auto grad = RTensor2(&grad_); + auto weight = RTensor2(weight_->mutable_data()); + auto gweight = RTensor2(weight_->mutable_grad()); + auto gsrc = RTensor2(srclayers_[0]->mutable_grad(this)); + gweight = 0; + TensorContainer<cpu, 1> tmp(Shape1(data_.shape()[1])); + // Check!! + for (int t = window_ - 1; t >= 0; t--) { + if (t < window_ - 1) { + tmp = dot(grad[t + 1], weight.T()); + grad[t] += tmp; + } + grad[t] = expr::F<op::sigmoid_grad>(data[t])* grad[t]; + } + gweight = dot(data.Slice(0, window_-1).T(), grad.Slice(1, window_)); + Copy(gsrc, grad); +} + +/*********** 1-Implementation for OutputLayer **********/ +OutputLayer::~OutputLayer() { + delete word_weight_; + delete class_weight_; +} + +void OutputLayer::Setup(const LayerProto& proto, int npartitions) { + Layer::Setup(proto, npartitions); + CHECK_EQ(srclayers_.size(), 2); + const auto& src = srclayers_[0]->data(this); + int max_window = src.shape()[0]; + int vdim = src.count() / max_window; // Dimension of input + int vocab_size = proto.GetExtension(output_conf).vocab_size(); + int nclass = proto.GetExtension(output_conf).nclass(); + word_weight_ = Param::Create(proto.param(0)); + word_weight_->Setup(vector<int>{vocab_size, vdim}); + class_weight_ = Param::Create(proto.param(0)); + class_weight_->Setup(vector<int>{nclass, vdim}); + + pword_.resize(max_window); + pclass_.Reshape(vector<int>{max_window, nclass}); +} + +void OutputLayer::ComputeFeature(int flag, Metric* perf) { + window_ = static_cast<RNNLayer*>(srclayers_[0])->window(); + auto pclass = RTensor2(&pclass_); + auto src = RTensor2(srclayers_[0]->mutable_data(this)); + auto word_weight = RTensor2(word_weight_->mutable_data()); + auto class_weight = RTensor2(class_weight_->mutable_data()); + const float * label = srclayers_[1]->data(this).cpu_data(); + + float loss = 0.f, ppl =0.f; + for (int t = 0; t < window_; t++) { + int start = static_cast<int>(label[t * 4 + 0]); + int end = static_cast<int>(label[t * 4 + 1]); + + auto wordWeight = word_weight.Slice(start, end); + pword_[t].Reshape(vector<int>{end-start}); + auto pword = RTensor1(&pword_[t]); + pword = dot(src[t], wordWeight.T()); + Softmax(pword, pword); + + pclass[t] = dot(src[t], class_weight.T()); + Softmax(pclass[t], pclass[t]); + + int wid = static_cast<int>(label[t * 4 + 2]); + int cid = static_cast<int>(label[t * 4 + 3]); + loss += -log(std::max(pword[wid - start] * pclass[t][cid], FLT_MIN)); + ppl += log10(std::max(pword[wid - start] * pclass[t][cid], FLT_MIN)); + } + + perf->Add("loss", loss, window_); + perf->Add("ppl before exp", ppl, window_); +} + +void OutputLayer::ComputeGradient(int flag, Metric* perf) { + auto pclass = RTensor2(&pclass_); + auto src = RTensor2(srclayers_[0]->mutable_data(this)); + auto gsrc = RTensor2(srclayers_[0]->mutable_grad(this)); + auto word_weight = RTensor2(word_weight_->mutable_data()); + auto gword_weight = RTensor2(word_weight_->mutable_grad()); + auto class_weight = RTensor2(class_weight_->mutable_data()); + auto gclass_weight = RTensor2(class_weight_->mutable_grad()); + const float * label = srclayers_[1]->data(this).cpu_data(); + gclass_weight = 0; + gword_weight = 0; + for (int t = 0; t < window_; t++) { + int start = static_cast<int>(label[t * 4 + 0]); + int end = static_cast<int>(label[t * 4 + 1]); + int wid = static_cast<int>(label[t * 4 + 2]); + int cid = static_cast<int>(label[t * 4 + 3]); + auto pword = RTensor1(&pword_[t]); + + // gL/gclass_act + pclass[t][cid] -= 1.0; + // gL/gword_act + pword[wid] -= 1.0; + + // gL/gword_weight + gword_weight.Slice(start, end) += dot(pword.FlatTo2D().T(), src[t].FlatTo2D()); + // gL/gclass_weight + gclass_weight += dot(pclass[t].FlatTo2D().T(), src[t].FlatTo2D()); + + gsrc[t] = dot(pword, word_weight.Slice(start, end)); + gsrc[t] += dot(pclass[t], class_weight); + } +} +} http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/13b1c08a/examples/rnnlm/rnnlm.h ---------------------------------------------------------------------- diff --git a/examples/rnnlm/rnnlm.h b/examples/rnnlm/rnnlm.h new file mode 100644 index 0000000..14d947c --- /dev/null +++ b/examples/rnnlm/rnnlm.h @@ -0,0 +1,89 @@ +#include "singa.h" +namespace singa { + +/** + * Base RNN layer. May make it a base layer of SINGA. + */ +class RNNLayer : public NeuronLayer { + public: + /** + * The recurrent layers may be unrolled different times for different + * iterations, depending on the applications. For example, the ending word + * of a sentence may stop the unrolling; unrolling also stops when the max + * window size is reached. Every layer must reset window_ in its + * ComputeFeature function. + * + * @return the effective BPTT length, which is <= max_window. + */ + inline int window() { return window_; } + + protected: + //!< effect window size for BPTT + int window_; +}; + +/** + * Word embedding layer that get one row from the embedding matrix for each + * word based on the word index + */ +class EmbeddingLayer : public RNNLayer { + public: + ~EmbeddingLayer(); + void Setup(const LayerProto& proto, int npartitions) override; + void ComputeFeature(int flag, Metric *perf) override; + void ComputeGradient(int flag, Metric* perf) override; + const std::vector<Param*> GetParams() const override { + std::vector<Param*> params{embed_}; + return params; + } + + + private: + int word_dim_; + int vocab_size_; + //!< word embedding matrix of size vocab_size_ x word_dim_ + Param* embed_; +}; + + +/** + * hid[t] = sigmoid(hid[t-1] * W + src[t]) + */ +class HiddenLayer : public RNNLayer { + public: + ~HiddenLayer(); + void Setup(const LayerProto& proto, int npartitions) override; + void ComputeFeature(int flag, Metric *perf) override; + void ComputeGradient(int flag, Metric* perf) override; + const std::vector<Param*> GetParams() const override { + std::vector<Param*> params{weight_}; + return params; + } + + + private: + Param* weight_; +}; + +/** + * p(word at t+1 is from class c) = softmax(src[t]*Wc)[c] + * p(w|c) = softmax(src[t]*Ww[Start(c):End(c)]) + * p(word at t+1 is w)=p(word at t+1 is from class c)*p(w|c) + */ +class OutputLayer : public RNNLayer { + public: + ~OutputLayer(); + void Setup(const LayerProto& proto, int npartitions) override; + void ComputeFeature(int flag, Metric *perf) override; + void ComputeGradient(int flag, Metric* perf) override; + const std::vector<Param*> GetParams() const override { + std::vector<Param*> params{word_weight_, class_weight_}; + return params; + } + + private: + vector<Blob<float>> pword_; + Blob<float> pclass_; + Param* word_weight_, *class_weight_; +}; +} http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/13b1c08a/examples/rnnlm/rnnlm.proto ---------------------------------------------------------------------- diff --git a/examples/rnnlm/rnnlm.proto b/examples/rnnlm/rnnlm.proto new file mode 100644 index 0000000..35b6bc2 --- /dev/null +++ b/examples/rnnlm/rnnlm.proto @@ -0,0 +1,18 @@ +package singa; +import "job.proto"; + + +message EmbeddingProto { + optional int32 word_dim = 1; + optional int32 vocab_size = 2; +} + +message OutputProto { + optional int32 nclass = 1; + optional int32 vocab_size = 2; +} + +extend LayerProto { + optional EmbeddingProto embedding_conf = 101; + optional OutputProto output_conf = 102; +} http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/13b1c08a/include/utils/common.h ---------------------------------------------------------------------- diff --git a/include/utils/common.h b/include/utils/common.h index 2be2715..3eb0bbd 100644 --- a/include/utils/common.h +++ b/include/utils/common.h @@ -95,6 +95,7 @@ class Metric { * @param value metric value */ void Add(const std::string& name, float value); + void Add(const std::string& name, float value, int count); /** * reset all metric counter and value to 0 */ http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/13b1c08a/src/utils/common.cc ---------------------------------------------------------------------- diff --git a/src/utils/common.cc b/src/utils/common.cc index 75974d1..6dd40c8 100644 --- a/src/utils/common.cc +++ b/src/utils/common.cc @@ -267,11 +267,14 @@ Metric::Metric(const string& str) { } void Metric::Add(const string& name, float value) { + Add( name, value, 1); +} +void Metric::Add(const string& name, float value, int count) { if (entry_.find(name) == entry_.end()) { entry_[name] = std::make_pair(1, value); } else { auto& e = entry_.at(name); - e.first += 1; + e.first += count; e.second += value; } }
