Repository: incubator-singa Updated Branches: refs/heads/dev 37e6ad283 -> 7333517b4
SINGA-199 Implement Python classes for SGD optimizers Add regularizer and constraint python classes and all C++ optimizers. Add swig files for optimzier and test_optimzier.py; add intializer.py commit from @aaronwwf - change module name from 'singa' to 'singa_wrap' - fix model_optimizer.i naming mistake Fix all bugs; Changed the CnMemPool's initialized and mtx fields to be class members. Add a global counter to check the CnMemPool instance numbers (which must be less than 2). CnMemPool must be used as a singleton, because cnmemInit can only be called once. Python unittest may run test cases in parallel, in which case we need to create cuda devices outside of test cases. Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/7333517b Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/7333517b Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/7333517b Branch: refs/heads/dev Commit: 7333517b492c49b9fcde22998063b84f98178fe3 Parents: 37e6ad2 Author: Wei Wang <[email protected]> Authored: Tue Jun 21 18:25:54 2016 +0800 Committer: Wei Wang <[email protected]> Committed: Mon Jul 4 20:18:24 2016 +0800 ---------------------------------------------------------------------- include/singa/core/device.h | 2 +- include/singa/core/memory.h | 9 +- include/singa/model/optimizer.h | 33 +++- src/CMakeLists.txt | 8 +- src/core/memory/memory.cc | 33 ++-- src/model/optimizer/adagrad.cc | 4 +- src/model/optimizer/optimizer.cc | 3 + src/python/device.py | 99 +++++----- src/python/initializer.py | 37 ++++ src/python/layer.py | 3 +- src/python/optimizer.py | 330 +++++++++++++++++++++++++++++++++ src/python/setup.py.in | 64 +++---- src/python/swig/core_device.i | 48 ++--- src/python/swig/model_layer.i | 3 +- src/python/swig/model_optimizer.i | 70 +++++++ src/python/swig/singa.i | 3 +- src/python/tensor.py | 56 +++--- test/python/test_optimizer.py | 104 +++++++++++ test/python/test_tensor.py | 40 ++-- test/singa/test_adagrad.cc | 8 +- test/singa/test_platform.cc | 23 ++- test/singa/test_snapshot.cc | 18 +- 22 files changed, 785 insertions(+), 213 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/7333517b/include/singa/core/device.h ---------------------------------------------------------------------- diff --git a/include/singa/core/device.h b/include/singa/core/device.h index 17fa66a..17613bb 100644 --- a/include/singa/core/device.h +++ b/include/singa/core/device.h @@ -233,7 +233,7 @@ class Platform { /// except the context initialization. static bool CheckDevice(const int device_id); - private: +// private: Platform() {}; // No need to construct an instance as it has no member fields }; http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/7333517b/include/singa/core/memory.h ---------------------------------------------------------------------- diff --git a/include/singa/core/memory.h b/include/singa/core/memory.h index 1019ed3..f664f95 100644 --- a/include/singa/core/memory.h +++ b/include/singa/core/memory.h @@ -20,6 +20,7 @@ #define SINGA_CORE_MEMORY_H_ #include <mutex> +#include <atomic> #include "singa/proto/core.pb.h" #include "singa/singa_config.h" @@ -68,12 +69,16 @@ class CnMemPool : public DeviceMemPool { protected: void Init(); + private: + MemPoolConf conf_; // whether the (global) memory pool has been initialized - static bool initialized; + bool initialized_ = false; // lock on the initialized variable - static std::mutex mtx; + std::mutex mtx_; + + static std::atomic<int> pool_count; }; class CudaMemPool : public DeviceMemPool { http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/7333517b/include/singa/model/optimizer.h ---------------------------------------------------------------------- diff --git a/include/singa/model/optimizer.h b/include/singa/model/optimizer.h index 2ec68fe..454a387 100644 --- a/include/singa/model/optimizer.h +++ b/include/singa/model/optimizer.h @@ -128,6 +128,9 @@ class Constraint { float threshold_; }; +inline std::shared_ptr<Constraint> CreateConstraint(std::string type) { + return std::make_shared<Constraint>(); +} /// Apply regularization for parameters (gradient), e.g., L1 norm and L2 norm. /// TODO(wangwei) implement a sub-class for each type of regularizer class Regularizer { @@ -159,6 +162,11 @@ class Regularizer { string type_ = "NotSet"; float coefficient_; }; +inline std::shared_ptr<Regularizer> CreateRegularizer(std::string type) { + return std::make_shared<Regularizer>(); +} + + // =============Vallina SGD with Momentum===================================== class SGD : public Optimizer { @@ -173,7 +181,6 @@ class SGD : public Optimizer { void SetMomentumGenerator(std::function<float(int)> func) { momentum_generator_ = func; } - virtual ~SGD() = default; private: std::unordered_map<string, Tensor> history_gradient_; @@ -181,7 +188,7 @@ class SGD : public Optimizer { }; // =============Nesterov====================================================== -class Nesterov : Optimizer { +class Nesterov : public Optimizer { public: void Setup(const OptimizerConf& conf); /// Apply the updating algorithm. @@ -193,7 +200,6 @@ class Nesterov : Optimizer { void SetMomentumGenerator(std::function<float(int)> func) { momentum_generator_ = func; } - virtual ~Nesterov() = default; private: std::unordered_map<string, Tensor> history_gradient_; @@ -201,20 +207,19 @@ class Nesterov : Optimizer { }; // =============Adagrad======================================================= -class Adagrad : Optimizer { +class AdaGrad : public Optimizer { public: void Setup(const OptimizerConf& conf); /// Apply the updating algorithm. void Apply(int step, float lr, const string& name, const Tensor& grad, Tensor* value) override; - virtual ~Adagrad() = default; private: std::unordered_map<string, Tensor> history_gradient_; float delta_; }; // =============RMSProp======================================================= -class RMSProp : Optimizer { +class RMSProp : public Optimizer { public: void Setup(const OptimizerConf& conf); /// Apply the updating algorithm. @@ -226,6 +231,22 @@ class RMSProp : Optimizer { std::unordered_map<string, Tensor> history_gradient_; float delta_, rho_; }; + + +inline std::shared_ptr<Optimizer> CreateOptimizer(const string& type) { + std::shared_ptr<Optimizer> opt; + if (type == "SGD") + opt = std::shared_ptr<Optimizer>(new SGD()); + else if (type == "RMSProp") + opt = std::shared_ptr<Optimizer>(new RMSProp()); + else if (type == "AdaGrad") + opt = std::shared_ptr<Optimizer>(new AdaGrad()); + else if (type == "Nesterov") + opt = std::shared_ptr<Optimizer>(new Nesterov()); + else + LOG(FATAL) << "Unknown optimizer type : " << type; + return opt; +} // ============LocalAllReduce for single node multiple workers ============== /// Updater for training models on a single node with multiple devices (workers) /// All model parameters are partitioned such that each parameter is updated on http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/7333517b/src/CMakeLists.txt ---------------------------------------------------------------------- diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 54d19ec..c479252 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -83,13 +83,13 @@ IF(USE_PYTHON) swig_generate_cxx(python_srcs ${python_files}) FILE(COPY python/ DESTINATION ${CMAKE_BINARY_DIR}/python/singa FILES_MATCHING PATTERN "swig" EXCLUDE PATTERN "*.py") SET(python_cxxs "${CMAKE_SOURCE_DIR}/src/core/tensor/tensor.cc;${CMAKE_SOURCE_DIR}/src/core/device/device.cc") - ADD_LIBRARY(_singa SHARED ${python_srcs} ${python_cxxs}) + ADD_LIBRARY(_singa_wrap SHARED ${python_srcs} ${python_cxxs}) SET(WRAPPER_LINKER_LIBS "${SINGA_LINKER_LIBS};protobuf") - TARGET_LINK_LIBRARIES(_singa ${WRAPPER_LINKER_LIBS}) - TARGET_INCLUDE_DIRECTORIES(_singa PRIVATE ${PYTHON_INCLUDE_DIRS}) + TARGET_LINK_LIBRARIES(_singa_wrap ${WRAPPER_LINKER_LIBS}) + TARGET_INCLUDE_DIRECTORIES(_singa_wrap PRIVATE ${PYTHON_INCLUDE_DIRS}) #message(STATUS "PYTHON_INCLUDE_DIRS ${PYTHON_INCLUDE_DIRS}") - SET_TARGET_PROPERTIES(_singa + SET_TARGET_PROPERTIES(_singa_wrap PROPERTIES PREFIX "" LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/python/singa ) http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/7333517b/src/core/memory/memory.cc ---------------------------------------------------------------------- diff --git a/src/core/memory/memory.cc b/src/core/memory/memory.cc index 63ffc2d..fa4b305 100644 --- a/src/core/memory/memory.cc +++ b/src/core/memory/memory.cc @@ -23,9 +23,7 @@ #ifdef USE_CUDA namespace singa { -bool singa::CnMemPool::initialized = false; -std::mutex singa::CnMemPool::mtx; - +std::atomic<int> CnMemPool::pool_count(0); std::pair<size_t, size_t> CnMemPool::GetMemUsage() { size_t free, total; auto status = cnmemMemGetInfo(&free, &total, NULL); @@ -39,17 +37,21 @@ CnMemPool::CnMemPool(int numDevices, size_t init_size, size_t max_size) { conf_.add_device(i); conf_.set_init_size(init_size); conf_.set_max_size(max_size); + CHECK_LT(++pool_count, 2) << "CnMemPool must be used as a singleton."; } -CnMemPool::CnMemPool(const MemPoolConf &conf) { conf_ = conf; } +CnMemPool::CnMemPool(const MemPoolConf &conf) { + conf_ = conf; + CHECK_LT(++pool_count, 2) << "CnMemPool must be used as a singleton."; +} void CnMemPool::Init() { - mtx.lock(); - if (!initialized) { + mtx_.lock(); + if (!initialized_) { const size_t kNBytesPerMB = (1u << 20); CHECK_GE(conf_.device_size(), 1); cnmemDevice_t *settingPtr = new cnmemDevice_t[conf_.device_size()]; - CHECK_GT(conf_.init_size(), 0); + CHECK_GT(conf_.init_size(), 0u); int i = 0; for (auto device : conf_.device()) { settingPtr[i].device = device; @@ -63,24 +65,25 @@ void CnMemPool::Init() { CHECK_EQ(status, cnmemStatus_t::CNMEM_STATUS_SUCCESS) << " " << cnmemGetErrorString(status); delete[] settingPtr; - initialized = true; + initialized_ = true; } - mtx.unlock(); + mtx_.unlock(); } CnMemPool::~CnMemPool() { - mtx.lock(); - if (initialized) { + mtx_.lock(); + if (initialized_) { cnmemStatus_t status = cnmemFinalize(); CHECK_EQ(status, cnmemStatus_t::CNMEM_STATUS_SUCCESS) << " " << cnmemGetErrorString(status); - initialized = false; + initialized_ = false; + --pool_count; } - mtx.unlock(); + mtx_.unlock(); } void CnMemPool::Malloc(void **ptr, const size_t size) { - if (!initialized) + if (!initialized_) Init(); cnmemStatus_t status = cnmemMalloc(ptr, size, NULL); CHECK_EQ(status, cnmemStatus_t::CNMEM_STATUS_SUCCESS) @@ -88,7 +91,7 @@ void CnMemPool::Malloc(void **ptr, const size_t size) { } void CnMemPool::Free(void *ptr) { - CHECK(initialized) << "Cannot free the memory as the pool is not initialzied"; + CHECK(initialized_) << "Cannot free the memory as the pool is not initialzied"; cnmemStatus_t status = cnmemFree(ptr, NULL); CHECK_EQ(status, cnmemStatus_t::CNMEM_STATUS_SUCCESS) << " " << cnmemGetErrorString(status); http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/7333517b/src/model/optimizer/adagrad.cc ---------------------------------------------------------------------- diff --git a/src/model/optimizer/adagrad.cc b/src/model/optimizer/adagrad.cc index fec9c96..0f2119b 100644 --- a/src/model/optimizer/adagrad.cc +++ b/src/model/optimizer/adagrad.cc @@ -21,11 +21,11 @@ #include <functional> namespace singa { -void Adagrad::Setup(const OptimizerConf& conf) { delta_ = conf.delta(); } +void AdaGrad::Setup(const OptimizerConf& conf) { delta_ = conf.delta(); } // history += grad*grad; // value = value - lr*grad/sqrt(history+delta) -void Adagrad::Apply(int step, float lr, const string& name, const Tensor& grad, +void AdaGrad::Apply(int step, float lr, const string& name, const Tensor& grad, Tensor* value) { if (history_gradient_.find(name) == history_gradient_.end()) history_gradient_[name].ResetLike(*value); http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/7333517b/src/model/optimizer/optimizer.cc ---------------------------------------------------------------------- diff --git a/src/model/optimizer/optimizer.cc b/src/model/optimizer/optimizer.cc index 9be47c8..9fee18d 100644 --- a/src/model/optimizer/optimizer.cc +++ b/src/model/optimizer/optimizer.cc @@ -84,6 +84,9 @@ void Optimizer::Apply(int step, const string& name, Tensor* grad, void Regularizer::Setup(const RegularizerConf& conf) { type_ = conf.type(); coefficient_ = conf.coefficient(); + if (type_ != "L2" && type_ != "l2") { + CHECK(type_ == "NotSet") << "Unknown regularizer type = " << type_; + } } void Regularizer::Apply(int step, Tensor* value, Tensor* grad, float scale) { http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/7333517b/src/python/device.py ---------------------------------------------------------------------- diff --git a/src/python/device.py b/src/python/device.py index 0877ae5..0472d8d 100644 --- a/src/python/device.py +++ b/src/python/device.py @@ -1,60 +1,42 @@ -#!/usr/bin/env python - -# /************************************************************ -# * -# * Licensed to the Apache Software Foundation (ASF) under one -# * or more contributor license agreements. See the NOTICE file -# * distributed with this work for additional information -# * regarding copyright ownership. The ASF licenses this file -# * to you under the Apache License, Version 2.0 (the -# * "License"); you may not use this file except in compliance -# * with the License. You may obtain a copy of the License at -# * -# * http://www.apache.org/licenses/LICENSE-2.0 -# * -# * Unless required by applicable law or agreed to in writing, -# * software distributed under the License is distributed on an -# * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# * KIND, either express or implied. See the License for the -# * specific language governing permissions and limitations -# * under the License. -# * -# *************************************************************/ - +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# ============================================================================= ''' This script includes Device class and its subclasses for python users to call singa::Device and its methods ''' -import sys -import os -import numpy as np -from . import singa +from . import singa_wrap as singa class Device(object): - ''' Class and member functions for singa::Device - ''' + """ Class and member functions for singa::Device. - def __init__(self, id=-1, num_executors=1, scheduler='sync', vm='gc-only', - device='cpu'): - ''' id = (int) // device ID - num_executors = (int) // # of executors (e.g., cuda streams) - scheduler = (string) // identifier of scheduler type (default - // scheduler run operations synchronously) - vm = (string) // virtual memory type (default vm only - // provides garbage collection) - (TODO) max mem size to use (in MB) - ''' - if device == 'gpu': - self.singa_device = singa.CudaGPU(id, num_executors, scheduler, vm) - else: - self.singa_device = singa.CppCPU(id, num_executors, scheduler, vm) + Create Device instances using the CreateXXXDevice. + """ + def __init__(self, id, device): + """Device constructor given device ID. + Args: + id (int): device ID. + device: swig shared_ptr<Device> + """ self.id = id - self.num_executors = num_executors - self.scheduler = scheduler - self.vm = vm + self.singa_device = device def set_rand_seed(self, seed): self.singa_device.SetRandSeed(seed) @@ -66,14 +48,27 @@ class Device(object): return self.singa_device.id() -class CppCPU(Device): +class Platform(object): + @staticmethod + def get_num_gpus(): + return singa.Platform.GetNumGPUs() - def __init__(self, id=-1, num_executors=1, scheduler='sync', vm='gc-only'): - super(CppCPU, self).__init__(id, num_executors, scheduler, vm) + @staticmethod + def get_gpu_ids(): + return singa.Platform.GetGPUIDs() + @staticmethod + def get_gpu_mem_size(id): + return singa.Platform.GetGPUMemSize(id) -class CudaGPU(Device): + @staticmethod + def device_query(id, verbose=False): + return singa.Platform.DeviceQuery(id, verbose) - def __init__(self, id=0, num_executors=1, scheduler='sync', vm='gc-only'): - super(CudaGPU, self).__init__(id, num_executors, scheduler, vm, 'gpu') + @staticmethod + def create_raw_cuda_gpus(num): + return singa.Platform.CreateCudaGPUs(num) + @staticmethod + def create_cuda_gpu(): + return singa.Platform.CreateCudaGPUs(1)[0] http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/7333517b/src/python/initializer.py ---------------------------------------------------------------------- diff --git a/src/python/initializer.py b/src/python/initializer.py new file mode 100644 index 0000000..bc1e8a0 --- /dev/null +++ b/src/python/initializer.py @@ -0,0 +1,37 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# ============================================================================= +"""Popular initialization methods for parameter values (Tensor ojects)""" + +import math + + +def uniform(t, low=0, high=1): + t.uniform(low, high) + + +def gaussian(t, mean=0, std=0.01): + t.gaussian(mean, std) + + +def xavier(t): + scale = math.sqrt(6.0 / (t.shape[0] + t.shape[1])) + t.uniform(-scale, scale) + + +def msra(t): + t.gaussian(0, math.sqrt(2.0 / t.shape[0])) http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/7333517b/src/python/layer.py ---------------------------------------------------------------------- diff --git a/src/python/layer.py b/src/python/layer.py index a1ec556..2151766 100644 --- a/src/python/layer.py +++ b/src/python/layer.py @@ -24,7 +24,8 @@ import sys import os import numpy as np -from . import singa + +from . import singa_wrap as singa from .proto.core_pb2 import * from .proto.model_pb2 import * http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/7333517b/src/python/optimizer.py ---------------------------------------------------------------------- diff --git a/src/python/optimizer.py b/src/python/optimizer.py new file mode 100644 index 0000000..43b4c9d --- /dev/null +++ b/src/python/optimizer.py @@ -0,0 +1,330 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# ============================================================================= +""" Python wrappers for optimizers implemented by C++.""" + +from . import singa_wrap as singa +import tensor +from proto import model_pb2 + + +class Optimizer(object): + """Base python optimizer. + + Usages: + 1. construct the optimizer + 2. (optional) register each parameter with its specs. + 3. use the optimizer to update parameter values given parameter + gradients and other optional info + """ + def __init__(self, lr=None, momentum=None, decay=None, lr_gen=None, + momentum_gen=None, regularizer=None, constraint=None): + """Constructor. + + Args: + lr: a constant or a function that generates learning rate given a + step, which is mutually exclusive with 'lr_gen'. + momentum: a constant or a function that generates the momentum value + given a step. + decay (float): the coefficent for L2 regularizer, which is mutually + exclusive with 'regularizer'. + lr_gen (function): a function returns the learning rate given + the current training step. It is mutually exclusive with lr. If + both are not set, the apply_with_lr function should be used for + param updating. + momentum_gen (function): a function returns the momentum value given + the current training step. It is mutually exclusive with + momentum. + regularizer: an instance of Regularizer or RegularizerConf; If set, + regularization would be applied in apply_with_lr(). + Users can also do regularization outside. + constraint: an instance of Constraint or ConstraintConf; If set, + constraint would be applied inside apply_with_lr(). Users can + also do regularization outside. + """ + if lr is not None: + assert lr_gen is None, 'Cannot set lr and lr_gen at the same time' + + def lr_gen(step): + return lr + self.lr_gen = lr_gen + if momentum is not None: + assert momentum_gen is None, 'Cannot set momentum and momentum_gen'\ + ' at the same time' + + def momentum_gen(step): + return momentum + self.momentum_gen = momentum_gen + if decay is not None: + assert regularizer is None, \ + 'Cannot set decay and regularizer at the same time' + regularizer = L2Regularizer(decay) + if regularizer is not None: + if type(regularizer) is model_pb2.RegularizerConf: + self.regularizer = CppRegularizer(regularizer) + else: + self.regularizer = regularizer + else: + self.regularizer = None + if constraint is not None: + if type(constraint) is model_pb2.ConstraintConf: + self.constraint = CppConstraint(constraint) + else: + self.constraint = constraint + else: + self.constraint = None + self.regularizers = {} + self.constraints = {} + + def register(self, name, specs): + """Register the param specs, including creating regularizer and + constraint per param object. Param specific regularizer and constraint + have higher priority than the global ones. + + Args: + name (str): parameter name + specs (ParamSpec): protobuf obj + """ + if specs.has_regularizer(): + self.regularizers[name] = CppRegularizer(specs.constraint) + if specs.has_constraint(): + self.constraints[name] = CppConstraint(specs.regularizer) + if specs.has_lr_mult(): + self.learning_rate_multiplier[name] = specs.lr_mult() + if specs.has_decay_mult(): + self.decay_multiplier[name] = specs.decay_mult() + + def apply_regularizer_constraint(self, value, grad, name=None, step=None): + """Apply regularization and constraint if available. + + If there are both global regularizer (constraint) and param specific + regularizer (constraint), it would use the param specific one. + + Args: + value (Tensor): parameter value Tensor + grad (Tensor): parameter gradient Tensor + name (string): to get parameter specific regularizer or constraint + step (int): some regularizer or constraint would use step + + Return: + the updated gradient Tensor + """ + if name is not None and name in self.constraints: + self.constraints[name].apply(value, grad, step) + elif self.constraint is not None: + self.constraint.apply(step, value, grad) + + if name is not None and name in self.regularizers: + self.regularizers[name].apply(value, grad, step) + elif self.regularizer is not None: + self.regularizer.apply(step, value, grad) + return grad + + def apply_with_lr(self, step, lr, grad, value, name=None): + """Do update with given learning rate. + + The subclass optimizer must override this function. + Args: + step (int): training step (could be iteration or epoch) + lr (float): learning rate + grad (Tensor): parameter gradient + value (Tesnor): parameter value + name (string): paramter name to retrieval parameter specific + updating rules (including regularizer and constraint) + + Return: + updated parameter value + """ + assert False, 'This is the base function, pls call the subclass func' + return value + + def apply(self, step, grad, value, name=None): + """Do update assume the learning rate generator is set. + + The subclass optimizer does not need to override this function. + Args: + step (int): training step (could be iteration or epoch) + grad (Tensor): parameter gradient + value (Tesnor): parameter value + name (string): paramter name to retrieval parameter specific + updating rules (including regularizer and constraint) + + Return: + updated parameter value + """ + + assert self.lr_gen is not None, 'Learning rate generator is not set.'\ + 'Either set the lr_gen in constructor or call apply_with_lr' + lr = self.lr_gen(step) + return self.apply_with_lr(step, lr, grad, value, name) + + +class SGD(Optimizer): + def __init__(self, lr=None, momentum=None, decay=None, **kwargs): + """The vallina Stochasitc Gradient Descent algorithm. + + See the base Optimizer for all arguments. + """ + super(SGD, self).__init__(lr, momentum, decay) + conf = model_pb2.OptimizerConf() + self.opt = singa.CreateOptimizer('SGD') + self.opt.Setup(conf.SerializeToString()) + + def apply_with_lr(self, step, lr, grad, value, name): + self.apply_regularizer_constraint(step, value, grad, name) + self.opt.Apply(step, lr, name, grad.singa_tensor, value.singa_tensor) + return value + + +class Nesterov(Optimizer): + def __init__(self, lr=None, momentum=0.9, decay=None, **kwargs): + """The SGD with Nesterov momentum + + See the base Optimizer for all arguments. + """ + super(Nesterov, self).__init__(lr, momentum, decay, kwargs) + conf = model_pb2.OptimizerConf() + self.opt = singa.CreateOptimizer('Nesterov') + self.opt.Setup(conf.SerializeToString()) + + def apply_with_lr(self, step, lr, grad, value, name): + self.apply_regularizer_constraint(step, value, grad, name) + self.opt.Apply(step, lr, name, grad.singa_tensor, value.singa_tensor) + return value + + +class AdaGrad(Optimizer): + def __init__(self, epsilon=1e-8, lr=None, decay=None, **kwargs): + """AdaGrad optimizer. + + See the base Optimizer for all constructor args. + Args: + epsilon (float): small number for preventing numeric error. + """ + super(RMSProp, self).__init__(lr, decay, **kwargs) + conf = model_pb2.OptimizerConf() + conf.delta = epsilon + self.opt = singa.CreateOptimizer('AdaGrad') + self.opt.Setup(conf.SerializeToString()) + + def apply_with_lr(self, step, lr, grad, value, name): + grad = self.apply_regularizer_constraint(step, value, grad, name) + self.opt.Apply(step, lr, name, grad.singa_tensor, value.singa_tensor) + return value + + +class RMSProp(Optimizer): + def __init__(self, rho=0.9, epsilon=1e-8, lr=None, decay=None, **kwargs): + """RMSProp optimizer. + + See the base Optimizer for all constructor args. + Args: + rho (float): float within [0, 1] + epsilon (float): small value for preventing numeric error + """ + super(RMSProp, self).__init__(lr, decay, kwargs) + conf = model_pb2.OptimizerConf() + conf.rho = rho + conf.delta = epsilon + self.opt = singa.CreateOptimizer('RMSProp') + self.opt.Setup(conf.SerializeToString()) + + def apply_with_lr(self, step, lr, grad, value, name): + grad = self.apply_regularizer_constraint(step, value, grad, name) + self.opt.Apply(step, lr, name, grad.singa_tensor, value.singa_tensor) + return value + + +class Regularizer(object): + """Base Python regularizer for parameter gradients. + """ + def apply(self, value, grad): + assert False, 'Not Implemented. Call the subclass function.' + return grad + + +class CppRegularizer(Regularizer): + """Wrapper for regularizer implemented using C++. + """ + def __init__(self, conf): + """Constructor. + + Args: + conf (RegularizerConf): protobuf message for the configuration. + """ + self.reg = singa.CreateRegularizer(conf.type) + self.reg.Setup(conf.SerializeToString()) + + def apply(self, step, value, grad): + self.reg.Apply(step, value.singa_tensor, grad.singa_tensor) + return grad + + +class L2Regularizer(Regularizer): + """L2 regularization""" + def __init__(self, coefficient): + """ + Args: + coefficient (float): regularization coefficient. + """ + self.coefficient = coefficient + + def apply(self, step, value, grad, coefficient=None): + if coefficient is None: + assert self.coefficient is not None, 'Must set the coefficient' + coefficient = self.coefficient + tensor.axpy(coefficient, value, grad) + return grad + + +class Constraint(object): + """Base Python constraint class for paramter gradients. + """ + def apply(self, step, value, grad): + return grad + + +class CppConstraint(Constraint): + """Wrapper for constraints implemented using C++. + """ + def __init__(self, conf): + """Constructor. + + Args: + conf (ConstraintConf): protobuf message for the configuration. + """ + self.constraint = singa.CreateConstraint(conf.type) + self.constraint.Setup(conf.SerializeToString()) + + def apply(self, step, value, grad): + self.constraint.Apply(step, value.singa_tensor, grad.singa_tensor) + return grad + + +class L2Constraint(Constraint): + """Rescale the gradient to make the L2 norm <= a given threshold. + """ + def __init__(self, threshold=None): + self.threshold = threshold + + def apply(self, step, value, grad, threshold=None): + if threshold is None: + assert self.threshold is not None, 'Must set the threshold' + threshold = self.threshold + nrm = grad.nrm2() + grad *= threshold / nrm + return grad http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/7333517b/src/python/setup.py.in ---------------------------------------------------------------------- diff --git a/src/python/setup.py.in b/src/python/setup.py.in index 9002c70..739104d 100644 --- a/src/python/setup.py.in +++ b/src/python/setup.py.in @@ -1,7 +1,6 @@ # Always prefer setuptools over distutils -from setuptools import setup, find_packages -from codecs import open -from os import path +from setuptools import setup + setup( name='singa', @@ -12,8 +11,8 @@ setup( url='https://github.com/apache/incubator-singa', - author='NUS Database Group', - author_email='[email protected]', + author='Apache SINGA (incubating)', + author_email='[email protected]', license='Apache 2', @@ -37,40 +36,43 @@ setup( keywords='deep learning singa apache', - packages= ['singa','singa.proto'], - - # py_modules=["singa"], + packages= ['singa', 'singa.proto'], - # install_requires=['peppercorn'], + ''' + py_modules=["singa"], - # List additional groups of dependencies here (e.g. development - # dependencies). You can install these using the following syntax, - # for example: - # $ pip install -e .[dev,test] - #extras_require={ - # 'dev': ['check-manifest'], - # 'test': ['coverage'], - #}, + install_requires=['peppercorn'], + List additional groups of dependencies here (e.g. development + dependencies). You can install these using the following syntax, + for example: + $ pip install -e .[dev,test] + extras_require={ + 'dev': ['check-manifest'], + 'test': ['coverage'], + }, - # If there are data files included in your packages that need to be - # installed, specify them here. If using Python 2.6 or less, then these - # have to be included in MANIFEST.in as well. + If there are data files included in your packages that need to be + installed, specify them here. If using Python 2.6 or less, then these + have to be included in MANIFEST.in as well. + ''' package_data={ - 'singa': ['_singa.so'], + 'singa': ['_singa_wrap.so'], }, - # Although 'package_data' is the preferred approach, in some case you may - # need to place data files outside of your packages. See: - # http://docs.python.org/3.4/distutils/setupscript.html#installing-additional-files # noqa - # In this case, 'data_file' will be installed into '<sys.prefix>/my_data' - #data_files=[('my_data', ['data/data_file'])], - - # To provide executable scripts, use entry points in preference to the - # "scripts" keyword. Entry points provide cross-platform support and allow - # pip to create the appropriate form of executable for the target platform. + ''' + Although 'package_data' is the preferred approach, in some case you may + need to place data files outside of your packages. See: + http://docs.python.org/3.4/distutils/setupscript.html#installing-additional-files # noqa + In this case, 'data_file' will be installed into '<sys.prefix>/my_data' + data_files=[('my_data', ['data/data_file'])], + + To provide executable scripts, use entry points in preference to the + "scripts" keyword. Entry points provide cross-platform support and allow + pip to create the appropriate form of executable for the target platform. + ''' entry_points={ 'console_scripts': [ 'singa=singa:main', ], }, -) \ No newline at end of file +) http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/7333517b/src/python/swig/core_device.i ---------------------------------------------------------------------- diff --git a/src/python/swig/core_device.i b/src/python/swig/core_device.i index e410982..b79d37e 100644 --- a/src/python/swig/core_device.i +++ b/src/python/swig/core_device.i @@ -24,6 +24,7 @@ %module core_device %include "std_vector.i" %include "std_string.i" +%include "std_pair.i" %include "std_shared_ptr.i" %{ @@ -32,32 +33,31 @@ /* smart pointer to avoid memory leak */ %shared_ptr(singa::Device); -%shared_ptr(singa::CppCPU); -%shared_ptr(singa::CudaGPU); + +namespace std{ +%template(sizePair) std::pair<size_t, size_t>; +%template(vectorPair) std::vector<std::pair<size_t, size_t>>; +%template(vectorSharedPtr) std::vector<std::shared_ptr<singa::Device>>; +} namespace singa{ - class Device { - public: - virtual void SetRandSeed(unsigned seed) = 0; - std::shared_ptr<Device> host(); - int id() const; - }; - - class CppCPU : public Device { - public: - CppCPU(); - void SetRandSeed(unsigned seed) override; - /* (TODO) add necessary functions of CppCPU class - */ - }; - - class CudaGPU : public Device { - public: - CudaGPU(); - void SetRandSeed(unsigned seed) override; - /* (TODO) add necessary functions of CudaGPU class - */ - }; +class Device { + public: + virtual void SetRandSeed(unsigned seed) = 0; + std::shared_ptr<Device> host(); + int id() const; +}; + +class Platform { + public: + static int GetNumGPUs(); + static const std::vector<int> GetGPUIDs(); + static const std::pair<size_t, size_t> GetGPUMemSize(const int device); + static const std::vector<std::pair<size_t, size_t>> GetGPUMemSize(); + static const std::string DeviceQuery(int id, bool verbose = false); + static const std::vector<std::shared_ptr<Device> > + CreateCudaGPUs(const size_t num_devices, size_t init_size = 0); +}; } http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/7333517b/src/python/swig/model_layer.i ---------------------------------------------------------------------- diff --git a/src/python/swig/model_layer.i b/src/python/swig/model_layer.i index 15bd05f..ee7c319 100644 --- a/src/python/swig/model_layer.i +++ b/src/python/swig/model_layer.i @@ -43,7 +43,7 @@ namespace std { %template(tensorVector) vector<Tensor>; %template(tensorPtrVector) vector<Tensor*>; %template(ttvecPair) pair<Tensor, vector<Tensor>>; - %template(tvectvecPair) pair<vector<Tensor>, vector<Tensor>>; + %template(tvecPair) pair<vector<Tensor>, vector<Tensor>>; } @@ -83,6 +83,5 @@ namespace singa { virtual const std::pair<std::vector<Tensor>, std::vector<Tensor>> Backward(int flag, const vector<Tensor>& grads); }; - } http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/7333517b/src/python/swig/model_optimizer.i ---------------------------------------------------------------------- diff --git a/src/python/swig/model_optimizer.i b/src/python/swig/model_optimizer.i new file mode 100644 index 0000000..ee60f54 --- /dev/null +++ b/src/python/swig/model_optimizer.i @@ -0,0 +1,70 @@ +/************************************************************ +* +* Licensed to the Apache Software Foundation (ASF) under one +* or more contributor license agreements. See the NOTICE file +* distributed with this work for additional information +* regarding copyright ownership. The ASF licenses this file +* to you under the Apache License, Version 2.0 (the +* "License"); you may not use this file except in compliance +* with the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an +* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +* KIND, either express or implied. See the License for the +* specific language governing permissions and limitations +* under the License. +* +*************************************************************/ + +/*interface file for swig */ + +%module model_optimizer +%include "std_vector.i" +%include "std_string.i" +%include "std_pair.i" +%include "std_shared_ptr.i" + +%{ +#include "singa/model/optimizer.h" +#include "singa/proto/model.pb.h" +using singa::Tensor; +using singa::ParamSpec; +using singa::OptimizerConf; +%} + + +%shared_ptr(singa::Optimizer) +%shared_ptr(singa::Regularizer) +%shared_ptr(singa::Constraint) + +namespace singa { +class Optimizer { + public: + // Optimizer() = default; + virtual ~Optimizer() = default; + void Setup(const std::string& str); + virtual void Apply(int step, float lr, const std::string& name, + const Tensor& grad, Tensor* value) = 0; +}; +inline std::shared_ptr<Optimizer> CreateOptimizer(const std::string& type); + +class Constraint { + public: + Constraint() = default; + void Setup(const std::string& conf_str); + void Apply(int step, Tensor* grad, Tensor* value); +}; + +inline std::shared_ptr<Constraint> CreateConstraint(const std::string& type); + +class Regularizer { + public: + Regularizer() = default; + void Setup(const std::string& conf_str); + void Apply(int step, Tensor* grad, Tensor* value); +}; +inline std::shared_ptr<Regularizer> CreateRegularizer(const std::string& type); +} http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/7333517b/src/python/swig/singa.i ---------------------------------------------------------------------- diff --git a/src/python/swig/singa.i b/src/python/swig/singa.i index 8b5e2dc..dbf621a 100644 --- a/src/python/swig/singa.i +++ b/src/python/swig/singa.i @@ -21,7 +21,8 @@ /*interface file for swig */ -%module singa +%module singa_wrap %include "core_tensor.i" %include "core_device.i" %include "model_layer.i" +%include "model_optimizer.i" http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/7333517b/src/python/tensor.py ---------------------------------------------------------------------- diff --git a/src/python/tensor.py b/src/python/tensor.py index 3b9fa52..099e706 100644 --- a/src/python/tensor.py +++ b/src/python/tensor.py @@ -1,37 +1,28 @@ -#!/usr/bin/env python - -# /************************************************************ -# * -# * Licensed to the Apache Software Foundation (ASF) under one -# * or more contributor license agreements. See the NOTICE file -# * distributed with this work for additional information -# * regarding copyright ownership. The ASF licenses this file -# * to you under the Apache License, Version 2.0 (the -# * "License"); you may not use this file except in compliance -# * with the License. You may obtain a copy of the License at -# * -# * http://www.apache.org/licenses/LICENSE-2.0 -# * -# * Unless required by applicable law or agreed to in writing, -# * software distributed under the License is distributed on an -# * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# * KIND, either express or implied. See the License for the -# * specific language governing permissions and limitations -# * under the License. -# * -# *************************************************************/ - -''' +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# ============================================================================= +""" This script includes Tensor class and its methods for python users to call singa::Tensor and its methods -''' -import sys -import os -import numpy as np +""" -from . import singa - -from .proto.core_pb2 import * +import numpy as np +from proto.core_pb2 import * +from . import singa_wrap as singa class Tensor(object): @@ -109,7 +100,7 @@ class Tensor(object): self.singa_tensor.ToHost() def nrm2(self): - self.singa_tensor.L2() + return self.singa_tensor.L2() def set_value(self, x): if type(x) == float: @@ -503,4 +494,3 @@ def _call_singa_func(_singa_func, *args): new_t.device = new_t.singa_tensor.device() new_t.dtype = new_t.singa_tensor.data_type() return new_t - http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/7333517b/test/python/test_optimizer.py ---------------------------------------------------------------------- diff --git a/test/python/test_optimizer.py b/test/python/test_optimizer.py new file mode 100644 index 0000000..fa062c8 --- /dev/null +++ b/test/python/test_optimizer.py @@ -0,0 +1,104 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# ============================================================================= +import sys +import os +import unittest +import numpy as np + +sys.path.append(os.path.join(os.path.dirname(__file__), '../../build/python')) + +import singa.tensor as tensor +import singa.optimizer as opt +import singa.device as device + +cuda = device.Platform.create_cuda_gpu() + + +class TestOptimizer(unittest.TestCase): + + def setUp(self): + self.np_W = np.array([0.1, 0.2, 0.3, 0.4], dtype=np.float32) + self.W = tensor.from_numpy(self.np_W) + self.np_g = np.array([0.1, 0.3, 0.1, 0.2], dtype=np.float32) + self.g = tensor.from_numpy(self.np_g) + + def to_cuda(self): + self.W.to_device(cuda) + self.g.to_device(cuda) + + def test_sgd(self): + lr = 0.1 + sgd = opt.SGD(lr) + sgd.apply(0, self.g, self.W, 'w') + w = tensor.to_numpy(self.W) + for i in range(self.W.size()): + self.assertAlmostEqual(w[i], self.np_W[i] - lr * self.np_g[i]) + + def test_sgd_cuda(self): + lr = 0.1 + sgd = opt.SGD(lr) + self.to_cuda() + sgd.apply(0, self.g, self.W, 'w') + self.W.to_host() + w = tensor.to_numpy(self.W) + for i in range(self.W.size()): + self.assertAlmostEqual(w[i], self.np_W[i] - lr * self.np_g[i]) + + def test_constraint(self): + threshold = 0.02 + cons = opt.L2Constraint(threshold) + cons.apply(0, self.W, self.g) + g = tensor.to_numpy(self.g) + nrm = np.linalg.norm(self.np_g) / self.np_g.size + for i in range(g.size): + self.assertAlmostEqual(g[i], self.np_g[i] * threshold / nrm) + + def test_constraint_cuda(self): + threshold = 0.02 + self.to_cuda() + cons = opt.L2Constraint(threshold) + cons.apply(0, self.W, self.g) + self.g.to_host() + g = tensor.to_numpy(self.g) + nrm = np.linalg.norm(self.np_g) / self.np_g.size + for i in range(g.size): + self.assertAlmostEqual(g[i], self.np_g[i] * threshold / nrm) + + def test_regularizer(self): + coefficient = 0.0001 + reg = opt.L2Regularizer(coefficient) + reg.apply(0, self.W, self.g) + g = tensor.to_numpy(self.g) + for i in range(g.size): + self.assertAlmostEqual(g[i], + self.np_g[i] + coefficient * self.np_W[i]) + + def test_regularizer_cuda(self): + coefficient = 0.0001 + reg = opt.L2Regularizer(coefficient) + self.to_cuda() + reg.apply(0, self.W, self.g) + self.g.to_host() + g = tensor.to_numpy(self.g) + for i in range(g.size): + self.assertAlmostEqual(g[i], + self.np_g[i] + coefficient * self.np_W[i]) + + +if __name__ == '__main__': + unittest.main() http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/7333517b/test/python/test_tensor.py ---------------------------------------------------------------------- diff --git a/test/python/test_tensor.py b/test/python/test_tensor.py index de9012e..c999705 100644 --- a/test/python/test_tensor.py +++ b/test/python/test_tensor.py @@ -1,26 +1,20 @@ -#!/usr/bin/env python - -#/************************************************************ -#* -#* Licensed to the Apache Software Foundation (ASF) under one -#* or more contributor license agreements. See the NOTICE file -#* distributed with this work for additional information -#* regarding copyright ownership. The ASF licenses this file -#* to you under the Apache License, Version 2.0 (the -#* "License"); you may not use this file except in compliance -#* with the License. You may obtain a copy of the License at -#* -#* http://www.apache.org/licenses/LICENSE-2.0 -#* -#* Unless required by applicable law or agreed to in writing, -#* software distributed under the License is distributed on an -#* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -#* KIND, either express or implied. See the License for the -#* specific language governing permissions and limitations -#* under the License. -#* -#*************************************************************/ - +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# ============================================================================= import sys import os import math http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/7333517b/test/singa/test_adagrad.cc ---------------------------------------------------------------------- diff --git a/test/singa/test_adagrad.cc b/test/singa/test_adagrad.cc index c45dcef..e8cd062 100644 --- a/test/singa/test_adagrad.cc +++ b/test/singa/test_adagrad.cc @@ -24,8 +24,8 @@ #include "singa/singa_config.h" #include <cmath> -TEST(Adagrad, ApplyCPU) { - singa::Adagrad adagrad; +TEST(AdaGrad, ApplyCPU) { + singa::AdaGrad adagrad; float lr = 0.1f; const float v[4] = {0.1, 0.2, 0.3, 0.4}; const float g[4] = {0.01, 0.02, 0.03, 0.04}; @@ -58,8 +58,8 @@ TEST(Adagrad, ApplyCPU) { } #ifdef USE_CUDA -TEST(Adagrad, ApplyCUDA) { - singa::Adagrad adagrad; +TEST(AdaGrad, ApplyCUDA) { + singa::AdaGrad adagrad; float lr = 0.1f; const float v[4] = {0.1, 0.2, 0.3, 0.4}; const float g[4] = {0.01, 0.02, 0.03, 0.04}; http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/7333517b/test/singa/test_platform.cc ---------------------------------------------------------------------- diff --git a/test/singa/test_platform.cc b/test/singa/test_platform.cc index a7c2b10..f50c978 100644 --- a/test/singa/test_platform.cc +++ b/test/singa/test_platform.cc @@ -19,6 +19,7 @@ #include "gtest/gtest.h" #include "singa/core/device.h" +#include "singa/core/tensor.h" #ifdef USE_CUDA using singa::Platform; @@ -29,7 +30,7 @@ TEST(Platform, NumGPUs) { } TEST(Platform, QueryMem) { - int n = Platform::GetNumGPUs(); + size_t n = Platform::GetNumGPUs(); auto ids = Platform::GetGPUIDs(); EXPECT_EQ(ids.size(), n); auto mem = Platform::GetGPUMemSize(); @@ -39,7 +40,7 @@ TEST(Platform, QueryMem) { TEST(Platform, CreateDevice) { auto dev = Platform::CreateCudaGPUs(1).at(0); - int size[] = { 128, 256, 3, 24 }; + size_t size[] = { 128, 256, 3, 24 }; { auto ptr = dev->NewBlock(size[0]); auto allocated = dev->GetAllocatedMem(); @@ -72,9 +73,25 @@ TEST(Platform, CreateMultDevice) { auto devs = Platform::CreateCudaGPUs(n); for (auto dev : devs) { auto b = dev->NewBlock(32); - EXPECT_LE(32, dev->GetAllocatedMem()); + EXPECT_LE(32u, dev->GetAllocatedMem()); dev->FreeBlock(b); } } + +TEST(Platform, CreatTensor) { + auto cuda = Platform::CreateCudaGPUs(1)[0]; + singa::Tensor t(singa::Shape{2,3,4}, cuda); + t.SetValue(2.1f); + t.ToHost(); + auto tPtr = t.data<float>(); + for (size_t i = 0; i < t.Size(); i++) + EXPECT_FLOAT_EQ(tPtr[i], 2.1f); + t.ToDevice(cuda); + t = t * 3.0f; + t.ToHost(); + tPtr = t.data<float>(); + for (size_t i = 0; i < t.Size(); i++) + EXPECT_FLOAT_EQ(tPtr[i], 2.1f * 3.0f); +} #endif http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/7333517b/test/singa/test_snapshot.cc ---------------------------------------------------------------------- diff --git a/test/singa/test_snapshot.cc b/test/singa/test_snapshot.cc index 26f1f8c..e83145b 100644 --- a/test/singa/test_snapshot.cc +++ b/test/singa/test_snapshot.cc @@ -49,12 +49,12 @@ TEST(Snapshot, ReadTest) { singa::Tensor param_1, param_2; singa::Shape shape1, shape2; shape1 = snapshot.ReadShape("Param_1"); - EXPECT_EQ(shape1.size(), 1); - EXPECT_EQ(shape1[0], 4); + EXPECT_EQ(shape1.size(), 1u); + EXPECT_EQ(shape1[0], 4u); shape2 = snapshot.ReadShape("Param_2"); - EXPECT_EQ(shape2.size(), 2); - EXPECT_EQ(shape2[0], 2); - EXPECT_EQ(shape2[1], 2); + EXPECT_EQ(shape2.size(), 2u); + EXPECT_EQ(shape2[0], 2u); + EXPECT_EQ(shape2[1], 2u); param_1 = snapshot.Read("Param_1"); const float* data_1 = param_1.data<float>(); for (size_t i = 0; i < singa::Product(shape1); ++i) @@ -84,8 +84,8 @@ TEST(Snapshot, ReadIntTest) { singa::Snapshot int_snapshot_read(prefix+".int", singa::Snapshot::kRead); singa::Shape shape; shape = int_snapshot_read.ReadShape("IntParam"); - EXPECT_EQ(shape.size(), 1); - EXPECT_EQ(shape[0], 4); + EXPECT_EQ(shape.size(), 1u); + EXPECT_EQ(shape[0], 4u); singa::Tensor int_param = int_snapshot_read.Read("IntParam"); const int* param_data = int_param.data<int>(); for (size_t i = 0; i < singa::Product(shape); ++i) @@ -106,8 +106,8 @@ TEST(Snapshot, ReadDoubleTest) { singa::Snapshot double_snapshot_read(prefix+".double", singa::Snapshot::kRead); singa::Shape shape; shape = double_snapshot_read.ReadShape("DoubleParam"); - EXPECT_EQ(shape.size(), 1); - EXPECT_EQ(shape[0], 4); + EXPECT_EQ(shape.size(), 1u); + EXPECT_EQ(shape[0], 4u); singa::Tensor double_param = double_snapshot_read.Read("DoubleParam"); const double* param_data = double_param.data<double>(); for (size_t i = 0; i < singa::Product(shape); ++i)
