SINGA-100 Implement layers using CUDNN for GPU training

1. Fix bugs from rebase onto latest master.
The LOG(FATAL) inside mutable_grad and mutable_data for input/output/loss layers
are commmented out. Because other layers may call check  mutable_xxx() == 
nullptr.

2. Update Makefile.am; run make test;


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/f8be9afa
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/f8be9afa
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/f8be9afa

Branch: refs/heads/master
Commit: f8be9afac09882ae85e905c9e679b6effd642f5c
Parents: 5728c24
Author: Wei Wang <[email protected]>
Authored: Fri Dec 11 10:59:33 2015 +0800
Committer: Wei Wang <[email protected]>
Committed: Fri Dec 11 11:48:24 2015 +0800

----------------------------------------------------------------------
 Makefile.am                                     | 29 ++++---
 examples/cifar10/job.conf                       |  7 +-
 include/singa/neuralnet/connection_layer.h      | 17 ++--
 .../singa/neuralnet/connection_layer/slice.h    | 49 -----------
 .../singa/neuralnet/connection_layer/split.h    | 48 -----------
 include/singa/neuralnet/layer.h                 | 12 +--
 include/singa/neuralnet/neuron_layer.h          | 34 +++++---
 include/singa/neuralnet/neuron_layer/dummy.h    | 51 -----------
 include/singa/neuralnet/neuron_layer/rbm.h      | 89 --------------------
 include/singa/neuralnet/neuron_layer/sigmoid.h  | 44 ----------
 include/singa/utils/blob.h                      |  2 +-
 include/singa/utils/param.h                     |  2 +-
 src/neuralnet/connection_layer/bridge.cc        |  7 +-
 src/neuralnet/connection_layer/split.cc         |  2 +-
 src/neuralnet/loss_layer/cudnn_softmaxloss.cc   |  2 +-
 src/neuralnet/neuralnet.cc                      |  6 +-
 src/neuralnet/neuron_layer/dummy.cc             |  2 +-
 src/proto/job.proto                             |  4 +-
 src/test/test_connection_layers.cc              | 10 +--
 src/utils/blob.cc                               |  8 +-
 src/utils/param.cc                              | 24 +++---
 21 files changed, 90 insertions(+), 359 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/f8be9afa/Makefile.am
----------------------------------------------------------------------
diff --git a/Makefile.am b/Makefile.am
index bbb3497..470ea8a 100644
--- a/Makefile.am
+++ b/Makefile.am
@@ -18,7 +18,7 @@
 #* under the License.
 #*
 #*************************************************************/
-       
+
 
 ACLOCAL_AMFLAGS = -I config
 AUTOMAKE_OPTIONS = foreign subdir-objects
@@ -51,7 +51,7 @@ PROTO_PYS := tool/python/pb2/singa_pb2.py \
 CUDA_SRCS := src/utils/math_kernel.cu
 
 PY_SRCS := tool/python/singa/driver_wrap.cxx
-                  src/driver.cc        
+                  src/driver.cc
 
 SINGA_SRCS := src/driver.cc \
               src/server.cc \
@@ -68,22 +68,24 @@ SINGA_SRCS := src/driver.cc \
               src/neuralnet/input_layer/record.cc \
               src/neuralnet/input_layer/deprecated.cc \
               src/neuralnet/input_layer/store.cc \
+              src/neuralnet/output_layer/accuracy.cc \
               src/neuralnet/output_layer/argsort.cc \
               src/neuralnet/output_layer/csv.cc \
               src/neuralnet/output_layer/record.cc \
               src/neuralnet/loss_layer/euclidean.cc \
               src/neuralnet/loss_layer/softmax.cc \
+              src/neuralnet/neuron_layer/activation.cc \
               src/neuralnet/neuron_layer/convolution.cc \
-              src/neuralnet/neuron_layer/dummy.cc \
               src/neuralnet/neuron_layer/dropout.cc \
+              src/neuralnet/neuron_layer/dummy.cc \
               src/neuralnet/neuron_layer/inner_product.cc \
               src/neuralnet/neuron_layer/lrn.cc \
               src/neuralnet/neuron_layer/pooling.cc \
               src/neuralnet/neuron_layer/rbm.cc \
               src/neuralnet/neuron_layer/relu.cc \
               src/neuralnet/neuron_layer/sigmoid.cc \
-              src/neuralnet/neuron_layer/stanh.cc \
               src/neuralnet/neuron_layer/softmax.cc \
+              src/neuralnet/neuron_layer/stanh.cc \
               src/neuralnet/neuralnet.cc \
               src/comm/socket.cc \
               src/comm/msg.cc \
@@ -155,11 +157,11 @@ TEST_SRCS := include/gtest/gtest_main.cc \
                                                 
src/test/test_csv_input_layer.cc
 
 #EXTRA_PROGRAMS = $(PROGS)
-EXTRA_PROGRAMS = singatest test 
+EXTRA_PROGRAMS = singatest test
 #EXTRA_LTLIBRARIES = $(LTLIBS)
 EXTRA_LTLIBRARIES = libgtest.la _driver.la
 
