Repository: incubator-singa
Updated Branches:
  refs/heads/master bd2e3453c -> f16b1be6f


SINGA-110 Add Layer member datavec_ and gradvec_

* add new member datavec_ and gradvec_, type vector<Blob<float>*>
* update Blob<float>& data(const Layer* from) function and Blob<float>* 
mutable_data(const Layer* from) in layer.h
* in RBM.cc, add pos_data_ and pos_sample_ to make the program easier for 
reading
* in RBM.cc, pos_data_, pos_sample_, neg_data_, neg_sample_ will be pushed to 
datavec_


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/e8d01dca
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/e8d01dca
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/e8d01dca

Branch: refs/heads/master
Commit: e8d01dca8b48d0fd4e9870747b0600b961689981
Parents: bd2e345
Author: zhaojing <[email protected]>
Authored: Wed Dec 2 22:29:01 2015 +0800
Committer: zhaojing <[email protected]>
Committed: Tue Dec 8 11:50:26 2015 +0800

----------------------------------------------------------------------
 include/singa/neuralnet/layer.h            | 50 +++++++++++++++++++++++++
 include/singa/neuralnet/neuron_layer/rbm.h | 10 +----
 src/neuralnet/neuron_layer/rbm.cc          | 49 +++++++++++++-----------
 3 files changed, 80 insertions(+), 29 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/e8d01dca/include/singa/neuralnet/layer.h
