SINGA-111 Add slice, concate and split layers

remove dedicated conf fields (slice_conf, concate_conf, split_conf)
instead, use partition_dim, num_partitions in LayerProto for configuration


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/7cdb22f6
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/7cdb22f6
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/7cdb22f6

Branch: refs/heads/master
Commit: 7cdb22f687b3a1c68e3ca2e28329e8eceba20017
Parents: 7e414e5
Author: WANG Sheng <[email protected]>
Authored: Thu Dec 10 16:19:11 2015 +0800
Committer: WANG Sheng <[email protected]>
Committed: Thu Dec 10 16:34:02 2015 +0800

----------------------------------------------------------------------
 .../singa/neuralnet/connection_layer/concate.h  |  3 -
 .../singa/neuralnet/connection_layer/slice.h    |  7 +-
 .../singa/neuralnet/connection_layer/split.h    |  5 +-
 include/singa/neuralnet/layer.h                 |  4 +-
 src/neuralnet/connection_layer/concate.cc       | 18 ++---
 src/neuralnet/connection_layer/slice.cc         | 70 +++++++++++---------
 src/neuralnet/connection_layer/split.cc         | 33 +++++----
 src/neuralnet/neuralnet.cc                      |  8 +--
 src/proto/job.proto                             | 20 ------
 src/test/test_connection_layers.cc              | 16 +++--
 10 files changed, 86 insertions(+), 98 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/7cdb22f6/include/singa/neuralnet/connection_layer/concate.h
