SINGA-100 Implement layers using CUDNN for GPU training

test with multi-gpus on cifar10; setting batchszie=500, it takes one iteration 
3s on single gpu, and 2s on 2 gpus

fix bug from cudnnsoftmax and cudnnsoftmaxloss; todo debug accuracy problem. 
the accuracy improves slower than that from caffe and cannot reach 0.8 finally.


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/81747603
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/81747603
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/81747603

Branch: refs/heads/master
Commit: 81747603149afeadd8b2b93ce48a170040982acf
Parents: eb97097
Author: Wei Wang <[email protected]>
Authored: Sat Dec 5 23:40:37 2015 +0800
Committer: Wei Wang <[email protected]>
Committed: Fri Dec 11 11:48:24 2015 +0800

----------------------------------------------------------------------
 examples/cifar10/cudnn.conf                   | 39 ++++++++++----
 examples/cifar10/job.conf                     |  8 +--
 include/singa/neuralnet/layer.h               |  3 --
 include/singa/neuralnet/loss_layer.h          | 22 +++++---
 include/singa/neuralnet/neuralnet.h           | 12 +++--
 include/singa/neuralnet/neuron_layer.h        | 12 +++--
 include/singa/neuralnet/output_layer.h        | 13 ++++-
 include/singa/utils/blob.h                    |  6 ++-
 include/singa/utils/context.h                 | 14 +++++
 include/singa/utils/math_blob.h               | 26 ++++-----
 include/singa/utils/math_kernel.h             |  4 +-
 include/singa/utils/param.h                   | 47 +++++++++--------
 src/driver.cc                                 | 12 +++--
 src/neuralnet/input_layer/deprecated.cc       | 22 ++++----
 src/neuralnet/input_layer/image_preprocess.cc |  8 +--
 src/neuralnet/input_layer/store.cc            | 19 +++----
 src/neuralnet/layer.cc                        | 11 ----
 src/neuralnet/loss_layer/cudnn_softmaxloss.cc | 28 ++++++----
 src/neuralnet/loss_layer/euclidean.cc         | 11 +++-
 src/neuralnet/loss_layer/softmax.cc           | 14 ++++-
 src/neuralnet/neuralnet.cc                    |  8 +--
 src/neuralnet/neuron_layer/cudnn_softmax.cc   | 17 +++---
 src/neuralnet/neuron_layer/rbm.cc             | 12 ++++-
 src/neuralnet/neuron_layer/softmax.cc         | 11 ++--
 src/neuralnet/output_layer/accuracy.cc        | 61 ++++++++++++++++++++++
 src/proto/job.proto                           |  1 +
 src/server.cc                                 | 10 ++--
 src/stub.cc                                   |  6 ++-
 src/utils/blob.cc                             | 10 +++-
 src/utils/math_kernel.cu                      | 44 ++++++++--------
 src/utils/param.cc                            | 35 +++++++------
 src/worker.cc                                 | 18 ++++---
 32 files changed, 366 insertions(+), 198 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/81747603/examples/cifar10/cudnn.conf
----------------------------------------------------------------------
diff --git a/examples/cifar10/cudnn.conf b/examples/cifar10/cudnn.conf
index 4a29b7f..6a5e15f 100644
--- a/examples/cifar10/cudnn.conf
+++ b/examples/cifar10/cudnn.conf
@@ -1,4 +1,5 @@
 name: "cifar10-convnet"
+<<<<<<< HEAD
 train_steps: 1000
 test_steps: 0
 test_freq: 0
@@ -6,6 +7,14 @@ test_freq: 0
 #validate_freq: 300
 disp_freq: 30
 gpu: 0
