http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b2dc51d2/src/test/dist_test/test_table_server.cc ---------------------------------------------------------------------- diff --git a/src/test/dist_test/test_table_server.cc b/src/test/dist_test/test_table_server.cc new file mode 100644 index 0000000..5f3612c --- /dev/null +++ b/src/test/dist_test/test_table_server.cc @@ -0,0 +1,357 @@ +// Copyright © 2014 Anh Dinh. All Rights Reserved. + +#include "core/global-table.h" +#include "core/common.h" +#include "core/table.h" +#include "core/table_server.h" +#include "utils/global_context.h" +#include "utils/common.h" +#include <gflags/gflags.h> +#include "proto/model.pb.h" +#include "proto/common.pb.h" +#include "worker.h" +#include "coordinator.h" +#include "utils/common.h" +#include "utils/proto_helper.h" + +#include <cmath> +#include <stdlib.h> +#include <vector> +#include <iostream> +#include <fstream> + + +/** + * Test for table server access. The table is of type <VKey,int> + */ +DEFINE_bool(restore_mode, false, "restore from checkpoint file"); +using namespace lapis; +using std::vector; + +DEFINE_int32(checkpoint_frequency, 5000, "frequency for cp"); +DEFINE_int32(checkpoint_after, 1, "cp after this steps"); +DEFINE_string(par_mode, "hybrid", "time training algorithm"); +DEFINE_bool(restore, false, "restore from checkpoint file"); + +DEFINE_string(db_backend, "lmdb", "backend db"); +DEFINE_string(system_conf, "examples/imagenet12/system.conf", "configuration file for node roles"); +DEFINE_string(model_conf, "examples/imagenet12/model.conf", "DL model configuration file"); +DEFINE_string(checkpoint_dir,"/data1/wangwei/lapis/","check point dir"); +DEFINE_int32(threshold,1000000, "max # of parameters in a vector"); +DEFINE_int32(iterations,5,"numer of get/put iterations"); +DEFINE_int32(workers,2,"numer of workers doing get/put"); +DECLARE_bool(checkpoint_enabled); + + +DECLARE_bool(checkpoint_enabled); + +/** + * Get and update handler for VKey. + */ +struct AnhUpdateHandler: BaseUpdateHandler<VKey, SGDValue> { + bool Update(SGDValue *a, const SGDValue &b) { + + float * adptr = a->mutable_data()->mutable_value()->mutable_data(); + const float*bdptr = b.grad(0).value().data(); + for (int i = 0; i < b.grad(0).value_size(); i++) + adptr[i] += bdptr[i]; + + return true; + } + + bool Get(const VKey k, const SGDValue &val, SGDValue *ret) { + *ret = val; + return true; + } + + bool is_checkpointable(const VKey k, const SGDValue v) { + return false; //always checkpoint + } +}; + +typedef map<int, GlobalTable*> Map; +Map tables; +shared_ptr<NetworkThread> network; +shared_ptr<GlobalContext> context; +std::vector<ServerState*> server_states; +TableServer *table_server; + +#define SIZE 16 +int tuple_sizes[SIZE] = {27448736, 16777216, 4096000, 1327104, 884736, 884736, 614400,14112,4096,4096,1000,384,384,256,256,96}; + +/** + * Initialize tables. + */ +void create_mem_table(int id, int num_shards){ + + TableDescriptor *info = new TableDescriptor(id, num_shards); + info->key_marshal = new Marshal<VKey>(); + info->value_marshal = new Marshal<SGDValue>(); + info->sharder = new VKeySharder; + info->accum = new AnhUpdateHandler; + info->partition_factory = new typename SparseTable<VKey, SGDValue>::Factory; + auto table=new TypedGlobalTable<VKey, SGDValue>(); + table->Init(info); + tables[id] = table; +} + +/** + * Coordinator assigns shards to processes. + * @param id table ID. + */ +void coordinator_assign_tables(int id) { + + // wait for the servers to be up. + for (int i = 0; i < context->num_procs(); i++) { + RegisterWorkerRequest req; + int src = 0; + // adding memory server. + if (context->IsTableServer(i)) { + VLOG(3)<< "Waiting for message from table server " << i; + network->Read(MPI::ANY_SOURCE, MTYPE_REGISTER_WORKER, &req, &src); + server_states.push_back(new ServerState(i)); + } + } + + VLOG(3) << " All servers registered and started up. Ready to go"; + VLOG(3) << "num of shards" << tables[id]->num_shards() << " for table " << id; + + // assign table to shard in round roubin fashion. + int server_idx = 0; + for (int shard = 0; shard < tables[id]->num_shards(); ++shard) { + ServerState &server = *server_states[server_idx]; + VLOG(3) << "Assigning table (" << id << "," << shard << ") to server " + << server_states[server_idx]->server_id; + server.shard_id = shard; + server.local_shards.insert(new TaskId(id, shard)); + server_idx = (server_idx + 1) % server_states.size(); + } + ShardAssignmentRequest req; + for (size_t i = 0; i < server_states.size(); ++i) { + ServerState &server = *server_states[i]; + for (auto * task : server.local_shards) { + ShardAssignment *s = req.add_assign(); + s->set_new_worker(server.server_id); + s->set_table(task->table); + s->set_shard(task->shard); + // update local tables + GlobalTable *t = tables.at(task->table); + t->get_partition_info(task->shard)->owner = server.server_id; + delete task; + } + } + + network->SyncBroadcast(MTYPE_SHARD_ASSIGNMENT, MTYPE_SHARD_ASSIGNMENT_DONE, + req); + VLOG(3) << "done table assignment... "; +} + + +void table_init(){ + table_server = new TableServer(); + table_server->StartTableServer(tables); + VLOG(3) << "table server started on process "<< NetworkThread::Get()->id(); +} + + +/** + * Coordinator loads data to the table. + * @param size number of tuples. + */ +void coordinator_load_data() { + auto table = static_cast<TypedGlobalTable<VKey, SGDValue>*>(tables[0]); + for (int i = 0; i < SIZE; i++) { + VKey key; + SGDValue x; + DAryProto *data = x.mutable_data(); + DAryProto *grad = x.add_grad(); + for (int j = 0; j < tuple_sizes[i]; j++) { + data->add_value(j * 1.0f); + grad->add_value(j * 1.0f); + } + key.set_key(i); + table->put(key, x); + } + VLOG(3) << "Done loading " << SIZE << " tuples ..."; +} + +/** + * Worker gets tuples from the server. + * @param size number of tuples to be requested. + */ +void get() { + auto table = static_cast<TypedGlobalTable<VKey,SGDValue>*>(tables[0]); + SGDValue value; + for (int i = 0; i < SIZE; i++) { + VKey key; + key.set_key(i); + table->async_get(key, &value); + } + VLOG(3) << "Done sending get requests ..."; + + for (int i = 0; i < SIZE; i++) { + VKey key; + while (!table->async_get_collect(&key, &value)) + Sleep(0.0001); + } +} + +/** + * Worker updates tuples. + */ +void update() { + auto table = static_cast<TypedGlobalTable<VKey, SGDValue>*>(tables[0]); + for (int i = 0; i < SIZE; i++) { + VKey key; + key.set_key(i); + + SGDValue x; + DAryProto *grad = x.add_grad(); + for (int j = 0; j < tuple_sizes[i]; j++) + grad->add_value(j * 1.0f); + + table->update(key, x); + } + VLOG(3) << "Done updating " << SIZE << " tuples ..."; +} + + +void worker_test_data() { + //get(size); + update(); + update(); + get(); + /* + update(table, tuples); + update(table, tuples); + update(table, tuples); + get(table, tuples); + */ +} + +/** + * Shutdown the process. + */ +void shutdown() { + if (context->AmICoordinator()) { + EmptyMessage msg; + for (int i = 0; i < context->num_procs() - 1; i++) + network->Read(MPI::ANY_SOURCE, MTYPE_WORKER_END, &msg); + EmptyMessage shutdown_msg; + for (int i = 0; i < network->size() - 1; i++) { + network->Send(i, MTYPE_SHUTDOWN, shutdown_msg); + } + //network->Flush(); + network->Shutdown(); + } else { + //network->Flush(); + network->Send(context->num_procs() - 1, MTYPE_WORKER_END, + EmptyMessage()); + EmptyMessage msg; + network->Read(context->num_procs() - 1, MTYPE_SHUTDOWN, &msg); + + if (context->AmITableServer()){ + RequestDispatcher::Get()->PrintStats(); + table_server->ShutdownTableServer(); + } + + network->Shutdown(); + } +} + +/** + * Worker handle shard assignment from the coordinator. + */ +void HandleShardAssignment() { + + ShardAssignmentRequest shard_req; + auto mpi = NetworkThread::Get(); + mpi->Read(GlobalContext::kCoordinator, MTYPE_SHARD_ASSIGNMENT, &shard_req); + + // request read from coordinator + for (int i = 0; i < shard_req.assign_size(); i++) { + const ShardAssignment &a = shard_req.assign(i); + GlobalTable *t = tables.at(a.table()); + t->get_partition_info(a.shard())->owner = a.new_worker(); + + //if local shard, create check-point files + if (FLAGS_checkpoint_enabled && t->is_local_shard(a.shard())) { + string checkpoint_file = StringPrintf("%s/checkpoint_%d", + FLAGS_checkpoint_dir.c_str(), a.shard()); + char hostname[256]; + gethostname(hostname, sizeof(hostname)); + + FILE *tmp_file = fopen(checkpoint_file.c_str(), "r"); + if (tmp_file) { //exists -> open to reading and writing + fclose(tmp_file); + auto cp = t->checkpoint_files(); + + if (FLAGS_restore_mode) { //open in read mode to restore, then close + LogFile *file = new LogFile(checkpoint_file, "rw", 0); + int table_size = file->read_latest_table_size(); + delete file; + + double start = Now(); + (*cp)[a.shard()] = new LogFile(checkpoint_file, "r", + a.shard()); + t->Restore(a.shard()); + delete (*cp)[a.shard()]; + double end = Now(); + LOG(ERROR) << "restore time\t" << end - start << "\tfor\t" + << table_size << "\tthreshold\t" << FLAGS_threshold; + } + char hostname[256]; + gethostname(hostname, sizeof(hostname)); + (*cp)[a.shard()] = new LogFile(checkpoint_file, "a", a.shard()); + } else { // not exist -> open to writing first time + auto cp = t->checkpoint_files(); + (*cp)[a.shard()] = new LogFile(checkpoint_file, "w", a.shard()); + } + } + } + + EmptyMessage empty; + mpi->Send(GlobalContext::kCoordinator, MTYPE_SHARD_ASSIGNMENT_DONE, empty); + VLOG(3) << "Done handling shard assignment ..."; + +} + + +int main(int argc, char **argv) { + FLAGS_logtostderr = 1; + int provided; + MPI_Init_thread(&argc, &argv, MPI_THREAD_MULTIPLE, &provided); + google::InitGoogleLogging(argv[0]); + gflags::ParseCommandLineFlags(&argc, &argv, true); + + context = GlobalContext::Get(FLAGS_system_conf); + network = NetworkThread::Get(); + + ModelProto model; + ReadProtoFromTextFile(FLAGS_model_conf.c_str(), &model); + + create_mem_table(0, context->num_table_servers()); + + if (context->AmICoordinator()) { + coordinator_assign_tables(0); + coordinator_load_data(); + network->barrier(); + } else { + if (context->AmITableServer()) { + table_init(); + HandleShardAssignment(); + network->barrier(); + } else { + HandleShardAssignment(); + network->barrier(); + Sleep(1); + VLOG(3) << "Worker cleared the barrier ..."; + worker_test_data(); + } + } + + shutdown(); + return 0; +} + +
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b2dc51d2/src/test/dist_test/test_tuple.cc ---------------------------------------------------------------------- diff --git a/src/test/dist_test/test_tuple.cc b/src/test/dist_test/test_tuple.cc new file mode 100644 index 0000000..727f8e3 --- /dev/null +++ b/src/test/dist_test/test_tuple.cc @@ -0,0 +1,258 @@ +#include <cstdio> +#include <iostream> +#include <sstream> +#include <string> +#include <vector> + +#include "server.h" +#include "proto/worker.pb.h" +#include "utils/network_service.h" +#include "core/common.h" +#include "core/network_queue.h" +#include "proto/model.pb.h" +#include "proto/common.pb.h" +#include "utils/global_context.h" + +/** + * @file test_tuple.cc + * + * Test performance of TableServer put/get/update operations. + */ +DECLARE_double(sleep_time); + +using namespace lapis; +using namespace std; +using std::vector; + +#define NKEYS 1000 +#define TUPLE_SIZE 50000000 + +#ifndef FLAGS_v + DEFINE_int32(v, 3, "vlog controller"); +#endif + + +#define SIZE 16 +#define THRESHOLD 500000 +int tuple_sizes[SIZE] = {37448736, 16777216, 4096000, 1327104, 884736, 884736, 614400,14112,4096,4096,1000,384,384,256,256,96}; +vector<int> valsizes; +int collect_size; +int num_tuples; + +void Put(int tid, int size, int version) { + RequestBase request; + request.set_table(0); + request.set_source(NetworkService::Get()->id()); + PutRequest *put_req = request.MutableExtension(PutRequest::name); + int shard = tid % GlobalContext::Get()->num_servers(); + put_req->set_shard(shard); + TableData *tuple = put_req->mutable_data(); + + TKey* key = tuple->mutable_key(); + TVal* val = tuple->mutable_value(); + + key->set_id(tid); + key->set_version(version); + + DAryProto *data = val->mutable_data(); + for (int i = 0; i < size; i++){ + data->add_value(0.0f); + } + + // TODO check the msg type + NetworkService::Get()->Send(shard, MTYPE_REQUEST, request); +} + +void Update(int tid, int size, int version) { + RequestBase request; + request.set_table(0); + request.set_source(NetworkService::Get()->id()); + UpdateRequest *update_req = request.MutableExtension(UpdateRequest::name); + int shard = tid % GlobalContext::Get()->num_servers(); + update_req->set_shard(shard); + TableData *tuple = update_req->mutable_data(); + + TKey* key = tuple->mutable_key(); + TVal* val = tuple->mutable_value(); + + key->set_id(tid); + key->set_version(version); + + DAryProto *data = val->mutable_grad(); + for (int i = 0; i < size; i++) + data->add_value(1.0f); + // TODO check the msg type + NetworkService::Get()->Send(shard, MTYPE_REQUEST, request); +} + +void print_result(TableData *data){ + TKey *key = data->mutable_key(); + TVal *val = data->mutable_value(); + int k = key->id(); + VLOG(3) << "key = " << k; + string s; + for (int i=0; i<TUPLE_SIZE; i++) + s.append(to_string(val->mutable_data()->value(i))).append(" "); + VLOG(3) << "val = " <<s; +} + +void AsyncGet(int tid, int version) { + RequestBase request; + request.set_table(0); + request.set_source(GlobalContext::Get()->rank()); //NetworkService::Get()->id()); + GetRequest *get_req = request.MutableExtension(GetRequest::name); + int shard = tid % GlobalContext::Get()->num_servers(); + get_req->set_shard(shard); + + TKey *key = get_req->mutable_key(); + key->set_id(tid); + key->set_version(version); + NetworkService::Get()->Send(shard, MTYPE_REQUEST, request); + +} + +void Collect(){ + int count = collect_size; + double start_collect = Now(); + while (count){ + while (true) { + Message *resp = NetworkService::Get()->Receive(); + if (!resp) + Sleep(FLAGS_sleep_time); + else{ + delete resp; + break; + } + } + count--; + } + double end_collect = Now(); + VLOG(3) << "Collected " << collect_size << " tuples in " << (end_collect-start_collect); +} + +/** + * Workers wait for the barrier, then one of them send SHUTDOWN message + * to all table servers. + */ +void worker_send_shutdown(int id){ + auto gc = lapis::GlobalContext::Get(); + NetworkService *network_service_ = NetworkService::Get().get(); + MPI_Barrier(gc->workergroup_comm()); + if (gc->rank()==id){ + for (int i=0; i<gc->num_procs(); i++){ + if (gc->IsTableServer(i)){ + EmptyMessage msg; + network_service_->Send(i, MTYPE_SHUTDOWN,msg); + } + } + } +} + +/** + * One worker with the specific ID puts, others wait. + */ +void worker_load_data(int id){ + auto gc = lapis::GlobalContext::Get(); + for (int i = 0; i < SIZE; i++) { + int m = tuple_sizes[i]; + if (m < THRESHOLD) + valsizes.push_back(m); + else { + for (int j = 0; j < m / THRESHOLD; j++) + valsizes.push_back(THRESHOLD); + if (m % THRESHOLD) + valsizes.push_back(m%THRESHOLD); + } + } + num_tuples = (int)valsizes.size(); + collect_size = 0; + for (int i=0; i<num_tuples; i++) + if (i%gc->group_size()==gc->worker_id()) + collect_size++; + + if (gc->rank()==id){ + for (size_t i=0; i<valsizes.size(); i++) + Put(i,valsizes[i],0); + VLOG(3) << "Done loading data, num_keys = "<<valsizes.size() << " process " << id; + } + VLOG(3) << "Collect size = " << collect_size; + MPI_Barrier(gc->workergroup_comm()); +} + +void worker_update_data() { + auto gc = lapis::GlobalContext::Get(); + for (int i = 0; i < num_tuples; i++) + if (i%gc->group_size()==gc->worker_id()) + Update(i,valsizes[i],0); + + VLOG(3) << "Done update ... for "<<collect_size << " tuples "; +} + +/* + * Async get. + */ +void worker_get_data(){ + auto gc = lapis::GlobalContext::Get(); + for (int i=0; i<num_tuples; i++) + if (i%gc->group_size()==gc->worker_id()) + AsyncGet(i,0); + Collect(); + VLOG(3) << "Done collect ..."; +} + +void start_network_service_for_worker(){ + NetworkService *network_service_ = NetworkService::Get().get(); + network_service_->Init(GlobalContext::Get()->rank(), Network::Get().get(), new SimpleQueue()); + network_service_->StartNetworkService(); +} + +int main(int argc, char **argv) { + google::InitGoogleLogging(argv[0]); + gflags::ParseCommandLineFlags(&argc, &argv, true); + + int provided; + + + MPI_Init_thread(&argc, &argv, MPI_THREAD_MULTIPLE, &provided); + + + FLAGS_logtostderr = 1; + + + // Init GlobalContext + Cluster cluster; + cluster.set_server_start(0); + cluster.set_server_end(8); + cluster.set_worker_start(8); + cluster.set_worker_end(24); + cluster.set_group_size(8); + cluster.set_data_folder("/data1/wangwei/lapis"); + + auto gc = lapis::GlobalContext::Get(cluster); + + // worker or table server + if (gc->AmITableServer()) { + lapis::TableServer server; + SGDProto sgd; + sgd.set_learning_rate(0.01); + sgd.set_momentum(0.9); + sgd.set_weight_decay(0.1); + sgd.set_gamma(0.5); + sgd.set_learning_rate_change_steps(1); + server.Start(sgd); + } else { + start_network_service_for_worker(); + worker_load_data(cluster.worker_start()); + for (int i=0; i<10; i++){ + worker_update_data(); + worker_get_data(); + } + worker_send_shutdown(cluster.worker_start()); + NetworkService::Get()->Shutdown(); + } + gc->Finalize(); + MPI_Finalize(); + VLOG(3) << "End, process "<< gc->rank(); + return 0; +} + http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b2dc51d2/src/test/model/test_blob.cc ---------------------------------------------------------------------- diff --git a/src/test/model/test_blob.cc b/src/test/model/test_blob.cc new file mode 100644 index 0000000..75f1921 --- /dev/null +++ b/src/test/model/test_blob.cc @@ -0,0 +1,58 @@ +// Copyright © 2014 Wei Wang. All Rights Reserved. +// 2014-07-18 19:44 +#include <gtest/gtest.h> +#include "proto/model.pb.h" +#include "model/lapis.h" + +namespace lapis { +class BlobTest : public ::testing::Test { + public: + BlobTest() : blob1(new Blob()), blob2(new Blob()) {} + ~BlobTest() { + delete blob1; + delete blob2; + } + protected: + Blob *blob1, *blob2; + Blob blob3, blob4; +}; + +TEST_F(BlobTest, Constructor) { + EXPECT_EQ(blob1->length(), 0); + EXPECT_EQ(blob1->width(), 0); + EXPECT_EQ(blob1->height(), 0); + EXPECT_EQ(blob3.length(), 0); + EXPECT_EQ(blob3.width(), 0); + EXPECT_EQ(blob3.height(), 0); + EXPECT_TRUE(blob2->dptr == nullptr); + EXPECT_TRUE(blob4.dptr == nullptr); +} + +TEST_F(BlobTest, TestResize) { + blob1->Resize(10,1,1,1); + EXPECT_EQ(blob1->length(), 10); + EXPECT_EQ(blob1->num(), 10); + EXPECT_EQ(blob1->height(), 1); + EXPECT_EQ(blob1->width(), 1); + EXPECT_TRUE(blob1->dptr != nullptr); + blob2->Resize(4,1,1,3); + EXPECT_EQ(blob2->length(), 12); + EXPECT_EQ(blob2->num(), 4); + EXPECT_EQ(blob2->height(), 1); + EXPECT_EQ(blob2->width(), 3); + EXPECT_TRUE(blob2->dptr != nullptr); + blob3.Resize(5,1,4,3); + EXPECT_EQ(blob3.length(), 60); + EXPECT_EQ(blob3.num(), 5); + EXPECT_EQ(blob3.height(), 4); + EXPECT_EQ(blob3.width(), 3); + EXPECT_TRUE(blob3.dptr != nullptr); + blob4.Resize(6,5,4,3); + EXPECT_EQ(blob4.length(), 360); + EXPECT_EQ(blob4.num(), 6); + EXPECT_EQ(blob4.height(), 4); + EXPECT_EQ(blob4.width(), 3); + EXPECT_TRUE(blob4.dptr != nullptr); +} + +} // namespace lapis http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b2dc51d2/src/test/model/test_data_layer.cc ---------------------------------------------------------------------- diff --git a/src/test/model/test_data_layer.cc b/src/test/model/test_data_layer.cc new file mode 100644 index 0000000..49519a5 --- /dev/null +++ b/src/test/model/test_data_layer.cc @@ -0,0 +1,178 @@ +// Copyright © 2014 Wei Wang. All Rights Reserved. +// 2014-08-01 16:09 + +#include <gtest/gtest.h> +#include <glog/logging.h> +#include <map> +#include <vector> + +#include "model/data_layer.h" +#include "model/trainer.h" +#include "model/sgd_trainer.h" +#include "model/conv_edge.h" +#include "model/relu_layer.h" +#include "proto/model.pb.h" + +#include "utils/proto_helper.h" + +namespace lapis { +class ModelTest : public ::testing::Test { + public: + ModelTest () { + ReadProtoFromTextFile("src/test/data/model.conf", &model_proto); + } + protected: + ModelProto model_proto; +}; +/********************************************************************** + * DataLayer Test + **********************************************************************/ +class DataLayerTest : public ModelTest { + public: + DataLayerTest() { + label_layer.Init(model_proto.net().layer(0)); + img_layer.Init(model_proto.net().layer(1)); + Trainer::InitDataSource(model_proto.trainer().train_data(), &sources); + EXPECT_EQ(2, sources.size()); + sources[0]->LoadData(nullptr); + sources[1]->LoadData(nullptr); + DLOG(INFO)<<"after init datasources"; + label_layer.Setup(2, TrainerProto::kBackPropagation, sources); + DLOG(INFO)<<"after setup label layer"; + img_layer.Setup(2, TrainerProto::kBackPropagation, sources); + DLOG(INFO)<<"after setup img layer"; + } + ~DataLayerTest() { + for(auto& source: sources) + delete source; + } + protected: + DataLayer img_layer, label_layer; + std::vector<DataSource*> sources; +}; + +TEST_F(DataLayerTest, InitSetupForward) { + EXPECT_TRUE(label_layer.HasInput()); + EXPECT_TRUE(img_layer.HasInput()); + EXPECT_STREQ("DataLayer", DataLayer::kType.c_str()); + + EXPECT_EQ(2, label_layer.feature(nullptr).num()); + EXPECT_EQ(1, label_layer.feature(nullptr).channels()); + EXPECT_EQ(1, label_layer.feature(nullptr).height()); + EXPECT_EQ(1, label_layer.feature(nullptr).width()); + + EXPECT_EQ(2, img_layer.feature(nullptr).num()); + EXPECT_EQ(3, img_layer.feature(nullptr).channels()); + EXPECT_EQ(227, img_layer.feature(nullptr).height()); + EXPECT_EQ(227, img_layer.feature(nullptr).width()); + + img_layer.Forward(); +} +// TODO(wangwei) test this after outgoing edges are tested + +/********************************************************************** + * ConvEdge Test + **********************************************************************/ +class ConvEdgeTest : public DataLayerTest { + public: + ConvEdgeTest() { + relu.Init(model_proto.net().layer(2)); + DLOG(INFO)<<"init both layers"; + layer_map["input_img"]=&img_layer; + layer_map["hidden1_relu"]=&relu; + + edge_proto=model_proto.net().edge(0); + convedge.Init(edge_proto, layer_map); + convedge.Setup(true); + } + protected: + std::map<std::string, Layer*> layer_map; + ConvEdge convedge; + EdgeProto edge_proto; + ReLULayer relu; +}; + +TEST_F(ConvEdgeTest, InitSetupForward) { + Layer* dest=layer_map.at("hidden1_relu"); + Blob &b=dest->feature(&convedge); + EXPECT_EQ(0,b.num()); + convedge.SetupTopBlob(&b); + int conv_height = (227 + 2 * edge_proto.pad() - edge_proto.kernel_size()) + / edge_proto.stride() + 1; + int conv_width=conv_height; + CHECK_EQ(2, b.num()); + CHECK_EQ(edge_proto.num_output(), b.channels()); + CHECK_EQ(conv_height, b.height()); + CHECK_EQ(conv_width, b.width()); + DLOG(INFO)<<"after shape check"; + + Layer* src=layer_map["input_img"]; + convedge.Forward(src->feature(&convedge), &b, true); +} + +/********************************************************************** + * ReLULayer Test + **********************************************************************/ +class ReLULayerTest : public ConvEdgeTest { + public: + ReLULayerTest() { + relu.Setup(2, TrainerProto::kBackPropagation, sources); + relu_proto=model_proto.net().layer(3); + } + protected: + LayerProto relu_proto; +}; + +TEST_F(ReLULayerTest, ForwardWithoutDropout) { + EXPECT_EQ(2, relu.feature(&convedge).num()); + EXPECT_EQ(2, relu.gradient(&convedge).num()); + + relu.Forward(); +} +/********************************************************************** + * PoolingEdge Test +class PoolingEdgeTest : public ReLULayerTest { + public: + PoolingEdgeTest() { + linearlayer.Init(model.net().layer(3)); + pooledge.Init(model.net().edge(1)); + } + + protected: + PoolingEdge pooledge; + LinearLayer linearlayer; +} + **********************************************************************/ +/********************************************************************** + * LinearLayer Test + **********************************************************************/ + +/********************************************************************** + * LRNEdge Test + **********************************************************************/ + +/********************************************************************** + * InnerProductEdge Test + **********************************************************************/ + +/********************************************************************** + * SoftmaxLayerLossEdge Test + **********************************************************************/ + + + + +/********************************************************************** + * SGDTrainer Test + **********************************************************************/ +class SGDTrainerTest : public ModelTest { + protected: + SGDTrainer sgd; +}; + +TEST_F(SGDTrainerTest, Init) { + sgd.Init(model_proto.trainer()); + EXPECT_TRUE(Trainer::phase==Phase::kInit); +} + +} // namespace lapis http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b2dc51d2/src/test/model/test_label_source.cc ---------------------------------------------------------------------- diff --git a/src/test/model/test_label_source.cc b/src/test/model/test_label_source.cc new file mode 100644 index 0000000..9b25c2a --- /dev/null +++ b/src/test/model/test_label_source.cc @@ -0,0 +1,59 @@ +// Copyright © 2014 Wei Wang. All Rights Reserved. +// 2014-07-21 19:40 + +#include <gtest/gtest.h> +#include <glog/logging.h> +#include "proto/model.pb.h" +#include "disk/label_source.h" + +namespace lapis { +class LabelSourceTest : public ::testing::Test { + public: + LabelSourceTest() { + DataSourceProto ds; + ds.set_path("src/test/data/label_source.dat"); + ds.set_size(12); + ds.set_name("label source"); + ls.Init(ds); + } + + protected: + LabelSource ls; +}; + +TEST_F(LabelSourceTest, LoadData) { + auto ptr2names = ls.LoadData(nullptr); + EXPECT_EQ(12, ptr2names->size()); + EXPECT_STREQ("img0.JPEG", ptr2names->at(0).c_str()); + EXPECT_STREQ("img1.JPEG", ptr2names->at(1).c_str()); + EXPECT_STREQ("img5.JPEG", ptr2names->at(5).c_str()); + EXPECT_STREQ("img10.JPEG", ptr2names->at(10).c_str()); + EXPECT_STREQ("img11.JPEG", ptr2names->at(11).c_str()); +} + +TEST_F(LabelSourceTest, GetData) { + ls.LoadData(nullptr); + Blob b; + b.Resize(1, 1, 1, 5); + ls.GetData(&b); + const float *val = b.dptr; + EXPECT_EQ(0.0f, val[0]); + EXPECT_EQ(1.0f, val[1]); + EXPECT_EQ(4.0f, val[2]); + EXPECT_EQ(9.0f, val[3]); + EXPECT_EQ(16.0f, val[4]); + ls.GetData(&b); + EXPECT_EQ(4.0f, val[0]); + EXPECT_EQ(5.0f, val[1]); + EXPECT_EQ(6.0f, val[2]); + EXPECT_EQ(7.0f, val[3]); + EXPECT_EQ(8.0f, val[4]); + ls.GetData(&b); + EXPECT_EQ(1.0f, val[0]); + EXPECT_EQ(2.0f, val[1]); + EXPECT_EQ(0.0f, val[2]); + EXPECT_EQ(1.0f, val[3]); + EXPECT_EQ(4.0f, val[4]); +} + +} // namespace lapis http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b2dc51d2/src/test/model/test_param.cc ---------------------------------------------------------------------- diff --git a/src/test/model/test_param.cc b/src/test/model/test_param.cc new file mode 100644 index 0000000..520fbe2 --- /dev/null +++ b/src/test/model/test_param.cc @@ -0,0 +1,138 @@ +#include <gtest/gtest.h> +#include <glog/logging.h> +#include "proto/model.pb.h" + +#include "utils/param.h" + +using namespace singa; + +class ParamTest : public ::testing::Test { + public: + ParamTest() { + wp.set_name("weight"); + wp.add_shape(3); + wp.add_shape(4); + bp.set_name("bias"); + bp.add_shape(4); + } + protected: + Param w, b; + ParamProto wp, bp; +}; + +TEST_F(ParamTest, ConstantInit) { + bp.set_init_method(ParamProto::kConstant); + bp.set_value(0.5); + b.Init(bp); + const float *val = b.content().dptr; + EXPECT_EQ(0.5f, val[0]); + EXPECT_EQ(0.5f, val[1]); + EXPECT_EQ(0.5f, val[2]); + EXPECT_EQ(0.5f, val[3]); + wp.set_init_method(ParamProto::kConstant); + wp.set_value(1.5); + w.Init(wp); + val = w.content().dptr; + EXPECT_EQ(1.5f, val[0]); + EXPECT_EQ(1.5f, val[3]); + EXPECT_EQ(1.5f, val[4]); + EXPECT_EQ(1.5f, val[11]); +} + +TEST_F(ParamTest, UniformInit) { + bp.set_init_method(ParamProto::kUniform); + bp.set_value(1.0f); + b.Init(bp); + const float *val = b.content().dptr; + EXPECT_TRUE(val[0] >= -1 && val[0] <= 1); + EXPECT_TRUE(val[1] >= -1 && val[2] <= 1); + EXPECT_TRUE(val[2] >= -1 && val[2] <= 1); + EXPECT_TRUE(val[3] >= -1 && val[3] <= 1); + wp.set_init_method(ParamProto::kUniform); + wp.set_value(1.0f); + w.Init(wp); + val = w.content().dptr; + EXPECT_TRUE(val[0] >= -1 && val[0] <= 1); + EXPECT_TRUE(val[3] >= -1 && val[3] <= 1); + EXPECT_TRUE(val[4] >= -1 && val[4] <= 1); + EXPECT_TRUE(val[11] >= -1 && val[11] <= 1); +} + +TEST_F(ParamTest, UniformSqrtFanInInit) { + wp.set_init_method(ParamProto::kUniformSqrtFanIn); + wp.set_value(2.0f); + w.Init(wp); + const float *val = w.content().dptr; + EXPECT_TRUE(val[0] >= -2 && val[0] <= 2); + EXPECT_TRUE(val[3] >= -2 && val[3] <= 2); + EXPECT_TRUE(val[4] >= -2 && val[4] <= 2); + EXPECT_TRUE(val[11] >= -2 && val[11] <= 2); +} + + +TEST_F(ParamTest, UniformSqrtFanInOutInit) { + wp.set_init_method(ParamProto::kUniformSqrtFanInOut); + wp.set_value(1.0f); + float low=1.0f, high=5.0f; + wp.set_low(low); + wp.set_high(high); + w.Init(wp); + const float *val = w.content().dptr; + /* + LOG(INFO) << val[0] << " " << val[1] << " " << val[2] << " " << val[3]; + LOG(INFO) << val[4] << " " << val[5] << " " << val[6] << " " << val[7]; + LOG(INFO) << val[8] << " " << val[9] << " " << val[10] << " " << val[11]; + */ + float factor = wp.value() / sqrt(wp.shape(0) + wp.shape(1)); + low=low*factor; + high=high*factor; + LOG(INFO)<<low<<" "<<high; + EXPECT_TRUE(val[0] >= low && val[0] <= high); + EXPECT_TRUE(val[3] >= low && val[3] <= high); + EXPECT_TRUE(val[4] >= low && val[4] <= high); + EXPECT_TRUE(val[11] >= low && val[11] <= high); +} + +TEST_F(ParamTest, GaussianInit) { + int len=5000, mean=0.0f, std=1.0f; + ParamProto p; + p.set_name("bias"); + p.add_shape(1); + p.add_shape(len); + p.set_init_method(ParamProto::kGaussain); + p.set_value(1.0f); + p.set_mean(mean); + p.set_std(std); + w.Init(p); + + const float *val = w.content().dptr; + float dmean=0.0f; + for(int i=0;i<len;i++) + dmean+=val[i]; + dmean/=len; + float dstd=0.0f; + for(int i=0;i<len;i++) + dstd+=(dmean-val[i])*(dmean-val[i]); + dstd/=len; + EXPECT_TRUE(std::abs(mean-dmean)<0.1); + EXPECT_TRUE(std::abs(std-dstd)<0.1); + /* + LOG(INFO) << val[0] << " " << val[1] << " " << val[2] << " " << val[3]; + LOG(INFO) << val[4] << " " << val[5] << " " << val[6] << " " << val[7]; + LOG(INFO) << val[8] << " " << val[9] << " " << val[10] << " " << val[11]; + */ +} + +TEST_F(ParamTest, GaussianSqrtFanInInit) { + wp.set_init_method(ParamProto::kGaussainSqrtFanIn); + wp.set_value(1.0f); + wp.set_mean(0); + wp.set_std(1.0f); + w.Init(wp); + //const float *val = w.content().dptr; + /* + LOG(INFO) << val[0] << " " << val[1] << " " << val[2] << " " << val[3]; + LOG(INFO) << val[4] << " " << val[5] << " " << val[6] << " " << val[7]; + LOG(INFO) << val[8] << " " << val[9] << " " << val[10] << " " << val[11]; + */ +} http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b2dc51d2/src/test/model/test_proto.cc ---------------------------------------------------------------------- diff --git a/src/test/model/test_proto.cc b/src/test/model/test_proto.cc new file mode 100644 index 0000000..f6d81fd --- /dev/null +++ b/src/test/model/test_proto.cc @@ -0,0 +1,67 @@ +// Copyright © 2014 Wei Wang. All Rights Reserved. +// 2014-07-15 21:54 +#include <glog/logging.h> +#include <gtest/gtest.h> +#include "proto/model.pb.h" +#include "utils/proto_helper.h" +namespace lapis { + +// use const Message& m=..., otherwise may lead to segment fault +TEST(ProtoTest, ReadFromFile) { + ModelProto model; + LOG(INFO)<<"start...."; + lapis::ReadProtoFromTextFile("src/test/data/model.conf", &model); + LOG(INFO)<<"after reading file..."; + EXPECT_STREQ("caffe_config", model.name().c_str()); + + // layer and edge size + const NetProto& net = model.net(); + EXPECT_EQ(15, net.layer().size()); + EXPECT_EQ(14, net.edge().size()); + LOG(INFO)<<"after size check..."; + + // layer config + LayerProto layer1 = net.layer().Get(1); + EXPECT_STREQ("input_img", layer1.name().c_str()); + EXPECT_STREQ("DataLayer", layer1.type().c_str()); + LOG(INFO)<<"after datalayer check..."; + // edge config + EdgeProto edge0 = net.edge().Get(0); + EXPECT_STREQ("input_img-hidden1_relu", edge0.name().c_str()); + EXPECT_STREQ("ConvEdge", edge0.type().c_str()); + EXPECT_EQ(2, edge0.param().size()); + LOG(INFO)<<"after first edge check..."; + // param config + ParamProto param1 = edge0.param().Get(0); + EXPECT_TRUE(ParamProto::kGaussain == param1.init_method()); + EXPECT_EQ(0.0f, param1.mean()); + EXPECT_EQ(0.01f, param1.std()); + EXPECT_EQ(1.0f, param1.learning_rate_multiplier()); + LOG(INFO)<<"after param of first edge check..."; + + ParamProto param2 = edge0.param().Get(1); + EXPECT_TRUE(ParamProto::kConstant == param2.init_method()); + EXPECT_EQ(0.0f, param2.value()); + EXPECT_EQ(0.0f, param2.weight_decay_multiplier()); + LOG(INFO)<<"after param of second edge check..."; + + // trainer config + const TrainerProto& trainer = model.trainer(); + const SGDProto& sgd=trainer.sgd(); + EXPECT_EQ(227, sgd.train_batchsize()); + EXPECT_EQ(0.01f, sgd.base_learning_rate()); + EXPECT_TRUE(SGDProto::kStep== sgd.learning_rate_change()); + LOG(INFO)<<"after sgd check..."; + + // data source config + EXPECT_EQ(2,trainer.train_data().size()); + LOG(INFO)<<"after size check..."; + const DataSourceProto& data=trainer.train_data(0); + LOG(INFO)<<"after get data..."; + EXPECT_STREQ("RGBDirSource", data.type().c_str()); + LOG(INFO)<<"after type check..."; + EXPECT_EQ(50000, data.size()); + EXPECT_EQ(3, data.channels()); + LOG(INFO)<<"after data source check..."; +} +} // namespace lapis http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b2dc51d2/src/test/model/test_rgb_dir_source.cc ---------------------------------------------------------------------- diff --git a/src/test/model/test_rgb_dir_source.cc b/src/test/model/test_rgb_dir_source.cc new file mode 100644 index 0000000..36ac21a --- /dev/null +++ b/src/test/model/test_rgb_dir_source.cc @@ -0,0 +1,63 @@ +// Copyright © 2014 Wei Wang. All Rights Reserved. +// 2014-07-21 21:52 + +#include <gtest/gtest.h> +#include <glog/logging.h> +#include <algorithm> + +#include "proto/model.pb.h" +#include "disk/rgb_dir_source.h" +#include "disk/label_source.h" + +namespace lapis { +class RGBDirSourceTest : public ::testing::Test { + public: + RGBDirSourceTest() { + DataSourceProto ds; + ds.set_path("src/test/data/rgb_dir"); + ds.set_mean_file("src/test/data/imagenet_mean.binaryproto"); + ds.set_size(3); + ds.set_height(256); + ds.set_width(256); + ds.set_offset(2); + ds.set_name("rgb dir source"); + rgbs.Init(ds); + } + + protected: + RGBDirSource rgbs; +}; + +TEST_F(RGBDirSourceTest, LoadDataNoInputKeys) { + auto &ptr2names = rgbs.LoadData(nullptr); + EXPECT_EQ(3, ptr2names->size()); + sort(ptr2names->begin(), ptr2names->end()); + EXPECT_STREQ("img0.JPEG", ptr2names->at(0).c_str()); + EXPECT_STREQ("img1.JPEG", ptr2names->at(1).c_str()); + EXPECT_STREQ("img2.JPEG", ptr2names->at(2).c_str()); +} + +TEST_F(RGBDirSourceTest, LoadDataWithInputKeys) { + LabelSource ls; + DataSourceProto ds; + ds.set_path("src/test/data/label_source.dat"); + ds.set_name("label source"); + ds.set_size(3); + ls.Init(ds); + auto ptr2names1 = ls.LoadData(nullptr); + auto ptr2names2 = rgbs.LoadData(ptr2names1); + EXPECT_EQ(3, ptr2names2->size()); + for (int i = 0; i < 3; i++) + EXPECT_STREQ(ptr2names1->at(i).c_str(), ptr2names2->at(i).c_str()); +} + +TEST_F(RGBDirSourceTest, GetData) { + Blob b; + b.Resize(256,256,3,2); + rgbs.LoadData(nullptr); + rgbs.GetData(&b); + rgbs.GetData(&b); + rgbs.GetData(&b); +} +} // namespace lapis + http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b2dc51d2/src/test/test_cluster.cc ---------------------------------------------------------------------- diff --git a/src/test/test_cluster.cc b/src/test/test_cluster.cc new file mode 100644 index 0000000..d86463a --- /dev/null +++ b/src/test/test_cluster.cc @@ -0,0 +1,95 @@ +#include <fstream> +#include "gtest/gtest.h" +#include "proto/cluster.pb.h" +#include "utils/cluster.h" + +using namespace singa; + +string folder="src/test/data/"; +/* +ClusterProto GenClusterProto(){ + ClusterProto proto; + int nworker=6, nserver=4; + proto.set_nworkers(nworker); + proto.set_nservers(nserver); + proto.set_nworkers_per_group(3); + proto.set_nservers_per_group(2); + proto.set_nthreads_per_worker(1); + proto.set_nthreads_per_server(2); + + proto.set_hostfile(folder+"/hostfile"); + + std::ofstream fout(folder+"/hostfile", std::ofstream::out); + for(int i=0;i<nworker+nserver;i++){ + char tmp[20]; + sprintf(tmp, "awan-0-%02d-0", i); + fout<<tmp<<std::endl; + } + fout.flush(); + fout.close(); + return proto; +} + +TEST(ClusterTest, NoServer){ + ClusterProto proto=GenClusterProto(); + proto.set_nservers(0); + auto cluster=Cluster::Get(proto, 0); + ASSERT_EQ(proto.nworkers(),cluster->nworkers()); + ASSERT_EQ(0, cluster->nservers()); + ASSERT_EQ(proto.nworkers_per_group(),cluster->nworkers_per_group()); + ASSERT_EQ(proto.nservers_per_group(),cluster->nservers_per_group()); + ASSERT_FALSE(cluster->AmIServer()); + ASSERT_TRUE(cluster->AmIWorker()); + ASSERT_EQ(0,cluster->group_procs_id()); + ASSERT_EQ(0,cluster->group_id()); + ASSERT_EQ(2, cluster->nworker_groups()); + ASSERT_EQ(0, cluster->nserver_groups()); + ASSERT_STREQ("awan-0-00-0", cluster->host_addr().c_str()); + + cluster=Cluster::Get(proto, 5); + ASSERT_EQ(2,cluster->group_procs_id()); + ASSERT_EQ(1,cluster->group_id()); + ASSERT_EQ(2, cluster->nworker_groups()); + ASSERT_EQ(0, cluster->nserver_groups()); + ASSERT_STREQ("awan-0-05-0", cluster->host_addr().c_str()); +} + +TEST(ClusterTest, SingleServerGroup){ + ClusterProto proto=GenClusterProto(); + proto.set_nservers(2); + auto cluster=Cluster::Get(proto, 3); + ASSERT_FALSE(cluster->AmIServer()); + ASSERT_TRUE(cluster->AmIWorker()); + ASSERT_EQ(0,cluster->group_procs_id()); + ASSERT_EQ(1,cluster->group_id()); + ASSERT_EQ(2, cluster->nworker_groups()); + ASSERT_EQ(1, cluster->nserver_groups()); + ASSERT_STREQ("awan-0-03-0", cluster->host_addr().c_str()); + + cluster=Cluster::Get(proto, 7); + ASSERT_EQ(1,cluster->group_procs_id()); + ASSERT_EQ(0,cluster->group_id()); + ASSERT_EQ(2, cluster->nworker_groups()); + ASSERT_EQ(1, cluster->nserver_groups()); + ASSERT_STREQ("awan-0-07-0", cluster->host_addr().c_str()); +} + +TEST(ClusterTest, MultiServerGroups){ + ClusterProto proto=GenClusterProto(); + auto cluster=Cluster::Get(proto, 7); + ASSERT_EQ(1,cluster->group_procs_id()); + ASSERT_EQ(0,cluster->group_id()); + ASSERT_EQ(2, cluster->nworker_groups()); + ASSERT_EQ(2, cluster->nserver_groups()); + ASSERT_STREQ("awan-0-07-0", cluster->host_addr().c_str()); + + cluster=Cluster::Get(proto, 8); + ASSERT_TRUE(cluster->AmIServer()); + ASSERT_FALSE(cluster->AmIWorker()); + ASSERT_EQ(0,cluster->group_procs_id()); + ASSERT_EQ(1,cluster->group_id()); + ASSERT_EQ(2, cluster->nworker_groups()); + ASSERT_EQ(2, cluster->nserver_groups()); + ASSERT_STREQ("awan-0-08-0", cluster->host_addr().c_str()); +} +*/ http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b2dc51d2/src/test/test_communication.cc ---------------------------------------------------------------------- diff --git a/src/test/test_communication.cc b/src/test/test_communication.cc new file mode 100644 index 0000000..c9c035f --- /dev/null +++ b/src/test/test_communication.cc @@ -0,0 +1,158 @@ +#include <thread> +#include <vector> +#include "gtest/gtest.h" +#include "communication/msg.h" +#include "communication/socket.h" +using std::vector; +using namespace singa; + +const char* ping="PING",*pong="PONG"; +/** + * Connect dealer with (gid, id, flag) to stub router + */ +void Connect(Dealer* dealer, int gid, int id, int flag){ + dealer->Connect("inproc://router"); + Msg msg; + msg.set_src(gid, id, flag); + msg.set_dst(0,0,2); + msg.set_type(0); + msg.add_frame(ping, 4); + dealer->Send(&msg); +} + +/** + * Dealer thread, ping-pong with the stub router + */ +void DealerPingPong(int id){ + Dealer* dealer=new Dealer(); + Connect(dealer, 0, id, 0); + Msg* msg=dealer->Receive(); + int flag=msg->src_flag(); + ASSERT_EQ(2, flag); + ASSERT_EQ(0, msg->dst_group_id()); + ASSERT_EQ(id, msg->dst_id()); + ASSERT_STREQ(pong, (char*)msg->frame_data()); + delete msg; + delete dealer; +} + +/** + * Worker thread, connect to router and communicate with server thread + */ +void WorkerDealer(int sid, int did){ + Dealer* dealer=new Dealer(); + Connect(dealer, 0, sid, 0); + for(int i=0;i<2;i++){ + { + Msg msg; + msg.set_src(0, sid, 0); + msg.set_dst(0, did, 1); + msg.set_type(3); + msg.set_target(i); + dealer->Send(&msg); + } + { + Msg *msg=dealer->Receive(); + ASSERT_EQ(0, msg->src_group_id()); + ASSERT_EQ(did, msg->src_id()); + ASSERT_EQ(1, msg->src_flag()); + delete msg; + } + } + delete dealer; +} + +/** + * Server thread, connect to router and communicate with worker thread + */ +void ServerDealer(int id, int n){ + Dealer* dealer=new Dealer(); + Connect(dealer, 0, id, 1); + for(int i=0;i<n;i++){ + Msg *msg=dealer->Receive(); + Msg reply; + reply.set_dst(msg->src_group_id(), msg->src_id(), msg->src_flag()); + reply.set_src(0, id, 1); + dealer->Send(&reply); + delete msg; + } + delete dealer; +} + +TEST(CommunicationTest, DealerRouterPingPong){ + int n=2; + vector<std::thread> threads; + for(int i=0;i<n;i++) + threads.push_back(std::thread(DealerPingPong, i)); + Router* router=new Router(); + router->Bind(""); + for(int k=0;k<n;k++){ + Msg* msg=router->Receive(); + ASSERT_EQ(0, msg->src_group_id()); + ASSERT_EQ(2, msg->dst_flag()); + ASSERT_STREQ(ping, (char*)msg->frame_data()); + + Msg reply; + reply.set_src(0,0,2); + reply.set_dst(msg->src_group_id(), msg->src_id(), msg->src_flag()); + reply.add_frame(pong, 4); + router->Send(&reply); + delete msg; + } + + delete router; + for(auto& thread:threads) + thread.join(); +} + +TEST(CommunicationTest, nWorkers1Server){ + int nworker=2; + vector<std::thread> threads; + for(int i=0;i<nworker;i++) + threads.push_back(std::thread(WorkerDealer, i, 0)); + //threads.push_back(std::thread(ServerDealer, 0, 4)); + Router* router=new Router(); + router->Bind(""); + int nmsg=4*nworker; + int k=0; + while(nmsg>0){ + Msg* msg=router->Receive(); + if(2== msg->dst_flag()){ + ASSERT_STREQ(ping, (char*)msg->frame_data()); + k++; + if(k==nworker) + threads.push_back(std::thread(ServerDealer, 0, 2*nworker)); + }else{ + nmsg--; + router->Send(msg); + } + delete msg; + } + delete router; + for(auto& thread:threads) + thread.join(); +} + +TEST(CommunicationTest, 2Workers2Server){ + vector<std::thread> threads; + threads.push_back(std::thread(WorkerDealer, 0, 0)); + threads.push_back(std::thread(WorkerDealer, 1, 1)); + threads.push_back(std::thread(ServerDealer, 0, 2)); + threads.push_back(std::thread(ServerDealer, 1, 2)); + Router* router=new Router(); + router->Bind(""); + int n=8; + while(n>0){ + Msg* msg=router->Receive(); + if(2== msg->dst_flag()){ + ASSERT_STREQ(ping, (char*)msg->frame_data()); + }else{ + n--; + router->Send(msg); + } + delete msg; + } + delete router; + for(auto& thread:threads) + thread.join(); +} http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b2dc51d2/src/test/test_shard.cc ---------------------------------------------------------------------- diff --git a/src/test/test_shard.cc b/src/test/test_shard.cc new file mode 100644 index 0000000..c96d876 --- /dev/null +++ b/src/test/test_shard.cc @@ -0,0 +1,56 @@ +#include <gtest/gtest.h> +#include <sys/stat.h> + +#include "utils/data_shard.h" + +std::string key[]={"firstkey","secondkey","3key", "key4", "key5"}; +std::string tuple[]={"firsttuple","2th-tuple","thridtuple", "tuple4", "tuple5"}; + +using namespace singa; + +TEST(DataShardTest, CreateDataShard){ + std::string path="src/test/data/shard_test"; + mkdir(path.c_str(), 0755); + DataShard shard(path, DataShard::kCreate, 50); + shard.Insert(key[0], tuple[0]); + shard.Insert(key[1], tuple[1]); + shard.Insert(key[2], tuple[2]); + shard.Flush(); +} + +TEST(DataShardTest, AppendDataShard){ + std::string path="src/test/data/shard_test"; + DataShard shard(path, DataShard::kAppend, 50); + shard.Insert(key[3], tuple[3]); + shard.Insert(key[4], tuple[4]); + shard.Flush(); +} +TEST(DataShardTest, CountDataShard){ + std::string path="src/test/data/shard_test"; + DataShard shard(path, DataShard::kRead, 50); + int count=shard.Count(); + ASSERT_EQ(5, count); +} + +TEST(DataShardTest, ReadDataShard){ + std::string path="src/test/data/shard_test"; + DataShard shard(path, DataShard::kRead, 50); + std::string k, t; + ASSERT_TRUE(shard.Next(&k, &t)); + ASSERT_STREQ(key[0].c_str(), k.c_str()); + ASSERT_STREQ(tuple[0].c_str(), t.c_str()); + ASSERT_TRUE(shard.Next(&k, &t)); + ASSERT_STREQ(key[1].c_str(), k.c_str()); + ASSERT_STREQ(tuple[1].c_str(), t.c_str()); + ASSERT_TRUE(shard.Next(&k, &t)); + ASSERT_TRUE(shard.Next(&k, &t)); + ASSERT_TRUE(shard.Next(&k, &t)); + ASSERT_STREQ(key[4].c_str(), k.c_str()); + ASSERT_STREQ(tuple[4].c_str(), t.c_str()); + + ASSERT_FALSE(shard.Next(&k, &t)); + shard.SeekToFirst(); + ASSERT_TRUE(shard.Next(&k, &t)); + ASSERT_STREQ(key[0].c_str(), k.c_str()); + ASSERT_STREQ(tuple[0].c_str(), t.c_str()); +} http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b2dc51d2/src/trainer/pm_server.cc ---------------------------------------------------------------------- diff --git a/src/trainer/pm_server.cc b/src/trainer/pm_server.cc new file mode 100644 index 0000000..28fa28d --- /dev/null +++ b/src/trainer/pm_server.cc @@ -0,0 +1,99 @@ +#include <gflags/gflags.h> +#include <glog/logging.h> +#include "trainer/pm_server.h" +#include "utils/singleton.h" +#include "utils/factory.h" +#include <vector> + +using std::vector; + +namespace singa{ +void PMServer::Setup(int group_id, int server_id, shared_ptr<ParamShard> shard, + const UpdaterProto& proto){ + group_id_=group_id; + server_id_=server_id; + shard_=shard; + updater_=shared_ptr<Updater>(Singleton<Factory<Updater>>::Instance() + ->Create("Updater")); + updater_->Init(proto); +} + +PMServer::~PMServer(){ +} + +bool PMServer::SyncNow(){ + return false; +} +Msg* PMServer::HandlePut(Msg **msg){ + int id=(*msg)->target(); + shared_ptr<Param> param=nullptr; + if(shard_->find(id)!=shard_->end()){ + LOG(ERROR)<<"Param ("<<id<<") is put more than once"; + param=shard_->at(id); + }else{ + param=shared_ptr<Param>(Singleton<Factory<Param>>::Instance() + ->Create("Param")); + param->set_id(id); + (*shard_)[id]=param; + } + return param->HandlePutMsg(msg); +} + +Msg* PMServer::HandleGet(Msg **msg){ + int id=(*msg)->target(); + shared_ptr<Param> param=nullptr; + if(shard_->find(id)!=shard_->end()){ + param=shard_->at(id); + return param->HandleGetMsg(msg); + } else { + //re-construct msg to be re-queued. + //the calling function will send this message off + return *msg; + } +} + +Msg* PMServer::HandleUpdate(Msg **msg) { + int id=(*msg)->target(); + shared_ptr<Param> param=nullptr; + if(shard_->find(id)!=shard_->end()){ + //repsonse of the format: <identity><type: kData><paramId><param content> + param=shard_->at(id); + Msg* tmp=static_cast<Msg*>((*msg)->CopyAddr()); + param->ParseUpdateMsg(msg); + updater_->Update(param->version(), param); + param->set_version(param->version()+1); + auto response=param->GenUpdateResponseMsg(); + tmp->SwapAddr(); + response->SetAddr(tmp); + delete tmp; + return response; + } else { + LOG(ERROR)<<"Param ("<<id<<") is not maintained by server ("<<group_id_ + <<", "<<server_id_<<")"; + //re-construct msg to be re-queued. + return *msg; + } +} + +Msg* PMServer::HandleSyncRequest(Msg **msg){ + int id=(*msg)->target(); + shared_ptr<Param> param=nullptr; + if(shard_->find(id)!=shard_->end()){ + //repsonse of the format: <identity><type: kData><paramId><param content> + param=shard_->at(id); + return param->HandleSyncMsg(msg); + } else { + //re-construct msg to be re-queued. + return *msg; + } +} + +int PMServer::HandleSyncResponse(Msg **msg){ + int id=(*msg)->target(); + CHECK(shard_->find(id)!=shard_->end()); + return shard_->at(id)->ParseSyncResponseMsg(msg); +} + +} // namespace singa + + http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b2dc51d2/src/trainer/pm_worker.cc ---------------------------------------------------------------------- diff --git a/src/trainer/pm_worker.cc b/src/trainer/pm_worker.cc new file mode 100644 index 0000000..7269578 --- /dev/null +++ b/src/trainer/pm_worker.cc @@ -0,0 +1,344 @@ +#include <sys/types.h> +#include <sys/stat.h> +#include <fcntl.h> +#include "gflags/gflags.h" +#include <glog/logging.h> +#include "proto/model.pb.h" +#include "trainer/pm_worker.h" +#include "mshadow/tensor.h" +#include "utils/cluster.h" + + +namespace singa{ + +void PMWorker::Setup(int group_id, int worker_id, + shared_ptr<ParamShard> shard){ + group_id_=group_id; + worker_id_=worker_id; + shard_=shard; +} +int PMWorker::Sharding(int param_id){ + return param_id%Cluster::Get()->nservers_per_group(); +} +/* +int PMWorker::Sharding(int param_id){ + static map<int, int> id2procs; + if(id2procs.find(param_id)==id2procs.end()){ + auto cluster=Cluster::Get(); + int server_group=group_id_%cluster->nserver_groups(); + int nprocs_per_server_group= + cluster->nservers_per_group()/cluster->nservers_per_procs(); + int procsid=server_group*nprocs_per_server_group+ + param_id%nprocs_per_server_group; + procsid= cluster->server_worker_separate()? + cluster->nworker_procs()+procsid:procsid; + id2procs[param_id]=procsid; + } + return id2procs[param_id]; +} +*/ + +Msg* PMWorker::Put(Msg** msg){ + return *msg; +} + +Msg* PMWorker::Put(shared_ptr<Param> param, int step){ + param->set_version(step); + // only owner can put shared parameter + if(param->owner()<0||param->owner()==param->id()){ + Msg* msg= param->GenPutMsg(&step); + msg->set_src(group_id_, worker_id_, kWorkerParam); + msg->set_dst(group_id_/Cluster::Get()->nworker_groups_per_server_group(), + Sharding(param->id()), kServer); + msg->set_type(kPut); + msg->set_target(param->id()); + return msg; + }else + return nullptr; +} + +Msg* PMWorker::Get(Msg** msg){ + return *msg; +} + +Msg* PMWorker::Get(shared_ptr<Param> param, int step){ + param->set_version(step); + bool send=false; + int id=param->id(); + shared_ptr<ParamCounter> entry=nullptr; + if(param->owner()>=0){ + entry=shard_->at(id); + entry->nGet++; + send=entry->nGet/entry->nLocal==step; + } + if(param->owner()<0||send){ + Msg* msg=nullptr; + if(param->owner()<0){ + msg=param->GenGetMsg(&step); + msg->set_dst(group_id_/Cluster::Get()->nworker_groups_per_server_group(), + Sharding(id), kServer); + } else { + msg=entry->param->GenGetMsg(&step); + msg->set_dst(entry->owner_procs,kStub); + } + msg->set_src(group_id_, worker_id_, kWorkerParam); + msg->set_type(kGet); + msg->set_target(id); + return msg; + }else + return nullptr; +} + +Msg* PMWorker::Update(Msg** msg){ + return *msg; +} +Msg* PMWorker::Update(shared_ptr<Param> param, int step){ + param->set_version(step); + bool send=false; + int id=param->id(); + shared_ptr<ParamCounter> entry; + if(param->owner()>=0){ + entry=shard_->at(param->id()); + entry->nGet++; + send=entry->nGet/entry->nLocal==step; + auto shape=mshadow::Shape1(param->size()); + mshadow::Tensor<mshadow::cpu,1> grad(param->mutable_cpu_grad(), shape); + mshadow::Tensor<mshadow::cpu,1> agg(entry->param->mutable_cpu_grad(), shape); + agg+=grad; + } + if(param->owner()<0||send){ + Msg* msg=nullptr; + if(param->owner()<0){ + msg=param->GenUpdateMsg(&step); + msg->set_dst(group_id_/Cluster::Get()->nworker_groups_per_server_group(), + Sharding(id), kServer); + } else { + entry->param->GenUpdateMsg(&step); + msg->set_dst(entry->owner_procs,kStub); + memset(param->mutable_cpu_data(), 0, sizeof(float)*param->size()); + } + msg->set_type(kUpdate); + msg->set_target(id); + msg->set_src(group_id_, worker_id_, kWorkerParam); + return msg; + }else + return nullptr; +} + +Msg* PMWorker::Collect(Msg** msg){ + int id=(*msg)->target(); + int type=(*msg)->type(); + auto pp=shard_->at(id)->param; + if(type==kRGet){ + pp->ParseGetResponseMsg(msg); + }else if(type==kRUpdate){ + pp->ParseUpdateResponseMsg(msg); + } + if(pp->owner()>=0){ + // forwarding to workers on other procs + } + delete (*msg); + *msg=nullptr; + return nullptr; +} + +/* +//id is the global worker id +SingaClient::SingaClient(int global_id, Topology &topology, vector<string> &hosts) { + //Read the config files and store endpoints + id_ = global_id; + + int n_workers = hosts.size() - topology.nservers(); + int n_worker_groups = topology.nworker_groups(); + int group_size = n_workers/n_worker_groups; + int server_group_size = topology.nservers()/topology.server_group_size(); + FLAGS_client_threads = topology.worker_threads(); + + local_id_ = (id_-topology.nservers())%group_size;//local worker id. + group_id_ = (id_-topology.nservers())/group_size; + + VLOG(3) << "Parsing client config for "<<hosts[id_]; + + //connect to all server in the server group group_id_ + int start_server_idx = group_id_*server_group_size; + int end_server_idx = start_server_idx+server_group_size; + + for (int i = start_server_idx; i < end_server_idx; i++) { + char *neighbor_endpoint = (char*) malloc(256); + sprintf(neighbor_endpoint, "tcp://%s:%d", hosts[i].c_str(), topology.port()); + neighbors_.push_back(neighbor_endpoint); + VLOG(3) << "Worker neighbor (server): "<<neighbor_endpoint; + } + + sprintf(backend_endpoint_, "inproc://singanus%d",id_); + + //Create shared paramshard + param_shard_ = new ParamShard(id_,0); +} + +void SingaClient::StartClient(){ + //Create and connect sockets to the server + vector<void *> server_sockets; + zctx_t *context = zctx_new(); + int nservers = neighbors_.size(); + int rc; + for (int i=0; i<nservers; i++){ + void *socket = zsocket_new(context, ZMQ_DEALER); + rc = zsocket_connect(socket, neighbors_[i]); + VLOG(3) << "Connected to neighbor " <<neighbors_[i]; + assert(rc==0); + server_sockets.push_back(socket); + } + + //Create and bind backend socket + void *backend = zsocket_new(context, ZMQ_ROUTER); + rc = zsocket_bind(backend, backend_endpoint_); + assert(rc==0); + + //Start client threads + for (int i=0; i<FLAGS_client_threads; i++){ + void * socket = zthread_fork(context, ClientThread, this); + zmsg_t *control_msg = zmsg_new(); + if (i==0 && local_id_==0) + zmsg_pushstr(control_msg,POPULATE); + else + zmsg_pushstr(control_msg, WAIT); + zmsg_send(&control_msg, socket); + } + + //Star the message loop + bool is_running = true; + int nsockets= nservers+1; + while (is_running) { + zmq_pollitem_t items[nsockets]; + for (int i = 0; i < nsockets-1; i++) + items[i] = {server_sockets[i], 0, ZMQ_POLLIN, 0}; + items[nsockets-1] = {backend, 0, ZMQ_POLLIN, 0}; + + int rc = zmq_poll(items,nsockets,-1); + if (rc<0) break; + + for (int i=0; i<nsockets-1; i++){ + if (items[i].revents & ZMQ_POLLIN){ + zmsg_t *msg = zmsg_recv(server_sockets[i]); + if (!msg){ + is_running = false; + break; + } + //forward to backend + zmsg_send(&msg, backend); + } + } + if (items[nsockets-1].revents & ZMQ_POLLIN){ + //compute serverId from paramId and forward to the socket + zmsg_t *msg = zmsg_recv(backend); + if (!msg) is_running=false; + zframe_t *identity = zmsg_pop(msg); + zframe_t *type = zmsg_pop(msg); + int paramId; + sscanf(zmsg_popstr(msg), "%d", ¶mId); + zmsg_pushstrf(msg,"%d",paramId); + zmsg_prepend(msg,&type); + zmsg_prepend(msg,&identity); + zmsg_send(&msg, server_sockets[param_to_server_id(paramId)]); + } + } + + zsocket_destroy(context, backend); + for (int i=0; i<nsockets-1; i++) + zsocket_destroy(context, server_sockets[i]); + zctx_destroy(&context); +} + +vector<Param*> gen_random_params() { + int size[] = { 1960000, 2500, 5000000, 2000, 3000000, 1500, 1500000, 1000, 500000, 500, 5000, 10 }; + vector<Param*> params; + for (int i = 0; i < 12; i++) { + ParamProto proto; + proto.set_id(i); + proto.set_init_method(ParamProto::kGaussain); + Param* p = new Param(); + p->Setup(proto, vector<int> { size[i] }, 0); + p->Init(); + params.push_back(p); + } + return params; +} + +//simple mapping +int SingaClient::param_to_server_id(int paramId){ + return paramId % neighbors_.size(); +} + +void ClientThread(void *args, zctx_t *ctx, void *pipe){ + SingaClient *client = static_cast<SingaClient*>(args); + + //Create back-end socket and connect to the main thread + void *backend = zsocket_new(ctx, ZMQ_DEALER); + int rc = zsocket_connect(backend, client->backend_endpoint()); + assert(rc==0); + //Create PMClient object + PMClient *pmclient = new PMClient(client->id(), client->param_shard(), backend); + + //FOR TESTING ONLY. REMOVE THIS! + //wait for control from main thread + vector<Param*> params = gen_random_params(); + zmsg_t *control_msg = zmsg_recv(pipe); + zframe_t *msg = zmsg_pop(control_msg); + if (zframe_streq(msg,WAIT)) + zclock_sleep(2000); //2s + else{ + for (int i=0; i<params.size(); i++){ + pmclient->Put(i, params[i]); + } + VLOG(3)<<"Done PUT requests for populating servers."; + zclock_sleep(2000); + } + zframe_destroy(&msg); + //END TESTING + LOG(ERROR) << "Done putting"; + + //first, get the params + + test_get(pmclient); + test_collect(pmclient); + + + int iterations = 1; + while (iterations<=200){ + VLOG(3) << "Iteration "<<iterations; + test_update(pmclient, params); + test_collect(pmclient); + iterations++; + } + + zsocket_destroy(ctx, backend); +} + +void test_get(PMClient *client){ + for (int i=0; i<12; i++){ + Param pm; + int status = client->Get(i, &pm); + assert(status==NON_LOCAL); + } +} + +void test_collect(PMClient *client){ + for (int i=0; i<12; i++){ + Param pm; + int64_t start_time = zclock_time(); + while (!client->Collect(&pm)) + zclock_sleep(1); + int64_t end_time = zclock_time(); + VLOG(3) << "Collected: " <<(end_time-start_time); + } +} + +void test_update(PMClient *client, vector<Param*> params){ + for (int i=0; i<params.size(); i++) + client->Update(i, params[i]); +} +*/ + + +} //namespace singa http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b2dc51d2/src/trainer/server.cc ---------------------------------------------------------------------- diff --git a/src/trainer/server.cc b/src/trainer/server.cc new file mode 100644 index 0000000..bf0ad03 --- /dev/null +++ b/src/trainer/server.cc @@ -0,0 +1,68 @@ +#include <list> +#include <tuple> +#include <queue> +#include "trainer/server.h" +#include "utils/param.h" +#include "utils/singleton.h" +#include "utils/factory.h" +#include "utils/cluster.h" + + +namespace singa { +Server::Server(int group_id, int server_id): + group_id_(group_id), server_id_(server_id){} + +void Server::Setup(const UpdaterProto& proto, + shared_ptr<PMServer::ParamShard> shard, + shared_ptr<Dealer> dealer){ + //VLOG(3) << "Parsing config file for host "<<hosts[id_] << " server id = " <<id_; + pmserver_=shared_ptr<PMServer>(Singleton<Factory<PMServer>>::Instance() + ->Create("PMServer")); + pmserver_->Setup(group_id_, server_id_, shard, proto); + dealer_=dealer; +} + +void Server::Run(){ + Msg* ping=new Msg(); + ping->set_src(group_id_, server_id_, kServer); + ping->set_dst(0,0,kStub); + ping->set_type(kConnect); + dealer_->Send(ping); + int timeout=Cluster::Get()->server_timeout(); + Poller poller; + poller.Add(dealer_.get()); + //start recv loop and process requests + while (true){ + Msg* msg=dealer_->Receive(); + if (msg==nullptr) + break; + Msg* response=nullptr; + int type=msg->type(); + switch (type){ + case kPut: + response = pmserver_->HandlePut(&msg); + break; + case kGet: + response = pmserver_->HandleGet(&msg); + break; + case kUpdate: + response = pmserver_->HandleUpdate(&msg); + break; + case kSyncRequest: + VLOG(3)<<"Handle SYNC-REQUEST"; + response = pmserver_->HandleSyncRequest(&msg); + break; + case kSyncResponse: + VLOG(3) << "Handle SYNC response"; + pmserver_->HandleSyncResponse(&msg); + break; + } + + if (response!=nullptr) + dealer_->Send(response); + } +} + + + +} /* singa */ http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b2dc51d2/src/trainer/trainer.cc ---------------------------------------------------------------------- diff --git a/src/trainer/trainer.cc b/src/trainer/trainer.cc new file mode 100644 index 0000000..3621b7e --- /dev/null +++ b/src/trainer/trainer.cc @@ -0,0 +1,206 @@ +#include <thread> +#include <vector> +#include <map> +#include <glog/logging.h> +#include "trainer/trainer.h" +using std::vector; +using std::map; + +namespace singa { +int ProcsIDOf(int group_id, int id, int flag){ + int procsid; + auto cluster=Cluster::Get(); + if(flag==kServer){ + procsid=group_id*cluster->nservers_per_group()/ + cluster->nservers_per_procs()+id/cluster->nservers_per_procs(); + if(cluster->server_worker_separate()) + procsid+=cluster->nworker_procs(); + }else if(flag==kWorkerLayer || flag==kWorkerParam){ + procsid=group_id*cluster->nworkers_per_group() + /cluster->nworkers_per_procs(); + if(cluster->nworkers_per_group()>cluster->nworkers_per_procs()) + procsid+=id/cluster->nworkers_per_procs(); + }else{ + LOG(ERROR)<<"Unkown flag ("<<flag<<")"; + } + return procsid; +} + +void Trainer::RegisterDefaultClasses(const singa::ModelProto& proto){ + // register all layers appearing in the neural net + singa::NeuralNet::RegisterLayers(); + Singleton<Factory<singa::Param>>::Instance()->Register( + "Param", CreateInstance(singa::Param, singa::Param)); + Singleton<Factory<singa::Updater>>::Instance() ->Register( + "Updater", CreateInstance(singa::SGDUpdater, singa::Updater)); + Singleton<Factory<singa::PMWorker>>::Instance() ->Register( + "PMWorker", CreateInstance(singa::PMWorker, singa::PMWorker)); + Singleton<Factory<singa::PMServer>>::Instance() ->Register( + "PMServer", CreateInstance(singa::PMServer, singa::PMServer)); + Singleton<Factory<singa::PMServer>>::Instance() ->Register( + "PMServer", CreateInstance(singa::PMServer, singa::PMServer)); +} + +void Trainer::Start(const ModelProto& mproto, const ClusterProto& cproto, + int procs_id){ + RegisterDefaultClasses(mproto); + + auto cluster=Cluster::Get(cproto, procs_id); + // create servers + vector<shared_ptr<Server>> servers; + int nSocket=1; // the first socket is the router + if(cluster->has_server()){ + int pid=cluster->procs_id(); + if(cluster->server_worker_separate()) + pid-=cluster->nworker_procs(); + int gid=pid*cluster->nservers_per_procs()/cluster->nservers_per_group(); + int start=pid*cluster->nservers_per_procs()%cluster->nservers_per_group(); + int end=start+cluster->nservers_per_group(); + // the ParamShard for servers consists of a dictionary of Param objects + auto shard=make_shared<PMServer::ParamShard>(); + for(int sid=start;sid<end;sid++){ + auto server=make_shared<Server>(gid, sid); + auto dealer=make_shared<Dealer>(nSocket++); + dealer->Connect(kInprocRouterEndpoint); + server->Setup(mproto.updater(), shard, dealer); + servers.push_back(server); + } + } + + // create workers + vector<shared_ptr<Worker>> workers; + if(cluster->has_worker()){ + auto net=NeuralNet::SetupNeuralNet(mproto.neuralnet(), kTrain); + int pid=cluster->procs_id(); + int gstart, gend, wstart, wend; + if(cluster->nworkers_per_group()>=cluster->nworkers_per_procs()){ + // all workers in this procs are from the same group + gstart=pid*cluster->nworkers_per_procs()/cluster->nworkers_per_group(); + gend=gstart+1; + wstart=pid*cluster->nworkers_per_procs()%cluster->nworkers_per_group(); + wend=wstart+cluster->nworkers_per_group(); + }else{ + // there are multiple groups in this procs + CHECK_EQ(cluster->nworkers_per_procs()%cluster->nworkers_per_group(),0); + int groups_per_procs= + cluster->nworkers_per_procs()/cluster->nworkers_per_group(); + gstart=pid*groups_per_procs; + gend=(pid+1)*groups_per_procs; + wstart=0; + wend=cluster->nworkers_per_group(); + } + for(int gid=gstart;gid<gend;gid++){ + shared_ptr<NeuralNet> train_net, test_net, validation_net; + if(gid==gstart) + train_net=net; + else{ + train_net=NeuralNet::SetupNeuralNet(mproto.neuralnet(), kTrain); + // the train net for other groups may share parameter values from the + // first group + if(mproto.hogwild()) + train_net->ShareParams(net, kValueOnly); + } + if(gid==0){ + // validation and test are performed only by the first group + if(mproto.test_steps()){ + test_net=NeuralNet::SetupNeuralNet(mproto.neuralnet(), kTest); + if(test_net!=nullptr) + test_net->ShareParams(train_net, kValueOnly); + } + if(mproto.validation_steps()){ + validation_net=NeuralNet::SetupNeuralNet(mproto.neuralnet(), kValidation); + if(validation_net!=nullptr) + validation_net->ShareParams(train_net, kValueOnly); + } + } + // create ParamShard for the workers + auto shard=make_shared<PMWorker::ParamShard>(); + for(auto layer: train_net->layers()){ + int procsid=ProcsIDOf(gid, layer->locationid(),kWorkerParam); + int local=procsid==cluster->procs_id(); + for(auto param: layer->GetParams()){ + int owner=param->owner()<0||param->owner()==param->id()?procsid:-1; + if(shard->find(param->id())==shard->end()) + (*shard)[param->id()]=make_shared<ParamCounter>(param, local, owner); + else + shard->at(param->id())->AddParam(param, local, owner); + } + } + for(int wid=wstart;wid<wend;wid++){ + shared_ptr<Worker> worker=nullptr; + if(mproto.alg()==ModelProto_GradCalcAlg_kBackPropagation) + worker=make_shared<BPWorker>(gid, wid); + else{ + // TODO add CDWorker + } + auto layer_dealer=make_shared<Dealer>(nSocket++); + auto param_dealer=make_shared<Dealer>(nSocket++); + layer_dealer->Connect(kInprocRouterEndpoint); + param_dealer->Connect(kInprocRouterEndpoint); + worker->Setup(mproto, train_net, shard, layer_dealer, param_dealer); + worker->set_test_net(test_net); + worker->set_validation_net(validation_net); + workers.push_back(worker); + } + } + } + +#ifdef USE_MPI + for(int i=0;i<nSocket;i++){ + MPIQueues.push_back(make_shared<SafeQueue>()); + } +#endif + vector<std::thread> threads; + for(auto server: servers) + threads.push_back(std::thread(&Server::Run,server)); + for(auto worker: workers) + threads.push_back(std::thread(&Worker::Run,worker)); + Run(); + for(auto& thread: threads) + thread.join(); +} + +void Trainer::Run(){ + auto cluster=Cluster::Get(); + auto router=make_shared<Router>(); + router->Bind(kInprocRouterEndpoint); + if(cluster->nprocs()>1) + router->Bind(cluster->endpoint()); + + map<int, shared_ptr<Dealer>> interprocs_dealers; + Poller poller; + poller.Add(router.get()); + int timeout=cluster->stub_timeout(); + while(true){ + Msg* msg=router->Receive(); + if(msg==nullptr){ + LOG(ERROR)<<"Connection broken!"; + exit(0); + } + int dst_flag=msg->dst_flag(); + int type=msg->type(); + int group_id, id, procs_id; + switch (dst_flag){ // TODO process other requests, e.g. RESTful + case kStub: + if(type==kConnect){ + delete msg; + }else{ + // TODO processing requests for worker group spanning multiple procs. + LOG(ERROR)<<"Unkown message type ("<<type<<") to stub"; + } + break; + default: + group_id=msg->dst_group_id(); + id=msg->dst_id(); + procs_id=ProcsIDOf(group_id, id, dst_flag); + if(procs_id!=cluster->procs_id()){ + if (interprocs_dealers.find(procs_id)==interprocs_dealers.end()) + interprocs_dealers[procs_id]=make_shared<Dealer>(procs_id); + interprocs_dealers[procs_id]->Send(msg); + } else + router->Send(msg); + break; + } + } +} +} /* singa */ http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b2dc51d2/src/trainer/worker.cc ---------------------------------------------------------------------- diff --git a/src/trainer/worker.cc b/src/trainer/worker.cc new file mode 100644 index 0000000..047ec2d --- /dev/null +++ b/src/trainer/worker.cc @@ -0,0 +1,299 @@ +#include <glog/logging.h> +#include <thread> +#include <memory> +#include <iostream> +#include "utils/singleton.h" +#include "utils/factory.h" +#include "trainer/worker.h" +#include "proto/model.pb.h" +using std::thread; +namespace singa { +Worker::Worker( int group_id, int worker_id): + group_id_(group_id), worker_id_(worker_id){ +} + +void Worker::Setup(const ModelProto& model, + shared_ptr<NeuralNet> train_net, + shared_ptr<PMWorker::ParamShard> shard, + shared_ptr<Dealer> layer_dealer, + shared_ptr<Dealer> param_dealer){ + train_net_=train_net; + modelproto_=model; + layer_dealer_=layer_dealer; + param_dealer_=param_dealer; + if(layer_dealer_!=nullptr) + layer_poller_.Add(layer_dealer_.get()); + if(param_dealer_!=nullptr) + param_poller_.Add(param_dealer_.get()); + pmworker_=shared_ptr<PMWorker>(Singleton<Factory<PMWorker>>::Instance() + ->Create("PMWorker")); + pmworker_->Setup(group_id_, worker_id_, shard); + step_=modelproto_.step(); + // init params + for(auto layer: train_net->layers()) + if(group_id_==0&&layer->locationid()==worker_id_) + for(auto param: layer->GetParams()){ + if(param->owner()<0||param->owner()==param->id()){ + param->Init(); + Put(param, step_); + } + Get(param, step_); + } +} + +void Worker::Run(){ + step_=modelproto_.step(); + Performance perf(train_net_); + try{ + while(!StopNow(step_)){ + RunOneBatch(step_, &perf); + step_++; + } + }catch(WorkerException& e){ + LOG(ERROR)<<e.what(); + } +} +int Worker::Put(shared_ptr<Param> param, int step){ + auto msg=pmworker_->Put(param, step); + if(msg!=nullptr) + param_dealer_->Send(msg); + return 1; +} +int Worker::Get(shared_ptr<Param> param, int step){ + if(param->version()<step){ + auto msg=pmworker_->Get(param, step); + if(msg!=nullptr) + param_dealer_->Send(msg); + } + return 1; +} +int Worker::Update(shared_ptr<Param> param, int step){ + auto msg=pmworker_->Update(param, step); + if(msg!=nullptr) + param_dealer_->Send(msg); + return 1; +} +int Worker::Collect(shared_ptr<Param> param, int step){ + while(param->version()<step){ + Msg* msg=param_dealer_->Receive(); + if(msg==nullptr) + return 0; + pmworker_->Collect(&msg); + } + return 1; +} + +void Worker::RunOneBatch(int step, Performance* perf){ + //DLOG(ERROR)<<"Step "<<step; + // Test will call Pull which updates the sync time + // Hence we store the sync time, and restore it later + //float tSyncData=tSyncData_, tSyncParam=tSyncParam_; + if(ValidateNow(step)){ + LOG(ERROR)<<"Validation at step "<<step; + Test(validation_net_, modelproto_.validation_steps(), perf!=nullptr); + } + if(TestNow(step)){ + LOG(ERROR)<<"Test at step "<<step; + Test(test_net_, modelproto_.test_steps(), perf!=nullptr); + } + //tSyncData_=tSyncData; tSyncParam_=tSyncParam; + + TrainOneBatch(step); + if(perf!=nullptr){ + perf->Update(); + if(DisplayNow(step)){ + LOG(ERROR)<<"Training at step "<<step; + LOG(ERROR)<<"\t"<<perf->ToString(); + perf->Reset(); + //LOG(ERROR)<<"\t"<<TimerInfo(); + } + } + + /* + if(CheckpointNow(step)){ + pm_->Checkpoint(cluster_->workspace()+"/snapshot-"+std::to_string(step)); + } + */ +} + +void Worker::ReceiveBlobs(shared_ptr<NeuralNet> net){ + /* + int type; + char *name; + int64_t tick=zclock_mono(); + zframe_t* frame=zframe_new_empty(); + + zsock_recv(pull_, "isf", &type, &name, &frame); + if(type==kDataFrame){ + auto* dst=static_cast<BridgeDstLayer*>( + net->name2layer(string(name)).get()); + memcpy(dst->mutable_data()->mutable_cpu_data(), zframe_data(frame), + zframe_size(frame)); + dst->set_ready(true); + }else if(type==kGradFrame){ + auto* src=static_cast<BridgeSrcLayer*>(net->name2layer(string(name)).get()); + memcpy(src->mutable_grad()->mutable_cpu_data(), zframe_data(frame), + zframe_size(frame)); + src->set_ready(true); + } + zframe_destroy(&frame); + delete name; + tSyncData_+=zclock_mono()-tick; + */ +} + +void Worker::SendBlob(){ + +} + +void Worker::Test(shared_ptr<NeuralNet> net, int nsteps, bool disperf){ + Performance perf(net); + for(int step=0;step<nsteps;step++){ + TestOneBatch(net, step, kTest); + if(disperf) + perf.Update(); + } + if(disperf) + LOG(ERROR)<<"\t"<<perf.ToString(); +} + +/****************************BPWorker**********************************/ + +void BPWorker::Forward(shared_ptr<NeuralNet> net, int step, bool training){ + auto& layers=net->layers(); + for(auto& layer: layers){ + if(layer->locationid()==worker_id_){ + if(layer->is_bridgedstlayer()){ + //auto* dst=static_cast<BridgeDstLayer*>(layer.get()); + // receive fea blobs + } + if(training){ + for(shared_ptr<Param> p: layer->GetParams()){ + if(Collect(p, step)==0){ + throw WorkerException(); + } + } + } + layer->ComputeFeature(training); + if(layer->is_bridgesrclayer()){ + // send fea blobs + } + if(training&&DisplayDebugInfo(step)&&layer->mutable_data()!=nullptr){ + LOG(INFO)<<StringPrintf("Forward layer %10s data norm1 %13.9f", + layer->name().c_str(), layer->data().asum_data()); + } + } + } +} + +void BPWorker::Backward(shared_ptr<NeuralNet> net, int step){ + auto& layers=net->layers(); + for (auto it = layers.rbegin(); it != layers.rend(); it++){ + shared_ptr<Layer> layer=*it; + if(layer->locationid()==worker_id_){ + if(layer->is_bridgesrclayer()){ + //auto* src=static_cast<BridgeSrcLayer*>(layer.get()); + // receive grad blobs + } + layer->ComputeGradient(); + if(DisplayDebugInfo(step)&&layer->mutable_grad()!=nullptr){ + LOG(INFO)<<StringPrintf("Backward layer %10s grad norm1 %13.9f\t", + layer->name().c_str(), layer->grad().asum_data()); + for(shared_ptr<Param> p: layer->GetParams()) + LOG(INFO)<<StringPrintf("param id %2d, name %10s,\ + value norm1 %13.9f, grad norm1 %13.9f", + p->id(), p->name().c_str(), + p->data().asum_data(), p->grad().asum_data()); + } + for(shared_ptr<Param> p: layer->GetParams()){ + Update(p, step); + } + if(layer->is_bridgedstlayer()){ + // send grad blobs + } + } + } +} + +void BPWorker::TrainOneBatch(int step){ + Forward(train_net_, step, true); + Backward(train_net_, step); +} + +void BPWorker::TestOneBatch(shared_ptr<NeuralNet> net,int step, Phase phase){ + Forward(net, step, false); +} + +/*********************Implementation for Performance class*******************/ +Performance::Performance(shared_ptr<NeuralNet> net):net_(net), counter_(0){ + for(auto& layer: net->losslayers()){ + name_.push_back(layer->name()); + metric_.push_back(vector<float>{}); + metric_.back().resize(layer->metric().count(),0.f); + } +} + +void Performance::Update(){ + const auto& losslayers=net_->losslayers(); + for(size_t i=0;i<losslayers.size();i++){ + const float * ptr=losslayers[i]->metric().cpu_data(); + vector<float>& m=metric_.at(i); + for(int j=0;j<losslayers[i]->metric().count();j++) + m[j]+=ptr[j]; + } + counter_++; +} + +void Performance::Reset(){ + for(auto& m: metric_) + for(auto& x: m) + x=0.f; + counter_=0; +} + +string Performance::ToString(){ + string disp=""; + for(size_t i=0;i<metric_.size();i++){ + disp+="Output from "+name_[i]+" layer "; + vector<float> m=metric_.at(i); + for(size_t j=0;j<m.size();j++) + disp+=std::to_string(j)+" : "+std::to_string(m[j]/counter_)+"\t"; + disp+="\n"; + } + return disp; +} +/* +void Executor::Setup(int local_threadid, const ModelProto& model){ + tForward_=tBackward_=tSyncData_=tSyncParam_=0; + modelproto_=model; + local_threadid_=local_threadid; + if(model.prefetch()){ + for(auto& layer: train_net_->datalayers()){ + if(cluster_->group_threadid(local_threadid_)==layer->locationid()) + localDataLayers_.push_back(layer); + } + if(localDataLayers_.size()) + prefetch_thread_=std::thread(Executor::PrefetchData, + std::ref(localDataLayers_), true,1); + } + int gthreadid=cluster_->group_threadid(local_threadid); +} + +void Executor::PrefetchData(const vector<DataLayer*>& datalayers, bool training, + int steps){ + if(datalayers.size()==0) + return; + for(int i=0;i<steps;i++){ + for(auto& layer: datalayers){ + layer->Prefetching(training); + for(auto& dstlayer: layer->dstlayers()){ + CHECK(dstlayer->is_parserlayer()); + auto parserlayer=static_cast<ParserLayer*>(dstlayer.get()); + parserlayer->Prefetching(training); + } + } + } +} +*/ + +} // namespace singa http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b2dc51d2/src/utils/blob.cc ---------------------------------------------------------------------- diff --git a/src/utils/blob.cc b/src/utils/blob.cc new file mode 100644 index 0000000..92fc989 --- /dev/null +++ b/src/utils/blob.cc @@ -0,0 +1,330 @@ +/** + * The code is adapted from that of Caffe whose license is attached. + * + * COPYRIGHT + * All contributions by the University of California: + * Copyright (c) 2014, The Regents of the University of California (Regents) + * All rights reserved. + * All other contributions: + * Copyright (c) 2014, the respective contributors + * All rights reserved. + * Caffe uses a shared copyright model: each contributor holds copyright over + * their contributions to Caffe. The project versioning records all such + * contribution and copyright details. If a contributor wants to further mark + * their specific copyright on a particular contribution, they should indicate + * their copyright solely in the commit message of the change when it is + * committed. + * LICENSE + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR + * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * CONTRIBUTION AGREEMENT + * By contributing to the BVLC/caffe repository through pull-request, comment, + * or otherwise, the contributor releases their content to the + * license and copyright terms herein. + */ +#include <utility> +#include <math.h> +#include <cblas.h> +#include "utils/blob.h" +/*********************SyncedMemory implementation************************/ + +#define NO_GPU LOG(FATAL) << "CPU-only Mode: cannot make GPU call." +// Instantiate a class with float and double specifications. +#define INSTANTIATE_CLASS(classname) \ + template class classname<float>; \ + template class classname<double> +// Disable the copy and assignment operator for a class. +#define DISABLE_COPY_AND_ASSIGN(classname) \ +private:\ + classname(const classname&);\ + classname& operator=(const classname&) + +#ifndef CPU_ONLY +// CUDA: various checks for different function calls. +#define CUDA_CHECK(condition) \ + /* Code block avoids redefinition of cudaError_t error */ \ + do { \ + cudaError_t error = condition; \ + CHECK_EQ(error, cudaSuccess) << " " << cudaGetErrorString(error); \ + } while (0) + +#define CUBLAS_CHECK(condition) \ + do { \ + cublasStatus_t status = condition; \ + CHECK_EQ(status, CUBLAS_STATUS_SUCCESS) << " " \ + << caffe::cublasGetErrorString(status); \ + } while (0) + +#define CURAND_CHECK(condition) \ + do { \ + curandStatus_t status = condition; \ + CHECK_EQ(status, CURAND_STATUS_SUCCESS) << " " \ + << caffe::curandGetErrorString(status); \ + } while (0) + +#endif // CPU_ONLY + + +SyncedMemory::~SyncedMemory() { + if (cpu_ptr_ && own_cpu_data_) { + FreeHost(cpu_ptr_); + } + +#ifndef CPU_ONLY + if (gpu_ptr_) { + CUDA_CHECK(cudaFree(gpu_ptr_)); + } +#endif // CPU_ONLY +} + +inline void SyncedMemory::to_cpu() { + switch (head_) { + case UNINITIALIZED: + MallocHost(&cpu_ptr_, size_); + memset(cpu_ptr_,0, size_); + head_ = HEAD_AT_CPU; + own_cpu_data_ = true; + break; + case HEAD_AT_GPU: +#ifndef CPU_ONLY + if (cpu_ptr_ == NULL) { + MallocHost(&cpu_ptr_, size_); + own_cpu_data_ = true; + } + CUDA_CHECK(cudaMemcpy(cpu_ptr_, gpu_ptr_, size_, cudaMemcpyDefault)); + head_ = SYNCED; +#else + NO_GPU; +#endif + break; + case HEAD_AT_CPU: + case SYNCED: + break; + } +} + +inline void SyncedMemory::to_gpu() { +#ifndef CPU_ONLY + switch (head_) { + case UNINITIALIZED: + CUDA_CHECK(cudaMalloc(&gpu_ptr_, size_)); + CUDA_CHECK(cudaMemset(gpu_ptr_, 0, N)); // NOLINT(caffe/alt_fn) + head_ = HEAD_AT_GPU; + break; + case HEAD_AT_CPU: + if (gpu_ptr_ == NULL) { + CUDA_CHECK(cudaMalloc(&gpu_ptr_, size_)); + } + CUDA_CHECK(cudaMemcpy( gpu_ptr_,cpu_ptr_, size_, cudaMemcpyDefault)); + head_ = SYNCED; + break; + case HEAD_AT_GPU: + case SYNCED: + break; + } +#else + NO_GPU; +#endif +} + +const void* SyncedMemory::cpu_data() { + to_cpu(); + return (const void*)cpu_ptr_; +} + +void SyncedMemory::set_cpu_data(void* data) { + CHECK(data); + if (own_cpu_data_) { + FreeHost(cpu_ptr_); + } + cpu_ptr_ = data; + head_ = HEAD_AT_CPU; + own_cpu_data_ = false; +} + +const void* SyncedMemory::gpu_data() { +#ifndef CPU_ONLY + to_gpu(); + return (const void*)gpu_ptr_; +#else + NO_GPU; +#endif + return nullptr; +} + +void* SyncedMemory::mutable_cpu_data() { + to_cpu(); + head_ = HEAD_AT_CPU; + return cpu_ptr_; +} + +void* SyncedMemory::mutable_gpu_data() { +#ifndef CPU_ONLY + to_gpu(); + head_ = HEAD_AT_GPU; + return gpu_ptr_; +#else + NO_GPU; +#endif + return nullptr; +} + +/*********************Blob implementation************************/ + +template <typename Dtype> +Blob<Dtype>::Blob(const vector<int>& shape) + // capacity_ must be initialized before calling Reshape + : capacity_(0) { + Reshape(shape); +} + +template <typename Dtype> +void Blob<Dtype>::Reshape(const vector<int>& shape) { + count_=1; + shape_=shape; + for(size_t i=0;i<shape.size();i++){ + CHECK(shape[i]); + count_*=shape[i]; + } + if (count_ > capacity_) { + capacity_ = count_; + data_.reset(new SyncedMemory(capacity_ * sizeof(Dtype))); + } +} + +template <typename Dtype> +void Blob<Dtype>::ReshapeLike(const Blob<Dtype>& other) { + Reshape(other.shape()); +} + +template <typename Dtype> +const Dtype* Blob<Dtype>::cpu_data() const { + CHECK(data_); + return (const Dtype*)data_->cpu_data(); +} + +template <typename Dtype> +void Blob<Dtype>::set_cpu_data(Dtype* data) { + CHECK(data); + data_->set_cpu_data(data); +} + +template <typename Dtype> +const Dtype* Blob<Dtype>::gpu_data() const { + CHECK(data_); + return (const Dtype*)data_->gpu_data(); +} + +template <typename Dtype> +Dtype* Blob<Dtype>::mutable_cpu_data() { + CHECK(data_); + return static_cast<Dtype*>(data_->mutable_cpu_data()); +} + +template <typename Dtype> +Dtype* Blob<Dtype>::mutable_gpu_data() { + CHECK(data_); + return static_cast<Dtype*>(data_->mutable_gpu_data()); +} + +template <typename Dtype> +void Blob<Dtype>::ShareData(const Blob& other) { + CHECK_EQ(count_, other.count()); + data_ = other.data(); +} + +template <> float Blob<float>::asum_data() const { + if(count()==0) + return 0.f; + return cblas_sasum(count(), cpu_data(), 1)/count(); +} +template <> float Blob<float>::sum_data() const { + if(count()==0) + return 0.f; + float sum=0.f; + const float *dptr=cpu_data(); + for(int i=0;i<count();i++) + sum+=dptr[i]; + return sum/count(); +} +template <> unsigned int Blob<unsigned int>::asum_data() const { + NOT_IMPLEMENTED; + return 0; +} + +template <> int Blob<int>::asum_data() const { + NOT_IMPLEMENTED; + return 0; +} + +template <typename Dtype> +void Blob<Dtype>::Swap(Blob& other){ + CHECK_EQ(other.count(), count()); + CHECK(std::equal(shape_.begin(), shape_.end(), other.shape_.begin())); + std::swap(data_, other.data_); + std::swap(capacity_, other.capacity_); +} + +template <typename Dtype> +void Blob<Dtype>::CopyFrom(const Blob& source, bool reshape) { + if (!std::equal(shape_.begin(),shape_.end(),source.shape_.begin())) { + if (reshape) { + Reshape(source.shape_); + } else { + LOG(FATAL) << "Trying to copy blobs of different sizes."; + } + } +#ifndef CPU_ONLY + CUDA_CHECK(cudaMemcpy(static_cast<Dtype*>(data_->mutable_gpu_data()), + source.gpu_data(), sizeof(Dtype) * count_, cudaMemcpyDefault)); +#endif + memcpy(static_cast<Dtype*>(data_->mutable_cpu_data()),source.cpu_data(), + sizeof(Dtype)*count_); +} + +/* +template <typename Dtype> +void Blob<Dtype>::FromProto(const BlobProto& proto) { + Reshape(); + // copy data + Dtype* data_vec = mutable_cpu_data(); + for (int i = 0; i < count_; ++i) { + data_vec[i] = proto.data(i); + } +} +*/ + +template <typename Dtype> +void Blob<Dtype>::ToProto(singa::BlobProto* proto) const { + proto->set_num(shape_[0]); + if(shape_.size()>1) + proto->set_channels(shape_[1]); + if(shape_.size()>2) + proto->set_height(shape_[2]); + if(shape_.size()>3) + proto->set_width(shape_[3]); + proto->clear_data(); + const Dtype* data_vec = cpu_data(); + for (int i = 0; i < count_; ++i) { + proto->add_data(data_vec[i]); + } +} + +INSTANTIATE_CLASS(Blob); +template class Blob<int>; +template class Blob<unsigned int>;
