SINGA-170 Add Dropout layer and CudnnDropout layer Add test_dropout.cc for Dropout class. Add RNN base layer draft. Add math functions to support Dropout.
Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/c3a0558c Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/c3a0558c Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/c3a0558c Branch: refs/heads/master Commit: c3a0558cf5896a9313e9e5c2636e742ec8649fad Parents: 99e0d24 Author: Wei Wang <[email protected]> Authored: Tue May 17 15:42:43 2016 +0800 Committer: Wei Wang <[email protected]> Committed: Tue May 17 15:42:43 2016 +0800 ---------------------------------------------------------------------- include/singa/core/device.h | 2 + include/singa/core/tensor.h | 20 ++++----- include/singa/model/layer.h | 71 ++++++++++---------------------- include/singa/model/rnn.h | 29 ------------- src/core/device/device.cc | 4 +- src/core/tensor/tensor.cc | 15 ++++--- src/core/tensor/tensor_math_cpp.h | 24 ++++++++++- src/model/layer/cudnn_dropout.cc | 71 ++++++++++++++++---------------- src/model/layer/cudnn_dropout.h | 14 +++---- src/model/layer/cudnn_utils.h | 14 ++++--- src/model/layer/dropout.cc | 6 +-- src/model/layer/dropout.h | 11 ++++- src/model/layer/rnn.h | 59 ++++++++++++++++++++++++++ test/singa/test_dropout.cc | 75 +++++++++++++++++++++++++++++++++- test/singa/test_tensor.cc | 3 +- 15 files changed, 266 insertions(+), 152 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/c3a0558c/include/singa/core/device.h ---------------------------------------------------------------------- diff --git a/include/singa/core/device.h b/include/singa/core/device.h index f3bb5a2..b96efca 100644 --- a/include/singa/core/device.h +++ b/include/singa/core/device.h @@ -114,6 +114,7 @@ class Device { // SafeQueue<Operation> op_queue_; // SafeQueue<Operation> op_log_; /// The host device + Context ctx_; Device* host_; }; // Implement Device using Cpp libs. @@ -129,6 +130,7 @@ class CppDevice : public Device { /// Free cpu memory. void Free(void* ptr) override; + }; /// a singleton CppDevice as the host for all devices. http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/c3a0558c/include/singa/core/tensor.h ---------------------------------------------------------------------- diff --git a/include/singa/core/tensor.h b/include/singa/core/tensor.h index 4807123..6c20c4f 100644 --- a/include/singa/core/tensor.h +++ b/include/singa/core/tensor.h @@ -155,38 +155,38 @@ class Tensor { Tensor T() const; /// Copy the meta info with data blob shared. - void operator=(const Tensor& t); + Tensor& operator=(const Tensor& t); /// Copy the meta info with data blob shared. - void operator=(Tensor&& t); + Tensor& operator=(Tensor&& t); - void operator+=(const Tensor& t); + Tensor& operator+=(const Tensor& t); // void operator+=(Tensor&& t); - void operator-=(const Tensor& t); + Tensor& operator-=(const Tensor& t); // void operator-=(Tensor&& t); - void operator*=(const Tensor& t); + Tensor& operator*=(const Tensor& t); // void operator*=(Tensor&& t); - void operator/=(const Tensor& t); + Tensor& operator/=(const Tensor& t); // void operator/=(Tensor&& t); // Scalar operations. /// T is a scalar type template<typename DType> - void operator+=(DType x); + Tensor& operator+=(DType x); /// T is a scalar type template <typename DType> - void operator-=(const DType x); + Tensor& operator-=(const DType x); /// T is a scalar type template <typename DType> - void operator*=(const DType x); + Tensor& operator*=(const DType x); /// T is a scalar type template <typename DType> - void operator/=(const DType x); + Tensor& operator/=(const DType x); /// save Tensor into a proto msg // void ToProto(TensorProto* t); http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/c3a0558c/include/singa/model/layer.h ---------------------------------------------------------------------- diff --git a/include/singa/model/layer.h b/include/singa/model/layer.h index 48fc58f..a4c4630 100644 --- a/include/singa/model/layer.h +++ b/include/singa/model/layer.h @@ -42,9 +42,9 @@ class Layer { } // ============= Following Functions could be override ===================== - /// Destruct the objecst created by this layer. + /// Destruct objects created by this layer. virtual ~Layer() { - for (Tensor * t : param_values_) { + for (Tensor* t : param_values_) { delete t; } } @@ -56,19 +56,18 @@ class Layer { /// Set meta data fields configured in 'conf' (a proto message). virtual void Setup(const LayerConf& conf) { name_ = conf.name(); - for (const auto& spec : conf.param()) - param_specs_.push_back(spec); + for (const auto& spec : conf.param()) param_specs_.push_back(spec); // TODO(wangwei) load param values from checkpoint blobs. } /// Do feature transformation for the given 'input' tensor (denoted as x). - /// 'flag' is either kPhaseTrain or kPhaseTest for feed-forward nets, and + /// 'flag' is either kTrain or kEval for feed-forward nets, and /// would be used for other phases of training other nets. For example, when /// training RBM, we may create an alias of this function as ComputeFeature - /// where flag could be kPositivePhase and kNegativePhase. + /// where flag could be kPositive and kNegative. /// It will return a Tensor (denoted as y). /// If the 'input' or 'output' is required for computing the gradients in - /// Backward(), then push them into the states_ stack. + /// Backward(), then buffer them as internal data. virtual const Tensor Forward(int flag, const Tensor& input) { LOG(FATAL) << "Not implemented"; Tensor t; @@ -77,10 +76,12 @@ class Layer { /// \copydoc Forward(int flag, const Tensor& input) /// Accept multiple input tensors and generate multiple output tensors. + /// If there is only one input tensor, it will call Forward(int, const + /// Tensor&) by default. Users can override this function for layers who + /// generate more than one outputs. virtual const vector<Tensor> Forward(int flag, const vector<Tensor>& inputs) { vector<Tensor> ret; - if (inputs.size() == 1) - ret.push_back(Forward(flag, inputs.at(0))); + if (inputs.size() == 1) ret.push_back(Forward(flag, inputs.at(0))); LOG(FATAL) << "Not implemented"; return ret; @@ -88,19 +89,14 @@ class Layer { /// Compute gradients of this layer. /// Specifically, there are two types of gradients: - /// 1. gradients of preceding layers, i.e., dx. - /// 2. gradients of parameters of this layer. - /// 1 and 2 are returned as a pair of vector<Tensor> + /// 1. gradient of the preceding layer, i.e., dx. + /// 2. gradients of parameters of this layer, e.g., dw for weight matrix. /// 1 is an empty tensor if there is no preceding layer or there is no need to - /// compute dx (e.g., x is from a data layer); 2 is empty if this layer has no - /// parameters. - /// 'flag' is either kTrainPhase or kTestPhase for feed-forward nets, and + /// compute dx (e.g., x is from a data layer); 2 is an empty vector if this + // layer has no parameters. + /// 'flag' is either kTrain or kEval for feed-forward nets, and /// would be used for other phases when training other nets. /// 'grad' is a Tensor for gradient (dy) from the upper layer. - /// Some layer would use 'input' or 'output' from Forward to compute the - /// gradients of parameters. Backward() pop out the state data. - /// It is useful for RNN layers, where the same layer is used multiple - /// times just like unrolling the layer. virtual const std::pair<Tensor, vector<Tensor>> Backward(int flag, const Tensor& grad) { LOG(FATAL) << "Not implemented!"; @@ -117,7 +113,7 @@ class Layer { auto ret = Backward(flag, grads.at(0)); input_grad.push_back(ret.first); param_grad = ret.second; - } else { + } else { LOG(FATAL) << "Not implemented"; } return std::make_pair(input_grad, param_grad); @@ -137,7 +133,7 @@ class Layer { /// Serialize the layer info (including params) into a LayerConf proto message virtual void ToProto(LayerConf* conf) const { conf->set_name(name_); - for (const auto& spec: param_specs_) { + for (const auto& spec : param_specs_) { ParamSpec* p = conf->add_param(); p->CopyFrom(spec); } @@ -157,19 +153,13 @@ class Layer { } /// Return specs/configuration of all parameter instances of this layer. /// \ref ParamSpec. - const vector<ParamSpec> param_specs() { - return param_specs_; - } + const vector<ParamSpec> param_specs() { return param_specs_; } /// Return the i-th ParamSpec. - const ParamSpec& param_specs(int i) { - return param_specs_.at(i); - } + const ParamSpec& param_specs(int i) { return param_specs_.at(i); } /// Return pointers to parameter Tensor s. - const vector<Tensor*> param_values() { - return param_values_; - } + const vector<Tensor*> param_values() { return param_values_; } /// Return a pointer to the 'i'-th parameter Tensor. Tensor* param_value(size_t i) { @@ -180,8 +170,7 @@ class Layer { /// Return names of all parmaeters. const vector<string> param_names() { vector<string> pname; - for (const auto& spec: param_specs_) - pname.push_back(spec.name()); + for (const auto& spec : param_specs_) pname.push_back(spec.name()); return pname; } @@ -195,29 +184,11 @@ class Layer { /// Used for debugging and logging. const std::string name() const { return name_; } - /* - std::stack<Tensor> states() const { - return states_; - } - */ - protected: std::string name_; vector<Tensor*> param_values_; vector<ParamSpec> param_specs_; - /// Used to store input or output of Forward(), which would be used in - /// Backward. Rules: - /// 1. push the 'input' or 'output' into states_ if the flag of Forward() is - /// for training. - /// 2. pop data out in Backward(). - /// TODO(wangwei) enable this feature for rnn layers. - // std::stack<Tensor*> states_; }; -// =========================================================================== -// Order layer sub-classes based on alphabetical order of the first letter. -// =========================================================================== - - } // namespace singa #endif // SINGA_LAYER_H_ http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/c3a0558c/include/singa/model/rnn.h ---------------------------------------------------------------------- diff --git a/include/singa/model/rnn.h b/include/singa/model/rnn.h deleted file mode 100644 index 7d2c20c..0000000 --- a/include/singa/model/rnn.h +++ /dev/null @@ -1,29 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - - -namespace singa { - -class RNN { - - - - -}; - -} /* singa */ http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/c3a0558c/src/core/device/device.cc ---------------------------------------------------------------------- diff --git a/src/core/device/device.cc b/src/core/device/device.cc index b2a8705..33f5bd8 100644 --- a/src/core/device/device.cc +++ b/src/core/device/device.cc @@ -23,11 +23,13 @@ Device::Device(int id, int num_executors, string scheduler, string vm) : id_(id) { scheduler_ = nullptr; vm_ = nullptr; + ctx_.seed = 0; + ctx_.random_generator = std::mt19937(ctx_.seed); } void Device::Exec(function<void(Context*)> fn, const vector<Blob*> read_blobs, const vector<Blob*> write_blobs, bool use_rand_generator) { - fn(nullptr); + fn(&ctx_); } Blob* Device::NewBlob(int size) { http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/c3a0558c/src/core/tensor/tensor.cc ---------------------------------------------------------------------- diff --git a/src/core/tensor/tensor.cc b/src/core/tensor/tensor.cc index 8352b48..cd62a38 100644 --- a/src/core/tensor/tensor.cc +++ b/src/core/tensor/tensor.cc @@ -71,7 +71,7 @@ Tensor::Tensor(Tensor&& t) } void Tensor::ResetLike(const Tensor& t) { - if (blob_->size() != t.MemSize()) { + if (blob_ == nullptr || blob_->size() != t.MemSize()) { if (blob_ != nullptr && blob_->DecRefCount() == 0) device_->FreeBlob(blob_); shape_ = t.shape_; device_ = t.device_; @@ -152,7 +152,7 @@ Tensor Tensor::T() const { return t; } -void Tensor::operator=(const Tensor& t) { +Tensor& Tensor::operator=(const Tensor& t) { if (blob_ != nullptr && blob_->DecRefCount() == 0) device_->FreeBlob(blob_); transpose_ = t.transpose_; @@ -161,9 +161,10 @@ void Tensor::operator=(const Tensor& t) { device_ = t.device_; blob_ = t.blob(); blob_->IncRefCount(); + return *this; } -void Tensor::operator=(Tensor&& t) { +Tensor& Tensor::operator=(Tensor&& t) { if (blob_ != nullptr && blob_->DecRefCount() == 0) device_->FreeBlob(blob_); transpose_ = t.transpose_; @@ -171,10 +172,11 @@ void Tensor::operator=(Tensor&& t) { device_ = t.device_; blob_ = t.blob_; t.blob_ = nullptr; + return *this; } #define GenUnaryTensorArgMemberFunction(op, fn) \ - void Tensor::op(const Tensor& t) { fn(*this, t, this); } + Tensor& Tensor::op(const Tensor& t) { fn(*this, t, this); return *this; } GenUnaryTensorArgMemberFunction(operator+=, Add); GenUnaryTensorArgMemberFunction(operator-=, Sub); @@ -183,10 +185,11 @@ GenUnaryTensorArgMemberFunction(operator/=, Div); #define GenUnaryScalarArgMemberFunction(op, fn) \ template <typename DType> \ - void Tensor::op(DType x) { \ + Tensor& Tensor::op(DType x) { \ fn(*this, x, this); \ + return *this; \ } \ - template void Tensor::op<float>(float x) + template Tensor& Tensor::op<float>(float x) GenUnaryScalarArgMemberFunction(operator-=, Sub); GenUnaryScalarArgMemberFunction(operator+=, Add); http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/c3a0558c/src/core/tensor/tensor_math_cpp.h ---------------------------------------------------------------------- diff --git a/src/core/tensor/tensor_math_cpp.h b/src/core/tensor/tensor_math_cpp.h index 9e7ed30..2cbc225 100644 --- a/src/core/tensor/tensor_math_cpp.h +++ b/src/core/tensor/tensor_math_cpp.h @@ -39,6 +39,26 @@ void Add<float, lib::Cpp>(int count, dptr[i] = lptr[i] + rptr[i]; } } +template <> +void EltwiseMult<float, lib::Cpp>(int count, const Blob* input, float x, Blob* ret, Context* ctx) +{ + float *dptr = static_cast<float*>(ret->mutable_data()); + const float *lptr = static_cast<const float*>(input->data()); + for (int i = 0; i < count; i++) { + dptr[i] = lptr[i] * x; + } +} + +template <> +void EltwiseMult<float, lib::Cpp>(int count, const Blob* lhs, const Blob* rhs, Blob* ret, Context* ctx) +{ + float *dptr = static_cast<float*>(ret->mutable_data()); + const float *lptr = static_cast<const float*>(lhs->data()); + const float *rptr = static_cast<const float*>(rhs->data()); + for (int i = 0; i < count; i++) { + dptr[i] = lptr[i] * rptr[i]; + } +} template <> void Bernoulli<float, lib::Cpp>(int count, float p, Blob* ret, @@ -46,7 +66,7 @@ void Bernoulli<float, lib::Cpp>(int count, float p, Blob* ret, std::bernoulli_distribution distribution(p); float* ptr = static_cast<float*>(ret->mutable_data()); for (int i = 0; i < count; i ++) { - ptr[i] = static_cast<float>(distribution(ctx->random_generator)); + ptr[i] = distribution(ctx->random_generator) ? 1.0f : 0.0f; } } @@ -69,6 +89,8 @@ void Gaussian<float, lib::Cpp>(int count, float mean, float std, Blob* ret, ptr[i] = static_cast<float>(distribution(ctx->random_generator)); } } + + #ifdef USE_CBLAS template<> void Dot<float, lib::Cpp>(int count, http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/c3a0558c/src/model/layer/cudnn_dropout.cc ---------------------------------------------------------------------- diff --git a/src/model/layer/cudnn_dropout.cc b/src/model/layer/cudnn_dropout.cc index 926ccb9..4d5f5d5 100644 --- a/src/model/layer/cudnn_dropout.cc +++ b/src/model/layer/cudnn_dropout.cc @@ -17,18 +17,16 @@ */ #ifdef USE_CUDNN // cudnn dropout is added in cudnn 5 -//#if CUDNN_MAJOR_VERSION >= 5 -#include "./cudnn_utils.h" +#if CUDNN_MAJOR_VERSION >= 5 #include "./cudnn_dropout.h" +#include "./cudnn_utils.h" #include "singa/utils/logging.h" namespace singa { CudnnDropout::~CudnnDropout() { if (drop_desc_ != nullptr) CUDNN_CHECK(cudnnDestroyDropoutDescriptor(drop_desc_)); - if (x_desc_ != nullptr) - CUDNN_CHECK(cudnnDestroyTensorDescriptor(x_desc_)); - if (y_desc_ != nullptr) - CUDNN_CHECK(cudnnDestroyTensorDescriptor(y_desc_)); + if (x_desc_ != nullptr) CUDNN_CHECK(cudnnDestroyTensorDescriptor(x_desc_)); + if (y_desc_ != nullptr) CUDNN_CHECK(cudnnDestroyTensorDescriptor(y_desc_)); } void CudnnDropout::InitCudnn(int size, DataType dtype, Context* ctx) { @@ -37,18 +35,16 @@ void CudnnDropout::InitCudnn(int size, DataType dtype, Context* ctx) { CUDNN_CHECK(cudnnCreateTensorDescriptor(&y_desc_)); CUDNN_CHECK(cudnnCreateDropoutDescriptor(&drop_desc_)); - int dim[] = {size}; - int stride[] = {1}; - CUDNN_CHECK(cudnnSetTensorNdDescriptor(x_desc_, GetCudnnDataType(dtype), 1, - dim, stride)); - CUDNN_CHECK(cudnnSetTensorNdDescriptor(y_desc_, GetCudnnDataType(dtype), 1, - dim, stride)); + CUDNN_CHECK(cudnnSetTensor4dDescriptor( + x_desc_, CUDNN_TENSOR_NCHW, GetCudnnDataType(dtype), 1, 1, 1, size)); + CUDNN_CHECK(cudnnSetTensor4dDescriptor( + y_desc_, CUDNN_TENSOR_NCHW, GetCudnnDataType(dtype), 1, 1, 1, size)); cudnnDropoutGetStatesSize(ctx->cudnn_handle, &state_size_); cudnnDropoutGetReserveSpaceSize(x_desc_, &reserve_size_); - cudnnSetDropoutDescriptor(drop_desc_, ctx->cudnn_handle, dropout_ratio_, - state_.blob()->mutable_data(), - state_size_, ctx->seed); + cudnnSetDropoutDescriptor(drop_desc_, ctx->cudnn_handle, 1 - dropout_ratio_, + state_.blob()->mutable_data(), state_size_, + ctx->seed); has_init_cudnn_ = true; } @@ -59,23 +55,27 @@ const Tensor CudnnDropout::Forward(int flag, const Tensor& input) { if (!has_init_cudnn_) { input.device()->Exec( [size, dtype, this](Context* ctx) { - this->InitCudnn(size, dtype, ctx); + this->InitCudnn(size, dtype, ctx); }, - {}, {state_.blob()}); + {}, {this->state_.blob()}); mask_.ResetLike(input); + // TODO(wangwei) update for async running, + // where reserve_size_ may not available CHECK_EQ(reserve_size_, mask_.MemSize()); } - Tensor out; - out.ResetLike(input); - Blob *inblob = input.blob(), *outblob = out.blob(), *mblob = mask_.blob(); - out.device()->Exec( - [inblob, outblob, mblob, this](Context* ctx) { - cudnnDropoutForward( - ctx->cudnn_handle, this->drop_desc_, this->x_desc_, inblob->data(), - this->y_desc_, outblob->mutable_data(), mblob, this->reserve_size_); + Tensor output; + output.ResetLike(input); + output.device()->Exec( + [input, output, this](Context* ctx) { + Blob *inblob = input.blob(), *outblob = output.blob(), + *mblob = mask_.blob(); + cudnnDropoutForward(ctx->cudnn_handle, this->drop_desc_, + this->x_desc_, inblob->data(), this->y_desc_, + outblob->mutable_data(), mblob, + this->reserve_size_); }, - {inblob}, {mblob, outblob}); - return out; + {input.blob()}, {output.blob(), mask_.blob()}); + return output; } else { return input; } @@ -87,20 +87,21 @@ const std::pair<Tensor, vector<Tensor>> CudnnDropout::Backward( Tensor dx; if (flag & kTrain) { dx.ResetLike(grad); - Blob *dyblob = grad.blob(), *dxblob = dx.blob(), *mblob = mask_.blob(); dx.device()->Exec( - [dyblob, dxblob, mblob, this](Context* ctx) { - cudnnDropoutBackward(ctx->cudnn_handle, this->drop_desc_, - this->y_desc_, dyblob->data(), this->x_desc_, - dxblob->mutable_data(), mblob, - this->reserve_size_); + [dx, grad, this](Context* ctx) { + Blob *dyblob = grad.blob(), *dxblob = dx.blob(), + *mblob = this->mask_.blob(); + cudnnDropoutBackward(ctx->cudnn_handle, this->drop_desc_, + this->y_desc_, dyblob->data(), this->x_desc_, + dxblob->mutable_data(), mblob->mutable_data(), + this->reserve_size_); }, - {dyblob, mblob}, {dxblob}); + {grad.blob(), mask_.blob()}, {dx.blob()}); } else { LOG(ERROR) << "Do not call backward for evaluation phase"; } return std::make_pair(dx, param_grad); } } // namespace singa -//#endif // CUDNN_VERSION_MAJOR>=5 +#endif // CUDNN_VERSION_MAJOR>=5 #endif // USE_CUDNN http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/c3a0558c/src/model/layer/cudnn_dropout.h ---------------------------------------------------------------------- diff --git a/src/model/layer/cudnn_dropout.h b/src/model/layer/cudnn_dropout.h index 0a19214..d2b68b9 100644 --- a/src/model/layer/cudnn_dropout.h +++ b/src/model/layer/cudnn_dropout.h @@ -20,12 +20,12 @@ #define SINGA_MODEL_LAYER_CUDNN_DROPOUT_H_ #ifdef USE_CUDNN // cudnn dropout is added in cudnn 5 -//#if CUDNN_MAJOR_VERSION >= 5 +#if CUDNN_MAJOR_VERSION >= 5 -#include "singa/model/layer.h" +#include "./dropout.h" #include "singa/core/common.h" +#include "singa/model/layer.h" #include "singa/proto/core.pb.h" -#include "./dropout.h" namespace singa { class CudnnDropout : public Dropout { @@ -35,8 +35,8 @@ class CudnnDropout : public Dropout { const std::string layer_type() const override { return "CudnnDropout"; } const Tensor Forward(int flag, const Tensor& input) override; - const std::pair<Tensor, vector<Tensor>> Backward( - int flag, const Tensor& grad) override; + const std::pair<Tensor, vector<Tensor>> Backward(int flag, + const Tensor& grad) override; /// Init cudnn related data structures. void InitCudnn(int size, DataType dtype, Context* ctx); @@ -49,6 +49,6 @@ class CudnnDropout : public Dropout { Tensor state_; }; } // namespace -//#endif // CUDNN_VERSION_MAJOR>=5 +#endif // CUDNN_VERSION_MAJOR>=5 #endif // USE_CUDNN -#endif // SINGA_MODEL_LAYER_CUDNN_DROPOUT_H_ +#endif // SINGA_MODEL_LAYER_CUDNN_DROPOUT_H_ http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/c3a0558c/src/model/layer/cudnn_utils.h ---------------------------------------------------------------------- diff --git a/src/model/layer/cudnn_utils.h b/src/model/layer/cudnn_utils.h index 735ec13..92c8df7 100644 --- a/src/model/layer/cudnn_utils.h +++ b/src/model/layer/cudnn_utils.h @@ -17,10 +17,12 @@ */ #ifndef SINGA_MODEL_LAYER_CUDNN_BASE_H_ #define SINGA_MODEL_LAYER_CUDNN_BASE_H_ + #ifdef USE_CUDNN + +#include <cudnn.h> #include "singa/proto/core.pb.h" #include "singa/utils/logging.h" -#include <cudnn.h> namespace singa { inline cudnnDataType_t GetCudnnDataType(DataType dtype) { cudnnDataType_t ret; @@ -41,11 +43,11 @@ inline cudnnDataType_t GetCudnnDataType(DataType dtype) { return ret; } -#define CUDNN_CHECK(condition) \ - do { \ - cudnnStatus_t status = condition; \ - CHECK_EQ(status, CUDNN_STATUS_SUCCESS) << " "\ - << cudnnGetErrorString(status); \ +#define CUDNN_CHECK(condition) \ + do { \ + cudnnStatus_t status = condition; \ + CHECK_EQ(status, CUDNN_STATUS_SUCCESS) << " " \ + << cudnnGetErrorString(status); \ } while (0) /* http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/c3a0558c/src/model/layer/dropout.cc ---------------------------------------------------------------------- diff --git a/src/model/layer/dropout.cc b/src/model/layer/dropout.cc index f0fe25b..c2c97be 100644 --- a/src/model/layer/dropout.cc +++ b/src/model/layer/dropout.cc @@ -30,7 +30,7 @@ const Tensor Dropout::Forward(int flag, const Tensor& input) { if (flag & kTrain) { mask_.ResetLike(input); // set mask_[i] = 1 with prob 1-dropout_rato_ - Bernoulli(1 - dropout_ratio_, &mask_); + Bernoulli(1.0f - dropout_ratio_, &mask_); mask_ *= 1.0f / (1.0f - dropout_ratio_); out = input * mask_; } else { @@ -39,8 +39,8 @@ const Tensor Dropout::Forward(int flag, const Tensor& input) { return out; } -const std::pair<Tensor, vector<Tensor>> Dropout::Backward( - int flag, const Tensor& grad) { +const std::pair<Tensor, vector<Tensor>> Dropout::Backward(int flag, + const Tensor& grad) { vector<Tensor> param_grad; Tensor input_grad; if (flag & kTrain) { http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/c3a0558c/src/model/layer/dropout.h ---------------------------------------------------------------------- diff --git a/src/model/layer/dropout.h b/src/model/layer/dropout.h index de349a5..a6e733a 100644 --- a/src/model/layer/dropout.h +++ b/src/model/layer/dropout.h @@ -31,7 +31,8 @@ class Dropout : public Layer { /// if flag is kTrain, then do dropout with given dropout_ratio; /// otherwise if it is kEval, copy input directly to the output /// TODO(wangwei) There are diff implementations, Caffe vs - /// <a href="https://github.com/nitishsrivastava/deepnet/blob/master/deepnet/fastdropoutnet.py"> + /// <a + /// href="https://github.com/nitishsrivastava/deepnet/blob/master/deepnet/fastdropoutnet.py"> const Tensor Forward(int flag, const Tensor& input) override; /// \copydoc Layer::Backward(int, const Tensor&, const Tensor&); @@ -40,6 +41,14 @@ class Dropout : public Layer { void ToDevice(Device* device) override; + float dropout_ratio() const { + return dropout_ratio_; + } + + const Tensor& mask() const { + return mask_; + } + protected: /// the proability to set each element to 0. float dropout_ratio_; http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/c3a0558c/src/model/layer/rnn.h ---------------------------------------------------------------------- diff --git a/src/model/layer/rnn.h b/src/model/layer/rnn.h new file mode 100644 index 0000000..a6ba461 --- /dev/null +++ b/src/model/layer/rnn.h @@ -0,0 +1,59 @@ + /** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef SINGA_MODEL_LAYER_DROPOUT_H_ +#define SINGA_MODEL_LAYER_DROPOUT_H_ +#include "singa/model/layer.h" +namespace singa { +/// To enable use the same layer multiple times in one iteration in RNN, +/// the Forward() function pushes the 'input' or 'output' that are +/// necessary for Backward() in a stack (states_). If neither 'input' or +/// 'output' is used by Backward(), then do not store them. The Backward() +/// pops data from the states_ stack to compute gradients. Users are +/// responsible for accumulating the gradients for the same parameters. +class RNN : public Layer { + public: + /// \copydoc Layer::layer_type() + const std::string layer_type() const override { return "RNN"; } + + /// \copydoc Layer::Setup(const LayerConf&); + void Setup(const LayerConf& conf) override; + + /// \copydoc Layer::Forward(int flag, const vector<Tensor>&) + const vector<Tensor> Forward(int flag, const vector<Tensor>& input) override; + + /// \copydoc Layer::Backward(int, const vector<Tensor>&); + const std::pair<vector<Tensor>, vector<Tensor>> Backward( + int flag, const vector<Tensor>& grad) override; + + void ToDevice(Device* device) override; + + /// Return the internal state stack, which should be empty at the beginning + /// of + /// one iteration. + std::stack<Tensor> states() const { return states_; } + + protected: + /// Storing input or output from Forward(), which are used in Backward(). + /// Rules: + /// 1. push the 'input' or 'output' into states_ if the flag of Forward() is + /// for kTrain and 'input' or 'output' is necessary for Backward(). + /// 2. pop data out in Backward(). + std::stack<Tensor*> states_; +}; +} // namespace singa +#endif // SINGA_MODEL_LAYER_DROPOUT_H_ http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/c3a0558c/test/singa/test_dropout.cc ---------------------------------------------------------------------- diff --git a/test/singa/test_dropout.cc b/test/singa/test_dropout.cc index cfe9d73..3190ecd 100644 --- a/test/singa/test_dropout.cc +++ b/test/singa/test_dropout.cc @@ -19,11 +19,82 @@ * *************************************************************/ -#include "gtest/gtest.h" #include "../src/model/layer/dropout.h" +#include "gtest/gtest.h" + +using singa::Dropout; +TEST(DropoutLayer, Setup) { + Dropout drop; + EXPECT_EQ("Dropout", drop.layer_type()); + + singa::LayerConf conf; + singa::DropoutConf* dropconf = conf.mutable_dropout_conf(); + dropconf->set_dropout_ratio(0.8); + + drop.Setup(conf); + EXPECT_EQ(0.8f, drop.dropout_ratio()); +} + +TEST(DropoutLayer, Forward) { + const float x[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f}; + size_t n = sizeof(x) / sizeof(float); + singa::Tensor in(singa::Shape{n}); + in.CopyDataFromHostPtr(x, n); + + float pdrop = 0.5; + Dropout drop; + singa::LayerConf conf; + singa::DropoutConf* dropconf = conf.mutable_dropout_conf(); + dropconf->set_dropout_ratio(pdrop); + drop.Setup(conf); + float scale = 1.0f / (1.0f - pdrop); + + singa::Tensor out1 = drop.Forward(singa::kTrain, in); + + const float* mptr = static_cast<const float*>(drop.mask().blob()->data()); + for (size_t i = 0; i < n; i++) + EXPECT_FLOAT_EQ(0, mptr[i] * (mptr[i] - scale)); + + const float* outptr1 = static_cast<const float*>(out1.blob()->data()); + EXPECT_EQ(n, out1.Size()); + // the output value should be 0 or the same as the input + EXPECT_EQ(0.f, outptr1[0] * (outptr1[0] - scale * x[0])); + EXPECT_EQ(0.f, outptr1[1] * (outptr1[1] - scale * x[1])); + EXPECT_EQ(0.f, outptr1[7] * (outptr1[7] - scale * x[7])); + + singa::Tensor out2 = drop.Forward(singa::kEval, in); + EXPECT_EQ(n, out2.Size()); + const float* outptr2 = static_cast<const float*>(out2.blob()->data()); + // the output value should be the same as the input + EXPECT_EQ(x[0], outptr2[0]); + EXPECT_EQ(x[1], outptr2[1]); + EXPECT_EQ(x[7], outptr2[7]); +} + +TEST(DropoutLayer, Backward) { + const float x[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f}; + size_t n = sizeof(x) / sizeof(float); + singa::Tensor in(singa::Shape{n}); + in.CopyDataFromHostPtr(x, n); + float pdrop = 0.5; + float scale = 1.0f / (1.0f - pdrop); -TEST(TestDropoutLayer, Setup) { + Dropout drop; + singa::LayerConf conf; + singa::DropoutConf* dropconf = conf.mutable_dropout_conf(); + dropconf->set_dropout_ratio(pdrop); + drop.Setup(conf); + singa::Tensor out1 = drop.Forward(singa::kTrain, in); + const float dy[] = {4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 1.0f, 2.0f, 3.0f}; + singa::Tensor grad(singa::Shape{n}); + grad.CopyDataFromHostPtr(dy, n); + const float* mptr = static_cast<const float*>(drop.mask().blob()->data()); + const auto ret = drop.Backward(singa::kTrain, grad); + const float* dx = static_cast<const float*>(ret.first.blob()->data()); + EXPECT_FLOAT_EQ(dx[0], dy[0] * (mptr[0] > 0 ? 1.0f : 0.0f) * scale); + EXPECT_FLOAT_EQ(dx[1], dy[1] * (mptr[1] > 0) * scale); + EXPECT_FLOAT_EQ(dx[7], dy[7] * (mptr[7] > 0) * scale); } http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/c3a0558c/test/singa/test_tensor.cc ---------------------------------------------------------------------- diff --git a/test/singa/test_tensor.cc b/test/singa/test_tensor.cc index ae20823..8c3c901 100644 --- a/test/singa/test_tensor.cc +++ b/test/singa/test_tensor.cc @@ -107,7 +107,8 @@ TEST(TensorClass, T) { EXPECT_EQ(true, o.transpose()); EXPECT_EQ(t.blob(), o.blob()); EXPECT_EQ(t.data_type(), o.data_type()); - EXPECT_TRUE((t.shape() == o.shape())); + EXPECT_EQ(t.shape()[0], o.shape()[1]); + EXPECT_EQ(t.shape()[1], o.shape()[0]); }
