http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b2dc51d2/src/test/dist_test/test_table_server.cc
----------------------------------------------------------------------
diff --git a/src/test/dist_test/test_table_server.cc 
b/src/test/dist_test/test_table_server.cc
new file mode 100644
index 0000000..5f3612c
--- /dev/null
+++ b/src/test/dist_test/test_table_server.cc
@@ -0,0 +1,357 @@
+//  Copyright © 2014 Anh Dinh. All Rights Reserved.
+
+#include "core/global-table.h"
+#include "core/common.h"
+#include "core/table.h"
+#include "core/table_server.h"
+#include "utils/global_context.h"
+#include "utils/common.h"
+#include <gflags/gflags.h>
+#include "proto/model.pb.h"
+#include "proto/common.pb.h"
+#include "worker.h"
+#include "coordinator.h"
+#include "utils/common.h"
+#include "utils/proto_helper.h"
+
+#include <cmath>
+#include <stdlib.h>
+#include <vector>
+#include <iostream>
+#include <fstream>
+
+
+/**
+ * Test for table server access. The table is of type <VKey,int>
+ */
+DEFINE_bool(restore_mode, false, "restore from checkpoint file");
+using namespace lapis;
+using std::vector;
+
+DEFINE_int32(checkpoint_frequency, 5000, "frequency for cp");
+DEFINE_int32(checkpoint_after, 1, "cp after this steps");
+DEFINE_string(par_mode, "hybrid",  "time training algorithm");
+DEFINE_bool(restore, false, "restore from checkpoint file");
+
+DEFINE_string(db_backend, "lmdb", "backend db");
+DEFINE_string(system_conf, "examples/imagenet12/system.conf", "configuration 
file for node roles");
+DEFINE_string(model_conf, "examples/imagenet12/model.conf", "DL model 
configuration file");
+DEFINE_string(checkpoint_dir,"/data1/wangwei/lapis/","check point dir");
+DEFINE_int32(threshold,1000000, "max # of parameters in a vector");
+DEFINE_int32(iterations,5,"numer of get/put iterations");
+DEFINE_int32(workers,2,"numer of workers doing get/put");
+DECLARE_bool(checkpoint_enabled);
+
+
+DECLARE_bool(checkpoint_enabled);
+
+/**
+ * Get and update handler for VKey.
+ */
+struct AnhUpdateHandler: BaseUpdateHandler<VKey, SGDValue> {
+       bool Update(SGDValue *a, const SGDValue &b) {
+
+               float * adptr = 
a->mutable_data()->mutable_value()->mutable_data();
+               const float*bdptr = b.grad(0).value().data();
+               for (int i = 0; i < b.grad(0).value_size(); i++)
+                       adptr[i] += bdptr[i];
+
+               return true;
+       }
+
+       bool Get(const VKey k, const SGDValue &val, SGDValue *ret) {
+               *ret = val;
+               return true;
+       }
+
+       bool is_checkpointable(const VKey k, const SGDValue v) {
+               return false; //always checkpoint
+       }
+};
+
+typedef map<int, GlobalTable*> Map;
+Map tables;
+shared_ptr<NetworkThread> network;
+shared_ptr<GlobalContext> context;
+std::vector<ServerState*> server_states;
+TableServer *table_server;
+
+#define SIZE 16
+int tuple_sizes[SIZE] = {27448736, 16777216, 4096000, 1327104, 884736, 884736, 
614400,14112,4096,4096,1000,384,384,256,256,96};
+
+/**
+ * Initialize tables.
+ */
+void create_mem_table(int id, int num_shards){
+
+       TableDescriptor *info = new TableDescriptor(id, num_shards);
+         info->key_marshal = new Marshal<VKey>();
+         info->value_marshal = new Marshal<SGDValue>();
+         info->sharder = new VKeySharder;
+         info->accum = new AnhUpdateHandler;
+         info->partition_factory = new typename SparseTable<VKey, 
SGDValue>::Factory;
+         auto table=new TypedGlobalTable<VKey, SGDValue>();
+         table->Init(info);
+         tables[id] = table;
+}
+
+/**
+ * Coordinator assigns shards to processes.
+ * @param id table ID.
+ */
+void coordinator_assign_tables(int id) {
+
+       // wait for the servers to be up.
+       for (int i = 0; i < context->num_procs(); i++) {
+               RegisterWorkerRequest req;
+               int src = 0;
+               //  adding memory server.
+               if (context->IsTableServer(i)) {
+                       VLOG(3)<< "Waiting for message from table server " << i;
+                       network->Read(MPI::ANY_SOURCE, MTYPE_REGISTER_WORKER, 
&req, &src);
+                       server_states.push_back(new ServerState(i));
+               }
+       }
+
+       VLOG(3) << " All servers registered and started up. Ready to go";
+       VLOG(3) << "num of shards" << tables[id]->num_shards() << " for table " 
<< id;
+
+       // assign table to shard in round roubin fashion.
+       int server_idx = 0;
+       for (int shard = 0; shard < tables[id]->num_shards(); ++shard) {
+               ServerState &server = *server_states[server_idx];
+               VLOG(3) << "Assigning table (" << id << "," << shard << ") to 
server "
+                               << server_states[server_idx]->server_id;
+               server.shard_id = shard;
+               server.local_shards.insert(new TaskId(id, shard));
+               server_idx = (server_idx + 1) % server_states.size();
+       }
+       ShardAssignmentRequest req;
+       for (size_t i = 0; i < server_states.size(); ++i) {
+               ServerState &server = *server_states[i];
+               for (auto * task : server.local_shards) {
+                       ShardAssignment *s = req.add_assign();
+                       s->set_new_worker(server.server_id);
+                       s->set_table(task->table);
+                       s->set_shard(task->shard);
+                       //  update local tables
+                       GlobalTable *t = tables.at(task->table);
+                       t->get_partition_info(task->shard)->owner = 
server.server_id;
+                       delete task;
+               }
+       }
+
+       network->SyncBroadcast(MTYPE_SHARD_ASSIGNMENT, 
MTYPE_SHARD_ASSIGNMENT_DONE,
+                       req);
+       VLOG(3) << "done table assignment... ";
+}
+
+
+void table_init(){
+       table_server = new TableServer();
+       table_server->StartTableServer(tables);
+       VLOG(3) << "table server started on process "<< 
NetworkThread::Get()->id();
+}
+
+
+/**
+ * Coordinator loads data to the table.
+ * @param size number of tuples.
+ */
+void coordinator_load_data() {
+       auto table = static_cast<TypedGlobalTable<VKey, SGDValue>*>(tables[0]);
+       for (int i = 0; i < SIZE; i++) {
+               VKey key;
+               SGDValue x;
+               DAryProto *data = x.mutable_data();
+               DAryProto *grad = x.add_grad();
+               for (int j = 0; j < tuple_sizes[i]; j++) {
+                       data->add_value(j * 1.0f);
+                       grad->add_value(j * 1.0f);
+               }
+               key.set_key(i);
+               table->put(key, x);
+       }
+       VLOG(3) << "Done loading " << SIZE << " tuples ...";
+}
+
+/**
+ * Worker gets tuples from the server.
+ * @param size number of tuples to be requested.
+ */
+void get() {
+       auto table = static_cast<TypedGlobalTable<VKey,SGDValue>*>(tables[0]);
+       SGDValue value;
+       for (int i = 0; i < SIZE; i++) {
+               VKey key;
+               key.set_key(i);
+               table->async_get(key, &value);
+       }
+       VLOG(3) << "Done sending get requests ...";
+
+       for (int i = 0; i < SIZE; i++) {
+               VKey key;
+               while (!table->async_get_collect(&key, &value))
+                       Sleep(0.0001);
+       }
+}
+
+/**
+ * Worker updates tuples.
+ */
+void update() {
+       auto table = static_cast<TypedGlobalTable<VKey, SGDValue>*>(tables[0]);
+       for (int i = 0; i < SIZE; i++) {
+               VKey key;
+               key.set_key(i);
+
+               SGDValue x;
+               DAryProto *grad = x.add_grad();
+               for (int j = 0; j < tuple_sizes[i]; j++)
+                       grad->add_value(j * 1.0f);
+
+               table->update(key, x);
+       }
+       VLOG(3) << "Done updating " << SIZE << " tuples ...";
+}
+
+
+void worker_test_data() {
+       //get(size);
+       update();
+       update();
+       get();
+       /*
+       update(table, tuples);
+       update(table, tuples);
+       update(table, tuples);
+       get(table, tuples);
+       */
+}
+
+/**
+ * Shutdown the process.
+ */
+void shutdown() {
+       if (context->AmICoordinator()) {
+               EmptyMessage msg;
+               for (int i = 0; i < context->num_procs() - 1; i++)
+                       network->Read(MPI::ANY_SOURCE, MTYPE_WORKER_END, &msg);
+               EmptyMessage shutdown_msg;
+               for (int i = 0; i < network->size() - 1; i++) {
+                       network->Send(i, MTYPE_SHUTDOWN, shutdown_msg);
+               }
+               //network->Flush();
+               network->Shutdown();
+       } else {
+               //network->Flush();
+               network->Send(context->num_procs() - 1, MTYPE_WORKER_END,
+                               EmptyMessage());
+               EmptyMessage msg;
+               network->Read(context->num_procs() - 1, MTYPE_SHUTDOWN, &msg);
+
+               if (context->AmITableServer()){
+                       RequestDispatcher::Get()->PrintStats();
+                       table_server->ShutdownTableServer();
+               }
+
+               network->Shutdown();
+       }
+}
+
+/**
+ * Worker handle shard assignment from the coordinator.
+ */
+void HandleShardAssignment() {
+
+       ShardAssignmentRequest shard_req;
+       auto mpi = NetworkThread::Get();
+       mpi->Read(GlobalContext::kCoordinator, MTYPE_SHARD_ASSIGNMENT, 
&shard_req);
+
+       //  request read from coordinator
+       for (int i = 0; i < shard_req.assign_size(); i++) {
+               const ShardAssignment &a = shard_req.assign(i);
+               GlobalTable *t = tables.at(a.table());
+               t->get_partition_info(a.shard())->owner = a.new_worker();
+
+               //if local shard, create check-point files
+               if (FLAGS_checkpoint_enabled && t->is_local_shard(a.shard())) {
+                       string checkpoint_file = 
StringPrintf("%s/checkpoint_%d",
+                                       FLAGS_checkpoint_dir.c_str(), 
a.shard());
+                       char hostname[256];
+                       gethostname(hostname, sizeof(hostname));
+
+                       FILE *tmp_file = fopen(checkpoint_file.c_str(), "r");
+                       if (tmp_file) { //exists -> open to reading and writing
+                               fclose(tmp_file);
+                               auto cp = t->checkpoint_files();
+
+                               if (FLAGS_restore_mode) { //open in read mode 
to restore, then close
+                                       LogFile *file = new 
LogFile(checkpoint_file, "rw", 0);
+                                       int table_size = 
file->read_latest_table_size();
+                                       delete file;
+
+                                       double start = Now();
+                                       (*cp)[a.shard()] = new 
LogFile(checkpoint_file, "r",
+                                                       a.shard());
+                                       t->Restore(a.shard());
+                                       delete (*cp)[a.shard()];
+                                       double end = Now();
+                                       LOG(ERROR) << "restore time\t" << end - 
start << "\tfor\t"
+                                                       << table_size << 
"\tthreshold\t" << FLAGS_threshold;
+                               }
+                               char hostname[256];
+                               gethostname(hostname, sizeof(hostname));
+                               (*cp)[a.shard()] = new LogFile(checkpoint_file, 
"a", a.shard());
+                       } else { // not exist -> open to writing first time
+                               auto cp = t->checkpoint_files();
+                               (*cp)[a.shard()] = new LogFile(checkpoint_file, 
"w", a.shard());
+                       }
+               }
+       }
+
+       EmptyMessage empty;
+       mpi->Send(GlobalContext::kCoordinator, MTYPE_SHARD_ASSIGNMENT_DONE, 
empty);
+       VLOG(3) << "Done handling shard assignment ...";
+
+}
+
+
+int main(int argc, char **argv) {
+       FLAGS_logtostderr = 1;
+       int provided;
+       MPI_Init_thread(&argc, &argv, MPI_THREAD_MULTIPLE, &provided);
+       google::InitGoogleLogging(argv[0]);
+       gflags::ParseCommandLineFlags(&argc, &argv, true);
+
+       context = GlobalContext::Get(FLAGS_system_conf);
+       network = NetworkThread::Get();
+
+       ModelProto model;
+       ReadProtoFromTextFile(FLAGS_model_conf.c_str(), &model);
+
+       create_mem_table(0, context->num_table_servers());
+
+       if (context->AmICoordinator()) {
+               coordinator_assign_tables(0);
+               coordinator_load_data();
+               network->barrier();
+       } else {
+               if (context->AmITableServer()) {
+                       table_init();
+                       HandleShardAssignment();
+                       network->barrier();
+               } else {
+                       HandleShardAssignment();
+                       network->barrier();
+                       Sleep(1);
+                       VLOG(3) << "Worker cleared the barrier ...";
+                       worker_test_data();
+               }
+       }
+
+       shutdown();
+       return 0;
+}
+
+

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b2dc51d2/src/test/dist_test/test_tuple.cc
----------------------------------------------------------------------
diff --git a/src/test/dist_test/test_tuple.cc b/src/test/dist_test/test_tuple.cc
new file mode 100644
index 0000000..727f8e3
--- /dev/null
+++ b/src/test/dist_test/test_tuple.cc
@@ -0,0 +1,258 @@
+#include <cstdio>
+#include <iostream>
+#include <sstream>
+#include <string>
+#include <vector>
+
+#include "server.h"
+#include "proto/worker.pb.h"
+#include "utils/network_service.h"
+#include "core/common.h"
+#include "core/network_queue.h"
+#include "proto/model.pb.h"
+#include "proto/common.pb.h"
+#include "utils/global_context.h"
+
+/**
+ * @file test_tuple.cc
+ *
+ * Test performance of TableServer put/get/update operations.
+ */
+DECLARE_double(sleep_time);
+
+using namespace lapis;
+using namespace std;
+using std::vector;
+
+#define NKEYS 1000
+#define TUPLE_SIZE 50000000
+
+#ifndef FLAGS_v
+  DEFINE_int32(v, 3, "vlog controller");
+#endif
+
+
+#define SIZE 16
+#define THRESHOLD 500000
+int tuple_sizes[SIZE] = {37448736, 16777216, 4096000, 1327104, 884736, 884736, 
614400,14112,4096,4096,1000,384,384,256,256,96};
+vector<int> valsizes;
+int collect_size;
+int num_tuples;
+
+void Put(int tid, int size, int version) {
+       RequestBase request;
+       request.set_table(0);
+       request.set_source(NetworkService::Get()->id());
+       PutRequest *put_req = request.MutableExtension(PutRequest::name);
+       int shard = tid % GlobalContext::Get()->num_servers();
+       put_req->set_shard(shard);
+       TableData *tuple = put_req->mutable_data();
+
+       TKey* key = tuple->mutable_key();
+       TVal* val = tuple->mutable_value();
+
+       key->set_id(tid);
+       key->set_version(version);
+
+       DAryProto *data = val->mutable_data();
+       for (int i = 0; i < size; i++){
+               data->add_value(0.0f);
+       }
+
+       // TODO check the msg type
+       NetworkService::Get()->Send(shard, MTYPE_REQUEST, request);
+}
+
+void Update(int tid, int size, int version) {
+       RequestBase request;
+       request.set_table(0);
+       request.set_source(NetworkService::Get()->id());
+       UpdateRequest *update_req = 
request.MutableExtension(UpdateRequest::name);
+       int shard = tid % GlobalContext::Get()->num_servers();
+       update_req->set_shard(shard);
+       TableData *tuple = update_req->mutable_data();
+
+       TKey* key = tuple->mutable_key();
+       TVal* val = tuple->mutable_value();
+
+       key->set_id(tid);
+       key->set_version(version);
+
+       DAryProto *data = val->mutable_grad();
+       for (int i = 0; i < size; i++)
+               data->add_value(1.0f);
+       // TODO check the msg type
+       NetworkService::Get()->Send(shard, MTYPE_REQUEST, request);
+}
+
+void print_result(TableData *data){
+       TKey *key = data->mutable_key();
+       TVal *val = data->mutable_value();
+       int k = key->id();
+       VLOG(3) << "key = " << k;
+       string s;
+       for (int i=0; i<TUPLE_SIZE; i++)
+               s.append(to_string(val->mutable_data()->value(i))).append(" ");
+       VLOG(3) << "val = " <<s;
+}
+
+void AsyncGet(int tid, int version) {
+       RequestBase request;
+       request.set_table(0);
+       request.set_source(GlobalContext::Get()->rank()); 
//NetworkService::Get()->id());
+       GetRequest *get_req = request.MutableExtension(GetRequest::name);
+       int shard = tid % GlobalContext::Get()->num_servers();
+       get_req->set_shard(shard);
+
+       TKey *key = get_req->mutable_key();
+       key->set_id(tid);
+       key->set_version(version);
+       NetworkService::Get()->Send(shard, MTYPE_REQUEST, request);
+
+}
+
+void Collect(){
+       int count = collect_size;
+       double start_collect = Now();
+       while (count){
+               while (true) {
+                               Message *resp = 
NetworkService::Get()->Receive();
+                               if (!resp)
+                                       Sleep(FLAGS_sleep_time);
+                               else{
+                                       delete resp;
+                                       break;
+                               }
+                       }
+               count--;
+       }
+       double end_collect = Now();
+       VLOG(3) << "Collected " << collect_size << " tuples in " << 
(end_collect-start_collect);
+}
+
+/**
+ * Workers wait for the barrier, then one of them send SHUTDOWN message
+ * to all table servers.
+ */
+void worker_send_shutdown(int id){
+       auto gc = lapis::GlobalContext::Get();
+       NetworkService *network_service_ = NetworkService::Get().get();
+       MPI_Barrier(gc->workergroup_comm());
+       if (gc->rank()==id){
+               for (int i=0; i<gc->num_procs(); i++){
+                       if (gc->IsTableServer(i)){
+                               EmptyMessage msg;
+                               network_service_->Send(i, MTYPE_SHUTDOWN,msg);
+                       }
+               }
+       }
+}
+
+/**
+ * One worker with the specific ID puts, others wait.
+ */
+void worker_load_data(int id){
+       auto gc = lapis::GlobalContext::Get();
+       for (int i = 0; i < SIZE; i++) {
+               int m = tuple_sizes[i];
+               if (m < THRESHOLD)
+                       valsizes.push_back(m);
+               else {
+                       for (int j = 0; j < m / THRESHOLD; j++)
+                               valsizes.push_back(THRESHOLD);
+                       if (m % THRESHOLD)
+                               valsizes.push_back(m%THRESHOLD);
+               }
+       }
+       num_tuples = (int)valsizes.size();
+       collect_size = 0;
+       for (int i=0; i<num_tuples; i++)
+               if (i%gc->group_size()==gc->worker_id())
+                       collect_size++;
+
+       if (gc->rank()==id){
+               for (size_t i=0; i<valsizes.size(); i++)
+                       Put(i,valsizes[i],0);
+               VLOG(3) << "Done loading data, num_keys = "<<valsizes.size() << 
" process " << id;
+       }
+       VLOG(3) << "Collect size = " << collect_size;
+       MPI_Barrier(gc->workergroup_comm());
+}
+
+void worker_update_data() {
+       auto gc = lapis::GlobalContext::Get();
+       for (int i = 0; i < num_tuples; i++)
+               if (i%gc->group_size()==gc->worker_id())
+                       Update(i,valsizes[i],0);
+
+       VLOG(3) << "Done update ... for "<<collect_size << " tuples ";
+}
+
+/*
+ * Async get.
+ */
+void worker_get_data(){
+       auto gc = lapis::GlobalContext::Get();
+       for (int i=0; i<num_tuples; i++)
+               if (i%gc->group_size()==gc->worker_id())
+                       AsyncGet(i,0);
+       Collect();
+       VLOG(3) << "Done collect ...";
+}
+
+void start_network_service_for_worker(){
+       NetworkService *network_service_ = NetworkService::Get().get();
+       network_service_->Init(GlobalContext::Get()->rank(), 
Network::Get().get(), new SimpleQueue());
+       network_service_->StartNetworkService();
+}
+
+int main(int argc, char **argv) {
+       google::InitGoogleLogging(argv[0]);
+       gflags::ParseCommandLineFlags(&argc, &argv, true);
+
+       int provided;
+
+
+       MPI_Init_thread(&argc, &argv, MPI_THREAD_MULTIPLE, &provided);
+
+
+       FLAGS_logtostderr = 1;
+
+
+       // Init GlobalContext
+       Cluster cluster;
+       cluster.set_server_start(0);
+       cluster.set_server_end(8);
+       cluster.set_worker_start(8);
+       cluster.set_worker_end(24);
+       cluster.set_group_size(8);
+       cluster.set_data_folder("/data1/wangwei/lapis");
+
+       auto gc = lapis::GlobalContext::Get(cluster);
+
+       // worker or table server
+       if (gc->AmITableServer()) {
+               lapis::TableServer server;
+               SGDProto sgd;
+               sgd.set_learning_rate(0.01);
+               sgd.set_momentum(0.9);
+               sgd.set_weight_decay(0.1);
+               sgd.set_gamma(0.5);
+               sgd.set_learning_rate_change_steps(1);
+               server.Start(sgd);
+       } else {
+               start_network_service_for_worker();
+               worker_load_data(cluster.worker_start());
+               for (int i=0; i<10; i++){
+                       worker_update_data();
+                       worker_get_data();
+               }
+               worker_send_shutdown(cluster.worker_start());
+               NetworkService::Get()->Shutdown();
+       }
+       gc->Finalize();
+       MPI_Finalize();
+       VLOG(3) << "End, process "<< gc->rank();
+       return 0;
+}
+

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b2dc51d2/src/test/model/test_blob.cc
----------------------------------------------------------------------
diff --git a/src/test/model/test_blob.cc b/src/test/model/test_blob.cc
new file mode 100644
index 0000000..75f1921
--- /dev/null
+++ b/src/test/model/test_blob.cc
@@ -0,0 +1,58 @@
+// Copyright © 2014 Wei Wang. All Rights Reserved.
+// 2014-07-18 19:44
+#include <gtest/gtest.h>
+#include "proto/model.pb.h"
+#include "model/lapis.h"
+
+namespace lapis {
+class BlobTest : public ::testing::Test {
+ public:
+  BlobTest() : blob1(new Blob()), blob2(new Blob()) {}
+  ~BlobTest() {
+    delete blob1;
+    delete blob2;
+  }
+ protected:
+  Blob *blob1, *blob2;
+  Blob blob3, blob4;
+};
+
+TEST_F(BlobTest, Constructor) {
+  EXPECT_EQ(blob1->length(), 0);
+  EXPECT_EQ(blob1->width(), 0);
+  EXPECT_EQ(blob1->height(), 0);
+  EXPECT_EQ(blob3.length(), 0);
+  EXPECT_EQ(blob3.width(), 0);
+  EXPECT_EQ(blob3.height(), 0);
+  EXPECT_TRUE(blob2->dptr == nullptr);
+  EXPECT_TRUE(blob4.dptr == nullptr);
+}
+
+TEST_F(BlobTest, TestResize) {
+  blob1->Resize(10,1,1,1);
+  EXPECT_EQ(blob1->length(), 10);
+  EXPECT_EQ(blob1->num(), 10);
+  EXPECT_EQ(blob1->height(), 1);
+  EXPECT_EQ(blob1->width(), 1);
+  EXPECT_TRUE(blob1->dptr != nullptr);
+  blob2->Resize(4,1,1,3);
+  EXPECT_EQ(blob2->length(), 12);
+  EXPECT_EQ(blob2->num(), 4);
+  EXPECT_EQ(blob2->height(), 1);
+  EXPECT_EQ(blob2->width(), 3);
+  EXPECT_TRUE(blob2->dptr != nullptr);
+  blob3.Resize(5,1,4,3);
+  EXPECT_EQ(blob3.length(), 60);
+  EXPECT_EQ(blob3.num(), 5);
+  EXPECT_EQ(blob3.height(), 4);
+  EXPECT_EQ(blob3.width(), 3);
+  EXPECT_TRUE(blob3.dptr != nullptr);
+  blob4.Resize(6,5,4,3);
+  EXPECT_EQ(blob4.length(), 360);
+  EXPECT_EQ(blob4.num(), 6);
+  EXPECT_EQ(blob4.height(), 4);
+  EXPECT_EQ(blob4.width(), 3);
+  EXPECT_TRUE(blob4.dptr != nullptr);
+}
+
+}  // namespace lapis

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b2dc51d2/src/test/model/test_data_layer.cc
----------------------------------------------------------------------
diff --git a/src/test/model/test_data_layer.cc 
b/src/test/model/test_data_layer.cc
new file mode 100644
index 0000000..49519a5
--- /dev/null
+++ b/src/test/model/test_data_layer.cc
@@ -0,0 +1,178 @@
+// Copyright © 2014 Wei Wang. All Rights Reserved.
+// 2014-08-01 16:09
+
+#include <gtest/gtest.h>
+#include <glog/logging.h>
+#include <map>
+#include <vector>
+
+#include "model/data_layer.h"
+#include "model/trainer.h"
+#include "model/sgd_trainer.h"
+#include "model/conv_edge.h"
+#include "model/relu_layer.h"
+#include "proto/model.pb.h"
+
+#include "utils/proto_helper.h"
+
+namespace lapis {
+class ModelTest : public ::testing::Test {
+ public:
+  ModelTest () {
+    ReadProtoFromTextFile("src/test/data/model.conf", &model_proto);
+  }
+ protected:
+  ModelProto model_proto;
+};
+/**********************************************************************
+ * DataLayer Test
+ **********************************************************************/
+class DataLayerTest : public ModelTest {
+ public:
+   DataLayerTest() {
+     label_layer.Init(model_proto.net().layer(0));
+     img_layer.Init(model_proto.net().layer(1));
+     Trainer::InitDataSource(model_proto.trainer().train_data(), &sources);
+     EXPECT_EQ(2, sources.size());
+     sources[0]->LoadData(nullptr);
+     sources[1]->LoadData(nullptr);
+     DLOG(INFO)<<"after init datasources";
+     label_layer.Setup(2, TrainerProto::kBackPropagation, sources);
+     DLOG(INFO)<<"after setup label layer";
+     img_layer.Setup(2, TrainerProto::kBackPropagation, sources);
+     DLOG(INFO)<<"after setup img layer";
+   }
+   ~DataLayerTest() {
+     for(auto& source: sources)
+       delete source;
+   }
+ protected:
+  DataLayer img_layer, label_layer;
+  std::vector<DataSource*> sources;
+};
+
+TEST_F(DataLayerTest, InitSetupForward) {
+  EXPECT_TRUE(label_layer.HasInput());
+  EXPECT_TRUE(img_layer.HasInput());
+  EXPECT_STREQ("DataLayer", DataLayer::kType.c_str());
+
+  EXPECT_EQ(2, label_layer.feature(nullptr).num());
+  EXPECT_EQ(1, label_layer.feature(nullptr).channels());
+  EXPECT_EQ(1, label_layer.feature(nullptr).height());
+  EXPECT_EQ(1, label_layer.feature(nullptr).width());
+
+  EXPECT_EQ(2, img_layer.feature(nullptr).num());
+  EXPECT_EQ(3, img_layer.feature(nullptr).channels());
+  EXPECT_EQ(227, img_layer.feature(nullptr).height());
+  EXPECT_EQ(227, img_layer.feature(nullptr).width());
+
+  img_layer.Forward();
+}
+// TODO(wangwei) test this after outgoing edges are tested
+
+/**********************************************************************
+ * ConvEdge Test
+ **********************************************************************/
+class ConvEdgeTest : public DataLayerTest {
+ public:
+  ConvEdgeTest() {
+    relu.Init(model_proto.net().layer(2));
+    DLOG(INFO)<<"init both layers";
+    layer_map["input_img"]=&img_layer;
+    layer_map["hidden1_relu"]=&relu;
+
+    edge_proto=model_proto.net().edge(0);
+    convedge.Init(edge_proto, layer_map);
+    convedge.Setup(true);
+  }
+ protected:
+  std::map<std::string, Layer*> layer_map;
+  ConvEdge convedge;
+  EdgeProto edge_proto;
+  ReLULayer relu;
+};
+
+TEST_F(ConvEdgeTest, InitSetupForward) {
+  Layer* dest=layer_map.at("hidden1_relu");
+  Blob &b=dest->feature(&convedge);
+  EXPECT_EQ(0,b.num());
+  convedge.SetupTopBlob(&b);
+  int conv_height = (227 + 2 * edge_proto.pad() - edge_proto.kernel_size())
+    / edge_proto.stride() + 1;
+  int conv_width=conv_height;
+  CHECK_EQ(2, b.num());
+  CHECK_EQ(edge_proto.num_output(), b.channels());
+  CHECK_EQ(conv_height, b.height());
+  CHECK_EQ(conv_width, b.width());
+  DLOG(INFO)<<"after shape check";
+
+  Layer* src=layer_map["input_img"];
+  convedge.Forward(src->feature(&convedge), &b, true);
+}
+
+/**********************************************************************
+ * ReLULayer Test
+ **********************************************************************/
+class ReLULayerTest : public ConvEdgeTest {
+ public:
+  ReLULayerTest() {
+    relu.Setup(2, TrainerProto::kBackPropagation, sources);
+    relu_proto=model_proto.net().layer(3);
+  }
+ protected:
+  LayerProto relu_proto;
+};
+
+TEST_F(ReLULayerTest, ForwardWithoutDropout) {
+  EXPECT_EQ(2, relu.feature(&convedge).num());
+  EXPECT_EQ(2, relu.gradient(&convedge).num());
+
+  relu.Forward();
+}
+/**********************************************************************
+ * PoolingEdge Test
+class PoolingEdgeTest : public ReLULayerTest {
+ public:
+  PoolingEdgeTest() {
+    linearlayer.Init(model.net().layer(3));
+    pooledge.Init(model.net().edge(1));
+  }
+
+ protected:
+  PoolingEdge pooledge;
+  LinearLayer linearlayer;
+}
+ **********************************************************************/
+/**********************************************************************
+ * LinearLayer Test
+ **********************************************************************/
+
+/**********************************************************************
+ * LRNEdge Test
+ **********************************************************************/
+
+/**********************************************************************
+ * InnerProductEdge Test
+ **********************************************************************/
+
+/**********************************************************************
+ * SoftmaxLayerLossEdge Test
+ **********************************************************************/
+
+
+
+
+/**********************************************************************
+ * SGDTrainer Test
+ **********************************************************************/
+class SGDTrainerTest : public ModelTest {
+ protected:
+  SGDTrainer sgd;
+};
+
+TEST_F(SGDTrainerTest, Init) {
+  sgd.Init(model_proto.trainer());
+  EXPECT_TRUE(Trainer::phase==Phase::kInit);
+}
+
+}  // namespace lapis

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b2dc51d2/src/test/model/test_label_source.cc
----------------------------------------------------------------------
diff --git a/src/test/model/test_label_source.cc 
b/src/test/model/test_label_source.cc
new file mode 100644
index 0000000..9b25c2a
--- /dev/null
+++ b/src/test/model/test_label_source.cc
@@ -0,0 +1,59 @@
+// Copyright © 2014 Wei Wang. All Rights Reserved.
+// 2014-07-21 19:40
+
+#include <gtest/gtest.h>
+#include <glog/logging.h>
+#include "proto/model.pb.h"
+#include "disk/label_source.h"
+
+namespace lapis {
+class LabelSourceTest : public ::testing::Test {
+ public:
+  LabelSourceTest() {
+    DataSourceProto ds;
+    ds.set_path("src/test/data/label_source.dat");
+    ds.set_size(12);
+    ds.set_name("label source");
+    ls.Init(ds);
+  }
+
+ protected:
+  LabelSource ls;
+};
+
+TEST_F(LabelSourceTest, LoadData) {
+  auto ptr2names = ls.LoadData(nullptr);
+  EXPECT_EQ(12, ptr2names->size());
+  EXPECT_STREQ("img0.JPEG", ptr2names->at(0).c_str());
+  EXPECT_STREQ("img1.JPEG", ptr2names->at(1).c_str());
+  EXPECT_STREQ("img5.JPEG", ptr2names->at(5).c_str());
+  EXPECT_STREQ("img10.JPEG", ptr2names->at(10).c_str());
+  EXPECT_STREQ("img11.JPEG", ptr2names->at(11).c_str());
+}
+
+TEST_F(LabelSourceTest, GetData) {
+  ls.LoadData(nullptr);
+  Blob b;
+  b.Resize(1, 1, 1, 5);
+  ls.GetData(&b);
+  const float *val = b.dptr;
+  EXPECT_EQ(0.0f, val[0]);
+  EXPECT_EQ(1.0f, val[1]);
+  EXPECT_EQ(4.0f, val[2]);
+  EXPECT_EQ(9.0f, val[3]);
+  EXPECT_EQ(16.0f, val[4]);
+  ls.GetData(&b);
+  EXPECT_EQ(4.0f, val[0]);
+  EXPECT_EQ(5.0f, val[1]);
+  EXPECT_EQ(6.0f, val[2]);
+  EXPECT_EQ(7.0f, val[3]);
+  EXPECT_EQ(8.0f, val[4]);
+  ls.GetData(&b);
+  EXPECT_EQ(1.0f, val[0]);
+  EXPECT_EQ(2.0f, val[1]);
+  EXPECT_EQ(0.0f, val[2]);
+  EXPECT_EQ(1.0f, val[3]);
+  EXPECT_EQ(4.0f, val[4]);
+}
+
+}  // namespace lapis

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b2dc51d2/src/test/model/test_param.cc
----------------------------------------------------------------------
diff --git a/src/test/model/test_param.cc b/src/test/model/test_param.cc
new file mode 100644
index 0000000..520fbe2
--- /dev/null
+++ b/src/test/model/test_param.cc
@@ -0,0 +1,138 @@
+#include <gtest/gtest.h>
+#include <glog/logging.h>
+#include "proto/model.pb.h"
+
+#include "utils/param.h"
+
+using namespace singa;
+
+class ParamTest : public ::testing::Test {
+ public:
+  ParamTest() {
+    wp.set_name("weight");
+    wp.add_shape(3);
+    wp.add_shape(4);
+    bp.set_name("bias");
+    bp.add_shape(4);
+  }
+ protected:
+  Param w, b;
+  ParamProto wp, bp;
+};
+
+TEST_F(ParamTest, ConstantInit) {
+  bp.set_init_method(ParamProto::kConstant);
+  bp.set_value(0.5);
+  b.Init(bp);
+  const float *val = b.content().dptr;
+  EXPECT_EQ(0.5f, val[0]);
+  EXPECT_EQ(0.5f, val[1]);
+  EXPECT_EQ(0.5f, val[2]);
+  EXPECT_EQ(0.5f, val[3]);
+  wp.set_init_method(ParamProto::kConstant);
+  wp.set_value(1.5);
+  w.Init(wp);
+  val = w.content().dptr;
+  EXPECT_EQ(1.5f, val[0]);
+  EXPECT_EQ(1.5f, val[3]);
+  EXPECT_EQ(1.5f, val[4]);
+  EXPECT_EQ(1.5f, val[11]);
+}
+
+TEST_F(ParamTest, UniformInit) {
+  bp.set_init_method(ParamProto::kUniform);
+  bp.set_value(1.0f);
+  b.Init(bp);
+  const float *val = b.content().dptr;
+  EXPECT_TRUE(val[0] >= -1 && val[0] <= 1);
+  EXPECT_TRUE(val[1] >= -1 && val[2] <= 1);
+  EXPECT_TRUE(val[2] >= -1 && val[2] <= 1);
+  EXPECT_TRUE(val[3] >= -1 && val[3] <= 1);
+  wp.set_init_method(ParamProto::kUniform);
+  wp.set_value(1.0f);
+  w.Init(wp);
+  val = w.content().dptr;
+  EXPECT_TRUE(val[0] >= -1 && val[0] <= 1);
+  EXPECT_TRUE(val[3] >= -1 && val[3] <= 1);
+  EXPECT_TRUE(val[4] >= -1 && val[4] <= 1);
+  EXPECT_TRUE(val[11] >= -1 && val[11] <= 1);
+}
+
+TEST_F(ParamTest, UniformSqrtFanInInit) {
+  wp.set_init_method(ParamProto::kUniformSqrtFanIn);
+  wp.set_value(2.0f);
+  w.Init(wp);
+  const float *val = w.content().dptr;
+  EXPECT_TRUE(val[0] >= -2 && val[0] <= 2);
+  EXPECT_TRUE(val[3] >= -2 && val[3] <= 2);
+  EXPECT_TRUE(val[4] >= -2 && val[4] <= 2);
+  EXPECT_TRUE(val[11] >= -2 && val[11] <= 2);
+}
+
+
+TEST_F(ParamTest, UniformSqrtFanInOutInit) {
+  wp.set_init_method(ParamProto::kUniformSqrtFanInOut);
+  wp.set_value(1.0f);
+  float low=1.0f, high=5.0f;
+  wp.set_low(low);
+  wp.set_high(high);
+  w.Init(wp);
+  const float *val = w.content().dptr;
+  /*
+  LOG(INFO) << val[0] << " " << val[1] << " " << val[2] << " " << val[3];
+  LOG(INFO) << val[4] << " " << val[5] << " " << val[6] << " " << val[7];
+  LOG(INFO) << val[8] << " " << val[9] << " " << val[10] << " " << val[11];
+  */
+  float factor = wp.value() / sqrt(wp.shape(0) + wp.shape(1));
+  low=low*factor;
+  high=high*factor;
+  LOG(INFO)<<low<<" "<<high;
+  EXPECT_TRUE(val[0] >= low && val[0] <= high);
+  EXPECT_TRUE(val[3] >= low && val[3] <= high);
+  EXPECT_TRUE(val[4] >= low && val[4] <= high);
+  EXPECT_TRUE(val[11] >= low && val[11] <= high);
+}
+
+TEST_F(ParamTest, GaussianInit) {
+  int len=5000, mean=0.0f, std=1.0f;
+  ParamProto p;
+  p.set_name("bias");
+  p.add_shape(1);
+  p.add_shape(len);
+  p.set_init_method(ParamProto::kGaussain);
+  p.set_value(1.0f);
+  p.set_mean(mean);
+  p.set_std(std);
+  w.Init(p);
+
+  const float *val = w.content().dptr;
+  float dmean=0.0f;
+  for(int i=0;i<len;i++)
+    dmean+=val[i];
+  dmean/=len;
+  float dstd=0.0f;
+  for(int i=0;i<len;i++)
+    dstd+=(dmean-val[i])*(dmean-val[i]);
+  dstd/=len;
+  EXPECT_TRUE(std::abs(mean-dmean)<0.1);
+  EXPECT_TRUE(std::abs(std-dstd)<0.1);
+  /*
+  LOG(INFO) << val[0] << " " << val[1] << " " << val[2] << " " << val[3];
+  LOG(INFO) << val[4] << " " << val[5] << " " << val[6] << " " << val[7];
+  LOG(INFO) << val[8] << " " << val[9] << " " << val[10] << " " << val[11];
+  */
+}
+
+TEST_F(ParamTest, GaussianSqrtFanInInit) {
+  wp.set_init_method(ParamProto::kGaussainSqrtFanIn);
+  wp.set_value(1.0f);
+  wp.set_mean(0);
+  wp.set_std(1.0f);
+  w.Init(wp);
+  //const float *val = w.content().dptr;
+  /*
+  LOG(INFO) << val[0] << " " << val[1] << " " << val[2] << " " << val[3];
+  LOG(INFO) << val[4] << " " << val[5] << " " << val[6] << " " << val[7];
+  LOG(INFO) << val[8] << " " << val[9] << " " << val[10] << " " << val[11];
+  */
+}

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b2dc51d2/src/test/model/test_proto.cc
----------------------------------------------------------------------
diff --git a/src/test/model/test_proto.cc b/src/test/model/test_proto.cc
new file mode 100644
index 0000000..f6d81fd
--- /dev/null
+++ b/src/test/model/test_proto.cc
@@ -0,0 +1,67 @@
+// Copyright © 2014 Wei Wang. All Rights Reserved.
+// 2014-07-15 21:54
+#include <glog/logging.h>
+#include <gtest/gtest.h>
+#include "proto/model.pb.h"
+#include "utils/proto_helper.h"
+namespace lapis {
+
+// use const Message& m=..., otherwise may lead to segment fault
+TEST(ProtoTest, ReadFromFile) {
+  ModelProto model;
+  LOG(INFO)<<"start....";
+  lapis::ReadProtoFromTextFile("src/test/data/model.conf", &model);
+  LOG(INFO)<<"after reading file...";
+  EXPECT_STREQ("caffe_config", model.name().c_str());
+
+  // layer and edge size
+  const NetProto& net = model.net();
+  EXPECT_EQ(15, net.layer().size());
+  EXPECT_EQ(14, net.edge().size());
+  LOG(INFO)<<"after size check...";
+
+  // layer config
+  LayerProto layer1 = net.layer().Get(1);
+  EXPECT_STREQ("input_img", layer1.name().c_str());
+  EXPECT_STREQ("DataLayer", layer1.type().c_str());
+  LOG(INFO)<<"after datalayer check...";
+  // edge config
+  EdgeProto edge0 = net.edge().Get(0);
+  EXPECT_STREQ("input_img-hidden1_relu", edge0.name().c_str());
+  EXPECT_STREQ("ConvEdge", edge0.type().c_str());
+  EXPECT_EQ(2, edge0.param().size());
+  LOG(INFO)<<"after first edge check...";
+  // param config
+  ParamProto param1 = edge0.param().Get(0);
+  EXPECT_TRUE(ParamProto::kGaussain == param1.init_method());
+  EXPECT_EQ(0.0f, param1.mean());
+  EXPECT_EQ(0.01f, param1.std());
+  EXPECT_EQ(1.0f, param1.learning_rate_multiplier());
+  LOG(INFO)<<"after param of first edge check...";
+
+  ParamProto param2 = edge0.param().Get(1);
+  EXPECT_TRUE(ParamProto::kConstant == param2.init_method());
+  EXPECT_EQ(0.0f, param2.value());
+  EXPECT_EQ(0.0f, param2.weight_decay_multiplier());
+  LOG(INFO)<<"after param of second edge check...";
+
+  // trainer config
+  const TrainerProto& trainer = model.trainer();
+  const SGDProto& sgd=trainer.sgd();
+  EXPECT_EQ(227, sgd.train_batchsize());
+  EXPECT_EQ(0.01f, sgd.base_learning_rate());
+  EXPECT_TRUE(SGDProto::kStep== sgd.learning_rate_change());
+  LOG(INFO)<<"after sgd check...";
+
+  // data source config
+  EXPECT_EQ(2,trainer.train_data().size());
+  LOG(INFO)<<"after size check...";
+  const DataSourceProto& data=trainer.train_data(0);
+  LOG(INFO)<<"after get data...";
+  EXPECT_STREQ("RGBDirSource", data.type().c_str());
+  LOG(INFO)<<"after type check...";
+  EXPECT_EQ(50000, data.size());
+  EXPECT_EQ(3, data.channels());
+  LOG(INFO)<<"after data source check...";
+}
+} // namespace lapis

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b2dc51d2/src/test/model/test_rgb_dir_source.cc
----------------------------------------------------------------------
diff --git a/src/test/model/test_rgb_dir_source.cc 
b/src/test/model/test_rgb_dir_source.cc
new file mode 100644
index 0000000..36ac21a
--- /dev/null
+++ b/src/test/model/test_rgb_dir_source.cc
@@ -0,0 +1,63 @@
+// Copyright © 2014 Wei Wang. All Rights Reserved.
+// 2014-07-21 21:52
+
+#include <gtest/gtest.h>
+#include <glog/logging.h>
+#include <algorithm>
+
+#include "proto/model.pb.h"
+#include "disk/rgb_dir_source.h"
+#include "disk/label_source.h"
+
+namespace lapis {
+class RGBDirSourceTest : public ::testing::Test {
+ public:
+  RGBDirSourceTest() {
+    DataSourceProto ds;
+    ds.set_path("src/test/data/rgb_dir");
+    ds.set_mean_file("src/test/data/imagenet_mean.binaryproto");
+    ds.set_size(3);
+    ds.set_height(256);
+    ds.set_width(256);
+    ds.set_offset(2);
+    ds.set_name("rgb dir source");
+    rgbs.Init(ds);
+  }
+
+ protected:
+  RGBDirSource rgbs;
+};
+
+TEST_F(RGBDirSourceTest, LoadDataNoInputKeys) {
+  auto &ptr2names = rgbs.LoadData(nullptr);
+  EXPECT_EQ(3, ptr2names->size());
+  sort(ptr2names->begin(), ptr2names->end());
+  EXPECT_STREQ("img0.JPEG", ptr2names->at(0).c_str());
+  EXPECT_STREQ("img1.JPEG", ptr2names->at(1).c_str());
+  EXPECT_STREQ("img2.JPEG", ptr2names->at(2).c_str());
+}
+
+TEST_F(RGBDirSourceTest, LoadDataWithInputKeys) {
+  LabelSource ls;
+  DataSourceProto ds;
+  ds.set_path("src/test/data/label_source.dat");
+  ds.set_name("label source");
+  ds.set_size(3);
+  ls.Init(ds);
+  auto ptr2names1 = ls.LoadData(nullptr);
+  auto ptr2names2 = rgbs.LoadData(ptr2names1);
+  EXPECT_EQ(3, ptr2names2->size());
+  for (int i = 0; i < 3; i++)
+    EXPECT_STREQ(ptr2names1->at(i).c_str(), ptr2names2->at(i).c_str());
+}
+
+TEST_F(RGBDirSourceTest, GetData) {
+  Blob b;
+  b.Resize(256,256,3,2);
+  rgbs.LoadData(nullptr);
+  rgbs.GetData(&b);
+  rgbs.GetData(&b);
+  rgbs.GetData(&b);
+}
+}  // namespace lapis
+

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b2dc51d2/src/test/test_cluster.cc
----------------------------------------------------------------------
diff --git a/src/test/test_cluster.cc b/src/test/test_cluster.cc
new file mode 100644
index 0000000..d86463a
--- /dev/null
+++ b/src/test/test_cluster.cc
@@ -0,0 +1,95 @@
+#include <fstream>
+#include "gtest/gtest.h"
+#include "proto/cluster.pb.h"
+#include "utils/cluster.h"
+
+using namespace singa;
+
+string folder="src/test/data/";
+/*
+ClusterProto GenClusterProto(){
+  ClusterProto proto;
+  int nworker=6, nserver=4;
+  proto.set_nworkers(nworker);
+  proto.set_nservers(nserver);
+  proto.set_nworkers_per_group(3);
+  proto.set_nservers_per_group(2);
+  proto.set_nthreads_per_worker(1);
+  proto.set_nthreads_per_server(2);
+
+  proto.set_hostfile(folder+"/hostfile");
+
+  std::ofstream fout(folder+"/hostfile", std::ofstream::out);
+  for(int i=0;i<nworker+nserver;i++){
+    char tmp[20];
+    sprintf(tmp, "awan-0-%02d-0", i);
+    fout<<tmp<<std::endl;
+  }
+  fout.flush();
+  fout.close();
+  return proto;
+}
+
+TEST(ClusterTest, NoServer){
+  ClusterProto proto=GenClusterProto();
+  proto.set_nservers(0);
+  auto cluster=Cluster::Get(proto, 0);
+  ASSERT_EQ(proto.nworkers(),cluster->nworkers());
+  ASSERT_EQ(0, cluster->nservers());
+  ASSERT_EQ(proto.nworkers_per_group(),cluster->nworkers_per_group());
+  ASSERT_EQ(proto.nservers_per_group(),cluster->nservers_per_group());
+  ASSERT_FALSE(cluster->AmIServer());
+  ASSERT_TRUE(cluster->AmIWorker());
+  ASSERT_EQ(0,cluster->group_procs_id());
+  ASSERT_EQ(0,cluster->group_id());
+  ASSERT_EQ(2, cluster->nworker_groups());
+  ASSERT_EQ(0, cluster->nserver_groups());
+  ASSERT_STREQ("awan-0-00-0", cluster->host_addr().c_str());
+
+  cluster=Cluster::Get(proto, 5);
+  ASSERT_EQ(2,cluster->group_procs_id());
+  ASSERT_EQ(1,cluster->group_id());
+  ASSERT_EQ(2, cluster->nworker_groups());
+  ASSERT_EQ(0, cluster->nserver_groups());
+  ASSERT_STREQ("awan-0-05-0", cluster->host_addr().c_str());
+}
+
+TEST(ClusterTest, SingleServerGroup){
+  ClusterProto proto=GenClusterProto();
+  proto.set_nservers(2);
+  auto cluster=Cluster::Get(proto, 3);
+  ASSERT_FALSE(cluster->AmIServer());
+  ASSERT_TRUE(cluster->AmIWorker());
+  ASSERT_EQ(0,cluster->group_procs_id());
+  ASSERT_EQ(1,cluster->group_id());
+  ASSERT_EQ(2, cluster->nworker_groups());
+  ASSERT_EQ(1, cluster->nserver_groups());
+  ASSERT_STREQ("awan-0-03-0", cluster->host_addr().c_str());
+
+  cluster=Cluster::Get(proto, 7);
+  ASSERT_EQ(1,cluster->group_procs_id());
+  ASSERT_EQ(0,cluster->group_id());
+  ASSERT_EQ(2, cluster->nworker_groups());
+  ASSERT_EQ(1, cluster->nserver_groups());
+  ASSERT_STREQ("awan-0-07-0", cluster->host_addr().c_str());
+}
+
+TEST(ClusterTest, MultiServerGroups){
+  ClusterProto proto=GenClusterProto();
+  auto cluster=Cluster::Get(proto, 7);
+  ASSERT_EQ(1,cluster->group_procs_id());
+  ASSERT_EQ(0,cluster->group_id());
+  ASSERT_EQ(2, cluster->nworker_groups());
+  ASSERT_EQ(2, cluster->nserver_groups());
+  ASSERT_STREQ("awan-0-07-0", cluster->host_addr().c_str());
+
+  cluster=Cluster::Get(proto, 8);
+  ASSERT_TRUE(cluster->AmIServer());
+  ASSERT_FALSE(cluster->AmIWorker());
+  ASSERT_EQ(0,cluster->group_procs_id());
+  ASSERT_EQ(1,cluster->group_id());
+  ASSERT_EQ(2, cluster->nworker_groups());
+  ASSERT_EQ(2, cluster->nserver_groups());
+  ASSERT_STREQ("awan-0-08-0", cluster->host_addr().c_str());
+}
+*/

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b2dc51d2/src/test/test_communication.cc
----------------------------------------------------------------------
diff --git a/src/test/test_communication.cc b/src/test/test_communication.cc
new file mode 100644
index 0000000..c9c035f
--- /dev/null
+++ b/src/test/test_communication.cc
@@ -0,0 +1,158 @@
+#include <thread>
+#include <vector>
+#include "gtest/gtest.h"
+#include "communication/msg.h"
+#include "communication/socket.h"
+using std::vector;
+using namespace singa;
+
+const char* ping="PING",*pong="PONG";
+/**
+ * Connect dealer with (gid, id, flag) to stub router
+ */
+void Connect(Dealer* dealer, int gid, int id, int flag){
+  dealer->Connect("inproc://router");
+  Msg msg;
+  msg.set_src(gid, id, flag);
+  msg.set_dst(0,0,2);
+  msg.set_type(0);
+  msg.add_frame(ping, 4);
+  dealer->Send(&msg);
+}
+
+/**
+ * Dealer thread, ping-pong with the stub router
+ */
+void DealerPingPong(int id){
+  Dealer* dealer=new Dealer();
+  Connect(dealer, 0, id, 0);
+  Msg* msg=dealer->Receive();
+  int flag=msg->src_flag();
+  ASSERT_EQ(2, flag);
+  ASSERT_EQ(0, msg->dst_group_id());
+  ASSERT_EQ(id, msg->dst_id());
+  ASSERT_STREQ(pong, (char*)msg->frame_data());
+  delete msg;
+  delete dealer;
+}
+
+/**
+ * Worker thread, connect to router and communicate with server thread
+ */
+void WorkerDealer(int sid, int did){
+  Dealer* dealer=new Dealer();
+  Connect(dealer, 0, sid, 0);
+  for(int i=0;i<2;i++){
+    {
+      Msg msg;
+      msg.set_src(0, sid, 0);
+      msg.set_dst(0, did, 1);
+      msg.set_type(3);
+      msg.set_target(i);
+      dealer->Send(&msg);
+    }
+    {
+      Msg *msg=dealer->Receive();
+      ASSERT_EQ(0, msg->src_group_id());
+      ASSERT_EQ(did, msg->src_id());
+      ASSERT_EQ(1, msg->src_flag());
+      delete msg;
+    }
+  }
+  delete dealer;
+}
+
+/**
+ * Server thread, connect to router and communicate with worker thread
+ */
+void ServerDealer(int id, int n){
+  Dealer* dealer=new Dealer();
+  Connect(dealer, 0, id, 1);
+  for(int i=0;i<n;i++){
+    Msg *msg=dealer->Receive();
+    Msg reply;
+    reply.set_dst(msg->src_group_id(), msg->src_id(), msg->src_flag());
+    reply.set_src(0, id, 1);
+    dealer->Send(&reply);
+    delete msg;
+  }
+  delete dealer;
+}
+
+TEST(CommunicationTest, DealerRouterPingPong){
+  int n=2;
+  vector<std::thread> threads;
+  for(int i=0;i<n;i++)
+    threads.push_back(std::thread(DealerPingPong, i));
+  Router* router=new Router();
+  router->Bind("");
+  for(int k=0;k<n;k++){
+    Msg* msg=router->Receive();
+    ASSERT_EQ(0, msg->src_group_id());
+    ASSERT_EQ(2, msg->dst_flag());
+    ASSERT_STREQ(ping, (char*)msg->frame_data());
+
+    Msg reply;
+    reply.set_src(0,0,2);
+    reply.set_dst(msg->src_group_id(), msg->src_id(), msg->src_flag());
+    reply.add_frame(pong, 4);
+    router->Send(&reply);
+    delete msg;
+  }
+
+  delete router;
+  for(auto& thread:threads)
+    thread.join();
+}
+
+TEST(CommunicationTest, nWorkers1Server){
+  int nworker=2;
+  vector<std::thread> threads;
+  for(int i=0;i<nworker;i++)
+    threads.push_back(std::thread(WorkerDealer, i, 0));
+  //threads.push_back(std::thread(ServerDealer, 0, 4));
+  Router* router=new Router();
+  router->Bind("");
+  int nmsg=4*nworker;
+  int k=0;
+  while(nmsg>0){
+    Msg* msg=router->Receive();
+    if(2== msg->dst_flag()){
+      ASSERT_STREQ(ping, (char*)msg->frame_data());
+      k++;
+      if(k==nworker)
+        threads.push_back(std::thread(ServerDealer, 0, 2*nworker));
+    }else{
+      nmsg--;
+      router->Send(msg);
+    }
+    delete msg;
+  }
+  delete router;
+  for(auto& thread:threads)
+    thread.join();
+}
+
+TEST(CommunicationTest, 2Workers2Server){
+  vector<std::thread> threads;
+  threads.push_back(std::thread(WorkerDealer, 0, 0));
+  threads.push_back(std::thread(WorkerDealer, 1, 1));
+  threads.push_back(std::thread(ServerDealer, 0, 2));
+  threads.push_back(std::thread(ServerDealer, 1, 2));
+  Router* router=new Router();
+  router->Bind("");
+  int n=8;
+  while(n>0){
+    Msg* msg=router->Receive();
+    if(2== msg->dst_flag()){
+      ASSERT_STREQ(ping, (char*)msg->frame_data());
+    }else{
+      n--;
+      router->Send(msg);
+    }
+    delete msg;
+  }
+  delete router;
+  for(auto& thread:threads)
+    thread.join();
+}

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b2dc51d2/src/test/test_shard.cc
----------------------------------------------------------------------
diff --git a/src/test/test_shard.cc b/src/test/test_shard.cc
new file mode 100644
index 0000000..c96d876
--- /dev/null
+++ b/src/test/test_shard.cc
@@ -0,0 +1,56 @@
+#include <gtest/gtest.h>
+#include <sys/stat.h>
+
+#include "utils/data_shard.h"
+
+std::string key[]={"firstkey","secondkey","3key", "key4", "key5"};
+std::string tuple[]={"firsttuple","2th-tuple","thridtuple", "tuple4", 
"tuple5"};
+
+using namespace singa;
+
+TEST(DataShardTest, CreateDataShard){
+  std::string path="src/test/data/shard_test";
+  mkdir(path.c_str(), 0755);
+  DataShard shard(path, DataShard::kCreate, 50);
+  shard.Insert(key[0], tuple[0]);
+  shard.Insert(key[1], tuple[1]);
+  shard.Insert(key[2], tuple[2]);
+  shard.Flush();
+}
+
+TEST(DataShardTest, AppendDataShard){
+  std::string path="src/test/data/shard_test";
+  DataShard shard(path, DataShard::kAppend, 50);
+  shard.Insert(key[3], tuple[3]);
+  shard.Insert(key[4], tuple[4]);
+  shard.Flush();
+}
+TEST(DataShardTest, CountDataShard){
+  std::string path="src/test/data/shard_test";
+  DataShard shard(path, DataShard::kRead, 50);
+  int count=shard.Count();
+  ASSERT_EQ(5, count);
+}
+
+TEST(DataShardTest, ReadDataShard){
+  std::string path="src/test/data/shard_test";
+  DataShard shard(path, DataShard::kRead, 50);
+  std::string k, t;
+  ASSERT_TRUE(shard.Next(&k, &t));
+  ASSERT_STREQ(key[0].c_str(), k.c_str());
+  ASSERT_STREQ(tuple[0].c_str(), t.c_str());
+  ASSERT_TRUE(shard.Next(&k, &t));
+  ASSERT_STREQ(key[1].c_str(), k.c_str());
+  ASSERT_STREQ(tuple[1].c_str(), t.c_str());
+  ASSERT_TRUE(shard.Next(&k, &t));
+  ASSERT_TRUE(shard.Next(&k, &t));
+  ASSERT_TRUE(shard.Next(&k, &t));
+  ASSERT_STREQ(key[4].c_str(), k.c_str());
+  ASSERT_STREQ(tuple[4].c_str(), t.c_str());
+
+  ASSERT_FALSE(shard.Next(&k, &t));
+  shard.SeekToFirst();
+  ASSERT_TRUE(shard.Next(&k, &t));
+  ASSERT_STREQ(key[0].c_str(), k.c_str());
+  ASSERT_STREQ(tuple[0].c_str(), t.c_str());
+}

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b2dc51d2/src/trainer/pm_server.cc
----------------------------------------------------------------------
diff --git a/src/trainer/pm_server.cc b/src/trainer/pm_server.cc
new file mode 100644
index 0000000..28fa28d
--- /dev/null
+++ b/src/trainer/pm_server.cc
@@ -0,0 +1,99 @@
+#include <gflags/gflags.h>
+#include <glog/logging.h>
+#include "trainer/pm_server.h"
+#include "utils/singleton.h"
+#include "utils/factory.h"
+#include <vector>
+
+using std::vector;
+
+namespace singa{
+void PMServer::Setup(int group_id, int server_id, shared_ptr<ParamShard> shard,
+      const UpdaterProto& proto){
+  group_id_=group_id;
+  server_id_=server_id;
+  shard_=shard;
+  updater_=shared_ptr<Updater>(Singleton<Factory<Updater>>::Instance()
+      ->Create("Updater"));
+  updater_->Init(proto);
+}
+
+PMServer::~PMServer(){
+}
+
+bool PMServer::SyncNow(){
+  return false;
+}
+Msg* PMServer::HandlePut(Msg **msg){
+  int id=(*msg)->target();
+  shared_ptr<Param> param=nullptr;
+  if(shard_->find(id)!=shard_->end()){
+    LOG(ERROR)<<"Param ("<<id<<") is put more than once";
+    param=shard_->at(id);
+  }else{
+    param=shared_ptr<Param>(Singleton<Factory<Param>>::Instance()
+        ->Create("Param"));
+    param->set_id(id);
+    (*shard_)[id]=param;
+  }
+  return param->HandlePutMsg(msg);
+}
+
+Msg* PMServer::HandleGet(Msg **msg){
+  int id=(*msg)->target();
+  shared_ptr<Param> param=nullptr;
+  if(shard_->find(id)!=shard_->end()){
+    param=shard_->at(id);
+    return param->HandleGetMsg(msg);
+       } else {
+               //re-construct msg to be re-queued.
+               //the calling function will send this message off
+    return *msg;
+       }
+}
+
+Msg* PMServer::HandleUpdate(Msg **msg) {
+  int id=(*msg)->target();
+  shared_ptr<Param> param=nullptr;
+  if(shard_->find(id)!=shard_->end()){
+               //repsonse of the format: <identity><type: 
kData><paramId><param content>
+    param=shard_->at(id);
+    Msg* tmp=static_cast<Msg*>((*msg)->CopyAddr());
+    param->ParseUpdateMsg(msg);
+    updater_->Update(param->version(), param);
+    param->set_version(param->version()+1);
+    auto response=param->GenUpdateResponseMsg();
+    tmp->SwapAddr();
+    response->SetAddr(tmp);
+    delete tmp;
+    return response;
+       } else {
+    LOG(ERROR)<<"Param ("<<id<<") is not maintained by server ("<<group_id_
+      <<", "<<server_id_<<")";
+               //re-construct msg to be re-queued.
+               return *msg;
+       }
+}
+
+Msg* PMServer::HandleSyncRequest(Msg **msg){
+  int id=(*msg)->target();
+  shared_ptr<Param> param=nullptr;
+  if(shard_->find(id)!=shard_->end()){
+               //repsonse of the format: <identity><type: 
kData><paramId><param content>
+    param=shard_->at(id);
+    return param->HandleSyncMsg(msg);
+       } else {
+               //re-construct msg to be re-queued.
+    return *msg;
+       }
+}
+
+int PMServer::HandleSyncResponse(Msg **msg){
+  int id=(*msg)->target();
+  CHECK(shard_->find(id)!=shard_->end());
+  return shard_->at(id)->ParseSyncResponseMsg(msg);
+}
+
+} // namespace singa
+
+

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b2dc51d2/src/trainer/pm_worker.cc
----------------------------------------------------------------------
diff --git a/src/trainer/pm_worker.cc b/src/trainer/pm_worker.cc
new file mode 100644
index 0000000..7269578
--- /dev/null
+++ b/src/trainer/pm_worker.cc
@@ -0,0 +1,344 @@
+#include <sys/types.h>
+#include <sys/stat.h>
+#include <fcntl.h>
+#include "gflags/gflags.h"
+#include <glog/logging.h>
+#include "proto/model.pb.h"
+#include "trainer/pm_worker.h"
+#include "mshadow/tensor.h"
+#include "utils/cluster.h"
+
+
+namespace singa{
+
+void PMWorker::Setup(int group_id, int worker_id,
+    shared_ptr<ParamShard> shard){
+  group_id_=group_id;
+  worker_id_=worker_id;
+  shard_=shard;
+}
+int PMWorker::Sharding(int param_id){
+  return param_id%Cluster::Get()->nservers_per_group();
+}
+/*
+int PMWorker::Sharding(int param_id){
+  static map<int, int> id2procs;
+  if(id2procs.find(param_id)==id2procs.end()){
+  auto cluster=Cluster::Get();
+  int server_group=group_id_%cluster->nserver_groups();
+  int nprocs_per_server_group=
+    cluster->nservers_per_group()/cluster->nservers_per_procs();
+  int procsid=server_group*nprocs_per_server_group+
+    param_id%nprocs_per_server_group;
+  procsid= cluster->server_worker_separate()?
+    cluster->nworker_procs()+procsid:procsid;
+  id2procs[param_id]=procsid;
+  }
+  return id2procs[param_id];
+}
+*/
+
+Msg* PMWorker::Put(Msg** msg){
+  return *msg;
+}
+
+Msg* PMWorker::Put(shared_ptr<Param> param, int step){
+  param->set_version(step);
+  // only owner can put shared parameter
+  if(param->owner()<0||param->owner()==param->id()){
+    Msg* msg= param->GenPutMsg(&step);
+    msg->set_src(group_id_, worker_id_, kWorkerParam);
+    msg->set_dst(group_id_/Cluster::Get()->nworker_groups_per_server_group(),
+        Sharding(param->id()), kServer);
+    msg->set_type(kPut);
+    msg->set_target(param->id());
+    return msg;
+  }else
+    return nullptr;
+}
+
+Msg* PMWorker::Get(Msg** msg){
+  return *msg;
+}
+
+Msg* PMWorker::Get(shared_ptr<Param> param, int step){
+  param->set_version(step);
+  bool send=false;
+  int id=param->id();
+  shared_ptr<ParamCounter> entry=nullptr;
+  if(param->owner()>=0){
+    entry=shard_->at(id);
+    entry->nGet++;
+    send=entry->nGet/entry->nLocal==step;
+  }
+  if(param->owner()<0||send){
+    Msg* msg=nullptr;
+    if(param->owner()<0){
+      msg=param->GenGetMsg(&step);
+      msg->set_dst(group_id_/Cluster::Get()->nworker_groups_per_server_group(),
+          Sharding(id), kServer);
+    } else {
+      msg=entry->param->GenGetMsg(&step);
+      msg->set_dst(entry->owner_procs,kStub);
+    }
+    msg->set_src(group_id_, worker_id_, kWorkerParam);
+    msg->set_type(kGet);
+    msg->set_target(id);
+    return msg;
+  }else
+    return nullptr;
+}
+
+Msg* PMWorker::Update(Msg** msg){
+  return *msg;
+}
+Msg* PMWorker::Update(shared_ptr<Param> param, int step){
+  param->set_version(step);
+  bool send=false;
+  int id=param->id();
+  shared_ptr<ParamCounter> entry;
+  if(param->owner()>=0){
+    entry=shard_->at(param->id());
+    entry->nGet++;
+    send=entry->nGet/entry->nLocal==step;
+    auto shape=mshadow::Shape1(param->size());
+    mshadow::Tensor<mshadow::cpu,1> grad(param->mutable_cpu_grad(), shape);
+    mshadow::Tensor<mshadow::cpu,1> agg(entry->param->mutable_cpu_grad(), 
shape);
+    agg+=grad;
+  }
+  if(param->owner()<0||send){
+    Msg* msg=nullptr;
+    if(param->owner()<0){
+      msg=param->GenUpdateMsg(&step);
+      msg->set_dst(group_id_/Cluster::Get()->nworker_groups_per_server_group(),
+          Sharding(id), kServer);
+    } else {
+      entry->param->GenUpdateMsg(&step);
+      msg->set_dst(entry->owner_procs,kStub);
+      memset(param->mutable_cpu_data(), 0, sizeof(float)*param->size());
+    }
+    msg->set_type(kUpdate);
+    msg->set_target(id);
+    msg->set_src(group_id_, worker_id_, kWorkerParam);
+    return msg;
+  }else
+    return nullptr;
+}
+
+Msg* PMWorker::Collect(Msg** msg){
+  int id=(*msg)->target();
+  int type=(*msg)->type();
+  auto pp=shard_->at(id)->param;
+  if(type==kRGet){
+    pp->ParseGetResponseMsg(msg);
+  }else if(type==kRUpdate){
+    pp->ParseUpdateResponseMsg(msg);
+  }
+  if(pp->owner()>=0){
+    // forwarding to workers on other procs
+  }
+  delete (*msg);
+  *msg=nullptr;
+  return nullptr;
+}
+
+/*
+//id is the global worker id
+SingaClient::SingaClient(int global_id, Topology &topology, vector<string> 
&hosts) {
+       //Read the config files and store endpoints
+       id_ = global_id;
+
+       int n_workers = hosts.size() - topology.nservers();
+       int n_worker_groups = topology.nworker_groups();
+       int group_size = n_workers/n_worker_groups;
+       int server_group_size = 
topology.nservers()/topology.server_group_size();
+       FLAGS_client_threads = topology.worker_threads();
+
+       local_id_ = (id_-topology.nservers())%group_size;//local worker id.
+       group_id_ = (id_-topology.nservers())/group_size;
+
+       VLOG(3) << "Parsing client config for "<<hosts[id_];
+
+       //connect to all server in the server group group_id_
+       int start_server_idx = group_id_*server_group_size;
+       int end_server_idx = start_server_idx+server_group_size;
+
+       for (int i = start_server_idx; i < end_server_idx; i++) {
+               char *neighbor_endpoint = (char*) malloc(256);
+               sprintf(neighbor_endpoint, "tcp://%s:%d", hosts[i].c_str(), 
topology.port());
+               neighbors_.push_back(neighbor_endpoint);
+               VLOG(3) << "Worker neighbor (server): "<<neighbor_endpoint;
+       }
+
+       sprintf(backend_endpoint_, "inproc://singanus%d",id_);
+
+       //Create shared paramshard
+       param_shard_ = new ParamShard(id_,0);
+}
+
+void SingaClient::StartClient(){
+       //Create and connect sockets to the server
+       vector<void *> server_sockets;
+       zctx_t *context = zctx_new();
+       int nservers = neighbors_.size();
+       int rc;
+       for (int i=0; i<nservers; i++){
+               void *socket = zsocket_new(context, ZMQ_DEALER);
+               rc = zsocket_connect(socket, neighbors_[i]);
+               VLOG(3) << "Connected to neighbor " <<neighbors_[i];
+               assert(rc==0);
+               server_sockets.push_back(socket);
+       }
+
+       //Create and bind backend socket
+       void *backend = zsocket_new(context, ZMQ_ROUTER);
+       rc = zsocket_bind(backend, backend_endpoint_);
+       assert(rc==0);
+
+       //Start client threads
+       for (int i=0; i<FLAGS_client_threads; i++){
+               void * socket = zthread_fork(context, ClientThread, this);
+               zmsg_t *control_msg = zmsg_new();
+               if (i==0 && local_id_==0)
+                       zmsg_pushstr(control_msg,POPULATE);
+               else
+                       zmsg_pushstr(control_msg, WAIT);
+               zmsg_send(&control_msg, socket);
+       }
+
+       //Star the message loop
+       bool is_running = true;
+       int nsockets= nservers+1;
+       while (is_running) {
+               zmq_pollitem_t items[nsockets];
+               for (int i = 0; i < nsockets-1; i++)
+                       items[i] = {server_sockets[i], 0, ZMQ_POLLIN, 0};
+               items[nsockets-1] = {backend, 0, ZMQ_POLLIN, 0};
+
+               int rc = zmq_poll(items,nsockets,-1);
+               if (rc<0) break;
+
+               for (int i=0; i<nsockets-1; i++){
+                       if (items[i].revents & ZMQ_POLLIN){
+                               zmsg_t *msg = zmsg_recv(server_sockets[i]);
+                               if (!msg){
+                                       is_running = false;
+                                       break;
+                               }
+                               //forward to backend
+                               zmsg_send(&msg, backend);
+                       }
+               }
+               if (items[nsockets-1].revents & ZMQ_POLLIN){
+                       //compute serverId from paramId and forward to the 
socket
+                       zmsg_t *msg = zmsg_recv(backend);
+                       if (!msg) is_running=false;
+                       zframe_t *identity = zmsg_pop(msg);
+                       zframe_t *type = zmsg_pop(msg);
+                       int paramId;
+                       sscanf(zmsg_popstr(msg), "%d", &paramId);
+                       zmsg_pushstrf(msg,"%d",paramId);
+                       zmsg_prepend(msg,&type);
+                       zmsg_prepend(msg,&identity);
+                       zmsg_send(&msg, 
server_sockets[param_to_server_id(paramId)]);
+               }
+       }
+
+       zsocket_destroy(context, backend);
+       for (int i=0; i<nsockets-1; i++)
+               zsocket_destroy(context, server_sockets[i]);
+       zctx_destroy(&context);
+}
+
+vector<Param*> gen_random_params() {
+       int size[] = { 1960000, 2500, 5000000, 2000, 3000000, 1500, 1500000, 
1000, 500000, 500, 5000, 10 };
+       vector<Param*> params;
+       for (int i = 0; i < 12; i++) {
+               ParamProto proto;
+               proto.set_id(i);
+               proto.set_init_method(ParamProto::kGaussain);
+               Param* p = new Param();
+               p->Setup(proto, vector<int> { size[i] }, 0);
+               p->Init();
+               params.push_back(p);
+       }
+       return params;
+}
+
+//simple mapping
+int SingaClient::param_to_server_id(int paramId){
+       return paramId % neighbors_.size();
+}
+
+void ClientThread(void *args, zctx_t *ctx, void *pipe){
+       SingaClient *client = static_cast<SingaClient*>(args);
+
+       //Create back-end socket and connect to the main thread
+       void *backend = zsocket_new(ctx, ZMQ_DEALER);
+       int rc = zsocket_connect(backend, client->backend_endpoint());
+       assert(rc==0);
+       //Create PMClient object
+       PMClient *pmclient = new PMClient(client->id(), client->param_shard(), 
backend);
+
+       //FOR TESTING ONLY. REMOVE THIS!
+       //wait for control from main thread
+       vector<Param*> params = gen_random_params();
+       zmsg_t *control_msg = zmsg_recv(pipe);
+       zframe_t *msg = zmsg_pop(control_msg);
+       if (zframe_streq(msg,WAIT))
+               zclock_sleep(2000); //2s
+       else{
+               for (int i=0; i<params.size(); i++){
+                       pmclient->Put(i, params[i]);
+               }
+               VLOG(3)<<"Done PUT requests for populating servers.";
+               zclock_sleep(2000);
+       }
+       zframe_destroy(&msg);
+       //END TESTING
+       LOG(ERROR) << "Done putting";
+
+       //first, get the params
+
+       test_get(pmclient);
+       test_collect(pmclient);
+
+
+       int iterations = 1;
+       while (iterations<=200){
+               VLOG(3) << "Iteration "<<iterations;
+               test_update(pmclient, params);
+               test_collect(pmclient);
+               iterations++;
+       }
+
+       zsocket_destroy(ctx, backend);
+}
+
+void test_get(PMClient *client){
+       for (int i=0; i<12; i++){
+               Param pm;
+               int status = client->Get(i, &pm);
+               assert(status==NON_LOCAL);
+       }
+}
+
+void test_collect(PMClient *client){
+       for (int i=0; i<12; i++){
+               Param pm;
+               int64_t start_time = zclock_time();
+               while (!client->Collect(&pm))
+                       zclock_sleep(1);
+               int64_t end_time = zclock_time();
+               VLOG(3) << "Collected: " <<(end_time-start_time);
+       }
+}
+
+void test_update(PMClient *client, vector<Param*> params){
+       for (int i=0; i<params.size(); i++)
+               client->Update(i, params[i]);
+}
+*/
+
+
+} //namespace singa

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b2dc51d2/src/trainer/server.cc
----------------------------------------------------------------------
diff --git a/src/trainer/server.cc b/src/trainer/server.cc
new file mode 100644
index 0000000..bf0ad03
--- /dev/null
+++ b/src/trainer/server.cc
@@ -0,0 +1,68 @@
+#include <list>
+#include <tuple>
+#include <queue>
+#include "trainer/server.h"
+#include "utils/param.h"
+#include "utils/singleton.h"
+#include "utils/factory.h"
+#include "utils/cluster.h"
+
+
+namespace singa {
+Server::Server(int group_id, int server_id):
+  group_id_(group_id), server_id_(server_id){}
+
+void Server::Setup(const UpdaterProto& proto,
+    shared_ptr<PMServer::ParamShard> shard,
+    shared_ptr<Dealer> dealer){
+       //VLOG(3) << "Parsing config file for host "<<hosts[id_] << " server id 
= " <<id_;
+  pmserver_=shared_ptr<PMServer>(Singleton<Factory<PMServer>>::Instance()
+      ->Create("PMServer"));
+  pmserver_->Setup(group_id_, server_id_, shard, proto);
+  dealer_=dealer;
+}
+
+void Server::Run(){
+  Msg* ping=new Msg();
+  ping->set_src(group_id_, server_id_, kServer);
+  ping->set_dst(0,0,kStub);
+  ping->set_type(kConnect);
+  dealer_->Send(ping);
+  int timeout=Cluster::Get()->server_timeout();
+  Poller poller;
+  poller.Add(dealer_.get());
+       //start recv loop and process requests
+  while (true){
+    Msg* msg=dealer_->Receive();
+    if (msg==nullptr)
+      break;
+    Msg* response=nullptr;
+    int type=msg->type();
+    switch (type){
+      case kPut:
+        response = pmserver_->HandlePut(&msg);
+        break;
+      case kGet:
+        response = pmserver_->HandleGet(&msg);
+        break;
+      case kUpdate:
+        response = pmserver_->HandleUpdate(&msg);
+        break;
+      case kSyncRequest:
+        VLOG(3)<<"Handle SYNC-REQUEST";
+        response = pmserver_->HandleSyncRequest(&msg);
+        break;
+      case kSyncResponse:
+        VLOG(3) << "Handle SYNC response";
+        pmserver_->HandleSyncResponse(&msg);
+        break;
+    }
+
+    if (response!=nullptr)
+      dealer_->Send(response);
+  }
+}
+
+
+
+} /* singa */

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b2dc51d2/src/trainer/trainer.cc
----------------------------------------------------------------------
diff --git a/src/trainer/trainer.cc b/src/trainer/trainer.cc
new file mode 100644
index 0000000..3621b7e
--- /dev/null
+++ b/src/trainer/trainer.cc
@@ -0,0 +1,206 @@
+#include <thread>
+#include <vector>
+#include <map>
+#include <glog/logging.h>
+#include "trainer/trainer.h"
+using std::vector;
+using std::map;
+
+namespace singa {
+int ProcsIDOf(int group_id, int id, int flag){
+  int procsid;
+  auto cluster=Cluster::Get();
+  if(flag==kServer){
+    procsid=group_id*cluster->nservers_per_group()/
+      cluster->nservers_per_procs()+id/cluster->nservers_per_procs();
+    if(cluster->server_worker_separate())
+      procsid+=cluster->nworker_procs();
+  }else if(flag==kWorkerLayer || flag==kWorkerParam){
+    procsid=group_id*cluster->nworkers_per_group()
+      /cluster->nworkers_per_procs();
+    if(cluster->nworkers_per_group()>cluster->nworkers_per_procs())
+      procsid+=id/cluster->nworkers_per_procs();
+  }else{
+    LOG(ERROR)<<"Unkown flag ("<<flag<<")";
+  }
+  return procsid;
+}
+
+void Trainer::RegisterDefaultClasses(const singa::ModelProto& proto){
+  // register all layers appearing in the neural net
+  singa::NeuralNet::RegisterLayers();
+  Singleton<Factory<singa::Param>>::Instance()->Register(
+      "Param", CreateInstance(singa::Param, singa::Param));
+  Singleton<Factory<singa::Updater>>::Instance() ->Register(
+      "Updater", CreateInstance(singa::SGDUpdater, singa::Updater));
+  Singleton<Factory<singa::PMWorker>>::Instance() ->Register(
+      "PMWorker", CreateInstance(singa::PMWorker, singa::PMWorker));
+  Singleton<Factory<singa::PMServer>>::Instance() ->Register(
+      "PMServer", CreateInstance(singa::PMServer, singa::PMServer));
+  Singleton<Factory<singa::PMServer>>::Instance() ->Register(
+      "PMServer", CreateInstance(singa::PMServer, singa::PMServer));
+}
+
+void Trainer::Start(const ModelProto& mproto, const ClusterProto& cproto,
+    int procs_id){
+  RegisterDefaultClasses(mproto);
+
+  auto cluster=Cluster::Get(cproto, procs_id);
+  // create servers
+  vector<shared_ptr<Server>> servers;
+  int nSocket=1; // the first socket is the router
+  if(cluster->has_server()){
+    int pid=cluster->procs_id();
+    if(cluster->server_worker_separate())
+      pid-=cluster->nworker_procs();
+    int gid=pid*cluster->nservers_per_procs()/cluster->nservers_per_group();
+    int start=pid*cluster->nservers_per_procs()%cluster->nservers_per_group();
+    int end=start+cluster->nservers_per_group();
+    // the ParamShard for servers consists of a dictionary of Param objects
+    auto shard=make_shared<PMServer::ParamShard>();
+    for(int sid=start;sid<end;sid++){
+      auto server=make_shared<Server>(gid, sid);
+      auto dealer=make_shared<Dealer>(nSocket++);
+      dealer->Connect(kInprocRouterEndpoint);
+      server->Setup(mproto.updater(), shard, dealer);
+      servers.push_back(server);
+    }
+  }
+
+  // create workers
+  vector<shared_ptr<Worker>> workers;
+  if(cluster->has_worker()){
+    auto net=NeuralNet::SetupNeuralNet(mproto.neuralnet(), kTrain);
+    int pid=cluster->procs_id();
+    int gstart, gend, wstart, wend;
+    if(cluster->nworkers_per_group()>=cluster->nworkers_per_procs()){
+      // all workers in this procs are from the same group
+      gstart=pid*cluster->nworkers_per_procs()/cluster->nworkers_per_group();
+      gend=gstart+1;
+      wstart=pid*cluster->nworkers_per_procs()%cluster->nworkers_per_group();
+      wend=wstart+cluster->nworkers_per_group();
+    }else{
+      // there are multiple groups in this procs
+      CHECK_EQ(cluster->nworkers_per_procs()%cluster->nworkers_per_group(),0);
+      int groups_per_procs=
+        cluster->nworkers_per_procs()/cluster->nworkers_per_group();
+      gstart=pid*groups_per_procs;
+      gend=(pid+1)*groups_per_procs;
+      wstart=0;
+      wend=cluster->nworkers_per_group();
+    }
+    for(int gid=gstart;gid<gend;gid++){
+      shared_ptr<NeuralNet> train_net, test_net, validation_net;
+      if(gid==gstart)
+        train_net=net;
+      else{
+        train_net=NeuralNet::SetupNeuralNet(mproto.neuralnet(), kTrain);
+        // the train net for other groups may share parameter values from the
+        // first group
+        if(mproto.hogwild())
+          train_net->ShareParams(net, kValueOnly);
+      }
+      if(gid==0){
+        // validation and test are performed only by the first group
+        if(mproto.test_steps()){
+          test_net=NeuralNet::SetupNeuralNet(mproto.neuralnet(), kTest);
+          if(test_net!=nullptr)
+            test_net->ShareParams(train_net, kValueOnly);
+        }
+        if(mproto.validation_steps()){
+          validation_net=NeuralNet::SetupNeuralNet(mproto.neuralnet(), 
kValidation);
+          if(validation_net!=nullptr)
+            validation_net->ShareParams(train_net, kValueOnly);
+        }
+      }
+      // create ParamShard for the workers
+      auto shard=make_shared<PMWorker::ParamShard>();
+      for(auto layer: train_net->layers()){
+        int procsid=ProcsIDOf(gid, layer->locationid(),kWorkerParam);
+        int local=procsid==cluster->procs_id();
+        for(auto param: layer->GetParams()){
+          int owner=param->owner()<0||param->owner()==param->id()?procsid:-1;
+          if(shard->find(param->id())==shard->end())
+            (*shard)[param->id()]=make_shared<ParamCounter>(param, local, 
owner);
+          else
+            shard->at(param->id())->AddParam(param, local, owner);
+        }
+      }
+      for(int wid=wstart;wid<wend;wid++){
+        shared_ptr<Worker> worker=nullptr;
+        if(mproto.alg()==ModelProto_GradCalcAlg_kBackPropagation)
+          worker=make_shared<BPWorker>(gid, wid);
+        else{
+        // TODO add CDWorker
+        }
+        auto layer_dealer=make_shared<Dealer>(nSocket++);
+        auto param_dealer=make_shared<Dealer>(nSocket++);
+        layer_dealer->Connect(kInprocRouterEndpoint);
+        param_dealer->Connect(kInprocRouterEndpoint);
+        worker->Setup(mproto, train_net, shard, layer_dealer, param_dealer);
+        worker->set_test_net(test_net);
+        worker->set_validation_net(validation_net);
+        workers.push_back(worker);
+      }
+    }
+  }
+
+#ifdef USE_MPI
+  for(int i=0;i<nSocket;i++){
+    MPIQueues.push_back(make_shared<SafeQueue>());
+  }
+#endif
+  vector<std::thread> threads;
+  for(auto server: servers)
+    threads.push_back(std::thread(&Server::Run,server));
+  for(auto worker: workers)
+    threads.push_back(std::thread(&Worker::Run,worker));
+  Run();
+  for(auto& thread: threads)
+    thread.join();
+}
+
+void Trainer::Run(){
+  auto cluster=Cluster::Get();
+  auto router=make_shared<Router>();
+  router->Bind(kInprocRouterEndpoint);
+  if(cluster->nprocs()>1)
+    router->Bind(cluster->endpoint());
+
+  map<int, shared_ptr<Dealer>> interprocs_dealers;
+  Poller poller;
+  poller.Add(router.get());
+  int timeout=cluster->stub_timeout();
+  while(true){
+    Msg* msg=router->Receive();
+    if(msg==nullptr){
+      LOG(ERROR)<<"Connection broken!";
+      exit(0);
+    }
+    int dst_flag=msg->dst_flag();
+    int type=msg->type();
+    int group_id, id, procs_id;
+    switch (dst_flag){ // TODO process other requests, e.g. RESTful
+      case kStub:
+        if(type==kConnect){
+          delete msg;
+        }else{
+          // TODO processing requests for worker group spanning multiple procs.
+          LOG(ERROR)<<"Unkown message type ("<<type<<") to stub";
+        }
+        break;
+      default:
+        group_id=msg->dst_group_id();
+        id=msg->dst_id();
+        procs_id=ProcsIDOf(group_id, id, dst_flag);
+        if(procs_id!=cluster->procs_id()){
+          if (interprocs_dealers.find(procs_id)==interprocs_dealers.end())
+            interprocs_dealers[procs_id]=make_shared<Dealer>(procs_id);
+          interprocs_dealers[procs_id]->Send(msg);
+        } else
+          router->Send(msg);
+        break;
+    }
+  }
+}
+} /* singa */

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b2dc51d2/src/trainer/worker.cc
----------------------------------------------------------------------
diff --git a/src/trainer/worker.cc b/src/trainer/worker.cc
new file mode 100644
index 0000000..047ec2d
--- /dev/null
+++ b/src/trainer/worker.cc
@@ -0,0 +1,299 @@
+#include <glog/logging.h>
+#include <thread>
+#include <memory>
+#include <iostream>
+#include "utils/singleton.h"
+#include "utils/factory.h"
+#include "trainer/worker.h"
+#include "proto/model.pb.h"
+using std::thread;
+namespace singa {
+Worker::Worker( int group_id, int worker_id):
+   group_id_(group_id), worker_id_(worker_id){
+}
+
+void Worker::Setup(const ModelProto& model,
+    shared_ptr<NeuralNet> train_net,
+    shared_ptr<PMWorker::ParamShard> shard,
+    shared_ptr<Dealer> layer_dealer,
+    shared_ptr<Dealer> param_dealer){
+  train_net_=train_net;
+  modelproto_=model;
+  layer_dealer_=layer_dealer;
+  param_dealer_=param_dealer;
+  if(layer_dealer_!=nullptr)
+    layer_poller_.Add(layer_dealer_.get());
+  if(param_dealer_!=nullptr)
+    param_poller_.Add(param_dealer_.get());
+  pmworker_=shared_ptr<PMWorker>(Singleton<Factory<PMWorker>>::Instance()
+      ->Create("PMWorker"));
+  pmworker_->Setup(group_id_, worker_id_, shard);
+  step_=modelproto_.step();
+  // init params
+  for(auto layer: train_net->layers())
+    if(group_id_==0&&layer->locationid()==worker_id_)
+      for(auto param: layer->GetParams()){
+        if(param->owner()<0||param->owner()==param->id()){
+          param->Init();
+          Put(param, step_);
+        }
+        Get(param, step_);
+      }
+}
+
+void Worker::Run(){
+  step_=modelproto_.step();
+  Performance perf(train_net_);
+  try{
+    while(!StopNow(step_)){
+      RunOneBatch(step_, &perf);
+      step_++;
+    }
+  }catch(WorkerException& e){
+    LOG(ERROR)<<e.what();
+  }
+}
+int Worker::Put(shared_ptr<Param> param, int step){
+  auto msg=pmworker_->Put(param, step);
+  if(msg!=nullptr)
+    param_dealer_->Send(msg);
+  return 1;
+}
+int Worker::Get(shared_ptr<Param> param, int step){
+  if(param->version()<step){
+    auto msg=pmworker_->Get(param, step);
+    if(msg!=nullptr)
+      param_dealer_->Send(msg);
+  }
+  return 1;
+}
+int Worker::Update(shared_ptr<Param> param, int step){
+  auto msg=pmworker_->Update(param, step);
+  if(msg!=nullptr)
+    param_dealer_->Send(msg);
+  return 1;
+}
+int Worker::Collect(shared_ptr<Param> param, int step){
+  while(param->version()<step){
+    Msg* msg=param_dealer_->Receive();
+    if(msg==nullptr)
+      return 0;
+    pmworker_->Collect(&msg);
+  }
+  return 1;
+}
+
+void Worker::RunOneBatch(int step, Performance* perf){
+  //DLOG(ERROR)<<"Step "<<step;
+  // Test will call Pull which updates the sync time
+  // Hence we store the sync time, and restore it later
+  //float tSyncData=tSyncData_, tSyncParam=tSyncParam_;
+  if(ValidateNow(step)){
+    LOG(ERROR)<<"Validation at step "<<step;
+    Test(validation_net_, modelproto_.validation_steps(), perf!=nullptr);
+  }
+  if(TestNow(step)){
+    LOG(ERROR)<<"Test at step "<<step;
+    Test(test_net_, modelproto_.test_steps(), perf!=nullptr);
+  }
+  //tSyncData_=tSyncData; tSyncParam_=tSyncParam;
+
+  TrainOneBatch(step);
+  if(perf!=nullptr){
+    perf->Update();
+    if(DisplayNow(step)){
+      LOG(ERROR)<<"Training at step "<<step;
+      LOG(ERROR)<<"\t"<<perf->ToString();
+      perf->Reset();
+      //LOG(ERROR)<<"\t"<<TimerInfo();
+    }
+  }
+
+  /*
+  if(CheckpointNow(step)){
+    pm_->Checkpoint(cluster_->workspace()+"/snapshot-"+std::to_string(step));
+  }
+  */
+}
+
+void Worker::ReceiveBlobs(shared_ptr<NeuralNet> net){
+  /*
+  int type;
+  char *name;
+  int64_t tick=zclock_mono();
+  zframe_t* frame=zframe_new_empty();
+
+  zsock_recv(pull_, "isf", &type, &name, &frame);
+  if(type==kDataFrame){
+    auto* dst=static_cast<BridgeDstLayer*>(
+        net->name2layer(string(name)).get());
+    memcpy(dst->mutable_data()->mutable_cpu_data(), zframe_data(frame),
+        zframe_size(frame));
+    dst->set_ready(true);
+  }else if(type==kGradFrame){
+    auto* 
src=static_cast<BridgeSrcLayer*>(net->name2layer(string(name)).get());
+    memcpy(src->mutable_grad()->mutable_cpu_data(), zframe_data(frame),
+        zframe_size(frame));
+    src->set_ready(true);
+  }
+  zframe_destroy(&frame);
+  delete name;
+  tSyncData_+=zclock_mono()-tick;
+  */
+}
+
+void Worker::SendBlob(){
+
+}
+
+void Worker::Test(shared_ptr<NeuralNet> net, int nsteps, bool disperf){
+  Performance perf(net);
+  for(int step=0;step<nsteps;step++){
+    TestOneBatch(net, step, kTest);
+    if(disperf)
+      perf.Update();
+  }
+  if(disperf)
+    LOG(ERROR)<<"\t"<<perf.ToString();
+}
+
+/****************************BPWorker**********************************/
+
+void BPWorker::Forward(shared_ptr<NeuralNet> net, int step,  bool training){
+  auto& layers=net->layers();
+  for(auto& layer: layers){
+    if(layer->locationid()==worker_id_){
+      if(layer->is_bridgedstlayer()){
+        //auto* dst=static_cast<BridgeDstLayer*>(layer.get());
+        // receive fea blobs
+      }
+      if(training){
+        for(shared_ptr<Param> p: layer->GetParams()){
+          if(Collect(p, step)==0){
+            throw WorkerException();
+          }
+        }
+      }
+      layer->ComputeFeature(training);
+      if(layer->is_bridgesrclayer()){
+        // send fea blobs
+      }
+      if(training&&DisplayDebugInfo(step)&&layer->mutable_data()!=nullptr){
+        LOG(INFO)<<StringPrintf("Forward layer  %10s data norm1 %13.9f",
+            layer->name().c_str(), layer->data().asum_data());
+      }
+    }
+  }
+}
+
+void BPWorker::Backward(shared_ptr<NeuralNet> net, int step){
+  auto& layers=net->layers();
+  for (auto it = layers.rbegin(); it != layers.rend(); it++){
+    shared_ptr<Layer> layer=*it;
+    if(layer->locationid()==worker_id_){
+      if(layer->is_bridgesrclayer()){
+        //auto* src=static_cast<BridgeSrcLayer*>(layer.get());
+        // receive grad blobs
+      }
+      layer->ComputeGradient();
+      if(DisplayDebugInfo(step)&&layer->mutable_grad()!=nullptr){
+        LOG(INFO)<<StringPrintf("Backward layer %10s grad norm1 %13.9f\t",
+            layer->name().c_str(), layer->grad().asum_data());
+        for(shared_ptr<Param> p: layer->GetParams())
+          LOG(INFO)<<StringPrintf("param id %2d, name %10s,\
+              value norm1 %13.9f, grad norm1 %13.9f",
+              p->id(), p->name().c_str(),
+              p->data().asum_data(), p->grad().asum_data());
+      }
+      for(shared_ptr<Param> p: layer->GetParams()){
+        Update(p, step);
+      }
+      if(layer->is_bridgedstlayer()){
+        // send grad blobs
+      }
+    }
+  }
+}
+
+void BPWorker::TrainOneBatch(int step){
+  Forward(train_net_, step, true);
+  Backward(train_net_, step);
+}
+
+void BPWorker::TestOneBatch(shared_ptr<NeuralNet> net,int step, Phase phase){
+  Forward(net, step, false);
+}
+
+/*********************Implementation for Performance class*******************/
+Performance::Performance(shared_ptr<NeuralNet> net):net_(net), counter_(0){
+  for(auto& layer: net->losslayers()){
+    name_.push_back(layer->name());
+    metric_.push_back(vector<float>{});
+    metric_.back().resize(layer->metric().count(),0.f);
+  }
+}
+
+void Performance::Update(){
+  const auto& losslayers=net_->losslayers();
+  for(size_t i=0;i<losslayers.size();i++){
+    const float * ptr=losslayers[i]->metric().cpu_data();
+    vector<float>& m=metric_.at(i);
+    for(int j=0;j<losslayers[i]->metric().count();j++)
+      m[j]+=ptr[j];
+  }
+  counter_++;
+}
+
+void Performance::Reset(){
+  for(auto& m: metric_)
+    for(auto& x: m)
+      x=0.f;
+  counter_=0;
+}
+
+string Performance::ToString(){
+  string disp="";
+  for(size_t i=0;i<metric_.size();i++){
+    disp+="Output from "+name_[i]+" layer ";
+    vector<float> m=metric_.at(i);
+    for(size_t j=0;j<m.size();j++)
+        disp+=std::to_string(j)+" : "+std::to_string(m[j]/counter_)+"\t";
+    disp+="\n";
+  }
+  return disp;
+}
+/*
+void Executor::Setup(int local_threadid, const ModelProto& model){
+  tForward_=tBackward_=tSyncData_=tSyncParam_=0;
+  modelproto_=model;
+  local_threadid_=local_threadid;
+  if(model.prefetch()){
+    for(auto& layer: train_net_->datalayers()){
+      if(cluster_->group_threadid(local_threadid_)==layer->locationid())
+        localDataLayers_.push_back(layer);
+    }
+    if(localDataLayers_.size())
+      prefetch_thread_=std::thread(Executor::PrefetchData,
+          std::ref(localDataLayers_), true,1);
+  }
+  int gthreadid=cluster_->group_threadid(local_threadid);
+}
+
+void Executor::PrefetchData(const vector<DataLayer*>& datalayers, bool 
training,
+    int steps){
+  if(datalayers.size()==0)
+    return;
+  for(int i=0;i<steps;i++){
+    for(auto& layer: datalayers){
+      layer->Prefetching(training);
+      for(auto& dstlayer: layer->dstlayers()){
+        CHECK(dstlayer->is_parserlayer());
+        auto parserlayer=static_cast<ParserLayer*>(dstlayer.get());
+        parserlayer->Prefetching(training);
+      }
+    }
+  }
+}
+*/
+
+}  // namespace singa

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b2dc51d2/src/utils/blob.cc
----------------------------------------------------------------------
diff --git a/src/utils/blob.cc b/src/utils/blob.cc
new file mode 100644
index 0000000..92fc989
--- /dev/null
+++ b/src/utils/blob.cc
@@ -0,0 +1,330 @@
+/**
+ * The code is adapted from that of Caffe whose license is attached.
+ *
+ * COPYRIGHT
+ * All contributions by the University of California:
+ * Copyright (c) 2014, The Regents of the University of California (Regents)
+ * All rights reserved.
+ * All other contributions:
+ * Copyright (c) 2014, the respective contributors
+ * All rights reserved.
+ * Caffe uses a shared copyright model: each contributor holds copyright over
+ * their contributions to Caffe. The project versioning records all such
+ * contribution and copyright details. If a contributor wants to further mark
+ * their specific copyright on a particular contribution, they should indicate
+ * their copyright solely in the commit message of the change when it is
+ * committed.
+ * LICENSE
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are met:
+ * 1. Redistributions of source code must retain the above copyright notice, 
this
+ * list of conditions and the following disclaimer.
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
+ * this list of conditions and the following disclaimer in the documentation
+ * and/or other materials provided with the distribution.
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 
AND
+ * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 
IMPLIED
+ * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE 
FOR
+ * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 
DAMAGES
+ * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+ * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
+ * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+ * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF 
THIS
+ * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ * CONTRIBUTION AGREEMENT
+ * By contributing to the BVLC/caffe repository through pull-request, comment,
+ * or otherwise, the contributor releases their content to the
+ * license and copyright terms herein.
+ */
+#include <utility>
+#include <math.h>
+#include <cblas.h>
+#include "utils/blob.h"
+/*********************SyncedMemory implementation************************/
+
+#define NO_GPU LOG(FATAL) << "CPU-only Mode: cannot make GPU call."
+// Instantiate a class with float and double specifications.
+#define INSTANTIATE_CLASS(classname) \
+  template class classname<float>; \
+  template class classname<double>
+// Disable the copy and assignment operator for a class.
+#define DISABLE_COPY_AND_ASSIGN(classname) \
+private:\
+  classname(const classname&);\
+  classname& operator=(const classname&)
+
+#ifndef CPU_ONLY
+// CUDA: various checks for different function calls.
+#define CUDA_CHECK(condition) \
+  /* Code block avoids redefinition of cudaError_t error */ \
+  do { \
+    cudaError_t error = condition; \
+    CHECK_EQ(error, cudaSuccess) << " " << cudaGetErrorString(error); \
+  } while (0)
+
+#define CUBLAS_CHECK(condition) \
+  do { \
+    cublasStatus_t status = condition; \
+    CHECK_EQ(status, CUBLAS_STATUS_SUCCESS) << " " \
+      << caffe::cublasGetErrorString(status); \
+  } while (0)
+
+#define CURAND_CHECK(condition) \
+  do { \
+    curandStatus_t status = condition; \
+    CHECK_EQ(status, CURAND_STATUS_SUCCESS) << " " \
+      << caffe::curandGetErrorString(status); \
+  } while (0)
+
+#endif // CPU_ONLY
+
+
+SyncedMemory::~SyncedMemory() {
+  if (cpu_ptr_ && own_cpu_data_) {
+    FreeHost(cpu_ptr_);
+  }
+
+#ifndef CPU_ONLY
+  if (gpu_ptr_) {
+    CUDA_CHECK(cudaFree(gpu_ptr_));
+  }
+#endif  // CPU_ONLY
+}
+
+inline void SyncedMemory::to_cpu() {
+  switch (head_) {
+  case UNINITIALIZED:
+    MallocHost(&cpu_ptr_, size_);
+    memset(cpu_ptr_,0, size_);
+    head_ = HEAD_AT_CPU;
+    own_cpu_data_ = true;
+    break;
+  case HEAD_AT_GPU:
+#ifndef CPU_ONLY
+    if (cpu_ptr_ == NULL) {
+      MallocHost(&cpu_ptr_, size_);
+      own_cpu_data_ = true;
+    }
+    CUDA_CHECK(cudaMemcpy(cpu_ptr_, gpu_ptr_, size_, cudaMemcpyDefault));
+    head_ = SYNCED;
+#else
+    NO_GPU;
+#endif
+    break;
+  case HEAD_AT_CPU:
+  case SYNCED:
+    break;
+  }
+}
+
+inline void SyncedMemory::to_gpu() {
+#ifndef CPU_ONLY
+  switch (head_) {
+  case UNINITIALIZED:
+    CUDA_CHECK(cudaMalloc(&gpu_ptr_, size_));
+    CUDA_CHECK(cudaMemset(gpu_ptr_, 0, N));  // NOLINT(caffe/alt_fn)
+    head_ = HEAD_AT_GPU;
+    break;
+  case HEAD_AT_CPU:
+    if (gpu_ptr_ == NULL) {
+      CUDA_CHECK(cudaMalloc(&gpu_ptr_, size_));
+    }
+    CUDA_CHECK(cudaMemcpy( gpu_ptr_,cpu_ptr_, size_, cudaMemcpyDefault));
+    head_ = SYNCED;
+    break;
+  case HEAD_AT_GPU:
+  case SYNCED:
+    break;
+  }
+#else
+  NO_GPU;
+#endif
+}
+
+const void* SyncedMemory::cpu_data() {
+  to_cpu();
+  return (const void*)cpu_ptr_;
+}
+
+void SyncedMemory::set_cpu_data(void* data) {
+  CHECK(data);
+  if (own_cpu_data_) {
+    FreeHost(cpu_ptr_);
+  }
+  cpu_ptr_ = data;
+  head_ = HEAD_AT_CPU;
+  own_cpu_data_ = false;
+}
+
+const void* SyncedMemory::gpu_data() {
+#ifndef CPU_ONLY
+  to_gpu();
+  return (const void*)gpu_ptr_;
+#else
+  NO_GPU;
+#endif
+  return nullptr;
+}
+
+void* SyncedMemory::mutable_cpu_data() {
+  to_cpu();
+  head_ = HEAD_AT_CPU;
+  return cpu_ptr_;
+}
+
+void* SyncedMemory::mutable_gpu_data() {
+#ifndef CPU_ONLY
+  to_gpu();
+  head_ = HEAD_AT_GPU;
+  return gpu_ptr_;
+#else
+  NO_GPU;
+#endif
+  return nullptr;
+}
+
+/*********************Blob implementation************************/
+
+template <typename Dtype>
+Blob<Dtype>::Blob(const vector<int>& shape)
+  // capacity_ must be initialized before calling Reshape
+  : capacity_(0) {
+  Reshape(shape);
+}
+
+template <typename Dtype>
+void Blob<Dtype>::Reshape(const vector<int>& shape) {
+  count_=1;
+  shape_=shape;
+  for(size_t i=0;i<shape.size();i++){
+    CHECK(shape[i]);
+    count_*=shape[i];
+  }
+  if (count_ > capacity_) {
+    capacity_ = count_;
+    data_.reset(new SyncedMemory(capacity_ * sizeof(Dtype)));
+  }
+}
+
+template <typename Dtype>
+void Blob<Dtype>::ReshapeLike(const Blob<Dtype>& other) {
+  Reshape(other.shape());
+}
+
+template <typename Dtype>
+const Dtype* Blob<Dtype>::cpu_data() const {
+  CHECK(data_);
+  return (const Dtype*)data_->cpu_data();
+}
+
+template <typename Dtype>
+void Blob<Dtype>::set_cpu_data(Dtype* data) {
+  CHECK(data);
+  data_->set_cpu_data(data);
+}
+
+template <typename Dtype>
+const Dtype* Blob<Dtype>::gpu_data() const {
+  CHECK(data_);
+  return (const Dtype*)data_->gpu_data();
+}
+
+template <typename Dtype>
+Dtype* Blob<Dtype>::mutable_cpu_data() {
+  CHECK(data_);
+  return static_cast<Dtype*>(data_->mutable_cpu_data());
+}
+
+template <typename Dtype>
+Dtype* Blob<Dtype>::mutable_gpu_data() {
+  CHECK(data_);
+  return static_cast<Dtype*>(data_->mutable_gpu_data());
+}
+
+template <typename Dtype>
+void Blob<Dtype>::ShareData(const Blob& other) {
+  CHECK_EQ(count_, other.count());
+  data_ = other.data();
+}
+
+template <> float Blob<float>::asum_data() const {
+  if(count()==0)
+    return 0.f;
+  return cblas_sasum(count(), cpu_data(), 1)/count();
+}
+template <> float Blob<float>::sum_data() const {
+  if(count()==0)
+    return 0.f;
+  float sum=0.f;
+  const float *dptr=cpu_data();
+  for(int i=0;i<count();i++)
+    sum+=dptr[i];
+  return sum/count();
+}
+template <> unsigned int Blob<unsigned int>::asum_data() const {
+  NOT_IMPLEMENTED;
+  return 0;
+}
+
+template <> int Blob<int>::asum_data() const {
+  NOT_IMPLEMENTED;
+  return 0;
+}
+
+template <typename Dtype>
+void Blob<Dtype>::Swap(Blob& other){
+  CHECK_EQ(other.count(), count());
+  CHECK(std::equal(shape_.begin(), shape_.end(), other.shape_.begin()));
+  std::swap(data_, other.data_);
+  std::swap(capacity_, other.capacity_);
+}
+
+template <typename Dtype>
+void Blob<Dtype>::CopyFrom(const Blob& source, bool reshape) {
+  if (!std::equal(shape_.begin(),shape_.end(),source.shape_.begin())) {
+    if (reshape) {
+      Reshape(source.shape_);
+    } else {
+      LOG(FATAL) << "Trying to copy blobs of different sizes.";
+    }
+  }
+#ifndef CPU_ONLY
+  CUDA_CHECK(cudaMemcpy(static_cast<Dtype*>(data_->mutable_gpu_data()),
+            source.gpu_data(), sizeof(Dtype) * count_, cudaMemcpyDefault));
+#endif
+  memcpy(static_cast<Dtype*>(data_->mutable_cpu_data()),source.cpu_data(),
+        sizeof(Dtype)*count_);
+}
+
+/*
+template <typename Dtype>
+void Blob<Dtype>::FromProto(const BlobProto& proto) {
+  Reshape();
+  // copy data
+  Dtype* data_vec = mutable_cpu_data();
+  for (int i = 0; i < count_; ++i) {
+    data_vec[i] = proto.data(i);
+  }
+}
+*/
+
+template <typename Dtype>
+void Blob<Dtype>::ToProto(singa::BlobProto* proto) const {
+  proto->set_num(shape_[0]);
+  if(shape_.size()>1)
+    proto->set_channels(shape_[1]);
+  if(shape_.size()>2)
+    proto->set_height(shape_[2]);
+  if(shape_.size()>3)
+    proto->set_width(shape_[3]);
+  proto->clear_data();
+  const Dtype* data_vec = cpu_data();
+  for (int i = 0; i < count_; ++i) {
+    proto->add_data(data_vec[i]);
+  }
+}
+
+INSTANTIATE_CLASS(Blob);
+template class Blob<int>;
+template class Blob<unsigned int>;

Reply via email to