-lib_LTLIBRARIES = libsinga.la $(LTLIBS) 
+lib_LTLIBRARIES = libsinga.la $(LTLIBS)
 bin_PROGRAMS = singa singatool $(PROGS)
 pydir = $(CURDIR)/tool/python/singa/
 py_LTLIBRARIES = $(PY_PROGS)
@@ -167,7 +169,7 @@ py_LTLIBRARIES = $(PY_PROGS)
 
 #lib_LTLIBRARIES = libsinga.la
 libsinga_la_SOURCES = $(PROTO_SRCS) $(SINGA_SRCS)
-libsinga_la_CXXFLAGS = $(DEFAULT_FLAGS) -msse3 -fpermissive 
-I$(top_srcdir)/include 
+libsinga_la_CXXFLAGS = $(DEFAULT_FLAGS) -msse3 -fpermissive 
-I$(top_srcdir)/include
 if LMDB
 libsinga_la_CXXFLAGS += -DUSE_LMDB
 endif
@@ -181,7 +183,7 @@ endif
 
 #bin_PROGRAMS = singa
 singa_SOURCES = src/main.cc
-singa_CXXFLAGS = $(DEFAULT_FLAGS) -MMD -I$(top_srcdir)/include 
+singa_CXXFLAGS = $(DEFAULT_FLAGS) -MMD -I$(top_srcdir)/include
 singa_LDFLAGS = -lsinga \
                 -lglog  \
                 -lprotobuf \
@@ -203,7 +205,7 @@ endif
 #bin_PROGRAMS += singatool
 singatool_SOURCES = src/utils/tool.cc
 singatool_CXXFLAGS = -Wall -pthread -fPIC -std=c++11 -MMD -Wno-unknown-pragmas 
\
-                     -funroll-loops -DTHREADED -I$(top_srcdir)/include 
+                     -funroll-loops -DTHREADED -I$(top_srcdir)/include
 singatool_LDFLAGS = -lsinga \
                     -lglog  \
                     -lprotobuf \
@@ -220,7 +222,7 @@ endif
 #bin_PROGRAMS += test
 
 singatest_SOURCES = $(GTEST_HDRS) $(TEST_SRCS)
-singatest_CXXFLAGS = $(DEFAULT_FLAGS) -I$(top_srcdir)/include 
+singatest_CXXFLAGS = $(DEFAULT_FLAGS) -I$(top_srcdir)/include
 singatest_LDADD = ./libgtest.la
 singatest_LDFLAGS = -lsinga \
                 -lglog  \
@@ -241,8 +243,8 @@ singatest_LDFLAGS += $(CUDA_LDFLAGS) $(CUDA_LIBS)
 endif
 
 _driver_la_SOURCES = $(PY_SRCS)
-_driver_la_CXXFLAGS = $(DEFAULT_FLAGS) $(MSHADOW_FLAGS) 
-I$(top_srcdir)/include $(PYFLAGS) 
-_driver_la_LDFLAGS = -lsinga -module -shared $(PYLIBS) -avoid-version -rpath 
$(pydir) 
+_driver_la_CXXFLAGS = $(DEFAULT_FLAGS) $(MSHADOW_FLAGS) 
-I$(top_srcdir)/include $(PYFLAGS)
+_driver_la_LDFLAGS = -lsinga -module -shared $(PYLIBS) -avoid-version -rpath 
$(pydir)
 
 if DCUDA
 _driver_la_CXXFLAGS += $(CUDA_CFLAGS)
@@ -281,7 +283,7 @@ install-pyLTLIBRARIES: $(py_LTLIBRARIES)
        touch tool/python/singa/__init__.py
        cp -f .libs/_driver.so tool/python/singa/
 
-uninstall-pyLTLIBRARIES: 
+uninstall-pyLTLIBRARIES:
        rm -f tool/python/singa/__init__.py
        rm -f tool/python/singa/_driver.so
 
