SINGA-120 - Implemented GRU and BPTT Add input layers for char rnn example.
Fix the bug from worker.cc for flag setting in computegradient Run with GPU; Loss decreases slowly to 3 per unit; Todo add RNNDummyLayer and train with RMSProp Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/959ef705 Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/959ef705 Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/959ef705 Branch: refs/heads/master Commit: 959ef705a66b20b474dfad3e85a9f35635e8690f Parents: 1f03f9d Author: Wei Wang <[email protected]> Authored: Sat Jan 2 22:54:20 2016 +0800 Committer: Wei Wang <[email protected]> Committed: Wed Jan 6 01:55:08 2016 +0800 ---------------------------------------------------------------------- Makefile.am | 4 + include/singa/neuralnet/input_layer.h | 38 ++++ include/singa/neuralnet/layer.h | 6 + include/singa/neuralnet/neuralnet.h | 11 + include/singa/neuralnet/neuron_layer.h | 50 +++++ include/singa/utils/common.h | 5 + include/singa/utils/context.h | 6 + include/singa/utils/math_blob.h | 15 ++ include/singa/utils/param.h | 6 +- include/singa/utils/updater.h | 7 +- include/singa/worker.h | 31 ++- src/driver.cc | 5 + src/neuralnet/input_layer/char_rnn.cc | 95 +++++++++ src/neuralnet/input_layer/rnn_label.cc | 35 ++++ src/neuralnet/neuralnet.cc | 243 +++++++++++++---------- src/neuralnet/neuron_layer/embedding.cc | 98 +++++++++ src/neuralnet/neuron_layer/gru.cc | 115 +++++------ src/neuralnet/neuron_layer/inner_product.cc | 13 +- src/proto/job.proto | 42 +++- src/stub.cc | 2 + src/test/test_gru_layer.cc | 1 - src/test/test_math.cc | 1 - src/utils/common.cc | 26 +++ src/utils/param.cc | 17 +- src/utils/updater.cc | 17 ++ src/worker.cc | 84 +++++++- 26 files changed, 787 insertions(+), 186 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/959ef705/Makefile.am ---------------------------------------------------------------------- diff --git a/Makefile.am b/Makefile.am index aa88348..d2b2aa8 100644 --- a/Makefile.am +++ b/Makefile.am @@ -75,12 +75,15 @@ SINGA_SRCS := src/driver.cc \ src/neuralnet/connection_layer/concate.cc \ src/neuralnet/connection_layer/slice.cc \ src/neuralnet/connection_layer/split.cc \ + src/neuralnet/input_layer/char_rnn.cc \ + src/neuralnet/input_layer/onehot.cc \ src/neuralnet/input_layer/csv.cc \ src/neuralnet/input_layer/image_preprocess.cc \ src/neuralnet/input_layer/prefetch.cc \ src/neuralnet/input_layer/record.cc \ src/neuralnet/input_layer/deprecated.cc \ src/neuralnet/input_layer/store.cc \ + src/neuralnet/input_layer/rnn_label.cc \ src/neuralnet/output_layer/accuracy.cc \ src/neuralnet/output_layer/argsort.cc \ src/neuralnet/output_layer/csv.cc \ @@ -91,6 +94,7 @@ SINGA_SRCS := src/driver.cc \ src/neuralnet/neuron_layer/convolution.cc \ src/neuralnet/neuron_layer/dropout.cc \ src/neuralnet/neuron_layer/dummy.cc \ + src/neuralnet/neuron_layer/embedding.cc \ src/neuralnet/neuron_layer/inner_product.cc \ src/neuralnet/neuron_layer/lrn.cc \ src/neuralnet/neuron_layer/pooling.cc \ http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/959ef705/include/singa/neuralnet/input_layer.h ---------------------------------------------------------------------- diff --git a/include/singa/neuralnet/input_layer.h b/include/singa/neuralnet/input_layer.h index 2825d65..e701eec 100644 --- a/include/singa/neuralnet/input_layer.h +++ b/include/singa/neuralnet/input_layer.h @@ -162,6 +162,44 @@ class PrefetchLayer : public Layer { std::thread thread_; }; +class OneHotLayer : public InputLayer { + public: + void Setup(const LayerProto& proto, const vector<Layer*>& srclayers) override; + void ComputeFeature(int flag, const vector<Layer*>& srclayers); + + private: + int batchsize_, dim_; +}; + +/** + * * Read the ASCII file as a large string used for RNN model where each character + * * is a single input to the unrolled RNN layer. + * * max string length is string::max_size(); + * */ +class CharRNNInputLayer : public InputLayer { + public: + void Setup(const LayerProto& proto, const vector<Layer*>& srclayers) override; + void ComputeFeature(int flag, const vector<Layer*>& srclayers); + + private: + int batchsize_ = 0, unroll_len_ = 1; + unsigned offset_ = 0; + string path_, vocab_path_; + string buf_; + vector<int> start_; + std::unordered_map<char, int> char2index_; +}; + +/** + * Label layer for fetching labels from the src input layer for RNN models. + * The i-th unrolled layer fetch label from the input layer via data(i+1). + * Particularly, it shares data_ Blob with data(i+1) of its src layer. + */ +class RNNLabelLayer : public InputLayer { + public: + void Setup(const LayerProto& proto, const vector<Layer*>& srclayers); + void ComputeFeature(int flag, const vector<Layer*>& srclayers); +}; /****************Deprecated layers******************/ /** * @deprecated please use the StoreInputLayer. http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/959ef705/include/singa/neuralnet/layer.h ---------------------------------------------------------------------- diff --git a/include/singa/neuralnet/layer.h b/include/singa/neuralnet/layer.h index 28ab92e..f4738fa 100644 --- a/include/singa/neuralnet/layer.h +++ b/include/singa/neuralnet/layer.h @@ -174,6 +174,12 @@ class Layer { */ inline const std::string& name() const { return layer_conf_.name(); } /** + * Return the index of the unrolled layer within the unrolling group, which + * should be [0, max_unrolling_length) + */ + inline const int unroll_index() const { return layer_conf_.unroll_index(); } + + /** * @return a const ref for Blob vector storing feature values of this layer. */ virtual const vector<Blob<float>*>& data() const { http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/959ef705/include/singa/neuralnet/neuralnet.h ---------------------------------------------------------------------- diff --git a/include/singa/neuralnet/neuralnet.h b/include/singa/neuralnet/neuralnet.h index be8f5c8..33ad38c 100644 --- a/include/singa/neuralnet/neuralnet.h +++ b/include/singa/neuralnet/neuralnet.h @@ -109,6 +109,15 @@ class NeuralNet { << "layer (" << layer->name() << " ) has no source layers"; return src_map_.at(layer); } + Layer* last_unroll_layer(const Layer* layer) const { + auto pos = layer->name().find("#"); + if (pos == std::string::npos) + return nullptr; + string last_name = std::to_string(unroll_len_) + layer->name().substr(pos); + CHECK(name2layer_.find(last_name) != name2layer_.end()) + << "layer name = " << last_name << " has no unroll layers"; + return name2layer_.at(last_name); + } inline Param* paramid2param(int id) const { return paramid2param_.at(id); } /** @@ -137,6 +146,7 @@ class NeuralNet { * prepare data structures, e.g., params_, layers_, etc. */ void PrepareDataStructures(); + void PrepareDataStructures(const NetProto& proto); /** * add split layers, due to connections to multiple dst-layers */ @@ -149,6 +159,7 @@ class NeuralNet { int npartitions); protected: + int unroll_len_ = 1; std::vector<Layer*> layers_; std::vector<Param*> params_; http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/959ef705/include/singa/neuralnet/neuron_layer.h ---------------------------------------------------------------------- diff --git a/include/singa/neuralnet/neuron_layer.h b/include/singa/neuralnet/neuron_layer.h index 3f126ab..e587e38 100644 --- a/include/singa/neuralnet/neuron_layer.h +++ b/include/singa/neuralnet/neuron_layer.h @@ -131,12 +131,60 @@ class DummyLayer: public NeuronLayer { bool output_ = false; // use as output layer }; +/** + * Embedding layer that converts an array of index ID into a matrix. + * + * Each index ID corresponds to a word (or feature) vector in the vocabulary + * matrix maintained by the embedding layer. + * The index ID ranges within [0, |D|), where |D| is the size of the vocabulary, + * i.e., the number of rows of the vocabulary matrix. + * If the index is -1, which means it is a padding word. A feature vector with + * all values 0 will be constructed and inserted into the feature Blob. + * Users handle special words by themseleves. For example, the index 0 could be + * the starting word/symbol of a sentence, the index 1 could be the ending + * word/symbol of a sentence. + */ +class EmbeddingLayer : public NeuronLayer { + public: + ~EmbeddingLayer() { + delete vocab_; + } + void Setup(const LayerProto& proto, const vector<Layer*>& srclayers) override; + void ComputeFeature(int flag, const vector<Layer*>& srclayers) override; + void ComputeGradient(int flag, const vector<Layer*>& srclayers) override; + const std::vector<Param*> GetParams() const override { + std::vector<Param*> params; + params.push_back(vocab_); + return params; + } + + private: + int vocab_size_, feature_dim_, batchsize_; + //!< the vocabulary matrix to be learned + Param *vocab_; +}; + class GRULayer : public NeuronLayer { public: ~GRULayer(); void Setup(const LayerProto& proto, const vector<Layer*>& srclayers) override; void ComputeFeature(int flag, const vector<Layer*>& srclayers) override; void ComputeGradient(int flag, const vector<Layer*>& srclayers) override; + ConnectionType dst_layer_connection() const override{ + return kOneToMany; + } + Blob<float>* mutable_grad(const Layer* from) override { + if (typeid(*from) == typeid(GRULayer)) + return gradvec_[1]; + else + return gradvec_[0]; + } + const Blob<float>& grad(const Layer* from) override{ + if (typeid(*from) == typeid(GRULayer)) + return *gradvec_[1]; + else + return *gradvec_[0]; + } const std::vector<Param*> GetParams() const override { if (bias_z_ != nullptr && bias_r_ != nullptr && bias_c_ != nullptr) { @@ -156,6 +204,8 @@ class GRULayer : public NeuronLayer { int vdim_, hdim_; // dimensions Blob<float> *update_gate, *reset_gate, *new_memory; + //!< gru layer connect to two dst layers, hence need to grad blobs. + Blob<float> aux_grad_; Param *weight_z_hx_, *weight_z_hh_, *bias_z_; // update gate Param *weight_r_hx_, *weight_r_hh_, *bias_r_; // reset gate http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/959ef705/include/singa/utils/common.h ---------------------------------------------------------------------- diff --git a/include/singa/utils/common.h b/include/singa/utils/common.h index afbe954..0bcec58 100644 --- a/include/singa/utils/common.h +++ b/include/singa/utils/common.h @@ -155,6 +155,11 @@ void WriteProtoToBinaryFile(const Message& proto, const char* filename); * Write a string (e.g., graph reprensetation of a net) into a text file. */ void WriteStringToTextFile(const string& filename, const string& context); + +/** + * Parse metric pairs (key = value[, key = value]) from string + */ +const vector<std::pair<string, float>> GetMetricFromString(const string& disp); } // namespace singa #endif // SINGA_UTILS_COMMON_H_ http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/959ef705/include/singa/utils/context.h ---------------------------------------------------------------------- diff --git a/include/singa/utils/context.h b/include/singa/utils/context.h index 8e7bbb8..b1128c1 100644 --- a/include/singa/utils/context.h +++ b/include/singa/utils/context.h @@ -100,6 +100,12 @@ class Context { } /** + * @return the device ID of the current thread. + */ + int device_id() { + return device_id(std::this_thread::get_id()); + } + /** * @return the ID of the device attached to a given CPU thread, or -1 if this * thread has not been attached GPU device. */ http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/959ef705/include/singa/utils/math_blob.h ---------------------------------------------------------------------- diff --git a/include/singa/utils/math_blob.h b/include/singa/utils/math_blob.h index 125509a..bdaf914 100644 --- a/include/singa/utils/math_blob.h +++ b/include/singa/utils/math_blob.h @@ -712,6 +712,21 @@ void Softmax(int nb_rows, const Blob<Dtype>& A, Blob<Dtype>* B) { #endif // USE_GPU } } + +template<typename Dtype> +void Zero(Blob<Dtype>* B) { + auto context = Singleton<Context>::Instance(); + int device = context->device_id(std::this_thread::get_id()); + if (device == -1) { + B->SetValue(0); + } else { +#ifdef USE_GPU + cudaMemset(B->mutable_gpu_data(), 0, B->count() * sizeof(float)); +#else + LOG(FATAL) << "Not implemented"; +#endif // USE_GPU + } +} } // end of namespace singa #endif // SINGA_UTILS_MATH_BLOB_H_ http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/959ef705/include/singa/utils/param.h ---------------------------------------------------------------------- diff --git a/include/singa/utils/param.h b/include/singa/utils/param.h index 415490e..9930710 100644 --- a/include/singa/utils/param.h +++ b/include/singa/utils/param.h @@ -146,7 +146,11 @@ class Param { * @param cpu_only if true, share only cpu memory (used for training with * multi-gpu cards); else, share both cpu and gpu memory. */ - void ShareFrom(Param* other, bool cpu_only); + void ShareDataFrom(Param* other, bool cpu_only); + /** + * Share both data and grad from other param + */ + void ShareFrom(Param* other); /** * Init param values from checkpoint blob. */ http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/959ef705/include/singa/utils/updater.h ---------------------------------------------------------------------- diff --git a/include/singa/utils/updater.h b/include/singa/utils/updater.h index 6413a80..575ab86 100644 --- a/include/singa/utils/updater.h +++ b/include/singa/utils/updater.h @@ -7,9 +7,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at -* +* * http://www.apache.org/licenses/LICENSE-2.0 -* +* * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -93,12 +93,13 @@ class Updater { virtual void Init(const UpdaterProto &proto); virtual void Update(int step, Param* param, float grad_scale) = 0; - + void Clip(const float low, const float high, Param* param); protected: UpdaterProto proto_; LRGenerator* lr_gen_; float weight_decay_; float momentum_; + float clip_low_, clip_high_; }; class SGDUpdater : public Updater { http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/959ef705/include/singa/worker.h ---------------------------------------------------------------------- diff --git a/include/singa/worker.h b/include/singa/worker.h index 8738c27..34c8000 100644 --- a/include/singa/worker.h +++ b/include/singa/worker.h @@ -165,7 +165,7 @@ class Worker { * @param prefix display prefix, e.g., 'Train step 100', 'Test step 90'. * @param net display layers from this neural net. */ - void Display(int flag, const std::string& prefix, NeuralNet* net); + virtual void Display(int flag, const std::string& prefix, NeuralNet* net); /** * Put Param values to server. * @@ -284,10 +284,35 @@ class BPWorker: public Worker { public: void TrainOneBatch(int step, NeuralNet* net) override; void TestOneBatch(int step, Phase phase, NeuralNet* net) override; - void Forward(int step, Phase phase, NeuralNet* net); - void Backward(int step, NeuralNet* net); + virtual void Forward(int step, Phase phase, NeuralNet* net); + virtual void Backward(int step, NeuralNet* net); }; +/** + * Subclass of Worker that implements BPTT (Backpropagation through time) + * algorithm for computing gradients of RNN models. + * Max BPTT/unrolling length is configured by users. + */ +class BPTTWorker: public BPWorker { + public: + void Forward(int step, Phase phase, NeuralNet* net) override; + void Backward(int step, NeuralNet* net) override; + void Display(int flag, const std::string& prefix, NeuralNet* net) override; + + private: + /* + * indicator used in truncted BPTT, which feeds the hidden state of the last + * unrolled unit to the first unit in Forward() for the next iteration. + * currently always feed the last hidden state to the first. + */ + bool full_state_ = false; + //!< indicator used for the starting of a new pass of the dataset. + bool begin_ = false; +}; +/** + * Subclass of Worker that implements the Contrastive Divergence algorithm for + * computing the gradients of paramters of energy models. + */ class CDWorker: public Worker { public: void TrainOneBatch(int step, NeuralNet* net) override; http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/959ef705/src/driver.cc ---------------------------------------------------------------------- diff --git a/src/driver.cc b/src/driver.cc index 5e0772b..21968bb 100644 --- a/src/driver.cc +++ b/src/driver.cc @@ -71,6 +71,9 @@ void Driver::Init(int argc, char **argv) { RegisterLayer<ImagePreprocessLayer, int>(kImagePreprocess); RegisterLayer<RecordOutputLayer, int>(kRecordOutput); RegisterLayer<CSVOutputLayer, int>(kCSVOutput); + RegisterLayer<CharRNNInputLayer, int>(kCharRNN); + RegisterLayer<RNNLabelLayer, int>(kRNNLabel); + RegisterLayer<OneHotLayer, int>(kOneHot); // connection layers RegisterLayer<BridgeDstLayer, int>(kBridgeDst); @@ -84,6 +87,7 @@ void Driver::Init(int argc, char **argv) { RegisterLayer<ConvolutionLayer, int>(kConvolution); RegisterLayer<CConvolutionLayer, int>(kCConvolution); RegisterLayer<CPoolingLayer, int>(kCPooling); + RegisterLayer<EmbeddingLayer, int>(kEmbedding); #ifdef USE_CUDNN RegisterLayer<CudnnActivationLayer, int>(kCudnnActivation); @@ -135,6 +139,7 @@ void Driver::Init(int argc, char **argv) { // register workers RegisterWorker<BPWorker>(kBP); + RegisterWorker<BPTTWorker>(kBPTT); RegisterWorker<CDWorker>(kCD); // register params http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/959ef705/src/neuralnet/input_layer/char_rnn.cc ---------------------------------------------------------------------- diff --git a/src/neuralnet/input_layer/char_rnn.cc b/src/neuralnet/input_layer/char_rnn.cc new file mode 100644 index 0000000..cc13b1b --- /dev/null +++ b/src/neuralnet/input_layer/char_rnn.cc @@ -0,0 +1,95 @@ +/************************************************************ +* +* Licensed to the Apache Software Foundation (ASF) under one +* or more contributor license agreements. See the NOTICE file +* distributed with this work for additional information +* regarding copyright ownership. The ASF licenses this file +* to you under the Apache License, Version 2.0 (the +* "License"); you may not use this file except in compliance +* with the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an +* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +* KIND, either express or implied. See the License for the +* specific language governing permissions and limitations +* under the License. +* +*************************************************************/ +#include <sstream> +#include <fstream> +#include "singa/neuralnet/input_layer.h" +namespace singa { + +void CharRNNInputLayer::Setup(const LayerProto& conf, + const vector<Layer*>& srclayers) { + InputLayer::Setup(conf, srclayers); + batchsize_ = conf.char_rnn_conf().batchsize(); + path_ = conf.char_rnn_conf().path(); + vocab_path_ = conf.char_rnn_conf().vocab_path(); + unroll_len_ = conf.char_rnn_conf().unroll_len(); + datavec_.clear(); + // each unroll layer has a input blob + for (int i = 0; i <= unroll_len_; i++) { + datavec_.push_back(new Blob<float>(batchsize_)); + } +} + +void CharRNNInputLayer::ComputeFeature(int flag, + const vector<Layer*>& srclayers) { + if (buf_.size() == 0) { + + // read the vocab + { + std::ifstream fin; + fin.open(vocab_path_); + CHECK(fin.is_open()) << "Can't open vocab_path = " << vocab_path_; + std::stringstream stream; + stream << fin.rdbuf(); + string vocab = stream.str(); + LOG(ERROR) << "Vocab_size = " << vocab.length(); + for (char c : vocab) + char2index_[c] = char2index_.size() - 1; + fin.close(); + } + + // read the whole text file + { + std::ifstream fin; + fin.open(path_); + CHECK(fin.is_open()) << "Can't open filepath = " << path_; + std::stringstream stream; + stream << fin.rdbuf(); + buf_ = stream.str(); + fin.close(); + } + + // decide the start pos of each instance in one mini-batch + int max_offset = buf_.length() / batchsize_; + CHECK_GT(max_offset, unroll_len_); + for (int i = 0; i < batchsize_; i ++) { + start_.push_back(i * max_offset); + } + } + + for (int l = 0; l < unroll_len_ + 1; l++) { + float* ptr = datavec_[l]->mutable_cpu_data(); + for (int i = 0; i < batchsize_; i++) { + ptr[i] = static_cast<float>(char2index_.at(buf_[start_[i] + l])); + } + } + offset_ += unroll_len_; + if (offset_ >= buf_.length() / batchsize_) { +// unsigned seed = std::chrono::system_clock::now().time_since_epoch().count(); +// std::mt19937 g(seed); +// std::shuffle(start_.begin(), start_.end(), g); + offset_ = 0; + // return -1; + } else { + // return 0; + } +} + +} // namespace singa http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/959ef705/src/neuralnet/input_layer/rnn_label.cc ---------------------------------------------------------------------- diff --git a/src/neuralnet/input_layer/rnn_label.cc b/src/neuralnet/input_layer/rnn_label.cc new file mode 100644 index 0000000..4924d87 --- /dev/null +++ b/src/neuralnet/input_layer/rnn_label.cc @@ -0,0 +1,35 @@ +/************************************************************ +* +* Licensed to the Apache Software Foundation (ASF) under one +* or more contributor license agreements. See the NOTICE file +* distributed with this work for additional information +* regarding copyright ownership. The ASF licenses this file +* to you under the Apache License, Version 2.0 (the +* "License"); you may not use this file except in compliance +* with the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an +* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +* KIND, either express or implied. See the License for the +* specific language governing permissions and limitations +* under the License. +* +*************************************************************/ + +#include "singa/neuralnet/input_layer.h" +namespace singa { +void RNNLabelLayer::Setup(const LayerProto& proto, + const vector<Layer*>& srclayers) { + InputLayer::Setup(proto, srclayers); + aux_data_.resize(srclayers[0]->data(unroll_index() + 1).shape(0)); +} +void RNNLabelLayer::ComputeFeature(int flag, const vector<Layer*>& srclayers) { + const float* input = srclayers[0]->data(unroll_index() + 1).cpu_data(); + for (unsigned i = 0; i < aux_data_.size(); i++) { + aux_data_[i] = static_cast<int>(input[i]); + } +} +} // namespace singa http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/959ef705/src/neuralnet/neuralnet.cc ---------------------------------------------------------------------- diff --git a/src/neuralnet/neuralnet.cc b/src/neuralnet/neuralnet.cc index 6bb0ecd..f9579b1 100644 --- a/src/neuralnet/neuralnet.cc +++ b/src/neuralnet/neuralnet.cc @@ -33,10 +33,36 @@ using std::map; using std::string; using std::vector; +/** + * Check user defined net config and make some preprocessing, e.g., assing names + * to params. + * TODO(wnagwei) implement the following functions. + * 1. layer and paramname should not include '@', '+' and '#'. '@<suffix>' + * is used for identifying layer location/partition. '<prefix>#' is used for + * identifying the unrolled Param in RNN models. + * 2. assign names to unnamed Param, e.g., p<param_id>+<layer_name>. + */ +const NetProto NetConfPreprocess(const NetProto& conf) { + /* + string param_name = "$"; + // if user does not name the param, then name it based on layer name. + if (param->name() == "") { + param->set_name(layer->name() + param_name); + param_name += "$"; + } + */ + NetProto proto = conf; + for (int i = 0; i < proto.layer_size(); i++) { + if (!proto.layer(i).has_unroll_len()) + proto.mutable_layer(i)->set_unroll_len(proto.unroll_len()); + } + return proto; +} + NeuralNet* NeuralNet::Create(const NetProto& net_conf, Phase phase, int npartitions) { - NetProto conf; - conf.CopyFrom(net_conf); + const NetProto& full_net_conf = NetConfPreprocess(net_conf); + NetProto conf = full_net_conf; conf.clear_layer(); // flag=0: neither exclude nor include field appears // flag=1: exclude field appears @@ -45,25 +71,19 @@ NeuralNet* NeuralNet::Create(const NetProto& net_conf, Phase phase, // exclude layers according to phase // exclude field is deprecated // please use include field instead - for (const auto& layer : net_conf.layer()) { + for (const auto& layer : full_net_conf.layer()) { bool include = true; for (auto p : layer.exclude()) { // check whether both exclude and include field // appear in the same .conf file - CHECK(flag == 0 || flag == 1) - << "include and exclude field should not simultaneously" - << " appear in the same .conf file"; + CHECK(flag == 0 || flag == 1) << "Don't use include and exclude together"; if (p == phase) include = false; flag = 1; } // neural net only include the specified layer in the include field for (auto p : layer.include()) { - // check whether both exclude and include field - // appear in the same .conf file - CHECK(flag == 0 || flag == 2) - << "include and exclude field should not simultaneously" - << " appear in the same .conf file"; + CHECK(flag == 0 || flag == 2) << "Don't use include and exclude together"; if (p == phase) { include = true; break; @@ -78,21 +98,19 @@ NeuralNet* NeuralNet::Create(const NetProto& net_conf, Phase phase, if (!layer_conf->has_partition_dim()) layer_conf->set_partition_dim(net_conf.partition_dim()); } - //LOG(INFO) << "Before unrolling: \n" << conf.DebugString(); + // LOG(INFO) << "Before unrolling: \n" << conf.DebugString(); conf = Unrolling (conf); // Copy shared parameters for sharing param conf - std::unordered_map<string, ParamProto*> name2param; std::vector<ParamProto*> shares; + std::unordered_map<string, ParamProto*> name2param; for (int index = 0; index < conf.layer_size();index ++) { LayerProto* layer = conf.mutable_layer(index); for (int i = 0; i < layer->param_size(); i++) { ParamProto* param = layer->mutable_param(i); - if (param->has_name() && param->name() != "") { - CHECK(name2param.find(param->name()) == name2param.end()) - << "param name is repeated: " << param->name(); + CHECK(name2param.find(param->name()) == name2param.end()) + << "Repeated param = " << param->name(); name2param[param->name()] = param; - } if (param->has_share_from() && param->share_from() != "") shares.push_back(param); } @@ -101,99 +119,108 @@ NeuralNet* NeuralNet::Create(const NetProto& net_conf, Phase phase, const std::string from = param->share_from(); const std::string name = param->name(); CHECK(name2param.find(from) != name2param.end()) - << "can't find param " << from; + << "can't find share_from = " << from; // CopyFrom will overwrite the name and share_from fields param->CopyFrom(*name2param.at(from)); param->set_name(name); param->set_share_from(from); } LOG(INFO) << "Initial NeuralNet Config is\n" << conf.DebugString(); - // TODO(wangwei) create net based on net type, e.g., directed, undirected, etc + // TODO(wangwei) create net based on net type, e.g., directed, undirected. return new NeuralNet(conf, npartitions); } const NetProto NeuralNet::Unrolling(const NetProto& net_conf) { - // Step 1: Unroll each layer & set parameter sharing - NetProto conf; + // Step 1: Unroll each layer & set parameter sharing + NetProto conf; - std::vector<std::vector<int>> layer_groups; - std::unordered_map<string,int> org_layer_names; - for (int index = 0; index < net_conf.layer_size(); index ++) { - const LayerProto& org_layer = net_conf.layer(index); - org_layer_names[org_layer.name()] = index; // layer_name -> index + std::vector<std::vector<int>> layer_groups; + std::unordered_map<string,int> org_layer_names; + for (int index = 0; index < net_conf.layer_size(); index ++) { + const LayerProto& org_layer = net_conf.layer(index); + org_layer_names[org_layer.name()] = index; // layer_name -> index - std::vector<int> layer_group; - for (int i = 0; i < org_layer.unroll_len(); i ++) { // unroll - LayerProto* unroll_layer = conf.add_layer(); - unroll_layer->CopyFrom(org_layer); // create a new layer conf - if (org_layer.unroll_len() > 1) { - // update layer names - std::stringstream sstm; - sstm << unroll_layer->name() << "_" << i; - unroll_layer->set_name(sstm.str()); - // update layer parameter sharing - for (int j = 0; j < unroll_layer->param_size(); j ++) { - ParamProto* param = unroll_layer->mutable_param(j); - if (i == 0) continue; // no need to rename parameters in the i-th unrolled layer - if (!param->has_share_from() || param->share_from() == "") {// not shared from others - param->set_share_from(param->name()); - } - std::stringstream sstm1; - sstm1 << param->name() << "_" << i; - param->set_name(sstm1.str()); - } - } - // clear unrolling related fields - unroll_layer->clear_unroll_len(); - unroll_layer->clear_unroll_conn_type(); - unroll_layer->clear_shift(); - unroll_layer->clear_srclayers(); + std::vector<int> layer_group; + for (int i = 0; i < org_layer.unroll_len(); i ++) { // unroll + LayerProto* unroll_layer = conf.add_layer(); + unroll_layer->CopyFrom(org_layer); // create a new layer conf + if (org_layer.unroll_len() > 1) { + // update layer names + std::stringstream sstm; + sstm << i << '#' << unroll_layer->name(); + unroll_layer->set_name(sstm.str()); + unroll_layer->set_unroll_index(i); + // update layer parameter sharing + for (int j = 0; j < unroll_layer->param_size(); j ++) { + ParamProto* param = unroll_layer->mutable_param(j); + if (i > 0) { + param->set_share_from("0#" + param->name()); + } + std::stringstream sstm1; + sstm1 << i << '#' << param->name(); + param->set_name(sstm1.str()); + } + } + // clear unrolling related fields + unroll_layer->clear_unroll_len(); + unroll_layer->clear_unroll_conn_type(); + unroll_layer->clear_shift(); + unroll_layer->clear_srclayers(); - layer_group.push_back(conf.layer_size() - 1); - } - layer_groups.push_back(layer_group); - } - // Step 2: Connect unrolled layers by setting `srclayers` - for (int index = 0; index < net_conf.layer_size(); index ++) { - const LayerProto& org_layer = net_conf.layer(index); - if (org_layer.srclayers_size() == 0) continue; // no src layer - //TODO(fanju): add LSTM when it is ready - if (org_layer.type() == kGRU) { // connect GRU layers - for (unsigned int j = 1; j < layer_groups[index].size(); j ++) { - LayerProto* unroll_layer = conf.mutable_layer(layer_groups[index][j]); - unroll_layer->add_srclayers(conf.layer(layer_groups[index][j-1]).name()); - } - } - for (int i = 0; i < org_layer.srclayers_size(); i ++) { - const string& org_layer_src = org_layer.srclayers(i); + layer_group.push_back(conf.layer_size() - 1); + // LOG(ERROR) << "unrolling layer " << unroll_layer->name(); + } + layer_groups.push_back(layer_group); + } + // Step 2: Connect unrolled layers by setting `srclayers` + for (int index = 0; index < net_conf.layer_size(); index ++) { + const LayerProto& org_layer = net_conf.layer(index); + if (org_layer.srclayers_size() == 0) + continue; // no src layer + for (int i = 0; i < org_layer.srclayers_size(); i ++) { + const string& org_layer_src = org_layer.srclayers(i); + singa::UnrollConnType unroll_conn_type = kUnrollOneToOne; + if (i < org_layer.unroll_conn_type_size()) + unroll_conn_type = org_layer.unroll_conn_type(i); + unsigned int shift = 0; + if (i < org_layer.shift_size()) + shift = org_layer.shift(i); - singa::UnrollConnType unroll_conn_type = kUnrollOneToOne; // Default value - if (i < org_layer.unroll_conn_type_size()) unroll_conn_type = org_layer.unroll_conn_type(i); - unsigned int shift = 0; // Default shift value - if (i < org_layer.shift_size()) shift = org_layer.shift(i); + const std::vector<int> unroll_layer_srcs + = layer_groups[org_layer_names[org_layer_src]]; - const std::vector<int> unroll_layer_srcs = layer_groups[org_layer_names[org_layer_src]]; + for (unsigned int j = 0; j < layer_groups[index].size(); j ++) { + LayerProto* unroll_layer = conf.mutable_layer(layer_groups[index][j]); + // Update src layers of `unroll_layer` by considering the types + if (unroll_conn_type == kUnrollOneToAll) { + for (int unroll_layer_src : unroll_layer_srcs) { + unroll_layer->add_srclayers(conf.layer(unroll_layer_src).name()); + } + } else if (unroll_conn_type == kUnrollOneToOne) { + if (j < shift) continue; // no need to connect with the src + int unroll_layer_src = unroll_layer_srcs[j - shift]; + unroll_layer->add_srclayers(conf.layer(unroll_layer_src).name()); + } else if (unroll_conn_type == kUnrollFirstToLast) { + if (j > 0) break; + int unroll_layer_src = + unroll_layer_srcs[unroll_layer_srcs.size() - 1]; + unroll_layer->add_srclayers(conf.layer(unroll_layer_src).name()); + } + } + } - for (unsigned int j = 0; j < layer_groups[index].size(); j ++) { - LayerProto* unroll_layer = conf.mutable_layer(layer_groups[index][j]); - // Update src layers of `unroll_layer` by considering the types - if (unroll_conn_type == kUnrollOneToAll) { - for (int unroll_layer_src : unroll_layer_srcs) { - unroll_layer->add_srclayers(conf.layer(unroll_layer_src).name()); - } - } else if (unroll_conn_type == kUnrollOneToOne) { - if (j < shift) continue; // no need to connect with the src - int unroll_layer_src = unroll_layer_srcs[j - shift]; - unroll_layer->add_srclayers(conf.layer(unroll_layer_src).name()); - } else if (unroll_conn_type == kUnrollFirstToLast) { - if (j > 0) break; - int unroll_layer_src = unroll_layer_srcs[unroll_layer_srcs.size() - 1]; - unroll_layer->add_srclayers(conf.layer(unroll_layer_src).name()); - } - } - } - } - return conf; + //TODO(fanju): add LSTM when it is ready + if (org_layer.type() == kGRU) { // connect GRU layers + for (unsigned int j = 1; j < layer_groups[index].size(); j ++) { + LayerProto* unroll_layer = conf.mutable_layer(layer_groups[index][j]); + string srcname = conf.layer(layer_groups[index][j-1]).name(); + unroll_layer->add_srclayers(srcname); + // LOG(ERROR) << "connect " << unroll_layer->name() << " from " << srcname; + } + } + + } + return conf; } @@ -202,10 +229,12 @@ NeuralNet::NeuralNet(NetProto netproto, int npartitions) { auto graph = CreateGraph(netproto, npartitions); CreateNetFromGraph(graph); PrepareDataStructures(); + for (Node* node : graph->nodes()) delete static_cast<LayerProto*>(node->proto); delete graph; LOG(INFO) << "NeuralNet Constructed"; + unroll_len_ = netproto.unroll_len(); } NeuralNet::~NeuralNet() { @@ -243,7 +272,7 @@ void NeuralNet::ShareParamsFrom(NeuralNet* other, bool cpu_only) { const auto& params = layer->GetParams(); CHECK_EQ(params.size(), otherparams.size()); for (size_t i = 0; i < params.size(); i++) { - params[i]->ShareFrom(otherparams[i], cpu_only); + params[i]->ShareDataFrom(otherparams[i], cpu_only); } } } @@ -442,6 +471,7 @@ Graph* NeuralNet::CreateGraph(const NetProto& netproto, int npartitions) { proto->set_num_partitions(npartitions); Node* node = graph->AddNode(nodename, layer.name(), i, proto); nodes.push_back(node); + // TODO(wangwei) update param name } name2nodes[layer.name()] = nodes; name2proto[layer.name()] = &layer; @@ -526,14 +556,8 @@ void NeuralNet::CreateNetFromGraph(Graph* graph) { layer->Setup(*(static_cast<LayerProto*>(node->proto)), srclayers(layer)); DLOG(INFO) << "constructing graph: " << layer->name(); layerinfo[layer->name()] = IntVecToString(layer->data(nullptr).shape()); - string param_name = "$"; for (auto param : layer->GetParams()) { param->set_id(paramid++); - // if user does not name the param, then name it based on layer name. - if (param->name() == "") { - param->set_name(layer->name() + param_name); - param_name += "$"; - } } if (layer->partition_dim() == 0) share_param_layers[node->origin].push_back(layer); @@ -556,12 +580,25 @@ void NeuralNet::CreateNetFromGraph(Graph* graph) { const string share_from = param->share_from(); if (param->share_from() != "") { if (name2param.find(share_from) != name2param.end()) { - param->ShareFrom(name2param.at(param->share_from()), false); + param->ShareDataFrom(name2param.at(param->share_from()), false); } else { LOG(FATAL) << "No param with the name (share_from) " << share_from; } } } + + // share params due to laye unrolling + for (auto & entry : name2param) { + Param* param = entry.second; + auto pos = param->name().find("#"); + if (pos != std::string::npos && param->owner() != param->id()) { + string from = "0" + param->name().substr(pos); + CHECK(name2param.find(from) != name2param.end()) + << "Can't find owner = " << from << " for param = " << param->name(); + Param* owner = name2param.at(from); + param->ShareFrom(owner); + } + } // share Params for layers generated (partitioned) from the same origin layer for (auto & entry : share_param_layers) { const auto& owner = entry.second.begin(); @@ -570,7 +607,7 @@ void NeuralNet::CreateNetFromGraph(Graph* graph) { auto params = (*it)->GetParams(); CHECK_EQ(params.size(), owner_params.size()); for (size_t i = 0; i < params.size(); i++) - params.at(i)->ShareFrom(owner_params.at(i), true); + params.at(i)->ShareDataFrom(owner_params.at(i), true); } } } http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/959ef705/src/neuralnet/neuron_layer/embedding.cc ---------------------------------------------------------------------- diff --git a/src/neuralnet/neuron_layer/embedding.cc b/src/neuralnet/neuron_layer/embedding.cc new file mode 100644 index 0000000..00e9139 --- /dev/null +++ b/src/neuralnet/neuron_layer/embedding.cc @@ -0,0 +1,98 @@ +/************************************************************ +* +* Licensed to the Apache Software Foundation (ASF) under one +* or more contributor license agreements. See the NOTICE file +* distributed with this work for additional information +* regarding copyright ownership. The ASF licenses this file +* to you under the Apache License, Version 2.0 (the +* "License"); you may not use this file except in compliance +* with the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an +* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +* KIND, either express or implied. See the License for the +* specific language governing permissions and limitations +* under the License. +* +*************************************************************/ + +#include "singa/neuralnet/neuron_layer.h" +#include "singa/utils/math_addr.h" +#include "singa/utils/math_blob.h" +#include "singa/utils/singleton.h" +#include "singa/utils/context.h" + +namespace singa { + +void EmbeddingLayer::Setup(const LayerProto& conf, + const vector<Layer*>& srclayers) { + NeuronLayer::Setup(conf, srclayers); + vocab_size_ = conf.embedding_conf().vocab_size(); + feature_dim_ = conf.embedding_conf().feature_dim(); + vocab_ = Param::Create(conf.param(0)); + vocab_->Setup(vector<int>{vocab_size_, feature_dim_}); + batchsize_ = srclayers.at(0)->data(unroll_index()).shape(0); + data_.Reshape(batchsize_, feature_dim_); + grad_.ReshapeLike(data_); +} + +void EmbeddingLayer::ComputeFeature(int flag, const vector<Layer*>& srclayers) { + const float* word_idx = srclayers.at(0)->data(unroll_index()).cpu_data(); + int device = Singleton<Context>::Instance()->device_id(); + if (device == -1) { + const float* src = vocab_->data().cpu_data(); + float* dst = data_.mutable_cpu_data(); + for (int i = 0; i < batchsize_; i++) { + memcpy(dst + i * feature_dim_, + src + static_cast<int>(word_idx[i]) * feature_dim_, + feature_dim_ * sizeof(float)); + } + } else { +#ifdef USE_GPU + const float* src = vocab_->data().gpu_data(); + float* dst = data_.mutable_gpu_data(); + for (int i = 0; i < batchsize_; i++) { + cudaMemcpy(dst + i * feature_dim_, + src + static_cast<int>(word_idx[i]) * feature_dim_, + feature_dim_ * sizeof(float), cudaMemcpyDefault); + } +#else + LOG(FATAL) << "Not implemented"; +#endif + } +} + +void EmbeddingLayer::ComputeGradient(int flag, const vector<Layer*>& srclayers) +{ + const float* word_idx = srclayers.at(0)->data(unroll_index()).cpu_data(); + auto context = Singleton<Context>::Instance(); + if ((flag & kAggGrad) == 0) + Zero(vocab_->mutable_grad()); + + if (context->device_id() == -1) { + const float* src = grad_.cpu_data(); + float* dst = vocab_->mutable_grad()->mutable_cpu_data(); + memset(dst, 0 , sizeof(float) * grad_.count()); + for (int i = 0; i < batchsize_; i++) { + cpu_axpy(feature_dim_, 1.0f, src + i * feature_dim_, + dst + static_cast<int>(word_idx[i]) * feature_dim_); + } + } else { +#ifdef USE_GPU + const float* src = grad_.gpu_data(); + float* dst = vocab_->mutable_grad()->mutable_gpu_data(); + for (int i = 0; i < batchsize_; i++) { + gpu_axpy(context->cublas_handle(), grad_.count(), 1.0f, + src + i * feature_dim_, + dst + static_cast<int>(word_idx[i]) * feature_dim_); + } +#else + LOG(FATAL) << "Not implemented"; +#endif + } +} + +} // namespace singa http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/959ef705/src/neuralnet/neuron_layer/gru.cc ---------------------------------------------------------------------- diff --git a/src/neuralnet/neuron_layer/gru.cc b/src/neuralnet/neuron_layer/gru.cc index 45d7873..9ba5a50 100644 --- a/src/neuralnet/neuron_layer/gru.cc +++ b/src/neuralnet/neuron_layer/gru.cc @@ -64,6 +64,8 @@ void GRULayer::Setup(const LayerProto& conf, data_.Reshape(vector<int>{batchsize_, hdim_}); grad_.ReshapeLike(data_); + // one for grad from dst GRU, one for grad from upper layer + gradvec_.push_back(new Blob<float>(grad_.shape())); // Initialize the parameters weight_z_hx_ = Param::Create(conf.param(0)); @@ -74,7 +76,7 @@ void GRULayer::Setup(const LayerProto& conf, weight_r_hh_ = Param::Create(conf.param(4)); weight_c_hh_ = Param::Create(conf.param(5)); - if (conf.gru_conf().bias_term()) { + if (conf.param_size() > 6) { bias_z_ = Param::Create(conf.param(6)); bias_r_ = Param::Create(conf.param(7)); bias_c_ = Param::Create(conf.param(8)); @@ -88,7 +90,7 @@ void GRULayer::Setup(const LayerProto& conf, weight_r_hh_->Setup(vector<int>{hdim_, hdim_}); weight_c_hh_->Setup(vector<int>{hdim_, hdim_}); - if (conf.gru_conf().bias_term()) { + if (conf.param_size() > 6) { bias_z_->Setup(vector<int>{hdim_}); bias_r_->Setup(vector<int>{hdim_}); bias_c_->Setup(vector<int>{hdim_}); @@ -97,7 +99,6 @@ void GRULayer::Setup(const LayerProto& conf, update_gate = new Blob<float>(batchsize_, hdim_); reset_gate = new Blob<float>(batchsize_, hdim_); new_memory = new Blob<float>(batchsize_, hdim_); - } void GRULayer::ComputeFeature(int flag, @@ -105,11 +106,11 @@ void GRULayer::ComputeFeature(int flag, CHECK_LE(srclayers.size(), 2); // Do transpose - Blob<float> *w_z_hx_t = Transpose (weight_z_hx_->data()); + Blob<float> *w_z_hx_t = Transpose (weight_z_hx_->data()); Blob<float> *w_z_hh_t = Transpose (weight_z_hh_->data()); - Blob<float> *w_r_hx_t = Transpose (weight_r_hx_->data()); + Blob<float> *w_r_hx_t = Transpose (weight_r_hx_->data()); Blob<float> *w_r_hh_t = Transpose (weight_r_hh_->data()); - Blob<float> *w_c_hx_t = Transpose (weight_c_hx_->data()); + Blob<float> *w_c_hx_t = Transpose (weight_c_hx_->data()); Blob<float> *w_c_hh_t = Transpose (weight_c_hh_->data()); // Prepare the data input and the context @@ -123,49 +124,34 @@ void GRULayer::ComputeFeature(int flag, // Compute the update gate GEMM(1.0f, 0.0f, src,*w_z_hx_t,update_gate); - if (bias_z_ != nullptr) + if (bias_z_ != nullptr) MVAddRow(1.0f,1.0f,bias_z_->data(),update_gate); - Blob<float> zprev (batchsize_,hdim_); - GEMM(1.0f, 0.0f, *context,*w_z_hh_t, &zprev); - Add<float>(*update_gate, zprev, update_gate); + GEMM(1.0f, 1.0f, *context, *w_z_hh_t, update_gate); Map<op::Sigmoid<float>,float>(*update_gate, update_gate); // Compute the reset gate GEMM(1.0f, 0.0f, src,*w_r_hx_t,reset_gate); if (bias_r_ != nullptr) MVAddRow(1.0f,1.0f,bias_r_->data(),reset_gate); - Blob<float> rprev (batchsize_, hdim_); - GEMM(1.0f, 0.0f, *context, *w_r_hh_t, &rprev); - Add<float>(*reset_gate, rprev, reset_gate); + GEMM(1.0f, 1.0f, *context, *w_r_hh_t, reset_gate); Map<op::Sigmoid<float>,float>(*reset_gate, reset_gate); // Compute the new memory GEMM(1.0f, 0.0f, src, *w_c_hx_t, new_memory); if (bias_c_ != nullptr) MVAddRow(1.0f,1.0f,bias_c_->data(), new_memory); - Blob<float> cprev (batchsize_, hdim_); - GEMM(1.0f, 0.0f, *context, *w_c_hh_t, &cprev); - //Blob<float> new_cprev (batchsize_, hdim_); - Mult<float>(*reset_gate, cprev, &cprev); - Add<float>(*new_memory, cprev, new_memory); + Mult<float>(*reset_gate, *new_memory, new_memory); + GEMM(1.0f, 1.0f, *context, *w_c_hh_t, new_memory); Map<op::Tanh<float>,float>(*new_memory, new_memory); - // Compute data - new memory part - Blob<float> z1 (batchsize_,hdim_); - for (int i = 0; i < z1.count(); i ++) { - z1.mutable_cpu_data()[i] = 1.0f; // generate a matrix with ones - } - AXPY<float>(-1.0f, *update_gate, &z1); - Mult<float>(z1, *new_memory, &data_); - // Compute data - context part - Blob<float> data_prev (batchsize_, hdim_); - Mult<float>(*update_gate,*context,&data_prev); - Add<float>(data_, data_prev, &data_); + Sub(*context, *new_memory, &data_); + Mult(data_, *update_gate, &data_); + Add(data_, *new_memory, &data_); // delete the pointers - if (srclayers.size() == 1) delete context; - else context = NULL; + if (srclayers.size() == 1) + delete context; delete w_z_hx_t; delete w_z_hh_t; @@ -178,14 +164,20 @@ void GRULayer::ComputeFeature(int flag, void GRULayer::ComputeGradient(int flag, const vector<Layer*>& srclayers) { CHECK_LE(srclayers.size(), 2); + // agg grad from two dst layers + AXPY(1.0f, *gradvec_[1], &grad_); + float beta = 1.0f; // agg param gradients + Layer* ilayer = srclayers[0]; // input layer + Layer* clayer = nullptr; // context layer // Prepare the data input and the context - const Blob<float>& src = srclayers[0]->data(this); + const Blob<float>& src = ilayer->data(this); const Blob<float> *context; if (srclayers.size() == 1) { // only have data input context = new Blob<float>(batchsize_, hdim_); } else { // have data input & context - context = &srclayers[1]->data(this); + clayer = srclayers[1]; + context = &(clayer->data(this)); } // Prepare gradient of output neurons @@ -197,7 +189,7 @@ void GRULayer::ComputeGradient(int flag, Blob<float> drgatedr (batchsize_, hdim_); Map<singa::op::SigmoidGrad<float>, float>(*reset_gate, &drgatedr); Blob<float> dnewmdc (batchsize_, hdim_); - Map<singa::op::TanhGrad<float>, float>(*new_memory,&dnewmdc); + Map<singa::op::TanhGrad<float>, float>(*new_memory, &dnewmdc); Blob<float> dLdz (batchsize_, hdim_); Sub<float>(*context, *new_memory, &dLdz); @@ -206,9 +198,7 @@ void GRULayer::ComputeGradient(int flag, Blob<float> dLdc (batchsize_,hdim_); Blob<float> z1 (batchsize_,hdim_); - for (int i = 0; i < z1.count(); i ++) { - z1.mutable_cpu_data()[i] = 1.0f; // generate a matrix with ones - } + z1.SetValue(1.0f); AXPY<float>(-1.0f, *update_gate, &z1); Mult(grad_,z1,&dLdc); Mult(dLdc,dnewmdc,&dLdc); @@ -218,57 +208,58 @@ void GRULayer::ComputeGradient(int flag, Blob<float> dLdr (batchsize_, hdim_); Blob<float> cprev (batchsize_, hdim_); - Blob<float> *w_c_hh_t = Transpose(weight_c_hh_->data()); - GEMM(1.0f,0.0f,*context,*w_c_hh_t, &cprev); - delete w_c_hh_t; - Mult(dLdc,cprev,&dLdr); - Mult(dLdr,drgatedr,&dLdr); - + GEMM(1.0f, 0.0f, *context, weight_c_hh_->data().T(), &cprev); + Mult(dLdc, cprev, &dLdr); + Mult(dLdr, drgatedr, &dLdr); // Compute gradients for parameters of update gate Blob<float> *dLdz_t = Transpose(dLdz); - GEMM(1.0f,0.0f,*dLdz_t,src,weight_z_hx_->mutable_grad()); - GEMM(1.0f,0.0f,*dLdz_t,*context,weight_z_hh_->mutable_grad()); + GEMM(1.0f, beta, *dLdz_t, src, weight_z_hx_->mutable_grad()); + GEMM(1.0f, beta, *dLdz_t, *context, weight_z_hh_->mutable_grad()); if (bias_z_ != nullptr) - MVSumRow<float>(1.0f,0.0f,dLdz,bias_z_->mutable_grad()); + MVSumRow<float>(1.0f, beta, dLdz, bias_z_->mutable_grad()); delete dLdz_t; // Compute gradients for parameters of reset gate Blob<float> *dLdr_t = Transpose(dLdr); - GEMM(1.0f,0.0f,*dLdr_t,src,weight_r_hx_->mutable_grad()); - GEMM(1.0f,0.0f,*dLdr_t,*context,weight_r_hh_->mutable_grad()); + GEMM(1.0f, beta, *dLdr_t, src, weight_r_hx_->mutable_grad()); + GEMM(1.0f, beta, *dLdr_t, *context, weight_r_hh_->mutable_grad()); if (bias_r_ != nullptr) - MVSumRow(1.0f,0.0f,dLdr,bias_r_->mutable_grad()); + MVSumRow(1.0f, beta, dLdr, bias_r_->mutable_grad()); delete dLdr_t; // Compute gradients for parameters of new memory Blob<float> *dLdc_t = Transpose(dLdc); - GEMM(1.0f,0.0f,*dLdc_t,src,weight_c_hx_->mutable_grad()); + GEMM(1.0f, beta, *dLdc_t, src,weight_c_hx_->mutable_grad()); if (bias_c_ != nullptr) - MVSumRow(1.0f,0.0f,dLdc,bias_c_->mutable_grad()); + MVSumRow(1.0f, beta, dLdc, bias_c_->mutable_grad()); delete dLdc_t; Blob<float> *reset_dLdc_t = Transpose(reset_dLdc); - GEMM(1.0f,0.0f,*reset_dLdc_t,*context,weight_c_hh_->mutable_grad()); + GEMM(1.0f, beta, *reset_dLdc_t, *context, weight_c_hh_->mutable_grad()); delete reset_dLdc_t; // Compute gradients for data input layer if (srclayers[0]->mutable_grad(this) != nullptr) { - GEMM(1.0f,0.0f,dLdc,weight_c_hx_->data(),srclayers[0]->mutable_grad(this)); - GEMM(1.0f,1.0f,dLdz,weight_z_hx_->data(),srclayers[0]->mutable_grad(this)); - GEMM(1.0f,1.0f,dLdr,weight_r_hx_->data(), srclayers[0]->mutable_grad(this)); + GEMM(1.0f,0.0f,dLdc, weight_c_hx_->data(), ilayer->mutable_grad(this)); + GEMM(1.0f,1.0f,dLdz, weight_z_hx_->data(), ilayer->mutable_grad(this)); + GEMM(1.0f,1.0f,dLdr, weight_r_hx_->data(), ilayer->mutable_grad(this)); } - if (srclayers.size() > 1 && srclayers[1]->mutable_grad(this) != nullptr) { + if (clayer != nullptr && clayer->mutable_grad(this) != nullptr) { // Compute gradients for context layer - GEMM(1.0f,0.0f,reset_dLdc,weight_c_hh_->data(), srclayers[1]->mutable_grad(this)); - GEMM(1.0f,1.0f,dLdr, weight_r_hh_->data(), srclayers[1]->mutable_grad(this)); - GEMM(1.0f,1.0f,dLdz,weight_z_hh_->data(), srclayers[1]->mutable_grad(this)); - Add(srclayers[1]->grad(this), *update_gate, srclayers[1]->mutable_grad(this)); + GEMM(1.0f, 0.0f, reset_dLdc, weight_c_hh_->data(), + clayer->mutable_grad(this)); + GEMM(1.0f, 1.0f, dLdr, weight_r_hh_->data(), clayer->mutable_grad(this)); + GEMM(1.0f, 1.0f, dLdz, weight_z_hh_->data(), clayer->mutable_grad(this)); + Add(clayer->grad(this), *update_gate, clayer->mutable_grad(this)); + // LOG(ERROR) << "grad to prev gru " << Asum(clayer->grad(this)); } - if (srclayers.size() == 1) delete context; - else context = NULL; + if (srclayers.size() == 1) + delete context; + else + context = NULL; delete grad_t; } http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/959ef705/src/neuralnet/neuron_layer/inner_product.cc ---------------------------------------------------------------------- diff --git a/src/neuralnet/neuron_layer/inner_product.cc b/src/neuralnet/neuron_layer/inner_product.cc index f50afba..1e5e93e 100644 --- a/src/neuralnet/neuron_layer/inner_product.cc +++ b/src/neuralnet/neuron_layer/inner_product.cc @@ -66,12 +66,17 @@ void InnerProductLayer::ComputeFeature(int flag, void InnerProductLayer::ComputeGradient(int flag, const vector<Layer*>& srclayers) { - - MVSumRow(1.0f, 0.0f, grad_, bias_->mutable_grad()); + float beta = 0.0f; + if (flag & kAggGrad) + beta = 1.0f; + MVSumRow(1.0f, beta, grad_, bias_->mutable_grad()); if (transpose_) - MMDot(srclayers[0]->data(this).T(), grad_, weight_->mutable_grad()); + GEMM(1.0f, beta, srclayers[0]->data(this).T(), grad_, + weight_->mutable_grad()); else - MMDot(grad_.T(), srclayers[0]->data(this), weight_->mutable_grad()); + GEMM(1.0f, beta, grad_.T(), srclayers[0]->data(this), + weight_->mutable_grad()); + if (srclayers[0]->mutable_grad(this) != nullptr) { if (transpose_) MMDot(grad_, weight_->data().T(), srclayers[0]->mutable_grad(this)); http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/959ef705/src/proto/job.proto ---------------------------------------------------------------------- diff --git a/src/proto/job.proto b/src/proto/job.proto index 7cdc287..e520eba 100644 --- a/src/proto/job.proto +++ b/src/proto/job.proto @@ -120,6 +120,14 @@ message NetProto { repeated LayerProto layer = 1; // partitioning type for parallelism optional int32 partition_dim = 20 [default = 0]; + // Each layer corresponds to a group of unrolled layers, used in RNN models + repeated LayerGroupProto layer_group = 21; + optional int32 unroll_len = 22 [default = 1]; +} + +message LayerGroupProto { + // name of the layers belong to the same group + repeated string layer = 1; } message UpdaterProto { @@ -139,6 +147,9 @@ message UpdaterProto { // used to avoid divide by 0, i.e. x/(y+delta) optional float delta = 35 [default = 0.00000001]; + optional float clip_low = 36 [default = 0]; + optional float clip_high = 37 [default = 0]; + extensions 101 to 200; } @@ -195,10 +206,11 @@ message LayerProto { // share data and grad blob with the single src layer, e.g., relu layer can // share blobs from conv layer. It is useful for saving memory space. optional bool share_src_blobs = 22 [default = false]; - + // for unrolling layers in RNN model optional int32 unroll_len = 23 [default = 1]; - repeated UnrollConnType unroll_conn_type = 24; - repeated int32 shift = 25; + optional int32 unroll_index = 24 [default = 0]; + repeated UnrollConnType unroll_conn_type = 25; + repeated int32 shift = 26; // overrides the partition dimension for neural net optional int32 partition_dim = 60 [default = -1]; @@ -215,6 +227,7 @@ message LayerProto { optional MnistProto mnist_conf = 192; optional RGBImageProto rgbimage_conf = 193; optional DataProto sharddata_conf = 194; + optional CharRNNProto char_rnn_conf = 195; // configuration for neuron layers id range [200, 300) optional ActivationProto activation_conf = 200; @@ -228,6 +241,7 @@ message LayerProto { optional ReLUProto relu_conf = 211; optional SoftmaxProto softmax_conf = 214; optional GRUProto gru_conf = 215; + optional EmbeddingProto embedding_conf = 216; // configuration for loss layers, id range [300, 400) optional SoftmaxLossProto softmaxloss_conf = 301; @@ -354,7 +368,19 @@ message StoreProto { optional int32 random_skip = 11 [default = 0]; optional bool has_label = 12 [default = true]; } +message CharRNNProto { + optional string path = 1; + optional string vocab_path = 2; + // num of chars to read per instance, should = NetProto::unroll_len + optional int32 unroll_len = 3 [default = 50]; + optional int32 batchsize = 4 [default = 1]; +} +message EmbeddingProto { + optional int32 vocab_size = 1 [default = 0]; + optional int32 feature_dim = 2 [default = 100]; + +} message SoftmaxLossProto { // computing accuracy against topk results optional int32 topk = 1 [default = 1]; @@ -572,6 +598,8 @@ enum AlgType { kBP = 1; // Contrastive Divergence algorithm for RBM, DBM, etc. kCD = 2; + // BPTT for training RNN models + kBPTT = 3; // For user defined algorithm. kUserAlg = 104; } @@ -590,6 +618,9 @@ enum LayerType { kMnist = 192; // deprecated kRGBImage = 193; // deprecated kShardData = 194; // deprecated + kCharRNN = 195; + kRNNLabel = 196; + kOneHot = 197; /* * Neuron layers @@ -610,6 +641,8 @@ enum LayerType { kSigmoid = 213; kSoftmax = 214; kGRU = 215; + kEmbedding = 216; + // cudnn v3 kCudnnConv = 250; kCudnnPool = 251; @@ -678,6 +711,9 @@ enum Phase { kBackward = 64; kLoss = 128; kDeploy = 256; + + // used for aggregate parameter gradients when Param is shared + kAggGrad = 512; } enum ParamType { http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/959ef705/src/stub.cc ---------------------------------------------------------------------- diff --git a/src/stub.cc b/src/stub.cc index 7c0ec90..c06128c 100644 --- a/src/stub.cc +++ b/src/stub.cc @@ -83,6 +83,8 @@ const std::unordered_map<int, ParamEntry*> CreateParamShard( int grp = entry.first; int wstart = grp2workers[grp].first, wend = grp2workers[grp].second; for (auto layer : entry.second->layers()) { + if (layer->unroll_index() > 0) + continue; int partition = layer->partition_id(); bool local = partition >= wstart && partition < wend; for (auto param : layer->GetParams()) { http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/959ef705/src/test/test_gru_layer.cc ---------------------------------------------------------------------- diff --git a/src/test/test_gru_layer.cc b/src/test/test_gru_layer.cc index 296b795..e432ae1 100644 --- a/src/test/test_gru_layer.cc +++ b/src/test/test_gru_layer.cc @@ -239,7 +239,6 @@ TEST_F(GRULayerTest, ComputeFeature) { singa::GRULayer gru_layer_2; gru_layer_2.Setup(gru2_conf, std::vector<singa::Layer*>{&in_layer_2, &gru_layer_1}); - for (unsigned int i = 0; i < gru_layer_2.GetParams().size(); i ++) { gru_layer_2.GetParams()[i]->InitValues(); } http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/959ef705/src/test/test_math.cc ---------------------------------------------------------------------- diff --git a/src/test/test_math.cc b/src/test/test_math.cc index 2e7deec..2627b2e 100644 --- a/src/test/test_math.cc +++ b/src/test/test_math.cc @@ -286,7 +286,6 @@ TEST(MathTest, TestAxpyGPU) { TEST(MathTest, TestDotGPU) { float A[12]; float B[12]; - for (int i = 0; i < 12; i++) { A[i] = i - 1; B[i] = i + 1; http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/959ef705/src/utils/common.cc ---------------------------------------------------------------------- diff --git a/src/utils/common.cc b/src/utils/common.cc index 928d4bb..eefdb5c 100644 --- a/src/utils/common.cc +++ b/src/utils/common.cc @@ -570,4 +570,30 @@ void WriteStringToTextFile(const string& filename, const string& context) { ofs.flush(); ofs.close(); } + + +const vector<std::pair<string, float>> GetMetricFromString(const string& disp) { + size_t pos = 0; + vector<string> terms; + while (pos != string::npos) { + auto next = disp.find_first_of(" ,", pos); // delimiter: space or comma + if (next != string::npos) { + terms.push_back(disp.substr(pos, next - pos)); + pos = disp.find_first_not_of(" ,", next + 1); + } else { + break; + } + } + if (pos != string::npos) + terms.push_back(disp.substr(pos)); + vector<std::pair<string, float>> ret; + for (unsigned i = 0; i < terms.size(); i++) { + if (terms[i] == "=") { + CHECK_GE(i, 1); + CHECK_LT(i, terms.size() - 1) << "terms[i] = " << terms[i]; + ret.push_back(std::make_pair(terms[i-1], std::stof(terms[i + 1]))); + } + } + return ret; +} } // namespace singa http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/959ef705/src/utils/param.cc ---------------------------------------------------------------------- diff --git a/src/utils/param.cc b/src/utils/param.cc index bdae72f..95396bc 100644 --- a/src/utils/param.cc +++ b/src/utils/param.cc @@ -166,7 +166,12 @@ void Param::InitValues(int version) { set_version(version); } -void Param::ShareFrom(Param* other, bool cpu_only) { +void Param::ShareDataFrom(Param* other, bool cpu_only) { + if (this == other) { + LOG(WARNING) << "No need to share Param with itself"; + return; + } + proto_.set_owner(other->owner()); CHECK_EQ(data_.count(), other->data_.count()); data_.ShareData(&(other->data_), cpu_only); @@ -183,6 +188,16 @@ void Param::ShareFrom(Param* other, bool cpu_only) { pending_update_.resize(other->pending_update_.size()); } +void Param::ShareFrom(Param* other) { + if (this == other) { + LOG(WARNING) << "No need to share Param with itself"; + return; + } + + ShareDataFrom(other, false); + grad_.ShareData(&(other->grad_), false); +} + void Param::FromProto(const BlobProto& blob) { data_.FromProto(blob); } http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/959ef705/src/utils/updater.cc ---------------------------------------------------------------------- diff --git a/src/utils/updater.cc b/src/utils/updater.cc index 21608fa..a9f70c0 100644 --- a/src/utils/updater.cc +++ b/src/utils/updater.cc @@ -101,9 +101,24 @@ void Updater::Init(const UpdaterProto& proto) { momentum_ = proto.momentum(); weight_decay_ = proto.weight_decay(); lr_gen_ = LRGenerator::Create(proto.learning_rate()); + clip_low_ = proto.clip_low(); + clip_high_ = proto.clip_high(); +} + +void Updater::Clip(const float low, const float high, Param* param) { + Blob<float>* grad = param->mutable_grad(); + float* ptr = grad->mutable_cpu_data(); + for (int i = 0; i < grad->count(); i++) { + if (ptr[i] > high) + ptr[i] = high; + else if (ptr[i] < low) + ptr[i] = low; + } } void SGDUpdater::Update(int step, Param* param, float grad_scale) { + if (clip_high_ > clip_low_) + Clip(clip_low_, clip_high_, param); Shape<1> s = Shape1(param->size()); Tensor<cpu, 1> data(param->mutable_cpu_data(), s); Tensor<cpu, 1> grad(param->mutable_cpu_grad(), s); @@ -143,6 +158,8 @@ void NesterovUpdater::Update(int step, Param* param, float grad_scale) { } /***********************AdaGrad******************************/ void AdaGradUpdater::Update(int step, Param* param, float grad_scale) { + if (clip_high_ > clip_low_) + Clip(clip_low_, clip_high_, param); Shape<1> s = Shape1(param->size()); Tensor<cpu, 1> data(param->mutable_cpu_data(), s); Tensor<cpu, 1> grad(param->mutable_cpu_grad(), s); http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/959ef705/src/worker.cc ---------------------------------------------------------------------- diff --git a/src/worker.cc b/src/worker.cc index 8495b5c..4e1dc75 100644 --- a/src/worker.cc +++ b/src/worker.cc @@ -29,6 +29,7 @@ #include "singa/utils/factory.h" #include "singa/utils/singleton.h" #include "singa/utils/context.h" +#include "singa/utils/math_blob.h" namespace singa { @@ -214,7 +215,6 @@ void Worker::InitNetParams(const JobProto& job_conf, NeuralNet* net) { } } - void Worker::Checkpoint(int step, const std::string& folder, NeuralNet* net) { BlobProtos bps; for (auto layer : net->layers()) { @@ -338,7 +338,7 @@ void BPWorker::Forward(int step, Phase phase, NeuralNet* net) { map<string, string> label; for (auto& layer : net->layers()) { if (layer->partition_id() == id_) { - if (phase == kTrain) { + if (phase == kTrain && layer->unroll_index() == 0) { // wait until param is updated for (Param* p : layer->GetParams()) { Collect(step, p); @@ -346,7 +346,7 @@ void BPWorker::Forward(int step, Phase phase, NeuralNet* net) { } // DLOG(ERROR) << "Forward " << layer->name(); layer->ComputeFeature(phase | kForward, net->srclayers(layer)); - if (job_conf_.debug() && grp_id_ == 0) + if (job_conf_.debug() && DisplayNow(step) && grp_id_ == 0) label[layer->name()] = layer->ToString(true, phase | kForward); } } @@ -364,7 +364,7 @@ void BPWorker::Backward(int step, NeuralNet* net) { Layer* layer = *it; if (layer->partition_id() == id_) { layer->ComputeGradient(kTrain | kBackward, net->srclayers(layer)); - if (job_conf_.debug() && grp_id_ == 0) + if (job_conf_.debug() && DisplayNow(step) && grp_id_ == 0) label[layer->name()] = layer->ToString(true, kTrain | kBackward); for (Param* p : layer->GetParams()) Update(step, p); @@ -377,6 +377,82 @@ void BPWorker::Backward(int step, NeuralNet* net) { } } +/***************************BPTTWorker*********************************/ +void BPTTWorker::Forward(int step, Phase phase, NeuralNet* net) { + map<string, string> label; + for (auto& layer : net->layers()) { + if (layer->partition_id() == id_) { + if (phase == kTrain && layer->unroll_index() == 0) { + // wait until param is updated + for (Param* p : layer->GetParams()) { + Collect(step, p); + Zero(p->mutable_grad()); + } + } + vector<Layer*> src = net->srclayers(layer); + // if full state rnn and not the starting of a new passing of the dataset, + // feed the hidden state of the last unit to the first unit. + if (layer->unroll_index() == 0 && full_state_ && !begin_) { + Layer* last = net->last_unroll_layer(layer); + if (last != layer) { + src.push_back(last); + } + } + // LOG(ERROR) << layer->name() << " forward"; + // int ret = + layer->ComputeFeature(phase | kForward, src); + /* + if ((phase & Phase::kTrain) && ret == Status::kEnd) + begin_ = true; + */ + + if (job_conf_.debug() && DisplayNow(step) && grp_id_ == 0) + label[layer->name()] = layer->ToString(true, phase | kForward); + } + } + if (label.size()) { + const string path = Cluster::Get()->vis_folder() + "/fp-step" + + std::to_string(step) +"-loc" + std::to_string(id_) + ".json"; + WriteStringToTextFile(path, net->ToGraph(false).ToJson(label)); + } +} + +void BPTTWorker::Backward(int step, NeuralNet* net) { + map<string, string> label; + auto& layers = net->layers(); + for (auto it = layers.rbegin(); it != layers.rend(); it++) { + Layer* layer = *it; + if (layer->partition_id() == id_) { + layer->ComputeGradient(kTrain | kBackward | kAggGrad, net->srclayers(layer)); + // LOG(ERROR) << layer->name() << " backward"; + if (job_conf_.debug() && DisplayNow(step) && grp_id_ == 0) + label[layer->name()] = layer->ToString(true, kTrain | kBackward); + // unrolled layers share parameter data and grad, just update the 1st one + if (layer->unroll_index() == 0) + for (Param* p : layer->GetParams()) + Update(step, p); + } + } + if (label.size()) { + const string path = Cluster::Get()->vis_folder() + "/bp-step" + + std::to_string(step) + "-loc" + std::to_string(id_) + ".json"; + WriteStringToTextFile(path, net->ToGraph(false).Reverse().ToJson(label)); + } +} +void BPTTWorker::Display(int flag, const std::string& prefix, NeuralNet* net) { + std::unordered_map<string, float> perf; + for (auto layer : net->layers()) { + if (layer->partition_id() == id_) { + const string& disp = layer->ToString(false, flag); + for (const auto& entry : GetMetricFromString(disp)) + perf[entry.first] += entry.second; + } + } + string disp = prefix + " "; + for (const auto& entry : perf) + disp += entry.first + " = " + std::to_string(entry.second) + ", "; + LOG(ERROR) << disp; +} /****************************CDWorker**********************************/ void CDWorker::TrainOneBatch(int step, NeuralNet* net) { const auto& layers = net->layers();
