SINGA-178 Add Convolution layer and Pooling layer Minor update on variable names and InitCudnn arguments. Fix compiling warnings about signed and unsigned number comparison. Format code. Pass all tests.
Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/7d149ecf Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/7d149ecf Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/7d149ecf Branch: refs/heads/dev Commit: 7d149ecf786f816cf2da47ea9e5bb86f8fecdd6b Parents: 152056d Author: Wei Wang <[email protected]> Authored: Mon May 30 16:53:40 2016 +0800 Committer: Wei Wang <[email protected]> Committed: Mon May 30 18:17:57 2016 +0800 ---------------------------------------------------------------------- include/singa/core/device.h | 12 +-- include/singa/core/tensor.h | 2 +- include/singa/model/layer.h | 7 +- include/singa/utils/string.h | 81 ++++++++++++++++++ include/singa/utils/tokenizer.h | 65 -------------- src/core/tensor/tensor.cc | 8 +- src/model/layer/convolution.cc | 33 ++++++-- src/model/layer/convolution.h | 3 +- src/model/layer/cudnn_convolution.cc | 135 ++++++++++++++---------------- src/model/layer/cudnn_convolution.h | 11 ++- src/model/layer/cudnn_pooling.cc | 40 ++++----- src/model/layer/cudnn_pooling.h | 2 +- src/model/layer/pooling.cc | 8 +- src/model/layer/pooling.h | 3 +- src/proto/model.proto | 15 ++-- test/singa/test_cudnn_convolution.cc | 50 +++++------ test/singa/test_cudnn_pooling.cc | 26 +++--- 17 files changed, 274 insertions(+), 227 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/7d149ecf/include/singa/core/device.h ---------------------------------------------------------------------- diff --git a/include/singa/core/device.h b/include/singa/core/device.h index a4b3f6d..56eda70 100644 --- a/include/singa/core/device.h +++ b/include/singa/core/device.h @@ -77,6 +77,10 @@ class Device { Device* host() const { return host_;} + Context* context(int k) { + return &ctx_; + } + int id() const { return id_; } protected: @@ -104,6 +108,8 @@ class Device { // SafeQueue<Operation> op_log_; /// The host device Device* host_; + // TODO(wangwei) define multiple contexts, one per executor + Context ctx_; }; /// Represent a CPU device which may have multiple threads/executors. @@ -125,9 +131,6 @@ class CppCPU : public Device { /// Free cpu memory. void Free(void* ptr) override; - - protected: - Context ctx_; }; /// a singleton CppDevice as the host for all devices. @@ -177,9 +180,6 @@ class CudaGPU : public Device { /// Free cpu memory. void Free(void* ptr) override; - - protected: - Context ctx_; }; /// CudaCPU which uses cudaMallocHost to allocate pinned memory for host. http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/7d149ecf/include/singa/core/tensor.h ---------------------------------------------------------------------- diff --git a/include/singa/core/tensor.h b/include/singa/core/tensor.h index f51c899..8682bca 100644 --- a/include/singa/core/tensor.h +++ b/include/singa/core/tensor.h @@ -97,7 +97,7 @@ public: return shape_.at(idx); } - int nDim() const { return shape_.size(); } + size_t nDim() const { return shape_.size(); } bool transpose() const { return transpose_; } http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/7d149ecf/include/singa/model/layer.h ---------------------------------------------------------------------- diff --git a/include/singa/model/layer.h b/include/singa/model/layer.h index c6a3bd1..82c8edc 100644 --- a/include/singa/model/layer.h +++ b/include/singa/model/layer.h @@ -44,7 +44,7 @@ class Layer { // ============= Following Functions could be override ===================== /// Destruct objects created by this layer. - virtual ~Layer() {}; + virtual ~Layer() {}; /// Each layer sub-class would optionaly have a type name. /// Used for debugging and logging. @@ -160,7 +160,10 @@ class Layer { const vector<ParamSpec> param_specs() { return param_specs_; } /// Return the i-th ParamSpec. - const ParamSpec& param_specs(int i) { return param_specs_.at(i); } + const ParamSpec& param_specs(size_t i) { + CHECK_LT(i, param_specs_.size()); + return param_specs_.at(i); + } /// Return pointers to parameter Tensor s. const vector<Tensor*> param_values() { return param_values_; } http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/7d149ecf/include/singa/utils/string.h ---------------------------------------------------------------------- diff --git a/include/singa/utils/string.h b/include/singa/utils/string.h new file mode 100644 index 0000000..b739afc --- /dev/null +++ b/include/singa/utils/string.h @@ -0,0 +1,81 @@ +/************************************************************ +* +* Licensed to the Apache Software Foundation (ASF) under one +* or more contributor license agreements. See the NOTICE file +* distributed with this work for additional information +* regarding copyright ownership. The ASF licenses this file +* to you under the Apache License, Version 2.0 (the +* "License"); you may not use this file except in compliance +* with the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an +* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +* KIND, either express or implied. See the License for the +* specific language governing permissions and limitations +* under the License. +* +*************************************************************/ + +#ifndef SINGA_UTILS_TOKENIZER_H_ +#define SINGA_UTILS_TOKENIZER_H_ + +#include <string> +#include <algorithm> +#include "singa/utils/logging.h" + +namespace singa { +inline bool icasecmp(const string& l, const string& r) { + return l.size() == r.size() && + equal(l.cbegin(), l.cend(), r.cbegin(), + [](string::value_type l1, string::value_type r1) { + return toupper(l1) == toupper(r1); + }); +} + +inline string ToLowerCase(const string& input) { + string out; + out.resize(input.size()); + std::transform(input.begin(), input.end(), out.begin(), ::tolower); + return out; +} + +/** + * Tokenize a string. + * + * example: + * Tokenizer t("assa,asf;wes", ",;"); + * string x; + * t >> x; // x is assa + * t >> x; // x is asf + * t >> x; // x is wes + * cout << (t >> x); // print 0. + */ + +class Tokenizer { + public: + Tokenizer(const std::string& str, const std::string& sep): start_(0), + sep_(sep), buf_(str) {} + Tokenizer & operator>>(std::string& out) { + CHECK_LT(start_, buf_.length()); + int start = start_; + auto pos = buf_.find_first_of(sep_, start); + if (pos == std::string::npos) + pos = buf_.length(); + start_ = pos + 1; + out = buf_.substr(start, pos); + return *this; + } + bool Valid() { return start_ < buf_.length(); } + + private: + unsigned start_; + std::string sep_; + const std::string& buf_; +}; + +} // namespace singa + +#endif // SINGA_UTILS_TOKENIZER_H_ http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/7d149ecf/include/singa/utils/tokenizer.h ---------------------------------------------------------------------- diff --git a/include/singa/utils/tokenizer.h b/include/singa/utils/tokenizer.h deleted file mode 100644 index 92c24b6..0000000 --- a/include/singa/utils/tokenizer.h +++ /dev/null @@ -1,65 +0,0 @@ -/************************************************************ -* -* Licensed to the Apache Software Foundation (ASF) under one -* or more contributor license agreements. See the NOTICE file -* distributed with this work for additional information -* regarding copyright ownership. The ASF licenses this file -* to you under the Apache License, Version 2.0 (the -* "License"); you may not use this file except in compliance -* with the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, -* software distributed under the License is distributed on an -* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -* KIND, either express or implied. See the License for the -* specific language governing permissions and limitations -* under the License. -* -*************************************************************/ - -#ifndef SINGA_UTILS_TOKENIZER_H_ -#define SINGA_UTILS_TOKENIZER_H_ - -#include <string> -#include "singa/utils/logging.h" - -namespace singa { -/** - * Tokenize a string. - * - * example: - * Tokenizer t("assa,asf;wes", ",;"); - * string x; - * t >> x; // x is assa - * t >> x; // x is asf - * t >> x; // x is wes - * cout << (t >> x); // print 0. - */ - -class Tokenizer { - public: - Tokenizer(const std::string& str, const std::string& sep): start_(0), - sep_(sep), buf_(str) {} - Tokenizer & operator>>(std::string& out) { - CHECK_LT(start_, buf_.length()); - int start = start_; - auto pos = buf_.find_first_of(sep_, start); - if (pos == std::string::npos) - pos = buf_.length(); - start_ = pos + 1; - out = buf_.substr(start, pos); - return *this; - } - bool Valid() { return start_ < buf_.length(); } - - private: - unsigned start_; - std::string sep_; - const std::string& buf_; -}; - -} // namespace singa - -#endif // SINGA_UTILS_TOKENIZER_H_ http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/7d149ecf/src/core/tensor/tensor.cc ---------------------------------------------------------------------- diff --git a/src/core/tensor/tensor.cc b/src/core/tensor/tensor.cc index 0e47a4f..fcf42c2 100644 --- a/src/core/tensor/tensor.cc +++ b/src/core/tensor/tensor.cc @@ -562,8 +562,8 @@ void AddColumn(const float alpha, const float beta, const Tensor &v, Tensor X = M->T(); AddRow(v, &X); } else { - CHECK_EQ(M->nDim(), 2); - CHECK_EQ(v.nDim(), 1); + CHECK_EQ(M->nDim(), 2u); + CHECK_EQ(v.nDim(), 1u); size_t nb_row = M->shape(0), nb_col = M->shape(1); CHECK_EQ(nb_row, v.Size()); @@ -581,8 +581,8 @@ void AddRow(const float alpha, const float beta, const Tensor &v, Tensor *M) { Tensor X = M->T(); AddColumn(v, &X); } else { - CHECK_EQ(M->nDim(), 2); - CHECK_EQ(v.nDim(), 1); + CHECK_EQ(M->nDim(), 2u); + CHECK_EQ(v.nDim(), 1u); size_t nb_row = M->shape(0), nb_col = M->shape(1); CHECK_EQ(nb_col, v.Size()); http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/7d149ecf/src/model/layer/convolution.cc ---------------------------------------------------------------------- diff --git a/src/model/layer/convolution.cc b/src/model/layer/convolution.cc index 6406a31..50ee3c8 100644 --- a/src/model/layer/convolution.cc +++ b/src/model/layer/convolution.cc @@ -28,32 +28,51 @@ void Convolution::Setup(const LayerConf &conf) { ConvolutionConf conv_conf = conf.convolution_conf(); // kernel_size, pad, and stride are repeated fields. if (conv_conf.kernel_size_size() > 0) { - kernel_w_ = kernel_h_ = conv_conf.kernel_size(0); + if (conv_conf.kernel_size_size() == 1) { + kernel_w_ = kernel_h_ = conv_conf.kernel_size(0); + } else { + kernel_w_ = conv_conf.kernel_size(0); + kernel_h_ = conv_conf.kernel_size(1); + } } else { kernel_w_ = conv_conf.kernel_w(); kernel_h_ = conv_conf.kernel_h(); } - CHECK_NE(kernel_w_, 0); - CHECK_NE(kernel_h_, 0); + CHECK_GT(kernel_w_, 0u); + CHECK_GT(kernel_h_, 0u); if (conv_conf.pad_size() > 0) { - pad_w_ = pad_h_ = conv_conf.pad(0); + if (conv_conf.pad_size() == 1) { + pad_w_ = pad_h_ = conv_conf.pad(0); + } else { + pad_w_ = conv_conf.pad(0); + pad_h_ = conv_conf.pad(1); + } } else { pad_w_ = conv_conf.pad_w(); pad_h_ = conv_conf.pad_h(); } + CHECK_GE(pad_w_, 0u); + CHECK_GE(pad_h_, 0u); if (conv_conf.stride_size() > 0) { - stride_w_ = stride_h_ = conv_conf.stride(0); + if (conv_conf.stride_size() == 1) { + stride_w_ = stride_h_ = conv_conf.stride(0); + } else { + stride_w_ = conv_conf.stride(0); + stride_h_ = conv_conf.stride(1); + } } else { stride_w_ = conv_conf.stride_w(); stride_h_ = conv_conf.stride_h(); } + CHECK_GT(stride_w_, 0u); + CHECK_GT(stride_h_, 0u); num_filters_ = conv_conf.num_output(); bias_term_ = conv_conf.bias_term(); - // Shape of src + // Shape of input image channels_ = conv_conf.channels(); height_ = conv_conf.height(); width_ = conv_conf.width(); @@ -68,7 +87,7 @@ void Convolution::Setup(const LayerConf &conf) { bias_.Reshape(Shape{num_filters_}); // Push back params into param_values_ // Assume the order of param is: weight, bias - for (const auto& spec : conf.param()) param_specs_.push_back(spec); + for (const auto &spec : conf.param()) param_specs_.push_back(spec); param_values_.push_back(&weight_); param_values_.push_back(&bias_); } http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/7d149ecf/src/model/layer/convolution.h ---------------------------------------------------------------------- diff --git a/src/model/layer/convolution.h b/src/model/layer/convolution.h index a9bf833..477efb3 100644 --- a/src/model/layer/convolution.h +++ b/src/model/layer/convolution.h @@ -47,7 +47,6 @@ class Convolution : public Layer { size_t stride_w() const { return stride_w_; } size_t stride_h() const { return stride_h_; } size_t num_filters() const { return num_filters_; } - size_t batchsize() const { return batchsize_; } size_t channels() const { return channels_; } size_t height() const { return height_; } size_t width() const { return width_; } @@ -67,7 +66,7 @@ class Convolution : public Layer { protected: size_t kernel_w_, pad_w_, stride_w_; size_t kernel_h_, pad_h_, stride_h_; - size_t batchsize_, channels_, height_, width_; + size_t channels_, height_, width_; size_t col_height_, col_width_, conv_height_, conv_width_, num_filters_; Tensor weight_, bias_; // store intermediate data, i.e., input tensor http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/7d149ecf/src/model/layer/cudnn_convolution.cc ---------------------------------------------------------------------- diff --git a/src/model/layer/cudnn_convolution.cc b/src/model/layer/cudnn_convolution.cc index ec7cd6a..922b7e0 100644 --- a/src/model/layer/cudnn_convolution.cc +++ b/src/model/layer/cudnn_convolution.cc @@ -39,9 +39,9 @@ void CudnnConvolution::Setup(const LayerConf &conf) { ConvolutionConf conv_conf = conf.convolution_conf(); // convert MB to bytes workspace_byte_limit_ = conv_conf.workspace_byte_limit() << 20; - pref_ = conv_conf.algo_pref(); - CHECK(pref_ == "fastest" || pref_ == "limited_workspace" || - pref_ == "no_workspace") + prefer_ = ToLowerCase(conv_conf.prefer()); + CHECK(prefer_ == "fastest" || prefer_ == "limited_workspace" || + prefer_ == "no_workspace") << "CudnnConvolution only supports three algorithm preferences: fastest, " "limited_workspace and no_workspace"; } @@ -52,8 +52,12 @@ void CudnnConvolution::ToDevice(Device *device) { workspace_.ToDevice(device); } -void CudnnConvolution::InitCudnn(DataType dtype, Device *dev, Context *ctx) { +void CudnnConvolution::InitCudnn(const Tensor& input) { CHECK(!has_init_cudnn_); + DataType dtype = input.data_type(); + Device *dev = input.device(); + Context *ctx = dev->context(0); + size_t batchsize = input.shape(0); CUDNN_CHECK(cudnnCreateTensorDescriptor(&x_desc_)); CUDNN_CHECK(cudnnCreateTensorDescriptor(&y_desc_)); CUDNN_CHECK(cudnnCreateTensorDescriptor(&bias_desc_)); @@ -61,10 +65,10 @@ void CudnnConvolution::InitCudnn(DataType dtype, Device *dev, Context *ctx) { CUDNN_CHECK(cudnnCreateConvolutionDescriptor(&conv_desc_)); CUDNN_CHECK(cudnnSetTensor4dDescriptor(x_desc_, CUDNN_TENSOR_NCHW, - GetCudnnDataType(dtype), batchsize_, + GetCudnnDataType(dtype), batchsize, channels_, height_, width_)); CUDNN_CHECK(cudnnSetTensor4dDescriptor( - y_desc_, CUDNN_TENSOR_NCHW, GetCudnnDataType(dtype), batchsize_, + y_desc_, CUDNN_TENSOR_NCHW, GetCudnnDataType(dtype), batchsize, num_filters_, conv_height_, conv_width_)); if (bias_term_) CUDNN_CHECK(cudnnSetTensor4dDescriptor(bias_desc_, CUDNN_TENSOR_NCHW, @@ -88,20 +92,20 @@ void CudnnConvolution::InitCudnn(DataType dtype, Device *dev, Context *ctx) { cudnnConvolutionFwdPreference_t fwd_pref; cudnnConvolutionBwdFilterPreference_t bwd_filt_pref; cudnnConvolutionBwdDataPreference_t bwd_data_pref; - if (pref_ == "fastest") { + if (prefer_ == "fastest") { fwd_pref = CUDNN_CONVOLUTION_FWD_PREFER_FASTEST; bwd_filt_pref = CUDNN_CONVOLUTION_BWD_FILTER_PREFER_FASTEST; bwd_data_pref = CUDNN_CONVOLUTION_BWD_DATA_PREFER_FASTEST; - } else if (pref_ == "limited_workspace") { + } else if (prefer_ == "limited_workspace") { fwd_pref = CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT; bwd_filt_pref = CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT; bwd_data_pref = CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT; - } else if (pref_ == "no_workspace") { + } else if (prefer_ == "no_workspace") { fwd_pref = CUDNN_CONVOLUTION_FWD_NO_WORKSPACE; bwd_filt_pref = CUDNN_CONVOLUTION_BWD_FILTER_NO_WORKSPACE; bwd_data_pref = CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT; } else { - LOG(FATAL) << "Algorithm preference is not implemented!"; + LOG(FATAL) << "Preferred algorithm is not available!"; } CUDNN_CHECK(cudnnGetConvolutionForwardAlgorithm( ctx->cudnn_handle, x_desc_, filter_desc_, conv_desc_, y_desc_, fwd_pref, @@ -133,51 +137,46 @@ void CudnnConvolution::InitCudnn(DataType dtype, Device *dev, Context *ctx) { const Tensor CudnnConvolution::Forward(int flag, const Tensor &input) { CHECK_EQ(input.device()->lang(), kCuda); - CHECK_EQ(input.shape().size(), 4); + CHECK_EQ(input.nDim(), 4u); buf_.push(input); - batchsize_ = input.shape()[0]; + size_t batchsize = input.shape()[0]; DataType dtype = input.data_type(); Device *dev = input.device(); - if (!has_init_cudnn_) InitCudnn(dtype, dev, dev->context(0)); + if (!has_init_cudnn_) InitCudnn(input); - Shape shape{batchsize_, num_filters_, conv_height_, conv_width_}; + Shape shape{batchsize, num_filters_, conv_height_, conv_width_}; Tensor output(shape, dev, dtype); - float alpha = 1.f, beta = 0.f; - output.device()->Exec( - [input, output, alpha, beta, this](Context *ctx) { - Blob *inblob = input.blob(), *outblob = output.blob(), - *wblob = this->weight_.blob(); - cudnnConvolutionForward(ctx->cudnn_handle, &alpha, this->x_desc_, - inblob->data(), this->filter_desc_, - wblob->data(), this->conv_desc_, this->fp_alg_, - this->workspace_.blob()->mutable_data(), - this->workspace_count_ * sizeof(float), &beta, - this->y_desc_, outblob->mutable_data()); - }, - {input.blob(), weight_.blob()}, {output.blob()}, workspace_.blob()); + output.device()->Exec([input, output, this](Context *ctx) { + Blob *inblob = input.blob(), *outblob = output.blob(), + *wblob = this->weight_.blob(); + float alpha = 1.f, beta = 0.f; + cudnnConvolutionForward(ctx->cudnn_handle, &alpha, this->x_desc_, + inblob->data(), this->filter_desc_, wblob->data(), + this->conv_desc_, this->fp_alg_, + this->workspace_.blob()->mutable_data(), + this->workspace_count_ * sizeof(float), &beta, + this->y_desc_, outblob->mutable_data()); + }, {input.blob(), weight_.blob()}, {output.blob()}, workspace_.blob()); if (bias_term_) { - beta = 1.f; - output.device()->Exec( - [output, alpha, beta, this](Context *ctx) { - Blob *outblob = output.blob(), *bblob = this->bias_.blob(); - cudnnAddTensor(ctx->cudnn_handle, &alpha, this->bias_desc_, - bblob->data(), &beta, this->y_desc_, - outblob->mutable_data()); - }, - {output.blob(), bias_.blob()}, {output.blob()}); + output.device()->Exec([output, this](Context *ctx) { + float beta = 1.f, alpha = 1.0f; + Blob *outblob = output.blob(), *bblob = this->bias_.blob(); + cudnnAddTensor(ctx->cudnn_handle, &alpha, this->bias_desc_, bblob->data(), + &beta, this->y_desc_, outblob->mutable_data()); + }, {output.blob(), bias_.blob()}, {output.blob()}); } return output; } const std::pair<Tensor, vector<Tensor>> CudnnConvolution::Backward( int flag, const Tensor &grad) { + CHECK(has_init_cudnn_); CHECK_EQ(grad.device()->lang(), kCuda); - CHECK_EQ(grad.shape().size(), 4); + CHECK_EQ(grad.nDim(), 4u); Tensor src_data = buf_.top(); buf_.pop(); - float alpha = 1.f, beta = 0.f; vector<Tensor> param_grad; Tensor dx; dx.ResetLike(src_data); @@ -187,42 +186,38 @@ const std::pair<Tensor, vector<Tensor>> CudnnConvolution::Backward( // LOG(ERROR) << "backward bias"; if (bias_term_) { - dx.device()->Exec( - [grad, db, alpha, beta, this](Context *ctx) { - Blob *dyblob = grad.blob(), *dbblob = db.blob(); - cudnnConvolutionBackwardBias(ctx->cudnn_handle, &alpha, this->y_desc_, - dyblob->data(), &beta, this->bias_desc_, - dbblob->mutable_data()); - }, - {grad.blob()}, {db.blob()}); + dx.device()->Exec([grad, db, this](Context *ctx) { + Blob *dyblob = grad.blob(), *dbblob = db.blob(); + float alpha = 1.f, beta = 0.f; + cudnnConvolutionBackwardBias(ctx->cudnn_handle, &alpha, this->y_desc_, + dyblob->data(), &beta, this->bias_desc_, + dbblob->mutable_data()); + }, {grad.blob()}, {db.blob()}); } // LOG(ERROR) << "backward w"; - dx.device()->Exec( - [grad, dw, src_data, alpha, beta, this](Context *ctx) { - Blob *inblob = src_data.blob(), *dyblob = grad.blob(), - *dwblob = dw.blob(); - cudnnConvolutionBackwardFilter( - ctx->cudnn_handle, &alpha, this->x_desc_, inblob->data(), - this->y_desc_, dyblob->data(), this->conv_desc_, - this->bp_filter_alg_, this->workspace_.blob()->mutable_data(), - this->workspace_count_ * sizeof(float), &beta, this->filter_desc_, - dwblob->mutable_data()); - }, - {grad.blob(), src_data.blob()}, {dw.blob(), workspace_.blob()}); + dx.device()->Exec([grad, dw, src_data, this](Context *ctx) { + Blob *inblob = src_data.blob(), *dyblob = grad.blob(), *dwblob = dw.blob(); + float alpha = 1.f, beta = 0.f; + cudnnConvolutionBackwardFilter( + ctx->cudnn_handle, &alpha, this->x_desc_, inblob->data(), this->y_desc_, + dyblob->data(), this->conv_desc_, this->bp_filter_alg_, + this->workspace_.blob()->mutable_data(), + this->workspace_count_ * sizeof(float), &beta, this->filter_desc_, + dwblob->mutable_data()); + }, {grad.blob(), src_data.blob()}, {dw.blob(), workspace_.blob()}); // LOG(ERROR) << "backward src"; - dx.device()->Exec( - [dx, grad, alpha, beta, this](Context *ctx) { - Blob *wblob = this->weight_.blob(), *dyblob = grad.blob(), - *dxblob = dx.blob(); - cudnnConvolutionBackwardData( - ctx->cudnn_handle, &alpha, this->filter_desc_, wblob->data(), - this->y_desc_, dyblob->data(), this->conv_desc_, this->bp_data_alg_, - this->workspace_.blob()->mutable_data(), - this->workspace_count_ * sizeof(float), &beta, this->x_desc_, - dxblob->mutable_data()); - }, - {grad.blob(), weight_.blob()}, {dx.blob(), workspace_.blob()}); + dx.device()->Exec([dx, grad, this](Context *ctx) { + Blob *wblob = this->weight_.blob(), *dyblob = grad.blob(), + *dxblob = dx.blob(); + float alpha = 1.f, beta = 0.f; + cudnnConvolutionBackwardData(ctx->cudnn_handle, &alpha, this->filter_desc_, + wblob->data(), this->y_desc_, dyblob->data(), + this->conv_desc_, this->bp_data_alg_, + this->workspace_.blob()->mutable_data(), + this->workspace_count_ * sizeof(float), &beta, + this->x_desc_, dxblob->mutable_data()); + }, {grad.blob(), weight_.blob()}, {dx.blob(), workspace_.blob()}); param_grad.push_back(dw); param_grad.push_back(db); return std::make_pair(dx, param_grad); http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/7d149ecf/src/model/layer/cudnn_convolution.h ---------------------------------------------------------------------- diff --git a/src/model/layer/cudnn_convolution.h b/src/model/layer/cudnn_convolution.h index cf04be0..b86c576 100644 --- a/src/model/layer/cudnn_convolution.h +++ b/src/model/layer/cudnn_convolution.h @@ -27,6 +27,7 @@ #include "singa/core/common.h" #include "singa/model/layer.h" #include "singa/proto/core.pb.h" +#include "singa/utils/string.h" namespace singa { class CudnnConvolution : public Convolution { @@ -41,13 +42,15 @@ class CudnnConvolution : public Convolution { /// \copydoc Layer::Setup(const LayerConf&); void Setup(const LayerConf &conf) override; - /// Init cudnn related data structures. - void InitCudnn(DataType dtype, Device *dev, Context *ctx); void ToDevice(Device *device) override; size_t workspace_byte_limit() { return workspace_byte_limit_; } - string pref() { return pref_; } + string prefer() { return prefer_; } + + protected: + /// Init cudnn related data structures. + void InitCudnn(const Tensor& input); protected: bool has_init_cudnn_ = false; @@ -61,7 +64,7 @@ class CudnnConvolution : public Convolution { cudnnConvolutionBwdDataAlgo_t bp_data_alg_; size_t workspace_byte_limit_, workspace_count_; Tensor workspace_; - string pref_; + string prefer_; }; } // namespace singa http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/7d149ecf/src/model/layer/cudnn_pooling.cc ---------------------------------------------------------------------- diff --git a/src/model/layer/cudnn_pooling.cc b/src/model/layer/cudnn_pooling.cc index d68bcd2..afbc490 100644 --- a/src/model/layer/cudnn_pooling.cc +++ b/src/model/layer/cudnn_pooling.cc @@ -41,17 +41,19 @@ void CudnnPooling::Setup(const LayerConf &conf) { nan_prop_ = CUDNN_NOT_PROPAGATE_NAN; } -void CudnnPooling::InitCudnn(DataType dtype) { +void CudnnPooling::InitCudnn(const Tensor& input) { CHECK(!has_init_cudnn_); + DataType dtype = input.data_type(); + size_t batchsize = input.shape(0); CUDNN_CHECK(cudnnCreateTensorDescriptor(&x_desc_)); CUDNN_CHECK(cudnnCreateTensorDescriptor(&y_desc_)); CUDNN_CHECK(cudnnCreatePoolingDescriptor(&pool_desc_)); CUDNN_CHECK(cudnnSetTensor4dDescriptor(x_desc_, CUDNN_TENSOR_NCHW, - GetCudnnDataType(dtype), batchsize_, + GetCudnnDataType(dtype), batchsize, channels_, height_, width_)); CUDNN_CHECK(cudnnSetTensor4dDescriptor( - y_desc_, CUDNN_TENSOR_NCHW, GetCudnnDataType(dtype), batchsize_, + y_desc_, CUDNN_TENSOR_NCHW, GetCudnnDataType(dtype), batchsize, channels_, pooled_height_, pooled_width_)); auto pool_method = CUDNN_POOLING_MAX; if (pool_ == PoolingConf_PoolMethod_MAX) @@ -77,19 +79,19 @@ void CudnnPooling::InitCudnn(DataType dtype) { const Tensor CudnnPooling::Forward(int flag, const Tensor &input) { CHECK_EQ(input.device()->lang(), kCuda); - CHECK_EQ(input.shape().size(), 4); + CHECK_EQ(input.nDim(), 4u); buf_.push(input); - batchsize_ = input.shape()[0]; + size_t batchsize = input.shape(0); DataType dtype = input.data_type(); Device *dev = input.device(); - float alpha = 1.0f, beta = 0.0f; - if (!has_init_cudnn_) InitCudnn(dtype); + if (!has_init_cudnn_) InitCudnn(input); - Shape shape{batchsize_, channels_, pooled_height_, pooled_width_}; + Shape shape{batchsize, channels_, pooled_height_, pooled_width_}; Tensor output = Tensor(shape, dev, dtype); output.device()->Exec( - [input, output, alpha, beta, this](Context *ctx) { + [input, output, this](Context *ctx) { Blob *inblob = input.blob(), *outblob = output.blob(); + float alpha = 1.0f, beta = 0.0f; cudnnPoolingForward(ctx->cudnn_handle, this->pool_desc_, &alpha, this->x_desc_, inblob->data(), &beta, this->y_desc_, outblob->mutable_data()); @@ -102,26 +104,26 @@ const Tensor CudnnPooling::Forward(int flag, const Tensor &input) { const std::pair<Tensor, vector<Tensor>> CudnnPooling::Backward( int flag, const Tensor &grad) { CHECK_EQ(grad.device()->lang(), kCuda); - CHECK_EQ(grad.shape().size(), 4); + CHECK_EQ(grad.nDim(), 4u); vector<Tensor> param_grad; - Tensor dx; - Tensor data = buf_.top(); + Tensor y = buf_.top(); buf_.pop(); - Tensor src_data = buf_.top(); + Tensor x = buf_.top(); buf_.pop(); - dx.ResetLike(src_data); + Tensor dx; + dx.ResetLike(x); - float alpha = 1.0f, beta = 0.0f; dx.device()->Exec( - [dx, grad, src_data, data, alpha, beta, this](Context *ctx) { - Blob *dyblob = grad.blob(), *dxblob = dx.blob(), - *yblob = data.blob(), *xblob = src_data.blob(); + [dx, grad, x, y, this](Context *ctx) { + Blob *dyblob = grad.blob(), *dxblob = dx.blob(), *yblob = y.blob(), + *xblob = x.blob(); + float alpha = 1.0f, beta = 0.0f; cudnnPoolingBackward(ctx->cudnn_handle, this->pool_desc_, &alpha, this->y_desc_, yblob->data(), this->y_desc_, dyblob->data(), this->x_desc_, xblob->data(), &beta, this->x_desc_, dxblob->mutable_data()); }, - {grad.blob(), data.blob(), src_data.blob()}, {dx.blob()}); + {grad.blob(), y.blob(), x.blob()}, {dx.blob()}); return std::make_pair(dx, param_grad); } http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/7d149ecf/src/model/layer/cudnn_pooling.h ---------------------------------------------------------------------- diff --git a/src/model/layer/cudnn_pooling.h b/src/model/layer/cudnn_pooling.h index 14bdf40..1a38cd5 100644 --- a/src/model/layer/cudnn_pooling.h +++ b/src/model/layer/cudnn_pooling.h @@ -43,7 +43,7 @@ class CudnnPooling : public Pooling { const Tensor &grad) override; /// Init cudnn related data structures. - void InitCudnn(DataType dtype); + void InitCudnn(const Tensor& input); private: bool has_init_cudnn_ = false; http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/7d149ecf/src/model/layer/pooling.cc ---------------------------------------------------------------------- diff --git a/src/model/layer/pooling.cc b/src/model/layer/pooling.cc index 05c6bc9..2655369 100644 --- a/src/model/layer/pooling.cc +++ b/src/model/layer/pooling.cc @@ -30,8 +30,8 @@ void Pooling::Setup(const LayerConf& conf) { kernel_w_ = pool_conf.kernel_w(); kernel_h_ = pool_conf.kernel_h(); } - CHECK_NE(kernel_w_, 0); - CHECK_NE(kernel_h_, 0); + CHECK_GT(kernel_w_, 0u); + CHECK_GT(kernel_h_, 0u); if (pool_conf.has_pad()) { pad_w_ = pad_h_ = pool_conf.pad(); @@ -39,6 +39,8 @@ void Pooling::Setup(const LayerConf& conf) { pad_w_ = pool_conf.pad_w(); pad_h_ = pool_conf.pad_h(); } + CHECK_GE(pad_w_, 0u); + CHECK_GE(pad_h_, 0u); if (pool_conf.has_stride()) { stride_w_ = stride_h_ = pool_conf.stride(); @@ -46,6 +48,8 @@ void Pooling::Setup(const LayerConf& conf) { stride_w_ = pool_conf.stride_w(); stride_h_ = pool_conf.stride_h(); } + CHECK_GT(stride_w_, 0u); + CHECK_GT(stride_h_, 0u); pool_ = pool_conf.pool(); CHECK(pool_ == PoolingConf_PoolMethod_AVE || http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/7d149ecf/src/model/layer/pooling.h ---------------------------------------------------------------------- diff --git a/src/model/layer/pooling.h b/src/model/layer/pooling.h index ce6670d..522b603 100644 --- a/src/model/layer/pooling.h +++ b/src/model/layer/pooling.h @@ -46,7 +46,6 @@ class Pooling : public Layer { size_t stride_w() const { return stride_w_; } size_t stride_h() const { return stride_h_; } PoolingConf_PoolMethod pool_method() const { return pool_; } - size_t batchsize() const { return batchsize_; } size_t channels() const { return channels_; } size_t height() const { return height_; } size_t width() const { return width_; } @@ -54,7 +53,7 @@ class Pooling : public Layer { protected: size_t kernel_w_, pad_w_, stride_w_; size_t kernel_h_, pad_h_, stride_h_; - size_t batchsize_, channels_, height_, width_, pooled_height_, pooled_width_; + size_t channels_, height_, width_, pooled_height_, pooled_width_; PoolingConf_PoolMethod pool_; // To store the input and output(of forward) tensors std::stack<Tensor> buf_; http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/7d149ecf/src/proto/model.proto ---------------------------------------------------------------------- diff --git a/src/proto/model.proto b/src/proto/model.proto index 03ad6ad..66296d5 100644 --- a/src/proto/model.proto +++ b/src/proto/model.proto @@ -306,7 +306,8 @@ message ConvolutionConf { optional uint32 stride_h = 13; // The stride height (2D only) optional uint32 stride_w = 14; // The stride width (2D only) - optional uint32 group = 5 [default = 1]; // The group size for group conv + // SINGA: not supported. + // optional uint32 group = 5 [default = 1]; // The group size for group conv optional FillerConf weight_filler = 7; // The filler for the weight optional FillerConf bias_filler = 8; // The filler for the bias @@ -326,20 +327,24 @@ message ConvolutionConf { // With (N, C, D, H, W) inputs, and axis == 1, we perform // N independent 3D convolutions, sliding (C/g)-channels // filters across the spatial axes (D, H, W) of the input. - optional int32 axis = 16 [default = 1]; + // SINGA: not supported; + // optional int32 axis = 16 [default = 1]; // Whether to force use of the general ND convolution, even if a specific // implementation for blobs of the appropriate number of spatial dimensions // is available. (Currently, there is only a 2D-specific convolution // implementation; for input blobs with num_axes != 2, this option is // ignored and the ND implementation will be used.) - optional bool force_nd_im2col = 17 [default = false]; - // add by xiangrui + // SINGA: not supported; + // optional bool force_nd_im2col = 17 [default = false]; + + + // SINGA: add by xiangrui // cudnn workspace size in MB optional int32 workspace_byte_limit = 50 [default = 512]; // cudnn algorithm preference // options: "fastest", "limited_workspace", "no_workspace" - optional string algo_pref = 51 [default = "fastest"]; + optional string prefer = 51 [default = "fastest"]; // input shape optional int32 channels = 52; optional int32 height = 53; http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/7d149ecf/test/singa/test_cudnn_convolution.cc ---------------------------------------------------------------------- diff --git a/test/singa/test_cudnn_convolution.cc b/test/singa/test_cudnn_convolution.cc index 0955c82..73359b4 100644 --- a/test/singa/test_cudnn_convolution.cc +++ b/test/singa/test_cudnn_convolution.cc @@ -40,31 +40,31 @@ TEST(CudnnConvolution, Setup) { convconf->set_bias_term(true); // MB convconf->set_workspace_byte_limit(256); - convconf->set_algo_pref("fastest"); + convconf->set_prefer("fastest"); convconf->set_channels(1); convconf->set_height(3); convconf->set_width(3); conv.Setup(conf); - EXPECT_EQ(2, conv.kernel_h()); - EXPECT_EQ(2, conv.kernel_w()); - EXPECT_EQ(1, conv.pad_h()); - EXPECT_EQ(1, conv.pad_w()); - EXPECT_EQ(1, conv.stride_h()); - EXPECT_EQ(1, conv.stride_w()); - EXPECT_EQ(2, conv.num_filters()); + EXPECT_EQ(2u, conv.kernel_h()); + EXPECT_EQ(2u, conv.kernel_w()); + EXPECT_EQ(1u, conv.pad_h()); + EXPECT_EQ(1u, conv.pad_w()); + EXPECT_EQ(1u, conv.stride_h()); + EXPECT_EQ(1u, conv.stride_w()); + EXPECT_EQ(2u, conv.num_filters()); EXPECT_EQ(true, conv.bias_term()); - EXPECT_EQ(256 << 20, conv.workspace_byte_limit()); - EXPECT_STREQ("fastest", conv.pref().c_str()); - EXPECT_EQ(1, conv.channels()); - EXPECT_EQ(3, conv.height()); - EXPECT_EQ(3, conv.width()); + EXPECT_EQ(256u << 20, conv.workspace_byte_limit()); + EXPECT_STREQ("fastest", conv.prefer().c_str()); + EXPECT_EQ(1u, conv.channels()); + EXPECT_EQ(3u, conv.height()); + EXPECT_EQ(3u, conv.width()); } TEST(CudnnConvolution, Forward) { const size_t batchsize = 1, c = 1, h = 3, w = 3; const float x[batchsize * c * h * w] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, - 6.0f, 7.0f, 8.0f, 9.0f}; + 6.0f, 7.0f, 8.0f, 9.0f}; singa::CudaGPU cuda(0, 1); singa::Tensor in(singa::Shape{batchsize, c, h, w}, &cuda); in.CopyDataFromHostPtr(x, batchsize * c * h * w); @@ -94,7 +94,7 @@ TEST(CudnnConvolution, Forward) { convconf->set_bias_term(true); // MB convconf->set_workspace_byte_limit(256); - convconf->set_algo_pref("fastest"); + convconf->set_prefer("fastest"); convconf->set_channels(1); convconf->set_height(3); convconf->set_width(3); @@ -106,7 +106,7 @@ TEST(CudnnConvolution, Forward) { out1.ToDevice(&host); const float *outptr1 = out1.data<const float *>(); // Input: 3*3; kernel: 3*3; stride: 2*2; padding: 1*1. - EXPECT_EQ(4, out1.Size()); + EXPECT_EQ(4u, out1.Size()); EXPECT_EQ(3.0f, outptr1[0]); EXPECT_EQ(7.0f, outptr1[1]); @@ -118,7 +118,7 @@ TEST(CudnnConvolution, Backward) { // src_data const size_t batchsize = 1, c = 1, src_h = 3, src_w = 3; const float x[batchsize * c * src_h * src_w] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, - 6.0f, 7.0f, 8.0f, 9.0f}; + 6.0f, 7.0f, 8.0f, 9.0f}; singa::CudaGPU cuda(0, 1); singa::Tensor in(singa::Shape{batchsize, c, src_h, src_w}, &cuda); in.CopyDataFromHostPtr(x, batchsize * c * src_h * src_w); @@ -148,7 +148,7 @@ TEST(CudnnConvolution, Backward) { convconf->set_num_output(1); convconf->set_bias_term(true); convconf->set_workspace_byte_limit(256); - convconf->set_algo_pref("fastest"); + convconf->set_prefer("fastest"); convconf->set_channels(1); convconf->set_height(3); convconf->set_width(3); @@ -159,8 +159,10 @@ TEST(CudnnConvolution, Backward) { // grad const size_t grad_h = 2, grad_w = 2; - const float dy[batchsize * num_filters * grad_h * grad_w] = {0.1f, 0.2f, 0.3f, 0.4f}; - singa::Tensor grad(singa::Shape{batchsize, num_filters, grad_h, grad_w}, &cuda); + const float dy[batchsize * num_filters * grad_h * grad_w] = {0.1f, 0.2f, 0.3f, + 0.4f}; + singa::Tensor grad(singa::Shape{batchsize, num_filters, grad_h, grad_w}, + &cuda); grad.CopyDataFromHostPtr(dy, batchsize * num_filters * grad_h * grad_w); const auto ret = conv.Backward(singa::kTrain, grad); @@ -169,7 +171,7 @@ TEST(CudnnConvolution, Backward) { in_grad.ToDevice(&host); const float *dx = in_grad.data<const float *>(); const float *wptr = we; - EXPECT_EQ(9, in_grad.Size()); + EXPECT_EQ(9u, in_grad.Size()); EXPECT_EQ(dy[0] * wptr[4], dx[0]); EXPECT_EQ(dy[0] * wptr[5] + dy[1] * wptr[3], dx[1]); EXPECT_EQ(dy[1] * wptr[4], dx[2]); @@ -190,7 +192,7 @@ TEST(CudnnConvolution, Backward) { EXPECT_EQ(dy[0] + dy[1] + dy[2] + dy[3], dbptr[0]); const float *dwptr = dw.data<const float *>(); - EXPECT_EQ(9, dw.Size()); + EXPECT_EQ(9u, dw.Size()); EXPECT_EQ(dy[3] * x[4], dwptr[0]); EXPECT_EQ(dy[3] * x[5] + dy[2] * x[3], dwptr[1]); EXPECT_EQ(dy[2] * x[4], dwptr[2]); @@ -201,5 +203,5 @@ TEST(CudnnConvolution, Backward) { EXPECT_EQ(dy[1] * x[4], dwptr[6]); EXPECT_EQ(dy[0] * x[3] + dy[1] * x[5], dwptr[7]); EXPECT_EQ(dy[0] * x[4], dwptr[8]); -} // USE_CUDNN -#endif +} +#endif // USE_CUDNN http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/7d149ecf/test/singa/test_cudnn_pooling.cc ---------------------------------------------------------------------- diff --git a/test/singa/test_cudnn_pooling.cc b/test/singa/test_cudnn_pooling.cc index 0bfd620..e66f212 100644 --- a/test/singa/test_cudnn_pooling.cc +++ b/test/singa/test_cudnn_pooling.cc @@ -43,23 +43,23 @@ TEST(CudnnPooling, Setup) { pool.Setup(conf); EXPECT_EQ(singa::PoolingConf_PoolMethod_MAX, pool.pool_method()); - EXPECT_EQ(1, pool.kernel_h()); - EXPECT_EQ(2, pool.kernel_w()); - EXPECT_EQ(1, pool.pad_h()); - EXPECT_EQ(0, pool.pad_w()); - EXPECT_EQ(2, pool.stride_h()); - EXPECT_EQ(1, pool.stride_w()); - EXPECT_EQ(1, pool.channels()); - EXPECT_EQ(3, pool.height()); - EXPECT_EQ(3, pool.width()); + EXPECT_EQ(1u, pool.kernel_h()); + EXPECT_EQ(2u, pool.kernel_w()); + EXPECT_EQ(1u, pool.pad_h()); + EXPECT_EQ(0u, pool.pad_w()); + EXPECT_EQ(2u, pool.stride_h()); + EXPECT_EQ(1u, pool.stride_w()); + EXPECT_EQ(1u, pool.channels()); + EXPECT_EQ(3u, pool.height()); + EXPECT_EQ(3u, pool.width()); } TEST(CudnnPooling, Forward) { const size_t batchsize = 1, c = 1, h = 3, w = 3; const float x[batchsize * c * h * w] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, - 6.0f, 7.0f, 8.0f, 9.0f}; + 6.0f, 7.0f, 8.0f, 9.0f}; singa::CudaGPU cuda(0, 1); - singa::Tensor in(singa::Shape{batchsize, c, h, w}, &cuda); + singa::Tensor in(singa::Shape{batchsize, c, h, w}, &cuda); in.CopyDataFromHostPtr(x, batchsize * c * h * w); CudnnPooling pool; @@ -83,7 +83,7 @@ TEST(CudnnPooling, Forward) { out1.ToDevice(&host); const float *outptr1 = out1.data<const float *>(); // Input: 3*3; kernel: 2*2; stride: 1*1; no padding. - EXPECT_EQ(4, out1.Size()); + EXPECT_EQ(4u, out1.Size()); EXPECT_EQ(5.0f, outptr1[0]); EXPECT_EQ(6.0f, outptr1[1]); EXPECT_EQ(8.0f, outptr1[2]); @@ -127,7 +127,7 @@ TEST(CudnnPooling, Backward) { singa::Tensor in_grad = ret.first; in_grad.ToDevice(&host); const float *dx = in_grad.data<const float *>(); - EXPECT_EQ(9, in_grad.Size()); + EXPECT_EQ(9u, in_grad.Size()); EXPECT_EQ(0.0f, dx[0]); EXPECT_EQ(0.0f, dx[1]); EXPECT_EQ(0.0f, dx[2]);