@@ -296,3 +298,4 @@ $(PROTO_HDRS) $(PROTO_SRCS): $(PROTOS)
        mkdir -p $(top_srcdir)/include/singa/proto/
        cp $(top_srcdir)/src/proto/*.pb.h $(top_srcdir)/include/singa/proto/
        @echo
+

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/f8be9afa/examples/cifar10/job.conf
----------------------------------------------------------------------
diff --git a/examples/cifar10/job.conf b/examples/cifar10/job.conf
index 1dad0f7..d20b452 100644
--- a/examples/cifar10/job.conf
+++ b/examples/cifar10/job.conf
@@ -1,11 +1,10 @@
 name: "cifar10-convnet"
-train_steps: 5
+train_steps: 1000
 test_steps: 100
-test_freq: 0
+test_freq: 200
 #validate_steps: 100
 #validate_freq: 300
-disp_freq: 1
-debug: true
+disp_freq: 50
 #checkpoint_path: "examples/cifar10/checkpoint/step1000-worker0"
 train_one_batch {
   alg: kBP

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/f8be9afa/include/singa/neuralnet/connection_layer.h
----------------------------------------------------------------------
diff --git a/include/singa/neuralnet/connection_layer.h 
b/include/singa/neuralnet/connection_layer.h
index 0b14a94..0cbe940 100644
--- a/include/singa/neuralnet/connection_layer.h
+++ b/include/singa/neuralnet/connection_layer.h
@@ -93,15 +93,14 @@ class ConcateLayer : public ConnectionLayer {
  */
 class SliceLayer : public ConnectionLayer {
  public:
+  ~SliceLayer();
   void Setup(const LayerProto& proto, const vector<Layer*>& srclayers) 
override;
   void ComputeFeature(int flag, const vector<Layer*>& srclayers) override;
   void ComputeGradient(int flag, const vector<Layer*>& srclayers) override;
-
- private:
-  std::vector<Blob<float>> datavec_;
-  std::vector<Blob<float>> gradvec_;
-  int slice_dim_;
-  int slice_num_;
+  const Blob<float>& data(const Layer* from) const override;
+  const Blob<float>& grad(const Layer* from) const override;
+  Blob<float>* mutable_data(const Layer* from) override;
+  Blob<float>* mutable_grad(const Layer* from) override;
 };
 
 /**
@@ -113,12 +112,12 @@ class SliceLayer : public ConnectionLayer {
  */
 class SplitLayer : public ConnectionLayer {
  public:
+  ~SplitLayer();
   void Setup(const LayerProto& proto, const vector<Layer*>& srclayers) 
override;
   void ComputeFeature(int flag, const vector<Layer*>& srclayers) override;
   void ComputeGradient(int flag, const vector<Layer*>& srclayers) override;
-
- protected:
-  Blob<float> grads_;
+  const Blob<float>& grad(const Layer* from) const override;
+  Blob<float>* mutable_grad(const Layer* from) override;
 };
 
 

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/f8be9afa/include/singa/neuralnet/connection_layer/slice.h
----------------------------------------------------------------------
diff --git a/include/singa/neuralnet/connection_layer/slice.h 
b/include/singa/neuralnet/connection_layer/slice.h
deleted file mode 100644
index 023ebc1..0000000
--- a/include/singa/neuralnet/connection_layer/slice.h
+++ /dev/null
@@ -1,49 +0,0 @@
-/************************************************************
-*
-* Licensed to the Apache Software Foundation (ASF) under one
-* or more contributor license agreements.  See the NOTICE file
-* distributed with this work for additional information
-* regarding copyright ownership.  The ASF licenses this file
-* to you under the Apache License, Version 2.0 (the
-* "License"); you may not use this file except in compliance
-* with the License.  You may obtain a copy of the License at
-*
-*   http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing,
-* software distributed under the License is distributed on an
-* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-* KIND, either express or implied.  See the License for the
-* specific language governing permissions and limitations
-* under the License.
-*
-*************************************************************/
-
-#ifndef SINGA_NEURALNET_CONNECTION_LAYER_SLICE_H_
-#define SINGA_NEURALNET_CONNECTION_LAYER_SLICE_H_
-
-#include <vector>
-#include "singa/neuralnet/layer.h"
-
-namespace singa {
-/**
- * Connect a single (src) layer with multiple (dst) layers.
- *
- * It slices the feature Blob (i.e., matrix) of the src layer on one dimension.
- * The sliced feature Blobs will be fed into dst layers.
- */
-class SliceLayer : public ConnectionLayer {
- public:
-  ~SliceLayer();
-  void Setup(const LayerProto& proto, const vector<Layer*>& srclayers) 
override;
-  void ComputeFeature(int flag, const vector<Layer*>& srclayers) override;
-  void ComputeGradient(int flag, const vector<Layer*>& srclayers) override;
-  const Blob<float>& data(const Layer* from) const override;
-  const Blob<float>& grad(const Layer* from) const override;
-  Blob<float>* mutable_data(const Layer* from) override;
-  Blob<float>* mutable_grad(const Layer* from) override;
-};
-
-}  // namespace singa
-
-#endif  // SINGA_NEURALNET_CONNECTION_LAYER_SLICE_H_

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/f8be9afa/include/singa/neuralnet/connection_layer/split.h
----------------------------------------------------------------------
diff --git a/include/singa/neuralnet/connection_layer/split.h 
b/include/singa/neuralnet/connection_layer/split.h
deleted file mode 100644
index 959d1a3..0000000
--- a/include/singa/neuralnet/connection_layer/split.h
+++ /dev/null
@@ -1,48 +0,0 @@
-/************************************************************
-*
-* Licensed to the Apache Software Foundation (ASF) under one
-* or more contributor license agreements.  See the NOTICE file
-* distributed with this work for additional information
-* regarding copyright ownership.  The ASF licenses this file
-* to you under the Apache License, Version 2.0 (the
-* "License"); you may not use this file except in compliance
-* with the License.  You may obtain a copy of the License at
-*
-*   http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing,
-* software distributed under the License is distributed on an
-* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-* KIND, either express or implied.  See the License for the
-* specific language governing permissions and limitations
-* under the License.
-*
-*************************************************************/
-
-#ifndef SINGA_NEURALNET_CONNECTION_LAYER_SPLIT_H_
-#define SINGA_NEURALNET_CONNECTION_LAYER_SPLIT_H_
-
-#include <vector>
-#include "singa/neuralnet/layer.h"
-
-namespace singa {
-/**
- * Connect a single (src) layer with multiple dst layers.
- *
- * It replicates the feature Blob of the src layer.
- * Each replicated feature Blob will be fed into one dst layer.
- * It aggregates gradients set by all dst layers and set it to the src layer.
- */
-class SplitLayer : public ConnectionLayer {
- public:
-  ~SplitLayer();
-  void Setup(const LayerProto& proto, const vector<Layer*>& srclayers) 
override;
-  void ComputeFeature(int flag, const vector<Layer*>& srclayers) override;
-  void ComputeGradient(int flag, const vector<Layer*>& srclayers) override;
-  const Blob<float>& grad(const Layer* from) const override;
-  Blob<float>* mutable_grad(const Layer* from) override;
-};
-
-}  // namespace singa
-
-#endif  // SINGA_NEURALNET_CONNECTION_LAYER_SPLIT_H_

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/f8be9afa/include/singa/neuralnet/layer.h
----------------------------------------------------------------------
diff --git a/include/singa/neuralnet/layer.h b/include/singa/neuralnet/layer.h
index 599203f..3dbb3fc 100644
--- a/include/singa/neuralnet/layer.h
+++ b/include/singa/neuralnet/layer.h
@@ -283,12 +283,12 @@ class InputLayer : virtual public Layer {
   void ComputeGradient(int flag, const vector<Layer*>& srclayers) override {}
   ConnectionType dst_layer_connection() const override { return kOneToMany; }
   Blob<float>* mutable_grad(const Layer* layer) override {
-    LOG(FATAL) << "Input layer has no gradient blob";
     return nullptr;
+    // LOG(FATAL) << "Input layer has no gradient blob";
   }
   const Blob<float>& grad(const Layer* from) const override {
-    LOG(FATAL) << "Input layer has no gradient blob";
     return grad_;
+    // LOG(FATAL) << "Input layer has no gradient blob";
   }
 };
 
