Repository: incubator-singa
Updated Branches:
  refs/heads/master bb75a0be5 -> a2f4e4680


SINGA-120 - Implemented GRU and BPTT \n 1) Added the implementation of the GRU 
model; \n 2) Added a test for GRU functions


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/ddf4e79a
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/ddf4e79a
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/ddf4e79a

Branch: refs/heads/master
Commit: ddf4e79aff5d8616f6758df18056b9443761405d
Parents: bb75a0b
Author: Ju Fan <[email protected]>
Authored: Fri Jan 1 10:41:59 2016 +0800
Committer: Wei Wang <[email protected]>
Committed: Wed Jan 6 01:50:48 2016 +0800

----------------------------------------------------------------------
 src/neuralnet/neuron_layer/gru.cc | 275 +++++++++++++++++++++++++++++++
 src/test/test_gru_layer.cc        | 286 +++++++++++++++++++++++++++++++++
 2 files changed, 561 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/ddf4e79a/src/neuralnet/neuron_layer/gru.cc
----------------------------------------------------------------------
diff --git a/src/neuralnet/neuron_layer/gru.cc 
b/src/neuralnet/neuron_layer/gru.cc
new file mode 100644
index 0000000..45d7873
--- /dev/null
+++ b/src/neuralnet/neuron_layer/gru.cc
@@ -0,0 +1,275 @@
+/************************************************************
+*
+* Licensed to the Apache Software Foundation (ASF) under one
+* or more contributor license agreements.  See the NOTICE file
+* distributed with this work for additional information
+* regarding copyright ownership.  The ASF licenses this file
+* to you under the Apache License, Version 2.0 (the
+* "License"); you may not use this file except in compliance
+* with the License.  You may obtain a copy of the License at
+*
+*   http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing,
+* software distributed under the License is distributed on an
+* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+* KIND, either express or implied.  See the License for the
+* specific language governing permissions and limitations
+* under the License.
+*
+*************************************************************/
+
+#include "singa/neuralnet/neuron_layer.h"
+
+#include <glog/logging.h>
+#include "singa/utils/singleton.h"
+#include "singa/utils/math_blob.h"
+#include "singa/utils/singa_op.h"
+
+#include <iostream>
+using namespace std;
+
+namespace singa {
+
+using std::vector;
+
+GRULayer::~GRULayer() {
+  delete weight_z_hx_;
+  delete weight_z_hh_;
+  delete bias_z_;
+
+  delete weight_r_hx_;
+  delete weight_r_hh_;
+  delete bias_r_;
+
+  delete weight_c_hx_;
+  delete weight_c_hh_;
+  delete bias_c_;
+
+  delete update_gate;
+  delete reset_gate;
+  delete new_memory;
+}
+
+void GRULayer::Setup(const LayerProto& conf,
+    const vector<Layer*>& srclayers) {
+  Layer::Setup(conf, srclayers);
+  CHECK_LE(srclayers.size(), 2);
+  const auto& src = srclayers[0]->data(this);
+
+  batchsize_ = src.shape()[0]; // size of batch
+  vdim_ = src.count() / (batchsize_); // dimension of input
+
+  hdim_ = layer_conf_.gru_conf().dim_hidden(); // dimension of hidden state
+
+  data_.Reshape(vector<int>{batchsize_, hdim_});
+  grad_.ReshapeLike(data_);
+
+  // Initialize the parameters
+  weight_z_hx_ = Param::Create(conf.param(0));
+  weight_r_hx_ = Param::Create(conf.param(1));
+  weight_c_hx_ = Param::Create(conf.param(2));
+
+  weight_z_hh_ = Param::Create(conf.param(3));
+  weight_r_hh_ = Param::Create(conf.param(4));
+  weight_c_hh_ = Param::Create(conf.param(5));
+
+  if (conf.gru_conf().bias_term()) {
+         bias_z_ = Param::Create(conf.param(6));
+         bias_r_ = Param::Create(conf.param(7));
+         bias_c_ = Param::Create(conf.param(8));
+  }
+
+  weight_z_hx_->Setup(vector<int>{hdim_, vdim_});
+  weight_r_hx_->Setup(vector<int>{hdim_, vdim_});
+  weight_c_hx_->Setup(vector<int>{hdim_, vdim_});
+
+  weight_z_hh_->Setup(vector<int>{hdim_, hdim_});
+  weight_r_hh_->Setup(vector<int>{hdim_, hdim_});
+  weight_c_hh_->Setup(vector<int>{hdim_, hdim_});
+
+  if (conf.gru_conf().bias_term()) {
+         bias_z_->Setup(vector<int>{hdim_});
+         bias_r_->Setup(vector<int>{hdim_});
+         bias_c_->Setup(vector<int>{hdim_});
+  }
+
+  update_gate = new Blob<float>(batchsize_, hdim_);
+  reset_gate = new Blob<float>(batchsize_, hdim_);
+  new_memory = new Blob<float>(batchsize_, hdim_);
+
+}
+
+void GRULayer::ComputeFeature(int flag,
+    const vector<Layer*>& srclayers) {
+       CHECK_LE(srclayers.size(), 2);
+
+       // Do transpose
+       Blob<float> *w_z_hx_t = Transpose (weight_z_hx_->data());
+       Blob<float> *w_z_hh_t = Transpose (weight_z_hh_->data());
+       Blob<float> *w_r_hx_t = Transpose (weight_r_hx_->data());
+       Blob<float> *w_r_hh_t = Transpose (weight_r_hh_->data());
+       Blob<float> *w_c_hx_t = Transpose (weight_c_hx_->data());
+       Blob<float> *w_c_hh_t = Transpose (weight_c_hh_->data());
+
+       // Prepare the data input and the context
+       const auto& src = srclayers[0]->data(this);
+       const Blob<float> *context;
+       if (srclayers.size() == 1) { // only have data input
+               context = new Blob<float>(batchsize_, hdim_);
+       } else { // have data input & context
+               context = &srclayers[1]->data(this);
+       }
+
+       // Compute the update gate
+       GEMM(1.0f, 0.0f, src,*w_z_hx_t,update_gate);
+       if (bias_z_ != nullptr)
+               MVAddRow(1.0f,1.0f,bias_z_->data(),update_gate);
+       Blob<float> zprev (batchsize_,hdim_);
+       GEMM(1.0f, 0.0f, *context,*w_z_hh_t, &zprev);
+       Add<float>(*update_gate, zprev, update_gate);
+       Map<op::Sigmoid<float>,float>(*update_gate, update_gate);
+
+       // Compute the reset gate
+       GEMM(1.0f, 0.0f, src,*w_r_hx_t,reset_gate);
+       if (bias_r_ != nullptr)
+               MVAddRow(1.0f,1.0f,bias_r_->data(),reset_gate);
+       Blob<float> rprev (batchsize_, hdim_);
+       GEMM(1.0f, 0.0f, *context, *w_r_hh_t, &rprev);
+       Add<float>(*reset_gate, rprev, reset_gate);
+       Map<op::Sigmoid<float>,float>(*reset_gate, reset_gate);
+
+       // Compute the new memory
+       GEMM(1.0f, 0.0f, src, *w_c_hx_t, new_memory);
+       if (bias_c_ != nullptr)
+               MVAddRow(1.0f,1.0f,bias_c_->data(), new_memory);
+       Blob<float> cprev (batchsize_, hdim_);
+       GEMM(1.0f, 0.0f, *context, *w_c_hh_t, &cprev);
+       //Blob<float> new_cprev (batchsize_, hdim_);
+       Mult<float>(*reset_gate, cprev, &cprev);
+       Add<float>(*new_memory, cprev, new_memory);
+       Map<op::Tanh<float>,float>(*new_memory, new_memory);
+
+       // Compute data - new memory part
+       Blob<float> z1 (batchsize_,hdim_);
+       for (int i = 0; i < z1.count(); i ++) {
+               z1.mutable_cpu_data()[i] = 1.0f; // generate a matrix with ones
+       }
+       AXPY<float>(-1.0f, *update_gate, &z1);
+       Mult<float>(z1, *new_memory, &data_);
+
+       // Compute data - context part
+       Blob<float> data_prev (batchsize_, hdim_);
+       Mult<float>(*update_gate,*context,&data_prev);
+       Add<float>(data_, data_prev, &data_);
+
+       // delete the pointers
+       if (srclayers.size() == 1) delete context;
+       else context = NULL;
+
+       delete w_z_hx_t;
+       delete w_z_hh_t;
+       delete w_r_hx_t;
+       delete w_r_hh_t;
+       delete w_c_hx_t;
+       delete w_c_hh_t;
+}
+
+void GRULayer::ComputeGradient(int flag,
+    const vector<Layer*>& srclayers) {
+       CHECK_LE(srclayers.size(), 2);
+
+       // Prepare the data input and the context
+       const Blob<float>& src = srclayers[0]->data(this);
+       const Blob<float> *context;
+       if (srclayers.size() == 1) { // only have data input
+               context = new Blob<float>(batchsize_, hdim_);
+       } else { // have data input & context
+               context = &srclayers[1]->data(this);
+       }
+
+       // Prepare gradient of output neurons
+       Blob<float> *grad_t = Transpose (grad_);
+
+       // Compute intermediate gradients which are used for other computations
+       Blob<float> dugatedz (batchsize_, hdim_);
+       Map<singa::op::SigmoidGrad<float>, float>(*update_gate, &dugatedz);
+       Blob<float> drgatedr (batchsize_, hdim_);
+       Map<singa::op::SigmoidGrad<float>, float>(*reset_gate, &drgatedr);
+       Blob<float> dnewmdc (batchsize_, hdim_);
+       Map<singa::op::TanhGrad<float>, float>(*new_memory,&dnewmdc);
+
+       Blob<float> dLdz (batchsize_, hdim_);
+       Sub<float>(*context, *new_memory, &dLdz);
+       Mult<float>(dLdz, grad_, &dLdz);
+       Mult<float>(dLdz, dugatedz, &dLdz);
+
+       Blob<float> dLdc (batchsize_,hdim_);
+       Blob<float> z1 (batchsize_,hdim_);
+       for (int i = 0; i < z1.count(); i ++) {
+               z1.mutable_cpu_data()[i] = 1.0f; // generate a matrix with ones
+       }
+       AXPY<float>(-1.0f, *update_gate, &z1);
+       Mult(grad_,z1,&dLdc);
+       Mult(dLdc,dnewmdc,&dLdc);
+
+       Blob<float> reset_dLdc (batchsize_,hdim_);
+       Mult(dLdc, *reset_gate, &reset_dLdc);
+
+       Blob<float> dLdr (batchsize_, hdim_);
+       Blob<float> cprev (batchsize_, hdim_);
+       Blob<float> *w_c_hh_t = Transpose(weight_c_hh_->data());
+       GEMM(1.0f,0.0f,*context,*w_c_hh_t, &cprev);
+       delete w_c_hh_t;
+       Mult(dLdc,cprev,&dLdr);
+       Mult(dLdr,drgatedr,&dLdr);
+
+
+       // Compute gradients for parameters of update gate
+       Blob<float> *dLdz_t = Transpose(dLdz);
+       GEMM(1.0f,0.0f,*dLdz_t,src,weight_z_hx_->mutable_grad());
+       GEMM(1.0f,0.0f,*dLdz_t,*context,weight_z_hh_->mutable_grad());
+       if (bias_z_ != nullptr)
+               MVSumRow<float>(1.0f,0.0f,dLdz,bias_z_->mutable_grad());
+       delete dLdz_t;
+
+       // Compute gradients for parameters of reset gate
+       Blob<float> *dLdr_t = Transpose(dLdr);
+       GEMM(1.0f,0.0f,*dLdr_t,src,weight_r_hx_->mutable_grad());
+       GEMM(1.0f,0.0f,*dLdr_t,*context,weight_r_hh_->mutable_grad());
+       if (bias_r_ != nullptr)
+               MVSumRow(1.0f,0.0f,dLdr,bias_r_->mutable_grad());
+       delete dLdr_t;
+
+       // Compute gradients for parameters of new memory
+       Blob<float> *dLdc_t = Transpose(dLdc);
+       GEMM(1.0f,0.0f,*dLdc_t,src,weight_c_hx_->mutable_grad());
+       if (bias_c_ != nullptr)
+               MVSumRow(1.0f,0.0f,dLdc,bias_c_->mutable_grad());
+       delete dLdc_t;
+
+       Blob<float> *reset_dLdc_t = Transpose(reset_dLdc);
+       GEMM(1.0f,0.0f,*reset_dLdc_t,*context,weight_c_hh_->mutable_grad());
+       delete reset_dLdc_t;
+
+       // Compute gradients for data input layer
+       if (srclayers[0]->mutable_grad(this) != nullptr) {
+               
GEMM(1.0f,0.0f,dLdc,weight_c_hx_->data(),srclayers[0]->mutable_grad(this));
+               
GEMM(1.0f,1.0f,dLdz,weight_z_hx_->data(),srclayers[0]->mutable_grad(this));
+               GEMM(1.0f,1.0f,dLdr,weight_r_hx_->data(), 
srclayers[0]->mutable_grad(this));
+       }
+
+       if (srclayers.size() > 1 && srclayers[1]->mutable_grad(this) != 
nullptr) {
+               // Compute gradients for context layer
+               GEMM(1.0f,0.0f,reset_dLdc,weight_c_hh_->data(), 
srclayers[1]->mutable_grad(this));
+               GEMM(1.0f,1.0f,dLdr, weight_r_hh_->data(), 
srclayers[1]->mutable_grad(this));
+               GEMM(1.0f,1.0f,dLdz,weight_z_hh_->data(), 
srclayers[1]->mutable_grad(this));
+               Add(srclayers[1]->grad(this), *update_gate, 
srclayers[1]->mutable_grad(this));
+       }
+
+       if (srclayers.size() == 1) delete context;
+       else context = NULL;
+       delete grad_t;
+}
+
+}  // namespace singa

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/ddf4e79a/src/test/test_gru_layer.cc
----------------------------------------------------------------------
diff --git a/src/test/test_gru_layer.cc b/src/test/test_gru_layer.cc
new file mode 100644
index 0000000..296b795
--- /dev/null
+++ b/src/test/test_gru_layer.cc
@@ -0,0 +1,286 @@
+/************************************************************
+ *
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ *
+ *************************************************************/
+#include <string>
+#include <vector>
+#include <fstream>
+#include <iostream>
+using namespace std;
+
+
+#include "gtest/gtest.h"
+#include "singa/neuralnet/neuron_layer.h"
+#include "singa/neuralnet/input_layer.h"
+#include "singa/driver.h"
+#include "singa/proto/job.pb.h"
+
+using namespace singa;
+
+class GRULayerTest: public ::testing::Test {
+protected:
+       virtual void SetUp() {
+               // Initialize the settings for the first input-layer
+               std::string path1 = "src/test/gru-in-1.csv"; // path of a csv 
file
+               std::ofstream ofs1(path1, std::ofstream::out);
+               ASSERT_TRUE(ofs1.is_open());
+               ofs1 << "0,0,0,1\n";
+               ofs1 << "0,0,1,0\n";
+               ofs1.close();
+               auto conf1 = in1_conf.mutable_store_conf();
+               conf1->set_path(path1);
+               conf1->set_batchsize(2);
+               conf1->add_shape(4);
+               conf1->set_backend("textfile");
+               conf1->set_has_label(false);
+
+
+               // Initialize the settings for the second input-layer
+               std::string path2 = "src/test/gru-in-2.csv"; // path of a csv 
file
+               std::ofstream ofs2(path2, std::ofstream::out);
+               ASSERT_TRUE(ofs2.is_open());
+               ofs2 << "0,1,0,0\n";
+               ofs2 << "1,0,0,0\n";
+               ofs2.close();
+               auto conf2 = in2_conf.mutable_store_conf();
+               conf2->set_path(path2);
+
+               conf2->set_batchsize(2);
+               conf2->add_shape(4);
+               conf2->set_backend("textfile");
+               conf2->set_has_label(false);
+
+
+               gru1_conf.mutable_gru_conf() -> set_dim_hidden(2);
+               gru1_conf.mutable_gru_conf() -> set_bias_term(true);
+               for (int i = 0; i < 9; i ++) {
+                       gru1_conf.add_param();
+               }
+
+
+               gru1_conf.mutable_param(0)->set_name("wzhx1");
+               gru1_conf.mutable_param(0)->set_type(kParam);
+               gru1_conf.mutable_param(0)->mutable_init()->set_type(kConstant);
+               gru1_conf.mutable_param(0)->mutable_init()->set_value(0.5f);
+
+               gru1_conf.mutable_param(1)->set_name("wrhx1");
+               gru1_conf.mutable_param(1)->set_type(kParam);
+               gru1_conf.mutable_param(1)->mutable_init()->set_type(kConstant);
+               gru1_conf.mutable_param(1)->mutable_init()->set_value(0.5f);
+
+               gru1_conf.mutable_param(2)->set_name("wchx1");
+               gru1_conf.mutable_param(2)->set_type(kParam);
+               gru1_conf.mutable_param(2)->mutable_init()->set_type(kConstant);
+               gru1_conf.mutable_param(2)->mutable_init()->set_value(0.5f);
+
+               gru1_conf.mutable_param(3)->set_name("wzhh1");
+               gru1_conf.mutable_param(3)->set_type(kParam);
+               gru1_conf.mutable_param(3)->mutable_init()->set_type(kConstant);
+               gru1_conf.mutable_param(3)->mutable_init()->set_value(0.5f);
+
+               gru1_conf.mutable_param(4)->set_name("wrhh1");
+               gru1_conf.mutable_param(4)->set_type(kParam);
+               gru1_conf.mutable_param(4)->mutable_init()->set_type(kConstant);
+               gru1_conf.mutable_param(4)->mutable_init()->set_value(0.5f);
+
+               gru1_conf.mutable_param(5)->set_name("wchh1");
+               gru1_conf.mutable_param(5)->set_type(kParam);
+               gru1_conf.mutable_param(5)->mutable_init()->set_type(kConstant);
+               gru1_conf.mutable_param(5)->mutable_init()->set_value(0.5f);
+
+               gru1_conf.mutable_param(6)->set_name("bz1");
+               gru1_conf.mutable_param(6)->set_type(kParam);
+               gru1_conf.mutable_param(6)->mutable_init()->set_type(kConstant);
+               gru1_conf.mutable_param(6)->mutable_init()->set_value(0.5f);
+
+               gru1_conf.mutable_param(7)->set_name("br1");
+               gru1_conf.mutable_param(7)->set_type(kParam);
+               gru1_conf.mutable_param(7)->mutable_init()->set_type(kConstant);
+               gru1_conf.mutable_param(7)->mutable_init()->set_value(0.5f);
+
+               gru1_conf.mutable_param(8)->set_name("bc1");
+               gru1_conf.mutable_param(8)->set_type(kParam);
+               gru1_conf.mutable_param(8)->mutable_init()->set_type(kConstant);
+               gru1_conf.mutable_param(8)->mutable_init()->set_value(0.5f);
+
+               gru2_conf.mutable_gru_conf() -> set_dim_hidden(2);
+               gru2_conf.mutable_gru_conf() -> set_bias_term(true);
+               for (int i = 0; i < 9; i ++) {
+                       gru2_conf.add_param();
+               }
+
+               gru2_conf.mutable_param(0)->set_name("wzhx2");
+               gru2_conf.mutable_param(0)->set_type(kParam);
+               gru2_conf.mutable_param(0)->mutable_init()->set_type(kConstant);
+               gru2_conf.mutable_param(0)->mutable_init()->set_value(0.5f);
+
+               gru2_conf.mutable_param(1)->set_name("wrhx2");
+               gru2_conf.mutable_param(1)->set_type(kParam);
+               gru2_conf.mutable_param(1)->mutable_init()->set_type(kConstant);
+               gru2_conf.mutable_param(1)->mutable_init()->set_value(0.5f);
+
+               gru2_conf.mutable_param(2)->set_name("wchx2");
+               gru2_conf.mutable_param(2)->set_type(kParam);
+               gru2_conf.mutable_param(2)->mutable_init()->set_type(kConstant);
+               gru2_conf.mutable_param(2)->mutable_init()->set_value(0.5f);
+
+               gru2_conf.mutable_param(3)->set_name("wzhh2");
+               gru2_conf.mutable_param(3)->set_type(kParam);
+               gru2_conf.mutable_param(3)->mutable_init()->set_type(kConstant);
+               gru2_conf.mutable_param(3)->mutable_init()->set_value(0.5f);
+
+               gru2_conf.mutable_param(4)->set_name("wrhh2");
+               gru2_conf.mutable_param(4)->set_type(kParam);
+               gru2_conf.mutable_param(4)->mutable_init()->set_type(kConstant);
+               gru2_conf.mutable_param(4)->mutable_init()->set_value(0.5f);
+
+               gru2_conf.mutable_param(5)->set_name("wchh2");
+               gru2_conf.mutable_param(5)->set_type(kParam);
+               gru2_conf.mutable_param(5)->mutable_init()->set_type(kConstant);
+               gru2_conf.mutable_param(5)->mutable_init()->set_value(0.5f);
+
+               gru2_conf.mutable_param(6)->set_name("bz2");
+               gru2_conf.mutable_param(6)->set_type(kParam);
+               gru2_conf.mutable_param(6)->mutable_init()->set_type(kConstant);
+               gru2_conf.mutable_param(6)->mutable_init()->set_value(0.5f);
+
+               gru2_conf.mutable_param(7)->set_name("br2");
+               gru2_conf.mutable_param(7)->set_type(kParam);
+               gru2_conf.mutable_param(7)->mutable_init()->set_type(kConstant);
+               gru2_conf.mutable_param(7)->mutable_init()->set_value(0.5f);
+
+               gru2_conf.mutable_param(8)->set_name("bc2");
+               gru2_conf.mutable_param(8)->set_type(kParam);
+               gru2_conf.mutable_param(8)->mutable_init()->set_type(kConstant);
+               gru2_conf.mutable_param(8)->mutable_init()->set_value(0.5f);
+
+       }
+       singa::LayerProto in1_conf;
+       singa::LayerProto in2_conf;
+       singa::LayerProto gru1_conf;
+       singa::LayerProto gru2_conf;
+};
+
+TEST_F(GRULayerTest, Setup) {
+       singa::Driver driver;
+       //driver.RegisterLayer<GRULayer, int> (kGRU);
+       driver.RegisterParam<Param>(0);
+       driver.RegisterParamGenerator<UniformGen>(kUniform);
+       driver.RegisterParamGenerator<ParamGenerator>(kConstant);
+
+       singa::CSVInputLayer in_layer_1;
+       singa::CSVInputLayer in_layer_2;
+
+       in_layer_1.Setup(in1_conf, std::vector<singa::Layer*> { });
+       EXPECT_EQ(2, static_cast<int>(in_layer_1.aux_data().size()));
+       EXPECT_EQ(8, in_layer_1.data(nullptr).count());
+
+       in_layer_2.Setup(in2_conf, std::vector<singa::Layer*>{ });
+       EXPECT_EQ(2, static_cast<int>(in_layer_2.aux_data().size()));
+       EXPECT_EQ(8, in_layer_2.data(nullptr).count());
+
+       singa::GRULayer gru_layer_1;
+       gru_layer_1.Setup(gru1_conf, std::vector<singa::Layer*>{&in_layer_1});
+       //EXPECT_EQ(2, gru_layer_1.hdim());
+       //EXPECT_EQ(4, gru_layer_1.vdim());
+
+       for (unsigned int i = 0; i < gru_layer_1.GetParams().size(); i ++) {
+               gru_layer_1.GetParams()[i]->InitValues();
+       }
+       EXPECT_EQ (0.5, gru_layer_1.GetParams()[0]->data().cpu_data()[0]);
+       //cout << "gru_layer_1: " << 
gru_layer_1.GetParams()[0]->data().cpu_data()[0] << endl;
+
+       singa::GRULayer gru_layer_2;
+       gru_layer_2.Setup(gru2_conf, std::vector<singa::Layer*>{&in_layer_2, 
&gru_layer_1});
+       //EXPECT_EQ(2, gru_layer_2.hdim());
+       //EXPECT_EQ(4, gru_layer_2.vdim());
+       for (unsigned int i = 0; i < gru_layer_2.GetParams().size(); i ++) {
+               gru_layer_2.GetParams()[i]->InitValues();
+       }
+       EXPECT_EQ (0.5, gru_layer_2.GetParams()[0]->data().cpu_data()[0]);
+}
+
+
+TEST_F(GRULayerTest, ComputeFeature) {
+       singa::CSVInputLayer in_layer_1;
+       singa::CSVInputLayer in_layer_2;
+
+       in_layer_1.Setup(in1_conf, std::vector<singa::Layer*> { });
+       in_layer_1.ComputeFeature(singa::kTrain, std::vector<singa::Layer*> { 
});
+       in_layer_2.Setup(in2_conf, std::vector<singa::Layer*>{ });
+       in_layer_2.ComputeFeature(singa::kTrain, std::vector<singa::Layer*> { 
});
+
+
+       singa::GRULayer gru_layer_1;
+       gru_layer_1.Setup(gru1_conf, std::vector<singa::Layer*>{&in_layer_1});
+       for (unsigned int i = 0; i < gru_layer_1.GetParams().size(); i ++) {
+               gru_layer_1.GetParams()[i]->InitValues();
+       }
+       gru_layer_1.ComputeFeature(singa::kTrain, 
std::vector<singa::Layer*>{&in_layer_1});
+       for (int i = 0; i < gru_layer_1.data(nullptr).count(); i ++) {
+               
EXPECT_GT(0.000001,abs(0.204824-gru_layer_1.data(nullptr).cpu_data()[i]));
+       }
+
+       singa::GRULayer gru_layer_2;
+       gru_layer_2.Setup(gru2_conf, std::vector<singa::Layer*>{&in_layer_2, 
&gru_layer_1});
+
+       for (unsigned int i = 0; i < gru_layer_2.GetParams().size(); i ++) {
+               gru_layer_2.GetParams()[i]->InitValues();
+       }
+       gru_layer_2.ComputeFeature(singa::kTrain, 
std::vector<singa::Layer*>{&in_layer_2, &gru_layer_1});
+       for (int i = 0; i < gru_layer_2.data(nullptr).count(); i ++) {
+               
EXPECT_GT(0.000001,abs(0.346753-gru_layer_2.data(nullptr).cpu_data()[i]));
+       }
+}
+
+
+TEST_F(GRULayerTest, ComputeGradient) {
+       singa::CSVInputLayer in_layer_1;
+       singa::CSVInputLayer in_layer_2;
+
+       in_layer_1.Setup(in1_conf, std::vector<singa::Layer*> { });
+       in_layer_1.ComputeFeature(singa::kTrain, std::vector<singa::Layer*> { 
});
+       in_layer_2.Setup(in2_conf, std::vector<singa::Layer*>{ });
+       in_layer_2.ComputeFeature(singa::kTrain, std::vector<singa::Layer*> { 
});
+
+
+       singa::GRULayer gru_layer_1;
+       gru_layer_1.Setup(gru1_conf, std::vector<singa::Layer*>{&in_layer_1});
+       for (unsigned int i = 0; i < gru_layer_1.GetParams().size(); i ++) {
+               gru_layer_1.GetParams()[i]->InitValues();
+       }
+       gru_layer_1.ComputeFeature(singa::kTrain, 
std::vector<singa::Layer*>{&in_layer_1});
+
+
+       singa::GRULayer gru_layer_2;
+       gru_layer_2.Setup(gru2_conf, std::vector<singa::Layer*>{&in_layer_2, 
&gru_layer_1});
+       for (unsigned int i = 0; i < gru_layer_2.GetParams().size(); i ++) {
+               gru_layer_2.GetParams()[i]->InitValues();
+       }
+       gru_layer_2.ComputeFeature(singa::kTrain, 
std::vector<singa::Layer*>{&in_layer_2, &gru_layer_1});
+
+       // For test purpose, we set dummy values for gru_layer_2.grad_
+       for (int i = 0; i < gru_layer_2.grad(nullptr).count(); i ++) {
+               gru_layer_2.mutable_grad(nullptr)->mutable_cpu_data()[i] = 1.0f;
+       }
+       gru_layer_2.ComputeGradient(singa::kTrain, 
std::vector<singa::Layer*>{&in_layer_2, &gru_layer_1});
+
+       gru_layer_1.ComputeGradient(singa::kTrain, 
std::vector<singa::Layer*>{&in_layer_1});
+
+}

Reply via email to