SINGA-10 Add Support for Recurrent Neural Networks (RNN)
improve extract_cluster function in tool.cc
- we read the last occurrence of "cluster"
- content are the first occurrence of "{" and "}" after "cluster"
format rnnlm.h and common.cc
Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/79a241c8
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/79a241c8
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/79a241c8
Branch: refs/heads/master
Commit: 79a241c8bb6c520979d52c895f4958c2b901cd47
Parents: ba3b1a5
Author: wang sheng <[email protected]>
Authored: Sun Sep 20 17:14:22 2015 +0800
Committer: wang sheng <[email protected]>
Committed: Sun Sep 20 17:14:22 2015 +0800
----------------------------------------------------------------------
examples/rnnlm/rnnlm.h | 2 +-
src/utils/common.cc | 2 +-
src/utils/tool.cc | 43 ++++++++++++++++++++-----------------------
3 files changed, 22 insertions(+), 25 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/79a241c8/examples/rnnlm/rnnlm.h
----------------------------------------------------------------------
diff --git a/examples/rnnlm/rnnlm.h b/examples/rnnlm/rnnlm.h
index 9fc0bcf..b848fa4 100644
--- a/examples/rnnlm/rnnlm.h
+++ b/examples/rnnlm/rnnlm.h
@@ -145,5 +145,5 @@ class LossLayer : public RNNLayer {
Blob<float> pclass_;
Param* word_weight_, *class_weight_;
};
-} // namespace singa
+} // namespace rnnlm
#endif // EXAMPLES_RNNLM_RNNLM_H_
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/79a241c8/src/utils/common.cc
----------------------------------------------------------------------
diff --git a/src/utils/common.cc b/src/utils/common.cc
index 6dd40c8..65b2ec2 100644
--- a/src/utils/common.cc
+++ b/src/utils/common.cc
@@ -267,7 +267,7 @@ Metric::Metric(const string& str) {
}
void Metric::Add(const string& name, float value) {
- Add( name, value, 1);
+ Add(name, value, 1);
}
void Metric::Add(const string& name, float value, int count) {
if (entry_.find(name) == entry_.end()) {
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/79a241c8/src/utils/tool.cc
----------------------------------------------------------------------
diff --git a/src/utils/tool.cc b/src/utils/tool.cc
index 305cbe7..295dcf0 100644
--- a/src/utils/tool.cc
+++ b/src/utils/tool.cc
@@ -53,41 +53,38 @@ int create() {
}
// extract cluster configuration part from the job config file
-// TODO improve this function to make it robust
+// TODO(wangsh) improve this function to make it robust
const std::string extract_cluster(const char* jobfile) {
std::ifstream fin;
fin.open(jobfile, std::ifstream::in);
CHECK(fin.is_open()) << "cannot open job conf file " << jobfile;
std::string line;
std::string cluster;
+ bool in_cluster = false;
while (std::getline(fin, line)) {
- // end of extraction (cluster config has not nested messages)
- if (line.find("}") != std::string::npos && cluster.length()) {
- cluster += line.substr(0, line.find("}"));
- break;
+ if (in_cluster == false) {
+ size_t pos = line.find("cluster");
+ if (pos == std::string::npos) continue;
+ in_cluster = true;
+ line = line.substr(pos);
+ cluster = "";
}
- unsigned int pos = 0;
- while (pos < line.length() && line.at(pos) == ' ' ) pos++;
- if (line.find("cluster", pos) == pos) { // start with <whitespace> cluster
- pos += 7;
- do { // looking for the first '{', which may be in the next lines
- while (pos < line.length() &&
- (line.at(pos) == ' ' || line.at(pos) =='\t')) pos++;
- if (pos < line.length()) {
- CHECK_EQ(line.at(pos), '{') << "error around 'cluster' field";
- cluster = " "; // start extraction
- break;
- } else
- pos = 0;
- }while(std::getline(fin, line));
- } else if (cluster.length()) {
- cluster += line + "\n";
+ if (in_cluster == true) {
+ cluster += line + "\n";
+ if (line.find("}") != std::string::npos)
+ in_cluster = false;
}
}
- return cluster;
+ LOG(INFO) << "cluster configure: " << cluster;
+ size_t s_pos = cluster.find("{");
+ size_t e_pos = cluster.find("}");
+ if (s_pos == std::string::npos || e_pos == std::string::npos) {
+ LOG(FATAL) << "cannot extract valid cluster configuration in file: "
+ << jobfile;
+ }
+ return cluster.substr(s_pos+1, e_pos-s_pos-1);
}
-
// generate a host list
int genhost(char* job_conf) {
// compute required #process from job conf