@@ -309,12 +309,12 @@ class NeuronLayer : virtual public Layer {
 class LossLayer : virtual public Layer {
  public:
   Blob<float>* mutable_grad(const Layer* layer) override {
-    LOG(FATAL) << "Loss layer has no gradient blob";
     return nullptr;
+    // LOG(FATAL) << "Loss layer has no gradient blob";
   }
   const Blob<float>& grad(const Layer* from) const override {
-    LOG(FATAL) << "Loss layer has no gradient blob";
     return grad_;
+    // LOG(FATAL) << "Loss layer has no gradient blob";
   }
 };
 
@@ -325,12 +325,12 @@ class OutputLayer : virtual public Layer {
  public:
   void ComputeGradient(int flag, const vector<Layer*>& srclayers) override {}
   Blob<float>* mutable_grad(const Layer* layer) override {
-    LOG(FATAL) << "Output layer has no gradient blob";
     return nullptr;
+    // LOG(FATAL) << "Output layer has no gradient blob";
   }
   const Blob<float>& grad(const Layer* from) const override {
-    LOG(FATAL) << "Output layer has no gradient blob";
     return grad_;
+    // LOG(FATAL) << "Output layer has no gradient blob";
   }
 };
 

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/f8be9afa/include/singa/neuralnet/neuron_layer.h
----------------------------------------------------------------------
diff --git a/include/singa/neuralnet/neuron_layer.h 
b/include/singa/neuralnet/neuron_layer.h
index 39f4d69..17891e1 100644
--- a/include/singa/neuralnet/neuron_layer.h
+++ b/include/singa/neuralnet/neuron_layer.h
@@ -109,6 +109,24 @@ class DropoutLayer : public NeuronLayer {
    */
   Blob<float> mask_;
 };
+/**
+ * This layer is dummy and do no real work.
+ * It is used for testing purpose only.
+ *
+ * Use it as input layer, it will generate random data;
+ * Use it as output layer, it will generate random grad;
+ * Use it as neuron layer, it will replicates data and grad.
+ */
+class DummyLayer: public Layer {
+ public:
+  void Setup(const LayerProto& proto, const vector<Layer*>& srclayers) 
override;
+  void ComputeFeature(int flag, const vector<Layer*>& srclayers) override;
+  void ComputeGradient(int flag, const vector<Layer*>& srclayers) override;
+ private:
+  bool input_ = false;  // use as input layer
+  bool output_ = false;  // use as output layer
+};
+
 
 /**
  * Layer that applys linear transformations as
@@ -356,16 +374,10 @@ class CudnnSoftmaxLayer : public SoftmaxLayer, public 
CudnnLayer {
 /**
  * Base layer for RBM models.
  */
-class RBMLayer: virtual public NeuronLayer {
+class RBMLayer: virtual public Layer {
  public:
   virtual ~RBMLayer() {}
   void Setup(const LayerProto& proto, const vector<Layer*>& srclayers) 
override;
-  const Blob<float>& neg_data(const Layer* layer) {
-    return neg_data_;
-  }
-  Blob<float>* mutable_neg_data(const Layer* layer) {
-    return &neg_data_;
-  }
   const std::vector<Param*> GetParams() const override {
     std::vector<Param*> params{weight_, bias_};
     return params;
@@ -382,16 +394,16 @@ class RBMLayer: virtual public NeuronLayer {
   int batchsize_;
   bool first_gibbs_;
   Param* weight_, *bias_;
-
+  Blob<float> pos_data_;
   Blob<float> neg_data_;
   Blob<float> neg_sample_;
-  Blob<float> sample_;
+  Blob<float> pos_sample_;
 };
 
 /**
  * RBM visible layer
  */
-class RBMVisLayer: public RBMLayer {
+class RBMVisLayer: public RBMLayer, public LossLayer {
  public:
   ~RBMVisLayer();
   void Setup(const LayerProto& proto, const vector<Layer*>& srclayers) 
override;
@@ -402,9 +414,9 @@ class RBMVisLayer: public RBMLayer {
  private:
   RBMLayer* hid_layer_;
   Layer* input_layer_;
-
   float error_ = 0.0f;
   int counter_ = 0;
+
 };
 /**
  * RBM hidden layer

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/f8be9afa/include/singa/neuralnet/neuron_layer/dummy.h
----------------------------------------------------------------------
diff --git a/include/singa/neuralnet/neuron_layer/dummy.h 
b/include/singa/neuralnet/neuron_layer/dummy.h
deleted file mode 100644
index 3177b7e..0000000
--- a/include/singa/neuralnet/neuron_layer/dummy.h
+++ /dev/null
@@ -1,51 +0,0 @@
-/************************************************************
-*
-* Licensed to the Apache Software Foundation (ASF) under one
-* or more contributor license agreements.  See the NOTICE file
-* distributed with this work for additional information
-* regarding copyright ownership.  The ASF licenses this file
-* to you under the Apache License, Version 2.0 (the
-* "License"); you may not use this file except in compliance
-* with the License.  You may obtain a copy of the License at
-*
-*   http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing,
-* software distributed under the License is distributed on an
-* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-* KIND, either express or implied.  See the License for the
-* specific language governing permissions and limitations
-* under the License.
-*
-*************************************************************/
-
-#ifndef SINGA_NEURALNET_NEURON_LAYER_DUMMY_H_
-#define SINGA_NEURALNET_NEURON_LAYER_DUMMY_H_
-
-#include <random>
-#include <vector>
-#include "singa/neuralnet/layer.h"
-#include "singa/proto/job.pb.h"
-
-namespace singa {
-/**
- * This layer is dummy and do no real work.
- * It is used for testing purpose only.
- *
- * Use it as input layer, it will generate random data;
- * Use it as output layer, it will generate random grad;
- * Use it as neuron layer, it will replicates data and grad.
- */
-class DummyLayer: public Layer {
- public:
-  void Setup(const LayerProto& proto, const vector<Layer*>& srclayers) 
override;
-  void ComputeFeature(int flag, const vector<Layer*>& srclayers) override;
-  void ComputeGradient(int flag, const vector<Layer*>& srclayers) override;
- private:
-  bool input_ = false;  // use as input layer
-  bool output_ = false;  // use as output layer
-};
-
-}  // namespace singa
-
-#endif  // SINGA_NEURALNET_NEURON_LAYER_DUMMY_H_

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/f8be9afa/include/singa/neuralnet/neuron_layer/rbm.h
----------------------------------------------------------------------
diff --git a/include/singa/neuralnet/neuron_layer/rbm.h 
b/include/singa/neuralnet/neuron_layer/rbm.h
deleted file mode 100644
index 432c499..0000000
--- a/include/singa/neuralnet/neuron_layer/rbm.h
+++ /dev/null
@@ -1,89 +0,0 @@
-/************************************************************
-*
-* Licensed to the Apache Software Foundation (ASF) under one
-* or more contributor license agreements.  See the NOTICE file
-* distributed with this work for additional information
-* regarding copyright ownership.  The ASF licenses this file
-* to you under the Apache License, Version 2.0 (the
-* "License"); you may not use this file except in compliance
-* with the License.  You may obtain a copy of the License at
-*
-*   http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing,
-* software distributed under the License is distributed on an
-* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-* KIND, either express or implied.  See the License for the
-* specific language governing permissions and limitations
-* under the License.
-*
-*************************************************************/
-
-#ifndef SINGA_NEURALNET_NEURON_LAYER_RBM_H_
-#define SINGA_NEURALNET_NEURON_LAYER_RBM_H_
-
-#include <vector>
-#include "singa/neuralnet/layer.h"
-#include "singa/proto/job.pb.h"
-
-namespace singa {
-/**
- * Base layer for RBM models.
- */
-class RBMLayer: virtual public Layer {
- public:
-  virtual ~RBMLayer() {}
-  void Setup(const LayerProto& proto, const vector<Layer*>& srclayers) 
override;
-  const std::vector<Param*> GetParams() const override {
-    std::vector<Param*> params{weight_, bias_};
-    return params;
-  }
-  virtual Blob<float>* Sample(int flat);
-
- protected:
-  //! if ture, sampling according to guassian distribution
-  bool gaussian_;
-  //! dimension of the hidden layer
-  int hdim_;
-  //! dimension of the visible layer
-  int vdim_;
-  int batchsize_;
-  bool first_gibbs_;
-  Param* weight_, *bias_;
-  Blob<float> pos_data_;
-  Blob<float> neg_data_;
-  Blob<float> neg_sample_;
-  Blob<float> pos_sample_;
-};
-
-/**
- * RBM visible layer
- */
-class RBMVisLayer: public RBMLayer, public LossLayer {
- public:
-  ~RBMVisLayer();
-  void Setup(const LayerProto& proto, const vector<Layer*>& srclayers) 
override;
-  void ComputeFeature(int flag, const vector<Layer*>& srclayers) override;
-  void ComputeGradient(int flag, const vector<Layer*>& srclayers) override;
-
- private:
-  RBMLayer* hid_layer_;
-  Layer* input_layer_;
-};
-/**
- * RBM hidden layer
- */
-class RBMHidLayer: public RBMLayer {
- public:
-  ~RBMHidLayer();
-  void Setup(const LayerProto& proto, const vector<Layer*>& srclayers) 
override;
-  void ComputeFeature(int flag, const vector<Layer*>& srclayers) override;
-  void ComputeGradient(int flag, const vector<Layer*>& srclayers) override;
-
- private:
-  RBMLayer *vis_layer_;
-};
-
-}  // namespace singa
-
-#endif  // SINGA_NEURALNET_NEURON_LAYER_RBM_H_

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/f8be9afa/include/singa/neuralnet/neuron_layer/sigmoid.h
----------------------------------------------------------------------
diff --git a/include/singa/neuralnet/neuron_layer/sigmoid.h 
b/include/singa/neuralnet/neuron_layer/sigmoid.h
deleted file mode 100644
index 3cf80e7..0000000
--- a/include/singa/neuralnet/neuron_layer/sigmoid.h
+++ /dev/null
@@ -1,44 +0,0 @@
-/************************************************************
-*
-* Licensed to the Apache Software Foundation (ASF) under one
-* or more contributor license agreements.  See the NOTICE file
-* distributed with this work for additional information
-* regarding copyright ownership.  The ASF licenses this file
-* to you under the Apache License, Version 2.0 (the
-* "License"); you may not use this file except in compliance
-* with the License.  You may obtain a copy of the License at
-*
-*   http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing,
-* software distributed under the License is distributed on an
-* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-* KIND, either express or implied.  See the License for the
-* specific language governing permissions and limitations
-* under the License.
-*
-*************************************************************/
-
-#ifndef SINGA_NEURALNET_NEURON_LAYER_SIGMOID_H_
-#define SINGA_NEURALNET_NEURON_LAYER_SIGMOID_H_
-
-#include <vector>
-#include "singa/neuralnet/layer.h"
-#include "singa/proto/job.pb.h"
-
-namespace singa {
-/**
- * This layer apply Sigmoid function to neuron activations.
- * f(x)=1/(1+exp(-x))
- * f'(x)=f(x)*(1-f(x))
- */
-class SigmoidLayer: public Layer {
- public:
-  void Setup(const LayerProto& proto, const vector<Layer*>& srclayers) 
override;
-  void ComputeFeature(int flag, const vector<Layer*>& srclayers) override;
-  void ComputeGradient(int flag, const vector<Layer*>& srclayers) override;
-};
-
-}  // namespace singa
-
-#endif  // SINGA_NEURALNET_NEURON_LAYER_SIGMOID_H_

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/f8be9afa/include/singa/utils/blob.h
----------------------------------------------------------------------
diff --git a/include/singa/utils/blob.h b/include/singa/utils/blob.h
index b260862..7e1e516 100644
--- a/include/singa/utils/blob.h
+++ b/include/singa/utils/blob.h
@@ -247,7 +247,7 @@ class Blob {
    * data_ field. For training with multi-gpu cards, cpu_only must be true,
    * becuase gpu memory cannot be shared among different devices.
    */
-  void ShareData(const Blob& other, bool cpu_only = true);
+  void ShareData(Blob* other, bool cpu_only = true);
 
   /*
   void Swap(Blob& other);

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/f8be9afa/include/singa/utils/param.h
----------------------------------------------------------------------
diff --git a/include/singa/utils/param.h b/include/singa/utils/param.h
index 33b61ff..415490e 100644
--- a/include/singa/utils/param.h
+++ b/include/singa/utils/param.h
@@ -146,7 +146,7 @@ class Param {
    * @param cpu_only if true, share only cpu memory (used for training with
    * multi-gpu cards); else, share both cpu and gpu memory.
    */
-  void ShareFrom(const Param& other, bool cpu_only);
+  void ShareFrom(Param* other, bool cpu_only);
   /**
    * Init param values from checkpoint blob.
    */

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/f8be9afa/src/neuralnet/connection_layer/bridge.cc
----------------------------------------------------------------------
diff --git a/src/neuralnet/connection_layer/bridge.cc 
b/src/neuralnet/connection_layer/bridge.cc
index 200a3f9..a2302ab 100644
--- a/src/neuralnet/connection_layer/bridge.cc
+++ b/src/neuralnet/connection_layer/bridge.cc
@@ -19,8 +19,7 @@
 *
 *************************************************************/
 
-<<<<<<< HEAD
-#include "singa/neuralnet/connection_layer/bridge.h"
+#include "singa/neuralnet/connection_layer.h"
 #include "singa/comm/msg.h"
 
 namespace singa {
@@ -70,8 +69,8 @@ void BridgeSrcLayer::Setup(const LayerProto& conf,
   Layer::Setup(conf, srclayers);
   data_.Reshape(srclayers[0]->data(this).shape());
   grad_.ReshapeLike(data_);
-  data_.ShareData(srclayers[0]->data(this));
-  grad_.ShareData(srclayers[0]->grad(this));
+  data_.ShareData(srclayers[0]->mutable_data(this));
+  grad_.ShareData(srclayers[0]->mutable_grad(this));
 }
 
 void BridgeSrcLayer::ComputeFeature(int flag, const vector<Layer*>& srcs) {

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/f8be9afa/src/neuralnet/connection_layer/split.cc
----------------------------------------------------------------------
diff --git a/src/neuralnet/connection_layer/split.cc 
b/src/neuralnet/connection_layer/split.cc
index 7ee24fa..5b9db5b 100644
--- a/src/neuralnet/connection_layer/split.cc
+++ b/src/neuralnet/connection_layer/split.cc
@@ -36,7 +36,7 @@ void SplitLayer::Setup(const LayerProto& conf,
   CHECK_EQ(srclayers.size(), 1);
   Layer::Setup(conf, srclayers);
   data_.Reshape(srclayers[0]->data(this).shape());
-  data_.ShareData(srclayers[0]->data(this));
+  data_.ShareData(srclayers[0]->mutable_data(this));
   CHECK_GT(num_partitions(), 0);
   // add num_partitions()-1 more grad blobs
   for (int i = 1; i < num_partitions(); ++i) {

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/f8be9afa/src/neuralnet/loss_layer/cudnn_softmaxloss.cc
----------------------------------------------------------------------
diff --git a/src/neuralnet/loss_layer/cudnn_softmaxloss.cc 
b/src/neuralnet/loss_layer/cudnn_softmaxloss.cc
index 78a035a..1fe228b 100644
--- a/src/neuralnet/loss_layer/cudnn_softmaxloss.cc
+++ b/src/neuralnet/loss_layer/cudnn_softmaxloss.cc
@@ -29,7 +29,7 @@ void CudnnSoftmaxLossLayer::Setup(const LayerProto& conf,
     const vector<Layer*>& srclayers) {
   softmax_.Setup(conf, vector<Layer*> {srclayers.at(0)});
   data_.Reshape(softmax_.data(this).shape());
-  data_.ShareData(*softmax_.mutable_data(this), false);
+  data_.ShareData(softmax_.mutable_data(this), false);
   batchsize_ = data_.shape(0);
   dim_ = data_.count() / batchsize_;
   LOG(ERROR) << batchsize_ << " " << dim_;

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/f8be9afa/src/neuralnet/neuralnet.cc
----------------------------------------------------------------------
diff --git a/src/neuralnet/neuralnet.cc b/src/neuralnet/neuralnet.cc
index aabc361..eb8c270 100644
--- a/src/neuralnet/neuralnet.cc
+++ b/src/neuralnet/neuralnet.cc
@@ -163,7 +163,7 @@ void NeuralNet::ShareParamsFrom(NeuralNet* other, bool 
cpu_only) {
       const auto& params = layer->GetParams();
       CHECK_EQ(params.size(), otherparams.size());
       for (size_t i = 0; i < params.size(); i++) {
-        params[i]->ShareFrom(*otherparams[i], cpu_only);
+        params[i]->ShareFrom(otherparams[i], cpu_only);
       }
     }
   }
@@ -416,7 +416,7 @@ void NeuralNet::CreateNetFromGraph(Graph* graph, int 
npartitions) {
     const string share_from = param->share_from();
     if (param->share_from() != "") {
       if (name2param.find(share_from) != name2param.end()) {
-        param->ShareFrom(*name2param.at(param->share_from()), false);
+        param->ShareFrom(name2param.at(param->share_from()), false);
       } else {
         LOG(FATAL) << "No param with the name (share_from) " << share_from;
       }
@@ -430,7 +430,7 @@ void NeuralNet::CreateNetFromGraph(Graph* graph, int 
npartitions) {
       auto params = (*it)->GetParams();
       CHECK_EQ(params.size(), owner_params.size());
       for (size_t i = 0; i < params.size(); i++)
-        params.at(i)->ShareFrom(*owner_params.at(i), true);
+        params.at(i)->ShareFrom(owner_params.at(i), true);
     }
   }
 }

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/f8be9afa/src/neuralnet/neuron_layer/dummy.cc
----------------------------------------------------------------------
diff --git a/src/neuralnet/neuron_layer/dummy.cc 
b/src/neuralnet/neuron_layer/dummy.cc
index 8d165f7..9b352dd 100644
--- a/src/neuralnet/neuron_layer/dummy.cc
+++ b/src/neuralnet/neuron_layer/dummy.cc
@@ -19,7 +19,7 @@
 *
 *************************************************************/
 
-#include "singa/neuralnet/neuron_layer/dummy.h"
+#include "singa/neuralnet/neuron_layer.h"
 #include <glog/logging.h>
 
 namespace singa {

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/f8be9afa/src/proto/job.proto
----------------------------------------------------------------------
diff --git a/src/proto/job.proto b/src/proto/job.proto
index ff27c05..80752c3 100644
--- a/src/proto/job.proto
+++ b/src/proto/job.proto
@@ -191,14 +191,14 @@ message LayerProto {
   optional string user_type =21;
 
   // proto for the specific layer
-  // configuration for activation layer
+  // configuration for input layers
   optional ActivationProto activation_conf = 54;
   // configuration for argsort layer
   optional ArgSortProto argsort_conf = 52;
   // configuration for convolution layer
   optional ConvolutionProto convolution_conf = 30;
   // configuration for dummy layer
-  optional DummyProto dummy_conf = 53;
+  optional DummyProto dummy_conf = 55;
   // configuration for dropout layer
   optional DropoutProto dropout_conf = 33;
   // configuration for inner product layer

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/f8be9afa/src/test/test_connection_layers.cc
----------------------------------------------------------------------
diff --git a/src/test/test_connection_layers.cc 
b/src/test/test_connection_layers.cc
index 2415fcd..a10d3be 100644
--- a/src/test/test_connection_layers.cc
+++ b/src/test/test_connection_layers.cc
@@ -25,11 +25,11 @@
 #include "gtest/gtest.h"
 #include "singa/comm/msg.h"
 #include "singa/comm/socket.h"
-#include "singa/neuralnet/connection_layer/bridge.h"
-#include "singa/neuralnet/connection_layer/concate.h"
-#include "singa/neuralnet/connection_layer/slice.h"
-#include "singa/neuralnet/connection_layer/split.h"
-#include "singa/neuralnet/neuron_layer/dummy.h"
+#include "singa/neuralnet/connection_layer.h"
+#include "singa/neuralnet/connection_layer.h"
+#include "singa/neuralnet/connection_layer.h"
+#include "singa/neuralnet/connection_layer.h"
+#include "singa/neuralnet/neuron_layer.h"
 #include "singa/proto/job.pb.h"
 
 using namespace singa;

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/f8be9afa/src/utils/blob.cc
----------------------------------------------------------------------
diff --git a/src/utils/blob.cc b/src/utils/blob.cc
index 9607683..735226d 100644
--- a/src/utils/blob.cc
+++ b/src/utils/blob.cc
@@ -265,12 +265,12 @@ void Blob<Dtype>::SetValue(Dtype v) {
     ptr[i] = v;
 }
 template <typename Dtype>
-void Blob<Dtype>::ShareData(const Blob& other, bool cpu_only) {
-  CHECK_EQ(count_, other.count());
+void Blob<Dtype>::ShareData(Blob* other, bool cpu_only) {
+  CHECK_EQ(count_, other->count());
   if (cpu_only)
-    data_->set_cpu_data(other.cpu_data());
+    data_->set_cpu_data(other->mutable_cpu_data());
   else
-    data_ = other.data_;
+    data_ = other->data_;
 }
 
 /*

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/f8be9afa/src/utils/param.cc
----------------------------------------------------------------------
diff --git a/src/utils/param.cc b/src/utils/param.cc
index 09f519b..097fa61 100644
--- a/src/utils/param.cc
+++ b/src/utils/param.cc
@@ -166,21 +166,21 @@ void Param::InitValues(int version) {
   set_version(version);
 }
 
-void Param::ShareFrom(const Param& other, bool cpu_only) {
-  proto_.set_owner(other.owner());
-  CHECK(data_.shape() == other.data_.shape());
-  data_.ShareData(other.data_, cpu_only);
+void Param::ShareFrom(Param* other, bool cpu_only) {
+  proto_.set_owner(other->owner());
+  CHECK(data_.shape() == other->data_.shape());
+  data_.ShareData(&(other->data_), cpu_only);
   if (grad_.count() == 0)
     grad_.Reshape(data_.shape());
-  version_ = other.version_;
-  last_version_ = other.last_version_;
-  slice_start_ = other.slice_start_;
-  num_slices_ = other.num_slices_;
-  slice_offset_ = other.slice_offset_;
-  slice_size_ = other.slice_size_;
+  version_ = other->version_;
+  last_version_ = other->last_version_;
+  slice_start_ = other->slice_start_;
+  num_slices_ = other->num_slices_;
+  slice_offset_ = other->slice_offset_;
+  slice_size_ = other->slice_size_;
   // change pending list size equal to slice size
-  pending_get_.resize(other.pending_get_.size());
-  pending_update_.resize(other.pending_update_.size());
+  pending_get_.resize(other->pending_get_.size());
+  pending_update_.resize(other->pending_update_.size());
 }
 
 void Param::FromProto(const BlobProto& blob) {

Reply via email to