Repository: incubator-singa Updated Branches: refs/heads/master e8e592f1e -> 4c289744d
SINGA-9 Add Support for Restricted Boltzman Machine (RBM) model Implement the Contrastive Divergence (CD) algorithm to train RBM model. We have implemented a BPWorker to run the Back-Propagation algorithm. To implement the CD algorithm, we follow the same way to create a CDWorker whose RunOneBatch function controls the logic of the CD algorithm, including positive phase, negative phase and computing gradient phase. Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/4afa4685 Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/4afa4685 Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/4afa4685 Branch: refs/heads/master Commit: 4afa468581b78903648e8bdcdb87275061fb51a3 Parents: c3a248a Author: wang wei <[email protected]> Authored: Sun Jul 19 11:45:26 2015 +0800 Committer: zhaojing <[email protected]> Committed: Fri Jul 24 16:14:31 2015 +0800 ---------------------------------------------------------------------- examples/mnist/rbm_job.conf | 99 +++++++++++++++++++++ include/mshadow/tensor_random.h | 34 ++++++++ include/neuralnet/base_layer.h | 26 +++--- include/neuralnet/layer.h | 103 ++++++++++++++++++++++ include/trainer/worker.h | 13 +++ src/neuralnet/base_layer.cc | 8 +- src/neuralnet/layer.cc | 161 +++++++++++++++++++++++++++++++++++ src/neuralnet/neuralnet.cc | 3 + src/proto/job.proto | 21 +++++ src/trainer/trainer.cc | 2 +- src/trainer/worker.cc | 60 +++++++++++++ src/utils/graph.cc | 31 +++++-- src/utils/param.cc | 3 +- 13 files changed, 542 insertions(+), 22 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/4afa4685/examples/mnist/rbm_job.conf ---------------------------------------------------------------------- diff --git a/examples/mnist/rbm_job.conf b/examples/mnist/rbm_job.conf new file mode 100644 index 0000000..3dcbbc9 --- /dev/null +++ b/examples/mnist/rbm_job.conf @@ -0,0 +1,99 @@ +cluster { + nworker_groups: 1 + nserver_groups: 1 + nservers_per_group: 1 + nworkers_per_group: 1 +} + +model { + name: "deep-big-simple-dbm" + train_steps: 46000 + test_steps:1 + test_frequency:100 + display_frequency: 100 + debug: true + alg: kContrastiveDivergence + pcd_k: 15 + visualization_frequency: 5000 + updater{ + base_lr: 0.1 + lr_change: kFixed + type: kSGD +# param_type: "Param" + } + + +neuralnet { +layer { + name: "data" + type: kShardData + sharddata_conf { + path: "examples/mnist/mnist_train_shard" + batchsize: 20 + } + exclude: kTest +} + + +layer { + name: "data" + type: kShardData + sharddata_conf { + path: "examples/mnist/mnist_test_shard" + batchsize: 20 + } + exclude: kTrain +} + + +layer{ + name:"mnist" + type: kMnist + srclayers: "data" + mnist_conf { + norm_a: 255 + norm_b: 0 + } +} + +layer{ + name: "RBMVis" + type: kRBMVis + srclayers:"mnist" + srclayers:"RBMHid" + rbmvis_conf{ + num_output: 500 + } + param{ + name: "w1" + init_method: kUniformSqrtFanInOut + low:-9.79 + high:9.79 + } + param{ + name: "b1" + init_method: kConstant + value: 0.0 + } +} + +layer{ + name: "RBMHid" + type: kRBMHid + srclayers:"RBMVis" + rbmhid_conf{ + hid_dim: 500 + } + param{ + name: "w2" +# init_method: kUniformSqrtFanInOut + share_from: "w1" + } + param{ + name: "b2" + init_method: kConstant + value: 0.0 + } +} +} +} http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/4afa4685/include/mshadow/tensor_random.h ---------------------------------------------------------------------- diff --git a/include/mshadow/tensor_random.h b/include/mshadow/tensor_random.h index 717d32c..72164a8 100644 --- a/include/mshadow/tensor_random.h +++ b/include/mshadow/tensor_random.h @@ -69,6 +69,40 @@ namespace mshadow { #endif } /*! + * \brief generate binary data according to a probability matrix + * \param dst destination + * \param a lower bound of uniform + * \param b upper bound of uniform + * \tparam dim dimension of tensor + */ + template<int dim> + inline void SampleBinary( Tensor<cpu, dim> &dst) { + real_t a=0.0f; + real_t b=1.0f; + Tensor<cpu, 2> mat = dst.FlatTo2D(); + std::uniform_real_distribution<real_t> distribution (a,b); + for ( index_t i = 0; i < mat.shape[1]; ++i ) { + #if MSHADOW_USE_MKL + #if MSHADOW_SINGLE_PRECISION + int status = vsRngUniform( 0, vStream_, mat.shape[0], mat[i].dptr, a, b ); + #else + int status = vdRngUniform( 0, vStream_, mat.shape[0], mat[i].dptr, a, b ); + #endif + utils::Assert(status == VSL_STATUS_OK, "Failed to generate random number by MKL.\n" ); + #else + // use stdlib + /* + for ( index_t j = 0; j < mat.shape[0]; ++j ) { + mat[i][j] = this->RandNext()*(b-a) + a; + } + */ + for ( index_t j = 0; j < mat.shape[0]; ++j ) { + mat[i][j] = distribution(gen_) > mat[i][j] ? 0.0f: 1.0f; + } + #endif + } + } + /*! * \brief generate data from uniform [a,b) * \param dst destination * \param a lower bound of uniform http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/4afa4685/include/neuralnet/base_layer.h ---------------------------------------------------------------------- diff --git a/include/neuralnet/base_layer.h b/include/neuralnet/base_layer.h index c00a1c9..ca63da0 100644 --- a/include/neuralnet/base_layer.h +++ b/include/neuralnet/base_layer.h @@ -56,8 +56,8 @@ class Layer { * * @param phase kTrain, kTest, kPositive, etc. */ + virtual void ComputeLoss(Metric* perf) {} virtual void ComputeGradient(Phase phase) = 0; - /** * For print debug info about each layer, e.g., norm of feature vector, * norm of parameters. @@ -140,10 +140,10 @@ class Layer { /** * @return a const ref for Blob storing neuron values of this layer for BP */ - virtual const Blob<float>& data(const Layer* from) const { + virtual const Blob<float>& data(const Layer* from, Phase = kPositive) const { return data_; } - virtual Blob<float>* mutable_data(const Layer* from) { + virtual Blob<float>* mutable_data(const Layer* from, Phase = kPositive) { return &data_; } @@ -207,6 +207,12 @@ class Layer { virtual bool is_bridgelayer() const { return false; } + virtual bool is_vislayer() const { + return false; + } + virtual bool is_hidlayer() const { + return false; + } protected: LayerProto layer_proto_; @@ -244,10 +250,10 @@ class BridgeSrcLayer: public BridgeLayer { ready_ = false; } - const Blob<float>& data(const Layer* from) const override { + const Blob<float>& data(const Layer* from, Phase phase) const override { return srclayers_[0]->data(this); } - Blob<float>* mutable_data(const Layer* from) override { + Blob<float>* mutable_data(const Layer* from, Phase phase) override { return srclayers_[0]->mutable_data(this); } const Blob<float>& grad(const Layer* from) const override { @@ -308,7 +314,7 @@ class DataLayer: public Layer{ bool is_datalayer() const override { return true; } - Blob<float>* mutable_data(const Layer* layer) override { + Blob<float>* mutable_data(const Layer* layer, Phase phase) override { return nullptr; } Blob<float>* mutable_grad(const Layer* layer) override { @@ -353,8 +359,8 @@ class PrefetchLayer : public Layer { void ComputeFeature(Phase phase, Metric* perf) override; void ComputeGradient(Phase phase) override {}; - const Blob<float>& data(const Layer* from) const override; - Blob<float>* mutable_data(const Layer* layer) override; + const Blob<float>& data(const Layer* from, Phase phase) const override; + Blob<float>* mutable_data(const Layer* layer, Phase phase) override; Blob<float>* mutable_grad(const Layer* layer) override { return nullptr; @@ -387,9 +393,9 @@ class SliceLayer: public Layer { ConnectionType dst_layer_connection() const override { return kOneToMany; } - const Blob<float>& data(const Layer* layer) const override; + const Blob<float>& data(const Layer* layer, Phase phase) const override; const Blob<float>& grad(const Layer* layer) const override; - Blob<float>* mutable_data(const Layer* layer) override; + Blob<float>* mutable_data(const Layer* layer, Phase phase) override; Blob<float>* mutable_grad(const Layer* layer) override; protected: http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/4afa4685/include/neuralnet/layer.h ---------------------------------------------------------------------- diff --git a/include/neuralnet/layer.h b/include/neuralnet/layer.h index 7f3b256..05db916 100644 --- a/include/neuralnet/layer.h +++ b/include/neuralnet/layer.h @@ -68,7 +68,110 @@ class DropoutLayer: public Layer { */ Blob<float> mask_; }; +/** + * RBM visible layer + */ +class RBMVisLayer: public Layer { + public: + using Layer::ComputeFeature; + using Layer::ComputeGradient; + + void Setup(const LayerProto& proto, + int npartitions) override; + virtual bool is_vislayer() const { + return true; + } + + void ComputeFeature(Phase phase, + Metric *perf) override; + void ComputeGradient(Phase phase) override; + virtual void ComputeLoss(Metric* perf); + virtual Blob<float>* mutable_data(const Layer* from, Phase phase) { + if (phase == kPositive) { + return &data_; + } else { + return &vis_sample_; + } + } + virtual const Blob<float>& data(const Layer* from, Phase phase) const { + if (phase == kPositive) { + return data_; + } else { + return vis_sample_; + } + } + // virtual void ToProto(LayerProto *layer_proto, bool copyData); + const vector<Param*> GetParams() const override { + vector<Param*> params{weight_, bias_}; + return params; + } + ~RBMVisLayer(); + + + private: + //! dimension of the hidden layer + int hdim_; + //! dimension of the visible layer + int vdim_; + int batchsize_; + // batchsize of negative phase + int neg_batchsize_; + bool is_first_iteration_vis_; + float scale_; + // srclayer index + int data_idx_; + int hid_idx_; + Param* weight_, *bias_; + // data to store sampling result + Blob<float> vis_sample_; + // in order to implement Persistent Contrastive Divergence, +}; +/** + * RBM hidden layer + */ +class RBMHidLayer: public Layer { + public: + using Layer::ComputeFeature; + using Layer::ComputeGradient; + + void Setup(const LayerProto& proto, + int npartitions) override; + virtual bool is_hidlayer() const { + return true; + } + + void ComputeFeature(Phase phase, + Metric *perf) override; + void ComputeGradient(Phase phase) override; + virtual Blob<float>* mutable_data(const Layer* from, Phase phase) { + if (phase == kPositive) + return &data_; + else + return &hid_sample_; + } + virtual const Blob<float>& data(const Layer* from, Phase phase) const { + if (phase == kPositive) + return data_; + else + return hid_sample_; + } + const vector<Param*> GetParams() const override { + vector<Param*> params{weight_, bias_}; + return params; + } + ~RBMHidLayer(); + private: + //! dimension of the hidden layer + int hdim_; + int vdim_; // dimension of visible layer + int batchsize_; + // batchsize of negative phase + int neg_batchsize_; + float scale_; + Blob<float> hid_sample_; + Param* weight_, *bias_; +}; /** * fully connected layer */ http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/4afa4685/include/trainer/worker.h ---------------------------------------------------------------------- diff --git a/include/trainer/worker.h b/include/trainer/worker.h index db8cac3..025bcc1 100644 --- a/include/trainer/worker.h +++ b/include/trainer/worker.h @@ -191,6 +191,19 @@ class BPWorker: public Worker{ void Backward(int step, shared_ptr<NeuralNet> net); }; +class CDWorker: public Worker{ + public: + CDWorker(int thread_id, int group_id, int worker_id); + ~CDWorker() {} + virtual void TrainOneBatch(int step, Metric* perf); + virtual void TestOneBatch(int step, Phase phase, + shared_ptr<NeuralNet> net, Metric* perf); + void PositivePhase(int step, shared_ptr<NeuralNet> net, Metric* perf); + void NegativePhase(int step, shared_ptr<NeuralNet> net, Metric* perf); + void GradientPhase(int step, shared_ptr<NeuralNet> net); + void LossPhase(int step, shared_ptr<NeuralNet> net, Metric* perf); +}; + inline int BlobTrgt(int grp, int layer) { return (grp << 16) | layer; } http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/4afa4685/src/neuralnet/base_layer.cc ---------------------------------------------------------------------- diff --git a/src/neuralnet/base_layer.cc b/src/neuralnet/base_layer.cc index e5fd822..57163e9 100644 --- a/src/neuralnet/base_layer.cc +++ b/src/neuralnet/base_layer.cc @@ -124,7 +124,7 @@ void PrefetchLayer::Setup(const LayerProto& proto, int npartitions) { datablobs_[layer->name()]=Blob<float>(layer->data(this).shape()); } -const Blob<float>& PrefetchLayer::data(const Layer* from) const { +const Blob<float>& PrefetchLayer::data(const Layer* from, Phase phase) const { if(from!=nullptr){ return datablobs_.at(from->datablob()); }else{ @@ -133,7 +133,7 @@ const Blob<float>& PrefetchLayer::data(const Layer* from) const { } } -Blob<float>* PrefetchLayer::mutable_data(const Layer* from) { +Blob<float>* PrefetchLayer::mutable_data(const Layer* from, Phase phase) { if(from!=nullptr){ return &(datablobs_.at(from->datablob())); }else{ @@ -183,7 +183,7 @@ int SliceLayer::SliceID(const Layer* layer) const { return -1; } -const Blob<float>& SliceLayer::data(const Layer* layer) const { +const Blob<float>& SliceLayer::data(const Layer* layer, Phase phase) const { if(layer==nullptr) return data_; return datavec_[SliceID(layer)]; @@ -193,7 +193,7 @@ const Blob<float>& SliceLayer::grad(const Layer* layer) const { return grad_; return gradvec_[SliceID(layer)]; } -Blob<float>* SliceLayer::mutable_data(const Layer* layer) { +Blob<float>* SliceLayer::mutable_data(const Layer* layer, Phase phase) { if(layer==nullptr) return &data_; return &datavec_[SliceID(layer)]; http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/4afa4685/src/neuralnet/layer.cc ---------------------------------------------------------------------- diff --git a/src/neuralnet/layer.cc b/src/neuralnet/layer.cc index 3d9def4..926bd17 100644 --- a/src/neuralnet/layer.cc +++ b/src/neuralnet/layer.cc @@ -160,7 +160,168 @@ void DropoutLayer::ComputeGradient(Phase phase) { auto gsrc = Tensor1(srclayers_[0]->mutable_grad(this)); gsrc = grad * mask; } +/**************** Implementation for RBMVisLayer********************/ +RBMVisLayer::~RBMVisLayer() { + delete weight_; + delete bias_; +} +void RBMVisLayer::Setup(const LayerProto& proto, + int npartitions) { + Layer::Setup(proto, npartitions); + CHECK_EQ(srclayers_.size(), 2); + // hid_idx_: index indicating which srclayer is is hidden layer + // data_idx_: index indicating which srclayer is data layer + for (unsigned int i = 0; i < srclayers_.size(); i++) + for (unsigned int j = 0; j < (srclayers_[i]-> dstlayers()).size(); j++) + if (strcmp(((srclayers_[i]->dstlayers()).at(j)->name().c_str()), + (this->name()).c_str()) == 0) + hid_idx_ = i; + for (unsigned int i = 0; i < srclayers_.size(); i++) + if (i != static_cast<unsigned int>(hid_idx_) ) + data_idx_ = i; + const auto& src = srclayers_[data_idx_]->data(this); + is_first_iteration_vis_ = true; + batchsize_ = src.shape()[0]; + neg_batchsize_ = batchsize_; + /*gibbs sampling size and input have the same size*/ + vdim_ = src.count()/batchsize_; + hdim_ = proto.rbmvis_conf().num_output(); + data_.Reshape(vector<int>{batchsize_, vdim_}); // this is visible dimension + vis_sample_.Reshape(vector<int>{neg_batchsize_, vdim_}); + Factory<Param>* factory = Singleton<Factory<Param>>::Instance(); + weight_ = factory->Create("Param"); + bias_ = factory->Create("Param"); + weight_->Setup(proto.param(0), vector<int>{vdim_, hdim_}); + bias_->Setup(proto.param(1), vector<int>{vdim_}); +} + +void RBMVisLayer::ComputeFeature(Phase phase, Metric* perf) { + if (phase == kPositive) { /*positive phase*/ + auto data = Tensor2(&data_); + CHECK_EQ(srclayers_[data_idx_]->data(this).count(), batchsize_*vdim_); + auto src = Tensor2(srclayers_[data_idx_]->mutable_data(this)); + Copy(data, src); + } else if (phase == kNegative) { /*negative phase*/ + if (is_first_iteration_vis_) { + CHECK_EQ(srclayers_[data_idx_]->data(this).count(), batchsize_*vdim_); + auto src = Tensor2(srclayers_[data_idx_]->mutable_data(this)); + auto vis_sample = Tensor2(&vis_sample_); + Copy(vis_sample, src); + is_first_iteration_vis_ = false; + } else { + auto hid_sample = + Tensor2(srclayers_[hid_idx_]->mutable_data(this, kNegative)); + // fetch sampling results from hidden layer + auto vis_sample = Tensor2(&vis_sample_); + auto weight = Tensor2(weight_->mutable_data()); + auto bias = Tensor1(bias_->mutable_data()); + vis_sample = dot(hid_sample, weight.T()); + vis_sample+=repmat(bias, neg_batchsize_); + vis_sample = F<op::sigmoid>(vis_sample); + TSingleton<Random<cpu>>::Instance()->SampleBinary(vis_sample); + } + } +} + +void RBMVisLayer::ComputeGradient(Phase phase) { + auto data = Tensor2(&data_); + auto hid_data = Tensor2(srclayers_[hid_idx_]->mutable_data(this, kPositive)); + auto vis_sample = Tensor2(&vis_sample_); + auto hid_sample = + Tensor2(srclayers_[hid_idx_]->mutable_data(this, kNegative)); + // fetch sampling results from hidden layer + auto gweight = Tensor2(weight_->mutable_grad()); + auto gbias = Tensor1(bias_->mutable_grad()); + gbias = sum_rows(vis_sample); + gbias -= sum_rows(data); + gweight = dot(vis_sample.T(), hid_sample); + gweight -= dot(data.T(), hid_data); + gbias*=(1.0f)/(1.0f*batchsize_); + gweight*=(1.0f)/(1.0f*batchsize_); +} + +void RBMVisLayer::ComputeLoss(Metric* perf) { + float loss = (0.0f); + CHECK_EQ(srclayers_[data_idx_]->data(this).count(), batchsize_*vdim_); + auto src = Tensor2(srclayers_[data_idx_]->mutable_data(this)); + auto hid_data = Tensor2(srclayers_[hid_idx_]->mutable_data(this, kPositive)); + // gibbs using u + auto weight = Tensor2(weight_->mutable_data()); + auto bias = Tensor1(bias_->mutable_data()); + Tensor<cpu, 2> reconstruct(Shape2(batchsize_, vdim_)); /*reconstruct error*/ + AllocSpace(reconstruct); + reconstruct = dot(hid_data, weight.T()); + reconstruct+=repmat(bias, batchsize_); + reconstruct = F<op::sigmoid>(reconstruct); + float *src_dptr = src.dptr; + float *reconstruct_dptr = reconstruct.dptr; + for (int i = 0; i < vdim_*batchsize_; i++) + loss += -(src_dptr[i]*log(reconstruct_dptr[i]) + +(1-src_dptr[i])*log(1-reconstruct_dptr[i])); + loss/=batchsize_; + FreeSpace(reconstruct); + perf->Reset(); + perf->Add("reconstruct_error", loss); +} +/**************** Implementation for RBMHidLayer********************/ +RBMHidLayer::~RBMHidLayer() { + delete weight_; + delete bias_; +} +void RBMHidLayer::Setup(const LayerProto& proto, + int npartitions) { + Layer::Setup(proto, npartitions); + CHECK_EQ(srclayers_.size(), 1); + const auto& src_data = srclayers_[0]->data(this, kPositive); + const auto& src_sample = srclayers_[0]->data(this, kNegative); + scale_ = static_cast<float> (1.0f); + batchsize_ = src_data.shape()[0]; + neg_batchsize_ = src_sample.shape()[0]; + vdim_ = src_data.count()/batchsize_; + hdim_ = proto.rbmhid_conf().hid_dim(); + data_.Reshape(vector<int>{batchsize_, hdim_}); + hid_sample_.Reshape(vector<int>{neg_batchsize_, hdim_}); + Factory<Param>* factory = Singleton<Factory<Param>>::Instance(); + bias_ = factory->Create("Param"); + weight_ = factory->Create("Param"); + bias_->Setup(proto.param(1), vector<int>{hdim_}); + weight_->Setup(proto.param(0), vector<int>{vdim_, hdim_}); +} +void RBMHidLayer::ComputeFeature(Phase phase, Metric* perf) { + if (phase == kPositive) { /*postive phase*/ + auto data = Tensor2(&data_); + CHECK_EQ(srclayers_[0]->data(this, kPositive).count(), batchsize_*vdim_); + auto src = Tensor2(srclayers_[0]->mutable_data(this, kPositive)); + auto weight = Tensor2(weight_->mutable_data()); + auto bias = Tensor1(bias_->mutable_data()); + data = dot(src, weight); + data += repmat(bias, batchsize_); + data = F<op::sigmoid>(data); + } else if (phase == kNegative) { /*negative phase*/ + CHECK_EQ(srclayers_[0]->data(this, kNegative).count(), + neg_batchsize_*vdim_); + auto src_sample = Tensor2(srclayers_[0]->mutable_data(this, kNegative)); + auto hid_sample = Tensor2(&hid_sample_); + auto bias = Tensor1(bias_->mutable_data()); + auto weight = Tensor2(weight_->mutable_data()); + hid_sample = dot(src_sample, weight); + hid_sample += repmat(bias, neg_batchsize_); + hid_sample = F<op::sigmoid>(hid_sample); + TSingleton<Random<cpu>>::Instance()->SampleBinary(hid_sample); + } else if (phase == kLoss) { /*test phase*/ + auto data = Tensor2(&data_); // data: sigmoid(Wv+b) + TSingleton<Random<cpu>>::Instance()->SampleBinary(data); + } +} +void RBMHidLayer::ComputeGradient(Phase phase) { + auto data = Tensor2(&data_); + auto hid_sample = Tensor2(&hid_sample_); + auto gbias = Tensor1(bias_->mutable_grad()); + gbias = sum_rows(hid_sample); + gbias -= sum_rows(data); + gbias *= scale_/(1.0f*batchsize_); +} /*********** Implementation for InnerProductLayer**********/ InnerProductLayer::~InnerProductLayer() { delete weight_; http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/4afa4685/src/neuralnet/neuralnet.cc ---------------------------------------------------------------------- diff --git a/src/neuralnet/neuralnet.cc b/src/neuralnet/neuralnet.cc index 7769e45..10ddcf1 100644 --- a/src/neuralnet/neuralnet.cc +++ b/src/neuralnet/neuralnet.cc @@ -33,6 +33,8 @@ void NeuralNet::RegisterLayers() { RegisterLayer(factory, SoftmaxLoss); RegisterLayer(factory, Split); RegisterLayer(factory, Tanh); + RegisterLayer(factory, RBMVis); + RegisterLayer(factory, RBMHid); #ifdef USE_OPTIONAL_LAYER RegisterLayer(factory, LMDBData); @@ -134,6 +136,7 @@ void NeuralNet::CreateNetFromGraph(Graph* graph, int npartitions) { for (Node* node : graph->nodes()) { auto layer = name2layer_[node->name]; layer->Setup(*(static_cast<LayerProto*>(node->proto)), npartitions); + LOG(INFO) << "constructing graph: " << layer->name(); layerinfo[layer->name()] = IntVecToString(layer->data(nullptr).shape()); string param_name = "$"; for (auto param : layer->GetParams()) { http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/4afa4685/src/proto/job.proto ---------------------------------------------------------------------- diff --git a/src/proto/job.proto b/src/proto/job.proto index 7c462d2..068867c 100644 --- a/src/proto/job.proto +++ b/src/proto/job.proto @@ -43,6 +43,7 @@ enum Phase { kNegative = 4; kForward = 5; kBackward = 6; + kLoss = 7; } message ModelProto { @@ -74,6 +75,8 @@ message ModelProto { optional int32 test_frequency = 33 [default = 0]; // frequency of checkpoint optional int32 checkpoint_frequency = 34 [default = 0]; + //frequency of visualization + optional int32 visualization_frequency = 37 [default=5000]; // send parameters to servers after training for this num of steps optional int32 warmup_steps = 35 [default = 0]; // checkpoint path @@ -95,6 +98,8 @@ message ModelProto { repeated string checkpoint = 66; // reset the version of params loaded from checkpoint file to step optional bool reset_param_version = 67 [default = false]; + //number of steps for gibbs sampling + optional int32 pcd_k=69 [default=15]; } message NetProto { @@ -185,6 +190,8 @@ message LayerProto { kSlice = 12; kSplit = 13; kTanh = 14; + kRBMVis = 23; + kRBMHid = 24; } // source layer names repeated string srclayers = 3; @@ -228,6 +235,10 @@ message LayerProto { optional SplitProto split_conf = 42; // configuration for tanh layer optional TanhProto tanh_conf = 43; + // configuration for rbmvis layer + optional RBMVisProto rbmvis_conf = 48; + // configuration for rbmhid layer + optional RBMHidProto rbmhid_conf = 49; // overrides the partition dimension for neural net @@ -326,6 +337,16 @@ message DropoutProto { optional float dropout_ratio = 30 [default = 0.5]; } +message RBMVisProto { + optional int32 num_output = 1; // The number of outputs for the layer + optional bool bias_term = 2 [default = true]; // whether to have bias terms +} + +message RBMHidProto { + optional int32 hid_dim = 1; // The number of outputs for the layer + optional bool bias_term = 2 [default = true]; // whether to have bias terms +} + // Message that stores parameters used by InnerProductLayer message InnerProductProto { // number of outputs for the layer http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/4afa4685/src/trainer/trainer.cc ---------------------------------------------------------------------- diff --git a/src/trainer/trainer.cc b/src/trainer/trainer.cc index 4a0a47a..d5c885f 100644 --- a/src/trainer/trainer.cc +++ b/src/trainer/trainer.cc @@ -186,7 +186,7 @@ vector<Worker*> Trainer::CreateWorkers(int nthreads, const ModelProto& mconf){ if (mconf.alg() == ModelProto_GradCalcAlg_kBackPropagation) worker = new BPWorker(nthreads++,gid, wid); else { - // TODO add CDWorker and BPTTWorker + worker=new CDWorker(nthreads++,gid, wid); } workers.push_back(worker); } http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/4afa4685/src/trainer/worker.cc ---------------------------------------------------------------------- diff --git a/src/trainer/worker.cc b/src/trainer/worker.cc index 87d251d..d9c6a59 100644 --- a/src/trainer/worker.cc +++ b/src/trainer/worker.cc @@ -378,4 +378,64 @@ void BPWorker::TestOneBatch(int step, Phase phase, Forward(step, phase, net, perf); } +/****************************CDWorker**********************************/ +CDWorker::CDWorker(int thread_id, int group_id, int worker_id): + Worker(thread_id, group_id, worker_id) { +} + +void CDWorker::PositivePhase(int step, + shared_ptr<NeuralNet> net, Metric* perf) { + auto& layers = net->layers(); + for (auto& layer : layers) { + // clock_t s=clock(); + layer->ComputeFeature(kPositive, perf); + } +} + +void CDWorker::NegativePhase(int step, + shared_ptr<NeuralNet> net, Metric* perf) { +// for negative phase, gibbs sampling only concerns RBM bottom and top layer + auto& layers = net->layers(); + for (int i = 0; i < modelproto_.pcd_k(); i++) { + for (auto& layer : layers) { + if (layer->is_vislayer() || layer->is_hidlayer()) + layer->ComputeFeature(kNegative, perf); + } + } +} + +void CDWorker::GradientPhase(int step, shared_ptr<NeuralNet> net) { + auto& layers = net->layers(); + for (auto& layer : layers) { + layer->ComputeGradient(kTrain); + for (Param* p : layer->GetParams()) { + Update(p, step); + } + } +} + +void CDWorker::LossPhase(int step, shared_ptr<NeuralNet> net, Metric* perf) { + auto& layers = net->layers(); + for (auto& layer : layers) { + if (layer->is_hidlayer()) + layer->ComputeFeature(kLoss, perf); + } + for (auto& layer : layers) { + if (layer->is_vislayer()) + layer->ComputeLoss(perf); + } +} + +void CDWorker::TrainOneBatch(int step, Metric* perf) { + PositivePhase(step, train_net_, perf); + NegativePhase(step, train_net_, perf); + GradientPhase(step, train_net_); + LossPhase(step, train_net_, perf); +} + +void CDWorker::TestOneBatch(int step, Phase phase, + shared_ptr<NeuralNet> net, Metric* perf) { + PositivePhase(step, test_net_, perf); + LossPhase(step, test_net_, perf); +} } // namespace singa http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/4afa4685/src/utils/graph.cc ---------------------------------------------------------------------- diff --git a/src/utils/graph.cc b/src/utils/graph.cc index d92e241..1f7df06 100644 --- a/src/utils/graph.cc +++ b/src/utils/graph.cc @@ -175,12 +175,27 @@ void Graph::Sort() { auto node = visiting_nodes.front(); visiting_nodes.pop(); bool visit = true; - for (auto src : node->srcnodes) { - // visit this node only if all srouce nodes have been visited - if (visited_set.find(src) == visited_set.end()) { - visit = false; - break; - } + bool bi_direction = false; + // check if a node has a bi-direction edge with its neighbour + for (auto src : node->srcnodes) + for (auto src_of_src : src->srcnodes) + if (strcmp((src_of_src->name).c_str(), (node->name).c_str()) == 0) { + bi_direction = true; + break; + } + // check whether its src nodes number greater than 1 + if (bi_direction && (node->srcnodes).size() > 1) { + auto src = node->srcnodes.at(0); + if (visited_set.find(src) == visited_set.end()) { + visit = false; + } + } + else { + for (auto src : node->srcnodes) + if (visited_set.find(src) == visited_set.end()) { + visit = false; + break; + } } if (visit) { nodes_.push_back(node); @@ -196,6 +211,10 @@ void Graph::Sort() { visiting_nodes.push(node); } } + for (auto node : nodes_) { + LOG(INFO) << "nodes: " << node->name; + } + LOG(INFO) << "finish printing nodes "; CHECK_EQ(nodes_.size(), n); } http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/4afa4685/src/utils/param.cc ---------------------------------------------------------------------- diff --git a/src/utils/param.cc b/src/utils/param.cc index 5541acc..4932221 100644 --- a/src/utils/param.cc +++ b/src/utils/param.cc @@ -274,9 +274,10 @@ void Param::ParseResponseMsg(Msg* msg, int slice_idx) { void Param::ShareFrom(const Param& other) { proto_.set_owner(other.owner()); - if(data_!=nullptr) + if(data_!=nullptr) { CHECK(std::equal(data_->shape().begin(), data_->shape().end(), other.data_->shape().begin())); + } data_ = other.data_; slice_offset_ = other.slice_offset_; slice_size_ = other.slice_size_;
