SINGA-226 Add parallel training on a single machine for singa v1.0

This commit implements a updater class for parallel training
in local-cpu and local-gpu mode. (The mode specification described
in https://issues.apache.org/jira/browse/SINGA-226)

Updater class is a wrapper of Optimizer class. It controls the
communication pattern among workers. When initializing Updater,
the user needs to provide the total number of workers and where
the Updater does aggregation and computation.

File changed descibed as follows:
* Put a new folder named cifar-parallel under example, containing
  the single machine multi-gpu parallel training example.
* Replace Optimizer pointer in feed_forward_net class with Updater
  pointer, since Updater is a wrapper of Optimizer. So the compile()
  method is changed accordingly.
* Add a helper function TrainThread() method in feed_forward_net to
  launch a training thread.
* Adapt alexnet.cc original cifar10 example to support the new compile()
* Fixed a bug in memory.cc which happens during initialzing cnmem
  memory pool


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/d45715da
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/d45715da
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/d45715da

Branch: refs/heads/dev
Commit: d45715da07a65e38e5e8f437461c37da4092a9c3
Parents: 8e3d3df
Author: WANG Ji <[email protected]>
Authored: Thu Jul 14 16:59:16 2016 +0800
Committer: WANG Ji <[email protected]>
Committed: Thu Jul 21 16:33:23 2016 +0800

----------------------------------------------------------------------
 CMakeLists.txt                                |   2 +-
 examples/CMakeLists.txt                       |   9 +
 examples/cifar10-parallel/alexnet-parallel.cc | 286 +++++++++++++++++++++
 examples/cifar10-parallel/cifar10.h           |  98 +++++++
 examples/cifar10-parallel/download_data.py    |  52 ++++
 examples/cifar10-parallel/run.sh              |   2 +
 examples/cifar10/alexnet.cc                   |  31 ++-
 include/singa/model/feed_forward_net.h        |  26 +-
 include/singa/model/updater.h                 |  68 +++++
 src/core/memory/memory.cc                     |   2 +-
 src/model/feed_forward_net.cc                 |  36 +--
 src/model/updater.cc                          |  87 +++++++
 12 files changed, 667 insertions(+), 32 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/d45715da/CMakeLists.txt
----------------------------------------------------------------------
diff --git a/CMakeLists.txt b/CMakeLists.txt
index d3cd776..dca8a30 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -1,7 +1,7 @@
 CMAKE_MINIMUM_REQUIRED(VERSION 2.6)
 
 PROJECT(singa)
-SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11")
+SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11 -lpthread")
 
 LIST(APPEND CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/cmake/Thirdparty)
 #message(STATUS "module path: ${CMAKE_MODULE_PATH}")

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/d45715da/examples/CMakeLists.txt
----------------------------------------------------------------------
diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt
index c2ec7e9..26d5cdd 100644
--- a/examples/CMakeLists.txt
+++ b/examples/CMakeLists.txt
@@ -8,3 +8,12 @@ ADD_EXECUTABLE(alexnet ${cifar_source})
 ADD_DEPENDENCIES(alexnet singa_core singa_model singa_utils)
 TARGET_LINK_LIBRARIES(alexnet singa_core singa_utils singa_model protobuf 
${SINGA_LIBKER_LIBS})
 ENDIF(USE_CUDNN)
+
+AUX_SOURCE_DIRECTORY(cifar10-parallel cifar_parallel_source)
+
+IF(USE_CUDNN)
+ADD_EXECUTABLE(alexnet-parallel ${cifar_parallel_source})
+ADD_DEPENDENCIES(alexnet-parallel singa_core singa_model singa_utils)
+TARGET_LINK_LIBRARIES(alexnet-parallel singa_core singa_utils singa_model 
protobuf ${SINGA_LIBKER_LIBS})
+SET_TARGET_PROPERTIES(alexnet-parallel PROPERTIES LINK_FLAGS "${LINK_FLAGS} 
-pthread")
+ENDIF(USE_CUDNN)

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/d45715da/examples/cifar10-parallel/alexnet-parallel.cc
----------------------------------------------------------------------
diff --git a/examples/cifar10-parallel/alexnet-parallel.cc 
b/examples/cifar10-parallel/alexnet-parallel.cc
new file mode 100644
index 0000000..cf581ea
--- /dev/null
+++ b/examples/cifar10-parallel/alexnet-parallel.cc
@@ -0,0 +1,286 @@
+/************************************************************
+*
+* Licensed to the Apache Software Foundation (ASF) under one
+* or more contributor license agreements.  See the NOTICE file
+* distributed with this work for additional information
+* regarding copyright ownership.  The ASF licenses this file
+* to you under the Apache License, Version 2.0 (the
+* "License"); you may not use this file except in compliance
+* with the License.  You may obtain a copy of the License at
+*
+*   http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing,
+* software distributed under the License is distributed on an
+* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+* KIND, either express or implied.  See the License for the
+* specific language governing permissions and limitations
+* under the License.
+*
+*************************************************************/
+
+#include "cifar10.h"
+#include "singa/model/feed_forward_net.h"
+#include "singa/model/optimizer.h"
+#include "singa/model/updater.h"
+#include "singa/model/initializer.h"
+#include "singa/model/metric.h"
+#include "singa/utils/channel.h"
+#include "singa/utils/string.h"
+#include "singa/core/memory.h"
+#include "../../src/model/layer/cudnn_convolution.h"
+#include "../../src/model/layer/cudnn_activation.h"
+#include "../../src/model/layer/cudnn_pooling.h"
+#include "../../src/model/layer/cudnn_lrn.h"
+#include "../../src/model/layer/dense.h"
+#include "../../src/model/layer/flatten.h"
+#include <thread>
+namespace singa {
+
+LayerConf GenConvConf(string name, int nb_filter, int kernel, int stride,
+                      int pad, float std) {
+  LayerConf conf;
+  conf.set_name(name);
+  conf.set_type("CudnnConvolution");
+  ConvolutionConf *conv = conf.mutable_convolution_conf();
+  conv->set_num_output(nb_filter);
+  conv->add_kernel_size(kernel);
+  conv->add_stride(stride);
+  conv->add_pad(pad);
+  conv->set_bias_term(true);
+
+  ParamSpec *wspec = conf.add_param();
+  wspec->set_name(name + "_weight");
+  auto wfill = wspec->mutable_filler();
+  wfill->set_type("Gaussian");
+  wfill->set_std(std);
+
+  ParamSpec *bspec = conf.add_param();
+  bspec->set_name(name + "_bias");
+  bspec->set_lr_mult(2);
+  //  bspec->set_decay_mult(0);
+  return conf;
+}
+
+LayerConf GenPoolingConf(string name, bool max_pool, int kernel, int stride,
+                         int pad) {
+  LayerConf conf;
+  conf.set_name(name);
+  conf.set_type("CudnnPooling");
+  PoolingConf *pool = conf.mutable_pooling_conf();
+  pool->set_kernel_size(kernel);
+  pool->set_stride(stride);
+  pool->set_pad(pad);
+  if (!max_pool) pool->set_pool(PoolingConf_PoolMethod_AVE);
+  return conf;
+}
+
+LayerConf GenReLUConf(string name) {
+  LayerConf conf;
+  conf.set_name(name);
+  conf.set_type("RELU");
+  return conf;
+}
+
+LayerConf GenDenseConf(string name, int num_output, float std, float wd) {
+  LayerConf conf;
+  conf.set_name(name);
+  conf.set_type("Dense");
+  DenseConf *dense = conf.mutable_dense_conf();
+  dense->set_num_output(num_output);
+
+  ParamSpec *wspec = conf.add_param();
+  wspec->set_name(name + "_weight");
+  wspec->set_decay_mult(wd);
+  auto wfill = wspec->mutable_filler();
+  wfill->set_type("Gaussian");
+  wfill->set_std(std);
+
+  ParamSpec *bspec = conf.add_param();
+  bspec->set_name(name + "_bias");
+  bspec->set_lr_mult(2);
+  bspec->set_decay_mult(0);
+
+  return conf;
+}
+
+LayerConf GenLRNConf(string name) {
+  LayerConf conf;
+  conf.set_name(name);
+  conf.set_type("CudnnLRN");
+  LRNConf *lrn = conf.mutable_lrn_conf();
+  lrn->set_local_size(3);
+  lrn->set_alpha(5e-05);
+  lrn->set_beta(0.75);
+  return conf;
+}
+
+LayerConf GenFlattenConf(string name) {
+  LayerConf conf;
+  conf.set_name(name);
+  conf.set_type("Flatten");
+  return conf;
+}
+
+FeedForwardNet CreateNet() {
+  FeedForwardNet net;
+  Shape s{3, 32, 32};
+
+  net.Add(new CudnnConvolution(), GenConvConf("conv1", 32, 5, 1, 2, 0.0001),
+          &s);
+  net.Add(new CudnnPooling(), GenPoolingConf("pool1", true, 3, 2, 1));
+  net.Add(new CudnnActivation(), GenReLUConf("relu1"));
+  net.Add(new CudnnLRN(), GenLRNConf("lrn1"));
+  net.Add(new CudnnConvolution(), GenConvConf("conv2", 32, 5, 1, 2, 0.01));
+  net.Add(new CudnnActivation(), GenReLUConf("relu2"));
+  net.Add(new CudnnPooling(), GenPoolingConf("pool2", false, 3, 2, 1));
+  net.Add(new CudnnLRN(), GenLRNConf("lrn2"));
+  net.Add(new CudnnConvolution, GenConvConf("conv3", 64, 5, 1, 2, 0.01));
+  net.Add(new CudnnActivation(), GenReLUConf("relu3"));
+  net.Add(new CudnnPooling(), GenPoolingConf("pool3", false, 3, 2, 1));
+  net.Add(new Flatten(), GenFlattenConf("flat"));
+  net.Add(new Dense(), GenDenseConf("ip", 10, 0.01, 250));
+  return net;
+}
+
+void Train(float lr, int num_epoch, string data_dir) {
+  Cifar10 data(data_dir);
+  Tensor train_x, train_y, test_x, test_y;
+  Tensor train_x_1, train_x_2, train_y_1, train_y_2;
+  {
+    auto train = data.ReadTrainData();
+    size_t nsamples = train.first.shape(0);
+    auto mtrain =
+        Reshape(train.first, Shape{nsamples, train.first.Size() / nsamples});
+    const Tensor &mean = Average(mtrain, 0);
+    SubRow(mean, &mtrain);
+    train_x = Reshape(mtrain, train.first.shape());
+    train_y = train.second;
+
+    LOG(INFO) << "Slicing training data...";
+    train_x_1.Reshape(Shape{nsamples / 2, train.first.shape(1),
+        train.first.shape(2), train.first.shape(3)});
+    LOG(INFO) << "Copying first data slice...";
+    CopyDataToFrom(&train_x_1, train_x, train_x.Size() / 2);
+    train_x_2.Reshape(Shape{nsamples / 2, train.first.shape(1),
+        train.first.shape(2), train.first.shape(3)});
+    LOG(INFO) << "Copying second data slice...";
+    CopyDataToFrom(&train_x_2, train_x, train_x.Size() / 2, 0,
+                   train_x.Size() / 2);
+    train_y_1.Reshape(Shape{nsamples / 2});
+    train_y_1.AsType(kInt);
+    LOG(INFO) << "Copying first label slice...";
+    CopyDataToFrom(&train_y_1, train_y, train_y.Size() / 2);
+    train_y_2.Reshape(Shape{nsamples / 2});
+    train_y_2.AsType(kInt);
+    LOG(INFO) << "Copying second label slice...";
+    CopyDataToFrom(&train_y_2, train_y, train_y.Size() / 2, 0,
+                   train_y.Size() / 2);
+
+    auto test = data.ReadTestData();
+    nsamples = test.first.shape(0);
+    auto mtest =
+        Reshape(test.first, Shape{nsamples, test.first.Size() / nsamples});
+    SubRow(mean, &mtest);
+    test_x = Reshape(mtest, test.first.shape());
+    test_y = test.second;
+  }
+
+  CHECK_EQ(train_x.shape(0), train_y.shape(0));
+  CHECK_EQ(test_x.shape(0), test_y.shape(0));
+  LOG(INFO) << "Total Training samples = " << train_y.shape(0)
+            << ", Total Test samples = " << test_y.shape(0);
+  CHECK_EQ(train_x_1.shape(0), train_y_1.shape(0));
+  LOG(INFO) << "On net 1, Training samples = " << train_y_1.shape(0)
+            << ", Test samples = " << test_y.shape(0);
+  CHECK_EQ(train_x_2.shape(0), train_y_2.shape(0));
+  LOG(INFO) << "On net 2, Training samples = " << train_y_2.shape(0);
+
+  auto net_1 = CreateNet();
+  auto net_2 = CreateNet();
+
+  SGD sgd;
+  OptimizerConf opt_conf;
+  opt_conf.set_momentum(0.9);
+  auto reg = opt_conf.mutable_regularizer();
+  reg->set_coefficient(0.004);
+  sgd.Setup(opt_conf);
+  sgd.SetLearningRateGenerator([lr](int step) {
+    if (step <= 120)
+      return 0.001;
+    else if (step <= 130)
+      return 0.0001;
+    else
+      return 0.00001;
+  });
+
+  SoftmaxCrossEntropy loss_1, loss_2;
+  Accuracy acc_1, acc_2;
+  /// Create updater aggregating gradient on CPU
+  Updater updater(2, &sgd);
+
+  /// Only need to register parameter once.
+  net_1.Compile(true, true, &updater, &loss_1, &acc_1);
+  net_2.Compile(true, false, &updater, &loss_2, &acc_1);
+
+  MemPoolConf mem_conf;
+  mem_conf.add_device(0);
+  mem_conf.add_device(1);
+  std::shared_ptr<DeviceMemPool> mem_pool(new CnMemPool(mem_conf));
+  std::shared_ptr<CudaGPU> cuda_1(new CudaGPU(0, mem_pool));
+  std::shared_ptr<CudaGPU> cuda_2(new CudaGPU(1, mem_pool));
+  net_1.ToDevice(cuda_1);
+  net_2.ToDevice(cuda_2);
+
+  /*
+  // this does not work for net_2
+  train_x_2.ResetLike(train_x);
+  train_y_2.ResetLike(train_y);
+  test_x_2.ResetLike(test_x);
+  test_y_2.ResetLike(test_y);
+
+  train_x.ToDevice(cuda_1);
+  train_y.ToDevice(cuda_1);
+  test_x.ToDevice(cuda_1);
+  test_y.ToDevice(cuda_1);
+
+  train_x_2.ToDevice(cuda_2);
+  train_y_2.ToDevice(cuda_2);
+  test_x_2.ToDevice(cuda_2);
+  test_y_2.ToDevice(cuda_2);
+  */
+
+  train_x_1.ToDevice(cuda_1);
+  train_y_1.ToDevice(cuda_1);
+  test_x.ToDevice(cuda_1);
+  test_y.ToDevice(cuda_1);
+  train_x_2.ToDevice(cuda_2);
+  train_y_2.ToDevice(cuda_2);
+
+  // net.Train(100, num_epoch, train_x, train_y, test_x, test_y);
+
+  LOG(INFO) << "Launching thread...";
+  std::thread t1 =
+      net_1.TrainThread(50, num_epoch, train_x_1, train_y_1, test_x, test_y);
+  std::thread t2 = net_2.TrainThread(50, num_epoch, train_x_2, train_y_2);
+  t1.join();
+  t2.join();
+}
+}
+
+int main(int argc, char **argv) {
+  singa::InitChannel(nullptr);
+  int pos = singa::ArgPos(argc, argv, "-epoch");
+  int nEpoch = 1;
+  if (pos != -1) nEpoch = atoi(argv[pos + 1]);
+  pos = singa::ArgPos(argc, argv, "-lr");
+  float lr = 0.001;
+  if (pos != -1) lr = atof(argv[pos + 1]);
+  pos = singa::ArgPos(argc, argv, "-data");
+  string data = "cifar-10-batches-bin";
+  if (pos != -1) data = argv[pos + 1];
+
+  LOG(INFO) << "Start training";
+  singa::Train(lr, nEpoch, data);
+  LOG(INFO) << "End training";
+}

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/d45715da/examples/cifar10-parallel/cifar10.h
----------------------------------------------------------------------
diff --git a/examples/cifar10-parallel/cifar10.h 
b/examples/cifar10-parallel/cifar10.h
new file mode 100644
index 0000000..d2b9225
--- /dev/null
+++ b/examples/cifar10-parallel/cifar10.h
@@ -0,0 +1,98 @@
+/************************************************************
+*
+* Licensed to the Apache Software Foundation (ASF) under one
+* or more contributor license agreements.  See the NOTICE file
+* distributed with this work for additional information
+* regarding copyright ownership.  The ASF licenses this file
+* to you under the Apache License, Version 2.0 (the
+* "License"); you may not use this file except in compliance
+* with the License.  You may obtain a copy of the License at
+*
+*   http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing,
+* software distributed under the License is distributed on an
+* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+* KIND, either express or implied.  See the License for the
+* specific language governing permissions and limitations
+* under the License.
+*
+*************************************************************/
+#include <fstream>
+#include <string>
+#include <cstdint>
+#include <iostream>
+#include "singa/core/tensor.h"
+using std::string;
+namespace singa {
+/// For reading cifar10 binary data as tensors.
+class Cifar10 {
+ public:
+  /// 'dir_path': path to the folder including the *.bin files
+  Cifar10(string dir_path, bool normalize = true) : dir_path_(dir_path) {}
+
+  /// read all training data into an image Tensor and a label Tensor
+  const std::pair<Tensor, Tensor> ReadTrainData();
+  /// read all test data into an image Tensor and a label Tensor
+  const std::pair<Tensor, Tensor> ReadTestData();
+  /// read data from one file into an image Tensor and a label Tensor
+  const std::pair<Tensor, Tensor> ReadFile(string file);
+
+  void ReadImage(std::ifstream* file, int* label, char* buffer);
+
+ private:
+  const size_t kImageSize = 32;
+  const size_t kImageVol = 3072;
+  const size_t kBatchSize = 10000;
+  const size_t kTrainFiles = 5;
+
+  string dir_path_;
+};
+
+void Cifar10::ReadImage(std::ifstream* file, int* label, char* buffer) {
+  char label_char;
+  file->read(&label_char, 1);
+  *label = static_cast<int>(label_char);
+  file->read(buffer, kImageVol);
+  return;
+}
+const std::pair<Tensor, Tensor> Cifar10::ReadFile(string file) {
+  Tensor images(Shape{kBatchSize, 3, kImageSize, kImageSize});
+  Tensor labels(Shape{kBatchSize}, kInt);
+  if (dir_path_.back() != '/') dir_path_.push_back('/');
+  LOG(INFO) << "Reading file " << dir_path_ + file;
+  std::ifstream data_file((dir_path_ + file).c_str(),
+                          std::ios::in | std::ios::binary);
+  CHECK(data_file.is_open()) << "Unable to open file " << dir_path_ + file;
+  int label;
+  char image[kImageVol];
+  float float_image[kImageVol];
+  int tmplabels[kBatchSize];
+  for (size_t itemid = 0; itemid < kBatchSize; ++itemid) {
+    // LOG(INFO) << "reading " << itemid << "-th image";
+    ReadImage(&data_file, &label, image);
+    for (size_t i = 0; i < kImageVol; i++)
+      float_image[i] = static_cast<float>(static_cast<uint8_t>(image[i]));
+    images.CopyDataFromHostPtr(float_image, kImageVol, itemid * kImageVol);
+    tmplabels[itemid] = label;
+  }
+  labels.CopyDataFromHostPtr(tmplabels, kBatchSize);
+  return std::make_pair(images, labels);
+}
+
+const std::pair<Tensor, Tensor> Cifar10::ReadTrainData() {
+  Tensor images(Shape{kBatchSize * kTrainFiles, 3, kImageSize, kImageSize});
+  Tensor labels(Shape{kBatchSize * kTrainFiles}, kInt);
+  for (size_t fileid = 0; fileid < kTrainFiles; ++fileid) {
+    string file = "data_batch_" + std::to_string(fileid + 1) + ".bin";
+    const auto ret = ReadFile(file);
+    CopyDataToFrom(&images, ret.first, ret.first.Size(),
+                   fileid * ret.first.Size());
+    CopyDataToFrom(&labels, ret.second, kBatchSize, fileid * kBatchSize);
+  }
+  return std::make_pair(images, labels);
+}
+const std::pair<Tensor, Tensor> Cifar10::ReadTestData() {
+  return ReadFile("test_batch.bin");
+}
+}  // namespace singa

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/d45715da/examples/cifar10-parallel/download_data.py
----------------------------------------------------------------------
diff --git a/examples/cifar10-parallel/download_data.py 
b/examples/cifar10-parallel/download_data.py
new file mode 100755
index 0000000..ce0ee4f
--- /dev/null
+++ b/examples/cifar10-parallel/download_data.py
@@ -0,0 +1,52 @@
+#!/usr/bin/env python
+import urllib
+import tarfile
+import os
+import sys
+import argparse
+
+
+def extract_tarfile(filepath):
+    if os.path.exists(filepath):
+        print 'The tar file does exist. Extracting it now..'
+        with tarfile.open(filepath, 'r') as f:
+            f.extractall('.')
+        print 'Finished!'
+        sys.exit(0)
+
+
+def check_dir_exist(dirpath):
+    if os.path.exists(dirpath):
+        print 'Directory %s does exist. To redownload the files, '\
+            'remove the existing directory and %s.tar.gz' % (dirpath, dirpath)
+        return True
+    else:
+        return False
+
+
+def do_download(dirpath, gzfile, url):
+    if check_dir_exist(dirpath):
+        sys.exit(0)
+    print 'Downloading CIFAR10 from %s' % (url)
+    urllib.urlretrieve(url, gzfile)
+    extract_tarfile(gzfile)
+    print 'Finished!'
+
+
+if __name__ == '__main__':
+    parser = argparse.ArgumentParser(description='Download Cifar10 datasets')
+    parser.add_argument(
+        'file',
+        type=str,
+        choices=['py', 'bin'])
+    args = parser.parse_args()
+    if args.file == 'bin':
+        dirpath = 'cifar-10-batches-bin'
+        gzfile = 'cifar-10-binary' + '.tar.gz'
+        url = 'http://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz'
+        do_download(dirpath, gzfile, url)
+    else:
+        dirpath = 'cifar-10-batches-py'
+        gzfile = 'cifar-10-python' + '.tar.gz'
+        url = 'http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'
+        do_download(dirpath, gzfile, url)

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/d45715da/examples/cifar10-parallel/run.sh
----------------------------------------------------------------------
diff --git a/examples/cifar10-parallel/run.sh b/examples/cifar10-parallel/run.sh
new file mode 100755
index 0000000..379f847
--- /dev/null
+++ b/examples/cifar10-parallel/run.sh
@@ -0,0 +1,2 @@
+#!/usr/bin/env sh
+../../build/bin/alexnet-parallel -epoch 140

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/d45715da/examples/cifar10/alexnet.cc
----------------------------------------------------------------------
diff --git a/examples/cifar10/alexnet.cc b/examples/cifar10/alexnet.cc
index 6480557..b237fb6 100644
--- a/examples/cifar10/alexnet.cc
+++ b/examples/cifar10/alexnet.cc
@@ -22,16 +22,19 @@
 #include "./cifar10.h"
 #include "singa/model/feed_forward_net.h"
 #include "singa/model/optimizer.h"
+#include "singa/model/updater.h"
 #include "singa/model/initializer.h"
 #include "singa/model/metric.h"
 #include "singa/utils/channel.h"
 #include "singa/utils/string.h"
+#include "singa/core/memory.h"
 #include "../../src/model/layer/cudnn_convolution.h"
 #include "../../src/model/layer/cudnn_activation.h"
 #include "../../src/model/layer/cudnn_pooling.h"
 #include "../../src/model/layer/cudnn_lrn.h"
 #include "../../src/model/layer/dense.h"
 #include "../../src/model/layer/flatten.h"
+#include <thread>
 namespace singa {
 
 LayerConf GenConvConf(string name, int nb_filter, int kernel, int stride,
@@ -55,7 +58,7 @@ LayerConf GenConvConf(string name, int nb_filter, int kernel, 
int stride,
   ParamSpec *bspec = conf.add_param();
   bspec->set_name(name + "_bias");
   bspec->set_lr_mult(2);
-//  bspec->set_decay_mult(0);
+  //  bspec->set_decay_mult(0);
   return conf;
 }
 
@@ -148,10 +151,11 @@ void Train(float lr, int num_epoch, string data_dir) {
     size_t nsamples = train.first.shape(0);
     auto mtrain =
         Reshape(train.first, Shape{nsamples, train.first.Size() / nsamples});
-    const Tensor& mean = Average(mtrain, 0);
+    const Tensor &mean = Average(mtrain, 0);
     SubRow(mean, &mtrain);
     train_x = Reshape(mtrain, train.first.shape());
     train_y = train.second;
+
     auto test = data.ReadTestData();
     nsamples = test.first.shape(0);
     auto mtest =
@@ -160,11 +164,14 @@ void Train(float lr, int num_epoch, string data_dir) {
     test_x = Reshape(mtest, test.first.shape());
     test_y = test.second;
   }
+
   CHECK_EQ(train_x.shape(0), train_y.shape(0));
   CHECK_EQ(test_x.shape(0), test_y.shape(0));
-  LOG(INFO) << "Training samples = " << train_y.shape(0)
-            << ", Test samples = " << test_y.shape(0);
+  LOG(INFO) << "Total Training samples = " << train_y.shape(0)
+            << ", Total Test samples = " << test_y.shape(0);
+
   auto net = CreateNet();
+
   SGD sgd;
   OptimizerConf opt_conf;
   opt_conf.set_momentum(0.9);
@@ -180,16 +187,24 @@ void Train(float lr, int num_epoch, string data_dir) {
       return 0.00001;
   });
 
-  SoftmaxCrossEntropy loss;
-  Accuracy acc;
-  net.Compile(true, &sgd, &loss, &acc);
+  SoftmaxCrossEntropy loss_1, loss_2;
+  Accuracy acc_1, acc_2;
+  /// Create updater aggregating gradient on CPU
+  Updater updater(1, &sgd);
 
-  auto cuda = std::make_shared<CudaGPU>();
+  net.Compile(true, true, &updater, &loss_1, &acc_1);
+
+  MemPoolConf mem_conf;
+  mem_conf.add_device(0);
+  std::shared_ptr<DeviceMemPool> mem_pool(new CnMemPool(mem_conf));
+  std::shared_ptr<CudaGPU> cuda(new CudaGPU(0, mem_pool));
   net.ToDevice(cuda);
+
   train_x.ToDevice(cuda);
   train_y.ToDevice(cuda);
   test_x.ToDevice(cuda);
   test_y.ToDevice(cuda);
+
   net.Train(100, num_epoch, train_x, train_y, test_x, test_y);
 }
 }

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/d45715da/include/singa/model/feed_forward_net.h
----------------------------------------------------------------------
diff --git a/include/singa/model/feed_forward_net.h 
b/include/singa/model/feed_forward_net.h
index 36cbe00..6eeb34d 100644
--- a/include/singa/model/feed_forward_net.h
+++ b/include/singa/model/feed_forward_net.h
@@ -20,7 +20,8 @@
 #include "singa/model/layer.h"
 #include "singa/model/loss.h"
 #include "singa/model/metric.h"
-#include "singa/model/optimizer.h"
+#include "singa/model/updater.h"
+#include <thread>
 namespace singa {
 
 /// The feed-forward neural net.
@@ -56,8 +57,9 @@ class FeedForwardNet {
   /// If the neural net is constructed for evaluation only, then 'opt' is not
   /// necessary; But for training, both 'opt' and 'loss' are necessary.
   /// 'shuffle' indicates shuffling training samples within one epoch it is
-  /// valid using Train();
-  void Compile(bool shuffle, Optimizer* opt, Loss* loss,
+  /// valid using Train(). If to_register is set true, parameter will be
+  /// registered in Updater;
+  void Compile(bool shuffle, bool to_register, Updater* updater, Loss* loss,
                Metric* metric);
 
   /// Conduct the training giving the training data 'x' and label 'y'.
@@ -118,6 +120,19 @@ class FeedForwardNet {
   /// Set the data type of each layer.
   void AsType(DataType dtype);
 
+  /// A wrapper method to spawn a thread to execute Train() method.
+  std::thread TrainThread(size_t batchsize, int nb_epoch, const Tensor& x,
+                          const Tensor& y, const Tensor& val_x,
+                          const Tensor& val_y) {
+    return std::thread([=]() { Train(batchsize, nb_epoch, x, y, val_x, val_y); 
});
+  }
+
+  /// A wrapper method to spawn a thread to execute Train() method.
+  std::thread TrainThread(size_t batchsize, int nb_epoch, const Tensor& x,
+                          const Tensor& y) {
+    return std::thread([=]() { Train(batchsize, nb_epoch, x, y); });
+  }
+
   const vector<Layer*> layers() const { return layers_; }
   const vector<string> GetParamNames() const;
   const vector<ParamSpec> GetParamSpecs() const;
@@ -125,7 +140,7 @@ class FeedForwardNet {
 
  protected:
   vector<Layer*> layers_;
-  Optimizer* opt_;
+  Updater* updater_;
   Loss* loss_;
   Metric* metric_;
 
@@ -134,7 +149,6 @@ class FeedForwardNet {
   DataType dtype_ = kFloat32;
 };
 
-}  /* singa */
-
+} /* singa */
 
 #endif  // SINGA_MODEL_FEED_FORWARD_NET_H_

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/d45715da/include/singa/model/updater.h
----------------------------------------------------------------------
diff --git a/include/singa/model/updater.h b/include/singa/model/updater.h
new file mode 100644
index 0000000..dbd91e8
--- /dev/null
+++ b/include/singa/model/updater.h
@@ -0,0 +1,68 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef SINGA_MODEL_UPDATER_H_
+#define SINGA_MODEL_UPDATER_H_
+
+#include "singa/model/optimizer.h"
+#include "singa/core/device.h"
+#include "singa/core/tensor.h"
+#include "singa/utils/logging.h"
+
+#include <memory>
+#include <vector>
+#include <mutex>
+#include <condition_variable>
+#include <string>
+#include <unordered_map>
+
+namespace singa {
+class Updater {
+ public:
+  Updater(int total_num, Optimizer* opt,
+          std::shared_ptr<Device> dev = defaultDevice)
+      : total_num_{total_num}, opt_{opt}, dev_(dev) {}
+  virtual ~Updater() {}
+
+  /// Forward Setup() to Optimizer.
+  virtual void Setup(const OptimizerConf& conf);
+  /// Forward Register() to Optimizer.
+  virtual void Register(const string& name, const ParamSpec& specs);
+  /// Update parameter value based on given gradient by invoking optimizer
+  /// algoritim. When tranining net call this function will be blocked until
+  /// all the partial gradients are aggrageted in a synchronized style 
training.
+  virtual void Apply(int step, const string& name, Tensor& grad, Tensor& 
value);
+  Optimizer* GetOptimizer() { return opt_; }
+
+  // No copy allowed.
+  Updater(const Updater&) = delete;
+  void operator=(const Updater&) = delete;
+
+ protected:
+  int total_num_;
+  Optimizer* opt_;
+  std::shared_ptr<Device> dev_;
+  std::mutex mtx_;
+  std::condition_variable aggr_count_eq_total_num_;
+  std::unordered_map<std::string, int> aggr_count_, copy_count_;
+  std::unordered_map<std::string, Tensor> buffer_, partial_sum_;
+  std::unordered_map<std::string, bool> has_averaged_, has_init_;
+};
+}  //  namespace singa
+
+#endif  //  SINGA_MODEL_UPDATER_H_

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/d45715da/src/core/memory/memory.cc
----------------------------------------------------------------------
diff --git a/src/core/memory/memory.cc b/src/core/memory/memory.cc
index fa4b305..cb33a48 100644
--- a/src/core/memory/memory.cc
+++ b/src/core/memory/memory.cc
@@ -55,7 +55,7 @@ void CnMemPool::Init() {
     int i = 0;
     for (auto device : conf_.device()) {
       settingPtr[i].device = device;
-      settingPtr[device].size = conf_.init_size() * kNBytesPerMB;
+      settingPtr[i].size = conf_.init_size() * kNBytesPerMB;
       settingPtr[i].numStreams = 0;
       settingPtr[i].streams = NULL;
       settingPtr[i].streamSizes = 0;

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/d45715da/src/model/feed_forward_net.cc
----------------------------------------------------------------------
diff --git a/src/model/feed_forward_net.cc b/src/model/feed_forward_net.cc
index 09a06d7..b30d24a 100644
--- a/src/model/feed_forward_net.cc
+++ b/src/model/feed_forward_net.cc
@@ -73,20 +73,22 @@ const vector<ParamSpec> FeedForwardNet::GetParamSpecs() 
const {
   return specs;
 }
 
-void FeedForwardNet::Compile(bool shuffle, Optimizer* opt, Loss* loss,
-                             Metric* metric) {
+void FeedForwardNet::Compile(bool shuffle, bool to_register, Updater* updater,
+                             Loss* loss, Metric* metric) {
   shuffle_ = shuffle;
-  bool train = (opt != nullptr) && (loss != nullptr);
+  bool train = (updater != nullptr) && (loss != nullptr);
   bool test = metric != nullptr;
-  CHECK(train || test) << "Must set opt and loss, or set metric";
-  opt_ = opt;
+  CHECK(train || test) << "Must set upater and loss, or set metric";
+  updater_ = updater;
   loss_ = loss;
   metric_ = metric;
   const auto specs = GetParamSpecs();
   auto params = GetParamValues();
   CHECK_EQ(specs.size(), params.size());
   for (size_t k = 0; k < specs.size(); k++) {
-    opt_->Register(specs[k].name(), specs[k]);
+    if (to_register) {
+      updater_->Register(specs[k].name(), specs[k]);
+    }
     auto init = CreateInitializer(specs[k].filler());
     init->Fill(params[k]);
     LOG(INFO) << specs[k].name() << " : " << params[k].L1();
@@ -147,8 +149,8 @@ void FeedForwardNet::Train(size_t batchsize, int nb_epoch, 
const Tensor& x,
   int num_extra_samples = x.shape(0) % batchsize;
   if (num_extra_samples != 0)
     LOG(WARNING) << "Pls set batchsize to make num_total_samples "
-      << "% batchsize == 0. Otherwise, the last " << num_extra_samples
-      << " samples would not be used";
+                 << "% batchsize == 0. Otherwise, the last "
+                 << num_extra_samples << " samples would not be used";
   Channel* train_ch = GetChannel("train_perf");
   train_ch->EnableDestStderr(true);
   Channel* val_ch = GetChannel("val_perf");
@@ -167,12 +169,14 @@ void FeedForwardNet::Train(size_t batchsize, int 
nb_epoch, const Tensor& x,
       loss += ret.first;
       metric += ret.second;
     }
+    if (val_x.Size() == 0) continue;
     loss /= b;
     metric /= b;
-    train_ch->Send("Epoch " + std::to_string(epoch) + ", training loss = " +
-                   std::to_string(loss) + ", accuracy = " +
-                   std::to_string(metric) + ", lr = " +
-                   std::to_string(opt_->GetLearningRate(epoch)));
+    train_ch->Send(
+        "Epoch " + std::to_string(epoch) + ", training loss = " +
+        std::to_string(loss) + ", accuracy = " + std::to_string(metric) +
+        ", lr = " +
+        std::to_string(updater_->GetOptimizer()->GetLearningRate(epoch)));
     if (val_x.Size() && val_y.Size()) {
       const auto val_perf = Evaluate(val_x, val_y, batchsize);
       val_ch->Send("Epoch " + std::to_string(epoch) + ", val loss = " +
@@ -195,7 +199,7 @@ const std::pair<float, float> 
FeedForwardNet::TrainOnBatch(int epoch,
   auto names = GetParamNames();
   auto values = GetParamValues();
   for (size_t k = 0; k < grads.size(); k++) {
-    opt_->Apply(epoch, names[k], grads[k], values.at(k));
+    updater_->Apply(epoch, names[k], grads[k], values.at(k));
   }
   return std::make_pair(loss, metric);
 }
@@ -203,7 +207,7 @@ const std::pair<float, float> 
FeedForwardNet::TrainOnBatch(int epoch,
 const Tensor FeedForwardNet::Forward(int flag, const Tensor& data) {
   Tensor input = data, output;
   for (auto layer : layers_) {
-//    LOG(INFO) << layer->name() << ": " << input.L1();
+    //    LOG(INFO) << layer->name() << ": " << input.L1();
     output = layer->Forward(flag, input);
     // LOG(INFO) << layer->name() << ": " << output.L2();
     input = output;
@@ -216,13 +220,13 @@ const vector<Tensor> FeedForwardNet::Backward(int flag, 
const Tensor& grad) {
   std::stack<Tensor> buf;
   Tensor tmp = grad;
   for (int i = layers_.size() - 1; i >= 0; i--) {
- //   LOG(INFO) << layers_.at(i)->name() << " : " << tmp.L1();
+    //   LOG(INFO) << layers_.at(i)->name() << " : " << tmp.L1();
     auto ret = layers_.at(i)->Backward(flag, tmp);
     tmp = ret.first;
     if (ret.second.size()) {
       for (int k = ret.second.size() - 1; k >= 0; k--) {
         buf.push(ret.second[k]);
- //       LOG(INFO) <<  "      " << buf.top().L1();
+        //       LOG(INFO) <<  "      " << buf.top().L1();
       }
     }
   }

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/d45715da/src/model/updater.cc
----------------------------------------------------------------------
diff --git a/src/model/updater.cc b/src/model/updater.cc
new file mode 100644
index 0000000..e6c1cdb
--- /dev/null
+++ b/src/model/updater.cc
@@ -0,0 +1,87 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "singa/model/updater.h"
+
+namespace singa {
+
+void Updater::Setup(const OptimizerConf& conf) { opt_->Setup(conf); }
+
+void Updater::Register(const string& name, const ParamSpec& specs) {
+  opt_->Register(name, specs);
+  aggr_count_[name] = 0;
+  copy_count_[name] = 0;
+  has_averaged_[name] = false;
+  has_init_[name] = false;
+}
+
+void Updater::Apply(int step, const string& name, Tensor& grad, Tensor& value) 
{
+  CHECK(aggr_count_.count(name) == 1) << "Parameter " << name
+                                         << " has not been registered before.";
+  /// This lock is aimed to protect aggregation counter, data transfering 
buffer,
+  /// and partial aggregation result. However, the data transfering can be 
moved out
+  /// of the critial section to improve performance in the future.
+  std::unique_lock<std::mutex> lock(mtx_);
+  if (aggr_count_[name] == 0) {
+    if (!has_init_[name]) {
+      Tensor tmp(grad.shape(), dev_, grad.data_type());
+      partial_sum_[name] = tmp;
+      buffer_[name].ResetLike(tmp);
+      has_init_[name] = true;
+    } else {
+      partial_sum_[name].SetValue(.0f);
+    }
+  }
+  buffer_[name].CopyData(grad);
+  Add(partial_sum_[name], buffer_[name], &partial_sum_[name]);
+  ++aggr_count_[name];
+
+  /// Block this thread when we have not gotten enough paritial gradients.
+  if (aggr_count_[name] != total_num_) {
+    while (aggr_count_[name] != total_num_) {
+      aggr_count_eq_total_num_.wait(lock);
+    }
+  } else {
+    aggr_count_eq_total_num_.notify_all();
+  }
+
+  /// Now we get enought paritial gradient from all neural net instances,
+  /// then we calcuate the average gradient. The first notified thread
+  /// finish the averaging once.
+  if (!has_averaged_[name]) {
+    Div(partial_sum_[name], static_cast<float>(total_num_),
+        &partial_sum_[name]);
+    copy_count_[name] = 0;
+    has_averaged_[name] = true;
+  }
+
+  /// For now, gradient aggregation and SGD algorithm run on the same device.
+  /// TODO(wangji): It is possible to make them run separately.
+  buffer_[name].CopyData(value);
+  /// Apply optimization algorithm based on the averaged gradient.
+  opt_->Apply(step, name, partial_sum_[name], buffer_[name]);
+  value.CopyData(buffer_[name]);
+
+  /// The last thread finishing copy should set aggregation counter back to 0.
+  ++copy_count_[name];
+  if (copy_count_[name] == total_num_) {
+    aggr_count_[name] = 0;
+    has_averaged_[name] = false;
+  }
+}
+}  // namesapce singa


Reply via email to