----------------------------------------------------------------------
diff --git a/include/singa/neuralnet/connection_layer/concate.h 
b/include/singa/neuralnet/connection_layer/concate.h
index 6e40040..5875835 100644
--- a/include/singa/neuralnet/connection_layer/concate.h
+++ b/include/singa/neuralnet/connection_layer/concate.h
@@ -37,9 +37,6 @@ class ConcateLayer : public ConnectionLayer {
   void Setup(const LayerProto& proto, const vector<Layer*>& srclayers) 
override;
   void ComputeFeature(int flag, const vector<Layer*>& srclayers) override;
   void ComputeGradient(int flag, const vector<Layer*>& srclayers) override;
-
- private:
-  int concate_dim_;
 };
 
 }  // namespace singa

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/7cdb22f6/include/singa/neuralnet/connection_layer/slice.h
----------------------------------------------------------------------
diff --git a/include/singa/neuralnet/connection_layer/slice.h 
b/include/singa/neuralnet/connection_layer/slice.h
index a2f715c..023ebc1 100644
--- a/include/singa/neuralnet/connection_layer/slice.h
+++ b/include/singa/neuralnet/connection_layer/slice.h
@@ -34,6 +34,7 @@ namespace singa {
  */
 class SliceLayer : public ConnectionLayer {
  public:
+  ~SliceLayer();
   void Setup(const LayerProto& proto, const vector<Layer*>& srclayers) 
override;
   void ComputeFeature(int flag, const vector<Layer*>& srclayers) override;
   void ComputeGradient(int flag, const vector<Layer*>& srclayers) override;
@@ -41,12 +42,6 @@ class SliceLayer : public ConnectionLayer {
   const Blob<float>& grad(const Layer* from) const override;
   Blob<float>* mutable_data(const Layer* from) override;
   Blob<float>* mutable_grad(const Layer* from) override;
-
- private:
-  std::vector<Blob<float>> datavec_;
-  std::vector<Blob<float>> gradvec_;
-  int slice_dim_;
-  int slice_num_;
 };
 
 }  // namespace singa

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/7cdb22f6/include/singa/neuralnet/connection_layer/split.h
----------------------------------------------------------------------
diff --git a/include/singa/neuralnet/connection_layer/split.h 
b/include/singa/neuralnet/connection_layer/split.h
index f4de238..959d1a3 100644
--- a/include/singa/neuralnet/connection_layer/split.h
+++ b/include/singa/neuralnet/connection_layer/split.h
@@ -35,15 +35,12 @@ namespace singa {
  */
 class SplitLayer : public ConnectionLayer {
  public:
+  ~SplitLayer();
   void Setup(const LayerProto& proto, const vector<Layer*>& srclayers) 
override;
   void ComputeFeature(int flag, const vector<Layer*>& srclayers) override;
   void ComputeGradient(int flag, const vector<Layer*>& srclayers) override;
   const Blob<float>& grad(const Layer* from) const override;
   Blob<float>* mutable_grad(const Layer* from) override;
-
- private:
-  std::vector<Blob<float>> grads_;
-  int split_num_;
 };
 
 }  // namespace singa

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/7cdb22f6/include/singa/neuralnet/layer.h
----------------------------------------------------------------------
diff --git a/include/singa/neuralnet/layer.h b/include/singa/neuralnet/layer.h
index a78deb7..7a92ced 100644
--- a/include/singa/neuralnet/layer.h
+++ b/include/singa/neuralnet/layer.h
@@ -284,11 +284,11 @@ class InputLayer : virtual public Layer {
   void ComputeGradient(int flag, const vector<Layer*>& srclayers) override {}
   ConnectionType dst_layer_connection() const override { return kOneToMany; }
   Blob<float>* mutable_grad(const Layer* layer) override {
-    // LOG(FATAL) << "Input layer has no gradient blob";
+    LOG(FATAL) << "Input layer has no gradient blob";
     return nullptr;
   }
   const Blob<float>& grad(const Layer* from) const override {
-    // LOG(FATAL) << "Input layer has no gradient blob";
+    LOG(FATAL) << "Input layer has no gradient blob";
     return grad_;
   }
 };

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/7cdb22f6/src/neuralnet/connection_layer/concate.cc
----------------------------------------------------------------------
diff --git a/src/neuralnet/connection_layer/concate.cc 
b/src/neuralnet/connection_layer/concate.cc
index 8f519d2..f9d6416 100644
--- a/src/neuralnet/connection_layer/concate.cc
+++ b/src/neuralnet/connection_layer/concate.cc
@@ -29,14 +29,14 @@ void ConcateLayer::Setup(const LayerProto& conf,
                          const vector<Layer*>& srclayers) {
   CHECK_GT(srclayers.size(), 1);
   Layer::Setup(conf, srclayers);
-  concate_dim_ = conf.concate_conf().concate_dim();
   vector<int> shape = srclayers[0]->data(this).shape();
-  CHECK_GE(concate_dim_, 0);
-  CHECK_LT(concate_dim_, shape.size());
+  CHECK_GE(partition_dim(), 0);
+  CHECK_LT(partition_dim(), shape.size());
+  CHECK_EQ(num_partitions(), srclayers.size());
   for (size_t i = 1; i < srclayers.size(); i++) {
     const vector<int>& src_shape = srclayers[i]->data(this).shape();
     for (size_t j = 0; j < shape.size(); j++)
-      if (static_cast<int>(j) == concate_dim_)
+      if (static_cast<int>(j) == partition_dim())
         shape[j] += src_shape[j];
       else
         CHECK_EQ(shape[j], src_shape[j]);
@@ -47,9 +47,10 @@ void ConcateLayer::Setup(const LayerProto& conf,
 
 void ConcateLayer::ComputeFeature(int flag, const vector<Layer*>& srclayers) {
   CHECK_GT(srclayers.size(), 1);
+  CHECK_EQ(num_partitions(), srclayers.size());
   // calculate step for each memcpy
-  int step = srclayers[0]->data(this).shape()[concate_dim_];
-  for (unsigned i = concate_dim_ + 1; i < data_.shape().size(); ++i)
+  int step = srclayers[0]->data(this).shape()[partition_dim()];
+  for (unsigned i = partition_dim() + 1; i < data_.shape().size(); ++i)
     step *= data_.shape()[i];
   int srclayer_offset = 0;
   int concate_offset = 0;
@@ -66,9 +67,10 @@ void ConcateLayer::ComputeFeature(int flag, const 
vector<Layer*>& srclayers) {
 
 void ConcateLayer::ComputeGradient(int flag, const vector<Layer*>& srclayers) {
   CHECK_GT(srclayers.size(), 1);
+  CHECK_EQ(num_partitions(), srclayers.size());
   // calculate step for each memcpy
-  int step = srclayers[0]->grad(this).shape()[concate_dim_];
-  for (unsigned i = concate_dim_ + 1; i < grad_.shape().size(); ++i)
+  int step = srclayers[0]->grad(this).shape()[partition_dim()];
+  for (unsigned i = partition_dim() + 1; i < grad_.shape().size(); ++i)
     step *= grad_.shape()[i];
   int srclayer_offset = 0;
   int concate_offset = 0;

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/7cdb22f6/src/neuralnet/connection_layer/slice.cc
----------------------------------------------------------------------
diff --git a/src/neuralnet/connection_layer/slice.cc 
b/src/neuralnet/connection_layer/slice.cc
index 66d3578..c69f797 100644
--- a/src/neuralnet/connection_layer/slice.cc
+++ b/src/neuralnet/connection_layer/slice.cc
@@ -25,25 +25,33 @@ namespace singa {
 
 using std::vector;
 
+SliceLayer::~SliceLayer() {
+  for (size_t i = 1; i < datavec_.size(); ++i) {
+    if (datavec_[i] != nullptr) delete datavec_[i];
+    if (gradvec_[i] != nullptr) delete gradvec_[i];
+  }
+}
+
 void SliceLayer::Setup(const LayerProto& conf,
                        const vector<Layer*>& srclayers) {
   CHECK_EQ(srclayers.size(), 1);
   Layer::Setup(conf, srclayers);
-  slice_dim_ = conf.slice_conf().slice_dim();
-  slice_num_ = conf.slice_conf().slice_num();
   vector<int> shape = srclayers[0]->data(this).shape();
-  CHECK_GE(slice_dim_, 0);
-  CHECK_LT(slice_dim_, shape.size());
-  CHECK_GT(slice_num_, 0);
-  datavec_.resize(slice_num_);
-  gradvec_.resize(slice_num_);
+  CHECK_GE(partition_dim(), 0);
+  CHECK_LT(partition_dim(), shape.size());
+  CHECK_GT(num_partitions(), 0);
+  // add num_partitions()-1 more blobs
+  for (int i = 1; i < num_partitions(); ++i) {
+    datavec_.push_back(new Blob<float>());
+    gradvec_.push_back(new Blob<float>());
+  }
   // TODO(wangsh): remove equal-size restrict later
-  CHECK_EQ(shape[slice_dim_] % slice_num_, 0);
-  shape[slice_dim_] /= slice_num_;
-  for (int i = 0; i < slice_num_; ++i) {
-    // if (i == slice_num - 1) shape[slice_dim_] += remain;
-    datavec_[i].Reshape(shape);
-    gradvec_[i].Reshape(shape);
+  CHECK_EQ(shape[partition_dim()] % num_partitions(), 0);
+  shape[partition_dim()] /= num_partitions();
+  for (int i = 0; i < num_partitions(); ++i) {
+    // if (i == slice_num - 1) shape[partition_dim()] += remain;
+    datavec_[i]->Reshape(shape);
+    gradvec_[i]->Reshape(shape);
   }
 }
 
@@ -51,15 +59,15 @@ void SliceLayer::ComputeFeature(int flag, const 
vector<Layer*>& srclayers) {
   CHECK_EQ(srclayers.size(), 1);
   const Blob<float>& blob = srclayers[0]->data(this);
   // calculate step for each memcpy
-  int step = datavec_[0].shape()[slice_dim_];
-  for (unsigned i = slice_dim_ + 1; i < datavec_[0].shape().size(); ++i)
-    step *= datavec_[0].shape()[i];
+  int step = datavec_[0]->shape()[partition_dim()];
+  for (unsigned i = partition_dim() + 1; i < datavec_[0]->shape().size(); ++i)
+    step *= datavec_[0]->shape()[i];
   int srclayer_offset = 0;
   int slice_offset = 0;
   while (srclayer_offset < blob.count()) {
-    for (int i = 0; i < slice_num_; ++i) {
+    for (int i = 0; i < num_partitions(); ++i) {
       const float* src = blob.cpu_data() + srclayer_offset;
-      float* dst = datavec_[i].mutable_cpu_data() + slice_offset;
+      float* dst = datavec_[i]->mutable_cpu_data() + slice_offset;
       memcpy(dst, src, step * sizeof(float));
       srclayer_offset += step;
     }
@@ -71,14 +79,14 @@ void SliceLayer::ComputeGradient(int flag, const 
vector<Layer*>& srclayers) {
   CHECK_EQ(srclayers.size(), 1);
   Blob<float>* blob = srclayers[0]->mutable_grad(this);
   // calculate step for each memcpy
-  int step = gradvec_[0].shape()[slice_dim_];
-  for (size_t i = slice_dim_ + 1; i < gradvec_[0].shape().size(); ++i)
-    step *= gradvec_[0].shape()[i];
+  int step = gradvec_[0]->shape()[partition_dim()];
+  for (size_t i = partition_dim() + 1; i < gradvec_[0]->shape().size(); ++i)
+    step *= gradvec_[0]->shape()[i];
   int srclayer_offset = 0;
   int slice_offset = 0;
   while (srclayer_offset < blob->count()) {
-    for (int i = 0; i < slice_num_; ++i) {
-      const float* src = gradvec_[i].cpu_data() + slice_offset;
+    for (int i = 0; i < num_partitions(); ++i) {
+      const float* src = gradvec_[i]->cpu_data() + slice_offset;
       float* dst = blob->mutable_cpu_data() + srclayer_offset;
       memcpy(dst, src, step * sizeof(float));
       srclayer_offset += step;
@@ -89,26 +97,26 @@ void SliceLayer::ComputeGradient(int flag, const 
vector<Layer*>& srclayers) {
 
 const Blob<float>& SliceLayer::data(const Layer* from) const {
   CHECK(from);
-  CHECK_LT(from->partition_id(), datavec_.size());
-  return datavec_[from->partition_id()];
+  CHECK_LT(from->partition_id(), num_partitions());
+  return *datavec_[from->partition_id()];
 }
 
 const Blob<float>& SliceLayer::grad(const Layer* from) const {
   CHECK(from);
-  CHECK_LT(from->partition_id(), gradvec_.size());
-  return gradvec_[from->partition_id()];
+  CHECK_LT(from->partition_id(), num_partitions());
+  return *gradvec_[from->partition_id()];
 }
 
 Blob<float>* SliceLayer::mutable_data(const Layer* from) {
   CHECK(from);
-  CHECK_LT(from->partition_id(), datavec_.size());
-  return &datavec_[from->partition_id()];
+  CHECK_LT(from->partition_id(), num_partitions());
+  return datavec_[from->partition_id()];
 }
 
 Blob<float>* SliceLayer::mutable_grad(const Layer* from) {
   CHECK(from);
-  CHECK_LT(from->partition_id(), gradvec_.size());
-  return &gradvec_[from->partition_id()];
+  CHECK_LT(from->partition_id(), num_partitions());
+  return gradvec_[from->partition_id()];
 }
 
 }  // namespace singa

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/7cdb22f6/src/neuralnet/connection_layer/split.cc
----------------------------------------------------------------------
diff --git a/src/neuralnet/connection_layer/split.cc 
b/src/neuralnet/connection_layer/split.cc
index a9fd291..36b391c 100644
--- a/src/neuralnet/connection_layer/split.cc
+++ b/src/neuralnet/connection_layer/split.cc
@@ -25,16 +25,25 @@ namespace singa {
 
 using std::vector;
 
+SplitLayer::~SplitLayer() {
+  for (size_t i = 1; i < gradvec_.size(); ++i) {
+    if (gradvec_[i] != nullptr) delete gradvec_[i];
+  }
+}
+
 void SplitLayer::Setup(const LayerProto& conf,
                        const vector<Layer*>& srclayers) {
   CHECK_EQ(srclayers.size(), 1);
   Layer::Setup(conf, srclayers);
-  split_num_ = conf.split_conf().split_num();
   data_.Reshape(srclayers[0]->data(this).shape());
   data_.ShareData(srclayers[0]->data(this));
-  grads_.resize(split_num_);
-  for (int i = 0; i < split_num_; ++i)
-    grads_[i].Reshape(srclayers[0]->data(this).shape());
+  CHECK_GT(num_partitions(), 0);
+  // add num_partitions()-1 more grad blobs
+  for (int i = 1; i < num_partitions(); ++i) {
+    gradvec_.push_back(new Blob<float>());
+  }
+  for (int i = 0; i < num_partitions(); ++i)
+    gradvec_[i]->Reshape(srclayers[0]->data(this).shape());
 }
 
 void SplitLayer::ComputeFeature(int flag, const vector<Layer*>& srclayers) {
@@ -45,23 +54,23 @@ void SplitLayer::ComputeFeature(int flag, const 
vector<Layer*>& srclayers) {
 void SplitLayer::ComputeGradient(int flag, const vector<Layer*>& srclayers) {
   CHECK_EQ(srclayers.size(), 1);
   // aggregate all gradients to grad_[0]
-  for (int i = 1; i < split_num_; ++i)
-    for (int j = 0; j < grads_[0].count(); ++j)
-      grads_[0].mutable_cpu_data()[j] += grads_[i].cpu_data()[j];
+  for (int i = 1; i < num_partitions(); ++i)
+    for (int j = 0; j < gradvec_[0]->count(); ++j)
+      gradvec_[0]->mutable_cpu_data()[j] += gradvec_[i]->cpu_data()[j];
   // copy grad_[0] to srclayer's grad
-  srclayers[0]->mutable_grad(this)->CopyFrom(grads_[0]);
+  srclayers[0]->mutable_grad(this)->CopyFrom(*gradvec_[0]);
 }
 
 const Blob<float>& SplitLayer::grad(const Layer* from) const {
   CHECK(from);
-  CHECK_LT(from->partition_id(), grads_.size());
-  return grads_[from->partition_id()];
+  CHECK_LT(from->partition_id(), num_partitions());
+  return *gradvec_[from->partition_id()];
 }
 
 Blob<float>* SplitLayer::mutable_grad(const Layer* from) {
   CHECK(from);
-  CHECK_LT(from->partition_id(), grads_.size());
-  return &grads_[from->partition_id()];
+  CHECK_LT(from->partition_id(), num_partitions());
+  return gradvec_[from->partition_id()];
 }
 
 }  // namespace singa

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/7cdb22f6/src/neuralnet/neuralnet.cc
----------------------------------------------------------------------
diff --git a/src/neuralnet/neuralnet.cc b/src/neuralnet/neuralnet.cc
index 8ae1805..b7944e7 100644
--- a/src/neuralnet/neuralnet.cc
+++ b/src/neuralnet/neuralnet.cc
@@ -178,9 +178,8 @@ Node* SliceNode(Graph* graph, Node* srcnode,
   proto->set_type(LayerType::kSlice);
   proto->set_partition_id(
       static_cast<LayerProto*>(srcnode->proto)->partition_id());
-  auto conf = proto->mutable_slice_conf();
-  conf->set_slice_dim(
-      static_cast<LayerProto*>(dstnodes[0]->proto)->partition_dim());
+  proto->set_partition_dim(
+      static_cast<LayerProto*>(srcnode->proto)->partition_dim());
   Node* node = new Node(name, "##" + name, proto->partition_id(), proto);
   graph->AddNode(node);
   graph->AddEdge(srcnode, node);
@@ -198,8 +197,7 @@ Node* ConcateNodes(Graph* graph, const vector<Node*>& 
srcnodes, Node* dstnode) {
   proto->set_type(LayerType::kConcate);
   proto->set_partition_id(
       static_cast<LayerProto*>(dstnode->proto)->partition_id());
-  auto conf = proto->mutable_concate_conf();
-  conf->set_concate_dim(
+  proto->set_partition_dim(
       static_cast<LayerProto*>(srcnodes[0]->proto)->partition_dim());
   Node* node = new Node(name, "##" + name, proto->partition_id(), proto);
   graph->AddNode(node);

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/7cdb22f6/src/proto/job.proto
----------------------------------------------------------------------
diff --git a/src/proto/job.proto b/src/proto/job.proto
index 7a30b73..ae461f8 100644
--- a/src/proto/job.proto
+++ b/src/proto/job.proto
@@ -192,8 +192,6 @@ message LayerProto {
   optional ArgSortProto argsort_conf = 52;
   // configuration for convolution layer
   optional ConvolutionProto convolution_conf = 30;
-  // configuration for concatenation layer
-  optional ConcateProto concate_conf = 31;
   // configuration for dummy layer
   optional DummyProto dummy_conf = 53;
   // configuration for dropout layer
@@ -218,12 +216,8 @@ message LayerProto {
   optional RGBImageProto rgbimage_conf = 39;
   // configuration for data layer
   optional DataProto sharddata_conf = 32;
-  // configuration for slice layer
-  optional SliceProto slice_conf = 41;
   // configuration for softmax loss layer
   optional SoftmaxLossProto softmaxloss_conf = 40;
-  // configuration for split layer
-  optional SplitProto split_conf = 42;
   // configuration for store input layers
   optional StoreProto store_conf = 51;
 
@@ -319,10 +313,6 @@ message PrefetchProto {
   repeated LayerProto sublayers = 1;
 }
 
-message SplitProto {
-  optional int32 split_num = 1 [default = 1];
-}
-
 message StoreProto {
   optional string backend = 1;
   optional string path = 2;
@@ -363,11 +353,6 @@ message ConvolutionProto {
   optional bool bias_term = 32 [default = true];
 }
 
-message ConcateProto {
-  // on which dimension, starts from 0
-  required int32 concate_dim = 1;
-}
-
 message DataProto {
   // path to the data file/folder, absolute or relative to the workspace
   required string path = 2;
@@ -451,11 +436,6 @@ message PoolingProto {
   optional uint32 stride = 32 [default = 1];
 }
 
-message SliceProto {
-  required int32 slice_dim = 1;
-  required int32 slice_num = 2;
-}
-
 message ReLUProto {
   // Ref. Maas, A. L., Hannun, A. Y., & Ng, A. Y. (2013).
   // Rectifier nonlinearities improve neural network acoustic models.

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/7cdb22f6/src/test/test_connection_layers.cc
----------------------------------------------------------------------
diff --git a/src/test/test_connection_layers.cc 
b/src/test/test_connection_layers.cc
index 4d8c984..2415fcd 100644
--- a/src/test/test_connection_layers.cc
+++ b/src/test/test_connection_layers.cc
@@ -178,8 +178,8 @@ TEST(ConnectionLayerTest, DataSliceTest) {
   src_slice.push_back(static_cast<Layer*>(&in));
   LayerProto proto_slice;
   proto_slice.set_name("slice");
-  proto_slice.mutable_slice_conf()->set_slice_dim(0);
-  proto_slice.mutable_slice_conf()->set_slice_num(K);
+  proto_slice.set_partition_dim(0);
+  proto_slice.set_num_partitions(K);
   SliceLayer slice;
   slice.Setup(proto_slice, src_slice);
   ASSERT_EQ(slice.data(static_cast<Layer*>(&slice)).shape(0), N / K);
@@ -235,8 +235,8 @@ TEST(ConnectionLayerTest, ModelSliceTest) {
   src_slice.push_back(static_cast<Layer*>(&in));
   LayerProto proto_slice;
   proto_slice.set_name("slice");
-  proto_slice.mutable_slice_conf()->set_slice_dim(1);
-  proto_slice.mutable_slice_conf()->set_slice_num(K);
+  proto_slice.set_partition_dim(1);
+  proto_slice.set_num_partitions(K);
   SliceLayer slice;
   slice.Setup(proto_slice, src_slice);
   ASSERT_EQ(slice.data(static_cast<Layer*>(&slice)).shape(0), N);
@@ -300,7 +300,8 @@ TEST(ConnectionLayerTest, DataConcateTest) {
     src_concate.push_back(static_cast<Layer*>(&in[i]));
   LayerProto proto_concate;
   proto_concate.set_name("concate");
-  proto_concate.mutable_concate_conf()->set_concate_dim(0);
+  proto_concate.set_partition_dim(0);
+  proto_concate.set_num_partitions(K);
   ConcateLayer concate;
   concate.Setup(proto_concate, src_concate);
   ASSERT_EQ(concate.data(static_cast<Layer*>(&concate)).shape(0), N);
@@ -357,7 +358,8 @@ TEST(ConnectionLayerTest, ModelConcateTest) {
     src_concate.push_back(static_cast<Layer*>(&in[i]));
   LayerProto proto_concate;
   proto_concate.set_name("concate");
-  proto_concate.mutable_concate_conf()->set_concate_dim(1);
+  proto_concate.set_partition_dim(1);
+  proto_concate.set_num_partitions(K);
   ConcateLayer concate;
   concate.Setup(proto_concate, src_concate);
   ASSERT_EQ(concate.data(static_cast<Layer*>(&concate)).shape(0), N);
@@ -414,7 +416,7 @@ TEST(ConnectionLayerTest, SplitTest) {
   src_split.push_back(static_cast<Layer*>(&in));
   LayerProto proto_split;
   proto_split.set_name("split");
-  proto_split.mutable_split_conf()->set_split_num(K);
+  proto_split.set_num_partitions(K);
   SplitLayer split;
   split.Setup(proto_split, src_split);
   ASSERT_EQ(split.data(static_cast<Layer*>(&split)).shape(0), N);

Reply via email to