Repository: incubator-singa
Updated Branches:
  refs/heads/master 7993a7867 -> 654d733ba


add PrefetchLayer; Prefetching is done by moving DataLayer and ParserLayer as 
sublayers/members of PrefetchLayer, whose ComputeFeature function launches a 
thread to call the ComputeFeature functions of all DataLayer and ParserLayer.


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/831efef0
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/831efef0
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/831efef0

Branch: refs/heads/master
Commit: 831efef03cc2f508866f69a9768b82c21b56af5f
Parents: 7993a78
Author: wang wei <[email protected]>
Authored: Sat May 9 20:57:36 2015 +0800
Committer: wang wei <[email protected]>
Committed: Sat May 9 20:57:36 2015 +0800

----------------------------------------------------------------------
 Makefile.example                     |  91 +++++++++++
 examples/cifar10/model-prefetch.conf | 241 ++++++++++++++++++++++++++++++
 examples/cifar10/model.conf          |  11 +-
 include/neuralnet/base_layer.h       | 232 ++++++++++++++--------------
 include/neuralnet/layer.h            |   2 +
 src/neuralnet/base_layer.cc          | 101 +++++++++++--
 src/neuralnet/layer.cc               |  46 +++---
 src/neuralnet/neuralnet.cc           |  11 +-
 src/proto/model.pb.h                 | 209 ++++++++++++++++++++------
 src/proto/model.proto                |   3 +
 src/trainer/worker.cc                |  10 +-
 11 files changed, 754 insertions(+), 203 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/831efef0/Makefile.example
