Repository: incubator-singa
Updated Branches:
  refs/heads/dev 8e3d3df61 -> 4e7f3c13b


SINGA-226 Add parallel training on a single machine for singa v1.0

Move cifar-10 parallel training from a separated folder into example/cifar10.
Retain former Compile() method in feed_forward_net to receive a Optimizer 
argument, in this way
the previous single card version alexnet.cc can keep unchanged.
Add a updater folder in src/model folder.


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/0184fac3
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/0184fac3
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/0184fac3

Branch: refs/heads/dev
Commit: 0184fac30b9c4a62925d5b15138ed8658b5e1e38
Parents: d45715d
Author: WANG Ji <[email protected]>
Authored: Tue Jul 19 16:02:02 2016 +0800
Committer: WANG Ji <[email protected]>
Committed: Thu Jul 21 16:33:23 2016 +0800

----------------------------------------------------------------------
 CMakeLists.txt                                |   2 +-
 examples/CMakeLists.txt                       |  20 +-
 examples/cifar10-parallel/alexnet-parallel.cc | 286 --------------------
 examples/cifar10-parallel/cifar10.h           |  98 -------
 examples/cifar10-parallel/download_data.py    |  52 ----
 examples/cifar10-parallel/run.sh              |   2 -
 examples/cifar10/CMakeLists.txt               |  13 +
 examples/cifar10/alexnet-parallel.cc          | 287 +++++++++++++++++++++
 examples/cifar10/alexnet.cc                   |  31 +--
 examples/cifar10/run-parallel.sh              |   2 +
 include/singa/model/feed_forward_net.h        |  24 +-
 include/singa/model/updater.h                 |  42 ++-
 src/CMakeLists.txt                            |   1 +
 src/model/feed_forward_net.cc                 |  11 +-
 src/model/updater.cc                          |  87 -------
 src/model/updater/local_updater.cc            |  82 ++++++
 src/model/updater/updater.cc                  |  32 +++
 17 files changed, 485 insertions(+), 587 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/0184fac3/CMakeLists.txt
----------------------------------------------------------------------
diff --git a/CMakeLists.txt b/CMakeLists.txt
index dca8a30..d3cd776 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -1,7 +1,7 @@
 CMAKE_MINIMUM_REQUIRED(VERSION 2.6)
 
 PROJECT(singa)
-SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11 -lpthread")
+SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11")
 
 LIST(APPEND CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/cmake/Thirdparty)
 #message(STATUS "module path: ${CMAKE_MODULE_PATH}")

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/0184fac3/examples/CMakeLists.txt
----------------------------------------------------------------------
diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt
index 26d5cdd..3490c38 100644
--- a/examples/CMakeLists.txt
+++ b/examples/CMakeLists.txt
@@ -1,19 +1 @@
-INCLUDE_DIRECTORIES(${CMAKE_CURRENT_SOURCE_DIR})
-INCLUDE_DIRECTORIES(${CMAKE_BINARY_DIR}/include)
-
-AUX_SOURCE_DIRECTORY(cifar10 cifar_source)
-
-IF(USE_CUDNN)
-ADD_EXECUTABLE(alexnet ${cifar_source})
-ADD_DEPENDENCIES(alexnet singa_core singa_model singa_utils)
-TARGET_LINK_LIBRARIES(alexnet singa_core singa_utils singa_model protobuf 
${SINGA_LIBKER_LIBS})
-ENDIF(USE_CUDNN)
-
-AUX_SOURCE_DIRECTORY(cifar10-parallel cifar_parallel_source)
-
-IF(USE_CUDNN)
-ADD_EXECUTABLE(alexnet-parallel ${cifar_parallel_source})
-ADD_DEPENDENCIES(alexnet-parallel singa_core singa_model singa_utils)
-TARGET_LINK_LIBRARIES(alexnet-parallel singa_core singa_utils singa_model 
protobuf ${SINGA_LIBKER_LIBS})
-SET_TARGET_PROPERTIES(alexnet-parallel PROPERTIES LINK_FLAGS "${LINK_FLAGS} 
-pthread")
-ENDIF(USE_CUDNN)
+ADD_SUBDIRECTORY(cifar10)

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/0184fac3/examples/cifar10-parallel/alexnet-parallel.cc
----------------------------------------------------------------------
diff --git a/examples/cifar10-parallel/alexnet-parallel.cc 
b/examples/cifar10-parallel/alexnet-parallel.cc
deleted file mode 100644
index cf581ea..0000000
--- a/examples/cifar10-parallel/alexnet-parallel.cc
+++ /dev/null
@@ -1,286 +0,0 @@
-/************************************************************
-*
-* Licensed to the Apache Software Foundation (ASF) under one
-* or more contributor license agreements.  See the NOTICE file
-* distributed with this work for additional information
-* regarding copyright ownership.  The ASF licenses this file
-* to you under the Apache License, Version 2.0 (the
-* "License"); you may not use this file except in compliance
-* with the License.  You may obtain a copy of the License at
-*
-*   http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing,
-* software distributed under the License is distributed on an
-* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-* KIND, either express or implied.  See the License for the
-* specific language governing permissions and limitations
-* under the License.
-*
-*************************************************************/
-
-#include "cifar10.h"
-#include "singa/model/feed_forward_net.h"
-#include "singa/model/optimizer.h"
-#include "singa/model/updater.h"
-#include "singa/model/initializer.h"
-#include "singa/model/metric.h"
-#include "singa/utils/channel.h"
-#include "singa/utils/string.h"
-#include "singa/core/memory.h"
-#include "../../src/model/layer/cudnn_convolution.h"
-#include "../../src/model/layer/cudnn_activation.h"
-#include "../../src/model/layer/cudnn_pooling.h"
-#include "../../src/model/layer/cudnn_lrn.h"
-#include "../../src/model/layer/dense.h"
-#include "../../src/model/layer/flatten.h"
-#include <thread>
-namespace singa {
-
-LayerConf GenConvConf(string name, int nb_filter, int kernel, int stride,
-                      int pad, float std) {
-  LayerConf conf;
-  conf.set_name(name);
-  conf.set_type("CudnnConvolution");
-  ConvolutionConf *conv = conf.mutable_convolution_conf();
-  conv->set_num_output(nb_filter);
-  conv->add_kernel_size(kernel);
-  conv->add_stride(stride);
-  conv->add_pad(pad);
-  conv->set_bias_term(true);
-
-  ParamSpec *wspec = conf.add_param();
-  wspec->set_name(name + "_weight");
-  auto wfill = wspec->mutable_filler();
-  wfill->set_type("Gaussian");
-  wfill->set_std(std);
-
-  ParamSpec *bspec = conf.add_param();
-  bspec->set_name(name + "_bias");
-  bspec->set_lr_mult(2);
-  //  bspec->set_decay_mult(0);
-  return conf;
-}
-
-LayerConf GenPoolingConf(string name, bool max_pool, int kernel, int stride,
-                         int pad) {
-  LayerConf conf;
-  conf.set_name(name);
-  conf.set_type("CudnnPooling");
-  PoolingConf *pool = conf.mutable_pooling_conf();
-  pool->set_kernel_size(kernel);
-  pool->set_stride(stride);
-  pool->set_pad(pad);
-  if (!max_pool) pool->set_pool(PoolingConf_PoolMethod_AVE);
-  return conf;
-}
-
-LayerConf GenReLUConf(string name) {
-  LayerConf conf;
-  conf.set_name(name);
-  conf.set_type("RELU");
-  return conf;
-}
-
-LayerConf GenDenseConf(string name, int num_output, float std, float wd) {
-  LayerConf conf;
-  conf.set_name(name);
-  conf.set_type("Dense");
-  DenseConf *dense = conf.mutable_dense_conf();
-  dense->set_num_output(num_output);
-
-  ParamSpec *wspec = conf.add_param();
-  wspec->set_name(name + "_weight");
-  wspec->set_decay_mult(wd);
-  auto wfill = wspec->mutable_filler();
-  wfill->set_type("Gaussian");
-  wfill->set_std(std);
-
-  ParamSpec *bspec = conf.add_param();
-  bspec->set_name(name + "_bias");
-  bspec->set_lr_mult(2);
-  bspec->set_decay_mult(0);
-
-  return conf;
-}
-
-LayerConf GenLRNConf(string name) {
-  LayerConf conf;
-  conf.set_name(name);
-  conf.set_type("CudnnLRN");
-  LRNConf *lrn = conf.mutable_lrn_conf();
-  lrn->set_local_size(3);
-  lrn->set_alpha(5e-05);
-  lrn->set_beta(0.75);
-  return conf;
-}
-
-LayerConf GenFlattenConf(string name) {
-  LayerConf conf;
-  conf.set_name(name);
-  conf.set_type("Flatten");
-  return conf;
-}
-
-FeedForwardNet CreateNet() {
-  FeedForwardNet net;
-  Shape s{3, 32, 32};
-
-  net.Add(new CudnnConvolution(), GenConvConf("conv1", 32, 5, 1, 2, 0.0001),
-          &s);
-  net.Add(new CudnnPooling(), GenPoolingConf("pool1", true, 3, 2, 1));
-  net.Add(new CudnnActivation(), GenReLUConf("relu1"));
-  net.Add(new CudnnLRN(), GenLRNConf("lrn1"));
-  net.Add(new CudnnConvolution(), GenConvConf("conv2", 32, 5, 1, 2, 0.01));
-  net.Add(new CudnnActivation(), GenReLUConf("relu2"));
-  net.Add(new CudnnPooling(), GenPoolingConf("pool2", false, 3, 2, 1));
-  net.Add(new CudnnLRN(), GenLRNConf("lrn2"));
-  net.Add(new CudnnConvolution, GenConvConf("conv3", 64, 5, 1, 2, 0.01));
-  net.Add(new CudnnActivation(), GenReLUConf("relu3"));
-  net.Add(new CudnnPooling(), GenPoolingConf("pool3", false, 3, 2, 1));
-  net.Add(new Flatten(), GenFlattenConf("flat"));
-  net.Add(new Dense(), GenDenseConf("ip", 10, 0.01, 250));
-  return net;
-}
-
-void Train(float lr, int num_epoch, string data_dir) {
-  Cifar10 data(data_dir);
-  Tensor train_x, train_y, test_x, test_y;
-  Tensor train_x_1, train_x_2, train_y_1, train_y_2;
-  {
-    auto train = data.ReadTrainData();
-    size_t nsamples = train.first.shape(0);
-    auto mtrain =
-        Reshape(train.first, Shape{nsamples, train.first.Size() / nsamples});
-    const Tensor &mean = Average(mtrain, 0);
-    SubRow(mean, &mtrain);
-    train_x = Reshape(mtrain, train.first.shape());
-    train_y = train.second;
-
-    LOG(INFO) << "Slicing training data...";
-    train_x_1.Reshape(Shape{nsamples / 2, train.first.shape(1),
-        train.first.shape(2), train.first.shape(3)});
-    LOG(INFO) << "Copying first data slice...";
-    CopyDataToFrom(&train_x_1, train_x, train_x.Size() / 2);
-    train_x_2.Reshape(Shape{nsamples / 2, train.first.shape(1),
-        train.first.shape(2), train.first.shape(3)});
-    LOG(INFO) << "Copying second data slice...";
-    CopyDataToFrom(&train_x_2, train_x, train_x.Size() / 2, 0,
-                   train_x.Size() / 2);
-    train_y_1.Reshape(Shape{nsamples / 2});
-    train_y_1.AsType(kInt);
-    LOG(INFO) << "Copying first label slice...";
-    CopyDataToFrom(&train_y_1, train_y, train_y.Size() / 2);
-    train_y_2.Reshape(Shape{nsamples / 2});
-    train_y_2.AsType(kInt);
-    LOG(INFO) << "Copying second label slice...";
-    CopyDataToFrom(&train_y_2, train_y, train_y.Size() / 2, 0,
-                   train_y.Size() / 2);
-
-    auto test = data.ReadTestData();
-    nsamples = test.first.shape(0);
-    auto mtest =
-        Reshape(test.first, Shape{nsamples, test.first.Size() / nsamples});
-    SubRow(mean, &mtest);
-    test_x = Reshape(mtest, test.first.shape());
-    test_y = test.second;
-  }
-
-  CHECK_EQ(train_x.shape(0), train_y.shape(0));
-  CHECK_EQ(test_x.shape(0), test_y.shape(0));
-  LOG(INFO) << "Total Training samples = " << train_y.shape(0)
-            << ", Total Test samples = " << test_y.shape(0);
-  CHECK_EQ(train_x_1.shape(0), train_y_1.shape(0));
-  LOG(INFO) << "On net 1, Training samples = " << train_y_1.shape(0)
-            << ", Test samples = " << test_y.shape(0);
-  CHECK_EQ(train_x_2.shape(0), train_y_2.shape(0));
-  LOG(INFO) << "On net 2, Training samples = " << train_y_2.shape(0);
-
-  auto net_1 = CreateNet();
-  auto net_2 = CreateNet();
-
-  SGD sgd;
-  OptimizerConf opt_conf;
-  opt_conf.set_momentum(0.9);
-  auto reg = opt_conf.mutable_regularizer();
-  reg->set_coefficient(0.004);
-  sgd.Setup(opt_conf);
-  sgd.SetLearningRateGenerator([lr](int step) {
-    if (step <= 120)
-      return 0.001;
-    else if (step <= 130)
-      return 0.0001;
-    else
-      return 0.00001;
-  });
-
-  SoftmaxCrossEntropy loss_1, loss_2;
-  Accuracy acc_1, acc_2;
-  /// Create updater aggregating gradient on CPU
-  Updater updater(2, &sgd);
-
-  /// Only need to register parameter once.
-  net_1.Compile(true, true, &updater, &loss_1, &acc_1);
-  net_2.Compile(true, false, &updater, &loss_2, &acc_1);
-
-  MemPoolConf mem_conf;
-  mem_conf.add_device(0);
-  mem_conf.add_device(1);
-  std::shared_ptr<DeviceMemPool> mem_pool(new CnMemPool(mem_conf));
-  std::shared_ptr<CudaGPU> cuda_1(new CudaGPU(0, mem_pool));
-  std::shared_ptr<CudaGPU> cuda_2(new CudaGPU(1, mem_pool));
-  net_1.ToDevice(cuda_1);
-  net_2.ToDevice(cuda_2);
-
-  /*
-  // this does not work for net_2
-  train_x_2.ResetLike(train_x);
-  train_y_2.ResetLike(train_y);
-  test_x_2.ResetLike(test_x);
-  test_y_2.ResetLike(test_y);
-
-  train_x.ToDevice(cuda_1);
-  train_y.ToDevice(cuda_1);
-  test_x.ToDevice(cuda_1);
-  test_y.ToDevice(cuda_1);
-
-  train_x_2.ToDevice(cuda_2);
-  train_y_2.ToDevice(cuda_2);
-  test_x_2.ToDevice(cuda_2);
-  test_y_2.ToDevice(cuda_2);
-  */
-
-  train_x_1.ToDevice(cuda_1);
-  train_y_1.ToDevice(cuda_1);
-  test_x.ToDevice(cuda_1);
-  test_y.ToDevice(cuda_1);
-  train_x_2.ToDevice(cuda_2);
-  train_y_2.ToDevice(cuda_2);
-
-  // net.Train(100, num_epoch, train_x, train_y, test_x, test_y);
-
-  LOG(INFO) << "Launching thread...";
-  std::thread t1 =
-      net_1.TrainThread(50, num_epoch, train_x_1, train_y_1, test_x, test_y);
-  std::thread t2 = net_2.TrainThread(50, num_epoch, train_x_2, train_y_2);
-  t1.join();
-  t2.join();
-}
-}
-
-int main(int argc, char **argv) {
-  singa::InitChannel(nullptr);
-  int pos = singa::ArgPos(argc, argv, "-epoch");
-  int nEpoch = 1;
-  if (pos != -1) nEpoch = atoi(argv[pos + 1]);
-  pos = singa::ArgPos(argc, argv, "-lr");
-  float lr = 0.001;
-  if (pos != -1) lr = atof(argv[pos + 1]);
-  pos = singa::ArgPos(argc, argv, "-data");
-  string data = "cifar-10-batches-bin";
-  if (pos != -1) data = argv[pos + 1];
-
-  LOG(INFO) << "Start training";
-  singa::Train(lr, nEpoch, data);
-  LOG(INFO) << "End training";
-}

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/0184fac3/examples/cifar10-parallel/cifar10.h
----------------------------------------------------------------------
diff --git a/examples/cifar10-parallel/cifar10.h 
b/examples/cifar10-parallel/cifar10.h
deleted file mode 100644
index d2b9225..0000000
--- a/examples/cifar10-parallel/cifar10.h
+++ /dev/null
@@ -1,98 +0,0 @@
-/************************************************************
-*
-* Licensed to the Apache Software Foundation (ASF) under one
-* or more contributor license agreements.  See the NOTICE file
-* distributed with this work for additional information
-* regarding copyright ownership.  The ASF licenses this file
-* to you under the Apache License, Version 2.0 (the
-* "License"); you may not use this file except in compliance
-* with the License.  You may obtain a copy of the License at
-*
-*   http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing,
-* software distributed under the License is distributed on an
-* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-* KIND, either express or implied.  See the License for the
-* specific language governing permissions and limitations
-* under the License.
-*
-*************************************************************/
-#include <fstream>
-#include <string>
-#include <cstdint>
-#include <iostream>
-#include "singa/core/tensor.h"
-using std::string;
-namespace singa {
-/// For reading cifar10 binary data as tensors.
-class Cifar10 {
- public:
-  /// 'dir_path': path to the folder including the *.bin files
-  Cifar10(string dir_path, bool normalize = true) : dir_path_(dir_path) {}
-
-  /// read all training data into an image Tensor and a label Tensor
-  const std::pair<Tensor, Tensor> ReadTrainData();
-  /// read all test data into an image Tensor and a label Tensor
-  const std::pair<Tensor, Tensor> ReadTestData();
-  /// read data from one file into an image Tensor and a label Tensor
-  const std::pair<Tensor, Tensor> ReadFile(string file);
-
-  void ReadImage(std::ifstream* file, int* label, char* buffer);
-
- private:
-  const size_t kImageSize = 32;
-  const size_t kImageVol = 3072;
-  const size_t kBatchSize = 10000;
-  const size_t kTrainFiles = 5;
-
-  string dir_path_;
-};
-
-void Cifar10::ReadImage(std::ifstream* file, int* label, char* buffer) {
-  char label_char;
-  file->read(&label_char, 1);
-  *label = static_cast<int>(label_char);
-  file->read(buffer, kImageVol);
-  return;
-}
-const std::pair<Tensor, Tensor> Cifar10::ReadFile(string file) {
-  Tensor images(Shape{kBatchSize, 3, kImageSize, kImageSize});
-  Tensor labels(Shape{kBatchSize}, kInt);
-  if (dir_path_.back() != '/') dir_path_.push_back('/');
-  LOG(INFO) << "Reading file " << dir_path_ + file;
-  std::ifstream data_file((dir_path_ + file).c_str(),
-                          std::ios::in | std::ios::binary);
-  CHECK(data_file.is_open()) << "Unable to open file " << dir_path_ + file;
-  int label;
-  char image[kImageVol];
-  float float_image[kImageVol];
-  int tmplabels[kBatchSize];
-  for (size_t itemid = 0; itemid < kBatchSize; ++itemid) {
-    // LOG(INFO) << "reading " << itemid << "-th image";
-    ReadImage(&data_file, &label, image);
-    for (size_t i = 0; i < kImageVol; i++)
-      float_image[i] = static_cast<float>(static_cast<uint8_t>(image[i]));
-    images.CopyDataFromHostPtr(float_image, kImageVol, itemid * kImageVol);
-    tmplabels[itemid] = label;
-  }
-  labels.CopyDataFromHostPtr(tmplabels, kBatchSize);
-  return std::make_pair(images, labels);
-}
-
-const std::pair<Tensor, Tensor> Cifar10::ReadTrainData() {
-  Tensor images(Shape{kBatchSize * kTrainFiles, 3, kImageSize, kImageSize});
-  Tensor labels(Shape{kBatchSize * kTrainFiles}, kInt);
-  for (size_t fileid = 0; fileid < kTrainFiles; ++fileid) {
-    string file = "data_batch_" + std::to_string(fileid + 1) + ".bin";
-    const auto ret = ReadFile(file);
-    CopyDataToFrom(&images, ret.first, ret.first.Size(),
-                   fileid * ret.first.Size());
-    CopyDataToFrom(&labels, ret.second, kBatchSize, fileid * kBatchSize);
-  }
-  return std::make_pair(images, labels);
-}
-const std::pair<Tensor, Tensor> Cifar10::ReadTestData() {
-  return ReadFile("test_batch.bin");
-}
-}  // namespace singa

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/0184fac3/examples/cifar10-parallel/download_data.py
----------------------------------------------------------------------
diff --git a/examples/cifar10-parallel/download_data.py 
b/examples/cifar10-parallel/download_data.py
deleted file mode 100755
index ce0ee4f..0000000
--- a/examples/cifar10-parallel/download_data.py
+++ /dev/null
@@ -1,52 +0,0 @@
-#!/usr/bin/env python
-import urllib
-import tarfile
-import os
-import sys
-import argparse
-
-
-def extract_tarfile(filepath):
-    if os.path.exists(filepath):
-        print 'The tar file does exist. Extracting it now..'
-        with tarfile.open(filepath, 'r') as f:
-            f.extractall('.')
-        print 'Finished!'
-        sys.exit(0)
-
-
-def check_dir_exist(dirpath):
-    if os.path.exists(dirpath):
-        print 'Directory %s does exist. To redownload the files, '\
-            'remove the existing directory and %s.tar.gz' % (dirpath, dirpath)
-        return True
-    else:
-        return False
-
-
-def do_download(dirpath, gzfile, url):
-    if check_dir_exist(dirpath):
-        sys.exit(0)
-    print 'Downloading CIFAR10 from %s' % (url)
-    urllib.urlretrieve(url, gzfile)
-    extract_tarfile(gzfile)
-    print 'Finished!'
-
-
-if __name__ == '__main__':
-    parser = argparse.ArgumentParser(description='Download Cifar10 datasets')
-    parser.add_argument(
-        'file',
-        type=str,
-        choices=['py', 'bin'])
-    args = parser.parse_args()
-    if args.file == 'bin':
-        dirpath = 'cifar-10-batches-bin'
-        gzfile = 'cifar-10-binary' + '.tar.gz'
-        url = 'http://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz'
-        do_download(dirpath, gzfile, url)
-    else:
-        dirpath = 'cifar-10-batches-py'
-        gzfile = 'cifar-10-python' + '.tar.gz'
-        url = 'http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'
-        do_download(dirpath, gzfile, url)

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/0184fac3/examples/cifar10-parallel/run.sh
----------------------------------------------------------------------
diff --git a/examples/cifar10-parallel/run.sh b/examples/cifar10-parallel/run.sh
deleted file mode 100755
index 379f847..0000000
--- a/examples/cifar10-parallel/run.sh
+++ /dev/null
@@ -1,2 +0,0 @@
-#!/usr/bin/env sh
-../../build/bin/alexnet-parallel -epoch 140

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/0184fac3/examples/cifar10/CMakeLists.txt
----------------------------------------------------------------------
diff --git a/examples/cifar10/CMakeLists.txt b/examples/cifar10/CMakeLists.txt
new file mode 100644
index 0000000..92f884c
--- /dev/null
+++ b/examples/cifar10/CMakeLists.txt
@@ -0,0 +1,13 @@
+INCLUDE_DIRECTORIES(${CMAKE_CURRENT_SOURCE_DIR})
+INCLUDE_DIRECTORIES(${CMAKE_BINARY_DIR}/include)
+
+IF(USE_CUDNN)
+ADD_EXECUTABLE(alexnet alexnet.cc)
+ADD_DEPENDENCIES(alexnet singa_core singa_model singa_utils)
+TARGET_LINK_LIBRARIES(alexnet singa_core singa_utils singa_model protobuf 
${SINGA_LIBKER_LIBS})
+
+ADD_EXECUTABLE(alexnet-parallel alexnet-parallel.cc)
+ADD_DEPENDENCIES(alexnet-parallel singa_core singa_model singa_utils)
+TARGET_LINK_LIBRARIES(alexnet-parallel singa_core singa_utils singa_model 
protobuf ${SINGA_LIBKER_LIBS})
+SET_TARGET_PROPERTIES(alexnet-parallel PROPERTIES LINK_FLAGS "${LINK_FLAGS} 
-pthread")
+ENDIF(USE_CUDNN)

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/0184fac3/examples/cifar10/alexnet-parallel.cc
----------------------------------------------------------------------
diff --git a/examples/cifar10/alexnet-parallel.cc 
b/examples/cifar10/alexnet-parallel.cc
new file mode 100644
index 0000000..15ef58e
--- /dev/null
+++ b/examples/cifar10/alexnet-parallel.cc
@@ -0,0 +1,287 @@
+/************************************************************
+*
+* Licensed to the Apache Software Foundation (ASF) under one
+* or more contributor license agreements.  See the NOTICE file
+* distributed with this work for additional information
+* regarding copyright ownership.  The ASF licenses this file
+* to you under the Apache License, Version 2.0 (the
+* "License"); you may not use this file except in compliance
+* with the License.  You may obtain a copy of the License at
+*
+*   http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing,
+* software distributed under the License is distributed on an
+* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+* KIND, either express or implied.  See the License for the
+* specific language governing permissions and limitations
+* under the License.
+*
+*************************************************************/
+
+#include "cifar10.h"
+#include "singa/model/feed_forward_net.h"
+#include "singa/model/optimizer.h"
+#include "singa/model/updater.h"
+#include "singa/model/initializer.h"
+#include "singa/model/metric.h"
+#include "singa/utils/channel.h"
+#include "singa/utils/string.h"
+#include "singa/core/memory.h"
+#include "../../src/model/layer/cudnn_convolution.h"
+#include "../../src/model/layer/cudnn_activation.h"
+#include "../../src/model/layer/cudnn_pooling.h"
+#include "../../src/model/layer/cudnn_lrn.h"
+#include "../../src/model/layer/dense.h"
+#include "../../src/model/layer/flatten.h"
+#include <thread>
+#include <memory>
+namespace singa {
+
+LayerConf GenConvConf(string name, int nb_filter, int kernel, int stride,
+                      int pad, float std) {
+  LayerConf conf;
+  conf.set_name(name);
+  conf.set_type("CudnnConvolution");
+  ConvolutionConf *conv = conf.mutable_convolution_conf();
+  conv->set_num_output(nb_filter);
+  conv->add_kernel_size(kernel);
+  conv->add_stride(stride);
+  conv->add_pad(pad);
+  conv->set_bias_term(true);
+
+  ParamSpec *wspec = conf.add_param();
+  wspec->set_name(name + "_weight");
+  auto wfill = wspec->mutable_filler();
+  wfill->set_type("Gaussian");
+  wfill->set_std(std);
+
+  ParamSpec *bspec = conf.add_param();
+  bspec->set_name(name + "_bias");
+  bspec->set_lr_mult(2);
+  //  bspec->set_decay_mult(0);
+  return conf;
+}
+
+LayerConf GenPoolingConf(string name, bool max_pool, int kernel, int stride,
+                         int pad) {
+  LayerConf conf;
+  conf.set_name(name);
+  conf.set_type("CudnnPooling");
+  PoolingConf *pool = conf.mutable_pooling_conf();
+  pool->set_kernel_size(kernel);
+  pool->set_stride(stride);
+  pool->set_pad(pad);
+  if (!max_pool) pool->set_pool(PoolingConf_PoolMethod_AVE);
+  return conf;
+}
+
+LayerConf GenReLUConf(string name) {
+  LayerConf conf;
+  conf.set_name(name);
+  conf.set_type("RELU");
+  return conf;
+}
+
+LayerConf GenDenseConf(string name, int num_output, float std, float wd) {
+  LayerConf conf;
+  conf.set_name(name);
+  conf.set_type("Dense");
+  DenseConf *dense = conf.mutable_dense_conf();
+  dense->set_num_output(num_output);
+
+  ParamSpec *wspec = conf.add_param();
+  wspec->set_name(name + "_weight");
+  wspec->set_decay_mult(wd);
+  auto wfill = wspec->mutable_filler();
+  wfill->set_type("Gaussian");
+  wfill->set_std(std);
+
+  ParamSpec *bspec = conf.add_param();
+  bspec->set_name(name + "_bias");
+  bspec->set_lr_mult(2);
+  bspec->set_decay_mult(0);
+
+  return conf;
+}
+
+LayerConf GenLRNConf(string name) {
+  LayerConf conf;
+  conf.set_name(name);
+  conf.set_type("CudnnLRN");
+  LRNConf *lrn = conf.mutable_lrn_conf();
+  lrn->set_local_size(3);
+  lrn->set_alpha(5e-05);
+  lrn->set_beta(0.75);
+  return conf;
+}
+
+LayerConf GenFlattenConf(string name) {
+  LayerConf conf;
+  conf.set_name(name);
+  conf.set_type("Flatten");
+  return conf;
+}
+
+FeedForwardNet CreateNet() {
+  FeedForwardNet net;
+  Shape s{3, 32, 32};
+
+  net.Add(new CudnnConvolution(), GenConvConf("conv1", 32, 5, 1, 2, 0.0001),
+          &s);
+  net.Add(new CudnnPooling(), GenPoolingConf("pool1", true, 3, 2, 1));
+  net.Add(new CudnnActivation(), GenReLUConf("relu1"));
+  net.Add(new CudnnLRN(), GenLRNConf("lrn1"));
+  net.Add(new CudnnConvolution(), GenConvConf("conv2", 32, 5, 1, 2, 0.01));
+  net.Add(new CudnnActivation(), GenReLUConf("relu2"));
+  net.Add(new CudnnPooling(), GenPoolingConf("pool2", false, 3, 2, 1));
+  net.Add(new CudnnLRN(), GenLRNConf("lrn2"));
+  net.Add(new CudnnConvolution, GenConvConf("conv3", 64, 5, 1, 2, 0.01));
+  net.Add(new CudnnActivation(), GenReLUConf("relu3"));
+  net.Add(new CudnnPooling(), GenPoolingConf("pool3", false, 3, 2, 1));
+  net.Add(new Flatten(), GenFlattenConf("flat"));
+  net.Add(new Dense(), GenDenseConf("ip", 10, 0.01, 250));
+  return net;
+}
+
+void Train(float lr, int num_epoch, string data_dir) {
+  Cifar10 data(data_dir);
+  Tensor train_x, train_y, test_x, test_y;
+  Tensor train_x_1, train_x_2, train_y_1, train_y_2;
+  {
+    auto train = data.ReadTrainData();
+    size_t nsamples = train.first.shape(0);
+    auto mtrain =
+        Reshape(train.first, Shape{nsamples, train.first.Size() / nsamples});
+    const Tensor &mean = Average(mtrain, 0);
+    SubRow(mean, &mtrain);
+    train_x = Reshape(mtrain, train.first.shape());
+    train_y = train.second;
+
+    LOG(INFO) << "Slicing training data...";
+    train_x_1.Reshape(Shape{nsamples / 2, train.first.shape(1),
+        train.first.shape(2), train.first.shape(3)});
+    LOG(INFO) << "Copying first data slice...";
+    CopyDataToFrom(&train_x_1, train_x, train_x.Size() / 2);
+    train_x_2.Reshape(Shape{nsamples / 2, train.first.shape(1),
+        train.first.shape(2), train.first.shape(3)});
+    LOG(INFO) << "Copying second data slice...";
+    CopyDataToFrom(&train_x_2, train_x, train_x.Size() / 2, 0,
+                   train_x.Size() / 2);
+    train_y_1.Reshape(Shape{nsamples / 2});
+    train_y_1.AsType(kInt);
+    LOG(INFO) << "Copying first label slice...";
+    CopyDataToFrom(&train_y_1, train_y, train_y.Size() / 2);
+    train_y_2.Reshape(Shape{nsamples / 2});
+    train_y_2.AsType(kInt);
+    LOG(INFO) << "Copying second label slice...";
+    CopyDataToFrom(&train_y_2, train_y, train_y.Size() / 2, 0,
+                   train_y.Size() / 2);
+
+    auto test = data.ReadTestData();
+    nsamples = test.first.shape(0);
+    auto mtest =
+        Reshape(test.first, Shape{nsamples, test.first.Size() / nsamples});
+    SubRow(mean, &mtest);
+    test_x = Reshape(mtest, test.first.shape());
+    test_y = test.second;
+  }
+
+  CHECK_EQ(train_x.shape(0), train_y.shape(0));
+  CHECK_EQ(test_x.shape(0), test_y.shape(0));
+  LOG(INFO) << "Total Training samples = " << train_y.shape(0)
+            << ", Total Test samples = " << test_y.shape(0);
+  CHECK_EQ(train_x_1.shape(0), train_y_1.shape(0));
+  LOG(INFO) << "On net 1, Training samples = " << train_y_1.shape(0)
+            << ", Test samples = " << test_y.shape(0);
+  CHECK_EQ(train_x_2.shape(0), train_y_2.shape(0));
+  LOG(INFO) << "On net 2, Training samples = " << train_y_2.shape(0);
+
+  auto net_1 = CreateNet();
+  auto net_2 = CreateNet();
+
+  SGD sgd;
+  OptimizerConf opt_conf;
+  opt_conf.set_momentum(0.9);
+  auto reg = opt_conf.mutable_regularizer();
+  reg->set_coefficient(0.004);
+  sgd.Setup(opt_conf);
+  sgd.SetLearningRateGenerator([lr](int step) {
+    if (step <= 120)
+      return 0.001;
+    else if (step <= 130)
+      return 0.0001;
+    else
+      return 0.00001;
+  });
+
+  SoftmaxCrossEntropy loss_1, loss_2;
+  Accuracy acc_1, acc_2;
+  /// Create updater aggregating gradient on CPU
+  std::shared_ptr<Updater> updater = std::make_shared<LocalUpdater>(2, &sgd);
+
+  /// Only need to register parameter once.
+  net_1.Compile(true, true, updater, &loss_1, &acc_1);
+  net_2.Compile(true, false, updater, &loss_2, &acc_1);
+
+  MemPoolConf mem_conf;
+  mem_conf.add_device(0);
+  mem_conf.add_device(1);
+  std::shared_ptr<DeviceMemPool> mem_pool(new CnMemPool(mem_conf));
+  std::shared_ptr<CudaGPU> cuda_1(new CudaGPU(0, mem_pool));
+  std::shared_ptr<CudaGPU> cuda_2(new CudaGPU(1, mem_pool));
+  net_1.ToDevice(cuda_1);
+  net_2.ToDevice(cuda_2);
+
+  /*
+  // this does not work for net_2
+  train_x_2.ResetLike(train_x);
+  train_y_2.ResetLike(train_y);
+  test_x_2.ResetLike(test_x);
+  test_y_2.ResetLike(test_y);
+
+  train_x.ToDevice(cuda_1);
+  train_y.ToDevice(cuda_1);
+  test_x.ToDevice(cuda_1);
+  test_y.ToDevice(cuda_1);
+
+  train_x_2.ToDevice(cuda_2);
+  train_y_2.ToDevice(cuda_2);
+  test_x_2.ToDevice(cuda_2);
+  test_y_2.ToDevice(cuda_2);
+  */
+
+  train_x_1.ToDevice(cuda_1);
+  train_y_1.ToDevice(cuda_1);
+  test_x.ToDevice(cuda_1);
+  test_y.ToDevice(cuda_1);
+  train_x_2.ToDevice(cuda_2);
+  train_y_2.ToDevice(cuda_2);
+
+  // net.Train(100, num_epoch, train_x, train_y, test_x, test_y);
+
+  LOG(INFO) << "Launching thread...";
+  std::thread t1 =
+      net_1.TrainThread(50, num_epoch, train_x_1, train_y_1, test_x, test_y);
+  std::thread t2 = net_2.TrainThread(50, num_epoch, train_x_2, train_y_2);
+  t1.join();
+  t2.join();
+}
+}
+
+int main(int argc, char **argv) {
+  singa::InitChannel(nullptr);
+  int pos = singa::ArgPos(argc, argv, "-epoch");
+  int nEpoch = 1;
+  if (pos != -1) nEpoch = atoi(argv[pos + 1]);
+  pos = singa::ArgPos(argc, argv, "-lr");
+  float lr = 0.001;
+  if (pos != -1) lr = atof(argv[pos + 1]);
+  pos = singa::ArgPos(argc, argv, "-data");
+  string data = "cifar-10-batches-bin";
+  if (pos != -1) data = argv[pos + 1];
+
+  LOG(INFO) << "Start training";
+  singa::Train(lr, nEpoch, data);
+  LOG(INFO) << "End training";
+}

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/0184fac3/examples/cifar10/alexnet.cc
----------------------------------------------------------------------
diff --git a/examples/cifar10/alexnet.cc b/examples/cifar10/alexnet.cc
index b237fb6..6480557 100644
--- a/examples/cifar10/alexnet.cc
+++ b/examples/cifar10/alexnet.cc
@@ -22,19 +22,16 @@
 #include "./cifar10.h"
 #include "singa/model/feed_forward_net.h"
 #include "singa/model/optimizer.h"
-#include "singa/model/updater.h"
 #include "singa/model/initializer.h"
 #include "singa/model/metric.h"
 #include "singa/utils/channel.h"
 #include "singa/utils/string.h"
-#include "singa/core/memory.h"
 #include "../../src/model/layer/cudnn_convolution.h"
 #include "../../src/model/layer/cudnn_activation.h"
 #include "../../src/model/layer/cudnn_pooling.h"
 #include "../../src/model/layer/cudnn_lrn.h"
 #include "../../src/model/layer/dense.h"
 #include "../../src/model/layer/flatten.h"
-#include <thread>
 namespace singa {
 
 LayerConf GenConvConf(string name, int nb_filter, int kernel, int stride,
@@ -58,7 +55,7 @@ LayerConf GenConvConf(string name, int nb_filter, int kernel, 
int stride,
   ParamSpec *bspec = conf.add_param();
   bspec->set_name(name + "_bias");
   bspec->set_lr_mult(2);
-  //  bspec->set_decay_mult(0);
+//  bspec->set_decay_mult(0);
   return conf;
 }
 
@@ -151,11 +148,10 @@ void Train(float lr, int num_epoch, string data_dir) {
     size_t nsamples = train.first.shape(0);
     auto mtrain =
         Reshape(train.first, Shape{nsamples, train.first.Size() / nsamples});
-    const Tensor &mean = Average(mtrain, 0);
+    const Tensor& mean = Average(mtrain, 0);
     SubRow(mean, &mtrain);
     train_x = Reshape(mtrain, train.first.shape());
     train_y = train.second;
-
     auto test = data.ReadTestData();
     nsamples = test.first.shape(0);
     auto mtest =
@@ -164,14 +160,11 @@ void Train(float lr, int num_epoch, string data_dir) {
     test_x = Reshape(mtest, test.first.shape());
     test_y = test.second;
   }
-
   CHECK_EQ(train_x.shape(0), train_y.shape(0));
   CHECK_EQ(test_x.shape(0), test_y.shape(0));
-  LOG(INFO) << "Total Training samples = " << train_y.shape(0)
-            << ", Total Test samples = " << test_y.shape(0);
-
+  LOG(INFO) << "Training samples = " << train_y.shape(0)
+            << ", Test samples = " << test_y.shape(0);
   auto net = CreateNet();
-
   SGD sgd;
   OptimizerConf opt_conf;
   opt_conf.set_momentum(0.9);
@@ -187,24 +180,16 @@ void Train(float lr, int num_epoch, string data_dir) {
       return 0.00001;
   });
 
-  SoftmaxCrossEntropy loss_1, loss_2;
-  Accuracy acc_1, acc_2;
-  /// Create updater aggregating gradient on CPU
-  Updater updater(1, &sgd);
+  SoftmaxCrossEntropy loss;
+  Accuracy acc;
+  net.Compile(true, &sgd, &loss, &acc);
 
-  net.Compile(true, true, &updater, &loss_1, &acc_1);
-
-  MemPoolConf mem_conf;
-  mem_conf.add_device(0);
-  std::shared_ptr<DeviceMemPool> mem_pool(new CnMemPool(mem_conf));
-  std::shared_ptr<CudaGPU> cuda(new CudaGPU(0, mem_pool));
+  auto cuda = std::make_shared<CudaGPU>();
   net.ToDevice(cuda);
-
   train_x.ToDevice(cuda);
   train_y.ToDevice(cuda);
   test_x.ToDevice(cuda);
   test_y.ToDevice(cuda);
-
   net.Train(100, num_epoch, train_x, train_y, test_x, test_y);
 }
 }

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/0184fac3/examples/cifar10/run-parallel.sh
----------------------------------------------------------------------
diff --git a/examples/cifar10/run-parallel.sh b/examples/cifar10/run-parallel.sh
new file mode 100755
index 0000000..6a9109a
--- /dev/null
+++ b/examples/cifar10/run-parallel.sh
@@ -0,0 +1,2 @@
+#!/usr/bin/env sh
+../../build/bin/alexnet-parallel -epoch 4

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/0184fac3/include/singa/model/feed_forward_net.h
----------------------------------------------------------------------
diff --git a/include/singa/model/feed_forward_net.h 
b/include/singa/model/feed_forward_net.h
index 6eeb34d..8adc259 100644
--- a/include/singa/model/feed_forward_net.h
+++ b/include/singa/model/feed_forward_net.h
@@ -22,6 +22,7 @@
 #include "singa/model/metric.h"
 #include "singa/model/updater.h"
 #include <thread>
+#include <memory>
 namespace singa {
 
 /// The feed-forward neural net.
@@ -54,13 +55,25 @@ class FeedForwardNet {
   Layer* Add(Layer* layer, const LayerConf& conf,
              const Shape* sample_shape = nullptr);
   /// Set some fields used for training and evaluating the neural net.
+  /// This method will instantiate an Updater ,then wrap the Optimier into
+  /// Updater and always register the parameters of the net instance.
   /// If the neural net is constructed for evaluation only, then 'opt' is not
   /// necessary; But for training, both 'opt' and 'loss' are necessary.
   /// 'shuffle' indicates shuffling training samples within one epoch it is
   /// valid using Train(). If to_register is set true, parameter will be
-  /// registered in Updater;
-  void Compile(bool shuffle, bool to_register, Updater* updater, Loss* loss,
-               Metric* metric);
+  /// registered in Updater.;
+  void Compile(bool shuffle, Optimizer* opt, Loss* loss, Metric* metric);
+  /// Set some fields used for training and evaluating the neural net.
+  /// This method is mainly used in parallel training, where we need
+  /// multiple neuralnet instances.
+  /// If the neural net is constructed for evaluation only, then 'updater' is
+  /// not
+  /// necessary; But for training, both 'opt' and 'loss' are necessary.
+  /// 'shuffle' indicates shuffling training samples within one epoch it is
+  /// valid using Train(). If to_register is set true, parameter will be
+  /// registered in Updater.;
+  void Compile(bool shuffle, bool to_register, std::shared_ptr<Updater> 
updater,
+               Loss* loss, Metric* metric);
 
   /// Conduct the training giving the training data 'x' and label 'y'.
   /// 'val_split' of training data is used for
@@ -124,7 +137,8 @@ class FeedForwardNet {
   std::thread TrainThread(size_t batchsize, int nb_epoch, const Tensor& x,
                           const Tensor& y, const Tensor& val_x,
                           const Tensor& val_y) {
-    return std::thread([=]() { Train(batchsize, nb_epoch, x, y, val_x, val_y); 
});
+    return std::thread(
+        [=]() { Train(batchsize, nb_epoch, x, y, val_x, val_y); });
   }
 
   /// A wrapper method to spawn a thread to execute Train() method.
@@ -140,7 +154,7 @@ class FeedForwardNet {
 
  protected:
   vector<Layer*> layers_;
-  Updater* updater_;
+  std::shared_ptr<Updater> updater_;
   Loss* loss_;
   Metric* metric_;
 

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/0184fac3/include/singa/model/updater.h
----------------------------------------------------------------------
diff --git a/include/singa/model/updater.h b/include/singa/model/updater.h
index dbd91e8..ef7c32a 100644
--- a/include/singa/model/updater.h
+++ b/include/singa/model/updater.h
@@ -32,20 +32,17 @@
 #include <unordered_map>
 
 namespace singa {
+/// Basic Updater class just forward all the method function call
+/// to the wrapped Optimizer.
 class Updater {
  public:
-  Updater(int total_num, Optimizer* opt,
-          std::shared_ptr<Device> dev = defaultDevice)
-      : total_num_{total_num}, opt_{opt}, dev_(dev) {}
+  explicit Updater(Optimizer* opt) : opt_{opt} {}
   virtual ~Updater() {}
-
   /// Forward Setup() to Optimizer.
   virtual void Setup(const OptimizerConf& conf);
   /// Forward Register() to Optimizer.
   virtual void Register(const string& name, const ParamSpec& specs);
-  /// Update parameter value based on given gradient by invoking optimizer
-  /// algoritim. When tranining net call this function will be blocked until
-  /// all the partial gradients are aggrageted in a synchronized style 
training.
+  /// Forward Apply() to Optimizer.
   virtual void Apply(int step, const string& name, Tensor& grad, Tensor& 
value);
   Optimizer* GetOptimizer() { return opt_; }
 
@@ -54,14 +51,35 @@ class Updater {
   void operator=(const Updater&) = delete;
 
  protected:
-  int total_num_;
   Optimizer* opt_;
+};
+
+/// LocalUpdater do gradient aggregation and update gradient calling
+/// the wrapped Optimizer on a specific device (i.e., CPU or GPU).
+class LocalUpdater : public Updater {
+ public:
+  LocalUpdater(int total_num, Optimizer* opt,
+               std::shared_ptr<Device> dev = defaultDevice)
+      : Updater(opt), total_num_{total_num}, dev_(dev) {}
+  virtual ~LocalUpdater() override {}
+  /// Forward Register() to Optimizer.
+  virtual void Register(const string& name, const ParamSpec& specs) override;
+  /// Update parameter value based on given gradient by invoking optimizer
+  /// algoritim. When tranining net call this function will be blocked until
+  /// all the partial gradients are aggrageted in a synchronized style 
training.
+  virtual void Apply(int step, const string& name, Tensor& grad,
+                     Tensor& value) override;
+
+ private:
+  int total_num_;
   std::shared_ptr<Device> dev_;
-  std::mutex mtx_;
-  std::condition_variable aggr_count_eq_total_num_;
   std::unordered_map<std::string, int> aggr_count_, copy_count_;
-  std::unordered_map<std::string, Tensor> buffer_, partial_sum_;
-  std::unordered_map<std::string, bool> has_averaged_, has_init_;
+  std::unordered_map<std::string, Tensor> grad_buffer_,
+    param_buffer_, partial_sum_;
+  std::unordered_map<std::string, bool> has_init_;
+  std::unordered_map<std::string, std::mutex> mtx_;
+  std::unordered_map<std::string, std::condition_variable>
+    aggr_count_eq_total_num_;
 };
 }  //  namespace singa
 

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/0184fac3/src/CMakeLists.txt
----------------------------------------------------------------------
diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt
index aa3ab36..65a81fc 100644
--- a/src/CMakeLists.txt
+++ b/src/CMakeLists.txt
@@ -61,6 +61,7 @@ AUX_SOURCE_DIRECTORY(model/layer model_source)
 AUX_SOURCE_DIRECTORY(model/optimizer model_source)
 AUX_SOURCE_DIRECTORY(model/loss model_source)
 AUX_SOURCE_DIRECTORY(model/metric model_source)
+AUX_SOURCE_DIRECTORY(model/updater model_source)
 #MESSAGE(STATUS "MODEL ${model_source}")
 ADD_LIBRARY(singa_model SHARED ${model_source})
 TARGET_LINK_LIBRARIES(singa_model ${SINGA_LINKER_LIBS})

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/0184fac3/src/model/feed_forward_net.cc
----------------------------------------------------------------------
diff --git a/src/model/feed_forward_net.cc b/src/model/feed_forward_net.cc
index b30d24a..297dda0 100644
--- a/src/model/feed_forward_net.cc
+++ b/src/model/feed_forward_net.cc
@@ -73,8 +73,15 @@ const vector<ParamSpec> FeedForwardNet::GetParamSpecs() 
const {
   return specs;
 }
 
-void FeedForwardNet::Compile(bool shuffle, bool to_register, Updater* updater,
-                             Loss* loss, Metric* metric) {
+void FeedForwardNet::Compile(bool shuffle, Optimizer* opt, Loss* loss,
+                             Metric* metric) {
+  std::shared_ptr<Updater> updater = std::make_shared<Updater>(opt);
+  Compile(shuffle, true, updater, loss, metric);
+}
+
+void FeedForwardNet::Compile(bool shuffle, bool to_register,
+                             std::shared_ptr<Updater> updater, Loss* loss,
+                             Metric* metric) {
   shuffle_ = shuffle;
   bool train = (updater != nullptr) && (loss != nullptr);
   bool test = metric != nullptr;

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/0184fac3/src/model/updater.cc
----------------------------------------------------------------------
diff --git a/src/model/updater.cc b/src/model/updater.cc
deleted file mode 100644
index e6c1cdb..0000000
--- a/src/model/updater.cc
+++ /dev/null
@@ -1,87 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "singa/model/updater.h"
-
-namespace singa {
-
-void Updater::Setup(const OptimizerConf& conf) { opt_->Setup(conf); }
-
-void Updater::Register(const string& name, const ParamSpec& specs) {
-  opt_->Register(name, specs);
-  aggr_count_[name] = 0;
-  copy_count_[name] = 0;
-  has_averaged_[name] = false;
-  has_init_[name] = false;
-}
-
-void Updater::Apply(int step, const string& name, Tensor& grad, Tensor& value) 
{
-  CHECK(aggr_count_.count(name) == 1) << "Parameter " << name
-                                         << " has not been registered before.";
-  /// This lock is aimed to protect aggregation counter, data transfering 
buffer,
-  /// and partial aggregation result. However, the data transfering can be 
moved out
-  /// of the critial section to improve performance in the future.
-  std::unique_lock<std::mutex> lock(mtx_);
-  if (aggr_count_[name] == 0) {
-    if (!has_init_[name]) {
-      Tensor tmp(grad.shape(), dev_, grad.data_type());
-      partial_sum_[name] = tmp;
-      buffer_[name].ResetLike(tmp);
-      has_init_[name] = true;
-    } else {
-      partial_sum_[name].SetValue(.0f);
-    }
-  }
-  buffer_[name].CopyData(grad);
-  Add(partial_sum_[name], buffer_[name], &partial_sum_[name]);
-  ++aggr_count_[name];
-
-  /// Block this thread when we have not gotten enough paritial gradients.
-  if (aggr_count_[name] != total_num_) {
-    while (aggr_count_[name] != total_num_) {
-      aggr_count_eq_total_num_.wait(lock);
-    }
-  } else {
-    aggr_count_eq_total_num_.notify_all();
-  }
-
-  /// Now we get enought paritial gradient from all neural net instances,
-  /// then we calcuate the average gradient. The first notified thread
-  /// finish the averaging once.
-  if (!has_averaged_[name]) {
-    Div(partial_sum_[name], static_cast<float>(total_num_),
-        &partial_sum_[name]);
-    copy_count_[name] = 0;
-    has_averaged_[name] = true;
-  }
-
-  /// For now, gradient aggregation and SGD algorithm run on the same device.
-  /// TODO(wangji): It is possible to make them run separately.
-  buffer_[name].CopyData(value);
-  /// Apply optimization algorithm based on the averaged gradient.
-  opt_->Apply(step, name, partial_sum_[name], buffer_[name]);
-  value.CopyData(buffer_[name]);
-
-  /// The last thread finishing copy should set aggregation counter back to 0.
-  ++copy_count_[name];
-  if (copy_count_[name] == total_num_) {
-    aggr_count_[name] = 0;
-    has_averaged_[name] = false;
-  }
-}
-}  // namesapce singa

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/0184fac3/src/model/updater/local_updater.cc
----------------------------------------------------------------------
diff --git a/src/model/updater/local_updater.cc 
b/src/model/updater/local_updater.cc
new file mode 100644
index 0000000..575f7ab
--- /dev/null
+++ b/src/model/updater/local_updater.cc
@@ -0,0 +1,82 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "singa/model/updater.h"
+
+namespace singa {
+
+void LocalUpdater::Register(const string& name, const ParamSpec& specs) {
+  opt_->Register(name, specs);
+  aggr_count_[name] = 0;
+  copy_count_[name] = 0;
+  has_init_[name] = false;
+}
+
+void LocalUpdater::Apply(int step, const string& name, Tensor& grad, Tensor& 
value) {
+  CHECK(aggr_count_.count(name) == 1) << "Parameter " << name
+                                         << " has not been registered before.";
+  /// This lock is aimed to protect aggregation counter, data transfering 
buffer,
+  /// and partial aggregation result. However, the data transfering can be 
moved out
+  /// of the critial section to improve performance in the future.
+  std::unique_lock<std::mutex> lock(mtx_[name]);
+  if (aggr_count_[name] == 0) {
+    switch (has_init_[name]) {
+      case (false): {
+        Tensor tmp(grad.shape(), dev_, grad.data_type());
+        partial_sum_[name] = tmp;
+        param_buffer_[name].ResetLike(tmp);
+        grad_buffer_[name].ResetLike(tmp);
+        has_init_[name] = true;
+        /* No break: intented fall-through */
+      }
+      case (true):
+        param_buffer_[name].CopyData(value);
+        partial_sum_[name].SetValue(.0f);
+    }
+  }
+  grad_buffer_[name].CopyData(grad);
+  Add(partial_sum_[name], grad_buffer_[name], &partial_sum_[name]);
+  ++aggr_count_[name];
+
+  /// Block this thread when we have not gotten enough paritial gradients.
+  if (aggr_count_[name] != total_num_) {
+    while (aggr_count_[name] != total_num_) {
+      aggr_count_eq_total_num_[name].wait(lock);
+    }
+  } else {
+  /// Now we get enought paritial gradient from all neural net instances,
+  /// then we calcuate the average gradient. The first notified thread
+  /// finish the averaging once.
+    Div(partial_sum_[name], static_cast<float>(total_num_),
+        &partial_sum_[name]);
+    copy_count_[name] = 0;
+  /// For now, gradient aggregation and SGD algorithm run on the same device.
+  /// TODO(wangji): It is possible to make them run separately.
+  /// Apply optimization algorithm based on the averaged gradient.
+    opt_->Apply(step, name, partial_sum_[name], param_buffer_[name]);
+    aggr_count_eq_total_num_[name].notify_all();
+  }
+
+  value.CopyData(param_buffer_[name]);
+
+  /// The last thread finishing copy should set aggregation counter back to 0.
+  ++copy_count_[name];
+  if (copy_count_[name] == total_num_)
+    aggr_count_[name] = 0;
+}
+}  // namesapce singa

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/0184fac3/src/model/updater/updater.cc
----------------------------------------------------------------------
diff --git a/src/model/updater/updater.cc b/src/model/updater/updater.cc
new file mode 100644
index 0000000..d386d30
--- /dev/null
+++ b/src/model/updater/updater.cc
@@ -0,0 +1,32 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "singa/model/updater.h"
+
+namespace singa {
+
+void Updater::Setup(const OptimizerConf& conf) { opt_->Setup(conf); }
+
+void Updater::Register(const string& name, const ParamSpec& specs) {
+  opt_->Register(name, specs);
+}
+
+void Updater::Apply(int step, const string& name, Tensor& grad, Tensor& value) 
{
+  opt_->Apply(step, name, grad, value);
+}
+}  // namesapce singa

Reply via email to