Repository: incubator-singa
Updated Branches:
  refs/heads/master 06f85e23e -> f29d93ff7


Working on data partition within one group running on a single node.
TODO
1. update the performance collection by reporting performance to the stub.
2. let workers pass requests to the stub without copying data (passing addr or 
param id). messages to servers are then generated by the stub which can 
aggregate gradients of shared parameters from all workers and collect the 
updated parameters for them.


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/39969f2b
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/39969f2b
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/39969f2b

Branch: refs/heads/master
Commit: 39969f2b2d917c79deb7b9e850926a0b7778c0b5
Parents: 48b8fea
Author: wang wei <[email protected]>
Authored: Tue May 12 09:50:23 2015 +0800
Committer: wang wei <[email protected]>
Committed: Tue May 12 09:50:23 2015 +0800

----------------------------------------------------------------------
 Makefile.example               |  91 --------------------------
 examples/cifar10/Makefile      |   2 +-
 examples/cifar10/cluster.conf  |   3 +-
 examples/cifar10/model.conf    |   5 +-
 include/neuralnet/base_layer.h |  23 ++-----
 include/neuralnet/neuralnet.h  |   5 +-
 include/trainer/pm_worker.h    |  15 +++--
 include/trainer/server.h       |   7 +-
 include/trainer/worker.h       |  10 +--
 include/utils/blob.h           |   9 ++-
 include/utils/graph.h          |   2 +-
 include/utils/param.h          |  46 ++++++-------
 script/graph.py                |  22 +++++++
 src/neuralnet/base_layer.cc    |  15 -----
 src/neuralnet/neuralnet.cc     |  41 ++++++++----
 src/proto/model.pb.h           |  42 ++++++------
 src/proto/model.proto          |   7 +-
 src/trainer/pm_worker.cc       | 102 ++++++++++++-----------------
 src/trainer/server.cc          |  22 +++++--
 src/trainer/trainer.cc         |  51 ++++++++-------
 src/trainer/worker.cc          | 126 ++++++++++++++++++++++++++----------
 src/utils/graph.cc             |   2 +-
 src/utils/param.cc             |  32 ++++-----
 23 files changed, 330 insertions(+), 350 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/39969f2b/Makefile.example
----------------------------------------------------------------------
diff --git a/Makefile.example b/Makefile.example
deleted file mode 100644
index 80dfc26..0000000
--- a/Makefile.example
+++ /dev/null
@@ -1,91 +0,0 @@
-###################User Config Varaibles #############################
-# third-party library installation folder
-HOME_DIR := /usr/
-# Lib folder for system and external libs. You may need to change it.
-LIBRARY_DIRS := $(HOME_DIR)/lib64 $(HOME_DIR)/lib $(HOME_DIR)/local/lib
-# Header folder for system and external libs. You may need to change it.
-INCLUDE_DIRS := $(HOME_DIR)/include ./include
-# g++ location, should support c++11, tested with 4.8.1
-CXX := g++
-
-######################Setting Varialbes#######################################
-LIBRARIES := glog gflags protobuf rt opencv_highgui opencv_imgproc opencv_core\
-       lmdb openblas zmq czmq
-
-LDFLAGS := $(foreach librarydir, $(LIBRARY_DIRS), -L$(librarydir))\
-       $(foreach library, $(LIBRARIES), -l$(library))
-# Folder to store compiled files
-BUILD_DIR := build
-MSHADOW_FLAGS :=-DMSHADOW_USE_CUDA=0 -DMSHADOW_USE_CBLAS=1 -DMSHADOW_USE_MKL=0
-CXXFLAGS := -O3 -Wall -pthread -fPIC -std=c++11 -Wno-unknown-pragmas \
-       $(MSHADOW_FLAGS) -DCPU_ONLY=1 \
-       -funroll-loops $(foreach includedir, $(INCLUDE_DIRS), -I$(includedir))
-
-# find user defined .proto file, and then compute the corresponding .h, .cc
-# files, which cannot be found by shell find, because they haven't been
-# generated currently
-PROTOS := $(shell find src/proto/ -name "*.proto")
-PROTO_SRCS :=$(PROTOS:.proto=.pb.cc)
-PROTO_HDRS :=$(patsubst src%, include%, $(PROTOS:.proto=.pb.h))
-PROTO_OBJS :=$(addprefix $(BUILD_DIR)/, $(PROTO_SRCS:.cc=.o))
-
-# each singa src file will generate a .o file
-SINGA_SRCS := $(shell find src/ \( -path "src/test" -o -path "src/main.cc" \) \
-       -prune -o \( -name "*.cc" -type f \) -print )
-SINGA_OBJS := $(sort $(addprefix $(BUILD_DIR)/, $(SINGA_SRCS:.cc=.o)) \
-       $(PROTO_OBJS) )
--include $(SINGA_OBJS:%.o=%.P)
-
-TEST_SRCS :=$(shell find src/test/ -maxdepth 1 -name "*.cc")
-TEST_OBJS := $(sort $(addprefix $(BUILD_DIR)/, $(TEST_SRCS:.cc=.o)))
--include $(TEST_OBJS:%.o=%.P)
-
-GTEST_SRC := include/gtest/gtest-all.cc
-GTEST_HDR := include/gtest/gtest.h
-GTEST_LIB := $(BUILD_DIR)/libgtest.a
-
-OBJS := $(sort $(SINGA_OBJS) $(TEST_OBJS) )
-
-########################Compilation Section###################################
-.PHONY: singa test
-
-singa: $(PROTO_OBJS) $(SINGA_OBJS)
-       $(CXX) $(SINGA_OBJS) src/main.cc -o $(BUILD_DIR)/singa $(CXXFLAGS) 
$(LDFLAGS)
-       @echo
-
-loader: proto $(LOADER_OBJS)
-       $(CXX) $(LOADER_OBJS) -o $(BUILD_DIR)/loader $(CXXFLAGS) $(LDFLAGS)
-       @echo
-
-test:  proto $(GTEST_LIB) $(TEST_OBJS) $(SINGA_OBJS)
-       $(CXX) $(TEST_OBJS) include/gtest/gtest_main.cc $(GTEST_LIB) \
-               $(SINGA_OBJS) -o $(BUILD_DIR)/test $(CXXFLAGS) $(LDFLAGS)
-       @echo
-
-$(GTEST_LIB): $(GTEST_HDR) $(GTEST_SRC)
-       $(CXX) $(GTEST_SRC) -c -o $(BUILD_DIR)/gtest-all.o $(CXXFLAGS)
-       ar -rv $(GTEST_LIB) $(BUILD_DIR)/gtest-all.o
-
-# compile all files
-$(OBJS):$(BUILD_DIR)/%.o : %.cc
-       @mkdir -p $(dir $@)
-       $(CXX) $<  $(CXXFLAGS) -MMD -c -o $@
-       cp $(BUILD_DIR)/$*.d $(BUILD_DIR)/$*.P; \
-       sed -e 's/#.*//' -e 's/^[^:]*: *//' -e 's/ *\\$$//' \
-               -e '/^$$/ d' -e 's/$$/ :/' < $(BUILD_DIR)/$*.d >> 
$(BUILD_DIR)/$*.P; \
-       rm -f $*.d
-
-proto: $(PROTO_OBJS)
-
-$(PROTO_SRCS): $(PROTOS)
-       protoc --proto_path=src/proto --cpp_out=src/proto $(PROTOS)
-       mkdir -p include/proto/
-       cp src/proto/*.pb.h include/proto/
-       @echo
-
-clean:
-       rm -rf *.a *.so
-       rm -rf include/proto/*
-       rm -rf src/proto/*.pb.h src/proto/*.pb.cc
-       rm -rf $(BUILD_DIR)
-       @echo

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/39969f2b/examples/cifar10/Makefile
----------------------------------------------------------------------
diff --git a/examples/cifar10/Makefile b/examples/cifar10/Makefile
index 40fece6..16c329f 100644
--- a/examples/cifar10/Makefile
+++ b/examples/cifar10/Makefile
@@ -5,7 +5,7 @@ libs :=singa glog protobuf
 download: cifar-10-binary-bin
 
 cifar-10-binary-bin:
-       wget http://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz
+       #wget http://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz
        tar xf cifar-10-binary.tar.gz
 
 create:

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/39969f2b/examples/cifar10/cluster.conf
----------------------------------------------------------------------
diff --git a/examples/cifar10/cluster.conf b/examples/cifar10/cluster.conf
index 6b8a8e6..88c3d4b 100644
--- a/examples/cifar10/cluster.conf
+++ b/examples/cifar10/cluster.conf
@@ -1,5 +1,6 @@
 nworker_groups: 1
 nserver_groups: 1
 nservers_per_group: 1
-nworkers_per_group: 1
+nworkers_per_group: 2
+nworkers_per_procs: 2
 workspace: "examples/cifar10/"

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/39969f2b/examples/cifar10/model.conf
----------------------------------------------------------------------
diff --git a/examples/cifar10/model.conf b/examples/cifar10/model.conf
index a27486a..bace74d 100644
--- a/examples/cifar10/model.conf
+++ b/examples/cifar10/model.conf
@@ -1,8 +1,8 @@
 name: "cifar10-convnet"
 train_steps: 70000
-test_steps:100
+test_steps:5
 test_frequency:1000
-display_frequency:50
+display_frequency:1
 updater{
   momentum:0.9
   weight_decay:0.004
@@ -15,6 +15,7 @@ updater{
   step_lr:0.00001
 }
 neuralnet {
+partition_type: kDataPartition
 layer{
   name: "data"
   type: "kShardData"

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/39969f2b/include/neuralnet/base_layer.h
----------------------------------------------------------------------
diff --git a/include/neuralnet/base_layer.h b/include/neuralnet/base_layer.h
index 9b8545a..8e49059 100644
--- a/include/neuralnet/base_layer.h
+++ b/include/neuralnet/base_layer.h
@@ -161,25 +161,12 @@ class Layer {
     return layer_proto_.partition_type();
   }
   /**
-   * Set location ID as the worker ID within a worker group.
-   * TODO(wangwei) merge location ID with partition ID
-   */
-  virtual void set_locationid(int id){
-    layer_proto_.set_locationid(id);
-  }
-  /**
-   * @return location ID
-   */
-  virtual int locationid() const {
-    return layer_proto_.locationid();
-  }
-  /**
    * partition id is the ID of the layer in the original layer.
    */
   virtual void set_partitionid(int id){
     layer_proto_.set_partitionid(id);
   }
