SINGA-10 Add Support for Recurrent Neural Networks (RNN) Revise Makefile.example; Add job.conf; Update src/utils/tool.cc to properly parse job.conf;
Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/3dc1eee6 Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/3dc1eee6 Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/3dc1eee6 Branch: refs/heads/master Commit: 3dc1eee640d258f36ffbd903b9ed735de5bff5c3 Parents: ad86f72 Author: chonho <[email protected]> Authored: Tue Sep 15 19:07:26 2015 +0800 Committer: Wei Wang <[email protected]> Committed: Fri Sep 18 16:46:41 2015 +0800 ---------------------------------------------------------------------- examples/rnnlm/Makefile.example | 16 ++--- examples/rnnlm/job.conf | 136 +++++++++++++++++++++++++++++++++++ examples/rnnlm/rnnlm.cc | 1 + examples/rnnlm/rnnlm.proto | 3 +- src/utils/tool.cc | 43 ++++++++++- 5 files changed, 185 insertions(+), 14 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/3dc1eee6/examples/rnnlm/Makefile.example ---------------------------------------------------------------------- diff --git a/examples/rnnlm/Makefile.example b/examples/rnnlm/Makefile.example index b4505cf..83e2522 100644 --- a/examples/rnnlm/Makefile.example +++ b/examples/rnnlm/Makefile.example @@ -9,25 +9,23 @@ numclass = 100 dirshards = train_shard valid_shard test_shard -.PHONY: all download create -download: rnnlm - -rnnlm: +download: wget $(filelink)/$(filename) tar zxf $(filename) rm $(filename) create: - $(CXX) create_shard.cc -std=c++11 -lsinga -lprotobuf -lzookeeper_mt -lglog -I../../include \ + protoc --proto_path=../../src/proto --proto_path=. --cpp_out=. rnnlm.proto + $(CXX) create_shard.cc rnnlm.pb.cc -std=c++11 -lsinga -lprotobuf -lzookeeper_mt -lglog -I../../include -I../../include/proto \ -L../../.libs/ -L/usr/local/lib -Wl,-unresolved-symbols=ignore-in-shared-libs -Wl,-rpath=../../.libs/ \ -o create_shard.bin for d in $(dirshards); do mkdir -p $${d}; done - ./create_shard.bin -train $(dirname)/train -class_size $(numclass) -test $(dirname)/test + ./create_shard.bin -train $(dirname)/train -test $(dirname)/test -valid $(dirname)/valid -class_size $(numclass) -all: +rnnlm: protoc --proto_path=../../src/proto --proto_path=. --cpp_out=. rnnlm.proto - $(CXX) main.cc rnnlm.cc rnnlm.pb.cc $(MSHADOW_FLAGS) -std=c++11 -lsinga -lglog -lprotobuf -lopenblas -I../../include\ - -I../../include/proto/ -L../../.libs/ -L/usr/local -Wl,-unresolved-symbols=ignore-in-shared-libs -Wl,-rpath=../../.libs/\ + $(CXX) main.cc rnnlm.cc rnnlm.pb.cc $(MSHADOW_FLAGS) -std=c++11 -lsinga -lglog -lprotobuf -lopenblas -I../../include -I../../include/proto \ + -L../../.libs/ -L/usr/local -Wl,-unresolved-symbols=ignore-in-shared-libs -Wl,-rpath=../../.libs/\ -o rnnlm.bin http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/3dc1eee6/examples/rnnlm/job.conf ---------------------------------------------------------------------- diff --git a/examples/rnnlm/job.conf b/examples/rnnlm/job.conf new file mode 100644 index 0000000..81bdb94 --- /dev/null +++ b/examples/rnnlm/job.conf @@ -0,0 +1,136 @@ +cluster { + nworker_groups: 1 + nserver_groups: 1 + nservers_per_group: 1 + nworkers_per_group: 1 + nservers_per_procs: 1 + nworkers_per_procs: 1 + workspace: "examples/rnnlm/" +} + +name: "recurrent-neural-network-language-model" +#To scan the training file (71350) 30 times +train_steps:214050 +#To scan the validation file (5829) once +test_steps:583 +test_freq:7135 +#disp_freq is specific to training +disp_freq:7135 + +train_one_batch { + alg: kBP +} + +updater{ + type: kSGD + #weight_decay:0.0000001 + learning_rate { + type: kFixedStep + fixedstep_conf:{ + step:0 + step:42810 + step:49945 + step:57080 + step:64215 + step_lr:0.1 + step_lr:0.05 + step_lr:0.025 + step_lr:0.0125 + step_lr:0.00625 + } + } +} + +neuralnet { +layer { + name: "data" + user_type: "kRnnData" + [singa.input_conf] { + path: "examples/rnnlm/train_shard" + max_window: 10 + } + exclude: kTest +} + +layer { + name: "data" + user_type: "kRnnData" + [singa.input_conf] { + path: "examples/rnnlm/test_shard" + max_window: 10 + } + exclude: kTrain +} + +layer{ + name:"wordlayer" + user_type: "kWord" + srclayers: "data" +} + +layer{ + name:"labellayer" + user_type: "kRnnLabel" + srclayers: "data" +} + +layer{ + name: "embeddinglayer" + user_type: "kEmbedding" + [singa.embedding_conf] { + word_dim: 15 + vocab_size: 3720 + } + srclayers: "wordlayer" + param { + name: "w1" + init { + type: kUniform + low:-0.3 + high:0.3 + } + } +} + +layer{ + name: "hiddenlayer" + user_type: "kHidden" + srclayers:"embeddinglayer" + param{ + name: "w2" + init { + type: kUniform + low:-0.3 + high:0.3 + } + } +} +layer{ + name: "outputlayer" + user_type: "kOutput" + srclayers:"hiddenlayer" + srclayers:"labellayer" + [singa.output_conf] { + nclass:100 + vocab_size: 3720 + } + param{ + name: "w3" + init { + type: kUniform + low:-0.3 + high:0.3 + } + } + param{ + name: "w4" + init { + type: kUniform + low:-0.3 + high:0.3 + } + } +} + +} + http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/3dc1eee6/examples/rnnlm/rnnlm.cc ---------------------------------------------------------------------- diff --git a/examples/rnnlm/rnnlm.cc b/examples/rnnlm/rnnlm.cc index 180300f..4d39b5f 100644 --- a/examples/rnnlm/rnnlm.cc +++ b/examples/rnnlm/rnnlm.cc @@ -69,6 +69,7 @@ void WordLayer::Setup(const LayerProto& proto, int npartitions) { Layer::Setup(proto, npartitions); CHECK_EQ(srclayers_.size(), 1); int max_window = static_cast<RnnDataLayer*>(srclayers_[0])->max_window(); + LOG(ERROR) << "clee " << max_window; data_.Reshape(vector<int>{max_window}); } http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/3dc1eee6/examples/rnnlm/rnnlm.proto ---------------------------------------------------------------------- diff --git a/examples/rnnlm/rnnlm.proto b/examples/rnnlm/rnnlm.proto index 65c34ec..01580c1 100644 --- a/examples/rnnlm/rnnlm.proto +++ b/examples/rnnlm/rnnlm.proto @@ -1,6 +1,5 @@ package singa; import "job.proto"; -import "common.proto"; message EmbeddingProto { @@ -30,4 +29,4 @@ message WordRecord { optional int32 class_index = 3; optional int32 class_start = 4; optional int32 class_end = 5; -} \ No newline at end of file +} http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/3dc1eee6/src/utils/tool.cc ---------------------------------------------------------------------- diff --git a/src/utils/tool.cc b/src/utils/tool.cc index ba453b2..305cbe7 100644 --- a/src/utils/tool.cc +++ b/src/utils/tool.cc @@ -20,6 +20,7 @@ *************************************************************/ #include <glog/logging.h> +#include <google/protobuf/text_format.h> #include <algorithm> #include <fstream> #include <iostream> @@ -51,12 +52,48 @@ int create() { return SUCCESS; } +// extract cluster configuration part from the job config file +// TODO improve this function to make it robust +const std::string extract_cluster(const char* jobfile) { + std::ifstream fin; + fin.open(jobfile, std::ifstream::in); + CHECK(fin.is_open()) << "cannot open job conf file " << jobfile; + std::string line; + std::string cluster; + while (std::getline(fin, line)) { + // end of extraction (cluster config has not nested messages) + if (line.find("}") != std::string::npos && cluster.length()) { + cluster += line.substr(0, line.find("}")); + break; + } + unsigned int pos = 0; + while (pos < line.length() && line.at(pos) == ' ' ) pos++; + if (line.find("cluster", pos) == pos) { // start with <whitespace> cluster + pos += 7; + do { // looking for the first '{', which may be in the next lines + while (pos < line.length() && + (line.at(pos) == ' ' || line.at(pos) =='\t')) pos++; + if (pos < line.length()) { + CHECK_EQ(line.at(pos), '{') << "error around 'cluster' field"; + cluster = " "; // start extraction + break; + } else + pos = 0; + }while(std::getline(fin, line)); + } else if (cluster.length()) { + cluster += line + "\n"; + } + } + return cluster; +} + + // generate a host list int genhost(char* job_conf) { // compute required #process from job conf - singa::JobProto job; - singa::ReadProtoFromTextFile(job_conf, &job); - singa::ClusterProto cluster = job.cluster(); + singa::ClusterProto cluster; + google::protobuf::TextFormat::ParseFromString(extract_cluster(job_conf), + &cluster); int nworker_procs = cluster.nworker_groups() * cluster.nworkers_per_group() / cluster.nworkers_per_procs(); int nserver_procs = cluster.nserver_groups() * cluster.nservers_per_group()