----------------------------------------------------------------------
diff --git a/Makefile.example b/Makefile.example
new file mode 100644
index 0000000..80dfc26
--- /dev/null
+++ b/Makefile.example
@@ -0,0 +1,91 @@
+###################User Config Varaibles #############################
+# third-party library installation folder
+HOME_DIR := /usr/
+# Lib folder for system and external libs. You may need to change it.
+LIBRARY_DIRS := $(HOME_DIR)/lib64 $(HOME_DIR)/lib $(HOME_DIR)/local/lib
+# Header folder for system and external libs. You may need to change it.
+INCLUDE_DIRS := $(HOME_DIR)/include ./include
+# g++ location, should support c++11, tested with 4.8.1
+CXX := g++
+
+######################Setting Varialbes#######################################
+LIBRARIES := glog gflags protobuf rt opencv_highgui opencv_imgproc opencv_core\
+       lmdb openblas zmq czmq
+
+LDFLAGS := $(foreach librarydir, $(LIBRARY_DIRS), -L$(librarydir))\
+       $(foreach library, $(LIBRARIES), -l$(library))
+# Folder to store compiled files
+BUILD_DIR := build
+MSHADOW_FLAGS :=-DMSHADOW_USE_CUDA=0 -DMSHADOW_USE_CBLAS=1 -DMSHADOW_USE_MKL=0
+CXXFLAGS := -O3 -Wall -pthread -fPIC -std=c++11 -Wno-unknown-pragmas \
+       $(MSHADOW_FLAGS) -DCPU_ONLY=1 \
+       -funroll-loops $(foreach includedir, $(INCLUDE_DIRS), -I$(includedir))
+
+# find user defined .proto file, and then compute the corresponding .h, .cc
+# files, which cannot be found by shell find, because they haven't been
+# generated currently
+PROTOS := $(shell find src/proto/ -name "*.proto")
+PROTO_SRCS :=$(PROTOS:.proto=.pb.cc)
+PROTO_HDRS :=$(patsubst src%, include%, $(PROTOS:.proto=.pb.h))
+PROTO_OBJS :=$(addprefix $(BUILD_DIR)/, $(PROTO_SRCS:.cc=.o))
+
+# each singa src file will generate a .o file
+SINGA_SRCS := $(shell find src/ \( -path "src/test" -o -path "src/main.cc" \) \
+       -prune -o \( -name "*.cc" -type f \) -print )
+SINGA_OBJS := $(sort $(addprefix $(BUILD_DIR)/, $(SINGA_SRCS:.cc=.o)) \
+       $(PROTO_OBJS) )
+-include $(SINGA_OBJS:%.o=%.P)
+
+TEST_SRCS :=$(shell find src/test/ -maxdepth 1 -name "*.cc")
+TEST_OBJS := $(sort $(addprefix $(BUILD_DIR)/, $(TEST_SRCS:.cc=.o)))
+-include $(TEST_OBJS:%.o=%.P)
+
+GTEST_SRC := include/gtest/gtest-all.cc
+GTEST_HDR := include/gtest/gtest.h
+GTEST_LIB := $(BUILD_DIR)/libgtest.a
+
+OBJS := $(sort $(SINGA_OBJS) $(TEST_OBJS) )
+
+########################Compilation Section###################################
+.PHONY: singa test
+
+singa: $(PROTO_OBJS) $(SINGA_OBJS)
+       $(CXX) $(SINGA_OBJS) src/main.cc -o $(BUILD_DIR)/singa $(CXXFLAGS) 
$(LDFLAGS)
+       @echo
+
+loader: proto $(LOADER_OBJS)
+       $(CXX) $(LOADER_OBJS) -o $(BUILD_DIR)/loader $(CXXFLAGS) $(LDFLAGS)
+       @echo
+
+test:  proto $(GTEST_LIB) $(TEST_OBJS) $(SINGA_OBJS)
+       $(CXX) $(TEST_OBJS) include/gtest/gtest_main.cc $(GTEST_LIB) \
+               $(SINGA_OBJS) -o $(BUILD_DIR)/test $(CXXFLAGS) $(LDFLAGS)
+       @echo
+
+$(GTEST_LIB): $(GTEST_HDR) $(GTEST_SRC)
+       $(CXX) $(GTEST_SRC) -c -o $(BUILD_DIR)/gtest-all.o $(CXXFLAGS)
+       ar -rv $(GTEST_LIB) $(BUILD_DIR)/gtest-all.o
+
+# compile all files
+$(OBJS):$(BUILD_DIR)/%.o : %.cc
+       @mkdir -p $(dir $@)
+       $(CXX) $<  $(CXXFLAGS) -MMD -c -o $@
+       cp $(BUILD_DIR)/$*.d $(BUILD_DIR)/$*.P; \
+       sed -e 's/#.*//' -e 's/^[^:]*: *//' -e 's/ *\\$$//' \
+               -e '/^$$/ d' -e 's/$$/ :/' < $(BUILD_DIR)/$*.d >> 
$(BUILD_DIR)/$*.P; \
+       rm -f $*.d
+
+proto: $(PROTO_OBJS)
+
+$(PROTO_SRCS): $(PROTOS)
+       protoc --proto_path=src/proto --cpp_out=src/proto $(PROTOS)
+       mkdir -p include/proto/
+       cp src/proto/*.pb.h include/proto/
+       @echo
+
+clean:
+       rm -rf *.a *.so
+       rm -rf include/proto/*
+       rm -rf src/proto/*.pb.h src/proto/*.pb.cc
+       rm -rf $(BUILD_DIR)
+       @echo

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/831efef0/examples/cifar10/model-prefetch.conf
----------------------------------------------------------------------
diff --git a/examples/cifar10/model-prefetch.conf 
b/examples/cifar10/model-prefetch.conf
new file mode 100644
index 0000000..220a4b9
--- /dev/null
+++ b/examples/cifar10/model-prefetch.conf
@@ -0,0 +1,241 @@
+name: "cifar10-convnet"
+train_steps: 70000
+test_steps:100
+test_frequency:1000
+display_frequency:50
+updater{
+  momentum:0.9
+  weight_decay:0.004
+  learning_rate_change_method:kFixedStep
+  step:0
+  step:60000
+  step:65000
+  step_lr:0.001
+  step_lr:0.0001
+  step_lr:0.00001
+}
+neuralnet {
+layer{
+  name: "prefetch"
+  type: "kPrefetch"
+  sublayers {
+    name: "data"
+    type: "kShardData"
+    data_param {
+      path: "examples/cifar10/cifar10_train_shard"
+      batchsize: 100
+    }
+  }
+  sublayers{
+    name:"rgb"
+    type: "kRGBImage"
+    srclayers: "data"
+    rgbimage_param {
+      meanfile: "examples/cifar10/image_mean.bin"
+    }
+  }
+  sublayers{
+    name: "label"
+    type: "kLabel"
+    srclayers: "data"
+  }
+  exclude: kTest
+}
+
+layer{
+  name: "prefetch"
+  type: "kPrefetch"
+  sublayers {
+    name: "data"
+    type: "kShardData"
+    data_param {
+      path: "examples/cifar10/cifar10_test_shard"
+      batchsize: 100
+    }
+  }
+  sublayers{
+    name:"rgb"
+    type: "kRGBImage"
+    srclayers: "data"
+    rgbimage_param {
+      meanfile: "examples/cifar10/image_mean.bin"
+    }
+  }
+  sublayers{
+    name: "label"
+    type: "kLabel"
+    srclayers: "data"
+  }
+  exclude: kTrain
+}
+
+layer {
+  name: "conv1"
+  type: "kConvolution"
+  srclayers: "prefetch"
+  datablob: "rgb"
+  convolution_param {
+    num_filters: 32
+    kernel: 5
+    stride: 1
+    pad:2
+  }
+  param{
+      name: "weight"
+      init_method:kGaussian
+      std:0.0001
+      learning_rate_multiplier:1.0
+    }
+  param{
+      name: "bias"
+      init_method: kConstant
+      learning_rate_multiplier:2.0
+      value:0
+    }
+}
+
+layer {
+  name: "pool1"
+  type: "kPooling"
+  srclayers: "conv1"
+  pooling_param {
+    pool: MAX
+    kernel: 3
+    stride: 2
+  }
+}
+layer {
+  name: "relu1"
+  type: "kReLU"
+  srclayers:"pool1"
+}
+layer {
+  name: "norm1"
+  type: "kLRN"
+  lrn_param {
+    norm_region: WITHIN_CHANNEL
+    local_size: 3
+    alpha: 5e-05
+    beta: 0.75
+  }
+  srclayers:"relu1"
+}
+layer {
+  name: "conv2"
+  type: "kConvolution"
+  srclayers: "norm1"
+  convolution_param {
+    num_filters: 32
+    kernel: 5
+    stride: 1
+    pad:2
+  }
+  param{
+      name: "weight"
+      init_method:kGaussian
+      std:0.01
+      learning_rate_multiplier:1.0
+    }
+  param{
+      name: "bias"
+      init_method: kConstant
+      learning_rate_multiplier:2.0
+      value:0
+    }
+}
+layer {
+  name: "relu2"
+  type: "kReLU"
+  srclayers:"conv2"
+}
+layer {
+  name: "pool2"
+  type: "kPooling"
+  srclayers: "relu2"
+  pooling_param {
+    pool: MAX
+    kernel: 3
+    stride: 2
+  }
+}
+layer {
+  name: "norm2"
+  type: "kLRN"
+  lrn_param {
+    norm_region: WITHIN_CHANNEL
+    local_size: 3
+    alpha: 5e-05
+    beta: 0.75
+  }
+  srclayers:"pool2"
+}
+layer {
+  name: "conv3"
+  type: "kConvolution"
+  srclayers: "norm2"
+  convolution_param {
+    num_filters: 64
+    kernel: 5
+    stride: 1
+    pad:2
+  }
+  param{
+      name: "weight"
+      init_method:kGaussian
+      std:0.01
+    }
+  param{
+      name: "bias"
+      init_method: kConstant
+      value:0
+    }
+}
+layer {
+  name: "relu3"
+  type: "kReLU"
+  srclayers:"conv3"
+}
+layer {
+  name: "pool3"
+  type: "kPooling"
+  srclayers: "relu3"
+  pooling_param {
+    pool: AVE
+    kernel: 3
+    stride: 2
+  }
+}
+layer {
+  name: "ip1"
+  type: "kInnerProduct"
+  srclayers:"pool3"
+  inner_product_param {
+    num_output: 10
+  }
+  param{
+      name: "weight"
+      init_method:kGaussian
+      std:0.01
+      learning_rate_multiplier:1.0
+      weight_decay_multiplier:250
+    }
+  param{
+      name: "bias"
+      init_method: kConstant
+      learning_rate_multiplier:2.0
+      weight_decay_multiplier:0
+      value:0
+  }
+}
+
+layer{
+  name: "loss"
+  type:"kSoftmaxLoss"
+  softmaxloss_param{
+    topk:1
+  }
+  srclayers:"ip1"
+  srclayers:"prefetch"
+  datablob: "label"
+}
+}

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/831efef0/examples/cifar10/model.conf
----------------------------------------------------------------------
diff --git a/examples/cifar10/model.conf b/examples/cifar10/model.conf
index 09e64aa..a27486a 100644
--- a/examples/cifar10/model.conf
+++ b/examples/cifar10/model.conf
@@ -15,7 +15,7 @@ updater{
   step_lr:0.00001
 }
 neuralnet {
-layer {
+layer{
   name: "data"
   type: "kShardData"
   data_param {
@@ -24,8 +24,7 @@ layer {
   }
   exclude: kTest
 }
-
-layer {
+layer{
   name: "data"
   type: "kShardData"
   data_param {
@@ -34,7 +33,6 @@ layer {
   }
   exclude: kTrain
 }
-
 layer{
   name:"rgb"
   type: "kRGBImage"
@@ -43,12 +41,12 @@ layer{
     meanfile: "examples/cifar10/image_mean.bin"
   }
 }
-
 layer{
   name: "label"
   type: "kLabel"
   srclayers: "data"
 }
+
 layer {
   name: "conv1"
   type: "kConvolution"
@@ -72,6 +70,7 @@ layer {
       value:0
     }
 }
+
 layer {
   name: "pool1"
   type: "kPooling"
@@ -213,6 +212,6 @@ layer{
     topk:1
   }
   srclayers:"ip1"
-  srclayers:"label"
+  srclayers: "label"
 }
 }

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/831efef0/include/neuralnet/base_layer.h
----------------------------------------------------------------------
diff --git a/include/neuralnet/base_layer.h b/include/neuralnet/base_layer.h
index 863c223..e4fc174 100644
--- a/include/neuralnet/base_layer.h
+++ b/include/neuralnet/base_layer.h
@@ -6,11 +6,10 @@
 #include <map>
 #include <functional>
 #include <utility>
-#include <condition_variable>
-#include <mutex>
 #include <memory>
 #include <chrono>
 #include <algorithm>
+#include <thread>
 
 #include "proto/model.pb.h"
 #include "utils/param.h"
@@ -31,36 +30,44 @@ typedef shared_ptr<Layer> SLayer;
  * Base layer class.
  * Children should implement at least Layer::Setup, Layer::ComputeFeature(),
  * Layer::ComputGradient() functions for backpropagation method;
- * TODO(wangwei) implement children layers to support contrastive divergence,
+ * TODO(zhaojing) subclass the base layer class to support contrastive 
divergence,
  * The identifier of each layer is the literal string of the class name without
  * the suffix "Layer", which is used in layer registration and creation.
  */
 class Layer {
  public:
   Layer(){}
+  virtual ~Layer(){}
   /**
-   * simply save the proto configuation.
-   * most initializations are done by Setup().
-   * @param layer_proto user defined layer configuration
+   * Layer initialization.
+   *
+   * It simply saves the proto configuation, most initializations are done by
+   * Setup().
+   *
+   * @param proto user defined layer configuration
    */
   virtual void Init(const LayerProto &proto);
   /**
-   * copy layer configuration from the other Layer, and set the shape.
+   * Copy layer configuration from the other Layer, and use the shape argument
+   * to as its data shape.
    */
   void Init(const Layer& other, const vector<int>& shape);
-  virtual ~Layer(){}
   /**
-   * Marshal layer properties and data into google protobuf object
-   * (i.e., snapshot).
+   * TODO(wangsheng) Marshal layer properties and data into google protobuf
+   * object (i.e., snapshot).
+   *
    * Parameters are marshalled separately into another object (i.e., model).
+   *
    * @param layer_proto
-   * @param copyData if true marshal data of DArray
+   * @param copyData if true marshal layer data, e.g., feature value
    */
   virtual void ToProto(LayerProto *layer_proto, bool copyData);
   /**
    * Setup layer properties.
+   *
    * Setup the shapes for data and parameters, also setup some properties
    * based on the layer configuration and connected src layers.
+   *
    * @param srclayers layers connecting to this layer
    */
   virtual void Setup(const LayerProto& proto,
@@ -71,8 +78,9 @@ class Layer {
   virtual void Setup();
   /**
    * Setup the layer properties except shape.
-   * the shape is already set and passed in to set other properties.
-   * perperties are set according to shapes of itself and connected layers, and
+   *
+   * The shape is already set and passed in to set other properties.
+   * properties are set according to shapes of itself and connected layers, and
    * configuration. this should not change the current shape_(
    * shape check is done outside the function).
    */
@@ -86,6 +94,7 @@ class Layer {
   virtual void SetupAfterPartition();
   /**
    * Layers that have paramters must overload this function.
+   *
    * @return parameters associated with this layer
    */
   virtual vector<shared_ptr<Param>> GetParams(){
@@ -93,8 +102,11 @@ class Layer {
   }
   /**
    * Compute features of this layer based on connected layers.
-   * Implement forward propagation for BP; TODO Implement both postive phase
-   * and negative phase for CD.
+   *
+   * Implement forward propagation for BP.
+   * TODO(zhaojing) Implement both postive phase and negative phase for CD.
+   *
+   * @param training true if in training phase
    * @param srclayers layers connecting to this layer
    */
   virtual void ComputeFeature(bool training, const vector<SLayer>& 
srclayers)=0;
@@ -104,8 +116,10 @@ class Layer {
   virtual void ComputeFeature(bool training);
   /**
    * Compute gradients for parameters and connecting layers.
-   * Implement backward propagation for BP; TODO Calculate gradients for
-   * parameters for CD.
+   *
+   * Implement backward propagation for BP.
+   * TODO(zhaojing) Calculate gradients for parameters for CD.
+   *
    * @param srclayers layers connecting to this layer.
    */
   virtual void ComputeGradient(const vector<SLayer>& srclayers)=0;
@@ -114,7 +128,8 @@ class Layer {
    */
   virtual void ComputeGradient();
   /**
-   * decide on which dimension to do the partitioning.
+   * Decide on which dimension to do the partitioning.
+   *
    * @mode kLayer, kData, kNone (no partition)
    * @return the partition dimension, -1 for no partition
    */
@@ -128,7 +143,8 @@ class Layer {
   }
 
   /**
-   * return connection type between two layers.
+   * Return connection type between two layers.
+   *
    * Currently support two connections: kOneToOne, and kOneToAll.
    * kOneToOne indicates the dst neuron depends on only one neuron from src
    * layer. kOneToAll indicates the dst neuron depends on all neurons from src
@@ -139,18 +155,21 @@ class Layer {
     return kOneToOne;
   }
   /**
-   * return partition type of this layer.
-   * E.g., kNone, kLayer or kData
+   * @return partition type of this layer, e.g., kNone, kLayer or kData.
    */
   virtual PartitionType partition_type() const {
     return layer_proto_.partition_type();
   }
   /**
-   * location id is the execution unit (i.e., thread from the working group) 
ID.
+   * Set location ID as the worker ID within a worker group.
+   * TODO(wangwei) merge location ID with partition ID
    */
   virtual void set_locationid(int id){
     layer_proto_.set_locationid(id);
   }
+  /**
+   * @return location ID
+   */
   virtual int locationid() const {
     return layer_proto_.locationid();
   }
@@ -176,27 +195,36 @@ class Layer {
   const std::string &name() const {
     return layer_proto_.name();
   }
-  const vector<int>& shape(const Layer* layer=nullptr) const{
+  /**
+   * @return name of src data blob, used by prefetch layer to locate the data
+   * blob in parser layers; The default value is "unknown"; If the
+   * src layer is the prefetch layer and there are more than one parser layers,
+   * this value value be set.
+   */
+  const std::string &datablob() const {
+    return layer_proto_.datablob();
+  }
+  const vector<int>& shape(const Layer* layer) const{
     return data(layer).shape();
   }
 
   /**
    * @return a const ref for Blob storing neuron values of this layer for BP
    */
-  virtual const Blob<float>& data(const Layer* from=nullptr) const {
+  virtual const Blob<float>& data(const Layer* from) const {
     return data_;
   }
-  virtual Blob<float>* mutable_data(const Layer* from=nullptr){
+  virtual Blob<float>* mutable_data(const Layer* from){
     return &data_;
   }
 
-  virtual const Blob<float>& grad(const Layer* from=nullptr) const {
+  virtual const Blob<float>& grad(const Layer* from) const {
     return grad_;
   }
   /**
    * @return a pointer to storing neuron grads of this layer for BP
    */
-  virtual Blob<float>* mutable_grad(const Layer* from=nullptr) {
+  virtual Blob<float>* mutable_grad(const Layer* from) {
     return &grad_;
   }
 
@@ -250,9 +278,7 @@ class Layer {
   }
 protected:
   string name_;
-  //vector<shared_ptr<SyncedMem>> memblobs_;
   Blob<float> data_, grad_;
-  // DArray pos_, neg_;//for CD
   LayerProto layer_proto_;
   vector<SLayer> srclayers_, dstlayers_;
 };
@@ -328,8 +354,8 @@ class ConcateLayer: public Layer {
 
 
 /**
- * base layer for prefetching records from local Shard, HDFS, lmdb, etc.
- * cannot be partitioned, always returns kNone for partition type.
+ * Base layer for reading records from local Shard, HDFS, lmdb, etc.
+ * Cannot be partitioned, always returns kNone for partition type.
  */
 
 class DataLayer: public Layer{
@@ -346,14 +372,14 @@ class DataLayer: public Layer{
   virtual void Setup(){
     vector<SLayer> dummy;
     Setup(layer_proto_,dummy);
-    has_set_=true;
+    has_setup_=true;
   }
   virtual void SetupAfterPartition(const LayerProto& proto,
       const vector<int> &shape,
       const vector<SLayer>& srclayers){}
 
   virtual void SetupAfterPartition(){
-    if(!has_set_)
+    if(!has_setup_)
     Setup();
   }
   virtual PartitionType partition_type () const {
@@ -367,36 +393,59 @@ class DataLayer: public Layer{
     return sample_;
   }
 
-  virtual Blob<float>* mutable_data(const Layer* layer=nullptr) {
+  virtual Blob<float>* mutable_data(const Layer* layer) {
     return nullptr;
   }
-  virtual Blob<float>* mutable_grad(const Layer* layer=nullptr) {
+  virtual Blob<float>* mutable_grad(const Layer* layer) {
     return nullptr;
   }
-  void set_prefetch(bool prefetch){
-    prefetch_=prefetch;
-  }
+ protected:
+  bool has_setup_;
+  int random_skip_, batchsize_;
+  Record sample_;
+  vector<Record> records_;
+};
 
-  virtual void ComputeFeature(bool training) {
-    if(!prefetch_)
-      ComputeFeature(training, srclayers_);
+/**
+ * Layer for prefetching data records and parsing them.
+ *
+ * The data loading and parsing work is done by internal DataLayer and
+ * ParserLayer respectively. This layer controls the prefetching thread, i.e.,
+ * creating and joining the prefetching thread.
+ */
+class PrefetchLayer : public Layer {
+ public:
+  virtual ~PrefetchLayer();
+  virtual void ComputeFeature(bool training, const vector<SLayer>& srclayers);
+  virtual void Setup(const LayerProto& proto, const vector<SLayer>& srclayers);
+  virtual const Blob<float>& data(const Layer* from) const ;
+  virtual Blob<float>* mutable_data(const Layer* layer) ;
+  virtual void ComputeGradient(const vector<SLayer>& srclayers){};
+  virtual Blob<float>* mutable_grad(const Layer* layer){
+    return nullptr;
+  }
+  virtual const Blob<float>& grad(const Layer* from) const {
+    CHECK(false)<<"Loss layer has not gradient blob";
+    return grad_;
   }
 
-  virtual void Prefetching(bool training){
-    CHECK(prefetch_);
-    ComputeFeature(training, srclayers_);
+  virtual void SetupAfterPartition(const LayerProto& proto,
+      const vector<int> &shape,
+      const vector<SLayer>& srclayers){}
+
+  virtual PartitionType partition_type () const {
+    return kNone;
   }
 
+  void Prefetch(bool training);
  protected:
-  bool has_set_;
-  bool prefetch_;
-  int random_skip_, batchsize_;
-  Record sample_;
-  vector<Record> records_;
+  vector<shared_ptr<Layer>> sublayers_;
+  map<string, Blob<float>> datablobs_;
+  std::thread thread_;
 };
 
 /**
- * Slice this layer into multiple dst layers on one dimension
+ * Slice the source layer into multiple dst layers on one dimension
  */
 class SliceLayer: public Layer {
  public:
@@ -407,10 +456,10 @@ class SliceLayer: public Layer {
       const vector<SLayer>& srclayers){}
 
 
-  virtual const Blob<float>& data(const Layer* layer=nullptr) const;
-  virtual const Blob<float>& grad(const Layer* layer=nullptr) const;
-  virtual Blob<float>* mutable_data(const Layer* layer=nullptr);
-  virtual Blob<float>* mutable_grad(const Layer* layer=nullptr);
+  virtual const Blob<float>& data(const Layer* layer) const;
+  virtual const Blob<float>& grad(const Layer* layer) const;
+  virtual Blob<float>* mutable_data(const Layer* layer);
+  virtual Blob<float>* mutable_grad(const Layer* layer);
   virtual void ComputeFeature(bool training, const vector<shared_ptr<Layer>>& 
srclayers);
   virtual void ComputeGradient(const vector<shared_ptr<Layer>>& srclayers);
 
@@ -421,6 +470,7 @@ class SliceLayer: public Layer {
 
 /**
  * Replciate this layer into multiple dst layers
+ * TODO change name to ReplicateLayer.
  */
 class SplitLayer: public Layer {
  public:
@@ -445,10 +495,10 @@ class LossLayer: public Layer{
   virtual void SetupAfterPartition(const LayerProto& proto,
       const vector<int> &shape,
       const vector<SLayer>& srclayers)=0;
-  virtual Blob<float>* mutable_grad(const Layer* layer=nullptr){
+  virtual Blob<float>* mutable_grad(const Layer* layer){
     return nullptr;
   }
-  virtual const Blob<float>& grad(const Layer* from=nullptr) const {
+  virtual const Blob<float>& grad(const Layer* from) const {
     CHECK(false)<<"Loss layer has not gradient blob";
     return grad_;
   }
@@ -468,30 +518,34 @@ class LossLayer: public Layer{
  */
 class ParserLayer: public Layer {
  public:
-  virtual void Setup(const LayerProto& proto, const vector<SLayer>& 
srclayers)=0;
+  virtual void Setup(const LayerProto& proto,
+      const vector<SLayer>& srclayers)=0;
   /**
    * Parse records from DataLayer into blob.
    * This function is called by
    * ComputeFeature(bool, const vector<SLayer>& srclayers)  or Prefetch(bool).
    */
-  virtual void ParseRecords(bool training, const vector<Record>& records, 
Blob<float>* blob)=0;
+  virtual void ParseRecords(bool training, const vector<Record>& records,
+      Blob<float>* blob)=0;
+
   virtual bool is_parserlayer() const {
     return true;
   }
+
+  virtual void ComputeFeature(bool training, const vector<SLayer>& srclayers);
   /**
    * Dummy function. ParserLayer does not compute gradients.
    */
   virtual void ComputeGradient(const vector<SLayer>& srclayers){};
   virtual void Setup(){
     Setup(layer_proto_,srclayers_);
-    has_set_=true;
-    ready_=true;
-    prefetch_=false;
+    has_setup_=true;
   }
   virtual void SetupAfterPartition(){
-    if(!has_set_)
+    if(!has_setup_)
       Setup();
   }
+
   virtual void SetupAfterPartition(const LayerProto& proto,
       const vector<int> &shape,
       const vector<SLayer>& srclayers){}
@@ -499,64 +553,16 @@ class ParserLayer: public Layer {
   virtual PartitionType partition_type () const{
     return kNone;
   }
-  virtual Blob<float>* mutable_grad(const Layer* layer=nullptr) {
+  virtual Blob<float>* mutable_grad(const Layer* layer) {
     return nullptr;
   }
-  virtual const Blob<float>& grad(const Layer* from=nullptr) const {
+  virtual const Blob<float>& grad(const Layer* from) const {
     CHECK(false)<<"Parser layer has not gradient blob";
     return grad_;
   }
 
-  virtual void ComputeFeature(bool training, const vector<SLayer>& srclayers){
-    if(!prefetch_){
-      DataLayer* datalayer=static_cast<DataLayer*>(srclayers[0].get());
-      ParseRecords(training, datalayer->records(), &data_);
-    }else{
-      std::unique_lock<std::mutex> lck(mtx_);
-      while(!ready_) cv_.wait(lck);
-      data_.CopyFrom(prefetch_data_);
-      ready_=false;
-      cv_.notify_all();
-    }
-  }
-  /**
-   * prefetching is transparent to parsing logics.
-   * users implement parsing logics in ParseRecords
-   * worker/training algorithm calls this function to do prefetching in a
-   * separate thread. Records are in fact parsed into prefetch_data_, and later
-   * copied into data_.
-   */
-  void Prefetching(bool training){
-    std::unique_lock<std::mutex> lck(mtx_);
-    while(ready_) cv_.wait(lck);
-    //data_.Swap(prefetch_data_);
-    DataLayer* datalayer=static_cast<DataLayer*>(srclayers_[0].get());
-    ParseRecords(training, datalayer->records(), &prefetch_data_);
-    ready_=true;
-    cv_.notify_all();
-  }
-
-  /**
-   * must be called before calling ComputeFeature(bool) if Prefetching runs in 
a
-   * separate thread
-   */
-  void set_prefetch(bool prefetch) {
-    if(prefetch){
-      if(prefetch_data_.count()==0)
-        prefetch_data_.ReshapeLike(data_);
-      ready_=false;
-    }
-    prefetch_=prefetch;
-  }
-
  private:
-  std::mutex mtx_;
-  std::condition_variable cv_;
-  bool ready_;
-  bool has_set_;
-  bool prefetch_;
-  //!< prefetch_data_ is invisible to layer logics, i.e., parsing.
-  Blob<float> prefetch_data_;
+  bool has_setup_;
 };
 } // singa
 

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/831efef0/include/neuralnet/layer.h
----------------------------------------------------------------------
diff --git a/include/neuralnet/layer.h b/include/neuralnet/layer.h
index 263d249..318f295 100644
--- a/include/neuralnet/layer.h
+++ b/include/neuralnet/layer.h
@@ -115,6 +115,8 @@ class LabelLayer: public ParserLayer {
   virtual void Setup(const LayerProto& proto, const vector<SLayer>& srclayers);
   virtual void ParseRecords(bool training, const vector<Record>& records,
       Blob<float>* blob);
+
+
 };
 
 class LRNLayer: public Layer {

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/831efef0/src/neuralnet/base_layer.cc
----------------------------------------------------------------------
diff --git a/src/neuralnet/base_layer.cc b/src/neuralnet/base_layer.cc
index 50fc396..44bed3f 100644
--- a/src/neuralnet/base_layer.cc
+++ b/src/neuralnet/base_layer.cc
@@ -3,11 +3,13 @@
 #include <cblas.h>
 #include <math.h>
 #include <cfloat>
+#include <glog/logging.h>
+#include "utils/singleton.h"
+#include "utils/factory.h"
 #include "neuralnet/base_layer.h"
+
 namespace singa {
-/*****************************************************************************
- * Implementation for Layer
- *****************************************************************************/
+/********* Implementation for Layer **************/
 void Layer::Init(const LayerProto &proto) {
   layer_proto_=proto;
 }
@@ -36,6 +38,8 @@ void Layer::ComputeGradient(){
 
 void Layer::ToProto(LayerProto *proto, bool copyData) {
 }
+
+/********* Implementation for BridgeSrcLayer **************/
 void BridgeSrcLayer::Setup(const LayerProto& proto,
     const vector<SLayer>& srclayers){
   CHECK_EQ(srclayers.size(),1);
@@ -57,6 +61,8 @@ void BridgeSrcLayer::ComputeFeature(bool training,
 void BridgeSrcLayer::ComputeGradient(const vector<SLayer>& srclayers){
 
 }
+
+/********* Implementation for BridgeDstLayer **************/
 void BridgeDstLayer::Setup(const LayerProto& proto,
     const vector<SLayer>& srclayers){
   CHECK_EQ(srclayers.size(),1);
@@ -79,9 +85,7 @@ void BridgeDstLayer::ComputeGradient(const 
vector<shared_ptr<Layer>>& srclayers)
 
 }
 
-/*******************************
- * Implementation for ConcateLayer
- *******************************/
+/************* Implementation for ConcateLayer ***********/
 void ConcateLayer::Setup(const LayerProto& proto,
     const vector<SLayer>& srclayers){
   size_t concate_dim=proto.concate_param().concate_dimension();
@@ -108,9 +112,87 @@ void ConcateLayer::SetupAfterPartition(){
 void ConcateLayer::ComputeFeature(bool training, const vector<SLayer>& 
srclayers){}
 
 void ConcateLayer::ComputeGradient(const vector<shared_ptr<Layer>>& 
srclayers){}
-/*****************************************************************************
- * Implementation for SliceLayer
- *****************************************************************************/
+
+/************* Implementation for ParserLayer ***********/
+void ParserLayer::ComputeFeature(bool training, const vector<SLayer>& 
srclayers){
+  CHECK_EQ(srclayers.size(),1);
+  auto datalayer=static_cast<DataLayer*>(srclayers.begin()->get());
+  ParseRecords(training, datalayer->records(), &data_);
+}
+
+/************* Implementation for PrefetchLayer ***********/
+void PrefetchLayer::Prefetch(bool training){
+  //clock_t s=clock();
+  for(auto layer: sublayers_)
+    layer->ComputeFeature(training);
+  //LOG(ERROR)<<(clock()-s)*1.0/CLOCKS_PER_SEC;
+}
+
+void PrefetchLayer::ComputeFeature(bool training,
+    const vector<SLayer>& srclayers){
+  if(thread_.joinable())
+    thread_.join();
+  else{
+    Prefetch(training);
+  }
+  for(auto layer: sublayers_){
+    if(layer->is_parserlayer())
+      // TODO replace CopyFrom with Swap?
+      datablobs_.at(layer->name()).CopyFrom(layer->data(this));
+  }
+  thread_=std::thread(&PrefetchLayer::Prefetch, this, training);
+}
+
+void PrefetchLayer::Setup(const LayerProto& proto,
+    const vector<SLayer>& srclayers){
+  Factory<Layer>* factory=Singleton<Factory<Layer>>::Instance();
+  CHECK_GE(proto.sublayers_size(), 1);
+  map<string, SLayer> layers;
+  for(auto const &p:proto.sublayers()){
+    auto layer=shared_ptr<Layer>(factory->Create(p.type()));
+    layer->Init(p);
+    sublayers_.push_back(layer);
+    layers[p.name()]= layer;
+  }
+  // TODO topology sort layers
+  auto layer=sublayers_.begin();
+  for(auto const &p:proto.sublayers()){
+    std::vector<SLayer> src;
+    for(auto const &srcname: p.srclayers()){
+      src.push_back(layers[srcname]);
+      (*layer)->AddSrcLayer(layers[srcname]);
+    }
+    (*layer)->Setup(p, src);
+    layer++;
+  }
+  for(auto layer: sublayers_)
+    if(layer->is_parserlayer())
+      datablobs_[layer->name()]=Blob<float>(layer->data(this).shape());
+}
+
+const Blob<float>& PrefetchLayer::data(const Layer* from) const {
+  if(from!=nullptr){
+    return datablobs_.at(from->datablob());
+  }else{
+    //CHECK_EQ(datablobs_.size(),1);
+    return datablobs_.begin()->second;
+  }
+}
+
+Blob<float>* PrefetchLayer::mutable_data(const Layer* from) {
+  if(from!=nullptr){
+    return &(datablobs_.at(from->datablob()));
+  }else{
+    //CHECK_EQ(datablobs_.size(),1);
+    return &(datablobs_.begin()->second);
+  }
+}
+
+PrefetchLayer::~PrefetchLayer(){
+  if(thread_.joinable())
+    thread_.join();
+}
+/************* Implementation for SliceLayer****************/
 void SliceLayer::Setup(const LayerProto& proto,
     const vector<SLayer>& srclayers){
   int slice_dim=proto.slice_param().slice_dimension();
@@ -179,6 +261,7 @@ void SplitLayer::Setup(const LayerProto& proto,
   grad_.Reshape(srclayers[0]->data(this).shape());
 }
 
+/************* Implementation for SplitLayer****************/
 void SplitLayer::SetupAfterPartition(){
   Setup(layer_proto_, srclayers_);
   //LOG(ERROR)<<name()<<":"<<IntVecToString(shape_);

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/831efef0/src/neuralnet/layer.cc
----------------------------------------------------------------------
diff --git a/src/neuralnet/layer.cc b/src/neuralnet/layer.cc
index 71c6f2a..03eacc1 100644
--- a/src/neuralnet/layer.cc
+++ b/src/neuralnet/layer.cc
@@ -142,7 +142,7 @@ void DropoutLayer::SetupAfterPartition(const LayerProto& 
proto,
 void DropoutLayer::ComputeFeature(bool training, const vector<SLayer>& 
srclayers) {
   // check training
   if(!training){
-    data_.CopyFrom(srclayers[0]->data());
+    data_.CopyFrom(srclayers[0]->data(this));
     return;
   }
   float pkeep=1-pdrop_;
@@ -150,7 +150,7 @@ void DropoutLayer::ComputeFeature(bool training, const 
vector<SLayer>& srclayers
   mask = F<op::threshold>(ASingleton<Random<cpu>>::Instance()\
       ->uniform(mask.shape), pkeep ) * (1.0f/pkeep);
   Tensor<cpu, 1> data(data_.mutable_cpu_data(), Shape1(data_.count()));
-  Blob<float>* srcblob=srclayers[0]->mutable_data();
+  Blob<float>* srcblob=srclayers[0]->mutable_data(this);
   Tensor<cpu, 1> src(srcblob->mutable_cpu_data(), Shape1(srcblob->count()));
   data=src*mask;
 }
@@ -158,7 +158,7 @@ void DropoutLayer::ComputeFeature(bool training, const 
vector<SLayer>& srclayers
 void DropoutLayer::ComputeGradient(const vector<SLayer>& srclayers)  {
   Tensor<cpu, 1> grad(grad_.mutable_cpu_data(), Shape1(data_.count()));
   Tensor<cpu, 1> mask(mask_.mutable_cpu_data(), Shape1(mask_.count()));
-  Blob<float>* gsrcblob=srclayers[0]->mutable_grad();
+  Blob<float>* gsrcblob=srclayers[0]->mutable_grad(this);
   Tensor<cpu, 1> gsrc(gsrcblob->mutable_cpu_data(), Shape1(gsrcblob->count()));
   gsrc=grad*mask;
 }
@@ -189,8 +189,8 @@ void InnerProductLayer::SetupAfterPartition(const 
LayerProto& proto,
 
 void InnerProductLayer::ComputeFeature(bool training, const vector<SLayer>& 
srclayers) {
   Tensor<cpu, 2> data(data_.mutable_cpu_data(), Shape2(batchsize_,hdim_));
-  CHECK_EQ(srclayers[0]->data().count(), batchsize_*vdim_);
-  Tensor<cpu, 2> src(srclayers[0]->mutable_data()->mutable_cpu_data(),
+  CHECK_EQ(srclayers[0]->data(this).count(), batchsize_*vdim_);
+  Tensor<cpu, 2> src(srclayers[0]->mutable_data(this)->mutable_cpu_data(),
       Shape2(batchsize_,vdim_));
   Tensor<cpu, 2> weight(weight_->mutable_cpu_data(), Shape2(vdim_,hdim_));
   Tensor<cpu, 1> bias(bias_->mutable_cpu_data(), Shape1(hdim_));
@@ -200,7 +200,7 @@ void InnerProductLayer::ComputeFeature(bool training, const 
vector<SLayer>& srcl
 }
 
 void InnerProductLayer::ComputeGradient(const vector<SLayer>& srclayers) {
-  Tensor<cpu, 2> src(srclayers[0]->mutable_data()->mutable_cpu_data(),
+  Tensor<cpu, 2> src(srclayers[0]->mutable_data(this)->mutable_cpu_data(),
       Shape2(batchsize_,vdim_));
   Tensor<cpu, 2> grad(grad_.mutable_cpu_data(),Shape2(batchsize_,hdim_));
   Tensor<cpu, 2> weight(weight_->mutable_cpu_data(), Shape2(vdim_,hdim_));
@@ -225,10 +225,10 @@ void LabelLayer::Setup(const LayerProto& proto,
   data_.Reshape(vector<int>{batchsize});
 }
 
-void LabelLayer::ParseRecords(bool training, const vector<Record>& records, 
Blob<float>* blob){
-  LOG_IF(ERROR, records.size()==0)<<"Empty records to parse";
-  float *label= blob->mutable_cpu_data() ;
+void LabelLayer::ParseRecords(bool training, const vector<Record>& records,
+    Blob<float>* blob){
   int rid=0;
+  float *label= blob->mutable_cpu_data() ;
   for(const Record& record: records){
     label[rid++]=record.image().label();
     CHECK_LT(record.image().label(),10);
@@ -371,7 +371,7 @@ void LRNLayer::ComputeFeature(bool training, const 
vector<SLayer>& srclayers){
 void LRNLayer::ComputeGradient(const vector<SLayer>& srclayers) {
   const float salpha = alpha_ / lsize_;
   Shape<4> s=Shape4(batchsize_,channels_, height_, width_);
-  Tensor<cpu, 4> src(srclayers[0]->mutable_data()->mutable_cpu_data(), s);
+  Tensor<cpu, 4> src(srclayers[0]->mutable_data(this)->mutable_cpu_data(), s);
   Tensor<cpu, 4> norm(norm_.mutable_cpu_data(), s);
   Tensor<cpu, 4> grad(grad_.mutable_cpu_data(), s);
   Tensor<cpu, 4> gsrc(srclayers[0]->mutable_grad(this)->mutable_cpu_data(), s);
@@ -383,8 +383,8 @@ void LRNLayer::ComputeGradient(const vector<SLayer>& 
srclayers) {
 
 /**************** Implementation for MnistImageLayer******************/
 
-void MnistImageLayer::ParseRecords(bool training, const vector<Record>& 
records,
-    Blob<float>* blob){
+void MnistImageLayer::ParseRecords(bool training,
+    const vector<Record>& records, Blob<float>* blob){
   LOG_IF(ERROR, records.size()==0)<<"Empty records to parse";
   int ndim=records.at(0).image().shape_size();
   int inputsize =records.at(0).image().shape(ndim-1);
@@ -545,8 +545,8 @@ void PoolingLayer::ComputeGradient(const vector<SLayer>& 
srclayers) {
 
 void ReLULayer::Setup(const LayerProto& proto,
       const vector<SLayer>& srclayers){
-  data_.ReshapeLike(srclayers[0]->data());
-  grad_.ReshapeLike(*(srclayers[0]->mutable_grad()));
+  data_.ReshapeLike(srclayers[0]->data(this));
+  grad_.ReshapeLike(*(srclayers[0]->mutable_grad(this)));
 }
 
 void ReLULayer::SetupAfterPartition(const LayerProto& proto,
@@ -572,11 +572,10 @@ void ReLULayer::ComputeGradient(const vector<SLayer>& 
srclayers) {
 
 /*************** Implementation for RGBImageLayer *************************/
 
-void RGBImageLayer::ParseRecords(bool training, const vector<Record>& records,
-    Blob<float>* blob){
-  LOG_IF(ERROR, records.size()==0)<<"Empty records to parse";
+void RGBImageLayer::ParseRecords(bool training,
+    const vector<Record>& records, Blob<float>* blob){
   const vector<int>& s=blob->shape();
-  Tensor<cpu, 4> images(blob->mutable_cpu_data(), Shape4(s[0],s[1],s[2],s[3]));
+  Tensor<cpu, 4> images(data_.mutable_cpu_data(), Shape4(s[0],s[1],s[2],s[3]));
   const SingleLabelImageRecord& r=records.at(0).image();
   Tensor<cpu, 3> raw_image(Shape3(r.shape(0),r.shape(1),r.shape(2)));
   AllocSpace(raw_image);
@@ -638,8 +637,9 @@ void RGBImageLayer::Setup(const LayerProto& proto,
   Record sample=static_cast<DataLayer*>(srclayers[0].get())->sample();
   vector<int> shape;
   shape.push_back(batchsize);
-  for(int x: sample.image().shape())
+  for(int x: sample.image().shape()){
     shape.push_back(x);
+  }
   CHECK_EQ(shape.size(),4);
   if(cropsize_){
     shape[2]=cropsize_;
@@ -743,9 +743,9 @@ void SoftmaxLossLayer::SetupAfterPartition(const 
LayerProto& proto,
 void SoftmaxLossLayer::ComputeFeature(bool training, const vector<SLayer>& 
srclayers) {
   Shape<2> s=Shape2(batchsize_, dim_);
   Tensor<cpu, 2> prob(data_.mutable_cpu_data(), s);
-  Tensor<cpu, 2> src(srclayers[0]->mutable_data()->mutable_cpu_data(), s);
+  Tensor<cpu, 2> src(srclayers[0]->mutable_data(this)->mutable_cpu_data(), s);
   Softmax(prob, src);
-  const float* label=srclayers[1]->data().cpu_data();
+  const float* label=srclayers[1]->data(this).cpu_data();
   const float* probptr=prob.dptr;
   float loss=0, precision=0;
   for(int n=0;n<batchsize_;n++){
@@ -777,8 +777,8 @@ void SoftmaxLossLayer::ComputeFeature(bool training, const 
vector<SLayer>& srcla
 }
 
 void SoftmaxLossLayer::ComputeGradient(const vector<SLayer>& srclayers) {
-  const float* label=srclayers[1]->data().cpu_data();
-  Blob<float>* gsrcblob=srclayers[0]->mutable_grad();
+  const float* label=srclayers[1]->data(this).cpu_data();
+  Blob<float>* gsrcblob=srclayers[0]->mutable_grad(this);
   gsrcblob->CopyFrom(data_);
   float* gsrcptr=gsrcblob->mutable_cpu_data();
   for(int n=0;n<batchsize_;n++){

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/831efef0/src/neuralnet/neuralnet.cc
----------------------------------------------------------------------
diff --git a/src/neuralnet/neuralnet.cc b/src/neuralnet/neuralnet.cc
index 0bca26e..accd619 100644
--- a/src/neuralnet/neuralnet.cc
+++ b/src/neuralnet/neuralnet.cc
@@ -12,18 +12,19 @@ namespace singa {
 
 void NeuralNet::RegisterLayers(){
   Factory<Layer>* factory=Singleton<Factory<Layer>>::Instance();
+  factory->Register("kBridgeDst", CreateLayer(BridgeDstLayer));
+  factory->Register("kBridgeSrc", CreateLayer(BridgeSrcLayer));
   factory->Register("kConvolution", CreateLayer(ConvolutionLayer));
   factory->Register("kConcate", CreateLayer(ConcateLayer));
   factory->Register("kDropout", CreateLayer(DropoutLayer));
   factory->Register("kInnerProduct", CreateLayer(InnerProductLayer));
-  factory->Register("kRGBImage", CreateLayer(RGBImageLayer));
   factory->Register("kLabel", CreateLayer(LabelLayer));
   factory->Register("kLMDBData", CreateLayer(LMDBDataLayer));
   factory->Register("kLRN", CreateLayer(LRNLayer));
   factory->Register("kMnistImage", CreateLayer(MnistImageLayer));
-  factory->Register("kBridgeDst", CreateLayer(BridgeDstLayer));
-  factory->Register("kBridgeSrc", CreateLayer(BridgeSrcLayer));
   factory->Register("kPooling", CreateLayer(PoolingLayer));
+  factory->Register("kPrefetch", CreateLayer(PrefetchLayer));
+  factory->Register("kRGBImage", CreateLayer(RGBImageLayer));
   factory->Register("kReLU", CreateLayer(ReLULayer));
   factory->Register("kShardData", CreateLayer(ShardDataLayer));
   factory->Register("kSlice", CreateLayer(SliceLayer));
@@ -361,7 +362,7 @@ string NeuralNet::DebugInfo(){
   for(auto& layer: layers_){
     if(!layer->is_datalayer()){
       sprintf(display, "Forward layer  %10s data norm1 %13.9f\n",
-          layer->name().c_str(), layer->data().asum_data());
+          layer->name().c_str(), layer->data(nullptr).asum_data());
       ret+=string(display);
     }
   }
@@ -369,7 +370,7 @@ string NeuralNet::DebugInfo(){
     shared_ptr<Layer> layer=*it;
     
if(!(layer->is_datalayer()||layer->is_losslayer()||layer->is_parserlayer())){
       sprintf(display, "Backward layer %10s grad norm1 %13.9f\n",
-          layer->name().c_str(), layer->grad().asum_data());
+          layer->name().c_str(), layer->grad(nullptr).asum_data());
       ret+=string(display);
     }
   }

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/831efef0/src/proto/model.pb.h
----------------------------------------------------------------------
diff --git a/src/proto/model.pb.h b/src/proto/model.pb.h
index bc4d952..b111a2f 100644
--- a/src/proto/model.pb.h
+++ b/src/proto/model.pb.h
@@ -1228,6 +1228,18 @@ class LayerProto : public ::google::protobuf::Message {
   inline ::singa::PartitionType partition_type() const;
   inline void set_partition_type(::singa::PartitionType value);
 
+  // optional string datablob = 7;
+  inline bool has_datablob() const;
+  inline void clear_datablob();
+  static const int kDatablobFieldNumber = 7;
+  inline const ::std::string& datablob() const;
+  inline void set_datablob(const ::std::string& value);
+  inline void set_datablob(const char* value);
+  inline void set_datablob(const char* value, size_t size);
+  inline ::std::string* mutable_datablob();
+  inline ::std::string* release_datablob();
+  inline void set_allocated_datablob(::std::string* datablob);
+
   // repeated string share_ary = 11;
   inline int share_ary_size() const;
   inline void clear_share_ary();
@@ -1354,6 +1366,18 @@ class LayerProto : public ::google::protobuf::Message {
   inline ::singa::PoolingProto* release_pooling_param();
   inline void set_allocated_pooling_param(::singa::PoolingProto* 
pooling_param);
 
+  // repeated .singa.LayerProto sublayers = 35;
+  inline int sublayers_size() const;
+  inline void clear_sublayers();
+  static const int kSublayersFieldNumber = 35;
+  inline const ::singa::LayerProto& sublayers(int index) const;
+  inline ::singa::LayerProto* mutable_sublayers(int index);
+  inline ::singa::LayerProto* add_sublayers();
+  inline const ::google::protobuf::RepeatedPtrField< ::singa::LayerProto >&
+      sublayers() const;
+  inline ::google::protobuf::RepeatedPtrField< ::singa::LayerProto >*
+      mutable_sublayers();
+
   // optional .singa.SliceProto slice_param = 32;
   inline bool has_slice_param() const;
   inline void clear_slice_param();
@@ -1420,6 +1444,8 @@ class LayerProto : public ::google::protobuf::Message {
   inline void clear_has_partitionid();
   inline void set_has_partition_type();
   inline void clear_has_partition_type();
+  inline void set_has_datablob();
+  inline void clear_has_datablob();
   inline void set_has_convolution_param();
   inline void clear_has_convolution_param();
   inline void set_has_concate_param();
@@ -1456,6 +1482,7 @@ class LayerProto : public ::google::protobuf::Message {
   ::google::protobuf::RepeatedPtrField< ::std::string> srclayers_;
   ::google::protobuf::int32 locationid_;
   ::google::protobuf::int32 partitionid_;
+  ::std::string* datablob_;
   ::google::protobuf::RepeatedPtrField< ::std::string> share_ary_;
   ::google::protobuf::RepeatedPtrField< ::singa::ParamProto > param_;
   ::google::protobuf::RepeatedPtrField< ::std::string> share_param_;
@@ -1468,6 +1495,7 @@ class LayerProto : public ::google::protobuf::Message {
   ::singa::LRNProto* lrn_param_;
   ::singa::MnistProto* mnist_param_;
   ::singa::PoolingProto* pooling_param_;
+  ::google::protobuf::RepeatedPtrField< ::singa::LayerProto > sublayers_;
   ::singa::SliceProto* slice_param_;
   ::singa::SplitProto* split_param_;
   ::singa::ReLUProto* relu_param_;
@@ -1477,7 +1505,7 @@ class LayerProto : public ::google::protobuf::Message {
   int partition_type_;
 
   mutable int _cached_size_;
-  ::google::protobuf::uint32 _has_bits_[(24 + 31) / 32];
+  ::google::protobuf::uint32 _has_bits_[(26 + 31) / 32];
 
   friend void  protobuf_AddDesc_model_2eproto();
   friend void protobuf_AssignDesc_model_2eproto();
@@ -5285,6 +5313,76 @@ inline void 
LayerProto::set_partition_type(::singa::PartitionType value) {
   partition_type_ = value;
 }
 
+// optional string datablob = 7;
+inline bool LayerProto::has_datablob() const {
+  return (_has_bits_[0] & 0x00000040u) != 0;
+}
+inline void LayerProto::set_has_datablob() {
+  _has_bits_[0] |= 0x00000040u;
+}
+inline void LayerProto::clear_has_datablob() {
+  _has_bits_[0] &= ~0x00000040u;
+}
+inline void LayerProto::clear_datablob() {
+  if (datablob_ != &::google::protobuf::internal::kEmptyString) {
+    datablob_->clear();
+  }
+  clear_has_datablob();
+}
+inline const ::std::string& LayerProto::datablob() const {
+  return *datablob_;
+}
+inline void LayerProto::set_datablob(const ::std::string& value) {
+  set_has_datablob();
+  if (datablob_ == &::google::protobuf::internal::kEmptyString) {
+    datablob_ = new ::std::string;
+  }
+  datablob_->assign(value);
+}
+inline void LayerProto::set_datablob(const char* value) {
+  set_has_datablob();
+  if (datablob_ == &::google::protobuf::internal::kEmptyString) {
+    datablob_ = new ::std::string;
+  }
+  datablob_->assign(value);
+}
+inline void LayerProto::set_datablob(const char* value, size_t size) {
+  set_has_datablob();
+  if (datablob_ == &::google::protobuf::internal::kEmptyString) {
+    datablob_ = new ::std::string;
+  }
+  datablob_->assign(reinterpret_cast<const char*>(value), size);
+}
+inline ::std::string* LayerProto::mutable_datablob() {
+  set_has_datablob();
+  if (datablob_ == &::google::protobuf::internal::kEmptyString) {
+    datablob_ = new ::std::string;
+  }
+  return datablob_;
+}
+inline ::std::string* LayerProto::release_datablob() {
+  clear_has_datablob();
+  if (datablob_ == &::google::protobuf::internal::kEmptyString) {
+    return NULL;
+  } else {
+    ::std::string* temp = datablob_;
+    datablob_ = const_cast< 
::std::string*>(&::google::protobuf::internal::kEmptyString);
+    return temp;
+  }
+}
+inline void LayerProto::set_allocated_datablob(::std::string* datablob) {
+  if (datablob_ != &::google::protobuf::internal::kEmptyString) {
+    delete datablob_;
+  }
+  if (datablob) {
+    set_has_datablob();
+    datablob_ = datablob;
+  } else {
+    clear_has_datablob();
+    datablob_ = const_cast< 
::std::string*>(&::google::protobuf::internal::kEmptyString);
+  }
+}
+
 // repeated string share_ary = 11;
 inline int LayerProto::share_ary_size() const {
   return share_ary_.size();
@@ -5427,13 +5525,13 @@ LayerProto::mutable_exclude() {
 
 // optional .singa.ConvolutionProto convolution_param = 21;
 inline bool LayerProto::has_convolution_param() const {
-  return (_has_bits_[0] & 0x00000400u) != 0;
+  return (_has_bits_[0] & 0x00000800u) != 0;
 }
 inline void LayerProto::set_has_convolution_param() {
-  _has_bits_[0] |= 0x00000400u;
+  _has_bits_[0] |= 0x00000800u;
 }
 inline void LayerProto::clear_has_convolution_param() {
-  _has_bits_[0] &= ~0x00000400u;
+  _has_bits_[0] &= ~0x00000800u;
 }
 inline void LayerProto::clear_convolution_param() {
   if (convolution_param_ != NULL) 
convolution_param_->::singa::ConvolutionProto::Clear();
@@ -5465,13 +5563,13 @@ inline void 
LayerProto::set_allocated_convolution_param(::singa::ConvolutionProt
 
 // optional .singa.ConcateProto concate_param = 31;
 inline bool LayerProto::has_concate_param() const {
-  return (_has_bits_[0] & 0x00000800u) != 0;
+  return (_has_bits_[0] & 0x00001000u) != 0;
 }
 inline void LayerProto::set_has_concate_param() {
-  _has_bits_[0] |= 0x00000800u;
+  _has_bits_[0] |= 0x00001000u;
 }
 inline void LayerProto::clear_has_concate_param() {
-  _has_bits_[0] &= ~0x00000800u;
+  _has_bits_[0] &= ~0x00001000u;
 }
 inline void LayerProto::clear_concate_param() {
   if (concate_param_ != NULL) concate_param_->::singa::ConcateProto::Clear();
@@ -5503,13 +5601,13 @@ inline void 
LayerProto::set_allocated_concate_param(::singa::ConcateProto* conca
 
 // optional .singa.DataProto data_param = 22;
 inline bool LayerProto::has_data_param() const {
-  return (_has_bits_[0] & 0x00001000u) != 0;
+  return (_has_bits_[0] & 0x00002000u) != 0;
 }
 inline void LayerProto::set_has_data_param() {
-  _has_bits_[0] |= 0x00001000u;
+  _has_bits_[0] |= 0x00002000u;
 }
 inline void LayerProto::clear_has_data_param() {
-  _has_bits_[0] &= ~0x00001000u;
+  _has_bits_[0] &= ~0x00002000u;
 }
 inline void LayerProto::clear_data_param() {
   if (data_param_ != NULL) data_param_->::singa::DataProto::Clear();
@@ -5541,13 +5639,13 @@ inline void 
LayerProto::set_allocated_data_param(::singa::DataProto* data_param)
 
 // optional .singa.DropoutProto dropout_param = 23;
 inline bool LayerProto::has_dropout_param() const {
-  return (_has_bits_[0] & 0x00002000u) != 0;
+  return (_has_bits_[0] & 0x00004000u) != 0;
 }
 inline void LayerProto::set_has_dropout_param() {
-  _has_bits_[0] |= 0x00002000u;
+  _has_bits_[0] |= 0x00004000u;
 }
 inline void LayerProto::clear_has_dropout_param() {
-  _has_bits_[0] &= ~0x00002000u;
+  _has_bits_[0] &= ~0x00004000u;
 }
 inline void LayerProto::clear_dropout_param() {
   if (dropout_param_ != NULL) dropout_param_->::singa::DropoutProto::Clear();
@@ -5579,13 +5677,13 @@ inline void 
LayerProto::set_allocated_dropout_param(::singa::DropoutProto* dropo
 
 // optional .singa.InnerProductProto inner_product_param = 24;
 inline bool LayerProto::has_inner_product_param() const {
-  return (_has_bits_[0] & 0x00004000u) != 0;
+  return (_has_bits_[0] & 0x00008000u) != 0;
 }
 inline void LayerProto::set_has_inner_product_param() {
-  _has_bits_[0] |= 0x00004000u;
+  _has_bits_[0] |= 0x00008000u;
 }
 inline void LayerProto::clear_has_inner_product_param() {
-  _has_bits_[0] &= ~0x00004000u;
+  _has_bits_[0] &= ~0x00008000u;
 }
 inline void LayerProto::clear_inner_product_param() {
   if (inner_product_param_ != NULL) 
inner_product_param_->::singa::InnerProductProto::Clear();
@@ -5617,13 +5715,13 @@ inline void 
LayerProto::set_allocated_inner_product_param(::singa::InnerProductP
 
 // optional .singa.LRNProto lrn_param = 25;
 inline bool LayerProto::has_lrn_param() const {
-  return (_has_bits_[0] & 0x00008000u) != 0;
+  return (_has_bits_[0] & 0x00010000u) != 0;
 }
 inline void LayerProto::set_has_lrn_param() {
-  _has_bits_[0] |= 0x00008000u;
+  _has_bits_[0] |= 0x00010000u;
 }
 inline void LayerProto::clear_has_lrn_param() {
-  _has_bits_[0] &= ~0x00008000u;
+  _has_bits_[0] &= ~0x00010000u;
 }
 inline void LayerProto::clear_lrn_param() {
   if (lrn_param_ != NULL) lrn_param_->::singa::LRNProto::Clear();
@@ -5655,13 +5753,13 @@ inline void 
LayerProto::set_allocated_lrn_param(::singa::LRNProto* lrn_param) {
 
 // optional .singa.MnistProto mnist_param = 26;
 inline bool LayerProto::has_mnist_param() const {
-  return (_has_bits_[0] & 0x00010000u) != 0;
+  return (_has_bits_[0] & 0x00020000u) != 0;
 }
 inline void LayerProto::set_has_mnist_param() {
-  _has_bits_[0] |= 0x00010000u;
+  _has_bits_[0] |= 0x00020000u;
 }
 inline void LayerProto::clear_has_mnist_param() {
-  _has_bits_[0] &= ~0x00010000u;
+  _has_bits_[0] &= ~0x00020000u;
 }
 inline void LayerProto::clear_mnist_param() {
   if (mnist_param_ != NULL) mnist_param_->::singa::MnistProto::Clear();
@@ -5693,13 +5791,13 @@ inline void 
LayerProto::set_allocated_mnist_param(::singa::MnistProto* mnist_par
 
 // optional .singa.PoolingProto pooling_param = 27;
 inline bool LayerProto::has_pooling_param() const {
-  return (_has_bits_[0] & 0x00020000u) != 0;
+  return (_has_bits_[0] & 0x00040000u) != 0;
 }
 inline void LayerProto::set_has_pooling_param() {
-  _has_bits_[0] |= 0x00020000u;
+  _has_bits_[0] |= 0x00040000u;
 }
 inline void LayerProto::clear_has_pooling_param() {
-  _has_bits_[0] &= ~0x00020000u;
+  _has_bits_[0] &= ~0x00040000u;
 }
 inline void LayerProto::clear_pooling_param() {
   if (pooling_param_ != NULL) pooling_param_->::singa::PoolingProto::Clear();
@@ -5729,15 +5827,40 @@ inline void 
LayerProto::set_allocated_pooling_param(::singa::PoolingProto* pooli
   }
 }
 
+// repeated .singa.LayerProto sublayers = 35;
+inline int LayerProto::sublayers_size() const {
+  return sublayers_.size();
+}
+inline void LayerProto::clear_sublayers() {
+  sublayers_.Clear();
+}
+inline const ::singa::LayerProto& LayerProto::sublayers(int index) const {
+  return sublayers_.Get(index);
+}
+inline ::singa::LayerProto* LayerProto::mutable_sublayers(int index) {
+  return sublayers_.Mutable(index);
+}
+inline ::singa::LayerProto* LayerProto::add_sublayers() {
+  return sublayers_.Add();
+}
+inline const ::google::protobuf::RepeatedPtrField< ::singa::LayerProto >&
+LayerProto::sublayers() const {
+  return sublayers_;
+}
+inline ::google::protobuf::RepeatedPtrField< ::singa::LayerProto >*
+LayerProto::mutable_sublayers() {
+  return &sublayers_;
+}
+
 // optional .singa.SliceProto slice_param = 32;
 inline bool LayerProto::has_slice_param() const {
-  return (_has_bits_[0] & 0x00040000u) != 0;
+  return (_has_bits_[0] & 0x00100000u) != 0;
 }
 inline void LayerProto::set_has_slice_param() {
-  _has_bits_[0] |= 0x00040000u;
+  _has_bits_[0] |= 0x00100000u;
 }
 inline void LayerProto::clear_has_slice_param() {
-  _has_bits_[0] &= ~0x00040000u;
+  _has_bits_[0] &= ~0x00100000u;
 }
 inline void LayerProto::clear_slice_param() {
   if (slice_param_ != NULL) slice_param_->::singa::SliceProto::Clear();
@@ -5769,13 +5892,13 @@ inline void 
LayerProto::set_allocated_slice_param(::singa::SliceProto* slice_par
 
 // optional .singa.SplitProto split_param = 33;
 inline bool LayerProto::has_split_param() const {
-  return (_has_bits_[0] & 0x00080000u) != 0;
+  return (_has_bits_[0] & 0x00200000u) != 0;
 }
 inline void LayerProto::set_has_split_param() {
-  _has_bits_[0] |= 0x00080000u;
+  _has_bits_[0] |= 0x00200000u;
 }
 inline void LayerProto::clear_has_split_param() {
-  _has_bits_[0] &= ~0x00080000u;
+  _has_bits_[0] &= ~0x00200000u;
 }
 inline void LayerProto::clear_split_param() {
   if (split_param_ != NULL) split_param_->::singa::SplitProto::Clear();
@@ -5807,13 +5930,13 @@ inline void 
LayerProto::set_allocated_split_param(::singa::SplitProto* split_par
 
 // optional .singa.ReLUProto relu_param = 28;
 inline bool LayerProto::has_relu_param() const {
-  return (_has_bits_[0] & 0x00100000u) != 0;
+  return (_has_bits_[0] & 0x00400000u) != 0;
 }
 inline void LayerProto::set_has_relu_param() {
-  _has_bits_[0] |= 0x00100000u;
+  _has_bits_[0] |= 0x00400000u;
 }
 inline void LayerProto::clear_has_relu_param() {
-  _has_bits_[0] &= ~0x00100000u;
+  _has_bits_[0] &= ~0x00400000u;
 }
 inline void LayerProto::clear_relu_param() {
   if (relu_param_ != NULL) relu_param_->::singa::ReLUProto::Clear();
@@ -5845,13 +5968,13 @@ inline void 
LayerProto::set_allocated_relu_param(::singa::ReLUProto* relu_param)
 
 // optional .singa.RGBImage rgbimage_param = 34;
 inline bool LayerProto::has_rgbimage_param() const {
-  return (_has_bits_[0] & 0x00200000u) != 0;
+  return (_has_bits_[0] & 0x00800000u) != 0;
 }
 inline void LayerProto::set_has_rgbimage_param() {
-  _has_bits_[0] |= 0x00200000u;
+  _has_bits_[0] |= 0x00800000u;
 }
 inline void LayerProto::clear_has_rgbimage_param() {
-  _has_bits_[0] &= ~0x00200000u;
+  _has_bits_[0] &= ~0x00800000u;
 }
 inline void LayerProto::clear_rgbimage_param() {
   if (rgbimage_param_ != NULL) rgbimage_param_->::singa::RGBImage::Clear();
@@ -5883,13 +6006,13 @@ inline void 
LayerProto::set_allocated_rgbimage_param(::singa::RGBImage* rgbimage
 
 // optional .singa.SoftmaxLossProto softmaxloss_param = 29;
 inline bool LayerProto::has_softmaxloss_param() const {
-  return (_has_bits_[0] & 0x00400000u) != 0;
+  return (_has_bits_[0] & 0x01000000u) != 0;
 }
 inline void LayerProto::set_has_softmaxloss_param() {
-  _has_bits_[0] |= 0x00400000u;
+  _has_bits_[0] |= 0x01000000u;
 }
 inline void LayerProto::clear_has_softmaxloss_param() {
-  _has_bits_[0] &= ~0x00400000u;
+  _has_bits_[0] &= ~0x01000000u;
 }
 inline void LayerProto::clear_softmaxloss_param() {
   if (softmaxloss_param_ != NULL) 
softmaxloss_param_->::singa::SoftmaxLossProto::Clear();
@@ -5921,13 +6044,13 @@ inline void 
LayerProto::set_allocated_softmaxloss_param(::singa::SoftmaxLossProt
 
 // optional .singa.TanhProto tanh_param = 30;
 inline bool LayerProto::has_tanh_param() const {
-  return (_has_bits_[0] & 0x00800000u) != 0;
+  return (_has_bits_[0] & 0x02000000u) != 0;
 }
 inline void LayerProto::set_has_tanh_param() {
-  _has_bits_[0] |= 0x00800000u;
+  _has_bits_[0] |= 0x02000000u;
 }
 inline void LayerProto::clear_has_tanh_param() {
-  _has_bits_[0] &= ~0x00800000u;
+  _has_bits_[0] &= ~0x02000000u;
 }
 inline void LayerProto::clear_tanh_param() {
   if (tanh_param_ != NULL) tanh_param_->::singa::TanhProto::Clear();

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/831efef0/src/proto/model.proto
----------------------------------------------------------------------
diff --git a/src/proto/model.proto b/src/proto/model.proto
index 4ea621d..dc45313 100644
--- a/src/proto/model.proto
+++ b/src/proto/model.proto
@@ -166,6 +166,7 @@ message LayerProto {
   optional int32 locationid=4 [default=0]; // todo make locationID an array
   optional int32 partitionid=5 [default=0];
   optional PartitionType partition_type=6;
+  optional string datablob=7;
   // can be pos/neg neuron value for CD, neuron value/grad for BP
   //repeated DAryProto ary = 10;
   repeated string share_ary =11;
@@ -188,6 +189,7 @@ message LayerProto {
   optional LRNProto lrn_param = 25;
   optional MnistProto mnist_param= 26;
   optional PoolingProto pooling_param = 27;
+  repeated LayerProto sublayers=35;
   optional SliceProto slice_param = 32;
   optional SplitProto split_param = 33;
   optional ReLUProto relu_param = 28;
@@ -195,6 +197,7 @@ message LayerProto {
   optional SoftmaxLossProto softmaxloss_param = 29;
   optional TanhProto tanh_param=30;
 }
+
 message RGBImage {
   optional float scale=1 [default=1.0];
   optional int32 cropsize=2 [default=0];

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/831efef0/src/trainer/worker.cc
----------------------------------------------------------------------
diff --git a/src/trainer/worker.cc b/src/trainer/worker.cc
index 047ec2d..6ead6c8 100644
--- a/src/trainer/worker.cc
+++ b/src/trainer/worker.cc
@@ -174,13 +174,15 @@ void BPWorker::Forward(shared_ptr<NeuralNet> net, int 
step,  bool training){
           }
         }
       }
+      //clock_t s=clock();
       layer->ComputeFeature(training);
+      //LOG(ERROR)<<layer->name()<<":"<<(clock()-s)*1.0/CLOCKS_PER_SEC;
       if(layer->is_bridgesrclayer()){
         // send fea blobs
       }
-      if(training&&DisplayDebugInfo(step)&&layer->mutable_data()!=nullptr){
+      
if(training&&DisplayDebugInfo(step)&&layer->mutable_data(nullptr)!=nullptr){
         LOG(INFO)<<StringPrintf("Forward layer  %10s data norm1 %13.9f",
-            layer->name().c_str(), layer->data().asum_data());
+            layer->name().c_str(), layer->data(nullptr).asum_data());
       }
     }
   }
@@ -196,9 +198,9 @@ void BPWorker::Backward(shared_ptr<NeuralNet> net, int 
step){
         // receive grad blobs
       }
       layer->ComputeGradient();
-      if(DisplayDebugInfo(step)&&layer->mutable_grad()!=nullptr){
+      if(DisplayDebugInfo(step)&&layer->mutable_grad(nullptr)!=nullptr){
         LOG(INFO)<<StringPrintf("Backward layer %10s grad norm1 %13.9f\t",
-            layer->name().c_str(), layer->grad().asum_data());
+            layer->name().c_str(), layer->grad(nullptr).asum_data());
         for(shared_ptr<Param> p: layer->GetParams())
           LOG(INFO)<<StringPrintf("param id %2d, name %10s,\
               value norm1 %13.9f, grad norm1 %13.9f",

Reply via email to