SINGA-10 Add Support for Recurrent Neural Networks (RNN)

improve extract_cluster function in tool.cc
  - we read the last occurrence of "cluster"
  - content are the first occurrence of "{" and "}" after "cluster"
format rnnlm.h and common.cc


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/79a241c8
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/79a241c8
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/79a241c8

Branch: refs/heads/master
Commit: 79a241c8bb6c520979d52c895f4958c2b901cd47
Parents: ba3b1a5
Author: wang sheng <[email protected]>
Authored: Sun Sep 20 17:14:22 2015 +0800
Committer: wang sheng <[email protected]>
Committed: Sun Sep 20 17:14:22 2015 +0800

----------------------------------------------------------------------
 examples/rnnlm/rnnlm.h |  2 +-
 src/utils/common.cc    |  2 +-
 src/utils/tool.cc      | 43 ++++++++++++++++++++-----------------------
 3 files changed, 22 insertions(+), 25 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/79a241c8/examples/rnnlm/rnnlm.h
----------------------------------------------------------------------
diff --git a/examples/rnnlm/rnnlm.h b/examples/rnnlm/rnnlm.h
index 9fc0bcf..b848fa4 100644
--- a/examples/rnnlm/rnnlm.h
+++ b/examples/rnnlm/rnnlm.h
@@ -145,5 +145,5 @@ class LossLayer : public RNNLayer {
   Blob<float> pclass_;
   Param* word_weight_, *class_weight_;
 };
-}  // namespace singa
+}  // namespace rnnlm
 #endif  // EXAMPLES_RNNLM_RNNLM_H_

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/79a241c8/src/utils/common.cc
----------------------------------------------------------------------
diff --git a/src/utils/common.cc b/src/utils/common.cc
index 6dd40c8..65b2ec2 100644
--- a/src/utils/common.cc
+++ b/src/utils/common.cc
@@ -267,7 +267,7 @@ Metric::Metric(const string& str) {
 }
 
 void Metric::Add(const string& name, float value) {
-  Add( name, value, 1);
+  Add(name, value, 1);
 }
 void Metric::Add(const string& name, float value, int count) {
   if (entry_.find(name) == entry_.end()) {

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/79a241c8/src/utils/tool.cc
----------------------------------------------------------------------
diff --git a/src/utils/tool.cc b/src/utils/tool.cc
index 305cbe7..295dcf0 100644
--- a/src/utils/tool.cc
+++ b/src/utils/tool.cc
@@ -53,41 +53,38 @@ int create() {
 }
 
 // extract cluster configuration part from the job config file
-// TODO improve this function to make it robust
+// TODO(wangsh) improve this function to make it robust
 const std::string extract_cluster(const char* jobfile) {
   std::ifstream fin;
   fin.open(jobfile, std::ifstream::in);
   CHECK(fin.is_open()) << "cannot open job conf file " << jobfile;
   std::string line;
   std::string cluster;
+  bool in_cluster = false;
   while (std::getline(fin, line)) {
-    // end of extraction (cluster config has not nested messages)
-    if (line.find("}") != std::string::npos && cluster.length()) {
-      cluster += line.substr(0, line.find("}"));
-      break;
+    if (in_cluster == false) {
+      size_t pos = line.find("cluster");
+      if (pos == std::string::npos) continue;
+      in_cluster = true;
+      line = line.substr(pos);
+      cluster = "";
     }
-    unsigned int pos = 0;
-    while (pos < line.length() && line.at(pos) == ' ' ) pos++;
-    if (line.find("cluster", pos) == pos) {  // start with <whitespace> cluster
-      pos += 7;
-      do {  // looking for the first '{', which may be in the next lines
-        while (pos < line.length() &&
-            (line.at(pos) == ' ' || line.at(pos) =='\t')) pos++;
-        if (pos < line.length()) {
-          CHECK_EQ(line.at(pos), '{') << "error around 'cluster' field";
-          cluster =  " ";  // start extraction
-          break;
-        } else
-          pos = 0;
-      }while(std::getline(fin, line));
-    } else if (cluster.length()) {
-        cluster += line + "\n";
+    if (in_cluster == true) {
+      cluster += line + "\n";
+      if (line.find("}") != std::string::npos)
+        in_cluster = false;
     }
   }
-  return cluster;
+  LOG(INFO) << "cluster configure: " << cluster;
+  size_t s_pos = cluster.find("{");
+  size_t e_pos = cluster.find("}");
+  if (s_pos == std::string::npos || e_pos == std::string::npos) {
+    LOG(FATAL) << "cannot extract valid cluster configuration in file: "
+               << jobfile;
+  }
+  return cluster.substr(s_pos+1, e_pos-s_pos-1);
 }
 
-
 // generate a host list
 int genhost(char* job_conf) {
   // compute required #process from job conf

Reply via email to