-  virtual int partitiionid() const {
+  virtual int partitionid() const {
     return layer_proto_.partitionid();
   }
   virtual void set_name(string name){
@@ -301,10 +288,10 @@ class BridgeSrcLayer: public Layer {
 
   virtual void ComputeFeature(bool training, const vector<SLayer>& srclayers);
   virtual void ComputeGradient(const vector<SLayer>& srclayers);
+  int dst_partition() const;
   virtual bool is_bridgesrclayer() const {
     return true;
   }
-
   virtual void set_ready(bool a) {
     ready_=a;
   }
@@ -330,8 +317,10 @@ class BridgeDstLayer: public Layer {
       const vector<int> &shape,
       const vector<SLayer>& srclayers){}
 
-  virtual void ComputeFeature(bool training, const vector<SLayer>& srclayers);
-  virtual void ComputeGradient(const vector<SLayer>& srclayers);
+  virtual void ComputeFeature(bool training, const vector<SLayer>& srclayers){
+    ready_=false;
+  }
+  virtual void ComputeGradient(const vector<SLayer>& srclayers){}
   virtual bool is_bridgedstlayer() const {
     return true;
   }

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/39969f2b/include/neuralnet/neuralnet.h
----------------------------------------------------------------------
diff --git a/include/neuralnet/neuralnet.h b/include/neuralnet/neuralnet.h
index 586a470..ec6797c 100644
--- a/include/neuralnet/neuralnet.h
+++ b/include/neuralnet/neuralnet.h
@@ -37,8 +37,11 @@ class NeuralNet {
    * setup (done outside of this funcion).
    *
    * @param np proto for the neural network.
+   * @param phase test/training/validation
+   * @param group_size partition the net among this num of workers
    */
-  static shared_ptr<NeuralNet> SetupNeuralNet(const NetProto& np, Phase phase);
+  static shared_ptr<NeuralNet> SetupNeuralNet(const NetProto& np, Phase phase,
+      int group_size);
 
  public:
   /**

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/39969f2b/include/trainer/pm_worker.h
----------------------------------------------------------------------
diff --git a/include/trainer/pm_worker.h b/include/trainer/pm_worker.h
index 198f5bd..9b973d6 100644
--- a/include/trainer/pm_worker.h
+++ b/include/trainer/pm_worker.h
@@ -36,8 +36,10 @@ namespace singa {
 class ParamCounter{
   public:
   ParamCounter(shared_ptr<Param> p,int local, int owner):
-    nUpdate(0), nGet(0), nPut(0), nCollect(0), nLocal(local), nTotal(0),
-    owner_procs(owner), param(p){}
+    nUpdate(0), nGet(0), nPut(0), nCollect(0), nLocal(local), nTotal(1),
+    owner_procs(owner){
+      shares.push_back(p);
+    }
 
   /**
    * Associate the counter to a Param object.
@@ -50,10 +52,10 @@ class ParamCounter{
   void AddParam(shared_ptr<Param> p, int local, int owner){
     nLocal+=local;
     nTotal+=1;
-    if(owner_procs>-1)
+    if(owner>-1)
       owner_procs=owner;
-    if(nLocal>1){
-      // TODO copy p->param;
+    if(local>0){
+      shares.push_back(p);
     }
   }
   std::atomic<int> nUpdate, nGet, nPut, nCollect; //!< all counters are atomic
@@ -61,10 +63,9 @@ class ParamCounter{
   int nLocal; //!< # local workers uses the shared parameter
   int nTotal; //!< # total workers uses the shared parameter
   int owner_procs; //!< the procs id of the worker that owns the parameter
-  shared_ptr<Param> param;
+  vector<shared_ptr<Param>> shares;
 };
 
-
 /**
  * Parameter manager at the worker side.
  */

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/39969f2b/include/trainer/server.h
----------------------------------------------------------------------
diff --git a/include/trainer/server.h b/include/trainer/server.h
index d113c7d..6ae09e4 100644
--- a/include/trainer/server.h
+++ b/include/trainer/server.h
@@ -8,13 +8,12 @@ using std::shared_ptr;
 namespace singa {
 class Server{
  public:
-  Server(int group_id, int server_id);
-  void Setup(const UpdaterProto& proto, shared_ptr<PMServer::ParamShard> shard,
-    shared_ptr<Dealer> dealer);
+  Server(int thread_id, int group_id, int server_id);
+  void Setup(const UpdaterProto& proto, shared_ptr<PMServer::ParamShard> 
shard);
   void Run();
 
  protected:
-  int group_id_, server_id_;
+  int thread_id_,group_id_, server_id_;
   shared_ptr<PMServer> pmserver_;
   shared_ptr<Dealer> dealer_;
 };

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/39969f2b/include/trainer/worker.h
----------------------------------------------------------------------
diff --git a/include/trainer/worker.h b/include/trainer/worker.h
index 0e9f356..13f7798 100644
--- a/include/trainer/worker.h
+++ b/include/trainer/worker.h
@@ -39,11 +39,10 @@ class Performance{
  */
 class Worker {
  public:
-  Worker(int group_id, int worker_id);
+  Worker(int thread_id, int group_id, int worker_id);
   ~Worker(){}
   void Setup(const ModelProto& model, shared_ptr<NeuralNet> train_net,
-      shared_ptr<PMWorker::ParamShard> shard, shared_ptr<Dealer> layer_dealer,
-    shared_ptr<Dealer> param_dealer);
+      shared_ptr<PMWorker::ParamShard> shard);
   void set_test_net(shared_ptr<NeuralNet> test_net){
     test_net_=test_net;
   }
@@ -55,6 +54,7 @@ class Worker {
   int Get(shared_ptr<Param> param, int step);
   int Update(shared_ptr<Param> param, int step);
   int Collect(shared_ptr<Param> param, int step);
+  int CollectAll(shared_ptr<NeuralNet> net, int step);
   /**
     * check validation/test firstly, then TrainOneBatch
     * Performance collects performance for the whole neuralnet.
@@ -160,7 +160,7 @@ class Worker {
   void ReceiveBlobs(shared_ptr<NeuralNet> net);
   void SendBlob();
  protected:
-  int group_id_, worker_id_;
+  int thread_id_,group_id_, worker_id_;
   int step_;
   ModelProto modelproto_;
   shared_ptr<PMWorker> pmworker_;
@@ -172,7 +172,7 @@ class Worker {
 class BPWorker: public Worker{
  public:
   ~BPWorker(){}
-  BPWorker(int group_id, int worker_id):Worker(group_id, worker_id){}
+  BPWorker(int thread_id, int group_id, int worker_id):Worker(thread_id, 
group_id, worker_id){}
   virtual void TrainOneBatch(int step);
   virtual void TestOneBatch(shared_ptr<NeuralNet> net, int step, Phase phase);
   void Forward(shared_ptr<NeuralNet> net, int step, bool training);

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/39969f2b/include/utils/blob.h
----------------------------------------------------------------------
diff --git a/include/utils/blob.h b/include/utils/blob.h
index 08068eb..8234b28 100644
--- a/include/utils/blob.h
+++ b/include/utils/blob.h
@@ -95,7 +95,7 @@ class SyncedMemory {
 template <typename Dtype>
 class Blob {
  public:
-  Blob(): count_(0), capacity_(0) {}
+  Blob(): count_(0), capacity_(0) , version_(-1){}
   Blob(const vector<int>&shape);
   /**
    * @brief Change the dimensions of the blob, allocating new memory if
@@ -117,6 +117,12 @@ class Blob {
     return shape_;
   }
   inline int count() const { return count_; }
+  void set_version(int v){
+    version_=v;
+  }
+  const int version() const {
+    return version_;
+  }
   /**
    * @brief Copy from a source Blob.
    *
@@ -161,6 +167,7 @@ class Blob {
   vector<int> shape_;
   int count_;
   int capacity_;
+  int version_;
 };  // class Blob
 
 #endif // INCLUDE_UTILS_BLOB_

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/39969f2b/include/utils/graph.h
----------------------------------------------------------------------
diff --git a/include/utils/graph.h b/include/utils/graph.h
index ca582b5..93348dd 100644
--- a/include/utils/graph.h
+++ b/include/utils/graph.h
@@ -18,7 +18,7 @@ using std::make_shared;
 typedef struct _LayerInfo{
   // origin identifies the origin of this node, i.e., the corresponding layer
   string origin;
-  int locationid;// locationidation id;
+  //int locationid;// locationidation id;
   int partitionid;
   int slice_dimension;
   int concate_dimension;

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/39969f2b/include/utils/param.h
----------------------------------------------------------------------
diff --git a/include/utils/param.h b/include/utils/param.h
index 907ef8c..0574b2c 100644
--- a/include/utils/param.h
+++ b/include/utils/param.h
@@ -11,8 +11,8 @@
 namespace singa {
 class Param {
  public:
-  Param();
-  virtual ~Param();
+  Param():data_(nullptr){}
+  virtual ~Param(){};
 
   virtual Msg* GenGetMsg(void* arg=nullptr);
   virtual Msg* GenPutMsg(void* arg=nullptr);
@@ -39,10 +39,11 @@ class Param {
    */
   virtual void Init(int v=0);
   void ShareData(shared_ptr<Param> other){
-    owner_=other->id();
-    CHECK(std::equal(data_.shape().begin(), data_.shape().end(),
-          other->data_.shape().begin()));
-    data_.ShareData(other->data_);
+    proto_.set_owner(other->owner());
+    if(data_!=nullptr)
+      CHECK(std::equal(data_->shape().begin(), data_->shape().end(),
+          other->data_->shape().begin()));
+    data_=other->data_;
   }
   float learning_rate_multiplier() {
     return proto_.learning_rate_multiplier();
@@ -55,44 +56,44 @@ class Param {
     return proto_.split_threshold();
   }
   */
+  const std::string& name() {
+    return proto_.name();
+  }
   /**
-   * if the Param shares data with others, then point to the owner.
-   * otherwise points to itself.
+   * if the Param shares data with others, then owner is the id of that param.
+   * otherwise it is itself's id.
    */
   const int owner() const{
-    return owner_;
-  }
-  const std::string& name() {
-    return proto_.name();
+    return proto_.owner();
   }
-
   int id() const{
     return proto_.id();
   }
   void set_id(int id){
     proto_.set_id(id);
+    proto_.set_owner(id);
   }
 
   int version() const {
-    return proto_.version(); // TODO store version in data blob
+    return data_->version(); // TODO store version in data blob
   }
   void set_version(int v) {
-    proto_.set_version(v); // TODO read version from data blob
+    data_->set_version(v); // TODO read version from data blob
   }
    /**
     * @return num of floats.
     */
   int size() const {
-    return data_.count();
+    return data_->count();
   }
   /**
    * Return const mem address for the content of this parameter
    */
   const Blob<float> &data() {
-    return data_;
+    return *data_;
   }
   Blob<float> *mutable_data() {
-    return &data_;
+    return data_.get();
   }
   /**
    * Return gradient of this parameter
@@ -112,7 +113,7 @@ class Param {
   }
 
   float* mutable_cpu_data(){
-    return data_.mutable_cpu_data();
+    return data_->mutable_cpu_data();
   }
   float* mutable_cpu_grad(){
     return grad_.mutable_cpu_data();
@@ -125,10 +126,9 @@ class Param {
    * name of the parameter used to share wights between neuralnets
    */
   std::string name_;
-  //! content, gradient, history gradient of this parameter
-  Blob<float> data_, grad_, history_;
-  int owner_;
-
+  shared_ptr<Blob<float>> data_;
+  //! gradient, history gradient of this parameter
+  Blob<float> grad_, history_;
   ParamProto proto_;
   int fan_in_;
 };

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/39969f2b/script/graph.py
----------------------------------------------------------------------
diff --git a/script/graph.py b/script/graph.py
new file mode 100644
index 0000000..17aaea7
--- /dev/null
+++ b/script/graph.py
@@ -0,0 +1,22 @@
+import sys
+import pygraphviz
+import networkx as nx
+from networkx.readwrite import json_graph
+import json
+
+
+if __name__=='__main__':
+  print sys.argv
+  if len(sys.argv)<3:
+    print 'usage: draw the network graph\npython graph.py JSON_DAT FIG_FILE'
+    sys.exit()
+
+  with open(sys.argv[1]) as fd:
+    nodelink=json.load(fd)
+    G=json_graph.node_link_graph(nodelink)
+    A = nx.to_agraph(G)
+    A.layout('dot', args='-Nfontsize=10 -Nwidth=".2" -Nheight=".2" -Nmargin=0 \
+        -Gfontsize=8')
+    A.draw(sys.argv[2])
+
+

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/39969f2b/src/neuralnet/base_layer.cc
----------------------------------------------------------------------
diff --git a/src/neuralnet/base_layer.cc b/src/neuralnet/base_layer.cc
index 44bed3f..e7702e9 100644
--- a/src/neuralnet/base_layer.cc
+++ b/src/neuralnet/base_layer.cc
@@ -53,13 +53,8 @@ void BridgeSrcLayer::SetupAfterPartition(){
 
 void BridgeSrcLayer::ComputeFeature(bool training,
     const vector<SLayer>& srclayers){
-  if(training)
-    ready_=false;
-  else
-    ready_=true;
 }
 void BridgeSrcLayer::ComputeGradient(const vector<SLayer>& srclayers){
-
 }
 
 /********* Implementation for BridgeDstLayer **************/
@@ -74,16 +69,6 @@ void BridgeDstLayer::SetupAfterPartition(){
   //LOG(ERROR)<<name()<<":"<<IntVecToString(shape_);
 }
 
-void BridgeDstLayer::ComputeFeature(bool training,
-    const vector<SLayer>& srclayers){
-  if(training)
-    ready_=true;
-  else
-    ready_=false;
-}
-void BridgeDstLayer::ComputeGradient(const vector<shared_ptr<Layer>>& 
srclayers){
-
-}
 
 /************* Implementation for ConcateLayer ***********/
 void ConcateLayer::Setup(const LayerProto& proto,

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/39969f2b/src/neuralnet/neuralnet.cc
----------------------------------------------------------------------
diff --git a/src/neuralnet/neuralnet.cc b/src/neuralnet/neuralnet.cc
index accd619..4dac512 100644
--- a/src/neuralnet/neuralnet.cc
+++ b/src/neuralnet/neuralnet.cc
@@ -32,7 +32,8 @@ void NeuralNet::RegisterLayers(){
   factory->Register("kSplit", CreateLayer(SplitLayer));
   factory->Register("kTanh", CreateLayer(TanhLayer));
 }
-shared_ptr<NeuralNet> NeuralNet::SetupNeuralNet(const NetProto& np, Phase 
phase){
+shared_ptr<NeuralNet> NeuralNet::SetupNeuralNet(const NetProto& np, Phase 
phase,
+    int group_size){
   NetProto proto;
   proto.set_partition_type(np.partition_type());
   // exclude layers if necessary
@@ -48,8 +49,7 @@ shared_ptr<NeuralNet> NeuralNet::SetupNeuralNet(const 
NetProto& np, Phase phase)
     }
   }
   LOG(INFO)<<"NeuralNet config is "<<proto.DebugString();
-  shared_ptr<NeuralNet> net(new NeuralNet(proto));
-  return net;
+  return make_shared<NeuralNet>(proto, group_size);
 }
 NeuralNet::NeuralNet(NetProto net_proto, int group_size) {
   group_size_=group_size;
@@ -66,15 +66,11 @@ NeuralNet::NeuralNet(NetProto net_proto, int group_size) {
   for(auto layer: layers_){
     DLOG(INFO)<<layer->name();
   }
-  // assign id for params;
-  int paramid=0;
   for(auto& layer: layers_){
     for(shared_ptr<Param> p: layer->GetParams()){
       params_.push_back(p);
-      p->set_id(paramid++);
     }
   }
-
   LOG(INFO)<<"Neural Net constructed";
 }
 
@@ -112,8 +108,11 @@ void NeuralNet::ConstructNeuralNet(const NetProto& 
net_proto){
       layer->AddSrcLayer(name2layer_[src->name()]);
   }
   // setup layer properties, e.g., shapes
+  int paramid=0;
   for(auto& layer: layers_){
       layer->Setup();
+      for(auto param: layer->GetParams())
+        param->set_id(paramid++);
   }
   LOG(INFO)<<"network graph witout partition\n"<<ToString();
 }
@@ -122,6 +121,7 @@ void NeuralNet::PartitionNeuralNet(){
   graph_=CreatePartitonedGraph(layers_, name2layer_);
   //DLOG(ERROR)<<"pure graph after partition\n"<<graph_.ToString();
   map<string, shared_ptr<Layer>> name2layer(name2layer_);
+  map<string, vector<shared_ptr<Layer>>> share_param_layers;
   name2layer_.clear();
   layers_.clear();
   int gsize=group_size_;
@@ -130,7 +130,6 @@ void NeuralNet::PartitionNeuralNet(){
   for(SNode node: graph_.nodes()){
     LayerProto proto;
     proto.set_name(node->name());
-    proto.set_locationid(node->val().locationid);
     proto.set_partitionid(node->val().partitionid);
     const string& origin=node->val().origin;
     if (origin=="kSlice"){
@@ -173,7 +172,10 @@ void NeuralNet::PartitionNeuralNet(){
         layer->Init(*oldlayer, shape);
         layer->set_name(node->name());
         newlayer=layer;
+        if(oldlayer->partition_type()==kDataPartition)
+          share_param_layers[node->val().origin].push_back(newlayer);
       }
+      newlayer->set_partitionid(node->val().partitionid);
     }
     layers_.push_back(newlayer);
     name2layer_[node->name()]=newlayer;
@@ -193,14 +195,30 @@ void NeuralNet::PartitionNeuralNet(){
   LOG(INFO)<<"Adjacency matrix\n"<<ToAdjacency();
 
   // set up layers after
+  int paramid=0;
   for(shared_ptr<Layer> layer: layers_){
     const vector<int>& shape=layer->shape(nullptr);
     layer->SetupAfterPartition();
+    for(auto param: layer->GetParams())
+      param->set_id(paramid++);
     const vector<int>& newshape=layer->shape(nullptr);
     if(shape.size())
       CHECK(std::equal(shape.begin(),shape.end(),newshape.begin()));
   }
 
+  // share Params for layers generated from the same origin layer due to
+  // data partition
+  for(auto & entry: share_param_layers){
+    auto layers= entry.second;
+    auto owner=layers.begin();
+    auto owner_params=(*owner)->GetParams();
+    for(auto it=owner+1; it!=layers.end();it++){
+      auto params=(*it)->GetParams();
+      CHECK_EQ(params.size(), owner_params.size());
+      for(size_t i=0;i<params.size();i++)
+        params.at(i)->ShareData(owner_params.at(i));
+    }
+  }
   LOG(INFO)<<"network graph after partition layers\n"<<ToString();
 }
 
@@ -219,13 +237,12 @@ Graph NeuralNet::CreatePartitonedGraph(const 
vector<shared_ptr<Layer>>& layers,
         sprintf(suffix, "%02d", i);
         // differentiate partitions
         string nodename=layer->name()+"-"+string(suffix);
-        LayerInfo info;
-        auto node=graph.AddNode(nodename, LayerInfo{layer->name(),i, i,-1,-1});
+        auto node=graph.AddNode(nodename, LayerInfo{layer->name(), i,-1,-1});
         nodes.push_back(node);
       }
     }else if(layer->partition_type()==kNone){
       auto node=graph.AddNode(layer->name(),
-          LayerInfo{layer->name(), layer->locationid(), 0,-1,-1});
+          LayerInfo{layer->name(), 0,-1,-1});
       nodes.push_back(node);
     }else{
       LOG(FATAL)<<"Unknown partition type "<<layer->partition_type();
@@ -321,7 +338,7 @@ Graph NeuralNet::CreatePartitonedGraph(const 
vector<shared_ptr<Layer>>& layers,
     vector<SNode> dstnodes=node->dstnodes();
     for(size_t i=0;i<dstnodes.size();i++){
       SNode dstnode=dstnodes.at(i);
-      if(node->val().locationid!=dstnode->val().locationid){
+      if(node->val().partitionid!=dstnode->val().partitionid){
         graph.RemoveEdge(node, dstnode);
         graph.InsertBridgeNode(node, dstnode);
       }

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/39969f2b/src/proto/model.pb.h
----------------------------------------------------------------------
diff --git a/src/proto/model.pb.h b/src/proto/model.pb.h
index b111a2f..4f68462 100644
--- a/src/proto/model.pb.h
+++ b/src/proto/model.pb.h
@@ -877,12 +877,12 @@ class ParamProto : public ::google::protobuf::Message {
   inline ::google::protobuf::int32 partition_dim() const;
   inline void set_partition_dim(::google::protobuf::int32 value);
 
-  // optional int32 version = 6;
-  inline bool has_version() const;
-  inline void clear_version();
-  static const int kVersionFieldNumber = 6;
-  inline ::google::protobuf::int32 version() const;
-  inline void set_version(::google::protobuf::int32 value);
+  // optional int32 owner = 6;
+  inline bool has_owner() const;
+  inline void clear_owner();
+  static const int kOwnerFieldNumber = 6;
+  inline ::google::protobuf::int32 owner() const;
+  inline void set_owner(::google::protobuf::int32 value);
 
   // optional .singa.ParamProto.InitMethod init_method = 7 [default = 
kConstant];
   inline bool has_init_method() const;
@@ -950,8 +950,8 @@ class ParamProto : public ::google::protobuf::Message {
   inline void clear_has_split_threshold();
   inline void set_has_partition_dim();
   inline void clear_has_partition_dim();
-  inline void set_has_version();
-  inline void clear_has_version();
+  inline void set_has_owner();
+  inline void clear_has_owner();
   inline void set_has_init_method();
   inline void clear_has_init_method();
   inline void set_has_value();
@@ -976,7 +976,7 @@ class ParamProto : public ::google::protobuf::Message {
   ::google::protobuf::int32 id_;
   ::google::protobuf::int32 split_threshold_;
   ::google::protobuf::int32 partition_dim_;
-  ::google::protobuf::int32 version_;
+  ::google::protobuf::int32 owner_;
   int init_method_;
   float value_;
   float low_;
@@ -4761,26 +4761,26 @@ inline void 
ParamProto::set_partition_dim(::google::protobuf::int32 value) {
   partition_dim_ = value;
 }
 
-// optional int32 version = 6;
-inline bool ParamProto::has_version() const {
+// optional int32 owner = 6;
+inline bool ParamProto::has_owner() const {
   return (_has_bits_[0] & 0x00000020u) != 0;
 }
-inline void ParamProto::set_has_version() {
+inline void ParamProto::set_has_owner() {
   _has_bits_[0] |= 0x00000020u;
 }
-inline void ParamProto::clear_has_version() {
+inline void ParamProto::clear_has_owner() {
   _has_bits_[0] &= ~0x00000020u;
 }
-inline void ParamProto::clear_version() {
-  version_ = 0;
-  clear_has_version();
+inline void ParamProto::clear_owner() {
+  owner_ = 0;
+  clear_has_owner();
 }
-inline ::google::protobuf::int32 ParamProto::version() const {
-  return version_;
+inline ::google::protobuf::int32 ParamProto::owner() const {
+  return owner_;
 }
-inline void ParamProto::set_version(::google::protobuf::int32 value) {
-  set_has_version();
-  version_ = value;
+inline void ParamProto::set_owner(::google::protobuf::int32 value) {
+  set_has_owner();
+  owner_ = value;
 }
 
 // optional .singa.ParamProto.InitMethod init_method = 7 [default = kConstant];

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/39969f2b/src/proto/model.proto
----------------------------------------------------------------------
diff --git a/src/proto/model.proto b/src/proto/model.proto
index dc45313..950bc2e 100644
--- a/src/proto/model.proto
+++ b/src/proto/model.proto
@@ -94,16 +94,13 @@ message ParamProto {
   // the program will calculate it
   repeated int32 shape = 3;
 
-  // split the parameter into multiple DAryProtos for serialzation and
+  // split the parameter into multiple sub params for serialzation and
   // transferring (Google Protobuf has size limit)
   optional int32 split_threshold=4 [default=5000000];
   // partition dimension, -1 for no partition
   optional int32 partition_dim=5 [default =-1];
 
-  optional int32 version=6;
-
-  // value of the parameter
-  //repeated DAryProto ary = 6;
+  optional int32 owner=6;
 
   enum InitMethod {
     kConstant = 0;

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/39969f2b/src/trainer/pm_worker.cc
----------------------------------------------------------------------
diff --git a/src/trainer/pm_worker.cc b/src/trainer/pm_worker.cc
index 7269578..d2531e0 100644
--- a/src/trainer/pm_worker.cc
+++ b/src/trainer/pm_worker.cc
@@ -43,18 +43,15 @@ Msg* PMWorker::Put(Msg** msg){
 }
 
 Msg* PMWorker::Put(shared_ptr<Param> param, int step){
-  param->set_version(step);
-  // only owner can put shared parameter
-  if(param->owner()<0||param->owner()==param->id()){
-    Msg* msg= param->GenPutMsg(&step);
-    msg->set_src(group_id_, worker_id_, kWorkerParam);
-    msg->set_dst(group_id_/Cluster::Get()->nworker_groups_per_server_group(),
-        Sharding(param->id()), kServer);
-    msg->set_type(kPut);
-    msg->set_target(param->id());
-    return msg;
-  }else
-    return nullptr;
+  int id=param->owner();
+  auto entry=shard_->at(id);
+  Msg* msg= param->GenPutMsg(&step);
+  msg->set_src(group_id_, worker_id_, kWorkerParam);
+  msg->set_dst(group_id_/Cluster::Get()->nworker_groups_per_server_group(),
+      Sharding(id), kServer);
+  msg->set_type(kPut);
+  msg->set_target(id);
+  return msg;
 }
 
 Msg* PMWorker::Get(Msg** msg){
@@ -62,79 +59,62 @@ Msg* PMWorker::Get(Msg** msg){
 }
 
 Msg* PMWorker::Get(shared_ptr<Param> param, int step){
-  param->set_version(step);
-  bool send=false;
-  int id=param->id();
-  shared_ptr<ParamCounter> entry=nullptr;
-  if(param->owner()>=0){
-    entry=shard_->at(id);
-    entry->nGet++;
-    send=entry->nGet/entry->nLocal==step;
-  }
-  if(param->owner()<0||send){
-    Msg* msg=nullptr;
-    if(param->owner()<0){
-      msg=param->GenGetMsg(&step);
-      msg->set_dst(group_id_/Cluster::Get()->nworker_groups_per_server_group(),
-          Sharding(id), kServer);
-    } else {
-      msg=entry->param->GenGetMsg(&step);
-      msg->set_dst(entry->owner_procs,kStub);
-    }
+  int id=param->owner();
+  shared_ptr<ParamCounter> entry=shard_->at(id);
+  Msg *msg=nullptr;
+  if((entry->nGet+1)%entry->nLocal==0&&param->version()<step){
+    msg=param->GenGetMsg(&step);
+    msg->set_dst(group_id_/Cluster::Get()->nworker_groups_per_server_group(),
+        Sharding(id), kServer);
     msg->set_src(group_id_, worker_id_, kWorkerParam);
     msg->set_type(kGet);
     msg->set_target(id);
-    return msg;
-  }else
-    return nullptr;
+  }
+  entry->nGet++;
+  return msg;
 }
 
 Msg* PMWorker::Update(Msg** msg){
   return *msg;
 }
 Msg* PMWorker::Update(shared_ptr<Param> param, int step){
-  param->set_version(step);
-  bool send=false;
-  int id=param->id();
-  shared_ptr<ParamCounter> entry;
-  if(param->owner()>=0){
-    entry=shard_->at(param->id());
-    entry->nGet++;
-    send=entry->nGet/entry->nLocal==step;
+  int id=param->owner();
+  shared_ptr<ParamCounter> entry=shard_->at(id);
+  Msg* msg=nullptr;
+  if((entry->nUpdate+1)%entry->nLocal==0){
     auto shape=mshadow::Shape1(param->size());
-    mshadow::Tensor<mshadow::cpu,1> grad(param->mutable_cpu_grad(), shape);
-    mshadow::Tensor<mshadow::cpu,1> agg(entry->param->mutable_cpu_grad(), 
shape);
-    agg+=grad;
-  }
-  if(param->owner()<0||send){
-    Msg* msg=nullptr;
-    if(param->owner()<0){
-      msg=param->GenUpdateMsg(&step);
-      msg->set_dst(group_id_/Cluster::Get()->nworker_groups_per_server_group(),
-          Sharding(id), kServer);
-    } else {
-      entry->param->GenUpdateMsg(&step);
-      msg->set_dst(entry->owner_procs,kStub);
-      memset(param->mutable_cpu_data(), 0, sizeof(float)*param->size());
+    auto it=entry->shares.begin();
+    mshadow::Tensor<mshadow::cpu,1> agg((*it)->mutable_cpu_grad(), shape);
+    for(++it;it!=entry->shares.end();it++){
+      mshadow::Tensor<mshadow::cpu,1> grad((*it)->mutable_cpu_grad(), shape);
+      agg+=grad/entry->nTotal;
     }
+    msg=entry->shares.at(0)->GenUpdateMsg(&step);
+    msg->set_dst(group_id_/Cluster::Get()->nworker_groups_per_server_group(),
+        Sharding(id), kServer);
+    /*
+       entry->param->GenUpdateMsg(&step);
+       msg->set_dst(entry->owner_procs,kStub);
+       memset(param->mutable_cpu_data(), 0, sizeof(float)*param->size());
+       */
     msg->set_type(kUpdate);
     msg->set_target(id);
     msg->set_src(group_id_, worker_id_, kWorkerParam);
-    return msg;
-  }else
-    return nullptr;
+  }
+  entry->nUpdate++;
+  return msg;
 }
 
 Msg* PMWorker::Collect(Msg** msg){
   int id=(*msg)->target();
   int type=(*msg)->type();
-  auto pp=shard_->at(id)->param;
+  auto pp=shard_->at(id)->shares.at(0);
   if(type==kRGet){
     pp->ParseGetResponseMsg(msg);
   }else if(type==kRUpdate){
     pp->ParseUpdateResponseMsg(msg);
   }
-  if(pp->owner()>=0){
+  if(pp->owner()!=pp->id()){
     // forwarding to workers on other procs
   }
   delete (*msg);

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/39969f2b/src/trainer/server.cc
----------------------------------------------------------------------
diff --git a/src/trainer/server.cc b/src/trainer/server.cc
index f5877c5..36c1302 100644
--- a/src/trainer/server.cc
+++ b/src/trainer/server.cc
@@ -9,23 +9,25 @@
 
 
 namespace singa {
-Server::Server(int group_id, int server_id):
-  group_id_(group_id), server_id_(server_id){}
+Server::Server(int thread_id,int group_id, int server_id):
+  thread_id_(thread_id),group_id_(group_id), server_id_(server_id){}
 
 void Server::Setup(const UpdaterProto& proto,
-    shared_ptr<PMServer::ParamShard> shard,
-    shared_ptr<Dealer> dealer){
+    shared_ptr<PMServer::ParamShard> shard){
        //VLOG(3) << "Parsing config file for host "<<hosts[id_] << " server id 
= " <<id_;
   pmserver_=shared_ptr<PMServer>(Singleton<Factory<PMServer>>::Instance()
       ->Create("PMServer"));
   pmserver_->Setup(group_id_, server_id_, shard, proto);
-  dealer_=dealer;
 }
 
 void Server::Run(){
+  dealer_=std::make_shared<Dealer>(2*thread_id_);
+  dealer_->Connect(kInprocRouterEndpoint);
+
   Msg* ping=new Msg();
   ping->set_src(group_id_, server_id_, kServer);
   ping->set_dst(0,0,kStub);
+  ping->add_frame("PING", 4);
   ping->set_type(kConnect);
   dealer_->Send(ping);
   Poller poller;
@@ -38,6 +40,12 @@ void Server::Run(){
     Msg* response=nullptr;
     int type=msg->type();
     switch (type){
+      case kConnect:{
+        string pong((char*)msg->frame_data(), msg->frame_size());
+        CHECK_STREQ("PONG", pong.c_str());
+        delete msg;
+        break;
+                    }
       case kPut:
         response = pmserver_->HandlePut(&msg);
         break;
@@ -57,8 +65,10 @@ void Server::Run(){
         break;
     }
 
-    if (response!=nullptr)
+    if (response!=nullptr){
+      //LOG(ERROR)<<"type: "<<type<<" response to "<<response->dst_id();
       dealer_->Send(response);
+    }
   }
 }
 

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/39969f2b/src/trainer/trainer.cc
----------------------------------------------------------------------
diff --git a/src/trainer/trainer.cc b/src/trainer/trainer.cc
index 0a1edc8..4ac51ce 100644
--- a/src/trainer/trainer.cc
+++ b/src/trainer/trainer.cc
@@ -48,7 +48,7 @@ void Trainer::Start(const ModelProto& mproto, const 
ClusterProto& cproto,
   auto cluster=Cluster::Get(cproto, procs_id);
   // create servers
   vector<shared_ptr<Server>> servers;
-  int nSocket=1; // the first socket is the router
+  int nthreads=1; // the first socket is the router
   if(cluster->has_server()){
     int pid=cluster->procs_id();
     if(cluster->server_worker_separate())
@@ -59,10 +59,8 @@ void Trainer::Start(const ModelProto& mproto, const 
ClusterProto& cproto,
     // the ParamShard for servers consists of a dictionary of Param objects
     auto shard=make_shared<PMServer::ParamShard>();
     for(int sid=start;sid<end;sid++){
-      auto server=make_shared<Server>(gid, sid);
-      auto dealer=make_shared<Dealer>(nSocket++);
-      dealer->Connect(kInprocRouterEndpoint);
-      server->Setup(mproto.updater(), shard, dealer);
+      auto server=make_shared<Server>(nthreads++,gid, sid);
+      server->Setup(mproto.updater(), shard);
       servers.push_back(server);
     }
   }
@@ -70,7 +68,9 @@ void Trainer::Start(const ModelProto& mproto, const 
ClusterProto& cproto,
   // create workers
   vector<shared_ptr<Worker>> workers;
   if(cluster->has_worker()){
-    auto net=NeuralNet::SetupNeuralNet(mproto.neuralnet(), kTrain);
+    auto net=NeuralNet::SetupNeuralNet(mproto.neuralnet(), kTrain,
+        cluster->nworkers_per_group());
+    //LOG(ERROR)<<net->ToString();
     int pid=cluster->procs_id();
     int gstart, gend, wstart, wend;
     if(cluster->nworkers_per_group()>=cluster->nworkers_per_procs()){
@@ -94,7 +94,8 @@ void Trainer::Start(const ModelProto& mproto, const 
ClusterProto& cproto,
       if(gid==gstart)
         train_net=net;
       else{
-        train_net=NeuralNet::SetupNeuralNet(mproto.neuralnet(), kTrain);
+        train_net=NeuralNet::SetupNeuralNet(mproto.neuralnet(), kTrain,
+            cluster->nworkers_per_group());
         // the train net for other groups may share parameter values from the
         // first group
         if(mproto.hogwild())
@@ -103,12 +104,14 @@ void Trainer::Start(const ModelProto& mproto, const 
ClusterProto& cproto,
       if(gid==0){
         // validation and test are performed only by the first group
         if(mproto.test_steps()){
-          test_net=NeuralNet::SetupNeuralNet(mproto.neuralnet(), kTest);
+          test_net=NeuralNet::SetupNeuralNet(mproto.neuralnet(), kTest,
+              cluster->nworkers_per_group());
           if(test_net!=nullptr)
             test_net->ShareParams(train_net, kValueOnly);
         }
         if(mproto.validation_steps()){
-          validation_net=NeuralNet::SetupNeuralNet(mproto.neuralnet(), 
kValidation);
+          validation_net=NeuralNet::SetupNeuralNet(mproto.neuralnet(), 
kValidation,
+              cluster->nworkers_per_group());
           if(validation_net!=nullptr)
             validation_net->ShareParams(train_net, kValueOnly);
         }
@@ -116,28 +119,24 @@ void Trainer::Start(const ModelProto& mproto, const 
ClusterProto& cproto,
       // create ParamShard for the workers
       auto shard=make_shared<PMWorker::ParamShard>();
       for(auto layer: train_net->layers()){
-        int procsid=ProcsIDOf(gid, layer->locationid(),kWorkerParam);
+        int procsid=ProcsIDOf(gid, layer->partitionid(),kWorkerParam);
         int local=procsid==cluster->procs_id();
         for(auto param: layer->GetParams()){
           int owner=param->owner()<0||param->owner()==param->id()?procsid:-1;
-          if(shard->find(param->id())==shard->end())
-            (*shard)[param->id()]=make_shared<ParamCounter>(param, local, 
owner);
+          if(shard->find(param->owner())==shard->end())
+            (*shard)[param->owner()]=make_shared<ParamCounter>(param, local, 
owner);
           else
-            shard->at(param->id())->AddParam(param, local, owner);
+            shard->at(param->owner())->AddParam(param, local, owner);
         }
       }
       for(int wid=wstart;wid<wend;wid++){
         shared_ptr<Worker> worker=nullptr;
         if(mproto.alg()==ModelProto_GradCalcAlg_kBackPropagation)
-          worker=make_shared<BPWorker>(gid, wid);
+          worker=make_shared<BPWorker>(nthreads++,gid, wid);
         else{
         // TODO add CDWorker
         }
-        auto layer_dealer=make_shared<Dealer>(nSocket++);
-        auto param_dealer=make_shared<Dealer>(nSocket++);
-        layer_dealer->Connect(kInprocRouterEndpoint);
-        param_dealer->Connect(kInprocRouterEndpoint);
-        worker->Setup(mproto, train_net, shard, layer_dealer, param_dealer);
+        worker->Setup(mproto, train_net, shard);
         worker->set_test_net(test_net);
         worker->set_validation_net(validation_net);
         workers.push_back(worker);
@@ -152,9 +151,9 @@ void Trainer::Start(const ModelProto& mproto, const 
ClusterProto& cproto,
 #endif
   vector<std::thread> threads;
   for(auto server: servers)
-    threads.push_back(std::thread(&Server::Run,server));
+    threads.push_back(std::thread(&Server::Run,server.get()));
   for(auto worker: workers)
-    threads.push_back(std::thread(&Worker::Run,worker));
+    threads.push_back(std::thread(&Worker::Run,worker.get()));
   Run();
   for(auto& thread: threads)
     thread.join();
@@ -168,8 +167,6 @@ void Trainer::Run(){
     router->Bind(cluster->endpoint());
 
   map<int, shared_ptr<Dealer>> interprocs_dealers;
-  Poller poller;
-  poller.Add(router.get());
   while(true){
     Msg* msg=router->Receive();
     if(msg==nullptr){
@@ -182,7 +179,15 @@ void Trainer::Run(){
     switch (dst_flag){ // TODO process other requests, e.g. RESTful
       case kStub:
         if(type==kConnect){
+          string ping((char*)msg->frame_data(), msg->frame_size());
+          CHECK_STREQ("PING", ping.c_str());
+          msg->SwapAddr();
+          Msg* reply=new Msg();
+          reply->SetAddr(msg);
+          reply->add_frame("PONG", 4);
+          reply->set_type(kConnect);
           delete msg;
+          router->Send(reply);
         }else{
           // TODO processing requests for worker group spanning multiple procs.
           LOG(ERROR)<<"Unkown message type ("<<type<<") to stub";

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/39969f2b/src/trainer/worker.cc
----------------------------------------------------------------------
diff --git a/src/trainer/worker.cc b/src/trainer/worker.cc
index a290996..3f1c83f 100644
--- a/src/trainer/worker.cc
+++ b/src/trainer/worker.cc
@@ -8,40 +8,66 @@
 #include "proto/model.pb.h"
 using std::thread;
 namespace singa {
-Worker::Worker( int group_id, int worker_id):
-  group_id_(group_id), worker_id_(worker_id){
+Worker::Worker(int thread_id, int group_id, int worker_id):
+  thread_id_(thread_id),group_id_(group_id), worker_id_(worker_id){
   }
 
 void Worker::Setup(const ModelProto& model,
     shared_ptr<NeuralNet> train_net,
-    shared_ptr<PMWorker::ParamShard> shard,
-    shared_ptr<Dealer> layer_dealer,
-    shared_ptr<Dealer> param_dealer){
+    shared_ptr<PMWorker::ParamShard> shard){
   train_net_=train_net;
   modelproto_=model;
-  layer_dealer_=layer_dealer;
-  param_dealer_=param_dealer;
-  if(layer_dealer_!=nullptr)
-    layer_poller_.Add(layer_dealer_.get());
-  if(param_dealer_!=nullptr)
-    param_poller_.Add(param_dealer_.get());
   pmworker_=shared_ptr<PMWorker>(Singleton<Factory<PMWorker>>::Instance()
       ->Create("PMWorker"));
   pmworker_->Setup(group_id_, worker_id_, shard);
+}
+
+void Worker::Run(){
+  param_dealer_=make_shared<Dealer>(2*thread_id_);
+  param_dealer_->Connect(kInprocRouterEndpoint);
+  param_poller_.Add(param_dealer_.get());
+  layer_dealer_=make_shared<Dealer>(2*thread_id_+1);
+  layer_dealer_->Connect(kInprocRouterEndpoint);
+
+  {
+  Msg* ping=new Msg();
+  ping->set_src(group_id_, worker_id_, kWorkerParam);
+  ping->set_dst(0,0,kStub);
+  ping->set_type(kConnect);
+  ping->add_frame("PING", 4);
+  param_dealer_->Send(ping);
+  ping=param_dealer_->Receive();
+  string pong((char*)ping->frame_data(), ping->frame_size());
+  CHECK_STREQ("PONG", pong.c_str());
+  delete ping;
+  }
+
+  {
+  Msg* ping=new Msg();
+  ping->set_src(group_id_, worker_id_, kWorkerLayer);
+  ping->set_dst(0,0,kStub);
+  ping->set_type(kConnect);
+  ping->add_frame("PING", 4);
+  layer_dealer_->Send(ping);
+  ping=layer_dealer_->Receive();
+  string pong((char*)ping->frame_data(), ping->frame_size());
+  CHECK_STREQ("PONG", pong.c_str());
+  delete ping;
+  }
   step_=modelproto_.step();
   // init params
-  for(auto layer: train_net->layers())
-    if(group_id_==0&&layer->locationid()==worker_id_)
+  for(auto layer: train_net_->layers()){
+    //LOG(ERROR)<<layer->partitionid()<<" : "<<layer->name();
+    if(layer->partitionid()==worker_id_)
       for(auto param: layer->GetParams()){
-        if(param->owner()<0||param->owner()==param->id()){
-          param->Init();
+        if(group_id_==0&&param->owner()==param->id()){
+          param->Init(0);
           Put(param, step_);
+        }else{
+          Get(param, step_);
         }
-        Get(param, step_);
       }
-}
-
-void Worker::Run(){
+  }
   step_=modelproto_.step();
   Performance perf(train_net_);
   while(!StopNow(step_)){
@@ -56,11 +82,9 @@ int Worker::Put(shared_ptr<Param> param, int step){
   return 1;
 }
 int Worker::Get(shared_ptr<Param> param, int step){
-  if(param->version()<step){
-    auto msg=pmworker_->Get(param, step);
-    if(msg!=nullptr)
-      param_dealer_->Send(msg);
-  }
+  auto msg=pmworker_->Get(param, step);
+  if(msg!=nullptr)
+    param_dealer_->Send(msg);
   return 1;
 }
 int Worker::Update(shared_ptr<Param> param, int step){
@@ -69,12 +93,26 @@ int Worker::Update(shared_ptr<Param> param, int step){
     param_dealer_->Send(msg);
   return 1;
 }
+
+int Worker::CollectAll(shared_ptr<NeuralNet> net, int step){
+  auto& layers=net->layers();
+  for(auto& layer: layers){
+    if(layer->partitionid()==worker_id_)
+      for(shared_ptr<Param> p: layer->GetParams()){
+        Collect(p, step);
+      }
+  }
+  return 1;
+}
 int Worker::Collect(shared_ptr<Param> param, int step){
   while(param->version()<step){
-    Msg* msg=param_dealer_->Receive();
-    if(msg==nullptr)
-      return 0;
-    pmworker_->Collect(&msg);
+    Socket* which=param_poller_.Wait(10);
+    if(which!=nullptr){
+      Msg* msg=param_dealer_->Receive();
+      if(msg==nullptr)
+        return 0;
+      pmworker_->Collect(&msg);
+    }
   }
   return 1;
 }
@@ -86,14 +124,17 @@ void Worker::RunOneBatch(int step, Performance* perf){
   //float tSyncData=tSyncData_, tSyncParam=tSyncParam_;
   if(ValidateNow(step)){
     LOG(ERROR)<<"Validation at step "<<step;
+    CollectAll(validation_net_, step);
     Test(validation_net_, modelproto_.validation_steps(), perf!=nullptr);
   }
   if(TestNow(step)){
     LOG(ERROR)<<"Test at step "<<step;
+    CollectAll(test_net_, step);
     Test(test_net_, modelproto_.test_steps(), perf!=nullptr);
   }
   //tSyncData_=tSyncData; tSyncParam_=tSyncParam;
 
+  CollectAll(train_net_, step);
   TrainOneBatch(step);
   if(perf!=nullptr){
     perf->Update();
@@ -158,10 +199,22 @@ void Worker::Test(shared_ptr<NeuralNet> net, int nsteps, 
bool disperf){
 void BPWorker::Forward(shared_ptr<NeuralNet> net, int step,  bool training){
   auto& layers=net->layers();
   for(auto& layer: layers){
-    if(layer->locationid()==worker_id_){
+    if(layer->partitionid()==worker_id_){
       if(layer->is_bridgedstlayer()){
-        //auto* dst=static_cast<BridgeDstLayer*>(layer.get());
-        // receive fea blobs
+        auto* dst=static_cast<BridgeDstLayer*>(layer.get());
+        while(!dst->ready()){
+          auto msg=layer_dealer_->Receive();
+          CHECK_EQ(msg->src_group_id(), group_id_);
+          string name((char*)msg->frame_data(), msg->frame_size());
+          auto tmp=net->name2layer(name);
+          CHECK(tmp->is_bridgedstlayer());
+          auto* dstlayer=static_cast<BridgeDstLayer*>(tmp.get());
+          auto data=dstlayer->mutable_data(nullptr);
+          msg->next_frame();
+          memcpy(data->mutable_cpu_data(), msg->frame_data(), 
msg->frame_size());
+          dstlayer->set_ready(true);
+          delete msg;
+        }
       }
       if(training){
         for(shared_ptr<Param> p: layer->GetParams()){
@@ -172,7 +225,14 @@ void BPWorker::Forward(shared_ptr<NeuralNet> net, int 
step,  bool training){
       layer->ComputeFeature(training);
       //LOG(ERROR)<<layer->name()<<":"<<(clock()-s)*1.0/CLOCKS_PER_SEC;
       if(layer->is_bridgesrclayer()){
-        // send fea blobs
+        auto dst=layer->dstlayers().at(0);
+        Msg *msg=new Msg();
+        msg->set_src(group_id_, worker_id_, kWorkerLayer);
+        msg->set_dst(group_id_, dst->partitionid(), kWorkerLayer);
+        msg->add_frame(dst->name().c_str(), dst->name().length());
+        auto const & blob=layer->data(nullptr);
+        msg->add_frame(blob.cpu_data(), blob.count()*sizeof(float));
+        layer_dealer_->Send(msg);
       }
       
if(training&&DisplayDebugInfo(step)&&layer->mutable_data(nullptr)!=nullptr){
         LOG(INFO)<<StringPrintf("Forward layer  %10s data norm1 %13.9f",
@@ -186,7 +246,7 @@ void BPWorker::Backward(shared_ptr<NeuralNet> net, int 
step){
   auto& layers=net->layers();
   for (auto it = layers.rbegin(); it != layers.rend(); it++){
     shared_ptr<Layer> layer=*it;
-    if(layer->locationid()==worker_id_){
+    if(layer->partitionid()==worker_id_){
       if(layer->is_bridgesrclayer()){
         //auto* src=static_cast<BridgeSrcLayer*>(layer.get());
         // receive grad blobs

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/39969f2b/src/utils/graph.cc
----------------------------------------------------------------------
diff --git a/src/utils/graph.cc b/src/utils/graph.cc
index d1cece6..076c971 100644
--- a/src/utils/graph.cc
+++ b/src/utils/graph.cc
@@ -20,7 +20,7 @@ const string Graph::ToString(const map<string, string>& info) 
const {
   for(auto node: nodes_){
     char str[1024];
     string name=node->name();
-    string color=colors[(node->val().locationid)%colors.size()];
+    string color=colors[(node->val().partitionid)%colors.size()];
     string shape;
     string origin=node->val().origin;
     if(origin=="kSlice"||origin=="kConcate"||origin=="kSplit"

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/39969f2b/src/utils/param.cc
----------------------------------------------------------------------
diff --git a/src/utils/param.cc b/src/utils/param.cc
index d64c65d..3d46ee6 100644
--- a/src/utils/param.cc
+++ b/src/utils/param.cc
@@ -10,14 +10,6 @@ using std::vector;
 using std::string;
 namespace singa {
 
-Param::Param(){
-  owner_=-1;
-  fan_in_=0;
-  set_version(-1);
-}
-
-Param::~Param(){}
-
 Msg* Param::GenPutMsg(void* arg){
   char buf[256];
   int v=*(int*)arg;
@@ -31,9 +23,10 @@ Msg* Param::GenPutMsg(void* arg){
 }
 
 Msg* Param::GenGetMsg(void* arg){
-  char buf[10];
+  char buf[12];
   int v=*(int*)arg;
   sprintf(buf, "%d", v);
+  LOG(ERROR)<<"gen get version "<<v;
   Msg* msg=new Msg();
   msg->set_type(kGet);
   msg->add_frame(buf, strlen(buf));
@@ -61,16 +54,16 @@ Msg* Param::HandlePutMsg(Msg** msg){
   float lr, wc;
   sscanf(static_cast<char*>((*msg)->frame_data()), "%d %d %f %f",
       &v, &size, &lr, &wc);
-  set_version(v);
   proto_.set_learning_rate_multiplier(lr);
   proto_.set_weight_decay_multiplier(wc);
   CHECK((*msg)->next_frame());
   vector<int> shape{size};
-  data_.Reshape(shape);
+  data_=std::make_shared<Blob<float>>(shape);
+  data_->set_version(v);
   grad_.Reshape(shape);
   history_.Reshape(shape);
   CHECK_EQ(size* sizeof(float), (*msg)->frame_size());
-  memcpy(data_.mutable_cpu_data(), (*msg)->frame_data(), size*sizeof(float));
+  memcpy(mutable_cpu_data(), (*msg)->frame_data(), size*sizeof(float));
   delete (*msg);
   *msg=nullptr;
   return nullptr;
@@ -81,7 +74,7 @@ Msg* Param::HandleGetMsg(Msg** msg){
   sscanf(static_cast<char*>((*msg)->frame_data()), "%d", &v);
   CHECK_LE(v, version());
   CHECK(!(*msg)->next_frame());
-  (*msg)->add_frame(data_.mutable_cpu_data(), sizeof(float)*size());
+  (*msg)->add_frame(mutable_cpu_data(), sizeof(float)*size());
   (*msg)->SwapAddr();
   (*msg)->set_type(kRGet);
   return *msg;
@@ -127,9 +120,10 @@ int Param::ParsePutResponseMsg(Msg **msg){
 int Param::ParseGetResponseMsg(Msg **msg){
   int v;
   sscanf(static_cast<char*>((*msg)->frame_data()), "%d", &v);
-  set_version(v);
   CHECK((*msg)->next_frame());
   memcpy(mutable_cpu_data(), (*msg)->frame_data(), (*msg)->frame_size());
+  // must be set after all other settings are done!
+  set_version(v);
   return 1;
 }
 int Param::ParseUpdateResponseMsg(Msg **msg){
@@ -138,7 +132,7 @@ int Param::ParseUpdateResponseMsg(Msg **msg){
 
 void Param::Setup(const ParamProto& proto, const vector<int>& shape,
     int fan_in){
-  data_.Reshape(shape);
+  data_=std::make_shared<Blob<float>>(shape);
   grad_.Reshape(shape);
   history_.Reshape(shape);
   proto_=proto;
@@ -146,8 +140,8 @@ void Param::Setup(const ParamProto& proto, const 
vector<int>& shape,
 }
 
 void Param::Init(int v){
-  proto_.set_version(v);
-  Tensor<cpu, 1> data(data_.mutable_cpu_data(), Shape1(data_.count()));
+  set_version(v);
+  Tensor<cpu, 1> data(mutable_cpu_data(), Shape1(size()));
   unsigned seed = std::chrono::system_clock::now().time_since_epoch().count();
   auto random=ASingleton<Random<cpu>>::Instance(seed);
   switch (proto_.init_method()) {
@@ -168,7 +162,7 @@ void Param::Init(int v){
   case ParamProto::kUniformSqrtFanInOut:
     random->SampleUniform(data, proto_.low(), proto_.high());
     if(proto_.value())
-      data*= proto_.value()/ sqrt(data_.shape()[0] +data_.shape()[1]);
+      data*= proto_.value()/ sqrt(data_->shape()[0] +data_->shape()[1]);
     break;
   case ParamProto::kGaussian:
     random->SampleGaussian(data, proto_.mean(), proto_.std());
@@ -178,7 +172,7 @@ void Param::Init(int v){
   case ParamProto::kGaussainSqrtFanIn:
     random->SampleGaussian(data, proto_.mean(), proto_.std());
     if(proto_.value())
-      data*= proto_.value()/ sqrt(data_.shape()[0]);
+      data*= proto_.value()/ sqrt(data_->shape()[0]);
     break;
   default:
     LOG(ERROR) << "Illegal parameter init method ";

Reply via email to