http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/2498ff13/include/singa.h ---------------------------------------------------------------------- diff --cc include/singa.h index 6fb9e97,82df64b..52d1f90 --- a/include/singa.h +++ b/include/singa.h @@@ -1,92 -1,40 +1,88 @@@ #ifndef SINGA_SINGA_H_ #define SINGA_SINGA_H_ ++ ++#include <cblas.h> #include <gflags/gflags.h> #include <glog/logging.h> --#include <cblas.h> - -#include "utils/common.h" +#include <string> - - #include "utils/common.h" ++#include "communication/socket.h" ++#include "neuralnet/neuralnet.h" #include "proto/job.pb.h" #include "proto/singa.pb.h" -- ++#include "trainer/trainer.h" ++#include "utils/common.h" #include "utils/param.h" #include "utils/singleton.h" #include "utils/factory.h" --#include "neuralnet/neuralnet.h" --#include "trainer/trainer.h" --#include "communication/socket.h" - +namespace singa { + -DEFINE_string(singa_conf, "conf/singa.conf", "Global config file"); +class Driver { + public: + /** + * Init SINGA, including init glog, parse job id and job conf from cmd line, + * and register built-in layer, worker, updater, param subclasses. + * + * May be used for MPI init if it is used for message passing. + */ + void Init(int argc, char** argv); + /** + * Register a Layer subclass. + * + * T is the subclass. + * @param type layer type ID. If called by users, it should be different to + * the types of built-in layers. + * @return 0 if success; otherwise -1. + */ + template<typename T> + int RegisterLayer(int type); + /** + * Register Updater subclasses. + * + * T is the subclass. + * @param type updater type ID. If called by users, it should be different to + * the types of built-in updaters. + * @return 0 if success; otherwise -1. + */ + template<typename T> + int RegisterUpdater(int type); + /** + * Register Worker subclasses. + * + * T is the subclass. + * @param type worker type ID. If called by users, it should be different to + * the types of built-in workers + * @return 0 if success; otherwise -1. + */ + template<typename T> + int RegisterWorker(int type); + /** + * Register Param subclasses. + * + * T is the subclass. + * @param type param type. If called by users, it should be different to the + * types of built-in params. SINGA currently provides only one built-in Param + * implementation whose type ID is 0. + * @return 0 if success; otherwise -1. + */ + template<typename T> + int RegisterParam(int type); - + /** + * Submit the job configuration for starting the job. + * @param resume resume from last checkpoint if true. + * @param job job configuration + */ + void Submit(bool resume, const JobProto& job); - + /** + * @return job ID which is generated by zookeeper and passed in by the + * launching script. + */ - int job_id() const { - return job_id_; - } ++ inline int job_id() const { return job_id_; } + + private: + int job_id_; +}; + -namespace singa { -void SubmitJob(int job, bool resume, const JobProto& jobConf) { - SingaProto singaConf; - ReadProtoFromTextFile(FLAGS_singa_conf.c_str(), &singaConf); - if (singaConf.has_log_dir()) - SetupLog(singaConf.log_dir(), - std::to_string(job) + "-" + jobConf.name()); - if (jobConf.num_openblas_threads() != 1) - LOG(WARNING) << "openblas is set with " << jobConf.num_openblas_threads() - << " threads"; - openblas_set_num_threads(jobConf.num_openblas_threads()); - JobProto proto; - proto.CopyFrom(jobConf); - proto.set_id(job); - Trainer trainer; - trainer.Start(resume, singaConf, &proto); -} } // namespace singa --#endif // SINGA_SINGA_H_ ++#endif // SINGA_SINGA_H_
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/2498ff13/src/driver.cc ---------------------------------------------------------------------- diff --cc src/driver.cc index 05c1195,0000000..5469583 mode 100644,000000..100644 --- a/src/driver.cc +++ b/src/driver.cc @@@ -1,101 -1,0 +1,102 @@@ +#include "singa.h" ++ +namespace singa { + +/** + * the job and singa_conf arguments are passed by the singa script which is + * transparent to users + */ +DEFINE_int32(job, -1, "Unique job ID generated from singa-run.sh"); +DEFINE_string(singa_conf, "conf/singa.conf", "Global config file"); + +void Driver::Init(int argc, char **argv) { + google::InitGoogleLogging(argv[0]); + gflags::ParseCommandLineFlags(&argc, &argv, true); + job_id_ = FLAGS_job; + + // register layers + RegisterLayer<BridgeDstLayer>(kBridgeDst); + RegisterLayer<BridgeSrcLayer>(kBridgeSrc); + RegisterLayer<ConvolutionLayer>(kConvolution); + RegisterLayer<ConcateLayer>(kConcate); + RegisterLayer<DropoutLayer>(kDropout); + RegisterLayer<InnerProductLayer>(kInnerProduct); + RegisterLayer<LabelLayer>(kLabel); + RegisterLayer<LRNLayer>(kLRN); + RegisterLayer<MnistLayer>(kMnist); + RegisterLayer<PrefetchLayer>(kPrefetch); + RegisterLayer<PoolingLayer>(kPooling); + RegisterLayer<RGBImageLayer>(kRGBImage); + RegisterLayer<ReLULayer>(kReLU); + RegisterLayer<ShardDataLayer>(kShardData); + RegisterLayer<SliceLayer>(kSlice); + RegisterLayer<SoftmaxLossLayer>(kSoftmaxLoss); + RegisterLayer<SplitLayer>(kSplit); + RegisterLayer<TanhLayer>(kTanh); + RegisterLayer<RBMVisLayer>(kRBMVis); + RegisterLayer<RBMHidLayer>(kRBMHid); +#ifdef USE_LMDB - RegisterLayer(factory, LMDBData); ++ RegisterLayer<LMDBDataLayer>(kLMDBData); +#endif + + // register updater + RegisterUpdater<AdaGradUpdater>(kAdaGrad); + RegisterUpdater<NesterovUpdater>(kNesterov); - // TODO(wangwei) RegisterUpdater<kRMSPropUpdater>(kRMSProp); ++ // TODO(wangwei) RegisterUpdater<kRMSPropUpdater>(kRMSProp); + RegisterUpdater<SGDUpdater>(kSGD); + + // register worker + RegisterWorker<BPWorker>(kBP); + RegisterWorker<CDWorker>(kCD); + + // register param + RegisterParam<Param>(0); +} + +template<typename T> +int Driver::RegisterLayer(int type) { + auto factory = Singleton<Factory<singa::Layer>>::Instance(); + factory->Register(type, CreateInstance(T, Layer)); + return 1; +} + +template<typename T> +int Driver::RegisterParam(int type) { + auto factory = Singleton<Factory<singa::Param>>::Instance(); + factory->Register(type, CreateInstance(T, Param)); + return 1; +} + +template<typename T> +int Driver::RegisterUpdater(int type) { + auto factory = Singleton<Factory<singa::Updater>>::Instance(); + factory->Register(type, CreateInstance(T, Updater)); + return 1; +} + +template<typename T> +int Driver::RegisterWorker(int type) { + auto factory = Singleton<Factory<singa::Worker>>::Instance(); + factory->Register(type, CreateInstance(T, Worker)); + return 1; +} + +void Driver::Submit(bool resume, const JobProto& jobConf) { + SingaProto singaConf; + ReadProtoFromTextFile(FLAGS_singa_conf.c_str(), &singaConf); + if (singaConf.has_log_dir()) + SetupLog(singaConf.log_dir(), std::to_string(FLAGS_job) + + "-" + jobConf.name()); + if (jobConf.num_openblas_threads() != 1) + LOG(WARNING) << "openblas with " + << jobConf.num_openblas_threads() << " threads"; + openblas_set_num_threads(jobConf.num_openblas_threads()); + + JobProto job; + job.CopyFrom(jobConf); + job.set_id(job_id_); + Trainer trainer; + trainer.Start(resume, singaConf, &job); +} + +} // namespace singa
