Repository: incubator-singa
Updated Branches:
  refs/heads/master ea7cfea49 -> 9a6e09fa2


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/9a6e09fa/src/utils/graph.cc
----------------------------------------------------------------------
diff --git a/src/utils/graph.cc b/src/utils/graph.cc
index b1f5b9f..d92e241 100644
--- a/src/utils/graph.cc
+++ b/src/utils/graph.cc
@@ -1,166 +1,202 @@
+
+#include "utils/graph.h"
+#include <glog/logging.h>
 #include <algorithm>
 #include <queue>
 #include <unordered_set>
-#include "utils/graph.h"
 
-const string Graph::ToString() const {
+namespace singa {
+/************************Node********************************/
+
+Node::~Node() {
+  // the proto field is deleted outside by other functions
+}
+
+Node::Node(string name) {
+  this->name = name;
+}
+
+Node::Node(const string& name, const string& origin, int id, void* proto) {
+  this->name = name;
+  this->origin = origin;
+  this->proto = proto;
+  this->partition_id = id;
+}
+
+void Node::AddDstNode(Node* dstnode) {
+  dstnodes.push_back(dstnode);
+}
+
+void Node::AddSrcNode(Node* srcnode) {
+  srcnodes.push_back(srcnode);
+}
+
+void Node::RemoveDstNode(Node* dst) {
+  auto iter = dstnodes.begin();
+  while ((*iter)->name != dst->name && iter != dstnodes.end())
+    iter++;
+  CHECK_STREQ((*iter)->name.c_str(), dst->name.c_str());
+  dstnodes.erase(iter);
+}
+
+void Node::RemoveSrcNode(Node* src) {
+  auto iter = srcnodes.begin();
+  while ((*iter)->name != src->name && iter != srcnodes.end())
+    iter++;
+  CHECK((*iter)->name == src->name);
+  srcnodes.erase(iter);
+}
+
+/*************************Graph****************************/
+Graph::~Graph() {
+  for (Node* node : nodes_)
+    delete node;
+}
+
+void Graph::AddNode(Node* node) {
+  nodes_.push_back(node);
+  name2node_[node->name] = node;
+}
+
+Node* Graph::AddNode(const string& name) {
+  Node* node = new Node(name);
+  AddNode(node);
+  return node;
+}
+
+void Graph::AddEdge(Node* srcnode, Node* dstnode) {
+  srcnode->AddDstNode(dstnode);
+  dstnode->AddSrcNode(srcnode);
+}
+
+void Graph::AddEdge(const string& src, const string& dst) {
+  CHECK(name2node_.find(src) != name2node_.end())
+    <<"can't find src node " << src;
+  CHECK(name2node_.find(dst) != name2node_.end())
+    <<"can't find dst node " << dst;
+
+  Node* srcnode = name2node_[src], *dstnode = name2node_[dst];
+  AddEdge(srcnode, dstnode);
+}
+
+void Graph::RemoveEdge(Node* src, Node* dst) {
+  src->RemoveDstNode(dst);
+  dst->RemoveSrcNode(src);
+}
+
+void Graph::RemoveEdge(const string &src, const string& dst) {
+  CHECK(name2node_.find(src) != name2node_.end())
+    <<"can't find src node " << src;
+  CHECK(name2node_.find(dst) != name2node_.end())
+    <<"can't find dst node " << dst;
+
+  Node* srcnode = name2node_[src], *dstnode = name2node_[dst];
+  RemoveEdge(srcnode, dstnode);
+}
+
+const string Graph::ToJson() const {
   map<string, string> info;
-  return ToString(info);
+  return ToJson(info);
 }
-const string Graph::ToString(const map<string, string>& info) const {
+
+const string Graph::ToJson(const map<string, string>& info) const {
   map<string, int> nodeid;
-  string disp="{\"directed\":1,\n";
+  string disp = "{\"directed\":1,\n";
 
   // add nodes
-  disp+="\"nodes\":[\n";
-  bool first=true;
+  disp += "\"nodes\":[\n";
+  bool first = true;
 
-  vector<string> colors={"red", "blue", "black", "green"};
+  vector<string> colors = {"red", "blue", "black", "green"};
   // see for more shapes at http://www.graphviz.org/doc/info/shapes.html
-  vector<string> shapes={"box", "ellipse"};
-  int id=0;
-  for(auto node: nodes_){
+  vector<string> shapes = {"box", "ellipse"};
+  int id = 0;
+  for (auto node : nodes_) {
     char str[1024];
-    string name=node->name();
-    string color=colors[(node->val().partitionid)%colors.size()];
+    string name = node->name;
+    string color = colors[(node->partition_id)%colors.size()];
     string shape;
-    string origin=node->val().origin;
-    if(origin=="kSlice"||origin=="kConcate"||origin=="kSplit"
-        ||origin=="kBridgeSrc"||origin=="kBridgeDst")
-      shape=shapes[1];
+    string origin = node->origin;
+    if (origin.find("##") != string::npos)
+      shape = shapes[1];
     else
-      shape=shapes[0];
-    sprintf(str, "{\"id\":\"%s%s\", \"color\":\"%s\",\"shape\":\"%s\"}\n",
-        name.c_str(), info.find(name)!=info.end()?info.at(name).c_str():"",
+      shape = shapes[0];
+    snprintf(str, sizeof(str),
+        "{\"id\":\"%s%s\", \"color\":\"%s\",\"shape\":\"%s\"}\n", name.c_str(),
+        info.find(name) != info.end() ? info.at(name).c_str() : "",
         color.c_str(), shape.c_str());
-    if(!first)
-      disp+=",";
+    if (!first)
+      disp += ",";
     else
-      first=false;
-    disp+=string(str);
-    nodeid[name]=id++;
+      first = false;
+    disp += string(str);
+    nodeid[name] = id++;
   }
-  disp+="]\n,";
+  disp += "]\n,";
 
   // add edges
-  disp+="\"links\":[\n";
-  first=true;
-  for(auto src: nodes_)
-    for(auto dst: src->dstnodes()){
-    char str[1024];
-    sprintf(str, "{\"source\":%d, \"target\":%d, \"color\":\"%s\"}\n",
-        nodeid[src->name()], nodeid[dst->name()], "black");
-    if(!first)
-      disp+=",";
-    else
-      first=false;
-    disp+=string(str);
+  disp += "\"links\":[\n";
+  first = true;
+  for (auto src : nodes_) {
+    for (auto dst : src->dstnodes) {
+      char str[1024];
+      snprintf(str, sizeof(str),
+          "{\"source\":%d, \"target\":%d, \"color\":\"%s\"}\n",
+          nodeid[src->name], nodeid[dst->name], "black");
+      if (!first)
+        disp += ",";
+      else
+        first = false;
+      disp += string(str);
+    }
   }
-  disp+="]\n";
+  disp += "]\n";
   return disp+"}";
 }
-bool Graph::Check() const {
-  return true;
-}
-
-
-// visited all dst nodes and then push current node into the stack
-void Graph::topology_sort_inner(SNode node,
-    map<string, bool> *visited,
-    std::stack<string> *stack) {
-  (*visited)[node->name()] = true;
-  const vector<SNode>& dstnodes=node->dstnodes();
-  for (auto it=dstnodes.rbegin();it!=dstnodes.rend();it++) {
-    if ((*visited)[(*it)->name()])
-      continue;
-    topology_sort_inner((*it),visited, stack);
-  }
-  stack->push(node->name());
-}
 
 // sort to make `bottom' nodes be placed in the front positions
 void Graph::Sort() {
-  SNode start=nullptr;
-  map<string, bool> visited;
-  for(auto node: nodes_){
-    if(node->srcnodes().size()==0){
-      CHECK(start==nullptr);
-      start=node;
+  // nodes to be visited
+  std::queue<Node*> visiting_nodes;
+  // visited node set
+  std::unordered_set<Node*> visited_set;
+  // visiting_nodes + visted_set
+  std::unordered_set<Node*> visit_set;;
+  for (auto node : nodes_) {
+    // visit nodes without source nodes firstly
+    if (node->srcnodes.size() == 0) {
+      visiting_nodes.push(node);
+      visit_set.insert(node);
     }
-    visited[node->name()]=false;
   }
-  int n=nodes_.size();
-  std::unordered_set<SNode> pushed;
-  std::queue<SNode> tmp;
-  tmp.push(start);
-  pushed.insert(start);
+  int n = nodes_.size();
   nodes_.clear();
-  while(!tmp.empty()){
-    auto node=tmp.front();
-    tmp.pop();
-    bool visit=true;
-    for(auto src: node->srcnodes())
-      if(visited[src->name()]==false){
-        visit=false;
+  while (!visiting_nodes.empty()) {
+    auto node = visiting_nodes.front();
+    visiting_nodes.pop();
+    bool visit = true;
+    for (auto src : node->srcnodes) {
+      // visit this node only if all srouce nodes have been visited
+      if (visited_set.find(src) == visited_set.end()) {
+        visit = false;
         break;
       }
-    if(visit){
+    }
+    if (visit) {
       nodes_.push_back(node);
-      visited[node->name()]=true;
-      for(auto dst: node->dstnodes()){
-        if(pushed.find(dst) == pushed.end()){
-          tmp.push(dst);
-          pushed.insert(dst);
+      visited_set.insert(node);
+      for (auto dst : node->dstnodes) {
+        // queueing the dst node if it is not queued before
+        if (visit_set.find(dst) == visit_set.end()) {
+          visiting_nodes.push(dst);
+          visit_set.insert(dst);
         }
       }
-    }else
-      tmp.push(node);
+    } else {
+      visiting_nodes.push(node);
+    }
   }
   CHECK_EQ(nodes_.size(), n);
 }
 
-
-SNode Graph::InsertSliceNode(SNode srcnode, const vector<SNode>& dstnodes,
-    const V& info, bool connect_dst){
-  V myinfo=info;
-  myinfo.origin="kSlice";
-  SNode node=AddNode("slice-"+srcnode->name(),myinfo);
-  AddEdge(srcnode, node);
-  if(connect_dst)
-    for(SNode dst: dstnodes)
-      AddEdge(node, dst);
-  return node;
-}
-SNode Graph::InsertConcateNode(const vector<SNode>&srcnodes, SNode dstnode,
-    const V& info){
-  V myinfo=info;
-  myinfo.origin="kConcate";
-  SNode node=AddNode("concate-"+dstnode->name(),myinfo);
-  AddEdge(node, dstnode);
-  for(SNode src: srcnodes)
-    AddEdge(src, node);
-  return node;
-}
-SNode Graph::InsertSplitNode(SNode srcnode, const vector<SNode>& dstnodes){
-  V myinfo=srcnode->val();
-  myinfo.origin="kSplit";
-  SNode node=AddNode("split-"+srcnode->name(), myinfo);
-  AddEdge(srcnode, node);
-  for(SNode dst: dstnodes)
-    AddEdge(node, dst);
-  return node;
-}
-std::pair<SNode, SNode> Graph::InsertBridgeNode(SNode srcnode, SNode dstnode){
-  LayerInfo info=srcnode->val();
-  info.origin="kBridgeSrc";
-  SNode src=AddNode("s-"+srcnode->name()+"-"+dstnode->name(), info);
-  info=dstnode->val();
-  info.origin="kBridgeDst";
-  SNode dst=AddNode("d-"+srcnode->name()+"-"+dstnode->name(), info);
-  AddEdge(srcnode, src);
-  AddEdge(src, dst);
-  AddEdge(dst, dstnode);
-  return pair<SNode, SNode>{src, dst};
-}
-
-
+}  // namespace singa

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/9a6e09fa/src/utils/param.cc
----------------------------------------------------------------------
diff --git a/src/utils/param.cc b/src/utils/param.cc
index 02d80a1..24a0541 100644
--- a/src/utils/param.cc
+++ b/src/utils/param.cc
@@ -27,8 +27,7 @@ void Param::AddSlice(int slice_id, int size){
     //must be added in order
     CHECK_EQ(slice_start_+num_slices_, slice_id);
     offset=slice_offset_.back()+slice_size_.back();
-  }
-  else{
+  } else {
     slice_start_=slice_id;
     offset=0;
   }

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/9a6e09fa/src/utils/updater.cc
----------------------------------------------------------------------
diff --git a/src/utils/updater.cc b/src/utils/updater.cc
index 80e3619..8e949ef 100644
--- a/src/utils/updater.cc
+++ b/src/utils/updater.cc
@@ -64,7 +64,7 @@ void SGDUpdater::Init(const UpdaterProto& proto){
   weight_decay_=proto.weight_decay();
 }
 
-void SGDUpdater::Update(int step, shared_ptr<Param> param, float grad_scale){
+void SGDUpdater::Update(int step, Param* param, float grad_scale){
   Shape<1> s=Shape1(param->size());
   Tensor<cpu, 1> data(param->mutable_cpu_data(), s);
   Tensor<cpu, 1> grad(param->mutable_cpu_grad(), s);
@@ -92,7 +92,7 @@ void NesterovUpdater::Init(const UpdaterProto& proto){
   weight_decay_=proto.weight_decay();
 }
 
-void NesterovUpdater::Update(int step, shared_ptr<Param> param, float 
grad_scale){
+void NesterovUpdater::Update(int step, Param* param, float grad_scale){
   Shape<1> s=Shape1(param->size());
   Tensor<cpu, 1> data(param->mutable_cpu_data(), s);
   Tensor<cpu, 1> grad(param->mutable_cpu_grad(), s);
@@ -118,7 +118,7 @@ void AdaGradUpdater::Init(const UpdaterProto& proto){
   weight_decay_=proto.weight_decay();
 }
 
-void AdaGradUpdater::Update(int step, shared_ptr<Param> param, float 
grad_scale){
+void AdaGradUpdater::Update(int step, Param* param, float grad_scale){
   Shape<1> s=Shape1(param->size());
   Tensor<cpu, 1> data(param->mutable_cpu_data(), s);
   Tensor<cpu, 1> grad(param->mutable_cpu_grad(), s);
@@ -143,7 +143,7 @@ void RMSPropUpdater::Init(const UpdaterProto& proto){
   weight_decay_=proto.weight_decay();
 }
 
-void RMSPropUpdater::Update(int step, shared_ptr<Param> param, float 
grad_scale){
+void RMSPropUpdater::Update(int step, Param* param, float grad_scale){
   Shape<1> s=Shape1(param->size());
   Tensor<cpu, 1> data(param->mutable_cpu_data(), s);
   Tensor<cpu, 1> grad(param->mutable_cpu_grad(), s);
@@ -166,7 +166,7 @@ void AdaDeltaUpdater::Init(const UpdaterProto& proto){
   weight_decay_=proto.weight_decay();
 }
 
-void AdaDeltaUpdater::Update(int step, shared_ptr<Param> param, float 
grad_scale){
+void AdaDeltaUpdater::Update(int step, Param* param, float grad_scale){
   Shape<1> s=Shape1(param->size());
   Tensor<cpu, 1> data(param->mutable_cpu_data(), s);
   Tensor<cpu, 1> grad(param->mutable_cpu_grad(), s);

Reply via email to