SINGA-170 Add Dropout layer and CudnnDropout layer pass compilation. there is link error for cudnn.
Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/99e0d24d Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/99e0d24d Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/99e0d24d Branch: refs/heads/dev Commit: 99e0d24d90fa1c588d73f87f402dfb0ac36ca8a7 Parents: 02851fa Author: Wei Wang <[email protected]> Authored: Mon May 16 21:40:24 2016 +0800 Committer: wangwei <[email protected]> Committed: Tue May 17 00:40:24 2016 +0800 ---------------------------------------------------------------------- CMakeLists.txt | 7 +- include/singa/core/common.h | 29 ++++- include/singa/core/device.h | 4 +- include/singa/core/tensor.h | 62 ++++++----- include/singa/model/layer.h | 190 +++++++++++++++++++++++++------- include/singa/model/param.h | 97 ---------------- src/CMakeLists.txt | 7 +- src/core/device/device.cc | 4 +- src/core/tensor/tensor.cc | 107 ++++++++++-------- src/core/tensor/tensor_math.h | 11 +- src/core/tensor/tensor_math_cpp.h | 29 +++++ src/core/tensor/tensor_math_cuda.h | 24 ++-- src/model/layer/conv.cc | 27 ----- src/model/layer/cudnn_dropout.cc | 106 ++++++++++++++++++ src/model/layer/cudnn_dropout.h | 54 +++++++++ src/model/layer/cudnn_utils.h | 83 ++++++++++++++ src/model/layer/dropout.cc | 60 ++++++++++ src/model/layer/dropout.h | 49 ++++++++ src/model/layer/layer.cc | 30 ----- src/proto/core.proto | 3 +- src/proto/layer.proto | 10 +- test/singa/test_dropout.cc | 29 +++++ test/singa/test_tensor.cc | 10 +- 23 files changed, 722 insertions(+), 310 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/99e0d24d/CMakeLists.txt ---------------------------------------------------------------------- diff --git a/CMakeLists.txt b/CMakeLists.txt index 67a82e5..dd92d03 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,6 +1,6 @@ CMAKE_MINIMUM_REQUIRED(VERSION 2.6) PROJECT(singa) -SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g -std=c++11") +SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g -std=c++11 -DUSE_CUDA -DUSE_CUDNN") # Flags IF(UNIX OR APPLE) @@ -10,12 +10,13 @@ ENDIF() # Includes SET(singa_include_dir ${PROJECT_SOURCE_DIR}/include) INCLUDE_DIRECTORIES(${singa_include_dir} ${PROJECT_BINARY_DIR}) +INCLUDE_DIRECTORIES("/home/wangwei/local/cudnn5/include" "/usr/local/cuda/include") SET(LIBRARY_OUTPUT_PATH ${PROJECT_BINARY_DIR}/lib) SET(EXECUTABLE_OUTPUT_PATH ${PROJECT_BINARY_DIR}/bin) -SET(singa_linker_lib) -LINK_DIRECTORIES(${LIBRARY_OUTPUT_PATH}) +SET(singa_linker_lib cudnn) +LINK_DIRECTORIES(${LIBRARY_OUTPUT_PATH} "/home/wangwei/local/cudnn5/lib64/") INCLUDE(cmake/ProtoBuf.cmake) http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/99e0d24d/include/singa/core/common.h ---------------------------------------------------------------------- diff --git a/include/singa/core/common.h b/include/singa/core/common.h index 1d73f67..4d783fb 100644 --- a/include/singa/core/common.h +++ b/include/singa/core/common.h @@ -18,9 +18,18 @@ #ifndef SINGA_CORE_COMMON_H_ #define SINGA_CORE_COMMON_H_ - +#include <random> +#include <chrono> #include "singa/utils/logging.h" +#ifdef USE_CUDA +#include <cuda_runtime.h> +#include "cublas_v2.h" +#ifdef USE_CUDNN +#include <cudnn.h> +#endif +#endif + namespace singa { namespace lib { /// To implemente functions using cpp libraries @@ -37,10 +46,10 @@ typedef unsigned char Byte; /// Blob reprent a chunk of memory (on device or host) managed by VirtualMemory. class Blob { public: - Blob(void* ptr, int size) : data_(ptr), size_(size), ref_count_(1) {} + Blob(void* ptr, size_t size) : data_(ptr), size_(size), ref_count_(1) {} void* mutable_data() const { return data_; } const void* data() const { return data_; } - int size() const { return size_; } + size_t size() const { return size_; } int IncRefCount() { ref_count_++; return ref_count_; @@ -54,11 +63,21 @@ class Blob { private: void* data_ = nullptr; - int size_ = 0; + size_t size_ = 0; int ref_count_ = 0; }; -class Context {}; +typedef struct _Context { + std::mt19937 random_generator; + unsigned long long seed; +#ifdef USE_CUDA + cublasHandle_t cublas_handle; + cudaStream_t stream; +#ifdef USE_CUDNN + cudnnHandle_t cudnn_handle; +#endif +#endif +} Context; } // namespace singa #endif // SINGA_CORE_COMMON_H_ http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/99e0d24d/include/singa/core/device.h ---------------------------------------------------------------------- diff --git a/include/singa/core/device.h b/include/singa/core/device.h index fa30d6d..f3bb5a2 100644 --- a/include/singa/core/device.h +++ b/include/singa/core/device.h @@ -79,8 +79,8 @@ class Device { void CopyDataFromHostPtr(Blob* dst, const void* src, size_t size); /// Submit the operation to the device, which may execute it right now or /// delay it depending on the scheduler. - void Submit(function<void(Context*)> fn, const vector<Blob*> read_blobs, - const vector<Blob*> write_blobs); + void Exec(function<void(Context*)> fn, const vector<Blob*> read_blobs, + const vector<Blob*> write_blobs, bool use_rand_generator = false); // Wait for one event. // void WaitFor(); http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/99e0d24d/include/singa/core/tensor.h ---------------------------------------------------------------------- diff --git a/include/singa/core/tensor.h b/include/singa/core/tensor.h index 4278078..4807123 100644 --- a/include/singa/core/tensor.h +++ b/include/singa/core/tensor.h @@ -31,25 +31,23 @@ using std::vector; using std::tuple; namespace singa { -typedef vector<int> Shape; -inline int Product(Shape shape) { - if (shape.size() == 0) - return 0; - return Product(shape.begin(), shape.end()); -} - -inline int Product(vector<int>::iterator begin, vector<int>::iterator end) { - CHECK(begin != end); - int v = 1; - for (auto it = being; it < end; it++) - v* = *it; +typedef vector<size_t> Shape; +typedef Shape::iterator ShapeIter; +inline size_t Product(const Shape& shape, int start = 0, size_t len = 0) { + if (len == 0) + len = shape.size(); + CHECK_LE(len, shape.size()); + size_t v = 1; + for (unsigned int i = start; i < len; i ++) + v *= shape[i]; return v; } /// hardcode the width of types defined in DataType -const int kDataWidth[] = {4, 2, 4, 1}; -inline int SizeOf(DataType t) { - static_assert(kNumDataType == sizeof(kDataWidth) / sizeof(int), +const size_t kDataWidth[] = {sizeof(float), sizeof(float) / 2, sizeof(int), + sizeof(char), sizeof(double)}; +inline size_t SizeOf(DataType t) { + static_assert(kNumDataType == sizeof(kDataWidth) / sizeof(size_t), "Num of data types not match num of data width"); CHECK_GT(kNumDataType, t); return kDataWidth[t]; @@ -112,18 +110,23 @@ class Tensor { } /// Return number of total elements - int Size() const { + size_t Size() const { return blob_->size() / SizeOf(data_type_); } /// Return memory size (i.e., Bytes) - int MemSize() const { + size_t MemSize() const { return blob_->size(); } /// Reset the tensor shape, it may reallocate blob, if MemSize() changes. void ReShape(const Shape& shape); + /// Reset the shape, device, and data type as given tensor. + /// If blob size changes, then reallocate a new blob. The previous blob would + /// be deleted. + void ResetLike(const Tensor& t); + /// Reset the data type, it would reallocate blob if type changes. void AsType(DataType type); @@ -136,7 +139,7 @@ class Tensor { /// For init the tensor values, copy 'num' elements. template<typename DType> - void CopyDataFromHostPtr(const DType* src, int num); + void CopyDataFromHostPtr(const DType* src, size_t num); /// Copy data from another Tensor which may be on a diff device. /// Meta data would not be copied! @@ -207,17 +210,17 @@ class Tensor { /// The first 'src_offset' ('dst_offset') elements will be skipped. void CopyData(Tensor* dst, const Tensor& src, - int num, - int src_offset = 0, - int dst_offset = 0); + size_t num, + size_t src_offset = 0, + size_t dst_offset = 0); /// Copy 'nBytes' bytes of src data to dst. /// The first 'src_offset' ('dst_offset') bytes will be skipped. void CopyRawData(Tensor* dst, const Tensor& src, - int nBytes, - int src_offset = 0, - int dst_offset = 0); + size_t nBytes, + size_t src_offset = 0, + size_t dst_offset = 0); // ==================Simple Linear Algebra Operations========================= Tensor Abs(const Tensor& t); @@ -306,15 +309,15 @@ void Mult(DType alpha, const Tensor& lhs, DType beta, const Tensor& rhs, // tempalte<typename DType> T Dot(const Tensor& lhs, const Tensor& rhs); //================Random operations========================================== -/// For each element x set x = 0 if random() < p; otherwise x = 1. -Tensor Bernoulli(float p, Blob* t); +/// For each element x set x = 1 if random() < p; otherwise x = 1. +void Bernoulli(float p, Tensor* t); /// Fill in Tensor 't' following uniform distribution. -Tensor Uniform(float low, DType high, Blob* t); +void Uniform(float low, float high, Tensor* t); /// Fill in Tensor 't' following Gaussian distribution. -Tensor Gaussian(float mean, DType std, Blob* t); +void Gaussian(float mean, float std, Tensor* t); //================Neural Net operations====================================== -// following API of cudnn, e.g., conv, pool, lrn, batchnorm, softmax +/* following API of cudnn, e.g., conv, pool, lrn, batchnorm, softmax void ConvFwd(const ConvConf& conf, const Tensor& x, const Tensor& w, Tensor* y); void ConvBwdBias(const ConvConf& conf, const Tensor& dy, Tensor* db); void ConvBwdFilter(const ConvConf& conf, const Tensor& dy, const Tensor& x, @@ -325,6 +328,7 @@ void PoolFwd(const PoolConf& conf, const Tensor& x, Tensor* y, Tensor* mask = nullptr); void PoolBwd(const PoolConf& conf, const Tensor& y, const Tensor& dy, const Tensor& x, Tensor* dx); +*/ } // namespace singa #endif // SINGA_CORE_TENSOR_H_ http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/99e0d24d/include/singa/model/layer.h ---------------------------------------------------------------------- diff --git a/include/singa/model/layer.h b/include/singa/model/layer.h index 7b9b6d4..48fc58f 100644 --- a/include/singa/model/layer.h +++ b/include/singa/model/layer.h @@ -21,6 +21,7 @@ #include <vector> #include <string> +#include <stack> #include "singa/core/tensor.h" #include "singa/proto/layer.pb.h" @@ -28,14 +29,10 @@ namespace singa { /// The base layer class. /// Generally, a layer conducts feature transformation against a set of Tensor -/// to generate a set of Tensor. Each layer may have some parameters represented -/// by Param instances. +/// to generate a set of Tensor. Each layer may have some parameters. class Layer { public: Layer() = default; - /// Each layer sub-class would optionaly have a type name. - /// Used for debugging and logging. - virtual const std::string layer_type() const { return "Unknown"; } /// Set meta data fields from a string representing a proto message. void Setup(const string& proto_str) { @@ -44,68 +41,183 @@ class Layer { this->Setup(conf); } + // ============= Following Functions could be override ===================== + /// Destruct the objecst created by this layer. + virtual ~Layer() { + for (Tensor * t : param_values_) { + delete t; + } + } + + /// Each layer sub-class would optionaly have a type name. + /// Used for debugging and logging. + virtual const std::string layer_type() const { return "Unknown"; } + /// Set meta data fields configured in 'conf' (a proto message). virtual void Setup(const LayerConf& conf) { name_ = conf.name(); + for (const auto& spec : conf.param()) + param_specs_.push_back(spec); + // TODO(wangwei) load param values from checkpoint blobs. } - /// Do feature transformation for given 'input' Tensor. - /// It is the forward pass for feed-forward nets and rnn nets. + /// Do feature transformation for the given 'input' tensor (denoted as x). /// 'flag' is either kPhaseTrain or kPhaseTest for feed-forward nets, and - /// would be used for phases of training other nets. - /// It will return a set of Tensor. - virtual const vector<Tensor> ComputeFeature(int flag, - const vector<Tensor>& input) { - return vector<Tensor>{}; - } - /// Compute gradients of parameters of this layer. - /// It would also compute the gradients for other layers, e.g., the - /// preceding layers in topology order. It would return an empty vector if - /// this layer does not need to compute gradients for other layers. - /// 'flag' is either kPhaseTrain or kPhaseTest for feed-forward nets, and - /// would be used for phases of training other nets. - /// 'input' is a vector of Tensor for gradients from other layers. - virtual const vector<Tensor> ComputeGradient(int flag, - const vector<Tensor>& input) { - return vector<Tensor>{}; + /// would be used for other phases of training other nets. For example, when + /// training RBM, we may create an alias of this function as ComputeFeature + /// where flag could be kPositivePhase and kNegativePhase. + /// It will return a Tensor (denoted as y). + /// If the 'input' or 'output' is required for computing the gradients in + /// Backward(), then push them into the states_ stack. + virtual const Tensor Forward(int flag, const Tensor& input) { + LOG(FATAL) << "Not implemented"; + Tensor t; + return t; + } + + /// \copydoc Forward(int flag, const Tensor& input) + /// Accept multiple input tensors and generate multiple output tensors. + virtual const vector<Tensor> Forward(int flag, const vector<Tensor>& inputs) { + vector<Tensor> ret; + if (inputs.size() == 1) + ret.push_back(Forward(flag, inputs.at(0))); + + LOG(FATAL) << "Not implemented"; + return ret; + } + + /// Compute gradients of this layer. + /// Specifically, there are two types of gradients: + /// 1. gradients of preceding layers, i.e., dx. + /// 2. gradients of parameters of this layer. + /// 1 and 2 are returned as a pair of vector<Tensor> + /// 1 is an empty tensor if there is no preceding layer or there is no need to + /// compute dx (e.g., x is from a data layer); 2 is empty if this layer has no + /// parameters. + /// 'flag' is either kTrainPhase or kTestPhase for feed-forward nets, and + /// would be used for other phases when training other nets. + /// 'grad' is a Tensor for gradient (dy) from the upper layer. + /// Some layer would use 'input' or 'output' from Forward to compute the + /// gradients of parameters. Backward() pop out the state data. + /// It is useful for RNN layers, where the same layer is used multiple + /// times just like unrolling the layer. + virtual const std::pair<Tensor, vector<Tensor>> Backward(int flag, + const Tensor& grad) { + LOG(FATAL) << "Not implemented!"; + Tensor t; + return std::make_pair(t, vector<Tensor>{}); + } + + /// \copydoc Backward(int, const vector<Tensor>&) + /// For Forward(int, const vector<Tensor>&) + virtual const std::pair<vector<Tensor>, vector<Tensor>> Backward( + int flag, const vector<Tensor>& grads) { + vector<Tensor> input_grad, param_grad; + if (grads.size() == 1u) { + auto ret = Backward(flag, grads.at(0)); + input_grad.push_back(ret.first); + param_grad = ret.second; + } else { + LOG(FATAL) << "Not implemented"; + } + return std::make_pair(input_grad, param_grad); } - // return <dx> <dw (ParamGrad)> - /// Move the layer (including its parameters and other Tensor) onto the given - /// device + /// Move the layer (including its parameters and other internal Tensor) onto + /// the given device virtual void ToDevice(Device* device) { - // for (auto p : params_) - // p->ToDevice(device); + for (auto p : param_values_) p->ToDevice(device); } - /// Set the data type of Tensor s in this layer. + /// Set the data type of Tensor in this layer. virtual void AsType(DataType dtype) { - // for (auto p : params_) - // p->AsType(dtype); + for (auto p : param_values_) p->AsType(dtype); } - /// Serialize the layer info, including params)_, into a LayerConf message. - virtual std::string ToProto(LayerConf* conf) const { + /// Serialize the layer info (including params) into a LayerConf proto message + virtual void ToProto(LayerConf* conf) const { conf->set_name(name_); + for (const auto& spec: param_specs_) { + ParamSpec* p = conf->add_param(); + p->CopyFrom(spec); + } + // TODO(wangwei) add param values into conf; } + // ======================================================================== + /// Serialize the layer info, including params_, into a string representing /// a LayerParameter message. - std::string ToProtoStr() const; + std::string ToProtoStr() const { + LayerConf conf; + ToProto(&conf); + string str; + conf.SerializeToString(&str); + return str; + } + /// Return specs/configuration of all parameter instances of this layer. + /// \ref ParamSpec. + const vector<ParamSpec> param_specs() { + return param_specs_; + } - /// Return all Param instances of this layer. - /// Each layer could cache the Param objects. - /// To save memory of , it can also create it when this function - /// is called - const vector<Param*> GetParam(); + /// Return the i-th ParamSpec. + const ParamSpec& param_specs(int i) { + return param_specs_.at(i); + } + + /// Return pointers to parameter Tensor s. + const vector<Tensor*> param_values() { + return param_values_; + } + + /// Return a pointer to the 'i'-th parameter Tensor. + Tensor* param_value(size_t i) { + CHECK_LT(i, param_values_.size()); + return param_values_[i]; + } + + /// Return names of all parmaeters. + const vector<string> param_names() { + vector<string> pname; + for (const auto& spec: param_specs_) + pname.push_back(spec.name()); + return pname; + } + + /// Return the 'i'-th parameter name. + const string& param_name(size_t i) { + CHECK_LT(i, param_specs_.size()); + return param_specs_.at(i).name(); + } /// Each layer instance would optionally have a name. /// Used for debugging and logging. const std::string name() const { return name_; } + /* + std::stack<Tensor> states() const { + return states_; + } + */ + protected: std::string name_; + vector<Tensor*> param_values_; + vector<ParamSpec> param_specs_; + /// Used to store input or output of Forward(), which would be used in + /// Backward. Rules: + /// 1. push the 'input' or 'output' into states_ if the flag of Forward() is + /// for training. + /// 2. pop data out in Backward(). + /// TODO(wangwei) enable this feature for rnn layers. + // std::stack<Tensor*> states_; }; +// =========================================================================== +// Order layer sub-classes based on alphabetical order of the first letter. +// =========================================================================== + + } // namespace singa #endif // SINGA_LAYER_H_ http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/99e0d24d/include/singa/model/param.h ---------------------------------------------------------------------- diff --git a/include/singa/model/param.h b/include/singa/model/param.h deleted file mode 100644 index b859b1c..0000000 --- a/include/singa/model/param.h +++ /dev/null @@ -1,97 +0,0 @@ -/************************************************************ -* -* Licensed to the Apache Software Foundation (ASF) under one -* or more contributor license agreements. See the NOTICE file -* distributed with this work for additional information -* regarding copyright ownership. The ASF licenses this file -* to you under the Apache License, Version 2.0 (the -* "License"); you may not use this file except in compliance -* with the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, -* software distributed under the License is distributed on an -* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -* KIND, either express or implied. See the License for the -* specific language governing permissions and limitations -* under the License. -* -*************************************************************/ - -#ifndef SINGA_MODEL_PARAM_H_ -#define SINGA_MODEL_PARAM_H_ -#include "singa/core/tensor.h" -#include <vector> -#include <string> -using std::vector; -using std::string; -namespace singa { -/// Base Param class for storing set of parameters, e.g., a weight matrix or a -/// bias vector. -/// It includes multiple Tensor s for parameter values, gradients, etc. -class Param { - public: - ~Param(); - Param(const ParamSpec& conf); - Param(Param&& p); - Param(const Param& p); - void operator=(Param&& p); - void operator=(const Param& p); - - Tensor& value() { - return value_; - } - - Tensor& grad() { - return grad_; - } - - void set_value(const Tensor& t) { - value_ = t; - } - - void set_value(Tensor&& t) { - value_ = std::move(t); - } - - void set_grad(const Tensor& t) { - isGradValid_ = true; - grad_ = t; - } - - void set_grad(Tensor&& t) { - grad_ = std::move(t); - } - - // void Compress(); - // string ToString(); - - protected: - string name_; - Tensor value_; - float lr_mult_ = 1.0f, decay_mult_ = 1.0f; -}; - -class ParamGrad { -// return grad tensor or data to recover the grad tensor, e.g., if W = U * V -// then, ParamGrad could just store U and V. provide func for serailize and -// deserialize. -}; - -// updater just copy the ParamGrad to a device and submit ops to that device, e.g., -// add grad; check update_condidtion; apply sgd; copy back. -// consider rpc (no rmda). - -Param* CreateParam(string type) { - Param* p = nullptr; - if (type == "default") - p = new Param(); - else - LOG(FATAL) << "Currently param type " << type << " is not implemented." - << "Pls use the 'default' type"; - return p; -} -#endif // SINGA_MODEL_PARAM_H_ - -} // namespace singa http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/99e0d24d/src/CMakeLists.txt ---------------------------------------------------------------------- diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index d8bec8d..e2e923e 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -15,7 +15,12 @@ FILE(GLOB_RECURSE core_source ${CMAKE_CURRENT_SOURCE_DIR}/core/ "*.cc") ADD_LIBRARY(singa_core SHARED ${core_source}) TARGET_LINK_LIBRARIES(singa_core ${singa_linker_libs}) list(APPEND singa_linker_libs singa_core) -MESSAGE(STATUS "link libs " ${singa_linker_libs}) +#MESSAGE(STATUS "link libs " ${singa_linker_libs}) + +FILE(GLOB_RECURSE model_source ${CMAKE_CURRENT_SOURCE_DIR}/model/ "*.cc") +ADD_LIBRARY(singa_model SHARED ${model_source}) +TARGET_LINK_LIBRARIES(singa_model ${singa_linker_libs}) +list(APPEND singa_linker_libs singa_model) #ADD_LIBRARY(singa_layer SHARED ${LAYER_SOURCE}) #ADD_LIBRARY(singa_model SHARED ${MODEL_SOURCE}) http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/99e0d24d/src/core/device/device.cc ---------------------------------------------------------------------- diff --git a/src/core/device/device.cc b/src/core/device/device.cc index 4976a32..b2a8705 100644 --- a/src/core/device/device.cc +++ b/src/core/device/device.cc @@ -25,8 +25,8 @@ Device::Device(int id, int num_executors, string scheduler, string vm) vm_ = nullptr; } -void Device::Submit(function<void(Context*)> fn, const vector<Blob*> read_blobs, - const vector<Blob*> write_blobs) { +void Device::Exec(function<void(Context*)> fn, const vector<Blob*> read_blobs, + const vector<Blob*> write_blobs, bool use_rand_generator) { fn(nullptr); } http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/99e0d24d/src/core/tensor/tensor.cc ---------------------------------------------------------------------- diff --git a/src/core/tensor/tensor.cc b/src/core/tensor/tensor.cc index 51b785e..8352b48 100644 --- a/src/core/tensor/tensor.cc +++ b/src/core/tensor/tensor.cc @@ -20,6 +20,7 @@ #include "./tensor_math_cpp.h" #include "./tensor_math_cuda.h" #include "./tensor_math_opencl.h" +#include <utility> namespace singa { @@ -69,6 +70,16 @@ Tensor::Tensor(Tensor&& t) t.blob_ = nullptr; } +void Tensor::ResetLike(const Tensor& t) { + if (blob_->size() != t.MemSize()) { + if (blob_ != nullptr && blob_->DecRefCount() == 0) device_->FreeBlob(blob_); + shape_ = t.shape_; + device_ = t.device_; + data_type_ = t.data_type_; + blob_ = device_->NewBlob(Product(shape_) * SizeOf(data_type_)); + } +} + void Tensor::ReShape(const Shape& shape) { if (shape_ != shape) { if (blob_ != nullptr && blob_->DecRefCount() == 0) @@ -105,7 +116,7 @@ void Tensor::ToHost() { } template<typename DType> -void Tensor::CopyDataFromHostPtr(const DType* src, int num) { +void Tensor::CopyDataFromHostPtr(const DType* src, size_t num) { CHECK_EQ(sizeof(DType), SizeOf(data_type_)) << "data_type is " << DataType_Name(data_type_) << " user given type is of size " @@ -115,7 +126,7 @@ void Tensor::CopyDataFromHostPtr(const DType* src, int num) { else LOG(WARNING) << "Copy data from null host ptr"; } -template void Tensor::CopyDataFromHostPtr(const float* src, int num); +template void Tensor::CopyDataFromHostPtr(const float* src, size_t num); void Tensor::CopyData(const Tensor& src) { CHECK_EQ(Size(), src.Size()); @@ -134,10 +145,10 @@ Tensor Tensor::Clone() { } Tensor Tensor::T() const { - CHECK_EQ(shape_.size(), 2); + CHECK_EQ(shape_.size(), 2u); Tensor t(*this); t.transpose_ = ~transpose_; - std::swap(shape_[0], shape_[1]); + std::swap(t.shape_[0], t.shape_[1]); return t; } @@ -185,21 +196,21 @@ GenUnaryScalarArgMemberFunction(operator/=, Div); // ====================Tensor Operations======================================= void CopyData(Tensor* dst, const Tensor& src, - int num, - int dst_offset, - int src_offset) { + size_t num, + size_t dst_offset, + size_t src_offset) { CHECK_GE(src.Size(), src_offset + num); CHECK_GE(dst->Size(), dst_offset + num); - int width = SizeOf(src.data_type()); + auto width = SizeOf(src.data_type()); CHECK_EQ(width, SizeOf(dst->data_type())); CopyRawData(dst, src, num * width, dst_offset * width, src_offset * width); } void CopyRawData(Tensor* dst, const Tensor& src, - int nBytes, - int dst_offset, - int src_offset) { + size_t nBytes, + size_t dst_offset, + size_t src_offset) { CHECK_GE(src.MemSize(), src_offset + nBytes); CHECK_GE(dst->MemSize(), dst_offset + nBytes); Device* src_dev = src.device(), *dst_dev = dst->device(); @@ -286,7 +297,7 @@ void CopyRawData(Tensor* dst, #define EltwiseUnaryTensorFn(fn, t, ret) \ do { \ TYPE_LIB_SWITCH(t.data_type(), DType, t.device()->device_lib(), Lib, { \ - ret->device()->Submit( \ + ret->device()->Exec( \ [t, ret](Context* ctx) { \ fn<DType, Lib>(t.Size(), t.blob(), ret->blob(), ctx); \ }, \ @@ -320,14 +331,14 @@ Tensor Softmax(const Tensor& t, int axis) { void Softmax(const Tensor& t, Tensor* ret, int axis) { int nrow = 1, ncol = t.Size(), size = ncol; CHECK_GE(axis, -1); - CHECK_GT(t.shape().size(), 0); + CHECK_GT(t.shape().size(), 0u); if (axis > -1) { - nrow = Product(t.shape().begin(), t.shape().begin() + axis + 1); + nrow = Product(t.shape(), 0, axis + 1); CHECK_EQ(size % nrow, 0) << "Size = " << size << " nrow = " << nrow; ncol = size / nrow; } TYPE_LIB_SWITCH(t.data_type(), DType, t.device()->device_lib(), Lib, { - ret->device()->Submit( + ret->device()->Exec( [nrow, ncol, t, ret](Context* ctx) { Softmax<DType, Lib>(nrow, ncol, t.blob(), ret->blob(), ctx); }, @@ -338,8 +349,8 @@ void Softmax(const Tensor& t, Tensor* ret, int axis) { #define EltwiseBinaryTensorFn(fn, lhs, rhs, ret) \ do { \ TYPE_LIB_SWITCH(lhs.data_type(), DType, lhs.device()->device_lib(), Lib, { \ - ret->device()->Submit( \ - CHECK_EQ(sizeof(DType), SizeOf(rhs.data_type())); \ + CHECK_EQ(sizeof(DType), SizeOf(rhs.data_type())); \ + ret->device()->Exec( \ [lhs, rhs, ret](Context* ctx) { \ fn<DType, Lib>(lhs.Size(), lhs.blob(), rhs.blob(), ret->blob(), \ ctx); \ @@ -364,28 +375,28 @@ GenBinaryTensorFunction(operator*, EltwiseMult); GenBinaryTensorFunction(operator/, Div); GenBinaryTensorFunction(Pow, Pow); -#define EltwiseTensorScalarFn(fn, t, x, ret) \ - do { \ - TYPE_LIB_SWITCH(t.data_type(), DType, t.device()->device_lib(), Lib, { \ - ret->device()->Submit( \ - static_assert(typeid(x) == typeid(DType), \ - "The Scalar type must match the Tensor data type"); \ - [t, x, ret](Context* ctx) { \ - fn<DType, Lib>(t.Size(), t.blob(), x, ret->blob(), ctx); \ - }, \ - {t.blob()}, {ret->blob()}); \ - }); \ +#define EltwiseTensorScalarFn(fn, t, x, ret) \ + do { \ + TYPE_LIB_SWITCH(t.data_type(), DType, t.device()->device_lib(), Lib, { \ + static_assert(std::is_same<SType, DType>::value, \ + "The Scalar type must match the Tensor data type"); \ + ret->device()->Exec( \ + [t, x, ret](Context* ctx) { \ + fn<DType, Lib>(t.Size(), t.blob(), x, ret->blob(), ctx); \ + }, \ + {t.blob()}, {ret->blob()}); \ + }); \ } while (0) #define GenTensorScalarFunction(op, fn) \ - template <typename DType> \ - Tensor op(const Tensor& t, DType x) { \ + template <typename SType> \ + Tensor op(const Tensor& t, SType x) { \ Tensor ret(t.shape(), t.device(), t.data_type()); \ fn(t, x, &ret); \ return ret; \ } \ - template <typename DType> \ - void fn(const Tensor& t, DType x, Tensor* ret) { \ + template <typename SType> \ + void fn(const Tensor& t, SType x, Tensor* ret) { \ EltwiseTensorScalarFn(fn, t, x, ret); \ } \ template Tensor op<float>(const Tensor& t, float x); \ @@ -424,15 +435,15 @@ template Tensor Mult<float>(float alpha, const Tensor& lhs, float beta, template <typename SType> void Mult(SType alpha, const Tensor& A, SType beta, const Tensor& B, Tensor* C) { - CHECK_EQ(A.shape().size(), 2); + CHECK_EQ(A.shape().size(), 2u); bool transA = A.transpose(); - int m = transA ? A.shape()[1] : A.shape()[0], n = 0; - if (B.shape().size() == 1) { + size_t m = transA ? A.shape()[1] : A.shape()[0], n = 0; + if (B.shape().size() == 1u) { n = C->Size(); TYPE_LIB_SWITCH(A.data_type(), DType, A.device()->device_lib(), Lib, { static_assert(std::is_same<SType, DType>::value, "The scalar type must be the same as the tensor data type"); - C->device()->Submit( + C->device()->Exec( [transA, m, n, alpha, A, beta, B, C](Context* ctx) { GEMV<DType, Lib>(transA, m, n, alpha, A.blob(), B.blob(), beta, C->blob(), ctx); @@ -442,7 +453,7 @@ void Mult(SType alpha, const Tensor& A, SType beta, const Tensor& B, Tensor* C) } else { CHECK(!C->transpose()); bool transB = B.transpose(); - int k = transB ? B.shape()[1] : B.shape()[0]; + size_t k = transB ? B.shape()[1] : B.shape()[0]; n = C->shape()[1]; CHECK_EQ(C->shape()[0], m); CHECK_EQ(A.Size(), m * k); @@ -450,7 +461,7 @@ void Mult(SType alpha, const Tensor& A, SType beta, const Tensor& B, Tensor* C) TYPE_LIB_SWITCH(A.data_type(), DType, A.device()->device_lib(), Lib, { static_assert(std::is_same<SType, DType>::value, "The scalar type must be the same as the tensor data type"); - C->device()->Submit( + C->device()->Exec( [transA, transB, m, n, k, alpha, A, beta, B, C](Context* ctx) { GEMM<DType, Lib>(transA, transB, m, n, k, alpha, A.blob(), B.blob(), beta, C->blob(), ctx); @@ -468,7 +479,7 @@ template void Mult<float>(float alpha, const Tensor& lhs, float beta, void Conv(const OpConf* conf, const Tensor& input, const Tensor& W, const Tensor& b, Tensor* ret) { TYPE_LIB_SWITCH(input.data_type(), DType, input.device()->nn_lib(), Lib, { - ret->device()->Submit( + ret->device()->Exec( [conf, input, W, b, ret](Context* ctx) { Conv<DType, Lib>(conf, input.blob(), W.blob(), b.blob(), ret->blob(), ctx); @@ -477,33 +488,33 @@ void Conv(const OpConf* conf, const Tensor& input, const Tensor& W, }); } */ -void Bernoulli(float threshold, Tensor* t) { +void Bernoulli(float p, Tensor* t) { TYPE_LIB_SWITCH(t->data_type(), DType, t->device()->nn_lib(), Lib, { - t->device()->Submit( - [threshold, t](Context* ctx) { - Bernoulli<DType, Lib>(t->Size(), threshold, t->blob(), ctx); + t->device()->Exec( + [p, t](Context* ctx) { + Bernoulli<DType, Lib>(t->Size(), p, t->blob(), ctx); }, - {}, {t->blob()}); + {}, {t->blob()}, true); }); } void Uniform(float low, float high, Tensor* t) { TYPE_LIB_SWITCH(t->data_type(), DType, t->device()->nn_lib(), Lib, { - t->device()->Submit( + t->device()->Exec( [low, high, t](Context* ctx) { Uniform<DType, Lib>(t->Size(), low, high, t->blob(), ctx); }, - {}, {t->blob()}); + {}, {t->blob()}, true); }); } void Gaussian(float mean, float std, Tensor* t) { TYPE_LIB_SWITCH(t->data_type(), DType, t->device()->nn_lib(), Lib, { - t->device()->Submit( + t->device()->Exec( [mean, std, t](Context* ctx) { Gaussian<DType, Lib>(t->Size(), mean, std, t->blob(), ctx); }, - {}, {t->blob()}); + {}, {t->blob()}, true); }); } } // namespace singa http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/99e0d24d/src/core/tensor/tensor_math.h ---------------------------------------------------------------------- diff --git a/src/core/tensor/tensor_math.h b/src/core/tensor/tensor_math.h index a4f68e3..aa520c9 100644 --- a/src/core/tensor/tensor_math.h +++ b/src/core/tensor/tensor_math.h @@ -96,6 +96,12 @@ void Sigmoid(int count, const Blob* input, Blob* ret, Context* ctx) { LOG(FATAL) << "Not Implemented"; } +/// Do softmax for each row invidually +template <typename DType, typename Lib> +void Softmax(int nrow, int ncol, const Blob* input, Blob* ret, Context* ctx) { + LOG(FATAL) << "Not Implemented"; +} + /// Element-wise operation, do v^x for every v from the input tensor template <typename DType, typename Lib> void Pow(int count, const Blob* input, DType x, Blob* ret, Context* ctx) { @@ -258,7 +264,7 @@ void GEMM(bool transA, bool transB, int m, int n, int k, DType alpha, // Get the random generator from 'ctx' // If DType is not float, then convert the threshold to DType template <typename DType, typename Lib> -void Bernoulli(int count, float threshold, Blob* ret, Context* ctx) { +void Bernoulli(int count, float p, Blob* ret, Context* ctx) { LOG(FATAL) << "Not Implemented"; } // The random generator should be extracted from ctx. @@ -274,7 +280,7 @@ void Gaussian(int count, float mean, float std, Blob* ret, Context* ctx) { LOG(FATAL) << "Not Implemented"; } -// ================Neural net functions======================================= +/* ================Neural net functions======================================= template <typename DType, typename Lib> void ConvFwd(ConvConf* conf, const Blob* x, const Blob* w, Blob* y, Context* ctx) { @@ -296,6 +302,7 @@ void PoolBwd(const PoolConf* conf, const Blob* y, const Blob* dy, const Blob* x, Blob* dx, Context* ctx) { LOG(FATAL) << "Not Implemented"; } +*/ } // namespace singa http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/99e0d24d/src/core/tensor/tensor_math_cpp.h ---------------------------------------------------------------------- diff --git a/src/core/tensor/tensor_math_cpp.h b/src/core/tensor/tensor_math_cpp.h index a953085..9e7ed30 100644 --- a/src/core/tensor/tensor_math_cpp.h +++ b/src/core/tensor/tensor_math_cpp.h @@ -40,6 +40,35 @@ void Add<float, lib::Cpp>(int count, } } +template <> +void Bernoulli<float, lib::Cpp>(int count, float p, Blob* ret, + Context* ctx) { + std::bernoulli_distribution distribution(p); + float* ptr = static_cast<float*>(ret->mutable_data()); + for (int i = 0; i < count; i ++) { + ptr[i] = static_cast<float>(distribution(ctx->random_generator)); + } +} + +template <> +void Uniform<float, lib::Cpp>(int count, float low, float high, Blob* ret, + Context* ctx) { + std::uniform_real_distribution<float> distribution(low, high); + float* ptr = static_cast<float*>(ret->mutable_data()); + for (int i = 0; i < count; i ++) { + ptr[i] = static_cast<float>(distribution(ctx->random_generator)); + } +} + +template <> +void Gaussian<float, lib::Cpp>(int count, float mean, float std, Blob* ret, + Context* ctx) { + std::normal_distribution<float> distribution(mean, std); + float* ptr = static_cast<float*>(ret->mutable_data()); + for (int i = 0; i < count; i++) { + ptr[i] = static_cast<float>(distribution(ctx->random_generator)); + } +} #ifdef USE_CBLAS template<> void Dot<float, lib::Cpp>(int count, http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/99e0d24d/src/core/tensor/tensor_math_cuda.h ---------------------------------------------------------------------- diff --git a/src/core/tensor/tensor_math_cuda.h b/src/core/tensor/tensor_math_cuda.h index e1c72d8..c5ea3c4 100644 --- a/src/core/tensor/tensor_math_cuda.h +++ b/src/core/tensor/tensor_math_cuda.h @@ -28,24 +28,16 @@ namespace singa { template<> void Add<float, lib::Cuda>(int count, const Blob* lhs, const Blob* rhs, Blob* ret, Context* ctx) { - cublasSetStream(ctx->handle, ctx->stream); - cublasScopy(ctx->handle, count, lhs->data(), 1, ret->mutable_data(), 1); - cublasSaxpy(ctx->handle, 1.0f, rhs->data(), 1, ret->mutable_data(), 1); + /* + cublasSetStream(ctx->cublas_handle, ctx->stream); + const float* lptr = static_cast<const float*>(lhs->data()); + const float* rptr = static_cast<const float*>(rhs->data()); + float* ptr = static_cast<float*>(ret->mutable_data()); + cublasScopy(ctx->cublas_handle, count, lptr, 1, ptr, 1); + cublasSaxpy(ctx->cublas_handle, 1.0f, rptr, 1, ptr, 1); + */ } -#ifdef USE_CUDNN -template<> -void Conv<float, lib::Cudnn>(const OpConf *conf, - const Blob* input, - const Blob* W, - const Blob* b, - Blob* ret, - Context* ctx) { - // auto conv_conf = conf->CastTo<ConvConf>(); - // conv op -} - -#endif #endif } http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/99e0d24d/src/model/layer/conv.cc ---------------------------------------------------------------------- diff --git a/src/model/layer/conv.cc b/src/model/layer/conv.cc deleted file mode 100644 index d1a7d2c..0000000 --- a/src/model/layer/conv.cc +++ /dev/null @@ -1,27 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - - -namespace singa { - - - - - - -} /* singa */ http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/99e0d24d/src/model/layer/cudnn_dropout.cc ---------------------------------------------------------------------- diff --git a/src/model/layer/cudnn_dropout.cc b/src/model/layer/cudnn_dropout.cc new file mode 100644 index 0000000..926ccb9 --- /dev/null +++ b/src/model/layer/cudnn_dropout.cc @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef USE_CUDNN +// cudnn dropout is added in cudnn 5 +//#if CUDNN_MAJOR_VERSION >= 5 +#include "./cudnn_utils.h" +#include "./cudnn_dropout.h" +#include "singa/utils/logging.h" +namespace singa { +CudnnDropout::~CudnnDropout() { + if (drop_desc_ != nullptr) + CUDNN_CHECK(cudnnDestroyDropoutDescriptor(drop_desc_)); + if (x_desc_ != nullptr) + CUDNN_CHECK(cudnnDestroyTensorDescriptor(x_desc_)); + if (y_desc_ != nullptr) + CUDNN_CHECK(cudnnDestroyTensorDescriptor(y_desc_)); +} + +void CudnnDropout::InitCudnn(int size, DataType dtype, Context* ctx) { + CHECK(!has_init_cudnn_); + CUDNN_CHECK(cudnnCreateTensorDescriptor(&x_desc_)); + CUDNN_CHECK(cudnnCreateTensorDescriptor(&y_desc_)); + CUDNN_CHECK(cudnnCreateDropoutDescriptor(&drop_desc_)); + + int dim[] = {size}; + int stride[] = {1}; + CUDNN_CHECK(cudnnSetTensorNdDescriptor(x_desc_, GetCudnnDataType(dtype), 1, + dim, stride)); + CUDNN_CHECK(cudnnSetTensorNdDescriptor(y_desc_, GetCudnnDataType(dtype), 1, + dim, stride)); + + cudnnDropoutGetStatesSize(ctx->cudnn_handle, &state_size_); + cudnnDropoutGetReserveSpaceSize(x_desc_, &reserve_size_); + cudnnSetDropoutDescriptor(drop_desc_, ctx->cudnn_handle, dropout_ratio_, + state_.blob()->mutable_data(), + state_size_, ctx->seed); + has_init_cudnn_ = true; +} + +const Tensor CudnnDropout::Forward(int flag, const Tensor& input) { + if (flag & kTrain) { + auto size = input.Size(); + DataType dtype = input.data_type(); + if (!has_init_cudnn_) { + input.device()->Exec( + [size, dtype, this](Context* ctx) { + this->InitCudnn(size, dtype, ctx); + }, + {}, {state_.blob()}); + mask_.ResetLike(input); + CHECK_EQ(reserve_size_, mask_.MemSize()); + } + Tensor out; + out.ResetLike(input); + Blob *inblob = input.blob(), *outblob = out.blob(), *mblob = mask_.blob(); + out.device()->Exec( + [inblob, outblob, mblob, this](Context* ctx) { + cudnnDropoutForward( + ctx->cudnn_handle, this->drop_desc_, this->x_desc_, inblob->data(), + this->y_desc_, outblob->mutable_data(), mblob, this->reserve_size_); + }, + {inblob}, {mblob, outblob}); + return out; + } else { + return input; + } +} + +const std::pair<Tensor, vector<Tensor>> CudnnDropout::Backward( + int flag, const Tensor& grad) { + vector<Tensor> param_grad; + Tensor dx; + if (flag & kTrain) { + dx.ResetLike(grad); + Blob *dyblob = grad.blob(), *dxblob = dx.blob(), *mblob = mask_.blob(); + dx.device()->Exec( + [dyblob, dxblob, mblob, this](Context* ctx) { + cudnnDropoutBackward(ctx->cudnn_handle, this->drop_desc_, + this->y_desc_, dyblob->data(), this->x_desc_, + dxblob->mutable_data(), mblob, + this->reserve_size_); + }, + {dyblob, mblob}, {dxblob}); + } else { + LOG(ERROR) << "Do not call backward for evaluation phase"; + } + return std::make_pair(dx, param_grad); +} +} // namespace singa +//#endif // CUDNN_VERSION_MAJOR>=5 +#endif // USE_CUDNN http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/99e0d24d/src/model/layer/cudnn_dropout.h ---------------------------------------------------------------------- diff --git a/src/model/layer/cudnn_dropout.h b/src/model/layer/cudnn_dropout.h new file mode 100644 index 0000000..0a19214 --- /dev/null +++ b/src/model/layer/cudnn_dropout.h @@ -0,0 +1,54 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef SINGA_MODEL_LAYER_CUDNN_DROPOUT_H_ +#define SINGA_MODEL_LAYER_CUDNN_DROPOUT_H_ +#ifdef USE_CUDNN +// cudnn dropout is added in cudnn 5 +//#if CUDNN_MAJOR_VERSION >= 5 + +#include "singa/model/layer.h" +#include "singa/core/common.h" +#include "singa/proto/core.pb.h" +#include "./dropout.h" + +namespace singa { +class CudnnDropout : public Dropout { + public: + ~CudnnDropout(); + /// \copydoc Layer::layer_type() + const std::string layer_type() const override { return "CudnnDropout"; } + + const Tensor Forward(int flag, const Tensor& input) override; + const std::pair<Tensor, vector<Tensor>> Backward( + int flag, const Tensor& grad) override; + + /// Init cudnn related data structures. + void InitCudnn(int size, DataType dtype, Context* ctx); + + private: + bool has_init_cudnn_ = false; + cudnnDropoutDescriptor_t drop_desc_; + cudnnTensorDescriptor_t x_desc_, y_desc_; + size_t state_size_, reserve_size_; + Tensor state_; +}; +} // namespace +//#endif // CUDNN_VERSION_MAJOR>=5 +#endif // USE_CUDNN +#endif // SINGA_MODEL_LAYER_CUDNN_DROPOUT_H_ http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/99e0d24d/src/model/layer/cudnn_utils.h ---------------------------------------------------------------------- diff --git a/src/model/layer/cudnn_utils.h b/src/model/layer/cudnn_utils.h new file mode 100644 index 0000000..735ec13 --- /dev/null +++ b/src/model/layer/cudnn_utils.h @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef SINGA_MODEL_LAYER_CUDNN_BASE_H_ +#define SINGA_MODEL_LAYER_CUDNN_BASE_H_ +#ifdef USE_CUDNN +#include "singa/proto/core.pb.h" +#include "singa/utils/logging.h" +#include <cudnn.h> +namespace singa { +inline cudnnDataType_t GetCudnnDataType(DataType dtype) { + cudnnDataType_t ret; + switch (dtype) { + case kFloat32: + ret = CUDNN_DATA_FLOAT; + break; + case kDouble: + ret = CUDNN_DATA_DOUBLE; + break; + case kFloat16: + ret = CUDNN_DATA_HALF; + break; + default: + LOG(FATAL) << "The data type " << DataType_Name(dtype) + << " is not support by cudnn"; + } + return ret; +} + +#define CUDNN_CHECK(condition) \ + do { \ + cudnnStatus_t status = condition; \ + CHECK_EQ(status, CUDNN_STATUS_SUCCESS) << " "\ + << cudnnGetErrorString(status); \ + } while (0) + +/* +inline const char* cudnnGetErrorString(cudnnStatus_t status) { + switch (status) { + case CUDNN_STATUS_SUCCESS: + return "CUDNN_STATUS_SUCCESS"; + case CUDNN_STATUS_NOT_INITIALIZED: + return "CUDNN_STATUS_NOT_INITIALIZED"; + case CUDNN_STATUS_ALLOC_FAILED: + return "CUDNN_STATUS_ALLOC_FAILED"; + case CUDNN_STATUS_BAD_PARAM: + return "CUDNN_STATUS_BAD_PARAM"; + case CUDNN_STATUS_INTERNAL_ERROR: + return "CUDNN_STATUS_INTERNAL_ERROR"; + case CUDNN_STATUS_INVALID_VALUE: + return "CUDNN_STATUS_INVALID_VALUE"; + case CUDNN_STATUS_ARCH_MISMATCH: + return "CUDNN_STATUS_ARCH_MISMATCH"; + case CUDNN_STATUS_MAPPING_ERROR: + return "CUDNN_STATUS_MAPPING_ERROR"; + case CUDNN_STATUS_EXECUTION_FAILED: + return "CUDNN_STATUS_EXECUTION_FAILED"; + case CUDNN_STATUS_NOT_SUPPORTED: + return "CUDNN_STATUS_NOT_SUPPORTED"; + case CUDNN_STATUS_LICENSE_ERROR: + return "CUDNN_STATUS_LICENSE_ERROR"; + } + return "Unknown cudnn status"; +} +*/ + +} // namespace singa +#endif // USE_CUDNN +#endif // SINGA_MODEL_LAYER_CUDNN_BASE_H_ http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/99e0d24d/src/model/layer/dropout.cc ---------------------------------------------------------------------- diff --git a/src/model/layer/dropout.cc b/src/model/layer/dropout.cc new file mode 100644 index 0000000..f0fe25b --- /dev/null +++ b/src/model/layer/dropout.cc @@ -0,0 +1,60 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "singa/model/layer.h" +#include "./dropout.h" +namespace singa { + +void Dropout::Setup(const LayerConf& conf) { + Layer::Setup(conf); + dropout_ratio_ = conf.dropout_conf().dropout_ratio(); +} + +const Tensor Dropout::Forward(int flag, const Tensor& input) { + Tensor out; + if (flag & kTrain) { + mask_.ResetLike(input); + // set mask_[i] = 1 with prob 1-dropout_rato_ + Bernoulli(1 - dropout_ratio_, &mask_); + mask_ *= 1.0f / (1.0f - dropout_ratio_); + out = input * mask_; + } else { + out = input; + } + return out; +} + +const std::pair<Tensor, vector<Tensor>> Dropout::Backward( + int flag, const Tensor& grad) { + vector<Tensor> param_grad; + Tensor input_grad; + if (flag & kTrain) { + // note mask is already scaled by 1/(1-dropout_ratio_) + input_grad = grad * mask_; + } else { + LOG(ERROR) << "Do not call backward for evaluation phase"; + } + return std::make_pair(input_grad, param_grad); +} + +void Dropout::ToDevice(Device* device) { + Layer::ToDevice(device); + mask_.ToDevice(device); +} + +} // namespace singa http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/99e0d24d/src/model/layer/dropout.h ---------------------------------------------------------------------- diff --git a/src/model/layer/dropout.h b/src/model/layer/dropout.h new file mode 100644 index 0000000..de349a5 --- /dev/null +++ b/src/model/layer/dropout.h @@ -0,0 +1,49 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef SINGA_MODEL_LAYER_DROPOUT_H_ +#define SINGA_MODEL_LAYER_DROPOUT_H_ +#include "singa/model/layer.h" +namespace singa { +class Dropout : public Layer { + public: + /// \copydoc Layer::layer_type() + const std::string layer_type() const override { return "Dropout"; } + + /// \copydoc Layer::Setup(const LayerConf&); + void Setup(const LayerConf& conf) override; + + /// \copydoc Layer::Forward(int flag, const Tensor&) + /// if flag is kTrain, then do dropout with given dropout_ratio; + /// otherwise if it is kEval, copy input directly to the output + /// TODO(wangwei) There are diff implementations, Caffe vs + /// <a href="https://github.com/nitishsrivastava/deepnet/blob/master/deepnet/fastdropoutnet.py"> + const Tensor Forward(int flag, const Tensor& input) override; + + /// \copydoc Layer::Backward(int, const Tensor&, const Tensor&); + const std::pair<Tensor, vector<Tensor>> Backward(int flag, + const Tensor& grad) override; + + void ToDevice(Device* device) override; + + protected: + /// the proability to set each element to 0. + float dropout_ratio_; + Tensor mask_; +}; +} // namespace singa +#endif // SINGA_MODEL_LAYER_DROPOUT_H_ http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/99e0d24d/src/model/layer/layer.cc ---------------------------------------------------------------------- diff --git a/src/model/layer/layer.cc b/src/model/layer/layer.cc deleted file mode 100644 index 0e83cde..0000000 --- a/src/model/layer/layer.cc +++ /dev/null @@ -1,30 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "singa/model/layer.h" - -namespace singa { -const vector<Tensor> ComputeFeature(int flag, const vector<Tensor>& input) { - const vector<Blob*> input_blobs; - -} - -void ComputeFeature(int flag, const vector<Tensor>& input) { - const vector<Blob*> input_blobs; - -} -} // namespace singa http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/99e0d24d/src/proto/core.proto ---------------------------------------------------------------------- diff --git a/src/proto/core.proto b/src/proto/core.proto index c137186..f366ed0 100644 --- a/src/proto/core.proto +++ b/src/proto/core.proto @@ -26,7 +26,8 @@ enum DataType { kFloat16 = 1; kInt = 2; kChar = 3; - kNumDataType = 4; + kDouble = 4; + kNumDataType = 5; } enum LibType { http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/99e0d24d/src/proto/layer.proto ---------------------------------------------------------------------- diff --git a/src/proto/layer.proto b/src/proto/layer.proto index 0fbbb5d..3d130ea 100644 --- a/src/proto/layer.proto +++ b/src/proto/layer.proto @@ -98,11 +98,15 @@ message ParamSpec { // The multiplier on the global weight decay for this parameter. optional float decay_mult = 4 [default = 1.0]; - // SINGA field for creating diff Param, e.g. SparseParam or CompressableParam - // Curently only have a default param implementation. - optional string type = 20 [default = "default"]; + // SINGA uses this filed internally. Users just configure the fillers in + // Layer specific conf message as caffe (style). + optional FillerConf filler = 20; } +enum Phase { + kTrain = 4; + kEval = 8; +} // NOTE // Update the next available ID when you add a new LayerConf field. // http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/99e0d24d/test/singa/test_dropout.cc ---------------------------------------------------------------------- diff --git a/test/singa/test_dropout.cc b/test/singa/test_dropout.cc new file mode 100644 index 0000000..cfe9d73 --- /dev/null +++ b/test/singa/test_dropout.cc @@ -0,0 +1,29 @@ +/************************************************************ +* +* Licensed to the Apache Software Foundation (ASF) under one +* or more contributor license agreements. See the NOTICE file +* distributed with this work for additional information +* regarding copyright ownership. The ASF licenses this file +* to you under the Apache License, Version 2.0 (the +* "License"); you may not use this file except in compliance +* with the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an +* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +* KIND, either express or implied. See the License for the +* specific language governing permissions and limitations +* under the License. +* +*************************************************************/ + +#include "gtest/gtest.h" +#include "../src/model/layer/dropout.h" + + +TEST(TestDropoutLayer, Setup) { + + +} http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/99e0d24d/test/singa/test_tensor.cc ---------------------------------------------------------------------- diff --git a/test/singa/test_tensor.cc b/test/singa/test_tensor.cc index 86200a8..ae20823 100644 --- a/test/singa/test_tensor.cc +++ b/test/singa/test_tensor.cc @@ -6,19 +6,19 @@ using singa::Device; TEST(TensorTest, TestConstructor) { singa::Tensor float_t(singa::Shape{2,3}); - EXPECT_EQ(6, float_t.Size()); + EXPECT_EQ(6u, float_t.Size()); EXPECT_EQ(sizeof(float) * 6, float_t.MemSize()); EXPECT_EQ(singa::kFloat32, float_t.data_type()); auto s = float_t.shape(); - EXPECT_EQ(s[0], 2); - EXPECT_EQ(s[1], 3); + EXPECT_EQ(s[0], 2u); + EXPECT_EQ(s[1], 3u); EXPECT_NE(float_t.device(), nullptr); singa::Tensor float16_t(Shape{2,3}, singa::kFloat16); EXPECT_EQ(singa::kFloat16, float16_t.data_type()); - EXPECT_EQ(6, float16_t.Size()); - EXPECT_EQ(12, float16_t.blob()->size()); + EXPECT_EQ(6u, float16_t.Size()); + EXPECT_EQ(12u, float16_t.blob()->size()); singa::Tensor x(float16_t); EXPECT_EQ(float16_t.Size(), x.Size());