----------------------------------------------------------------------
diff --git a/include/singa/neuralnet/layer.h b/include/singa/neuralnet/layer.h
index 3f63165..a78deb7 100644
--- a/include/singa/neuralnet/layer.h
+++ b/include/singa/neuralnet/layer.h
@@ -69,6 +69,8 @@ class Layer {
    */
   virtual void Setup(const LayerProto& conf, const vector<Layer*>& srclayers) {
     layer_conf_ = conf;
+    datavec_.push_back(&data_);
+    gradvec_.push_back(&grad_);
   }
   /**
    * Compute features of this layer based on connected layers.
@@ -178,18 +180,41 @@ class Layer {
    * more than one data Blob. In this case, this argument identifies the layer
    * that is requesting the data Blob.
    * @return a const ref for Blob storing feature values of this layer.
+   * @deprecated {This function will be deleted, use
+   * virtual const vector<Blob<float>>& data() const or
+   * virtual const Blob<float>& data(int k) const instead}.
    */
   virtual const Blob<float>& data(const Layer* from) const {
     return data_;
   }
   /**
+   * @return a const ref for Blob vector storing feature values of this layer.
+   */
+  virtual const vector<Blob<float>*>& data() const {
+    return datavec_;
+  }
+  /**
+   * @return a const ref for the kth Blob.
+   */
+  virtual const Blob<float>& data(int k) const {
+    return *datavec_.at(k);
+  }
+  /**
    * @see data().
    * @return the pointer to the Blob storing feature values of this layer.
+   * @deprecated {This function will be deleted, use
+   * virtual Blob<float>* mutable_data(int k) instead}.
    */
   virtual Blob<float>* mutable_data(const Layer* from) {
     return &data_;
   }
   /**
+   * @return the pointer to the kth Blob.
+   */
+  virtual Blob<float>* mutable_data(int k) {
+    return datavec_.at(k);
+  }
+  /**
    * @return auxiliary data, e.g., image label.
    */
   virtual const vector<AuxType>& aux_data(const Layer* from = nullptr) const {
@@ -199,23 +224,48 @@ class Layer {
    * @see data().
    * @return the const ref of the Blob for the gradient of this layer, mainly
    * used in BP algorithm.
+   * @deprecated {This function will be deleted, use
+   * virtual const vector<Blob<float>>& grad() const or
+   * virtual const Blob<float>& grad(int k) const instead}.
    */
   virtual const Blob<float>& grad(const Layer* from) const {
     return grad_;
   }
   /**
    * @see data().
+   * @return the const ref of the Blob vector for the gradient of this layer.
+   */
+  virtual const vector<Blob<float>*>& grad() const {
+    return gradvec_;
+  }
+  /**
+   * @return the const ref of the kth Blob for the gradient of this layer.
+   */
+  virtual const Blob<float>& grad(int k) const {
+    return *gradvec_.at(k);
+  }
+  /**
+   * @see data().
    * @return a pointer to the Blob storing gradients of this layer, mainly
    * used in BP algorithm.
    */
   virtual Blob<float>* mutable_grad(const Layer* from) {
     return &grad_;
   }
+  /**
+   * @see data().
+   * @return a pointer to the kth Blob storing gradients of this layer, mainly
+   * used in BP algorithm.
+   */
+  virtual Blob<float>* mutable_grad(int k) {
+    return gradvec_.at(k);
+  }
 
  protected:
   LayerProto layer_conf_;
   Blob<float> data_, grad_;
   vector<AuxType> aux_data_;
+  vector<Blob<float>*> datavec_, gradvec_;
 };
 
 /**

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/e8d01dca/include/singa/neuralnet/neuron_layer/rbm.h
----------------------------------------------------------------------
diff --git a/include/singa/neuralnet/neuron_layer/rbm.h 
b/include/singa/neuralnet/neuron_layer/rbm.h
index 7c6f81a..432c499 100644
--- a/include/singa/neuralnet/neuron_layer/rbm.h
+++ b/include/singa/neuralnet/neuron_layer/rbm.h
@@ -34,12 +34,6 @@ class RBMLayer: virtual public Layer {
  public:
   virtual ~RBMLayer() {}
   void Setup(const LayerProto& proto, const vector<Layer*>& srclayers) 
override;
-  const Blob<float>& neg_data(const Layer* layer) {
-    return neg_data_;
-  }
-  Blob<float>* mutable_neg_data(const Layer* layer) {
-    return &neg_data_;
-  }
   const std::vector<Param*> GetParams() const override {
     std::vector<Param*> params{weight_, bias_};
     return params;
@@ -56,10 +50,10 @@ class RBMLayer: virtual public Layer {
   int batchsize_;
   bool first_gibbs_;
   Param* weight_, *bias_;
-
+  Blob<float> pos_data_;
   Blob<float> neg_data_;
   Blob<float> neg_sample_;
-  Blob<float> sample_;
+  Blob<float> pos_sample_;
 };
 
 /**

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/e8d01dca/src/neuralnet/neuron_layer/rbm.cc
----------------------------------------------------------------------
diff --git a/src/neuralnet/neuron_layer/rbm.cc 
b/src/neuralnet/neuron_layer/rbm.cc
index f0b73b3..fc15e79 100644
--- a/src/neuralnet/neuron_layer/rbm.cc
+++ b/src/neuralnet/neuron_layer/rbm.cc
@@ -32,8 +32,8 @@ using std::vector;
 Blob<float>* RBMLayer::Sample(int flag) {
   Tensor<cpu, 2> sample, data;
   if ((flag & kPositive) == kPositive || first_gibbs_) {
-    data = Tensor2(&data_);
-    sample = Tensor2(&sample_);
+    data = Tensor2(&pos_data_);
+    sample = Tensor2(&pos_sample_);
   } else {
     data = Tensor2(&neg_data_);
     sample = Tensor2(&neg_sample_);
@@ -46,13 +46,20 @@ Blob<float>* RBMLayer::Sample(int flag) {
     random->SampleBinary(sample, data);
   }
   return (flag & kPositive) == kPositive || first_gibbs_ ?
-    &sample_ : &neg_sample_;
+    &pos_sample_ : &neg_sample_;
 }
 void RBMLayer::Setup(const LayerProto& conf, const vector<Layer*>& srclayers) {
   Layer::Setup(conf, srclayers);
   hdim_ = conf.rbm_conf().hdim();
   gaussian_ = conf.rbm_conf().gaussian();
   first_gibbs_ = true;
+  //pos_data_, neg_data_, neg_sample_, pos_sample_
+  datavec_.clear();
+  datavec_.push_back(&pos_data_);
+  datavec_.push_back(&neg_data_);
+  datavec_.push_back(&neg_sample_);
+  datavec_.push_back(&pos_sample_);
+  gradvec_.resize(4);
 }
 /**************** Implementation for RBMVisLayer********************/
 RBMVisLayer::~RBMVisLayer() {
@@ -76,9 +83,9 @@ void RBMVisLayer::Setup(const LayerProto& conf,
   input_layer_ = srclayers[0] != hid_layer_ ? srclayers[0]: srclayers[1];
   const auto& src = input_layer_->data(this);
   batchsize_ = src.shape()[0];
-  data_.ReshapeLike(src);
-  neg_data_.ReshapeLike(data_);
-  neg_sample_.ReshapeLike(data_);
+  pos_data_.ReshapeLike(src);
+  neg_data_.ReshapeLike(pos_data_);
+  neg_sample_.ReshapeLike(pos_data_);
   vdim_ = src.count() / batchsize_;
   weight_ = Param::Create(conf.param(0));
   weight_ ->Setup(vector<int>{hdim_, vdim_});
@@ -88,7 +95,7 @@ void RBMVisLayer::Setup(const LayerProto& conf,
 
 void RBMVisLayer::ComputeFeature(int flag, const vector<Layer*>& srclayers) {
   if ((flag & kPositive) == kPositive) {
-    data_.CopyFrom(input_layer_->data(this), true);
+    pos_data_.CopyFrom(input_layer_->data(this), true);
     first_gibbs_ = true;
   } else if ((flag & kNegative) == kNegative) {
     // fetch sampling results from hidden layer
@@ -100,9 +107,9 @@ void RBMVisLayer::ComputeFeature(int flag, const 
vector<Layer*>& srclayers) {
     data += expr::repmat(bias, batchsize_);
     data = expr::F<op::sigmoid>(data);
     if ((flag & kTest) == kTest) {
-      const float *dptr = data_.cpu_data(), *rcns = neg_data_.cpu_data();
+      const float *dptr = pos_data_.cpu_data(), *rcns = neg_data_.cpu_data();
       float err = 0.f;
-      for (int i = 0; i < data_.count(); i++) {
+      for (int i = 0; i < pos_data_.count(); i++) {
         err += (dptr[i] - rcns[i]) * (dptr[i] - rcns[i]);
       }
       metric_.Add("Squared Error", err / batchsize_);
@@ -112,10 +119,10 @@ void RBMVisLayer::ComputeFeature(int flag, const 
vector<Layer*>& srclayers) {
 }
 
 void RBMVisLayer::ComputeGradient(int flag, const vector<Layer*>& srclayers) {
-  auto vis_pos = Tensor2(&data_);
+  auto vis_pos = Tensor2(&pos_data_);
   auto vis_neg = Tensor2(&neg_data_);
-  auto hid_pos = Tensor2(hid_layer_->mutable_data(this));
-  auto hid_neg = Tensor2(hid_layer_->mutable_neg_data(this));
+  auto hid_pos = Tensor2(hid_layer_->mutable_data(0));
+  auto hid_neg = Tensor2(hid_layer_->mutable_data(1));
 
   auto gbias = Tensor1(bias_->mutable_grad());
   gbias = expr::sum_rows(vis_neg);
@@ -137,13 +144,13 @@ void RBMHidLayer::Setup(const LayerProto& conf,
       const vector<Layer*>& srclayers) {
   RBMLayer::Setup(conf, srclayers);
   CHECK_EQ(srclayers.size(), 1);
-  const auto& src_data = srclayers[0]->data(this);
+  const auto& src_data = srclayers[0]->data(0);
   batchsize_ = src_data.shape()[0];
   vdim_ = src_data.count() / batchsize_;
-  data_.Reshape(vector<int>{batchsize_, hdim_});
-  neg_data_.ReshapeLike(data_);
-  sample_.ReshapeLike(data_);
-  neg_sample_.ReshapeLike(data_);
+  pos_data_.Reshape(vector<int>{batchsize_, hdim_});
+  neg_data_.ReshapeLike(pos_data_);
+  pos_sample_.ReshapeLike(pos_data_);
+  neg_sample_.ReshapeLike(pos_data_);
   weight_ = Param::Create(conf.param(0));
   weight_->Setup(vector<int>{hdim_, vdim_});
   bias_ = Param::Create(conf.param(1));
@@ -157,13 +164,13 @@ void RBMHidLayer::ComputeFeature(int flag, const 
vector<Layer*>& srclayers) {
 
   Tensor<cpu, 2> data, src;
   if ((flag & kPositive) == kPositive) {
-    data = Tensor2(&data_);
-    src = Tensor2(vis_layer_->mutable_data(this));
+    data = Tensor2(&pos_data_);
+    src = Tensor2(vis_layer_->mutable_data(0));
     first_gibbs_ = true;
   } else {
     data = Tensor2(&neg_data_);
     // hinton's science paper does not sample the vis layer
-    src = Tensor2(vis_layer_->mutable_neg_data(this));
+    src = Tensor2(vis_layer_->mutable_data(1));
     first_gibbs_ = false;
   }
   data = dot(src, weight.T());
@@ -174,7 +181,7 @@ void RBMHidLayer::ComputeFeature(int flag, const 
vector<Layer*>& srclayers) {
 }
 
 void RBMHidLayer::ComputeGradient(int flag, const vector<Layer*>& srclayers) {
-  auto hid_pos = Tensor2(&data_);
+  auto hid_pos = Tensor2(&pos_data_);
   auto hid_neg = Tensor2(&neg_data_);
   auto gbias = Tensor1(bias_->mutable_grad());
   gbias = expr::sum_rows(hid_neg);

Reply via email to