Repository: incubator-singa Updated Branches: refs/heads/master 95b1e6dd3 -> b2dc51d23
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b2dc51d2/src/utils/cluster.cc ---------------------------------------------------------------------- diff --git a/src/utils/cluster.cc b/src/utils/cluster.cc new file mode 100644 index 0000000..ac47422 --- /dev/null +++ b/src/utils/cluster.cc @@ -0,0 +1,52 @@ +#include <glog/logging.h> +#include <fcntl.h> +#include <fstream> +#include "utils/cluster.h" +#include "proto/cluster.pb.h" +#include <sys/stat.h> +#include <sys/types.h> +namespace singa { + +std::shared_ptr<Cluster> Cluster::instance_; +Cluster::Cluster(const ClusterProto &cluster, int procs_id) { + procs_id_=procs_id; + cluster_ = cluster; + SetupFolders(cluster); + int nprocs; + if(server_worker_separate()) + nprocs=nworker_procs()+nserver_procs(); + else + nprocs=std::max(nworker_procs(), nserver_procs()); + CHECK_LT(procs_id, nprocs); + if (cluster_.has_nprocs()) + CHECK_EQ(cluster.nprocs(), nprocs); + else + cluster_.set_nprocs(nprocs); + if(nprocs>1){ + std::ifstream ifs(cluster.hostfile(), std::ifstream::in); + std::string line; + while(std::getline(ifs, line)&&endpoints_.size()<nprocs){ + endpoints_.push_back(line); + } + CHECK_EQ(endpoints_.size(), nprocs); + } +} + +void Cluster::SetupFolders(const ClusterProto &cluster){ + // create visulization folder + mkdir(vis_folder().c_str(), S_IRWXU | S_IRWXG | S_IROTH | S_IXOTH); +} + +shared_ptr<Cluster> Cluster::Get(const ClusterProto& cluster, int procs_id){ + instance_.reset(new Cluster(cluster, procs_id)); + return instance_; +} + +shared_ptr<Cluster> Cluster::Get() { + if(!instance_) { + LOG(ERROR)<<"The first call to Get should " + <<"provide the sys/model conf path"; + } + return instance_; +} +} // namespace singa http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b2dc51d2/src/utils/common.cc ---------------------------------------------------------------------- diff --git a/src/utils/common.cc b/src/utils/common.cc new file mode 100644 index 0000000..0697060 --- /dev/null +++ b/src/utils/common.cc @@ -0,0 +1,89 @@ +#include <fcntl.h> +#include <google/protobuf/io/coded_stream.h> +#include <google/protobuf/text_format.h> +#include <google/protobuf/io/zero_copy_stream_impl.h> +#include "utils/common.h" +using std::ios; +using std::max; +using google::protobuf::io::FileInputStream; +using google::protobuf::io::FileOutputStream; +using google::protobuf::io::ZeroCopyInputStream; +using google::protobuf::io::CodedInputStream; +using google::protobuf::io::ZeroCopyOutputStream; +using google::protobuf::io::CodedOutputStream; + +namespace singa { + +const int kBufLen=1024; +std::string IntVecToString(const vector<int>& vec) { + string disp="("; + for(int x: vec) + disp+=std::to_string(x)+", "; + return disp+")"; +} + +/** + * Formatted string. + */ +string VStringPrintf(string fmt, va_list l) { + char buffer[32768]; + vsnprintf(buffer, 32768, fmt.c_str(), l); + return string(buffer); +} + +/** + * Formatted string. + */ +string StringPrintf(string fmt, ...) { + va_list l; + va_start(l, fmt); //fmt.AsString().c_str()); + string result = VStringPrintf(fmt, l); + va_end(l); + return result; +} + +void Debug() { + int i = 0; + char hostname[256]; + gethostname(hostname, sizeof(hostname)); + printf("PID %d on %s ready for attach\n", getpid(), hostname); + fflush(stdout); + while (0 == i) + sleep(5); +} + +// the proto related functions are from Caffe. +void ReadProtoFromTextFile(const char* filename, + ::google::protobuf::Message* proto) { + int fd = open(filename, O_RDONLY); + CHECK_NE(fd, -1) << "File not found: " << filename; + FileInputStream* input = new FileInputStream(fd); + CHECK(google::protobuf::TextFormat::Parse(input, proto)); + delete input; + close(fd); +} +void WriteProtoToTextFile(const Message& proto, const char* filename) { + int fd = open(filename, O_WRONLY | O_CREAT, 0644); + FileOutputStream* output = new FileOutputStream(fd); + CHECK(google::protobuf::TextFormat::Print(proto, output)); + delete output; + close(fd); +} +void ReadProtoFromBinaryFile(const char* filename, Message* proto) { + int fd = open(filename, O_RDONLY); + CHECK_NE(fd, -1) << "File not found: " << filename; + ZeroCopyInputStream* raw_input = new FileInputStream(fd); + CodedInputStream* coded_input = new CodedInputStream(raw_input); + // upper limit 512MB, warning threshold 256MB + coded_input->SetTotalBytesLimit(536870912, 268435456); + CHECK(proto->ParseFromCodedStream(coded_input)); + delete coded_input; + delete raw_input; + close(fd); +} +void WriteProtoToBinaryFile(const Message& proto, const char* filename) { + int fd= open(filename, O_CREAT|O_WRONLY|O_TRUNC, 0644); + CHECK(proto.SerializeToFileDescriptor(fd)); +} + +} // namespace singa http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b2dc51d2/src/utils/data_shard.cc ---------------------------------------------------------------------- diff --git a/src/utils/data_shard.cc b/src/utils/data_shard.cc new file mode 100644 index 0000000..df311e1 --- /dev/null +++ b/src/utils/data_shard.cc @@ -0,0 +1,207 @@ +#include <sys/stat.h> +#include <glog/logging.h> + +#include "utils/data_shard.h" +namespace singa { + +DataShard::DataShard(std::string folder, char mode, int capacity){ + struct stat sb; + if(stat(folder.c_str(), &sb) == 0 && S_ISDIR(sb.st_mode)){ + LOG(INFO)<<"Open shard folder "<<folder; + }else{ + LOG(FATAL)<<"Cannot open shard folder "<<folder; + } + + path_= folder+"/shard.dat"; + if(mode==DataShard::kRead){ + fdat_.open(path_, std::ios::in|std::ios::binary); + CHECK(fdat_.is_open())<<"Cannot create file "<<path_; + } + if(mode==DataShard::kCreate){ + fdat_.open(path_, std::ios::binary|std::ios::out|std::ios::trunc); + CHECK(fdat_.is_open())<<"Cannot create file "<<path_; + } + if(mode==DataShard::kAppend){ + int last_tuple=PrepareForAppend(path_); + fdat_.open(path_, std::ios::binary|std::ios::out|std::ios::in|std::ios::ate); + CHECK(fdat_.is_open())<<"Cannot create file "<<path_; + fdat_.seekp(last_tuple); + } + + mode_=mode; + offset_=0; + bufsize_=0; + capacity_=capacity; + buf_=new char[capacity]; +} + +DataShard::~DataShard(){ + delete buf_; + fdat_.close(); +} + +bool DataShard::Insert(const std::string& key, const Message& val) { + std::string str; + val.SerializeToString(&str); + return Insert(key, str); +} +// insert one complete tuple +bool DataShard::Insert(const std::string& key, const std::string& val) { + if(keys_.find(key)!=keys_.end()||val.size()==0) + return false; + int size=key.size()+val.size()+2*sizeof(size_t); + if(offset_+size>capacity_){ + fdat_.write(buf_, offset_); + offset_=0; + CHECK_LE(size, capacity_)<<"Tuple size is larger than capacity" + <<"Try a larger capacity size"; + } + *reinterpret_cast<size_t*>(buf_+offset_)=key.size(); + offset_+=sizeof(size_t); + memcpy(buf_+offset_, key.data(), key.size()); + offset_+=key.size(); + *reinterpret_cast<size_t*>(buf_+offset_)=val.size(); + offset_+=sizeof(size_t); + memcpy(buf_+offset_, val.data(), val.size()); + offset_+=val.size(); + return true; +} + +void DataShard::Flush() { + fdat_.write(buf_, offset_); + fdat_.flush(); + offset_=0; +} + +int DataShard::Next(std::string *key){ + key->clear(); + int ssize=sizeof(size_t); + if(!PrepareNextField(ssize)) + return 0; + CHECK_LE(offset_+ssize, bufsize_); + int keylen=*reinterpret_cast<size_t*>(buf_+offset_); + offset_+=ssize; + + if(!PrepareNextField(keylen)) + return 0; + CHECK_LE(offset_+keylen, bufsize_); + for(int i=0;i<keylen;i++) + key->push_back(buf_[offset_+i]); + offset_+=keylen; + + if(!PrepareNextField(ssize)) + return 0; + CHECK_LE(offset_+ssize, bufsize_); + int vallen=*reinterpret_cast<size_t*>(buf_+offset_); + offset_+=ssize; + + if(!PrepareNextField(vallen)) + return 0; + CHECK_LE(offset_+vallen, bufsize_); + return vallen; +} + +bool DataShard::Next(std::string *key, Message* val) { + int vallen=Next(key); + if(vallen==0) + return false; + val->ParseFromArray(buf_+offset_, vallen); + offset_+=vallen; + return true; +} + +bool DataShard::Next(std::string *key, std::string* val) { + int vallen=Next(key); + if(vallen==0) + return false; + val->clear(); + for(int i=0;i<vallen;i++) + val->push_back(buf_[offset_+i]); + offset_+=vallen; + return true; +} + +void DataShard::SeekToFirst(){ + CHECK_EQ(mode_, kRead); + bufsize_=0; + offset_=0; + fdat_.close(); + fdat_.open(path_, std::ios::in|std::ios::binary); + CHECK(fdat_.is_open())<<"Cannot create file "<<path_; +} + +// if the buf does not have the next complete field, read data from disk +bool DataShard::PrepareNextField(int size){ + if(offset_+size>bufsize_){ + bufsize_-=offset_; + CHECK_LE(bufsize_, offset_); + for(int i=0;i<bufsize_;i++) + buf_[i]=buf_[i+offset_]; + offset_=0; + if(fdat_.eof()) + return false; + else{ + fdat_.read(buf_+bufsize_, capacity_-bufsize_); + bufsize_+=fdat_.gcount(); + } + } + return true; +} + +const int DataShard::Count() { + std::ifstream fin(path_, std::ios::in|std::ios::binary); + CHECK(fdat_.is_open())<<"Cannot create file "<<path_; + int count=0; + while(true){ + size_t len; + fin.read(reinterpret_cast<char*>(&len), sizeof(len)); + if(fin.good()) + fin.seekg(len, std::ios_base::cur); + else break; + if(fin.good()) + fin.read(reinterpret_cast<char*>(&len), sizeof(len)); + else break; + if(fin.good()) + fin.seekg(len, std::ios_base::cur); + else break; + if(!fin.good()) + break; + count++; + } + fin.close(); + return count; +} + +int DataShard::PrepareForAppend(std::string path){ + std::ifstream fin(path, std::ios::in|std::ios::binary); + if(!fin.is_open()){ + fdat_.open(path, std::ios::out|std::ios::binary); + fdat_.flush(); + fdat_.close(); + return 0; + } + + int last_tuple_offset=0; + char buf[256]; + size_t len; + while(true){ + memset(buf, 0, 256); + fin.read(reinterpret_cast<char*>(&len), sizeof(len)); + if(fin.good()) + fin.read(buf, len); + else break; + if(fin.good()) + fin.read(reinterpret_cast<char*>(&len), sizeof(len)); + else break; + if(fin.good()) + fin.seekg(len, std::ios_base::cur); + else break; + if(fin.good()) + keys_.insert(std::string(buf)); + else break; + last_tuple_offset=fin.tellg(); + } + fin.close(); + return last_tuple_offset; +} +} /* singa */ http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b2dc51d2/src/utils/graph.cc ---------------------------------------------------------------------- diff --git a/src/utils/graph.cc b/src/utils/graph.cc new file mode 100644 index 0000000..d1cece6 --- /dev/null +++ b/src/utils/graph.cc @@ -0,0 +1,148 @@ +#include <algorithm> +#include "utils/graph.h" + +const string Graph::ToString() const { + map<string, string> info; + return ToString(info); +} +const string Graph::ToString(const map<string, string>& info) const { + map<string, int> nodeid; + string disp="{\"directed\":1,\n"; + + // add nodes + disp+="\"nodes\":[\n"; + bool first=true; + + vector<string> colors={"red", "blue", "black", "green"}; + // see for more shapes at http://www.graphviz.org/doc/info/shapes.html + vector<string> shapes={"box", "ellipse"}; + int id=0; + for(auto node: nodes_){ + char str[1024]; + string name=node->name(); + string color=colors[(node->val().locationid)%colors.size()]; + string shape; + string origin=node->val().origin; + if(origin=="kSlice"||origin=="kConcate"||origin=="kSplit" + ||origin=="kBridgeSrc"||origin=="kBridgeDst") + shape=shapes[1]; + else + shape=shapes[0]; + sprintf(str, "{\"id\":\"%s%s\", \"color\":\"%s\",\"shape\":\"%s\"}\n", + name.c_str(), info.find(name)!=info.end()?info.at(name).c_str():"", + color.c_str(), shape.c_str()); + if(!first) + disp+=","; + else + first=false; + disp+=string(str); + nodeid[name]=id++; + } + disp+="]\n,"; + + // add edges + disp+="\"links\":[\n"; + first=true; + for(auto src: nodes_) + for(auto dst: src->dstnodes()){ + char str[1024]; + sprintf(str, "{\"source\":%d, \"target\":%d, \"color\":\"%s\"}\n", + nodeid[src->name()], nodeid[dst->name()], "black"); + if(!first) + disp+=","; + else + first=false; + disp+=string(str); + } + disp+="]\n"; + return disp+"}"; +} +bool Graph::Check() const { + return true; +} + + +// visited all dst nodes and then push current node into the stack +void Graph::topology_sort_inner(SNode node, + map<string, bool> *visited, + std::stack<string> *stack) { + (*visited)[node->name()] = true; + const vector<SNode>& dstnodes=node->dstnodes(); + for (auto it=dstnodes.rbegin();it!=dstnodes.rend();it++) { + if ((*visited)[(*it)->name()]) + continue; + topology_sort_inner((*it),visited, stack); + } + stack->push(node->name()); +} + +// sort to make `bottom' nodes be placed in the front positions +void Graph::Sort() { + // adjacent list from upper layers to lower layers + std::map<string, bool> visited; + // prepare adjacent list; input layers will be processed firstly, + // hence no need to sort them (mark them as visited) + for (SNode node: nodes_) { + visited[node->name()] = false; + } + // the `top' layer in the net will be placed at the bottom of the stack + // and then be processed (i.e., forward) at last + std::stack<string > stack; + for (SNode node: nodes_) { + if (visited[node->name()] == false) + topology_sort_inner(node, &visited, &stack); + } + nodes_.clear(); + + while (!stack.empty()) { + nodes_.push_back(name2node_[stack.top()]); + stack.pop(); + } +} + + + +SNode Graph::InsertSliceNode(SNode srcnode, const vector<SNode>& dstnodes, + const V& info, bool connect_dst){ + V myinfo=info; + myinfo.origin="kSlice"; + SNode node=AddNode("slice-"+srcnode->name(),myinfo); + AddEdge(srcnode, node); + if(connect_dst) + for(SNode dst: dstnodes) + AddEdge(node, dst); + return node; +} +SNode Graph::InsertConcateNode(const vector<SNode>&srcnodes, SNode dstnode, + const V& info){ + V myinfo=info; + myinfo.origin="kConcate"; + SNode node=AddNode("concate-"+dstnode->name(),myinfo); + AddEdge(node, dstnode); + for(SNode src: srcnodes) + AddEdge(src, node); + return node; +} +SNode Graph::InsertSplitNode(SNode srcnode, const vector<SNode>& dstnodes){ + V myinfo=srcnode->val(); + myinfo.origin="kSplit"; + SNode node=AddNode("split-"+srcnode->name(), myinfo); + AddEdge(srcnode, node); + for(SNode dst: dstnodes) + AddEdge(node, dst); + return node; +} +std::pair<SNode, SNode> Graph::InsertBridgeNode(SNode srcnode, SNode dstnode){ + LayerInfo info=srcnode->val(); + info.origin="kBridgeSrc"; + SNode src=AddNode("s-"+srcnode->name()+"-"+dstnode->name(), info); + info=dstnode->val(); + info.origin="kBridgeDst"; + SNode dst=AddNode("d-"+srcnode->name()+"-"+dstnode->name(), info); + AddEdge(srcnode, src); + AddEdge(src, dst); + AddEdge(dst, dstnode); + return pair<SNode, SNode>{src, dst}; +} + + http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b2dc51d2/src/utils/param.cc ---------------------------------------------------------------------- diff --git a/src/utils/param.cc b/src/utils/param.cc new file mode 100644 index 0000000..d64c65d --- /dev/null +++ b/src/utils/param.cc @@ -0,0 +1,345 @@ +#include <glog/logging.h> +#include <cmath> +#include <chrono> +#include <random> +#include "utils/param.h" +#include "mshadow/tensor.h" +#include "utils/singleton.h" +using namespace mshadow; +using std::vector; +using std::string; +namespace singa { + +Param::Param(){ + owner_=-1; + fan_in_=0; + set_version(-1); +} + +Param::~Param(){} + +Msg* Param::GenPutMsg(void* arg){ + char buf[256]; + int v=*(int*)arg; + sprintf(buf, "%d %d %f %f", v, size(), + learning_rate_multiplier(), weight_decay_multiplier()); + Msg* msg=new Msg(); + msg->set_type(kPut); + msg->add_frame(buf, strlen(buf)); + msg->add_frame(mutable_cpu_data(), size()*sizeof(float)); + return msg; +} + +Msg* Param::GenGetMsg(void* arg){ + char buf[10]; + int v=*(int*)arg; + sprintf(buf, "%d", v); + Msg* msg=new Msg(); + msg->set_type(kGet); + msg->add_frame(buf, strlen(buf)); + return msg; +} + +Msg* Param::GenUpdateMsg(void* arg){ + char buf[10]; + int v=*(int*)arg; + sprintf(buf, "%d", v); + Msg* msg=new Msg(); + msg->set_type(kUpdate); + msg->add_frame(buf, strlen(buf)); + + msg->add_frame(mutable_cpu_grad(), size()*sizeof(float)); + return msg; +} + +Msg* Param::GenSyncMsg(void* arg){ + return nullptr; +} + +Msg* Param::HandlePutMsg(Msg** msg){ + int v, size; + float lr, wc; + sscanf(static_cast<char*>((*msg)->frame_data()), "%d %d %f %f", + &v, &size, &lr, &wc); + set_version(v); + proto_.set_learning_rate_multiplier(lr); + proto_.set_weight_decay_multiplier(wc); + CHECK((*msg)->next_frame()); + vector<int> shape{size}; + data_.Reshape(shape); + grad_.Reshape(shape); + history_.Reshape(shape); + CHECK_EQ(size* sizeof(float), (*msg)->frame_size()); + memcpy(data_.mutable_cpu_data(), (*msg)->frame_data(), size*sizeof(float)); + delete (*msg); + *msg=nullptr; + return nullptr; +} + +Msg* Param::HandleGetMsg(Msg** msg){ + int v; + sscanf(static_cast<char*>((*msg)->frame_data()), "%d", &v); + CHECK_LE(v, version()); + CHECK(!(*msg)->next_frame()); + (*msg)->add_frame(data_.mutable_cpu_data(), sizeof(float)*size()); + (*msg)->SwapAddr(); + (*msg)->set_type(kRGet); + return *msg; +} + +int Param::ParseUpdateMsg(Msg** msg){ + int v; + sscanf(static_cast<char*>((*msg)->frame_data()), "%d", &v); + CHECK_LE(v, version()); + CHECK((*msg)->next_frame()); + memcpy(mutable_cpu_grad(), (*msg)->frame_data(),(*msg)->frame_size()); + delete (*msg); + *msg=nullptr; + return 1; +} + +Msg* Param::GenUpdateResponseMsg(void* arg){ + Msg* msg=new Msg(); + char buf[10]; + sprintf(buf, "%d", version()); + msg->set_type(kRUpdate); + msg->set_target(id()); + msg->add_frame(buf, strlen(buf)); + msg->add_frame(mutable_cpu_data(), size()*sizeof(float)); + return msg; +} + + +Msg* Param::HandleSyncMsg(Msg** msg){ + delete *msg; + *msg=nullptr; + return nullptr; +} + +int Param::ParseSyncResponseMsg(Msg** msg){ + delete *msg; + *msg=nullptr; + return 1; +} +int Param::ParsePutResponseMsg(Msg **msg){ + return ParseSyncResponseMsg(msg); +} +int Param::ParseGetResponseMsg(Msg **msg){ + int v; + sscanf(static_cast<char*>((*msg)->frame_data()), "%d", &v); + set_version(v); + CHECK((*msg)->next_frame()); + memcpy(mutable_cpu_data(), (*msg)->frame_data(), (*msg)->frame_size()); + return 1; +} +int Param::ParseUpdateResponseMsg(Msg **msg){ + return ParseGetResponseMsg(msg); +} + +void Param::Setup(const ParamProto& proto, const vector<int>& shape, + int fan_in){ + data_.Reshape(shape); + grad_.Reshape(shape); + history_.Reshape(shape); + proto_=proto; + fan_in_=fan_in; +} + +void Param::Init(int v){ + proto_.set_version(v); + Tensor<cpu, 1> data(data_.mutable_cpu_data(), Shape1(data_.count())); + unsigned seed = std::chrono::system_clock::now().time_since_epoch().count(); + auto random=ASingleton<Random<cpu>>::Instance(seed); + switch (proto_.init_method()) { + case ParamProto::kConstant: + data=proto_.value(); + break; + case ParamProto::kUniform: + random->SampleUniform(data, proto_.low(), proto_.high()); + if(proto_.value()) + data*= proto_.value(); + break; + case ParamProto::kUniformSqrtFanIn: + CHECK_GT(fan_in_,0); + random->SampleUniform(data, proto_.low(), proto_.high()); + if(proto_.value()) + data*= proto_.value()/ sqrt(fan_in_ / 3.0f); + break; + case ParamProto::kUniformSqrtFanInOut: + random->SampleUniform(data, proto_.low(), proto_.high()); + if(proto_.value()) + data*= proto_.value()/ sqrt(data_.shape()[0] +data_.shape()[1]); + break; + case ParamProto::kGaussian: + random->SampleGaussian(data, proto_.mean(), proto_.std()); + if(proto_.value()) + data*= proto_.value(); + break; + case ParamProto::kGaussainSqrtFanIn: + random->SampleGaussian(data, proto_.mean(), proto_.std()); + if(proto_.value()) + data*= proto_.value()/ sqrt(data_.shape()[0]); + break; + default: + LOG(ERROR) << "Illegal parameter init method "; + break; + } +} + +/**************************RandomSyncParam******************************** +const vector<int> RandomSyncParam::RandomSample(int seed, int m, int n){ + vector<int> samples(m); + std::mt19937 gen(seed); + std::uniform_real_distribution<float> dist(0.f,1.f); + for(int i=0,k=0;i<n&&k<m;i++) + if((m-k)*1.0f/(n-i)>dist(gen)){ + samples[k++]=i; + } + return samples; +} + +zmsg_t* RandomSyncParam::HandleSyncMsg(zmsg_t** msg){ + int64_t start=zclock_mono(); + char* control=zframe_strdup(zmsg_first(*msg)); + int seed, count; + sscanf(control, "%d-%d", &seed,&count); + delete control; + zframe_t* syncframe=zmsg_next(*msg); + CHECK_EQ(zframe_size(syncframe), count*sizeof(float)); + float* syncptr=(float*)zframe_data(syncframe); + float* dptr=data_.mutable_cpu_data(); + int k=0; + if(count==data_.count()){ + for(int idx=0;idx<count;idx++){ + float x=dptr[idx]; + dptr[idx]+=syncptr[k]; + syncptr[k]=x; + k++; + } + }else{ + for(int idx: RandomSample(seed, count, data_.count())){ + float x=dptr[idx]; + dptr[idx]+=syncptr[k]; + syncptr[k]=x; + k++; + } + } + CHECK_EQ(k,count); + CHECK_EQ(zframe_size(syncframe), count*sizeof(float)); + return *msg; +} + +zmsg_t *RandomSyncParam::GenSyncMsgFromWorker(float sample_ratio){ + int64_t start=zclock_mono(); + zmsg_t* msg=zmsg_new(); + unsigned seed = std::chrono::system_clock::now().time_since_epoch().count(); + int m=data_.count()*sample_ratio; + zmsg_addstrf(msg, "%u-%d", seed, m); + float* updateptr=new float[m]; + float* dptr=data_.mutable_cpu_data(); + float* sdptr=snapshot_.mutable_cpu_data(); + int k=0; + if(m==data_.count()){ + for(int idx=0;idx<m;idx++) + updateptr[k++]=dptr[idx]-sdptr[idx]; + }else{ + const vector<int> samples=RandomSample(seed, m, data_.count()); + for(int idx:samples){ + updateptr[k++]=dptr[idx]-sdptr[idx]; + } + } + CHECK_EQ(k,m); + zframe_t* frame=zframe_new(updateptr, sizeof(float)*m); + zmsg_append(msg, &frame); + delete updateptr; + worker_gen_sync+=zclock_mono()-start; + return msg; +} + +void RandomSyncParam::ParseSyncMsgFromPS(zmsg_t** msg){ + int64_t start=zclock_mono(); + //LOG(ERROR)<<"worker sync "<<id(); + char* control=zmsg_popstr(*msg); + int seed, count; + sscanf(control, "%u-%d", &seed, &count); + //LOG(ERROR)<<"worker sync "<<id()<<" "<<control; + delete control; + zframe_t* psdataframe=zmsg_pop(*msg); + CHECK_EQ(zframe_size(psdataframe), count*sizeof(float)); + float* psdptr=(float*)zframe_data(psdataframe); + float* dptr=data_.mutable_cpu_data(); + float* sdptr=snapshot_.mutable_cpu_data(); + int k=0; + if(count==data_.count()){ + for(int idx=0;idx<count;idx++){ + dptr[idx]+=psdptr[k++]-sdptr[idx]; + sdptr[idx]=dptr[idx]; + } + }else{ + for(int idx: RandomSample(seed, count, data_.count())){ + dptr[idx]+=psdptr[k++]-sdptr[idx]; + sdptr[idx]=dptr[idx]; + } + } + zframe_destroy(&psdataframe); + worker_handle_sync+=zclock_mono()-start; + zmsg_destroy(msg); +} + + +void RandomSyncParam::Setup(const ParamProto& proto, const vector<int>& shape, + int fan_in){ + Param::Setup(proto, shape, fan_in); + snapshot_.Reshape(shape); +} + +void RandomSyncParam::Init(){ + Param::Init(); + memcpy(snapshot_.mutable_cpu_data(), data_.mutable_cpu_data(), + sizeof(float)*data_.count()); +} +*/ + +/***************************ElasticParam************************************ +zmsg_t* ElasticParam::HandleSyncMsg(zmsg_t** msg){ + int64_t start=zclock_mono(); + char* control=zframe_strdup(zmsg_first(*msg)); + float alpha;int count; + sscanf(control, "%f-%d", &alpha,&count); + delete control; + zframe_t* syncframe=zmsg_next(*msg); + CHECK_EQ(size(), count); + Tensor<cpu, 1> server(data_.mutable_cpu_data(), Shape1(count)); + Tensor<cpu, 1> worker((float*)zframe_data(syncframe), Shape1(count)); + worker=(worker-server)*alpha; + server+=worker; + return *msg; +} + +zmsg_t *ElasticParam::GenSyncMsgFromWorker(float alpha){ + int64_t start=zclock_mono(); + zmsg_t* msg=zmsg_new(); + zmsg_addstrf(msg, "%f-%d", alpha, size()); + zmsg_addmem(msg, mutable_cpu_data(), sizeof(float)*size()); + worker_gen_sync+=zclock_mono()-start; + return msg; +} + +void ElasticParam::ParseSyncMsgFromPS(zmsg_t** msg){ + int64_t start=zclock_mono(); + //LOG(ERROR)<<"worker sync "<<id(); + char* control=zmsg_popstr(*msg); + float alpha;int count; + sscanf(control, "%f-%d", &alpha, &count); + delete control; + zframe_t* frame=zmsg_pop(*msg); + CHECK_EQ(zframe_size(frame), count*sizeof(float)); + Tensor<cpu, 1> diff((float*)zframe_data(frame), Shape1(count)); + Tensor<cpu, 1> data(mutable_cpu_data(), Shape1(count)); + data-=diff; + zframe_destroy(&frame); + zmsg_destroy(msg); + worker_handle_sync+=zclock_mono()-start; +} +*/ +} // namespace singa http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b2dc51d2/src/utils/updater.cc ---------------------------------------------------------------------- diff --git a/src/utils/updater.cc b/src/utils/updater.cc new file mode 100644 index 0000000..0b89ee8 --- /dev/null +++ b/src/utils/updater.cc @@ -0,0 +1,192 @@ + +#include "utils/updater.h" +#include "mshadow/tensor.h" +#include "mshadow/cxxnet_op.h" +#include "proto/model.pb.h" +using namespace mshadow; +using namespace mshadow::expr; + +namespace singa { + +float Updater::GetLearningRate(int step){ + float ret = 0., r = 0., base=proto_.base_learning_rate(); + int freq=0; + switch (proto_.learning_rate_change_method()) { + case UpdaterProto_ChangeProto_kFixed: + ret = base; + break; + case UpdaterProto_ChangeProto_kLinear: + // a is init, b is the final + freq=proto_.learning_rate_change_frequency(); + r = step * 1.0 / freq; + ret = (1.0 - r) * base + r * proto_.final_learning_rate(); + break; + case UpdaterProto_ChangeProto_kExponential: + // a is init, b is the final, from convnet + CHECK_EQ(base, 2 * proto_.final_learning_rate()) + << "final value should be the half"; + freq=proto_.learning_rate_change_frequency(); + ret = base / pow(2, step * 1. / freq); + break; + case UpdaterProto_ChangeProto_kInverse_t: + // a is init, b is the final, from convnet + CHECK_EQ(base, 2 * proto_.final_learning_rate()) + << "final value should be the half"; + ret = base / (1. + step * 1. / proto_.final_learning_rate()); + break; + case UpdaterProto_ChangeProto_kInverse: + // a is init, b is gamma, c is pow + ret=base*pow(1.f+proto_.gamma()*step, -proto_.pow()); + break; + case UpdaterProto_ChangeProto_kStep: + // a is the base learning rate, b is gamma, from caffe + // notice it is step/change_steps, not step*1.0/change_steps + freq=proto_.learning_rate_change_frequency(); + ret = base * pow(proto_.gamma(), step / freq); + break; + case UpdaterProto_ChangeProto_kFixedStep: + for(size_t i=0;i<proto_.step_size();i++){ + if(step>proto_.step(i)) + ret=proto_.step_lr(i); + } + break; + default: + LOG(ERROR) << "Wrong hyper-parameter update method"; + } + return ret; +} + +/***********************SGD with momentum******************************/ +void SGDUpdater::Init(const UpdaterProto& proto){ + Updater::Init(proto); + base_lr_=proto.base_learning_rate(); + //CHECK_GT(base_lr_, 0); + momentum_=proto.momentum(); + weight_decay_=proto.weight_decay(); +} + +void SGDUpdater::Update(int step, shared_ptr<Param> param, float grad_scale){ + Shape<1> s=Shape1(param->size()); + Tensor<cpu, 1> data(param->mutable_cpu_data(), s); + Tensor<cpu, 1> grad(param->mutable_cpu_grad(), s); + float lr=GetLearningRate(step)*param->learning_rate_multiplier(); + float wd=weight_decay_*param->weight_decay_multiplier(); + if(wd>0){ // L2 regularization + grad+=data*wd; + } + if(momentum_>0){ + Tensor<cpu, 1> history(param->mutable_cpu_history(), s); + if(step==0) history=0; + history=history*momentum_-lr*grad; + data+=history; + }else{ + grad*=-lr; + data+=grad; + } +} + +/***********************Nesterov******************************/ +void NesterovUpdater::Init(const UpdaterProto& proto){ + Updater::Init(proto); + base_lr_=proto.base_learning_rate(); + CHECK_GT(base_lr_, 0); + weight_decay_=proto.weight_decay(); +} + +void NesterovUpdater::Update(int step, shared_ptr<Param> param, float grad_scale){ + Shape<1> s=Shape1(param->size()); + Tensor<cpu, 1> data(param->mutable_cpu_data(), s); + Tensor<cpu, 1> grad(param->mutable_cpu_grad(), s); + Tensor<cpu, 1> history(param->mutable_cpu_history(), s); + TensorContainer<cpu, 1> tmp(s); + if(step==0) history=0; + float lr=GetLearningRate(step)*param->learning_rate_multiplier(); + float wd=weight_decay_*param->weight_decay_multiplier(); + if(wd>0){ // L2 regularization + grad+=data*wd; + } + Copy(tmp, history); + history=history*momentum_+lr*grad; + tmp=history*(1+momentum_)-tmp*momentum_; + data-=tmp; +} +/***********************AdaGrad******************************/ +void AdaGradUpdater::Init(const UpdaterProto& proto){ + Updater::Init(proto); + base_lr_=proto.base_learning_rate(); + CHECK_GT(base_lr_, 0); + delta_=proto.delta(); + weight_decay_=proto.weight_decay(); +} + +void AdaGradUpdater::Update(int step, shared_ptr<Param> param, float grad_scale){ + Shape<1> s=Shape1(param->size()); + Tensor<cpu, 1> data(param->mutable_cpu_data(), s); + Tensor<cpu, 1> grad(param->mutable_cpu_grad(), s); + Tensor<cpu, 1> history(param->mutable_cpu_history(), s); + if(step==0) history=0; + history+=F<op::square>(grad*grad_scale); + float lr=GetLearningRate(step)*param->learning_rate_multiplier(); + float wd=weight_decay_*param->weight_decay_multiplier(); + if(wd>0){ // L2 regularization + grad+=data*wd; + } + data-=lr*grad/(F<op::sqrtop>(history,delta_)); +} + +/***********************RMSProp******************************/ +void RMSPropUpdater::Init(const UpdaterProto& proto){ + Updater::Init(proto); + base_lr_=proto.base_learning_rate(); + CHECK_GT(base_lr_, 0); + delta_=proto.delta(); + rho_=proto.rho(); + weight_decay_=proto.weight_decay(); +} + +void RMSPropUpdater::Update(int step, shared_ptr<Param> param, float grad_scale){ + Shape<1> s=Shape1(param->size()); + Tensor<cpu, 1> data(param->mutable_cpu_data(), s); + Tensor<cpu, 1> grad(param->mutable_cpu_grad(), s); + Tensor<cpu, 1> history(param->mutable_cpu_history(), s); + if(step==0) history=0; + history=history*rho_+(1-rho_)*F<op::square>(grad*grad_scale); + float lr=GetLearningRate(step)*param->learning_rate_multiplier(); + float wd=weight_decay_*param->weight_decay_multiplier(); + if(wd>0){ // L2 regularization + grad+=data*wd; + } + data-=lr*grad/(F<op::sqrtop>(history,delta_)); +} + +/***********************AdaDelta****************************** +void AdaDeltaUpdater::Init(const UpdaterProto& proto){ + Updater::Init(proto); + delta_=proto.delta(); + rho_=proto.rho(); + weight_decay_=proto.weight_decay(); +} + +void AdaDeltaUpdater::Update(int step, shared_ptr<Param> param, float grad_scale){ + Shape<1> s=Shape1(param->size()); + Tensor<cpu, 1> data(param->mutable_cpu_data(), s); + Tensor<cpu, 1> grad(param->mutable_cpu_grad(), s); + Tensor<cpu, 1> history(param->mutable_cpu_history(), s); + Tensor<cpu, 1> update(param->mutable_cpu_update(), s); + TensorContainer<cpu, 1> tmp(s); + float wd=weight_decay_*param->weight_decay_multiplier(); + if(wd>0){ // L2 regularization + grad+=data*wd; + } + if(step==0){ + history=0; + update=0; + } + history=history*rho_+(1-rho_)*F<op::square>(grad*grad_scale); + tmp=grad*F<op::sqrtop>(update, delta_)/F<op::sqrtop>(history, delta_); + update=rho_*update+(1-rho_)*F<op::square>(tmp); + data-=tmp; +} +*/ + +} /* singa */
