SINGA-120 - Implemented GRU and BPTT: 1) Implemented the unrolling function for BPTT; 2) Added tests for unrolling
Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/777dfb6a Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/777dfb6a Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/777dfb6a Branch: refs/heads/master Commit: 777dfb6a6fc4058b35368dfbe1fba4a27a93e828 Parents: 473c985 Author: Ju Fan <[email protected]> Authored: Fri Jan 1 10:50:20 2016 +0800 Committer: Wei Wang <[email protected]> Committed: Wed Jan 6 01:53:05 2016 +0800 ---------------------------------------------------------------------- include/singa/neuralnet/neuralnet.h | 1 + src/neuralnet/neuralnet.cc | 121 ++++++++-- src/test/test_connection_layers.cc | 8 +- src/test/test_unrolling.cc | 398 +++++++++++++++++++++++++++++++ 4 files changed, 510 insertions(+), 18 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/777dfb6a/include/singa/neuralnet/neuralnet.h ---------------------------------------------------------------------- diff --git a/include/singa/neuralnet/neuralnet.h b/include/singa/neuralnet/neuralnet.h index bc1a7d8..be8f5c8 100644 --- a/include/singa/neuralnet/neuralnet.h +++ b/include/singa/neuralnet/neuralnet.h @@ -58,6 +58,7 @@ class NeuralNet { static NeuralNet* Create(const NetProto& net_conf, Phase phase, int npartitions); + static const NetProto Unrolling(const NetProto& net_conf); /** * construct the net structure from protocol buffer. * @param netproto neural net config http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/777dfb6a/src/neuralnet/neuralnet.cc ---------------------------------------------------------------------- diff --git a/src/neuralnet/neuralnet.cc b/src/neuralnet/neuralnet.cc index 9cdaff4..6bb0ecd 100644 --- a/src/neuralnet/neuralnet.cc +++ b/src/neuralnet/neuralnet.cc @@ -24,6 +24,8 @@ #include <algorithm> #include <queue> #include "singa/utils/singleton.h" +#include <unordered_map> +using namespace std; namespace singa { @@ -36,9 +38,6 @@ NeuralNet* NeuralNet::Create(const NetProto& net_conf, Phase phase, NetProto conf; conf.CopyFrom(net_conf); conf.clear_layer(); - // for sharing param conf - std::unordered_map<string, ParamProto*> name2param; - std::vector<ParamProto*> shares; // flag=0: neither exclude nor include field appears // flag=1: exclude field appears // flag=2: include field appears @@ -78,16 +77,25 @@ NeuralNet* NeuralNet::Create(const NetProto& net_conf, Phase phase, // using net partition if layer partition is not set if (!layer_conf->has_partition_dim()) layer_conf->set_partition_dim(net_conf.partition_dim()); - for (int i = 0; i < layer_conf->param_size(); i++) { - ParamProto* param = layer_conf->mutable_param(i); - if (param->has_name() && param->name() != "") { - CHECK(name2param.find(param->name()) == name2param.end()) - << "param name is repeated: " << param->name(); - name2param[param->name()] = param; - } - if (param->has_share_from() && param->share_from() != "") - shares.push_back(param); - } + } + //LOG(INFO) << "Before unrolling: \n" << conf.DebugString(); + conf = Unrolling (conf); + + // Copy shared parameters for sharing param conf + std::unordered_map<string, ParamProto*> name2param; + std::vector<ParamProto*> shares; + for (int index = 0; index < conf.layer_size();index ++) { + LayerProto* layer = conf.mutable_layer(index); + for (int i = 0; i < layer->param_size(); i++) { + ParamProto* param = layer->mutable_param(i); + if (param->has_name() && param->name() != "") { + CHECK(name2param.find(param->name()) == name2param.end()) + << "param name is repeated: " << param->name(); + name2param[param->name()] = param; + } + if (param->has_share_from() && param->share_from() != "") + shares.push_back(param); + } } for (auto param : shares) { const std::string from = param->share_from(); @@ -104,6 +112,91 @@ NeuralNet* NeuralNet::Create(const NetProto& net_conf, Phase phase, return new NeuralNet(conf, npartitions); } +const NetProto NeuralNet::Unrolling(const NetProto& net_conf) { + // Step 1: Unroll each layer & set parameter sharing + NetProto conf; + + std::vector<std::vector<int>> layer_groups; + std::unordered_map<string,int> org_layer_names; + for (int index = 0; index < net_conf.layer_size(); index ++) { + const LayerProto& org_layer = net_conf.layer(index); + org_layer_names[org_layer.name()] = index; // layer_name -> index + + std::vector<int> layer_group; + for (int i = 0; i < org_layer.unroll_len(); i ++) { // unroll + LayerProto* unroll_layer = conf.add_layer(); + unroll_layer->CopyFrom(org_layer); // create a new layer conf + if (org_layer.unroll_len() > 1) { + // update layer names + std::stringstream sstm; + sstm << unroll_layer->name() << "_" << i; + unroll_layer->set_name(sstm.str()); + // update layer parameter sharing + for (int j = 0; j < unroll_layer->param_size(); j ++) { + ParamProto* param = unroll_layer->mutable_param(j); + if (i == 0) continue; // no need to rename parameters in the i-th unrolled layer + if (!param->has_share_from() || param->share_from() == "") {// not shared from others + param->set_share_from(param->name()); + } + std::stringstream sstm1; + sstm1 << param->name() << "_" << i; + param->set_name(sstm1.str()); + } + } + // clear unrolling related fields + unroll_layer->clear_unroll_len(); + unroll_layer->clear_unroll_conn_type(); + unroll_layer->clear_shift(); + unroll_layer->clear_srclayers(); + + layer_group.push_back(conf.layer_size() - 1); + } + layer_groups.push_back(layer_group); + } + // Step 2: Connect unrolled layers by setting `srclayers` + for (int index = 0; index < net_conf.layer_size(); index ++) { + const LayerProto& org_layer = net_conf.layer(index); + if (org_layer.srclayers_size() == 0) continue; // no src layer + //TODO(fanju): add LSTM when it is ready + if (org_layer.type() == kGRU) { // connect GRU layers + for (unsigned int j = 1; j < layer_groups[index].size(); j ++) { + LayerProto* unroll_layer = conf.mutable_layer(layer_groups[index][j]); + unroll_layer->add_srclayers(conf.layer(layer_groups[index][j-1]).name()); + } + } + for (int i = 0; i < org_layer.srclayers_size(); i ++) { + const string& org_layer_src = org_layer.srclayers(i); + + singa::UnrollConnType unroll_conn_type = kUnrollOneToOne; // Default value + if (i < org_layer.unroll_conn_type_size()) unroll_conn_type = org_layer.unroll_conn_type(i); + unsigned int shift = 0; // Default shift value + if (i < org_layer.shift_size()) shift = org_layer.shift(i); + + const std::vector<int> unroll_layer_srcs = layer_groups[org_layer_names[org_layer_src]]; + + for (unsigned int j = 0; j < layer_groups[index].size(); j ++) { + LayerProto* unroll_layer = conf.mutable_layer(layer_groups[index][j]); + // Update src layers of `unroll_layer` by considering the types + if (unroll_conn_type == kUnrollOneToAll) { + for (int unroll_layer_src : unroll_layer_srcs) { + unroll_layer->add_srclayers(conf.layer(unroll_layer_src).name()); + } + } else if (unroll_conn_type == kUnrollOneToOne) { + if (j < shift) continue; // no need to connect with the src + int unroll_layer_src = unroll_layer_srcs[j - shift]; + unroll_layer->add_srclayers(conf.layer(unroll_layer_src).name()); + } else if (unroll_conn_type == kUnrollFirstToLast) { + if (j > 0) break; + int unroll_layer_src = unroll_layer_srcs[unroll_layer_srcs.size() - 1]; + unroll_layer->add_srclayers(conf.layer(unroll_layer_src).name()); + } + } + } + } + return conf; +} + + NeuralNet::NeuralNet(NetProto netproto, int npartitions) { LOG(INFO) << "Constructing NeuralNet..."; auto graph = CreateGraph(netproto, npartitions); @@ -260,7 +353,7 @@ NetProto NeuralNet::AddPartitionConnectionLayers(const NetProto& netproto, * (NO) src_pdim = dst_pdim ? * (YES) Direct Connection * (NO) Slice -> Concate - */ + */ for (const LayerProto& origin_layer : netproto.layer()) { LayerProto* dst_layer = name2proto[origin_layer.name()]; int dst_pdim = dst_layer->partition_dim(); http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/777dfb6a/src/test/test_connection_layers.cc ---------------------------------------------------------------------- diff --git a/src/test/test_connection_layers.cc b/src/test/test_connection_layers.cc index 5517bde..6529840 100644 --- a/src/test/test_connection_layers.cc +++ b/src/test/test_connection_layers.cc @@ -179,8 +179,8 @@ TEST(ConnectionLayerTest, DataSliceTest) { proto_slice.mutable_slice_conf()->set_num_slices(K); SliceLayer slice; slice.Setup(proto_slice, src_slice); - ASSERT_EQ(slice.data(static_cast<Layer*>(&slice)).shape(0), N / K); - ASSERT_EQ(slice.data(static_cast<Layer*>(&slice)).shape(1), M); + ASSERT_EQ(slice.data(nullptr).shape(0), N / K); + ASSERT_EQ(slice.data(nullptr).shape(1), M); // use dummy as output layers LayerProto proto_out[K]; @@ -236,8 +236,8 @@ TEST(ConnectionLayerTest, ModelSliceTest) { proto_slice.mutable_slice_conf()->set_num_slices(K); SliceLayer slice; slice.Setup(proto_slice, src_slice); - ASSERT_EQ(slice.data(static_cast<Layer*>(&slice)).shape(0), N); - ASSERT_EQ(slice.data(static_cast<Layer*>(&slice)).shape(1), M / K); + ASSERT_EQ(slice.data(nullptr).shape(0), N); + ASSERT_EQ(slice.data(nullptr).shape(1), M / K); // use dummy as output layers LayerProto proto_out[K]; http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/777dfb6a/src/test/test_unrolling.cc ---------------------------------------------------------------------- diff --git a/src/test/test_unrolling.cc b/src/test/test_unrolling.cc new file mode 100644 index 0000000..e32c528 --- /dev/null +++ b/src/test/test_unrolling.cc @@ -0,0 +1,398 @@ +/************************************************************ + * + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + *************************************************************/ +#include <string> +#include <vector> +#include <fstream> +#include <iostream> +using namespace std; + +#include "gtest/gtest.h" +#include "singa/neuralnet/input_layer.h" +#include "singa/neuralnet/neuron_layer.h" +#include "singa/neuralnet/neuralnet.h" +#include "singa/neuralnet/connection_layer.h" +#include "singa/driver.h" +#include "singa/proto/job.pb.h" +#include "singa/utils/common.h" + +using namespace singa; + +class UnrollingTest: public ::testing::Test { +protected: + virtual void SetUp() { + NetProto* net_conf1 = job_conf1.mutable_neuralnet(); + + LayerProto* data_layer1 = net_conf1->add_layer(); + data_layer1->set_name("data"); + data_layer1->set_type(kRecordInput); + + LayerProto* embedding_layer1 = net_conf1->add_layer(); + embedding_layer1->set_name("embedding"); + embedding_layer1->set_type(kDummy); + embedding_layer1->add_srclayers("data"); + embedding_layer1->set_unroll_len(3); + embedding_layer1->add_unroll_conn_type(kUnrollOneToAll); + embedding_layer1->add_shift(0); + + LayerProto* gru_layer1 = net_conf1->add_layer(); + gru_layer1->set_name("gru"); + gru_layer1->set_type(kGRU); + gru_layer1->add_srclayers("embedding"); + gru_layer1->mutable_gru_conf()->set_dim_hidden(20); + gru_layer1->mutable_gru_conf()->set_bias_term(false); + gru_layer1->add_param()->set_name("w_z_hx"); + gru_layer1->add_param()->set_name("w_r_hx"); + gru_layer1->add_param()->set_name("w_c_hx"); + gru_layer1->add_param()->set_name("w_z_hh"); + gru_layer1->add_param()->set_name("w_r_hh"); + gru_layer1->add_param()->set_name("w_c_hh"); + gru_layer1->set_unroll_len(3); + gru_layer1->add_unroll_conn_type(kUnrollOneToOne); + gru_layer1->add_shift(0); + + LayerProto* out_layer1 = net_conf1->add_layer(); + out_layer1->set_name("out"); + out_layer1->set_type(kInnerProduct); + out_layer1->add_srclayers("gru"); + out_layer1->mutable_innerproduct_conf()->set_num_output(100); + out_layer1->add_param()->set_name("w"); + out_layer1->add_param()->set_name("b"); + out_layer1->set_unroll_len(3); + out_layer1->add_unroll_conn_type(kUnrollOneToOne); + out_layer1->add_shift(0); + + LayerProto* softmax_layer1 = net_conf1->add_layer(); + softmax_layer1->set_name("softmax"); + softmax_layer1->set_type(kSoftmax); + softmax_layer1->add_srclayers("out"); + softmax_layer1->set_unroll_len(3); + softmax_layer1->add_unroll_conn_type(kUnrollOneToOne); + softmax_layer1->add_shift(0); + + LayerProto* loss_layer1 = net_conf1->add_layer(); + loss_layer1->set_name("loss"); + loss_layer1->set_type(kSoftmaxLoss); + loss_layer1->add_srclayers("softmax"); + loss_layer1->add_srclayers("data"); + loss_layer1->set_unroll_len(3); + loss_layer1->add_unroll_conn_type(kUnrollOneToOne); + loss_layer1->add_shift(0); + loss_layer1->add_unroll_conn_type(kUnrollOneToAll); + loss_layer1->add_shift(0); + + + /* + * Initialize job conf 2 + */ + NetProto* net_conf2 = job_conf2.mutable_neuralnet(); + + LayerProto* data_layer2 = net_conf2->add_layer(); + data_layer2->set_name("data"); + data_layer2->set_type(kRecordInput); + + LayerProto* embedding_layer2 = net_conf2->add_layer(); + embedding_layer2->set_name("embedding"); + embedding_layer2->set_type(kDummy); + embedding_layer2->add_srclayers("data"); + embedding_layer2->add_srclayers("softmax"); + embedding_layer2->set_unroll_len(3); + embedding_layer2->add_unroll_conn_type(kUnrollOneToAll); + embedding_layer2->add_shift(0); + embedding_layer2->add_unroll_conn_type(kUnrollOneToOne); + embedding_layer2->add_shift(1); + + LayerProto* gru_layer2 = net_conf2->add_layer(); + gru_layer2->set_name("gru"); + gru_layer2->set_type(kGRU); + gru_layer2->add_srclayers("embedding"); + gru_layer2->mutable_gru_conf()->set_dim_hidden(20); + gru_layer2->mutable_gru_conf()->set_bias_term(false); + gru_layer2->add_param()->set_name("w_z_hx"); + gru_layer2->add_param()->set_name("w_r_hx"); + gru_layer2->add_param()->set_name("w_c_hx"); + gru_layer2->add_param()->set_name("w_z_hh"); + gru_layer2->add_param()->set_name("w_r_hh"); + gru_layer2->add_param()->set_name("w_c_hh"); + gru_layer2->set_unroll_len(3); + gru_layer2->add_unroll_conn_type(kUnrollOneToOne); + gru_layer2->add_shift(0); + + LayerProto* out_layer2 = net_conf2->add_layer(); + out_layer2->set_name("out"); + out_layer2->set_type(kInnerProduct); + out_layer2->add_srclayers("gru"); + out_layer2->mutable_innerproduct_conf()->set_num_output(100); + out_layer2->add_param()->set_name("w"); + out_layer2->add_param()->set_name("b"); + out_layer2->set_unroll_len(3); + out_layer2->add_unroll_conn_type(kUnrollOneToOne); + out_layer2->add_shift(0); + + LayerProto* softmax_layer2 = net_conf2->add_layer(); + softmax_layer2->set_name("softmax"); + softmax_layer2->set_type(kSoftmax); + softmax_layer2->add_srclayers("out"); + softmax_layer2->set_unroll_len(3); + softmax_layer2->add_unroll_conn_type(kUnrollOneToOne); + softmax_layer2->add_shift(0); + + LayerProto* loss_layer2 = net_conf2->add_layer(); + loss_layer2->set_name("loss"); + loss_layer2->set_type(kSoftmaxLoss); + loss_layer2->add_srclayers("softmax"); + loss_layer2->add_srclayers("data"); + loss_layer2->set_unroll_len(3); + loss_layer2->add_unroll_conn_type(kUnrollOneToOne); + loss_layer2->add_shift(0); + loss_layer2->add_unroll_conn_type(kUnrollOneToAll); + loss_layer2->add_shift(0); + } + + singa::JobProto job_conf1; + singa::JobProto job_conf2; +}; + +TEST_F(UnrollingTest, GRULanguageModelTrain) { + NetProto net; + net.CopyFrom(job_conf1.neuralnet()); + NetProto unrolled_net = NeuralNet::Unrolling(net); + EXPECT_EQ("data", unrolled_net.layer(0).name()); + + EXPECT_EQ("embedding_0", unrolled_net.layer(1).name()); + EXPECT_EQ(1, unrolled_net.layer(1).srclayers_size()); + EXPECT_EQ("data", unrolled_net.layer(1).srclayers(0)); + + EXPECT_EQ("embedding_1", unrolled_net.layer(2).name()); + EXPECT_EQ(1, unrolled_net.layer(2).srclayers_size()); + EXPECT_EQ("data", unrolled_net.layer(2).srclayers(0)); + + EXPECT_EQ("embedding_2", unrolled_net.layer(3).name()); + EXPECT_EQ(1, unrolled_net.layer(3).srclayers_size()); + EXPECT_EQ("data", unrolled_net.layer(3).srclayers(0)); + + EXPECT_EQ("gru_0", unrolled_net.layer(4).name()); + EXPECT_EQ(1, unrolled_net.layer(4).srclayers_size()); + EXPECT_EQ("embedding_0", unrolled_net.layer(4).srclayers(0)); + EXPECT_EQ("w_z_hx", unrolled_net.layer(4).param(0).name()); + EXPECT_EQ("w_r_hx", unrolled_net.layer(4).param(1).name()); + EXPECT_EQ("w_c_hx", unrolled_net.layer(4).param(2).name()); + EXPECT_EQ("w_z_hh", unrolled_net.layer(4).param(3).name()); + EXPECT_EQ("w_r_hh", unrolled_net.layer(4).param(4).name()); + EXPECT_EQ("w_c_hh", unrolled_net.layer(4).param(5).name()); + + EXPECT_EQ("gru_1", unrolled_net.layer(5).name()); + EXPECT_EQ(2, unrolled_net.layer(5).srclayers_size()); + EXPECT_EQ("gru_0", unrolled_net.layer(5).srclayers(0)); + EXPECT_EQ("embedding_1", unrolled_net.layer(5).srclayers(1)); + EXPECT_EQ("w_z_hx_1", unrolled_net.layer(5).param(0).name()); + EXPECT_EQ("w_z_hx", unrolled_net.layer(5).param(0).share_from()); + EXPECT_EQ("w_r_hx_1", unrolled_net.layer(5).param(1).name()); + EXPECT_EQ("w_r_hx", unrolled_net.layer(5).param(1).share_from()); + EXPECT_EQ("w_c_hx_1", unrolled_net.layer(5).param(2).name()); + EXPECT_EQ("w_c_hx", unrolled_net.layer(5).param(2).share_from()); + EXPECT_EQ("w_z_hh_1", unrolled_net.layer(5).param(3).name()); + EXPECT_EQ("w_z_hh", unrolled_net.layer(5).param(3).share_from()); + EXPECT_EQ("w_r_hh_1", unrolled_net.layer(5).param(4).name()); + EXPECT_EQ("w_r_hh", unrolled_net.layer(5).param(4).share_from()); + EXPECT_EQ("w_c_hh_1", unrolled_net.layer(5).param(5).name()); + EXPECT_EQ("w_c_hh", unrolled_net.layer(5).param(5).share_from()); + + EXPECT_EQ("gru_2", unrolled_net.layer(6).name()); + EXPECT_EQ(2, unrolled_net.layer(6).srclayers_size()); + EXPECT_EQ("gru_1", unrolled_net.layer(6).srclayers(0)); + EXPECT_EQ("embedding_2", unrolled_net.layer(6).srclayers(1)); + EXPECT_EQ("w_z_hx_2", unrolled_net.layer(6).param(0).name()); + EXPECT_EQ("w_z_hx", unrolled_net.layer(6).param(0).share_from()); + EXPECT_EQ("w_r_hx_2", unrolled_net.layer(6).param(1).name()); + EXPECT_EQ("w_r_hx", unrolled_net.layer(6).param(1).share_from()); + EXPECT_EQ("w_c_hx_2", unrolled_net.layer(6).param(2).name()); + EXPECT_EQ("w_c_hx", unrolled_net.layer(6).param(2).share_from()); + EXPECT_EQ("w_z_hh_2", unrolled_net.layer(6).param(3).name()); + EXPECT_EQ("w_z_hh", unrolled_net.layer(6).param(3).share_from()); + EXPECT_EQ("w_r_hh_2", unrolled_net.layer(6).param(4).name()); + EXPECT_EQ("w_r_hh", unrolled_net.layer(6).param(4).share_from()); + EXPECT_EQ("w_c_hh_2", unrolled_net.layer(6).param(5).name()); + EXPECT_EQ("w_c_hh", unrolled_net.layer(6).param(5).share_from()); + + EXPECT_EQ("out_0", unrolled_net.layer(7).name()); + EXPECT_EQ(1, unrolled_net.layer(7).srclayers_size()); + EXPECT_EQ("gru_0", unrolled_net.layer(7).srclayers(0)); + EXPECT_EQ("w", unrolled_net.layer(7).param(0).name()); + EXPECT_EQ("b", unrolled_net.layer(7).param(1).name()); + + EXPECT_EQ("out_1", unrolled_net.layer(8).name()); + EXPECT_EQ(1, unrolled_net.layer(8).srclayers_size()); + EXPECT_EQ("gru_1", unrolled_net.layer(8).srclayers(0)); + EXPECT_EQ("w_1", unrolled_net.layer(8).param(0).name()); + EXPECT_EQ("w", unrolled_net.layer(8).param(0).share_from()); + EXPECT_EQ("b_1", unrolled_net.layer(8).param(1).name()); + EXPECT_EQ("b", unrolled_net.layer(8).param(1).share_from()); + + EXPECT_EQ("out_2", unrolled_net.layer(9).name()); + EXPECT_EQ(1, unrolled_net.layer(9).srclayers_size()); + EXPECT_EQ("gru_2", unrolled_net.layer(9).srclayers(0)); + EXPECT_EQ("w_2", unrolled_net.layer(9).param(0).name()); + EXPECT_EQ("w", unrolled_net.layer(9).param(0).share_from()); + EXPECT_EQ("b_2", unrolled_net.layer(9).param(1).name()); + EXPECT_EQ("b", unrolled_net.layer(9).param(1).share_from()); + + EXPECT_EQ("softmax_0", unrolled_net.layer(10).name()); + EXPECT_EQ(1, unrolled_net.layer(10).srclayers_size()); + EXPECT_EQ("out_0", unrolled_net.layer(10).srclayers(0)); + + EXPECT_EQ("softmax_1", unrolled_net.layer(11).name()); + EXPECT_EQ(1, unrolled_net.layer(11).srclayers_size()); + EXPECT_EQ("out_1", unrolled_net.layer(11).srclayers(0)); + + EXPECT_EQ("softmax_2", unrolled_net.layer(12).name()); + EXPECT_EQ(1, unrolled_net.layer(12).srclayers_size()); + EXPECT_EQ("out_2", unrolled_net.layer(12).srclayers(0)); + + EXPECT_EQ("loss_0", unrolled_net.layer(13).name()); + EXPECT_EQ(2, unrolled_net.layer(13).srclayers_size()); + EXPECT_EQ("softmax_0", unrolled_net.layer(13).srclayers(0)); + EXPECT_EQ("data", unrolled_net.layer(13).srclayers(1)); + + EXPECT_EQ("loss_1", unrolled_net.layer(14).name()); + EXPECT_EQ(2, unrolled_net.layer(14).srclayers_size()); + EXPECT_EQ("softmax_1", unrolled_net.layer(14).srclayers(0)); + EXPECT_EQ("data", unrolled_net.layer(14).srclayers(1)); + + EXPECT_EQ("loss_2", unrolled_net.layer(15).name()); + EXPECT_EQ(2, unrolled_net.layer(15).srclayers_size()); + EXPECT_EQ("softmax_2", unrolled_net.layer(15).srclayers(0)); + EXPECT_EQ("data", unrolled_net.layer(15).srclayers(1)); +} + +TEST_F(UnrollingTest, GRULanguageModelTest) { + NetProto net; + net.CopyFrom(job_conf2.neuralnet()); + NetProto unrolled_net = NeuralNet::Unrolling(net); + + EXPECT_EQ("data", unrolled_net.layer(0).name()); + + EXPECT_EQ("embedding_0", unrolled_net.layer(1).name()); + EXPECT_EQ(1, unrolled_net.layer(1).srclayers_size()); + EXPECT_EQ("data", unrolled_net.layer(1).srclayers(0)); + + EXPECT_EQ("embedding_1", unrolled_net.layer(2).name()); + EXPECT_EQ(2, unrolled_net.layer(2).srclayers_size()); + EXPECT_EQ("data", unrolled_net.layer(2).srclayers(0)); + EXPECT_EQ("softmax_0", unrolled_net.layer(2).srclayers(1)); + + EXPECT_EQ("embedding_2", unrolled_net.layer(3).name()); + EXPECT_EQ(2, unrolled_net.layer(3).srclayers_size()); + EXPECT_EQ("data", unrolled_net.layer(3).srclayers(0)); + EXPECT_EQ("softmax_1", unrolled_net.layer(3).srclayers(1)); + + EXPECT_EQ("gru_0", unrolled_net.layer(4).name()); + EXPECT_EQ(1, unrolled_net.layer(4).srclayers_size()); + EXPECT_EQ("embedding_0", unrolled_net.layer(4).srclayers(0)); + EXPECT_EQ("w_z_hx", unrolled_net.layer(4).param(0).name()); + EXPECT_EQ("w_r_hx", unrolled_net.layer(4).param(1).name()); + EXPECT_EQ("w_c_hx", unrolled_net.layer(4).param(2).name()); + EXPECT_EQ("w_z_hh", unrolled_net.layer(4).param(3).name()); + EXPECT_EQ("w_r_hh", unrolled_net.layer(4).param(4).name()); + EXPECT_EQ("w_c_hh", unrolled_net.layer(4).param(5).name()); + + EXPECT_EQ("gru_1", unrolled_net.layer(5).name()); + EXPECT_EQ(2, unrolled_net.layer(5).srclayers_size()); + EXPECT_EQ("gru_0", unrolled_net.layer(5).srclayers(0)); + EXPECT_EQ("embedding_1", unrolled_net.layer(5).srclayers(1)); + EXPECT_EQ("w_z_hx_1", unrolled_net.layer(5).param(0).name()); + EXPECT_EQ("w_z_hx", unrolled_net.layer(5).param(0).share_from()); + EXPECT_EQ("w_r_hx_1", unrolled_net.layer(5).param(1).name()); + EXPECT_EQ("w_r_hx", unrolled_net.layer(5).param(1).share_from()); + EXPECT_EQ("w_c_hx_1", unrolled_net.layer(5).param(2).name()); + EXPECT_EQ("w_c_hx", unrolled_net.layer(5).param(2).share_from()); + EXPECT_EQ("w_z_hh_1", unrolled_net.layer(5).param(3).name()); + EXPECT_EQ("w_z_hh", unrolled_net.layer(5).param(3).share_from()); + EXPECT_EQ("w_r_hh_1", unrolled_net.layer(5).param(4).name()); + EXPECT_EQ("w_r_hh", unrolled_net.layer(5).param(4).share_from()); + EXPECT_EQ("w_c_hh_1", unrolled_net.layer(5).param(5).name()); + EXPECT_EQ("w_c_hh", unrolled_net.layer(5).param(5).share_from()); + + EXPECT_EQ("gru_2", unrolled_net.layer(6).name()); + EXPECT_EQ(2, unrolled_net.layer(6).srclayers_size()); + EXPECT_EQ("gru_1", unrolled_net.layer(6).srclayers(0)); + EXPECT_EQ("embedding_2", unrolled_net.layer(6).srclayers(1)); + EXPECT_EQ("w_z_hx_2", unrolled_net.layer(6).param(0).name()); + EXPECT_EQ("w_z_hx", unrolled_net.layer(6).param(0).share_from()); + EXPECT_EQ("w_r_hx_2", unrolled_net.layer(6).param(1).name()); + EXPECT_EQ("w_r_hx", unrolled_net.layer(6).param(1).share_from()); + EXPECT_EQ("w_c_hx_2", unrolled_net.layer(6).param(2).name()); + EXPECT_EQ("w_c_hx", unrolled_net.layer(6).param(2).share_from()); + EXPECT_EQ("w_z_hh_2", unrolled_net.layer(6).param(3).name()); + EXPECT_EQ("w_z_hh", unrolled_net.layer(6).param(3).share_from()); + EXPECT_EQ("w_r_hh_2", unrolled_net.layer(6).param(4).name()); + EXPECT_EQ("w_r_hh", unrolled_net.layer(6).param(4).share_from()); + EXPECT_EQ("w_c_hh_2", unrolled_net.layer(6).param(5).name()); + EXPECT_EQ("w_c_hh", unrolled_net.layer(6).param(5).share_from()); + + EXPECT_EQ("out_0", unrolled_net.layer(7).name()); + EXPECT_EQ(1, unrolled_net.layer(7).srclayers_size()); + EXPECT_EQ("gru_0", unrolled_net.layer(7).srclayers(0)); + EXPECT_EQ("w", unrolled_net.layer(7).param(0).name()); + EXPECT_EQ("b", unrolled_net.layer(7).param(1).name()); + + EXPECT_EQ("out_1", unrolled_net.layer(8).name()); + EXPECT_EQ(1, unrolled_net.layer(8).srclayers_size()); + EXPECT_EQ("gru_1", unrolled_net.layer(8).srclayers(0)); + EXPECT_EQ("w_1", unrolled_net.layer(8).param(0).name()); + EXPECT_EQ("w", unrolled_net.layer(8).param(0).share_from()); + EXPECT_EQ("b_1", unrolled_net.layer(8).param(1).name()); + EXPECT_EQ("b", unrolled_net.layer(8).param(1).share_from()); + + EXPECT_EQ("out_2", unrolled_net.layer(9).name()); + EXPECT_EQ(1, unrolled_net.layer(9).srclayers_size()); + EXPECT_EQ("gru_2", unrolled_net.layer(9).srclayers(0)); + EXPECT_EQ("w_2", unrolled_net.layer(9).param(0).name()); + EXPECT_EQ("w", unrolled_net.layer(9).param(0).share_from()); + EXPECT_EQ("b_2", unrolled_net.layer(9).param(1).name()); + EXPECT_EQ("b", unrolled_net.layer(9).param(1).share_from()); + + EXPECT_EQ("softmax_0", unrolled_net.layer(10).name()); + EXPECT_EQ(1, unrolled_net.layer(10).srclayers_size()); + EXPECT_EQ("out_0", unrolled_net.layer(10).srclayers(0)); + + EXPECT_EQ("softmax_1", unrolled_net.layer(11).name()); + EXPECT_EQ(1, unrolled_net.layer(11).srclayers_size()); + EXPECT_EQ("out_1", unrolled_net.layer(11).srclayers(0)); + + EXPECT_EQ("softmax_2", unrolled_net.layer(12).name()); + EXPECT_EQ(1, unrolled_net.layer(12).srclayers_size()); + EXPECT_EQ("out_2", unrolled_net.layer(12).srclayers(0)); + + EXPECT_EQ("loss_0", unrolled_net.layer(13).name()); + EXPECT_EQ(2, unrolled_net.layer(13).srclayers_size()); + EXPECT_EQ("softmax_0", unrolled_net.layer(13).srclayers(0)); + EXPECT_EQ("data", unrolled_net.layer(13).srclayers(1)); + + EXPECT_EQ("loss_1", unrolled_net.layer(14).name()); + EXPECT_EQ(2, unrolled_net.layer(14).srclayers_size()); + EXPECT_EQ("softmax_1", unrolled_net.layer(14).srclayers(0)); + EXPECT_EQ("data", unrolled_net.layer(14).srclayers(1)); + + EXPECT_EQ("loss_2", unrolled_net.layer(15).name()); + EXPECT_EQ(2, unrolled_net.layer(15).srclayers_size()); + EXPECT_EQ("softmax_2", unrolled_net.layer(15).srclayers(0)); + EXPECT_EQ("data", unrolled_net.layer(15).srclayers(1)); +}