+=======
+train_steps: 13000
+test_steps: 100
+test_freq: 1000
+#validate_steps: 100
+#validate_freq: 300
+disp_freq: 200
+>>>>>>> bfb913f... fix bug from cudnnsoftmax and cudnnsoftmaxloss; todo debug 
accuracy problem. the accuracy improves slower than that from caffe and cannot 
reach 0.8 finally.
 gpu: 1
 #checkpoint_path: "examples/cifar10/checkpoint/step1000-worker0"
 train_one_batch {
@@ -93,6 +102,7 @@ neuralnet {
     param {
       name: "b1"
       lr_scale:2.0
+      wd_scale: 0
       init {
         type: kConstant
         value:0
@@ -148,6 +158,7 @@ neuralnet {
     param {
       name: "b2"
       lr_scale:2.0
+      wd_scale: 0
       init {
         type: kConstant
         value:0
@@ -201,6 +212,8 @@ neuralnet {
     }
     param {
       name: "b3"
+      lr_scale: 2
+      wd_scale: 0
       init {
         type: kConstant
         value:0
@@ -250,22 +263,26 @@ neuralnet {
       }
     }
   }
-#  layer {
-#   name : "softmax"
-#   type: kSoftmax
-#   srclayers: "ip1"
-#  }
-#
-#  layer {
-#   name : "argsort"
-#   type: kArgSort
-#   srclayers: "softmax"
-#  }
+  layer {
+   name : "softmax"
+   type: kCudnnSoftmax
+   srclayers: "ip1"
+   include: kTest
+  }
+
+  layer {
+   name : "accuracy"
+   type: kAccuracy
+   srclayers: "softmax"
+   srclayers: "data"
+   include: kTest
+  }
   layer{
     name: "loss"
     type: kSoftmaxLoss
     srclayers:"ip1"
     srclayers: "data"
+    include : kTrain
   }
 # uncomment "softmax", "argsort", "output" layer and comment "loss" layer
 # to extract features from argsort

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/81747603/examples/cifar10/job.conf
----------------------------------------------------------------------
diff --git a/examples/cifar10/job.conf b/examples/cifar10/job.conf
index 7e42ed8..1dad0f7 100644
--- a/examples/cifar10/job.conf
+++ b/examples/cifar10/job.conf
@@ -1,10 +1,10 @@
 name: "cifar10-convnet"
-train_steps: 30
+train_steps: 5
 test_steps: 100
 test_freq: 0
 #validate_steps: 100
 #validate_freq: 300
-disp_freq: 10
+disp_freq: 1
 debug: true
 #checkpoint_path: "examples/cifar10/checkpoint/step1000-worker0"
 train_one_batch {
@@ -34,8 +34,8 @@ neuralnet {
       backend: "kvfile"
       path: "examples/cifar10/train_data.bin"
       mean_file: "examples/cifar10/image_mean.bin"
-      batchsize: 64
-      random_skip: 5000
+      batchsize: 100
+      #random_skip: 5000
       shape: 3
       shape: 32
       shape: 32

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/81747603/include/singa/neuralnet/layer.h
----------------------------------------------------------------------
diff --git a/include/singa/neuralnet/layer.h b/include/singa/neuralnet/layer.h
index 5e2692e..599203f 100644
--- a/include/singa/neuralnet/layer.h
+++ b/include/singa/neuralnet/layer.h
@@ -308,7 +308,6 @@ class NeuronLayer : virtual public Layer {
  */
 class LossLayer : virtual public Layer {
  public:
-  const std::string ToString(bool debug, int flag) override;
   Blob<float>* mutable_grad(const Layer* layer) override {
     LOG(FATAL) << "Loss layer has no gradient blob";
     return nullptr;
@@ -317,8 +316,6 @@ class LossLayer : virtual public Layer {
     LOG(FATAL) << "Loss layer has no gradient blob";
     return grad_;
   }
- protected:
-  Metric metric_;
 };
 
 /**

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/81747603/include/singa/neuralnet/loss_layer.h
----------------------------------------------------------------------
diff --git a/include/singa/neuralnet/loss_layer.h 
b/include/singa/neuralnet/loss_layer.h
index f78d7e0..3a45370 100644
--- a/include/singa/neuralnet/loss_layer.h
+++ b/include/singa/neuralnet/loss_layer.h
@@ -35,25 +35,32 @@ class EuclideanLossLayer : public LossLayer {
   void Setup(const LayerProto& conf, const vector<Layer*>& srclayers) override;
   void ComputeFeature(int flag, const vector<Layer*>& srclayers) override;
   void ComputeGradient(int flag, const vector<Layer*>& srclayers) override;
+  const std::string ToString(bool debug, int flag) override;
+
+ private:
+  int counter_ = 0;
+  float loss_ = 0.0f;
 };
 /**
  * Cross-entropy loss applied to the probabilities computed from Softmax.
  * @f$ L_i = -log P_{t_i}, t_i\in [0, C] @f$ is the label for the i-th object,
  * C is the total number of classes.
  */
-class SoftmaxLossLayer : public LossLayer, public SoftmaxLayer {
+class SoftmaxLossLayer : public LossLayer {
  public:
   void Setup(const LayerProto& conf, const vector<Layer*>& srclayers) override;
   void ComputeFeature(int flag, const vector<Layer*>& srclayers) override;
   void ComputeGradient(int flag, const vector<Layer*>& srclayers) override;
+  const std::string ToString(bool debug, int flag) override;
 
  private:
+  int batchsize_, topk_, dim_, counter_ = 0;
   float scale_;
-  int topk_, dim_;
+  float loss_ = 0.0f, accuracy_ = 0.0f;
 };
 
 #ifdef USE_CUDNN
-class CudnnSoftmaxLossLayer : public LossLayer, public CudnnSoftmaxLayer {
+class CudnnSoftmaxLossLayer : public LossLayer{
  public:
   void Setup(const LayerProto& conf, const vector<Layer*>& srclayers) override;
   void ComputeFeature(int flag, const vector<Layer*>& srclayers) override;
@@ -61,12 +68,11 @@ class CudnnSoftmaxLossLayer : public LossLayer, public 
CudnnSoftmaxLayer {
   const std::string ToString(bool debug, int flag) override;
 
  private:
-  float scale_;
-  int topk_, dim_;
-  int counter_;
-  float loss_, accuracy_;
+  int batchsize_, dim_;
+  int counter_ = 0;
+  float loss_ = 0.0f;
 
-  CudnnSoftmaxLayer *softmax_;
+  CudnnSoftmaxLayer softmax_;
 };
 #endif
 }  // namespace singa

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/81747603/include/singa/neuralnet/neuralnet.h
----------------------------------------------------------------------
diff --git a/include/singa/neuralnet/neuralnet.h 
b/include/singa/neuralnet/neuralnet.h
index b9a58fe..7dd4edc 100644
--- a/include/singa/neuralnet/neuralnet.h
+++ b/include/singa/neuralnet/neuralnet.h
@@ -90,14 +90,18 @@ class NeuralNet {
    */
   /**
    * Share memory of parameter values from other neuralnet
+   * @param[in] other the neural net from which to share the Params
+   * @param[in] cpu_only if true only share cpu memory; else, share both cpu
+   * and gpu memory.
    */
-  void ShareParamsFrom(NeuralNet* other);
+  void ShareParamsFrom(NeuralNet* other, bool cpu_only);
   inline const std::vector<Layer*>& layers() const { return layers_; }
   inline const std::vector<Param*>& params() const { return params_; }
   inline Layer* name2layer(std::string name) const {
-    CHECK(name2layer_.find(name) != name2layer_.end())
-      << "No layer with name " << name;
-    return name2layer_.at(name);
+    if (name2layer_.find(name) == name2layer_.end())
+      return nullptr;
+    else
+      return name2layer_.at(name);
   }
   inline const std::vector<Layer*>& srclayers(const Layer* layer) const {
     CHECK(src_map_.find(layer) != src_map_.end())

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/81747603/include/singa/neuralnet/neuron_layer.h
----------------------------------------------------------------------
diff --git a/include/singa/neuralnet/neuron_layer.h 
b/include/singa/neuralnet/neuron_layer.h
index 9ae2738..39f4d69 100644
--- a/include/singa/neuralnet/neuron_layer.h
+++ b/include/singa/neuralnet/neuron_layer.h
@@ -217,11 +217,11 @@ class SoftmaxLayer : public NeuronLayer {
     return kOneToAll;
   }
  protected:
-  int batchsize_;
+  int batchsize_, dim_;
   //!< set by users (default is 1)
-  int num_softmax_per_instance_;
+  // int num_softmax_per_instance_;
   //!< size of the softmax area/length
-  int count_per_softmax_;
+  // int count_per_softmax_;
 };
 /**
  * @deprecated {please use ActivationLayer}
@@ -391,16 +391,20 @@ class RBMLayer: virtual public NeuronLayer {
 /**
  * RBM visible layer
  */
-class RBMVisLayer: public RBMLayer, public LossLayer {
+class RBMVisLayer: public RBMLayer {
  public:
   ~RBMVisLayer();
   void Setup(const LayerProto& proto, const vector<Layer*>& srclayers) 
override;
   void ComputeFeature(int flag, const vector<Layer*>& srclayers) override;
   void ComputeGradient(int flag, const vector<Layer*>& srclayers) override;
+  const std::string ToString(bool debug, int flag) override;
 
  private:
   RBMLayer* hid_layer_;
   Layer* input_layer_;
+
+  float error_ = 0.0f;
+  int counter_ = 0;
 };
 /**
  * RBM hidden layer

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/81747603/include/singa/neuralnet/output_layer.h
----------------------------------------------------------------------
diff --git a/include/singa/neuralnet/output_layer.h 
b/include/singa/neuralnet/output_layer.h
index a7d92d7..49a3c19 100644
--- a/include/singa/neuralnet/output_layer.h
+++ b/include/singa/neuralnet/output_layer.h
@@ -40,10 +40,21 @@ class ArgSortLayer : public OutputLayer {
   void Setup(const LayerProto& proto, const vector<Layer*>& srclayers) 
override;
   void ComputeFeature(int flag, const vector<Layer*>& srclayers) override;
 
- private:
+ protected:
   int batchsize_, dim_;
   int topk_;
 };
+
+class AccuracyLayer : public ArgSortLayer {
+ public:
+  void Setup(const LayerProto& proto, const vector<Layer*>& srclayers) 
override;
+  void ComputeFeature(int flag, const vector<Layer*>& srclayers) override;
+  const std::string ToString(bool debug, int flag) override;
+
+ private:
+  int counter_ = 0;
+  float accuracy_ = 0.0f;
+};
 /**
  * Output data (and label) for its source layer.
  */

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/81747603/include/singa/utils/blob.h
----------------------------------------------------------------------
diff --git a/include/singa/utils/blob.h b/include/singa/utils/blob.h
index 97b59e0..b260862 100644
--- a/include/singa/utils/blob.h
+++ b/include/singa/utils/blob.h
@@ -242,8 +242,12 @@ class Blob {
    *
    * It may deallocate the SyncedMemory holding this Blob's data_, as
    * shared_ptr calls its destructor when reset with the "=" operator.
+   * @param other the Blob who owns the data
+   * @param cpu_only if true, only share the cpu data; if false, share the 
whole
+   * data_ field. For training with multi-gpu cards, cpu_only must be true,
+   * becuase gpu memory cannot be shared among different devices.
    */
-  void ShareData(const Blob& other);
+  void ShareData(const Blob& other, bool cpu_only = true);
 
   /*
   void Swap(Blob& other);

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/81747603/include/singa/utils/context.h
----------------------------------------------------------------------
diff --git a/include/singa/utils/context.h b/include/singa/utils/context.h
index 1d1802c..c23b338 100644
--- a/include/singa/utils/context.h
+++ b/include/singa/utils/context.h
@@ -124,6 +124,13 @@ class Context {
   }
 
   /**
+   * \copybreif rand_generator(const std::thread::id&);
+   * @return the CPU random generator for the calling thread.
+   */
+  std::mt19937* rand_generator() {
+    return rand_generator(std::this_thread::get_id());
+  }
+  /**
    * Get the CPU random generator.
    * If the generator does not exist, then create it now.
    * If the seed is not set, i.e., seed=-1, then get a seed from system time.
@@ -142,6 +149,13 @@ class Context {
   }
 #ifdef USE_GPU
   /**
+   * \copybreif cublas_handle_(const std::thread::id&);
+   * @return cublas handle for the calling thread.
+   */
+  cublasHandle_t cublas_handle() {
+    return cublas_handle(std::this_thread::get_id());
+  }
+  /**
    * Get the handler of the GPU which is assigned to the given thread.
    * Calls cublas_handle(const int);
    */

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/81747603/include/singa/utils/math_blob.h
----------------------------------------------------------------------
diff --git a/include/singa/utils/math_blob.h b/include/singa/utils/math_blob.h
index cf989fa..bc38bd4 100644
--- a/include/singa/utils/math_blob.h
+++ b/include/singa/utils/math_blob.h
@@ -498,13 +498,13 @@ void MVAddRow(Dtype alpha, Dtype beta, const Blob<Dtype> 
& A, Blob<Dtype> * B) {
     B->set_transpose(true);
   } else {
     CHECK_EQ(B->count() % A.count(), 0) << "#col of B not match length of A";
-    int m = A.count(), n = B->count() / m;
-    Blob<Dtype> one(n);
-    one.SetValue(1);
+    int n = A.count(), m = B->count() / n;
     auto context = Singleton<Context>::Instance();
     int device = context->device_id(std::this_thread::get_id());
     if (device == -1) {
-      cpu_gemm(one.cpu_data(), A.cpu_data(), n, m, 1, alpha, beta,
+      Blob<Dtype> one(m);
+      one.SetValue(1);
+      cpu_gemm(one.cpu_data(), A.cpu_data(), m, n, 1, alpha, beta,
           false, false, B->mutable_cpu_data());
     } else {
 #ifdef USE_GPU
@@ -554,16 +554,16 @@ template<typename Dtype>
 void MVSumCol(Dtype alpha, Dtype beta, const Blob<Dtype> & A, Blob<Dtype> * B) 
{
   CHECK_EQ(A.count() % B->count(), 0) << "length of B must = # of cols of A";
   int m = B->count(), n = A.count() / m;
-  Blob<Dtype> one(n);
-  one.SetValue(1);
   auto context = Singleton<Context>::Instance();
   int device = context->device_id(std::this_thread::get_id());
   if (device == -1) {
+    Blob<Dtype> one(n);
+    one.SetValue(1);
     cpu_gemm(A.cpu_data(), one.cpu_data(), m, 1, n, alpha, beta,
         A.transpose(), false, B->mutable_cpu_data());
   } else {
 #ifdef USE_GPU
-    singa_gpu_sum_by_col(A.gpu_data(), B->mutable_gpu_data(), m, n, n);
+    singa_gpu_sum_col(A.gpu_data(), B->mutable_gpu_data(), m, n, n);
     // gpu part (TODO check transpose case)
 #endif  // USE_GPU
   }
@@ -578,17 +578,17 @@ void MVSumCol(Dtype alpha, Dtype beta, const Blob<Dtype> 
& A, Blob<Dtype> * B) {
 template<typename Dtype>
 void MVSumRow(Dtype alpha, Dtype beta, const Blob<Dtype> & A, Blob<Dtype> * B) 
{
   CHECK_EQ(A.count() % B->count(), 0) << "length of B must = # of cols of A";
-  int m = B->count(), n = A.count() / m;
-  Blob<Dtype> one(n);
-  one.SetValue(1);
+  int n = B->count(), m = A.count() / n;
   auto context = Singleton<Context>::Instance();
   int device = context->device_id(std::this_thread::get_id());
   if (device == -1) {
-    cpu_gemm(one.cpu_data(), A.cpu_data(), 1, m, n, alpha, beta, A.transpose(),
-      false, B->mutable_cpu_data());
+    Blob<Dtype> one(m);
+    one.SetValue(1);
+    cpu_gemm(one.cpu_data(), A.cpu_data(), 1, n, m, alpha, beta, false, 
A.transpose(),
+        B->mutable_cpu_data());
   } else {
 #ifdef USE_GPU
-    singa_gpu_sum_by_row(A.gpu_data(), B->mutable_gpu_data(), m, n, n);
+    singa_gpu_sum_row(A.gpu_data(), B->mutable_gpu_data(), m, n, n);
     // gpu part (TODO check transpose case)
 #endif  // USE_GPU
   }

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/81747603/include/singa/utils/math_kernel.h
----------------------------------------------------------------------
diff --git a/include/singa/utils/math_kernel.h 
b/include/singa/utils/math_kernel.h
index 59bc3bf..8eb7cf5 100644
--- a/include/singa/utils/math_kernel.h
+++ b/include/singa/utils/math_kernel.h
@@ -32,10 +32,10 @@ extern "C" {
 
   void singa_gpu_sum_vec(float *data, float *sum , int n);
 
-  void singa_gpu_sum_by_col(const float *src_mat_data, float *dst_vec_data,
+  void singa_gpu_sum_col(const float *src_mat_data, float *dst_vec_data,
     int rows, int cols, int stride);
 
-  void singa_gpu_sum_by_row(const float *src_mat_data, float *dst_vec_data,
+  void singa_gpu_sum_row(const float *src_mat_data, float *dst_vec_data,
     int rows, int cols, int stride);
 
   void singa_gpu_add_vec_row(const float *src_vec_data,

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/81747603/include/singa/utils/param.h
----------------------------------------------------------------------
diff --git a/include/singa/utils/param.h b/include/singa/utils/param.h
index bcfc3f9..33b61ff 100644
--- a/include/singa/utils/param.h
+++ b/include/singa/utils/param.h
@@ -143,8 +143,10 @@ class Param {
    * Share the data blob from other Param objects.
    *
    * @param other the Param object whose owner owns the data blob
+   * @param cpu_only if true, share only cpu memory (used for training with
+   * multi-gpu cards); else, share both cpu and gpu memory.
    */
-  void ShareFrom(const Param& other);
+  void ShareFrom(const Param& other, bool cpu_only);
   /**
    * Init param values from checkpoint blob.
    */
@@ -190,28 +192,27 @@ class Param {
     proto_.set_id(id);
     proto_.set_owner(id);
   }
+  inline int version() const { return version_; }
+  inline void set_version(int v) { version_ = v; }
   /**
-   * Param version is stored inside the data blob to enable all Param objs
-   * sharing the same values have the same version.
-   * @return the param version
+   * @return the version of the Param when the last Update request was issued.
    */
-  inline int version() const { return data_->version(); }
-  inline void set_version(int v) { data_->set_version(v); }
+  inline int last_version() const { return last_version_; }
+  inline void set_last_version(int v) { last_version_ = v; }
+
   /**
-   * @return the version of the parameter value local to a worker
+   * @return the sharing Param name which is configured by users in conf file.
    */
-  inline int local_version() const { return local_version_; }
-  inline void set_local_version(int v) { local_version_ = v; }
   inline const std::string& share_from() const { return proto_.share_from(); }
    /**
-    * @return num of floats.
+    * @return num of parameters in this Param obj.
     */
-  inline int size() const { return data_->count(); }
-  inline const Blob<float>& data() const { return *data_; }
-  inline Blob<float>* mutable_data() { return data_.get(); }
+  inline int size() const { return data_.count(); }
+  inline const Blob<float>& data() const { return data_; }
+  inline Blob<float>* mutable_data() { return &data_; }
   inline const Blob<float> &grad() const { return grad_; }
   inline Blob<float> *mutable_grad() { return &grad_; }
-  inline float* mutable_cpu_data() { return data_->mutable_cpu_data(); }
+  inline float* mutable_cpu_data() { return data_.mutable_cpu_data(); }
   inline float* mutable_cpu_grad() { return grad_.mutable_cpu_data(); }
   inline float* mutable_cpu_history() { return history_.mutable_cpu_data(); }
   /**
@@ -333,22 +334,24 @@ class Param {
   void ParseResponseMsg(Msg* msg, int slice_idx);
 
  protected:
-  int local_version_ = -1;
-  // the ID of the first slice
+  //!< param version updated by the Update/Sync/Get response
+  //!< only the owner param is initialized.
+  int version_ = -1;
+  //!< param version before last Update/Sync/Get request, set from version_
+  int last_version_ = -1;
+  //!< the global ID of the first slice
   int slice_start_ = 0;
+  //!< total num of slices for this Parm obj
   int num_slices_ = 0;
   // offset and size of each slice
   std::vector<int> slice_offset_;
   std::vector<int> slice_size_;
-  // for debug checking
-  // since put request has no feedback, we do not track its pending status
+  // for debug. Put request has no feedback, we do not track its pending status
   std::vector<bool> pending_get_;
   std::vector<bool> pending_update_;
   int num_pending_requests_ = 0;
-  // data field
-  std::shared_ptr<Blob<float>> data_ = nullptr;
-  // gradient, history gradient of this parameter
-  Blob<float> grad_, history_;
+  // data, gradient, history gradient of this parameter
+  Blob<float> data_, grad_, history_;
   ParamProto proto_;
 };
 

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/81747603/src/driver.cc
----------------------------------------------------------------------
diff --git a/src/driver.cc b/src/driver.cc
index b963912..b86da7c 100644
--- a/src/driver.cc
+++ b/src/driver.cc
@@ -72,6 +72,7 @@ void Driver::Init(int argc, char **argv) {
   RegisterLayer<BridgeDstLayer, int>(kBridgeDst);
   RegisterLayer<BridgeSrcLayer, int>(kBridgeSrc);
 
+  RegisterLayer<AccuracyLayer, int>(kAccuracy);
   RegisterLayer<ArgSortLayer, int>(kArgSort);
   RegisterLayer<ConvolutionLayer, int>(kConvolution);
   RegisterLayer<CConvolutionLayer, int>(kCConvolution);
@@ -219,10 +220,11 @@ void Driver::Train(const JobProto& job_conf) {
   // CHECK_LE(workers.size(), job_conf.gpu_size());
   for (auto worker : workers) {
     threads.push_back(std::thread(&Worker::Run, worker));
+    int device_id  = -1;
     if (gpu < job_conf.gpu_size()) {
-      int device_id = job_conf.gpu(gpu++);
-      context->SetupDevice(threads.back().get_id(), device_id);
+      device_id = job_conf.gpu(gpu++);
     }
+    context->SetupDevice(threads.back().get_id(), device_id);
   }
   if (grp_size > 1 || nserver_grps > 0) {
     int nservers_per_grp = cluster->nservers_per_group();
@@ -307,16 +309,16 @@ const vector<Worker*> Driver::CreateWorkers(const 
JobProto& job_conf,
       // test and validation are performed by the 1st group.
       if (gid == 0 && job_conf.test_steps() > 0) {
         test_net = NeuralNet::Create(job_conf.neuralnet(), kTest, 1);
-        test_net->ShareParamsFrom(train_net);
+        test_net->ShareParamsFrom(train_net, false);
       }
       if (gid == 0 && job_conf.validate_steps() > 0) {
         val_net = NeuralNet::Create(job_conf.neuralnet(), kVal, 1);
-        val_net->ShareParamsFrom(train_net);
+        val_net->ShareParamsFrom(train_net, false);
       }
     } else {
       train_net = NeuralNet::Create(job_conf.neuralnet(), kTrain, wgrp_size);
       if (cluster->share_memory()) {
-        train_net->ShareParamsFrom(net);
+        train_net->ShareParamsFrom(net, true);
       } else {
         Param::SliceParams(lcm, train_net->params());
       }

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/81747603/src/neuralnet/input_layer/deprecated.cc
----------------------------------------------------------------------
diff --git a/src/neuralnet/input_layer/deprecated.cc 
b/src/neuralnet/input_layer/deprecated.cc
index 0f98279..dfba675 100644
--- a/src/neuralnet/input_layer/deprecated.cc
+++ b/src/neuralnet/input_layer/deprecated.cc
@@ -19,6 +19,7 @@
 *
 *************************************************************/
 
+#include <random>
 #include "singa/neuralnet/input_layer.h"
 #include "singa/utils/context.h"
 #include "singa/utils/singleton.h"
@@ -59,12 +60,11 @@ void ShardDataLayer::ComputeFeature(int flag, const 
vector<Layer*>& srclayers) {
     shard_ = new DataShard(layer_conf_.sharddata_conf().path(),
                            DataShard::kRead);
   if (random_skip_) {
-  std::uniform_int_distribution<int> distribution(0, random_skip_);
-  auto generator =
-    Singleton<Context>::Instance()->generator(std::this_thread::get_id());
-    int nskip = distribution(generator);
+    std::uniform_int_distribution<int> distribution(0, random_skip_);
+    auto generator = Singleton<Context>::Instance()->rand_generator();
+    int nskip = distribution(*generator);
     LOG(INFO) << "Random Skip " << nskip << " records, there are "
-              << shard_->Count() << " records in total";
+      << shard_->Count() << " records in total";
     string key;
     for (int i = 0; i < nskip; i++) {
       shard_->Next(&key, &sample_);
@@ -130,8 +130,8 @@ void LMDBDataLayer::ComputeFeature(int flag, const 
vector<Layer*>& srclayers) {
   if (random_skip_) {
     std::uniform_int_distribution<int> distribution(0, random_skip_);
     auto generator =
-      Singleton<Context>::Instance()->generator(std::this_thread::get_id());
-    int nskip = distribution(generator);
+      
Singleton<Context>::Instance()->rand_generator(std::this_thread::get_id());
+    int nskip = distribution(*generator);
 
     int n = 0;
     CHECK_EQ(mdb_cursor_get(mdb_cursor_, &mdb_key_,
@@ -266,12 +266,12 @@ void RGBImageLayer::ParseRecords(int flag, const 
vector<Record>& records,
 
   std::uniform_int_distribution<int> distribution(0, r.shape(0) - cropsize_);
   auto generator =
-    Singleton<Context>::Instance()->generator(std::this_thread::get_id());
+    Singleton<Context>::Instance()->rand_generator(std::this_thread::get_id());
   for (const Record& record : records) {
     auto image = images[rid];
     bool do_crop = cropsize_> 0 && ((flag & kTrain) == kTrain);
     bool do_mirror = mirror_
-                    && (distribution(generator) % 2)
+                    && (distribution(*generator) % 2)
                     && ((flag & kTrain) == kTrain);
     float* dptr = nullptr;
     if (do_crop || do_mirror)
@@ -289,8 +289,8 @@ void RGBImageLayer::ParseRecords(int flag, const 
vector<Record>& records,
     for (int i = 0; i < mean_.count(); i++)
       dptr[i] -= meandptr[i];
     if (do_crop) {
-      int hoff = distribution(generator);
-      int woff = distribution(generator);
+      int hoff = distribution(*generator);
+      int woff = distribution(*generator);
       Shape<2> cropshape = Shape2(cropsize_, cropsize_);
       if (do_mirror) {
         croped_image = expr::crop(raw_image, cropshape, hoff, woff);

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/81747603/src/neuralnet/input_layer/image_preprocess.cc
----------------------------------------------------------------------
diff --git a/src/neuralnet/input_layer/image_preprocess.cc 
b/src/neuralnet/input_layer/image_preprocess.cc
index 251f456..c63c957 100644
--- a/src/neuralnet/input_layer/image_preprocess.cc
+++ b/src/neuralnet/input_layer/image_preprocess.cc
@@ -58,16 +58,16 @@ void ImagePreprocessLayer::ComputeFeature(int flag,
   std::uniform_int_distribution<int> rand1(0, srcdata.shape()[1] - cropsize_);
   std::uniform_int_distribution<int> rand2(0, srcdata.shape()[2] - cropsize_);
   auto generator =
-    Singleton<Context>::Instance()->generator(std::this_thread::get_id());
+    Singleton<Context>::Instance()->rand_generator(std::this_thread::get_id());
 
   for (int k = 0; k < batchsize; k++) {
     int h_offset = 0, w_offset = 0;
     if (cropsize_> 0 && ((flag & kTrain) == kTrain)) {
-      h_offset = rand1(generator);
-      w_offset = rand2(generator);
+      h_offset = rand1(*generator);
+      w_offset = rand2(*generator);
     }
     bool do_mirror = mirror_
-                    && (rand1(generator) % 2)
+                    && (rand1(*generator) % 2)
                     && ((flag & kTrain) == kTrain);
     ImageTransform(srcdptr + k * srcimage_size, nullptr, do_mirror, cropsize_,
         cropsize_, h_offset, w_offset, srcdata.shape()[1], height, width,

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/81747603/src/neuralnet/input_layer/store.cc
----------------------------------------------------------------------
diff --git a/src/neuralnet/input_layer/store.cc 
b/src/neuralnet/input_layer/store.cc
index dbb1874..283b0c7 100644
--- a/src/neuralnet/input_layer/store.cc
+++ b/src/neuralnet/input_layer/store.cc
@@ -38,13 +38,6 @@ void StoreInputLayer::Setup(const LayerProto& conf,
   if (conf.partition_dim() == 0) {
     batchsize_ /= conf.num_partitions();
   }
-  if (conf.store_conf().random_skip() > 0) {
-    std::uniform_int_distribution<int>
-      distribution(0, conf.store_conf().random_skip());
-    auto generator =
-      Singleton<Context>::Instance()->generator(std::this_thread::get_id());
-    random_skip_ = distribution(generator);
-  }
 }
 
 void StoreInputLayer::ComputeFeature(int flag,
@@ -52,8 +45,16 @@ void StoreInputLayer::ComputeFeature(int flag,
   string key, val;
   if (store_ == nullptr) {
     store_ = io::OpenStore(layer_conf_.store_conf().backend(),
-                             layer_conf_.store_conf().path(),
-                             io::kRead);
+        layer_conf_.store_conf().path(),
+        io::kRead);
+    if (layer_conf_.store_conf().random_skip() > 0) {
+      std::uniform_int_distribution<int>
+        distribution(0, layer_conf_.store_conf().random_skip());
+      auto generator = Singleton<Context>::Instance()->rand_generator(
+          std::this_thread::get_id());
+      random_skip_ = distribution(*generator);
+    }
+
     while (random_skip_ > 0) {
       if (!store_->Read(&key, &val)) {
         store_->SeekToFirst();

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/81747603/src/neuralnet/layer.cc
----------------------------------------------------------------------
diff --git a/src/neuralnet/layer.cc b/src/neuralnet/layer.cc
index df77239..1953d6c 100644
--- a/src/neuralnet/layer.cc
+++ b/src/neuralnet/layer.cc
@@ -62,15 +62,4 @@ const std::string Layer::ToString(bool debug, int flag) {
   }
   return ret;
 }
-
-const std::string LossLayer::ToString(bool debug, int flag) {
-  std::string disp;
-  if (debug) {
-    disp = Layer::ToString(debug, flag);
-  } else {
-    disp = metric_.ToLogString();
-    metric_.Reset();
-  }
-  return disp;
-}
 }  // namespace singa

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/81747603/src/neuralnet/loss_layer/cudnn_softmaxloss.cc
----------------------------------------------------------------------
diff --git a/src/neuralnet/loss_layer/cudnn_softmaxloss.cc 
b/src/neuralnet/loss_layer/cudnn_softmaxloss.cc
index b18a751..78a035a 100644
--- a/src/neuralnet/loss_layer/cudnn_softmaxloss.cc
+++ b/src/neuralnet/loss_layer/cudnn_softmaxloss.cc
@@ -21,21 +21,23 @@
 
 #include "singa/neuralnet/loss_layer.h"
 #include "singa/utils/blob.h"
+#include "singa/utils/math_blob.h"
 #include "singa/utils/math_kernel.h"
 
 namespace singa {
 void CudnnSoftmaxLossLayer::Setup(const LayerProto& conf,
     const vector<Layer*>& srclayers) {
-  CudnnSoftmaxLayer::Setup(conf, vector<Layer*> {srclayers.at(0)});
-  topk_ = conf.softmaxloss_conf().topk();
-  loss_ = accuracy_ = 0.0f;
-  counter_ = 0;
+  softmax_.Setup(conf, vector<Layer*> {srclayers.at(0)});
+  data_.Reshape(softmax_.data(this).shape());
+  data_.ShareData(*softmax_.mutable_data(this), false);
+  batchsize_ = data_.shape(0);
+  dim_ = data_.count() / batchsize_;
+  LOG(ERROR) << batchsize_ << " " << dim_;
 }
 void CudnnSoftmaxLossLayer::ComputeFeature(int flag,
     const vector<Layer*>& srclayers) {
-  CudnnSoftmaxLayer::ComputeFeature(flag, srclayers);
+  softmax_.ComputeFeature(flag, srclayers);
   // compute loss
-  float *prob = data_.mutable_gpu_data();
   Blob<int> label(batchsize_);
   int *labelptr = label.mutable_cpu_data();
 
@@ -46,9 +48,10 @@ void CudnnSoftmaxLossLayer::ComputeFeature(int flag,
 
   Blob<float> loss(batchsize_);
 
+  const float *prob = data_.gpu_data();
   singa_gpu_softmaxloss_forward(batchsize_, dim_, prob, label.gpu_data(),
       loss.mutable_gpu_data());
-
+  loss_ += Asum(loss);
   counter_++;
 }
 
@@ -66,15 +69,18 @@ void CudnnSoftmaxLossLayer::ComputeGradient(int flag,
     labelptr[i] = srclayers[1]->aux_data(this)[i];
   }
 
-  singa_gpu_softmaxloss_backward(batchsize_, dim_, scale_, label.gpu_data(),
+  singa_gpu_softmaxloss_backward(batchsize_, dim_, 1.0f, label.gpu_data(),
       gsrcptr);
+  Scale(1.0f / batchsize_, gsrcblob);
 }
 
 const std::string CudnnSoftmaxLossLayer::ToString(bool debug, int flag) {
-  string disp = "Loss = " + std::to_string(loss_ / counter_)
-    + ", accuracy = " + std::to_string(accuracy_ / counter_);
+  if (debug)
+    return Layer::ToString(debug, flag);
+
+  string disp = "Loss = " + std::to_string(loss_ / counter_);
   counter_ = 0;
-  loss_ = accuracy_ = 0;
+  loss_ = 0;
   return disp;
 }
 }  // namespace singa

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/81747603/src/neuralnet/loss_layer/euclidean.cc
----------------------------------------------------------------------
diff --git a/src/neuralnet/loss_layer/euclidean.cc 
b/src/neuralnet/loss_layer/euclidean.cc
index b6aa12a..49179d6 100644
--- a/src/neuralnet/loss_layer/euclidean.cc
+++ b/src/neuralnet/loss_layer/euclidean.cc
@@ -50,7 +50,8 @@ void EuclideanLossLayer::ComputeFeature(int flag,
       loss += (input_dptr[i] - reconstruct_dptr[i]) *
         (input_dptr[i] - reconstruct_dptr[i]);
   }
-  metric_.Add("loss", loss / srclayers[0]->data(this).shape()[0]);
+  loss_ += loss / srclayers[0]->data(this).shape()[0];
+  counter_ ++;
 }
 
 void EuclideanLossLayer::ComputeGradient(int flag,
@@ -67,5 +68,13 @@ void EuclideanLossLayer::ComputeGradient(int flag,
   Tensor<cpu, 1> gsrc(gsrcptr, Shape1(gsrcblob->count()));
   gsrc /= srclayers[0]->data(this).shape()[0];
 }
+const std::string EuclideanLossLayer::ToString(bool debug, int flag) {
+  if (debug)
+    return Layer::ToString(debug, flag);
 
+  string disp = "Loss = " + std::to_string(loss_ / counter_);
+  counter_ = 0;
+  loss_ = 0;
+  return disp;
+}
 }  // namespace singa

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/81747603/src/neuralnet/loss_layer/softmax.cc
----------------------------------------------------------------------
diff --git a/src/neuralnet/loss_layer/softmax.cc 
b/src/neuralnet/loss_layer/softmax.cc
index 0f3d5bf..ce858ac 100644
--- a/src/neuralnet/loss_layer/softmax.cc
+++ b/src/neuralnet/loss_layer/softmax.cc
@@ -79,8 +79,9 @@ void SoftmaxLossLayer::ComputeFeature(int flag,
     probptr += dim_;
   }
   CHECK_EQ(probptr, prob.dptr + prob.shape.Size());
-  metric_.Add("loss", loss * scale_ / (1.0f * batchsize_));
-  metric_.Add("accuracy", precision * scale_ / (1.0f * batchsize_));
+  loss_ += loss * scale_ / (1.0f * batchsize_);
+  accuracy_ += precision * scale_ / (1.0f * batchsize_);
+  counter_ ++;
 }
 
 void SoftmaxLossLayer::ComputeGradient(int flag,
@@ -95,5 +96,14 @@ void SoftmaxLossLayer::ComputeGradient(int flag,
   Tensor<cpu, 1> gsrc(gsrcptr, Shape1(gsrcblob->count()));
   gsrc *= scale_ / (1.0f * batchsize_);
 }
+const std::string SoftmaxLossLayer::ToString(bool debug, int flag) {
+  if (debug)
+    return Layer::ToString(debug, flag);
 
+  string disp = "Loss = " + std::to_string(loss_ / counter_)
+    + ", accuracy = " + std::to_string(accuracy_ / counter_);
+  counter_ = 0;
+  loss_ = accuracy_ = 0;
+  return disp;
+}
 }  // namespace singa

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/81747603/src/neuralnet/neuralnet.cc
----------------------------------------------------------------------
diff --git a/src/neuralnet/neuralnet.cc b/src/neuralnet/neuralnet.cc
index b7944e7..aabc361 100644
--- a/src/neuralnet/neuralnet.cc
+++ b/src/neuralnet/neuralnet.cc
@@ -155,7 +155,7 @@ std::string NeuralNet::ToAdjacency() {
 }
 */
 
-void NeuralNet::ShareParamsFrom(NeuralNet* other) {
+void NeuralNet::ShareParamsFrom(NeuralNet* other, bool cpu_only) {
   for (auto& layer : layers_) {
     auto otherlayer = other->name2layer(layer->name());
     if (otherlayer != nullptr) {
@@ -163,7 +163,7 @@ void NeuralNet::ShareParamsFrom(NeuralNet* other) {
       const auto& params = layer->GetParams();
       CHECK_EQ(params.size(), otherparams.size());
       for (size_t i = 0; i < params.size(); i++) {
-        params[i]->ShareFrom(*otherparams[i]);
+        params[i]->ShareFrom(*otherparams[i], cpu_only);
       }
     }
   }
@@ -416,7 +416,7 @@ void NeuralNet::CreateNetFromGraph(Graph* graph, int 
npartitions) {
     const string share_from = param->share_from();
     if (param->share_from() != "") {
       if (name2param.find(share_from) != name2param.end()) {
-        param->ShareFrom(*name2param.at(param->share_from()));
+        param->ShareFrom(*name2param.at(param->share_from()), false);
       } else {
         LOG(FATAL) << "No param with the name (share_from) " << share_from;
       }
@@ -430,7 +430,7 @@ void NeuralNet::CreateNetFromGraph(Graph* graph, int 
npartitions) {
       auto params = (*it)->GetParams();
       CHECK_EQ(params.size(), owner_params.size());
       for (size_t i = 0; i < params.size(); i++)
-        params.at(i)->ShareFrom(*owner_params.at(i));
+        params.at(i)->ShareFrom(*owner_params.at(i), true);
     }
   }
 }

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/81747603/src/neuralnet/neuron_layer/cudnn_softmax.cc
----------------------------------------------------------------------
diff --git a/src/neuralnet/neuron_layer/cudnn_softmax.cc 
b/src/neuralnet/neuron_layer/cudnn_softmax.cc
index 21d17c4..a1a492e 100644
--- a/src/neuralnet/neuron_layer/cudnn_softmax.cc
+++ b/src/neuralnet/neuron_layer/cudnn_softmax.cc
@@ -29,15 +29,15 @@ void CudnnSoftmaxLayer::InitCudnn() {
         CUDNN_TENSOR_NCHW,
         CUDNN_DATA_FLOAT,
         batchsize_,
-        num_softmax_per_instance_,
-        count_per_softmax_,
+        dim_,
+        1,
         1));
   CHECK_CUDNN(cudnnSetTensor4dDescriptor(my_desc_,
         CUDNN_TENSOR_NCHW,
         CUDNN_DATA_FLOAT,
         batchsize_,
-        num_softmax_per_instance_,
-        count_per_softmax_,
+        dim_,
+        1,
         1));
 }
 
@@ -46,12 +46,13 @@ void CudnnSoftmaxLayer::ComputeFeature(int flag,
   if (!has_init_cudnn_)
     InitCudnn();
   const float alpha = 1.0f, beta = 0.0f;
+  CHECK_EQ(srclayers.at(0)->data(this).shape().size(), 2);
   CHECK_CUDNN(cudnnSoftmaxForward(handle_,
         CUDNN_SOFTMAX_ACCURATE,
-        CUDNN_SOFTMAX_MODE_CHANNEL,
+        CUDNN_SOFTMAX_MODE_INSTANCE,
         &alpha,
         src_desc_,
-        srclayers[0]->data(this).gpu_data(),
+        srclayers.at(0)->data(this).gpu_data(),
         &beta,
         my_desc_,
         data_.mutable_gpu_data()));
@@ -62,7 +63,7 @@ void CudnnSoftmaxLayer::ComputeGradient(int flag,
   const float alpha = 1.f, beta = 0.f;
   CHECK_CUDNN(cudnnSoftmaxBackward(handle_,
         CUDNN_SOFTMAX_ACCURATE,
-        CUDNN_SOFTMAX_MODE_CHANNEL,
+        CUDNN_SOFTMAX_MODE_INSTANCE,
         &alpha,
         my_desc_,
         data_.gpu_data(),
@@ -70,6 +71,6 @@ void CudnnSoftmaxLayer::ComputeGradient(int flag,
         grad_.gpu_data(),
         &beta,
         src_desc_,
-        srclayers[0]->mutable_grad(this)->mutable_gpu_data()));
+        srclayers.at(0)->mutable_grad(this)->mutable_gpu_data()));
 }
 }  // namespace singa

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/81747603/src/neuralnet/neuron_layer/rbm.cc
----------------------------------------------------------------------
diff --git a/src/neuralnet/neuron_layer/rbm.cc 
b/src/neuralnet/neuron_layer/rbm.cc
index fadd1df..69ffa62 100644
--- a/src/neuralnet/neuron_layer/rbm.cc
+++ b/src/neuralnet/neuron_layer/rbm.cc
@@ -111,10 +111,11 @@ void RBMVisLayer::ComputeFeature(int flag, const 
vector<Layer*>& srclayers) {
       for (int i = 0; i < pos_data_.count(); i++) {
         err += (dptr[i] - rcns[i]) * (dptr[i] - rcns[i]);
       }
-      metric_.Add("Squared Error", err / batchsize_);
+      error_ += err / batchsize_;
     }
     first_gibbs_ = false;
   }
+  counter_ += batchsize_;
 }
 
 void RBMVisLayer::ComputeGradient(int flag, const vector<Layer*>& srclayers) {
@@ -133,6 +134,15 @@ void RBMVisLayer::ComputeGradient(int flag, const 
vector<Layer*>& srclayers) {
   gweight -= dot(hid_pos.T(), vis_pos);
   gweight /= batchsize_;
 }
+const std::string RBMVisLayer::ToString(bool debug, int flag) {
+  if (debug)
+    return Layer::ToString(debug, flag);
+
+  string disp = "Squared Error = " + std::to_string(error_ / counter_);
+  counter_ = 0;
+  error_ = 0;
+  return disp;
+}
 /**************** Implementation for RBMHidLayer********************/
 RBMHidLayer::~RBMHidLayer() {
   delete weight_;

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/81747603/src/neuralnet/neuron_layer/softmax.cc
----------------------------------------------------------------------
diff --git a/src/neuralnet/neuron_layer/softmax.cc 
b/src/neuralnet/neuron_layer/softmax.cc
index 44c64f8..4a09241 100644
--- a/src/neuralnet/neuron_layer/softmax.cc
+++ b/src/neuralnet/neuron_layer/softmax.cc
@@ -38,11 +38,13 @@ void SoftmaxLayer::Setup(const LayerProto& proto,
   CHECK_EQ(srclayers.size(), 1);
   NeuronLayer::Setup(proto, srclayers);
   const auto& srcdata = srclayers[0]->data(this);
-  batchsize_ = data_.shape()[0];
+  batchsize_ = srcdata.shape()[0];
+  dim_ = srcdata.count() / batchsize_;
+  /*
   num_softmax_per_instance_ = proto.softmax_conf().num_softmax_per_instance();
-  count_per_softmax_ = data_.count() / batchsize_ / num_softmax_per_instance_;
-  data_.Reshape(vector<int>{batchsize_, num_softmax_per_instance_,
-      count_per_softmax_});
+  count_per_softmax_ = srcdata.count() / batchsize_ / 
num_softmax_per_instance_;
+  */
+  data_.Reshape(batchsize_, dim_);
   grad_.ReshapeLike(data_);
 }
 
@@ -58,6 +60,7 @@ void SoftmaxLayer::ComputeFeature(int flag,
 void SoftmaxLayer::ComputeGradient(int flag,
     const vector<Layer*>& srclayers) {
   int batchsize = data_.shape()[0];
+  LOG(FATAL) << "not implemented";
   for (int n = 0; n < batchsize; n++) {
     // TODO(wangwei) finish the code using new math API
     // gxi=[(gyi+gyi*yi)-\sum_k(gyk*yk)]*yi

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/81747603/src/neuralnet/output_layer/accuracy.cc
----------------------------------------------------------------------
diff --git a/src/neuralnet/output_layer/accuracy.cc 
b/src/neuralnet/output_layer/accuracy.cc
new file mode 100644
index 0000000..21107df
--- /dev/null
+++ b/src/neuralnet/output_layer/accuracy.cc
@@ -0,0 +1,61 @@
+/************************************************************
+*
+* Licensed to the Apache Software Foundation (ASF) under one
+* or more contributor license agreements.  See the NOTICE file
+* distributed with this work for additional information
+* regarding copyright ownership.  The ASF licenses this file
+* to you under the Apache License, Version 2.0 (the
+* "License"); you may not use this file except in compliance
+* with the License.  You may obtain a copy of the License at
+*
+*   http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing,
+* software distributed under the License is distributed on an
+* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+* KIND, either express or implied.  See the License for the
+* specific language governing permissions and limitations
+* under the License.
+*
+*************************************************************/
+
+#include <algorithm>
+#include "singa/neuralnet/output_layer.h"
+
+namespace singa {
+
+void AccuracyLayer::Setup(const LayerProto& proto,
+    const vector<Layer*>& srclayers) {
+  CHECK_EQ(srclayers.size(), 2);
+  ArgSortLayer::Setup(proto, vector<Layer*>{srclayers.at(0)});
+}
+
+void AccuracyLayer::ComputeFeature(int flag,
+    const vector<Layer*>& srclayers) {
+  ArgSortLayer::ComputeFeature(flag, vector<Layer*>{srclayers.at(0)});
+  const auto& label = srclayers[1]->aux_data(this);
+  int ncorrect = 0;
+  for (int n = 0; n < batchsize_; n++) {
+    const float* pos = data_.cpu_data() + topk_ * n;
+    // check if true label is in top k predictions
+    for (int k = 0; k < topk_; k++) {
+      if (pos[k] == label[n]) {
+        ncorrect++;
+        break;
+      }
+    }
+  }
+  accuracy_ += ncorrect * 1.0f / batchsize_;
+  counter_ ++;
+}
+
+const std::string AccuracyLayer::ToString(bool debug, int flag) {
+  if (debug)
+    return Layer::ToString(debug, flag);
+
+  string disp = "accuracy = " + std::to_string(accuracy_ / counter_);
+  counter_ = 0;
+  accuracy_ = 0;
+  return disp;
+}
+}  // namespace singa

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/81747603/src/proto/job.proto
----------------------------------------------------------------------
diff --git a/src/proto/job.proto b/src/proto/job.proto
index 88caa44..ff27c05 100644
--- a/src/proto/job.proto
+++ b/src/proto/job.proto
@@ -602,6 +602,7 @@ enum LayerType {
   kRGBImage = 10;
   // Neuron layers
   //  - Feature transformation
+  kAccuracy = 36;
   kArgSort = 35;
   kConvolution = 1;
   kCConvolution = 27;

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/81747603/src/server.cc
----------------------------------------------------------------------
diff --git a/src/server.cc b/src/server.cc
index 9af5fc4..003d4e3 100644
--- a/src/server.cc
+++ b/src/server.cc
@@ -153,7 +153,7 @@ Msg* Server::HandlePut(Msg **msg) {
   shard_[slice_id] = new ParamEntry(num_shares, param);
   // must set version after HandlePutMsg which allocates the memory
   param->set_version(version);
-  param->set_local_version(version);
+  param->set_last_version(version);
   param->set_id(slice_id);
   // allocate blob for param sync between groups.
   if (slice2group_[slice_id] != grp_id_) {
@@ -201,16 +201,18 @@ const vector<Msg*> Server::HandleUpdate(Msg **msg) {
     // extract and aggregate gradients
     param->ParseUpdateMsgs(request);
     updater_->Update(step, param, 1.0f / entry->num_total);
-    param->set_local_version(param->local_version() + 1);
+    param->set_version(param->version() + 1);
     // response to all shares of this param
     for (auto response : param->GenUpdateResponseMsgs(&request, false)) {
-      response->set_trgt(trgt_val, param->local_version());
+      response->set_trgt(trgt_val, param->version());
       ret.push_back(response);
     }
     entry->num_update = 0;
     n_updates_[sliceid]++;
     // sync with master group after at least sync_freq local updates
     // the last check is to avoid sending msg to stopped servers
+    // may send the update steps on this server since last sync, i.e.,
+    // version-last_version
     if (slice2group_[sliceid] != grp_id_
         && n_updates_[sliceid] >= Cluster::Get()->sync_freq()
         && n_pending_sync_[sliceid] <= Cluster::Get()->sync_freq()) {
@@ -221,7 +223,7 @@ const vector<Msg*> Server::HandleUpdate(Msg **msg) {
       int addr = Addr(slice2group_[sliceid], slice2server_[sliceid], kServer);
       Msg* sync = new Msg(Addr(grp_id_, id_, kServer), addr);
       sync->set_type(kSyncRequest);
-      sync->set_trgt(trgt_val, param->local_version());
+      sync->set_trgt(trgt_val, param->version());
       sync->AddFrame(tmp.dptr, param->size() * sizeof(float));
       Copy(tmp, cur);
       ret.push_back(sync);

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/81747603/src/stub.cc
----------------------------------------------------------------------
diff --git a/src/stub.cc b/src/stub.cc
index a3605f7..184a1f3 100644
--- a/src/stub.cc
+++ b/src/stub.cc
@@ -270,7 +270,8 @@ void Stub::HandleGetResponse(ParamEntry* entry, Msg** msg) {
   int sliceid = SliceID((*msg)->trgt_val());
   auto param = entry->shares.at(0);
   if (param->ParseGetResponseMsg(*msg, sliceid-param->slice_start()))
-    param->set_version(version);
+    for (auto *p : entry->shares)
+      p->set_version(version);
   DeleteMsg(msg);
 }
 
@@ -279,7 +280,8 @@ void Stub::HandleUpdateResponse(ParamEntry* entry, Msg** 
msg) {
   int sliceid = SliceID((*msg)->trgt_val());
   auto param = entry->shares.at(0);
   if (param->ParseUpdateResponseMsg(*msg, sliceid-param->slice_start()))
-    param->set_version(version);
+    for (auto *p : entry->shares)
+      p->set_version(version);
   DeleteMsg(msg);
 }
 }  // namespace singa

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/81747603/src/utils/blob.cc
----------------------------------------------------------------------
diff --git a/src/utils/blob.cc b/src/utils/blob.cc
index 0dc797e..9607683 100644
--- a/src/utils/blob.cc
+++ b/src/utils/blob.cc
@@ -187,12 +187,15 @@ void SyncedMemory::to_gpu() {
 
 template <typename Dtype>
 void Blob<Dtype>::Reshape(const std::vector<int>& shape) {
+  int count = count_;
   count_ = 1;
   shape_ = shape;
   for (size_t i = 0; i < shape.size(); ++i) {
     CHECK(shape[i]);
     count_ *= shape[i];
   }
+  if (count > 0)
+    CHECK_EQ(count, count_);
   if (count_ > capacity_) {
     capacity_ = count_;
     data_.reset(new SyncedMemory(capacity_ * sizeof(Dtype)));
@@ -262,9 +265,12 @@ void Blob<Dtype>::SetValue(Dtype v) {
     ptr[i] = v;
 }
 template <typename Dtype>
-void Blob<Dtype>::ShareData(const Blob& other) {
+void Blob<Dtype>::ShareData(const Blob& other, bool cpu_only) {
   CHECK_EQ(count_, other.count());
-  data_ = other.data_;
+  if (cpu_only)
+    data_->set_cpu_data(other.cpu_data());
+  else
+    data_ = other.data_;
 }
 
 /*

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/81747603/src/utils/math_kernel.cu
----------------------------------------------------------------------
diff --git a/src/utils/math_kernel.cu b/src/utils/math_kernel.cu
index 9c54520..3650c09 100644
--- a/src/utils/math_kernel.cu
+++ b/src/utils/math_kernel.cu
@@ -37,8 +37,7 @@ void kernel_softmax_loss(const float *prob, const int *label 
, float *loss,
   int index = blockIdx.x * blockDim.x + threadIdx.x;
   int num_threads = blockDim.x * gridDim.x;
   for (; index < n; index += num_threads) {
-    const int label_value = static_cast<int>(label[index]);
-    float prob_of_truth = prob[index * dim + label_value];
+    float prob_of_truth = prob[index * dim + label[index]];
     loss[index] -= log(max(prob_of_truth, FLT_MIN));
   }
 }
@@ -49,8 +48,8 @@ void kernel_softmax_gradient(float *grad, const int *label ,
   int index = blockIdx.x * blockDim.x + threadIdx.x;
   int num_threads = blockDim.x * gridDim.x;
   for (; index < n; index += num_threads) {
-    int pos = index * dim + static_cast<int>(label[index]);
-    grad[pos] = (grad[pos] - 1.0f) * scale / (1.0 * n);
+    int pos = index * dim + label[index];
+    grad[pos] = (grad[pos] - 1.0f) * scale;
   }
 }
 
@@ -87,7 +86,20 @@ void kernel_sum_vec(float *data, float *sum , int n) {
 }
 
 __global__
-void kernel_sum_by_col(const float *src_mat_data,
+void kernel_sum_col(const float *src_mat_data,
+    float *dst_vec_data, int rows, int cols, int stride) {
+  int index = blockIdx.x * blockDim.x + threadIdx.x;
+  int num_threads = blockDim.x * gridDim.x;
+  for (; index < rows; index += num_threads) {
+    dst_vec_data[index] = 0.0f;
+    for (int k = 0; k < cols; k++) {
+      dst_vec_data[index] += src_mat_data[index * stride + k];
+    }
+  }
+}
+
+__global__
+void kernel_sum_row(const float *src_mat_data,
     float *dst_vec_data, int rows, int cols, int stride) {
   int j = blockIdx.x;
   int THREADS = blockDim.x;
@@ -119,19 +131,9 @@ void kernel_sum_by_col(const float *src_mat_data,
 
   __syncthreads();
   dst_vec_data[j] = aux[0];
-}
 
-__global__
-void kernel_sum_by_row(const float *src_mat_data,
-    float *dst_vec_data, int rows, int cols, int stride) {
-  int index = blockIdx.x * blockDim.x + threadIdx.x;
-  int num_threads = blockDim.x * gridDim.x;
-  for (; index < rows; index += num_threads) {
-    dst_vec_data[index] = 0.0f;
-    for (int k = 0; k < cols; k++) {
-      dst_vec_data[index] += src_mat_data[index * stride + k];
-    }
-  }
+
+
 }
 
 __global__
@@ -337,21 +339,21 @@ void singa_gpu_sum_vec(float *data, float *sum , int n) {
   kernel_sum_vec<<<num_blocks, threads_per_block>>>(data, sum, n);
 }
 
-void singa_gpu_sum_by_col(const float *src_mat_data, float *dst_vec_data,
+void singa_gpu_sum_col(const float *src_mat_data, float *dst_vec_data,
     int rows, int cols, int stride) {
   int threads_per_block = rows > CU1DBLOCK ? CU1DBLOCK : rows;
   int num_blocks = cols;
 
-  kernel_sum_by_col<<<num_blocks, threads_per_block>>>(src_mat_data,
+  kernel_sum_col<<<num_blocks, threads_per_block>>>(src_mat_data,
       dst_vec_data, rows, cols, stride);
 }
 
-void singa_gpu_sum_by_row(const float *src_mat_data, float *dst_vec_data,
+void singa_gpu_sum_row(const float *src_mat_data, float *dst_vec_data,
     int rows, int cols, int stride) {
   int threads_per_block = cols > CU1DBLOCK ? CU1DBLOCK : cols;
   int num_blocks = rows;
 
-  kernel_sum_by_row<<<num_blocks, threads_per_block>>>(src_mat_data,
+  kernel_sum_row<<<num_blocks, threads_per_block>>>(src_mat_data,
       dst_vec_data, rows, cols, stride);
 }
 

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/81747603/src/utils/param.cc
----------------------------------------------------------------------
diff --git a/src/utils/param.cc b/src/utils/param.cc
index 54fe2aa..09f519b 100644
--- a/src/utils/param.cc
+++ b/src/utils/param.cc
@@ -151,7 +151,7 @@ void Param::SliceParams(int num, const vector<Param*>& 
params) {
 }
 
 void Param::Setup(const vector<int>& shape) {
-  data_ = std::make_shared<Blob<float>>(shape);
+  data_.Reshape(shape);
   grad_.Reshape(shape);
   history_.Reshape(shape);
 }
@@ -162,17 +162,18 @@ void Param::InitValues() {
 
 void Param::InitValues(int version) {
   ParamGenerator* gen = ParamGenerator::Create(proto_.init());
-  gen->Fill(data_.get());
+  gen->Fill(&data_);
   set_version(version);
 }
 
-void Param::ShareFrom(const Param& other) {
+void Param::ShareFrom(const Param& other, bool cpu_only) {
   proto_.set_owner(other.owner());
-  if (data_ != nullptr)
-    CHECK(data_->shape() == other.data_->shape());
-  data_ = other.data_;
+  CHECK(data_.shape() == other.data_.shape());
+  data_.ShareData(other.data_, cpu_only);
   if (grad_.count() == 0)
-    grad_.Reshape(data_->shape());
+    grad_.Reshape(data_.shape());
+  version_ = other.version_;
+  last_version_ = other.last_version_;
   slice_start_ = other.slice_start_;
   num_slices_ = other.num_slices_;
   slice_offset_ = other.slice_offset_;
@@ -183,11 +184,11 @@ void Param::ShareFrom(const Param& other) {
 }
 
 void Param::FromProto(const BlobProto& blob) {
-  data_->FromProto(blob);
+  data_.FromProto(blob);
 }
 
 void Param::ToProto(BlobProto* blob) {
-  data_->ToProto(blob);
+  data_.ToProto(blob);
 }
 
 void Param::AddSlice(int slice_id, int size) {
@@ -225,7 +226,7 @@ Msg* Param::GenGetMsg(bool copy, int idx) {
   CHECK_LT(idx, num_slices_);
   Msg* msg = new Msg();
   msg->set_type(kGet);
-  msg->AddFormatFrame("ip",  copy, data_->cpu_data() + slice_offset_[idx]);
+  msg->AddFormatFrame("ip",  copy, data_.cpu_data() + slice_offset_[idx]);
   pending_get_[idx] = true;
   num_pending_requests_++;
   return msg;
@@ -239,7 +240,7 @@ Msg* Param::GenUpdateMsg(bool copy, int idx) {
   void* ptr = grad_.mutable_cpu_data() + slice_offset_[idx];
   // to change the head of SyncMem to cpu; otherwise, the updated parameter
   //   // values would not be synced to gpu (since the head is at gpu).
-  data_->mutable_cpu_data();
+  data_.mutable_cpu_data();
   if (copy) {
     msg->AddFrame(ptr, slice_size_[idx]*sizeof(float));
   } else {
@@ -254,9 +255,9 @@ Msg* Param::GenUpdateMsg(bool copy, int idx) {
 Msg* Param::GenSyncMsg(int offset, int size) {
   Msg* msg = new Msg();
   msg->set_type(kSyncRequest);
-  msg->set_trgt(ParamTrgt(-1, id()), local_version());
+  msg->set_trgt(ParamTrgt(-1, id()), last_version());
   // always copy data because syn is between server groups in diff procs
-  msg->AddFrame(mutable_cpu_data(), data_->count()*sizeof(float));
+  msg->AddFrame(mutable_cpu_data(), data_.count()*sizeof(float));
   return msg;
 }
 
@@ -278,7 +279,7 @@ Msg* Param::HandlePutMsg(Msg** msg, bool reserve) {
     CHECK_EQ(size * sizeof(float), (*msg)->FrameSize());
     memcpy(mutable_cpu_data(), (*msg)->FrameData(), size * sizeof(float));
   } else {
-    data_->set_cpu_data(ptr);
+    data_.set_cpu_data(ptr);
   }
   if (!reserve) DeleteMsg(msg);
   return nullptr;
@@ -292,7 +293,7 @@ Msg* Param::HandleGetMsg(Msg** msg, bool reserve) {
   (*msg)->ParseFormatFrame("ip", &copy, &ptr);
   if (copy) {
     (*msg)->AddFrame(mutable_cpu_data(), sizeof(float) * size());
-  } else if (ptr != data_->cpu_data()) {
+  } else if (ptr != data_.cpu_data()) {
     // this case reflects following situation:
     // worker 0 and server are in the same process, while worker 1 is not.
     // worker 1 "put" data into server, so server need to allocate memory.
@@ -300,8 +301,8 @@ Msg* Param::HandleGetMsg(Msg** msg, bool reserve) {
     //  1. copy the data to the worker0 provided space
     //  2. change its own pointer to that space in order to share memory
     // in this case, the server always points to last worker's space
-    memcpy(ptr, data_->cpu_data(), sizeof(float) * size());
-    data_->set_cpu_data(ptr);
+    memcpy(ptr, data_.cpu_data(), sizeof(float) * size());
+    data_.set_cpu_data(ptr);
   }
   // else the mem space is shared among all worker and servers
   Msg* ret = nullptr;

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/81747603/src/worker.cc
----------------------------------------------------------------------
diff --git a/src/worker.cc b/src/worker.cc
index a41e3a8..333408d 100644
--- a/src/worker.cc
+++ b/src/worker.cc
@@ -64,10 +64,11 @@ void Worker::Run() {
   // setup gpu device
   auto context = Singleton<Context>::Instance();
   int device = context->device_id(std::this_thread::get_id());
-  LOG(ERROR) << "Worker (group = " << grp_id_ <<", id = " << id_
-    << ") start on device " << device;
+  LOG(ERROR) << "Worker (group = " << grp_id_ <<", id = " << id_ << ") "
+    << " start on " << (device >= 0 ? "GPU " + std::to_string(device) : "CPU");
   if (device >= 0)
     context->ActivateDevice(device);
+
   auto cluster = Cluster::Get();
   int svr_grp = grp_id_ / cluster->nworker_groups_per_server_group();
   CHECK(cluster->runtime()->JoinSGroup(grp_id_, id_, svr_grp));
@@ -76,12 +77,12 @@ void Worker::Run() {
   InitNetParams(job_conf_, train_net_);
   while (!StopNow(step_)) {
     if (ValidateNow(step_) && val_net_ != nullptr) {
-      CollectAll(step_, val_net_);
+      CollectAll(step_, train_net_);
       LOG(ERROR) << "Validation @ step " + std::to_string(step_);
       Test(job_conf_.validate_steps(), kVal, val_net_);
     }
     if (TestNow(step_) && test_net_ != nullptr) {
-      CollectAll(step_, test_net_);
+      CollectAll(step_, train_net_);
       LOG(ERROR) << "Test @ step " + std::to_string(step_);
       Test(job_conf_.test_steps(), kTest, test_net_);
     }
@@ -195,8 +196,8 @@ void Worker::InitNetParams(const JobProto& job_conf, 
NeuralNet* net) {
     }
 
     // warmup training before put params to servers
-    for (; step_ < job_conf.warmup_steps(); step_++)
-      TrainOneBatch(step_, net);
+    // for (; step_ < job_conf.warmup_steps(); step_++)
+    //  TrainOneBatch(step_, net);
     for (auto layer : net->layers()) {
       if (layer->partition_id() == id_)
         for (auto param : layer->GetParams())
@@ -213,6 +214,7 @@ void Worker::InitNetParams(const JobProto& job_conf, 
NeuralNet* net) {
   }
 }
 
+
 void Worker::Checkpoint(int step, const std::string& folder, NeuralNet* net) {
   BlobProtos bps;
   for (auto layer : net->layers()) {
@@ -261,7 +263,7 @@ int Worker::Get(int step, Param* param) {
 }
 
 int Worker::Update(int step, Param* param) {
-  param->set_local_version(param->version());
+  param->set_last_version(param->version());
   if (dealer_ == nullptr) {
     LOG(WARNING) << "Null dealer in worker (" << grp_id_ << ", " << id_ << ")";
     return 1;
@@ -286,7 +288,7 @@ int Worker::CollectAll(int step, NeuralNet* net) {
 }
 
 int Worker::Collect(int step, Param* param) {
-  while (param->version() <= param->local_version())
+  while (param->version() <= param->last_version())
     std::this_thread::sleep_for(std::chrono::milliseconds(kCollectSleepTime));
   return 1;
 }

Reply via email to