This is an automated email from the ASF dual-hosted git repository. jxie pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push: new f0f8af1 [MXNET-266] Fix cudnn_conv and cudnn_deconv deadlock (#10392) f0f8af1 is described below commit f0f8af1e193894ab21774f1ad8e06498c4c25ff8 Author: reminisce <wujun....@gmail.com> AuthorDate: Wed Apr 4 14:19:24 2018 -0700 [MXNET-266] Fix cudnn_conv and cudnn_deconv deadlock (#10392) * Fix deadlock of cudnn_conv wrapper * Fix deconv deadlock * Fix lint * Revert "Fix lint" This reverts commit 66f0936de9822cd9ccd00038cb0c563cfaafcd64. * Fix lint * Fix indentation --- src/operator/nn/convolution.cu | 36 +- src/operator/nn/cudnn/cudnn_convolution-inl.h | 420 +++++++++++------------ src/operator/nn/cudnn/cudnn_deconvolution-inl.h | 434 ++++++++++++------------ src/operator/nn/deconvolution.cc | 4 + src/operator/nn/deconvolution.cu | 14 +- tests/python/unittest/test_operator.py | 9 +- 6 files changed, 455 insertions(+), 462 deletions(-) diff --git a/src/operator/nn/convolution.cu b/src/operator/nn/convolution.cu index f6d14e3..045e570 100644 --- a/src/operator/nn/convolution.cu +++ b/src/operator/nn/convolution.cu @@ -36,10 +36,12 @@ namespace op { #if MXNET_USE_CUDNN == 1 template<typename DType> -static CuDNNConvolutionOp<DType> &GetCuDNNConvOp(const ConvolutionParam& param, - int forward_compute_type, int backward_compute_type, - const std::vector<TShape>& in_shape, const std::vector<TShape>& out_shape, - const Context& ctx) { +static CuDNNConvolutionOp<DType>& GetCuDNNConvOp(const ConvolutionParam& param, + int forward_compute_type, + int backward_compute_type, + const std::vector<TShape>& in_shape, + const std::vector<TShape>& out_shape, + const RunContext& rctx) { #if DMLC_CXX11_THREAD_LOCAL static thread_local std::unordered_map<ConvSignature, std::shared_ptr<CuDNNConvolutionOp<DType> >, @@ -62,7 +64,7 @@ static CuDNNConvolutionOp<DType> &GetCuDNNConvOp(const ConvolutionParam& param, key.AddSign(backward_compute_type); key.AddSign(in_shape); key.AddSign(out_shape); - key.AddSign(ctx.dev_id); + key.AddSign(rctx.ctx.dev_id); auto it = ops.find(key); if (it == ops.end()) { @@ -72,7 +74,7 @@ static CuDNNConvolutionOp<DType> &GetCuDNNConvOp(const ConvolutionParam& param, CHECK(ins_ret.second); it = ins_ret.first; it->second->Init(param, forward_compute_type, backward_compute_type, in_shape, - out_shape, ctx); + out_shape, rctx); } return *it->second; } @@ -80,9 +82,10 @@ static CuDNNConvolutionOp<DType> &GetCuDNNConvOp(const ConvolutionParam& param, template<> void ConvolutionCompute<gpu>(const nnvm::NodeAttrs& attrs, - const OpContext& ctx, const std::vector<TBlob>& inputs, - const std::vector<OpReqType>& req, - const std::vector<TBlob>& outputs) { + const OpContext& ctx, + const std::vector<TBlob>& inputs, + const std::vector<OpReqType>& req, + const std::vector<TBlob>& outputs) { const ConvolutionParam& param = nnvm::get<ConvolutionParam>(attrs.parsed); int dtype = inputs[conv::kData].type_flag_; @@ -120,7 +123,7 @@ void ConvolutionCompute<gpu>(const nnvm::NodeAttrs& attrs, op.Init(param); op.Forward(ctx, inputs, req, outputs); } else if (!CuDNNConvolutionOp<DType>::Supports(param, - compute_type, compute_type, ctx.run_ctx.ctx)) { + compute_type, compute_type, ctx.run_ctx.ctx.dev_id)) { LOG(WARNING) << "This convolution is not supported by cudnn, MXNET convolution is applied."; ConvolutionOp<gpu, DType> op; op.Init(param); @@ -131,7 +134,7 @@ void ConvolutionCompute<gpu>(const nnvm::NodeAttrs& attrs, for (size_t i = 0; i < in_shape.size(); i++) in_shape[i] = inputs[i].shape_; CuDNNConvolutionOp<DType> &op = GetCuDNNConvOp<DType>(param, - compute_type, compute_type, in_shape, out_shape, ctx.run_ctx.ctx); + compute_type, compute_type, in_shape, out_shape, ctx.run_ctx); op.Forward(ctx, inputs, req, outputs); } }) @@ -146,9 +149,10 @@ void ConvolutionCompute<gpu>(const nnvm::NodeAttrs& attrs, template<> void ConvolutionGradCompute<gpu>(const nnvm::NodeAttrs& attrs, - const OpContext& ctx, const std::vector<TBlob>& inputs, - const std::vector<OpReqType>& req, - const std::vector<TBlob>& outputs) { + const OpContext& ctx, + const std::vector<TBlob>& inputs, + const std::vector<OpReqType>& req, + const std::vector<TBlob>& outputs) { const ConvolutionParam& param = nnvm::get<ConvolutionParam>(attrs.parsed); std::vector<TBlob> in_data(inputs.begin() + 1, inputs.end()); const TBlob &out_grad = inputs[0]; @@ -190,7 +194,7 @@ void ConvolutionGradCompute<gpu>(const nnvm::NodeAttrs& attrs, op.Init(param); op.Backward(ctx, std::vector<TBlob>{out_grad}, in_data, req, in_grad); } else if (!CuDNNConvolutionOp<DType>::Supports(param, - compute_type, compute_type, ctx.run_ctx.ctx)) { + compute_type, compute_type, ctx.run_ctx.ctx.dev_id)) { LOG(WARNING) << "This convolution is not supported by cudnn, MXNET convolution is applied."; ConvolutionOp<gpu, DType> op; op.Init(param); @@ -202,7 +206,7 @@ void ConvolutionGradCompute<gpu>(const nnvm::NodeAttrs& attrs, for (size_t i = 0; i < in_shape.size(); i++) in_shape[i] = in_data[i].shape_; CuDNNConvolutionOp<DType> &op = GetCuDNNConvOp<DType>(param, - compute_type, compute_type, in_shape, out_shape, ctx.run_ctx.ctx); + compute_type, compute_type, in_shape, out_shape, ctx.run_ctx); op.Backward(ctx, std::vector<TBlob>{out_grad}, in_data, req, in_grad); } }) diff --git a/src/operator/nn/cudnn/cudnn_convolution-inl.h b/src/operator/nn/cudnn/cudnn_convolution-inl.h index 229ba3c..ca60c99 100644 --- a/src/operator/nn/cudnn/cudnn_convolution-inl.h +++ b/src/operator/nn/cudnn/cudnn_convolution-inl.h @@ -55,11 +55,11 @@ class CuDNNConvolutionOp { } void Init(const ConvolutionParam& param, - int forward_compute_type, - int backward_compute_type, - const std::vector<TShape>& in_shape, - const std::vector<TShape>& out_shape, - const Context& ctx) { + int forward_compute_type, + int backward_compute_type, + const std::vector<TShape>& in_shape, + const std::vector<TShape>& out_shape, + const RunContext& rctx) { using namespace mshadow; this->param_ = param; InitBufferForParam(); @@ -90,10 +90,10 @@ class CuDNNConvolutionOp { param_.layout.value() == kNCDHW) << "Need CuDNN > 5.0 for layout support"; #endif // Double check to make sure this class supports the operation - if (!Supports(param, forward_compute_type, backward_compute_type, ctx)) + if (!Supports(param, forward_compute_type, backward_compute_type, rctx.ctx.dev_id)) LOG(FATAL) << "Need CuDNN >= 6.0 for dilated convolution."; - InitDescriptors(ctx, in_shape, out_shape, + InitDescriptors(in_shape, out_shape, cudnn_forward_compute_type, cudnn_backward_compute_type); if (!param_.cudnn_tune) { @@ -105,7 +105,7 @@ class CuDNNConvolutionOp { // approach keeps the treatment of convolution cases uniform and will // naturally respond to more algorithms supporting dilated convolutions in // future cuDNN releases. - SelectAlgo(ctx, in_shape, out_shape, + SelectAlgo(rctx, in_shape, out_shape, cudnn_forward_compute_type, cudnn_backward_compute_type); } @@ -120,9 +120,9 @@ class CuDNNConvolutionOp { } void Forward(const OpContext &ctx, - const std::vector<TBlob> &in_data, - const std::vector<OpReqType> &req, - const std::vector<TBlob> &out_data) { + const std::vector<TBlob> &in_data, + const std::vector<OpReqType> &req, + const std::vector<TBlob> &out_data) { using namespace mshadow; size_t expected = param_.no_bias ? 2 : 3; CHECK_EQ(in_data.size(), expected); @@ -289,7 +289,7 @@ class CuDNNConvolutionOp { static bool Supports(ConvolutionParam param, int forward_compute_type, int backward_compute_type, - const Context &ctx) { + int dev_id) { using namespace mshadow; // NDHWC not supported, NHWC not supported in true fp16 @@ -301,7 +301,7 @@ class CuDNNConvolutionOp { return false; // Permits graceful fallback to pseudo-fp16 on heterogenous systems - if (!SupportsFloat16Compute(ctx.dev_id) && + if (!SupportsFloat16Compute(dev_id) && (forward_compute_type == kFloat16 || backward_compute_type == kFloat16)) { return false; } @@ -329,8 +329,7 @@ class CuDNNConvolutionOp { return converted; } - void InitDescriptors(const Context& ctx, - const std::vector<TShape>& in_shape, + void InitDescriptors(const std::vector<TShape>& in_shape, const std::vector<TShape>& out_shape, cudnnDataType_t cudnn_forward_compute_type, cudnnDataType_t cudnn_backward_compute_type) { @@ -512,220 +511,213 @@ class CuDNNConvolutionOp { } } - void SelectAlgo(const Context& ctx, + void SelectAlgo(const RunContext& rctx, const std::vector<TShape>& in_shape, const std::vector<TShape>& out_shape, cudnnDataType_t cudnn_forward_compute_type, cudnnDataType_t cudnn_backward_compute_type) { if (!CuDNNConvAlgoReg::Get()->Find(param_, in_shape, out_shape, dtype_, cudnn_forward_compute_type, cudnn_backward_compute_type, - SMArch(ctx.dev_id), &forward_algo_, &back_algo_, + SMArch(rctx.ctx.dev_id), &forward_algo_, &back_algo_, &back_algo_w_)) { - // Not in algo registry, must determine via *Get*() or *Find*() - Engine::VarHandle var = Engine::Get()->NewVariable(); - Engine::Get()->PushAsync([=](RunContext rctx, Engine::CallbackOnComplete on_complete) { - mshadow::Stream<gpu> *s = rctx.get_stream<gpu>(); - CHECK_EQ(s->dnn_handle_ownership_, mshadow::Stream<gpu>::OwnHandle); - size_t workspace_byte = static_cast<size_t>(param_.workspace * sizeof(DType)); - #if CUDNN_MAJOR >= 7 - // Starting with cuDNNv7, the algo number returned by *Get*() is not the entire - // story: the notion of whether the algo ran in Tensor Core mode is not known. - // Since we want to report the Tensor Core mode in the verbose output, we switch - // to using the new *Get*_v7() call. Since the function signature of *Get*_v7() matches - // that of *Find*(), we can unify the find-vs-get logic by using function pointers. - - // Forward Algorithm Find/Get() v7 - std::vector<cudnnConvolutionFwdAlgoPerf_t> fwd_results(MaxForwardAlgos(s->dnn_handle_)); - int actual_fwd_algos = 0; - auto fwd_algo_discoverer = - param_.cudnn_tune.value() == conv::kOff ? cudnnGetConvolutionForwardAlgorithm_v7 - : cudnnFindConvolutionForwardAlgorithm; - CUDNN_CALL((*fwd_algo_discoverer)(s->dnn_handle_, - in_desc_, - filter_desc_, - forward_conv_desc_, - out_desc_, - fwd_results.size(), - &actual_fwd_algos, - fwd_results.data())); - fwd_results.resize(actual_fwd_algos); - AlgoFinalSelect<cudnnConvolutionFwdAlgoPerf_t, - cudnnConvolutionFwdAlgo_t>(fwd_results, "forward", - workspace_byte, &forward_algo_); - - // Backprop-to-Filter Algorithm Find/Get() v7 - auto max_bwd_filt_algos = MaxBackwardFilterAlgos(s->dnn_handle_); - std::vector<cudnnConvolutionBwdFilterAlgoPerf_t> bwd_filt_results(max_bwd_filt_algos); - int actual_bwd_filter_algos = 0; - auto bwd_filter_algo_discoverer = - param_.cudnn_tune.value() == conv::kOff ? cudnnGetConvolutionBackwardFilterAlgorithm_v7 - : cudnnFindConvolutionBackwardFilterAlgorithm; - CUDNN_CALL((*bwd_filter_algo_discoverer)(s->dnn_handle_, - in_desc_, - out_desc_, - back_conv_desc_w_, - filter_desc_, - bwd_filt_results.size(), - &actual_bwd_filter_algos, - bwd_filt_results.data())); - bwd_filt_results.resize(actual_bwd_filter_algos); - AlgoFinalSelect<cudnnConvolutionBwdFilterAlgoPerf_t, - cudnnConvolutionBwdFilterAlgo_t>(bwd_filt_results, "backprop-to-filter", - workspace_byte, &back_algo_w_); - - // Backprop-to-Data Algorithm Find/Get() v7 - auto max_bwd_data_algos = MaxBackwardDataAlgos(s->dnn_handle_); - std::vector<cudnnConvolutionBwdDataAlgoPerf_t> bwd_data_results(max_bwd_data_algos); - int actual_bwd_data_algos = 0; - auto bwd_data_algo_discoverer = - param_.cudnn_tune.value() == conv::kOff ? cudnnGetConvolutionBackwardDataAlgorithm_v7 - : cudnnFindConvolutionBackwardDataAlgorithm; - CUDNN_CALL((*bwd_data_algo_discoverer)(s->dnn_handle_, + mshadow::Stream<gpu> *s = rctx.get_stream<gpu>(); + CHECK_EQ(s->dnn_handle_ownership_, mshadow::Stream<gpu>::OwnHandle); + size_t workspace_byte = static_cast<size_t>(param_.workspace * sizeof(DType)); + #if CUDNN_MAJOR >= 7 + // Starting with cuDNNv7, the algo number returned by *Get*() is not the entire + // story: the notion of whether the algo ran in Tensor Core mode is not known. + // Since we want to report the Tensor Core mode in the verbose output, we switch + // to using the new *Get*_v7() call. Since the function signature of *Get*_v7() matches + // that of *Find*(), we can unify the find-vs-get logic by using function pointers. + + // Forward Algorithm Find/Get() v7 + std::vector<cudnnConvolutionFwdAlgoPerf_t> fwd_results(MaxForwardAlgos(s->dnn_handle_)); + int actual_fwd_algos = 0; + auto fwd_algo_discoverer = + param_.cudnn_tune.value() == conv::kOff ? cudnnGetConvolutionForwardAlgorithm_v7 + : cudnnFindConvolutionForwardAlgorithm; + CUDNN_CALL((*fwd_algo_discoverer)(s->dnn_handle_, + in_desc_, + filter_desc_, + forward_conv_desc_, + out_desc_, + fwd_results.size(), + &actual_fwd_algos, + fwd_results.data())); + fwd_results.resize(actual_fwd_algos); + AlgoFinalSelect<cudnnConvolutionFwdAlgoPerf_t, + cudnnConvolutionFwdAlgo_t>(fwd_results, "forward", + workspace_byte, &forward_algo_); + + // Backprop-to-Filter Algorithm Find/Get() v7 + auto max_bwd_filt_algos = MaxBackwardFilterAlgos(s->dnn_handle_); + std::vector<cudnnConvolutionBwdFilterAlgoPerf_t> bwd_filt_results(max_bwd_filt_algos); + int actual_bwd_filter_algos = 0; + auto bwd_filter_algo_discoverer = + param_.cudnn_tune.value() == conv::kOff ? cudnnGetConvolutionBackwardFilterAlgorithm_v7 + : cudnnFindConvolutionBackwardFilterAlgorithm; + CUDNN_CALL((*bwd_filter_algo_discoverer)(s->dnn_handle_, + in_desc_, + out_desc_, + back_conv_desc_w_, + filter_desc_, + bwd_filt_results.size(), + &actual_bwd_filter_algos, + bwd_filt_results.data())); + bwd_filt_results.resize(actual_bwd_filter_algos); + AlgoFinalSelect<cudnnConvolutionBwdFilterAlgoPerf_t, + cudnnConvolutionBwdFilterAlgo_t>(bwd_filt_results, "backprop-to-filter", + workspace_byte, &back_algo_w_); + + // Backprop-to-Data Algorithm Find/Get() v7 + auto max_bwd_data_algos = MaxBackwardDataAlgos(s->dnn_handle_); + std::vector<cudnnConvolutionBwdDataAlgoPerf_t> bwd_data_results(max_bwd_data_algos); + int actual_bwd_data_algos = 0; + auto bwd_data_algo_discoverer = + param_.cudnn_tune.value() == conv::kOff ? cudnnGetConvolutionBackwardDataAlgorithm_v7 + : cudnnFindConvolutionBackwardDataAlgorithm; + CUDNN_CALL((*bwd_data_algo_discoverer)(s->dnn_handle_, + filter_desc_, + out_desc_, + back_conv_desc_, + in_desc_, + bwd_data_results.size(), + &actual_bwd_data_algos, + bwd_data_results.data())); + bwd_data_results.resize(actual_bwd_data_algos); + AlgoFinalSelect<cudnnConvolutionBwdDataAlgoPerf_t, + cudnnConvolutionBwdDataAlgo_t>(bwd_data_results, "backprop-to-data", + workspace_byte, &back_algo_); + #else + // CUDNN_MAJOR < 7 + const int kMaxAlgos = 10; + int nalgo = kMaxAlgos; + int i = 0; + // Forward Algorithm Find/Get, v6 and earlier + if (CUDNN_MAJOR == 6 && param_.layout.value() == mshadow::kNHWC) { + // In cuDNNv6, for kNHWC, only CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM is + // supported. Hard-coded this since the algo find() or get() throws an FPE. + forward_algo_.Set(CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM, false); + } else if (!param_.cudnn_tune.value()) { + cudnnConvolutionFwdAlgo_t fastest_fwd_algo; + CUDNN_CALL(cudnnGetConvolutionForwardAlgorithm(s->dnn_handle_, + in_desc_, filter_desc_, + forward_conv_desc_, out_desc_, - back_conv_desc_, - in_desc_, - bwd_data_results.size(), - &actual_bwd_data_algos, - bwd_data_results.data())); - bwd_data_results.resize(actual_bwd_data_algos); - AlgoFinalSelect<cudnnConvolutionBwdDataAlgoPerf_t, - cudnnConvolutionBwdDataAlgo_t>(bwd_data_results, "backprop-to-data", - workspace_byte, &back_algo_); - #else - // CUDNN_MAJOR < 7 - const int kMaxAlgos = 10; - int nalgo = kMaxAlgos; - int i = 0; - // Forward Algorithm Find/Get, v6 and earlier - if (CUDNN_MAJOR == 6 && param_.layout.value() == mshadow::kNHWC) { - // In cuDNNv6, for kNHWC, only CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM is - // supported. Hard-coded this since the algo find() or get() throws an FPE. - forward_algo_.Set(CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM, false); - } else if (!param_.cudnn_tune.value()) { - cudnnConvolutionFwdAlgo_t fastest_fwd_algo; - CUDNN_CALL(cudnnGetConvolutionForwardAlgorithm(s->dnn_handle_, - in_desc_, - filter_desc_, - forward_conv_desc_, - out_desc_, - CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT, - workspace_byte, - &fastest_fwd_algo)); - forward_algo_.Set(fastest_fwd_algo, false); - } else { - cudnnConvolutionFwdAlgoPerf_t fwd_algo[kMaxAlgos]; - CUDNN_CALL(cudnnFindConvolutionForwardAlgorithm(s->dnn_handle_, - in_desc_, - filter_desc_, - forward_conv_desc_, - out_desc_, - kMaxAlgos, - &nalgo, - fwd_algo)); - i = 0; - while (i < nalgo - && (fwd_algo[i].status != CUDNN_STATUS_SUCCESS - || (param_.cudnn_tune.value() == conv::kLimited - && fwd_algo[i].memory > workspace_byte))) - ++i; - if (i == nalgo) { - LOG(FATAL) << "Failed to find a forward convolution algorithm."; - } else { - forward_algo_.Set(fwd_algo[i].algo, false); - } - } - // Backprop-to-Filter Algorithm Find/Get, v6 and earlier - if (!param_.cudnn_tune.value()) { - cudnnConvolutionBwdFilterAlgo_t fastest_bwd_filt_algo; - CUDNN_CALL(cudnnGetConvolutionBackwardFilterAlgorithm(s->dnn_handle_, - in_desc_, - out_desc_, - back_conv_desc_w_, - filter_desc_, - CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT, - workspace_byte, - &fastest_bwd_filt_algo)); - back_algo_w_.Set(fastest_bwd_filt_algo, false); - } else { - cudnnConvolutionBwdFilterAlgoPerf_t bwd_filter_algo[kMaxAlgos]; - CUDNN_CALL(cudnnFindConvolutionBackwardFilterAlgorithm(s->dnn_handle_, - in_desc_, - out_desc_, - back_conv_desc_w_, - filter_desc_, - kMaxAlgos, - &nalgo, - bwd_filter_algo)); - i = 0; - while (i < nalgo - && (bwd_filter_algo[i].status != CUDNN_STATUS_SUCCESS - || (param_.cudnn_tune.value() == conv::kLimited - && bwd_filter_algo[i].memory > workspace_byte))) - ++i; - if (i == nalgo) { - LOG(FATAL) << "Failed to find a backward filter convolution algorithm."; - } else { - back_algo_w_.Set(bwd_filter_algo[i].algo, false); - } - } - // Backprop-to-Data Algorithm Get(), v6 and earlier - if (!param_.cudnn_tune.value()) { - cudnnConvolutionBwdDataAlgo_t fastest_bwd_data_algo; - CUDNN_CALL(cudnnGetConvolutionBackwardDataAlgorithm(s->dnn_handle_, - filter_desc_, - out_desc_, - back_conv_desc_, - in_desc_, - CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT, - workspace_byte, - &fastest_bwd_data_algo)); - back_algo_.Set(fastest_bwd_data_algo, false); - } else { - cudnnConvolutionBwdDataAlgoPerf_t bwd_data_algo[kMaxAlgos]; - CUDNN_CALL(cudnnFindConvolutionBackwardDataAlgorithm(s->dnn_handle_, - filter_desc_, - out_desc_, - back_conv_desc_, - in_desc_, - kMaxAlgos, - &nalgo, - bwd_data_algo)); - i = 0; - while (i < nalgo - && (bwd_data_algo[i].status != CUDNN_STATUS_SUCCESS - || (param_.cudnn_tune.value() == conv::kLimited - && bwd_data_algo[i].memory > workspace_byte))) - ++i; - if (i == nalgo) { - LOG(FATAL) << "Failed to find a backward data convolution algorithm."; - } else { - back_algo_.Set(bwd_data_algo[i].algo, false); - } - } - #endif // CUDNN_MAJOR < 7 - // An algo specification by the user may be cached here, but another - // convolution will match only if identically specified. - // We're caching results of *Get* as well as *Find*, but these records - // will be held distinctly because param_.cudnn_tune is part of the key. - CuDNNConvAlgoReg::Get()->Register(param_, in_shape, out_shape, dtype_, - cudnn_forward_compute_type, - cudnn_backward_compute_type, - SMArch(ctx.dev_id), this->forward_algo_, - this->back_algo_, this->back_algo_w_); - on_complete(); - }, ctx, {}, {var}); - Engine::Get()->WaitForVar(var); - Engine::Get()->DeleteVariable([](RunContext s) {}, ctx, var); + CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT, + workspace_byte, + &fastest_fwd_algo)); + forward_algo_.Set(fastest_fwd_algo, false); + } else { + cudnnConvolutionFwdAlgoPerf_t fwd_algo[kMaxAlgos]; + CUDNN_CALL(cudnnFindConvolutionForwardAlgorithm(s->dnn_handle_, + in_desc_, + filter_desc_, + forward_conv_desc_, + out_desc_, + kMaxAlgos, + &nalgo, + fwd_algo)); + i = 0; + while (i < nalgo + && (fwd_algo[i].status != CUDNN_STATUS_SUCCESS + || (param_.cudnn_tune.value() == conv::kLimited + && fwd_algo[i].memory > workspace_byte))) + ++i; + if (i == nalgo) { + LOG(FATAL) << "Failed to find a forward convolution algorithm."; + } else { + forward_algo_.Set(fwd_algo[i].algo, false); + } + } + // Backprop-to-Filter Algorithm Find/Get, v6 and earlier + if (!param_.cudnn_tune.value()) { + cudnnConvolutionBwdFilterAlgo_t fastest_bwd_filt_algo; + CUDNN_CALL(cudnnGetConvolutionBackwardFilterAlgorithm(s->dnn_handle_, + in_desc_, + out_desc_, + back_conv_desc_w_, + filter_desc_, + CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT, + workspace_byte, + &fastest_bwd_filt_algo)); + back_algo_w_.Set(fastest_bwd_filt_algo, false); + } else { + cudnnConvolutionBwdFilterAlgoPerf_t bwd_filter_algo[kMaxAlgos]; + CUDNN_CALL(cudnnFindConvolutionBackwardFilterAlgorithm(s->dnn_handle_, + in_desc_, + out_desc_, + back_conv_desc_w_, + filter_desc_, + kMaxAlgos, + &nalgo, + bwd_filter_algo)); + i = 0; + while (i < nalgo + && (bwd_filter_algo[i].status != CUDNN_STATUS_SUCCESS + || (param_.cudnn_tune.value() == conv::kLimited + && bwd_filter_algo[i].memory > workspace_byte))) + ++i; + if (i == nalgo) { + LOG(FATAL) << "Failed to find a backward filter convolution algorithm."; + } else { + back_algo_w_.Set(bwd_filter_algo[i].algo, false); + } + } + // Backprop-to-Data Algorithm Get(), v6 and earlier + if (!param_.cudnn_tune.value()) { + cudnnConvolutionBwdDataAlgo_t fastest_bwd_data_algo; + CUDNN_CALL(cudnnGetConvolutionBackwardDataAlgorithm(s->dnn_handle_, + filter_desc_, + out_desc_, + back_conv_desc_, + in_desc_, + CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT, + workspace_byte, + &fastest_bwd_data_algo)); + back_algo_.Set(fastest_bwd_data_algo, false); + } else { + cudnnConvolutionBwdDataAlgoPerf_t bwd_data_algo[kMaxAlgos]; + CUDNN_CALL(cudnnFindConvolutionBackwardDataAlgorithm(s->dnn_handle_, + filter_desc_, + out_desc_, + back_conv_desc_, + in_desc_, + kMaxAlgos, + &nalgo, + bwd_data_algo)); + i = 0; + while (i < nalgo + && (bwd_data_algo[i].status != CUDNN_STATUS_SUCCESS + || (param_.cudnn_tune.value() == conv::kLimited + && bwd_data_algo[i].memory > workspace_byte))) + ++i; + if (i == nalgo) { + LOG(FATAL) << "Failed to find a backward data convolution algorithm."; + } else { + back_algo_.Set(bwd_data_algo[i].algo, false); + } + } + #endif // CUDNN_MAJOR < 7 + // An algo specification by the user may be cached here, but another + // convolution will match only if identically specified. + // We're caching results of *Get* as well as *Find*, but these records + // will be held distinctly because param_.cudnn_tune is part of the key. + CuDNNConvAlgoReg::Get()->Register(param_, in_shape, out_shape, dtype_, + cudnn_forward_compute_type, + cudnn_backward_compute_type, + SMArch(rctx.ctx.dev_id), this->forward_algo_, + this->back_algo_, this->back_algo_w_); } // If we're allowing Tensor Core variants of the algos to be considered in // *Find*() or *Get*(), but a non-Tensor-Core algo variant is the fastest, // we must change the descriptor to preclude Tensor Core. Simplest is to // once again set the mathType in all cases. #if CUDNN_MAJOR >= 7 - CUDNN_CALL(cudnnSetConvolutionMathType(forward_conv_desc_, forward_algo_.MathType())); - CUDNN_CALL(cudnnSetConvolutionMathType(back_conv_desc_, back_algo_.MathType())); - CUDNN_CALL(cudnnSetConvolutionMathType(back_conv_desc_w_, back_algo_w_.MathType())); + CUDNN_CALL(cudnnSetConvolutionMathType(forward_conv_desc_, forward_algo_.MathType())); + CUDNN_CALL(cudnnSetConvolutionMathType(back_conv_desc_, back_algo_.MathType())); + CUDNN_CALL(cudnnSetConvolutionMathType(back_conv_desc_w_, back_algo_w_.MathType())); #endif } diff --git a/src/operator/nn/cudnn/cudnn_deconvolution-inl.h b/src/operator/nn/cudnn/cudnn_deconvolution-inl.h index 3c80cdc..cb0de4c 100644 --- a/src/operator/nn/cudnn/cudnn_deconvolution-inl.h +++ b/src/operator/nn/cudnn/cudnn_deconvolution-inl.h @@ -52,11 +52,11 @@ class CuDNNDeconvolutionOp { } void Init(DeconvolutionParam param, - int forward_compute_type, - int backward_compute_type, - const std::vector<TShape>& in_shape, - const std::vector<TShape>& out_shape, - const Context& ctx) { + int forward_compute_type, + int backward_compute_type, + const std::vector<TShape>& in_shape, + const std::vector<TShape>& out_shape, + const RunContext& rctx) { using namespace mshadow; this->param_ = param; InitBufferForParam(); @@ -87,10 +87,10 @@ class CuDNNDeconvolutionOp { param_.layout.value() == kNCDHW) << "Need CuDNN > 5.0 for layout support"; #endif // Double check to make sure this class supports the operation - if (!Supports(param, forward_compute_type, backward_compute_type, ctx)) + if (!Supports(param, forward_compute_type, backward_compute_type, rctx.ctx.dev_id)) LOG(FATAL) << "Need CuDNN >= 6.0 for dilated deconvolution."; - InitDescriptors(ctx, in_shape, out_shape, + InitDescriptors(in_shape, out_shape, cudnn_forward_compute_type, cudnn_backward_compute_type); if (!param_.cudnn_tune) { @@ -102,7 +102,7 @@ class CuDNNDeconvolutionOp { // approach keeps the treatment of convolution cases uniform and will // naturally respond to more algorithms supporting dilated convolutions in // future cuDNN releases. - SelectAlgo(ctx, in_shape, out_shape, + SelectAlgo(rctx, in_shape, out_shape, cudnn_forward_compute_type, cudnn_backward_compute_type); } @@ -117,9 +117,9 @@ class CuDNNDeconvolutionOp { } void Forward(const OpContext &ctx, - const std::vector<TBlob> &in_data, - const std::vector<OpReqType> &req, - const std::vector<TBlob> &out_data) { + const std::vector<TBlob> &in_data, + const std::vector<OpReqType> &req, + const std::vector<TBlob> &out_data) { using namespace mshadow; size_t expected = param_.no_bias ? 2 : 3; CHECK_EQ(in_data.size(), expected); @@ -193,10 +193,10 @@ class CuDNNDeconvolutionOp { } void Backward(const OpContext &ctx, - const std::vector<TBlob> &out_grad, - const std::vector<TBlob> &in_data, - const std::vector<OpReqType> &req, - const std::vector<TBlob> &in_grad) { + const std::vector<TBlob> &out_grad, + const std::vector<TBlob> &in_data, + const std::vector<OpReqType> &req, + const std::vector<TBlob> &in_grad) { using namespace mshadow; using namespace mshadow::expr; size_t expected = param_.no_bias == 0 ? 3 : 2; @@ -299,7 +299,7 @@ class CuDNNDeconvolutionOp { static bool Supports(DeconvolutionParam param, int forward_compute_type, int backward_compute_type, - const Context &ctx) { + int dev_id) { using namespace mshadow; // NDHWC not supported, NHWC not supported in true fp16 @@ -311,7 +311,7 @@ class CuDNNDeconvolutionOp { return false; // Permits graceful fallback to pseudo-fp16 on heterogenous systems - if (!SupportsFloat16Compute(ctx.dev_id) && + if (!SupportsFloat16Compute(dev_id) && (forward_compute_type == kFloat16 || backward_compute_type == kFloat16)) { return false; } @@ -344,8 +344,7 @@ class CuDNNDeconvolutionOp { return converted; } - inline void InitDescriptors(const Context& ctx, - const std::vector<TShape> &in_shape, + inline void InitDescriptors(const std::vector<TShape> &in_shape, const std::vector<TShape> &out_shape, cudnnDataType_t cudnn_forward_compute_type, cudnnDataType_t cudnn_backward_compute_type) { @@ -536,7 +535,7 @@ class CuDNNDeconvolutionOp { } } - void SelectAlgo(const Context& ctx, + void SelectAlgo(const RunContext& rctx, const std::vector<TShape>& in_shape, const std::vector<TShape>& out_shape, cudnnDataType_t cudnn_forward_compute_type, @@ -544,222 +543,215 @@ class CuDNNDeconvolutionOp { if (!CuDNNDeconvAlgoReg::Get()->Find(param_, in_shape, out_shape, dtype_, cudnn_forward_compute_type, cudnn_backward_compute_type, - SMArch(ctx.dev_id), &forward_algo_, + SMArch(rctx.ctx.dev_id), &forward_algo_, &back_algo_, &back_algo_w_)) { - // Not in algo registry, must determine via *Get*() or *Find*() - Engine::VarHandle var = Engine::Get()->NewVariable(); - Engine::Get()->PushAsync([=](RunContext rctx, Engine::CallbackOnComplete on_complete) { - mshadow::Stream <gpu> *s = rctx.get_stream<gpu>(); - CHECK_EQ(s->dnn_handle_ownership_, mshadow::Stream<gpu>::OwnHandle); - size_t workspace_byte = static_cast<size_t>(param_.workspace * sizeof(DType)); - #if CUDNN_MAJOR >= 7 - // Starting with cuDNNv7, the algo number returned by *Get*() is not the entire - // story: the notion of whether the algo ran in Tensor Core mode is not known. - // Since we want to report the Tensor Core mode in the verbose output, we switch - // to using the new *Get*_v7() call. Since the function signature of *Get*_v7() matches - // that of *Find*(), we can unify the find-vs-get logic by using function pointers. - - // Forward Algorithm Find/Get() v7 - std::vector<cudnnConvolutionFwdAlgoPerf_t> fwd_results(MaxForwardAlgos(s->dnn_handle_)); - int actual_fwd_algos = 0; - auto fwd_algo_discoverer = - param_.cudnn_tune.value() == conv::kOff ? cudnnGetConvolutionForwardAlgorithm_v7 - : cudnnFindConvolutionForwardAlgorithm; - CUDNN_CALL((*fwd_algo_discoverer)(s->dnn_handle_, - out_desc_, - filter_desc_, - back_conv_desc_, // fwd algo used to backprop-to-data - in_desc_, - fwd_results.size(), - &actual_fwd_algos, - fwd_results.data())); - fwd_results.resize(actual_fwd_algos); - AlgoFinalSelect<cudnnConvolutionFwdAlgoPerf_t, - cudnnConvolutionFwdAlgo_t>(fwd_results, "forward", - workspace_byte, &forward_algo_); - - // Backprop-to-Filter Algorithm Find/Get() v7 - auto max_bwd_filt_algos = MaxBackwardFilterAlgos(s->dnn_handle_); - std::vector<cudnnConvolutionBwdFilterAlgoPerf_t> bwd_filt_results(max_bwd_filt_algos); - int actual_bwd_filter_algos = 0; - auto bwd_filter_algo_discoverer = - param_.cudnn_tune.value() == conv::kOff ? cudnnGetConvolutionBackwardFilterAlgorithm_v7 - : cudnnFindConvolutionBackwardFilterAlgorithm; - CUDNN_CALL((*bwd_filter_algo_discoverer)(s->dnn_handle_, + mshadow::Stream <gpu> *s = rctx.get_stream<gpu>(); + CHECK_EQ(s->dnn_handle_ownership_, mshadow::Stream<gpu>::OwnHandle); + size_t workspace_byte = static_cast<size_t>(param_.workspace * sizeof(DType)); + #if CUDNN_MAJOR >= 7 + // Starting with cuDNNv7, the algo number returned by *Get*() is not the entire + // story: the notion of whether the algo ran in Tensor Core mode is not known. + // Since we want to report the Tensor Core mode in the verbose output, we switch + // to using the new *Get*_v7() call. Since the function signature of *Get*_v7() matches + // that of *Find*(), we can unify the find-vs-get logic by using function pointers. + + // Forward Algorithm Find/Get() v7 + std::vector<cudnnConvolutionFwdAlgoPerf_t> fwd_results(MaxForwardAlgos(s->dnn_handle_)); + int actual_fwd_algos = 0; + auto fwd_algo_discoverer = + param_.cudnn_tune.value() == conv::kOff ? cudnnGetConvolutionForwardAlgorithm_v7 + : cudnnFindConvolutionForwardAlgorithm; + CUDNN_CALL((*fwd_algo_discoverer)(s->dnn_handle_, + out_desc_, + filter_desc_, + back_conv_desc_, // fwd algo used to backprop-to-data + in_desc_, + fwd_results.size(), + &actual_fwd_algos, + fwd_results.data())); + fwd_results.resize(actual_fwd_algos); + AlgoFinalSelect<cudnnConvolutionFwdAlgoPerf_t, + cudnnConvolutionFwdAlgo_t>(fwd_results, "forward", + workspace_byte, &forward_algo_); + + // Backprop-to-Filter Algorithm Find/Get() v7 + auto max_bwd_filt_algos = MaxBackwardFilterAlgos(s->dnn_handle_); + std::vector<cudnnConvolutionBwdFilterAlgoPerf_t> bwd_filt_results(max_bwd_filt_algos); + int actual_bwd_filter_algos = 0; + auto bwd_filter_algo_discoverer = + param_.cudnn_tune.value() == conv::kOff ? cudnnGetConvolutionBackwardFilterAlgorithm_v7 + : cudnnFindConvolutionBackwardFilterAlgorithm; + CUDNN_CALL((*bwd_filter_algo_discoverer)(s->dnn_handle_, + out_desc_, + in_desc_, + back_conv_desc_, + filter_desc_, + bwd_filt_results.size(), + &actual_bwd_filter_algos, + bwd_filt_results.data())); + bwd_filt_results.resize(actual_bwd_filter_algos); + AlgoFinalSelect<cudnnConvolutionBwdFilterAlgoPerf_t, + cudnnConvolutionBwdFilterAlgo_t>(bwd_filt_results, "backprop-to-filter", + workspace_byte, &back_algo_w_); + + // Backprop-to-Data Algorithm Find/Get() v7 + auto max_bwd_data_algos = MaxBackwardDataAlgos(s->dnn_handle_); + std::vector<cudnnConvolutionBwdDataAlgoPerf_t> bwd_data_results(max_bwd_data_algos); + int actual_bwd_data_algos = 0; + auto bwd_data_algo_discoverer = + param_.cudnn_tune.value() == conv::kOff ? cudnnGetConvolutionBackwardDataAlgorithm_v7 + : cudnnFindConvolutionBackwardDataAlgorithm; + CUDNN_CALL((*bwd_data_algo_discoverer)(s->dnn_handle_, + filter_desc_, + in_desc_, + forward_conv_desc_, // bwd algo used in inference + out_desc_, + bwd_data_results.size(), + &actual_bwd_data_algos, + bwd_data_results.data())); + bwd_data_results.resize(actual_bwd_data_algos); + AlgoFinalSelect<cudnnConvolutionBwdDataAlgoPerf_t, + cudnnConvolutionBwdDataAlgo_t>(bwd_data_results, "backprop-to-data", + workspace_byte, &back_algo_); + #else + // CUDNN_MAJOR < 7 + const int kMaxAlgos = 10; + int nalgo = kMaxAlgos; + int i = 0; + // Forward Algorithm Find/Get, v6 and earlier + if (CUDNN_MAJOR == 6 && param_.layout.value() == mshadow::kNHWC) { + // In cuDNNv6, for kNHWC, only CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM is + // supported. Hard-coded this since the algo find() or get() throws an FPE. + forward_algo_.Set(CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM, false); + } else if (!param_.cudnn_tune.value()) { + cudnnConvolutionFwdAlgo_t fastest_fwd_algo; + CUDNN_CALL(cudnnGetConvolutionForwardAlgorithm(s->dnn_handle_, out_desc_, - in_desc_, - back_conv_desc_, filter_desc_, - bwd_filt_results.size(), - &actual_bwd_filter_algos, - bwd_filt_results.data())); - bwd_filt_results.resize(actual_bwd_filter_algos); - AlgoFinalSelect<cudnnConvolutionBwdFilterAlgoPerf_t, - cudnnConvolutionBwdFilterAlgo_t>(bwd_filt_results, "backprop-to-filter", - workspace_byte, &back_algo_w_); - - // Backprop-to-Data Algorithm Find/Get() v7 - auto max_bwd_data_algos = MaxBackwardDataAlgos(s->dnn_handle_); - std::vector<cudnnConvolutionBwdDataAlgoPerf_t> bwd_data_results(max_bwd_data_algos); - int actual_bwd_data_algos = 0; - auto bwd_data_algo_discoverer = - param_.cudnn_tune.value() == conv::kOff ? cudnnGetConvolutionBackwardDataAlgorithm_v7 - : cudnnFindConvolutionBackwardDataAlgorithm; - CUDNN_CALL((*bwd_data_algo_discoverer)(s->dnn_handle_, - filter_desc_, - in_desc_, - forward_conv_desc_, // bwd algo used in inference - out_desc_, - bwd_data_results.size(), - &actual_bwd_data_algos, - bwd_data_results.data())); - bwd_data_results.resize(actual_bwd_data_algos); - AlgoFinalSelect<cudnnConvolutionBwdDataAlgoPerf_t, - cudnnConvolutionBwdDataAlgo_t>(bwd_data_results, "backprop-to-data", - workspace_byte, &back_algo_); - #else - // CUDNN_MAJOR < 7 - const int kMaxAlgos = 10; - int nalgo = kMaxAlgos; - int i = 0; - // Forward Algorithm Find/Get, v6 and earlier - if (CUDNN_MAJOR == 6 && param_.layout.value() == mshadow::kNHWC) { - // In cuDNNv6, for kNHWC, only CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM is - // supported. Hard-coded this since the algo find() or get() throws an FPE. - forward_algo_.Set(CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM, false); - } else if (!param_.cudnn_tune.value()) { - cudnnConvolutionFwdAlgo_t fastest_fwd_algo; - CUDNN_CALL(cudnnGetConvolutionForwardAlgorithm(s->dnn_handle_, - out_desc_, - filter_desc_, - back_conv_desc_, // fwd algo used in dgrad - in_desc_, - CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT, - workspace_byte, - &fastest_fwd_algo)); - forward_algo_.Set(fastest_fwd_algo, false); + back_conv_desc_, // fwd algo used in dgrad + in_desc_, + CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT, + workspace_byte, + &fastest_fwd_algo)); + forward_algo_.Set(fastest_fwd_algo, false); + } else { + cudnnConvolutionFwdAlgoPerf_t fwd_algo[kMaxAlgos]; + CUDNN_CALL(cudnnFindConvolutionForwardAlgorithm(s->dnn_handle_, + out_desc_, + filter_desc_, + back_conv_desc_, // fwd algo used in dgrad + in_desc_, + kMaxAlgos, + &nalgo, + fwd_algo)); + i = 0; + while (i < nalgo + && (fwd_algo[i].status != CUDNN_STATUS_SUCCESS + || (param_.cudnn_tune.value() == deconv::kLimited + && fwd_algo[i].memory > workspace_byte))) + ++i; + if (i == nalgo) { + LOG(FATAL) << "Failed to find a 'forward' convolution algorithm " << + "(for use in deconvolution operator backprop-to-data)."; } else { - cudnnConvolutionFwdAlgoPerf_t fwd_algo[kMaxAlgos]; - CUDNN_CALL(cudnnFindConvolutionForwardAlgorithm(s->dnn_handle_, - out_desc_, - filter_desc_, - back_conv_desc_, // fwd algo used in dgrad - in_desc_, - kMaxAlgos, - &nalgo, - fwd_algo)); - i = 0; - while (i < nalgo - && (fwd_algo[i].status != CUDNN_STATUS_SUCCESS - || (param_.cudnn_tune.value() == deconv::kLimited - && fwd_algo[i].memory > workspace_byte))) - ++i; - if (i == nalgo) { - LOG(FATAL) << "Failed to find a 'forward' convolution algorithm " << - "(for use in deconvolution operator backprop-to-data)."; - } else { - forward_algo_.Set(fwd_algo[i].algo, false); - } + forward_algo_.Set(fwd_algo[i].algo, false); } - // Backprop-to-Filter Algorithm Find/Get, v6 and earlier - if (!param_.cudnn_tune.value()) { - cudnnConvolutionBwdFilterAlgo_t fastest_bwd_filt_algo; - CUDNN_CALL(cudnnGetConvolutionBackwardFilterAlgorithm(s->dnn_handle_, - out_desc_, - in_desc_, - back_conv_desc_, - filter_desc_, - CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT, - workspace_byte, - &fastest_bwd_filt_algo)); - back_algo_w_.Set(fastest_bwd_filt_algo, false); + } + // Backprop-to-Filter Algorithm Find/Get, v6 and earlier + if (!param_.cudnn_tune.value()) { + cudnnConvolutionBwdFilterAlgo_t fastest_bwd_filt_algo; + CUDNN_CALL(cudnnGetConvolutionBackwardFilterAlgorithm(s->dnn_handle_, + out_desc_, + in_desc_, + back_conv_desc_, + filter_desc_, + CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT, + workspace_byte, + &fastest_bwd_filt_algo)); + back_algo_w_.Set(fastest_bwd_filt_algo, false); + } else { + cudnnConvolutionBwdFilterAlgoPerf_t bwd_filter_algo[kMaxAlgos]; + CUDNN_CALL(cudnnFindConvolutionBackwardFilterAlgorithm(s->dnn_handle_, + out_desc_, + in_desc_, + back_conv_desc_, + filter_desc_, + kMaxAlgos, + &nalgo, + bwd_filter_algo)); + i = 0; + while (i < nalgo + && (bwd_filter_algo[i].status != CUDNN_STATUS_SUCCESS + || (param_.cudnn_tune.value() == deconv::kLimited + && bwd_filter_algo[i].memory > workspace_byte))) + ++i; + if (i == nalgo) { + LOG(FATAL) << "Failed to find a backward filter convolution algorithm " << + "(for use in deconvolution operator backprop-to-filter)."; } else { - cudnnConvolutionBwdFilterAlgoPerf_t bwd_filter_algo[kMaxAlgos]; - CUDNN_CALL(cudnnFindConvolutionBackwardFilterAlgorithm(s->dnn_handle_, - out_desc_, - in_desc_, - back_conv_desc_, - filter_desc_, - kMaxAlgos, - &nalgo, - bwd_filter_algo)); - i = 0; - while (i < nalgo - && (bwd_filter_algo[i].status != CUDNN_STATUS_SUCCESS - || (param_.cudnn_tune.value() == deconv::kLimited - && bwd_filter_algo[i].memory > workspace_byte))) - ++i; - if (i == nalgo) { - LOG(FATAL) << "Failed to find a backward filter convolution algorithm " << - "(for use in deconvolution operator backprop-to-filter)."; - } else { - back_algo_w_.Set(bwd_filter_algo[i].algo, false); - } + back_algo_w_.Set(bwd_filter_algo[i].algo, false); } - // Backprop-to-Data Algorithm Get(), v6 and earlier - if (!param_.cudnn_tune.value()) { - cudnnConvolutionBwdDataAlgo_t fastest_bwd_data_algo; - CUDNN_CALL(cudnnGetConvolutionBackwardDataAlgorithm(s->dnn_handle_, - filter_desc_, - in_desc_, - forward_conv_desc_, // bwd algo used for inference - out_desc_, - CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT, - workspace_byte, - &fastest_bwd_data_algo)); - back_algo_.Set(fastest_bwd_data_algo, false); + } + // Backprop-to-Data Algorithm Get(), v6 and earlier + if (!param_.cudnn_tune.value()) { + cudnnConvolutionBwdDataAlgo_t fastest_bwd_data_algo; + CUDNN_CALL(cudnnGetConvolutionBackwardDataAlgorithm(s->dnn_handle_, + filter_desc_, + in_desc_, + forward_conv_desc_, // bwd algo used for inference + out_desc_, + CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT, + workspace_byte, + &fastest_bwd_data_algo)); + back_algo_.Set(fastest_bwd_data_algo, false); + } else { + cudnnConvolutionBwdDataAlgoPerf_t bwd_data_algo[kMaxAlgos]; + CUDNN_CALL(cudnnFindConvolutionBackwardDataAlgorithm(s->dnn_handle_, + filter_desc_, + in_desc_, + forward_conv_desc_, // bwd algo used in inference + out_desc_, + kMaxAlgos, + &nalgo, + bwd_data_algo)); + i = 0; + while (i < nalgo + && (bwd_data_algo[i].status != CUDNN_STATUS_SUCCESS + || (param_.cudnn_tune.value() == deconv::kLimited + && bwd_data_algo[i].memory > workspace_byte))) + ++i; + if (i == nalgo) { + LOG(FATAL) << "Failed to find a backward data convolution algorithm." << + "(for use in deconvolution operator forward inference)."; } else { - cudnnConvolutionBwdDataAlgoPerf_t bwd_data_algo[kMaxAlgos]; - CUDNN_CALL(cudnnFindConvolutionBackwardDataAlgorithm(s->dnn_handle_, - filter_desc_, - in_desc_, - forward_conv_desc_, // bwd algo used in inference - out_desc_, - kMaxAlgos, - &nalgo, - bwd_data_algo)); - i = 0; - while (i < nalgo - && (bwd_data_algo[i].status != CUDNN_STATUS_SUCCESS - || (param_.cudnn_tune.value() == deconv::kLimited - && bwd_data_algo[i].memory > workspace_byte))) - ++i; - if (i == nalgo) { - LOG(FATAL) << "Failed to find a backward data convolution algorithm." << - "(for use in deconvolution operator forward inference)."; - } else { - back_algo_.Set(bwd_data_algo[i].algo, false); - } + back_algo_.Set(bwd_data_algo[i].algo, false); } - #endif // CUDNN_MAJOR < 7 - // An algo specification by the user may be cached here, but another - // convolution will match only if identically specified. - // We're caching results of *Get* as well as *Find*, but these records - // will be held distinctly because param_.cudnn_tune is part of the key. - CuDNNDeconvAlgoReg::Get()->Register(param_, in_shape, out_shape, dtype_, - cudnn_forward_compute_type, - cudnn_backward_compute_type, - SMArch(ctx.dev_id), this->forward_algo_, - this->back_algo_, this->back_algo_w_); - on_complete(); - }, ctx, {}, {var}); - Engine::Get()->WaitForVar(var); - Engine::Get()->DeleteVariable([](RunContext s) {}, ctx, var); + } + #endif // CUDNN_MAJOR < 7 + // An algo specification by the user may be cached here, but another + // convolution will match only if identically specified. + // We're caching results of *Get* as well as *Find*, but these records + // will be held distinctly because param_.cudnn_tune is part of the key. + CuDNNDeconvAlgoReg::Get()->Register(param_, in_shape, out_shape, dtype_, + cudnn_forward_compute_type, + cudnn_backward_compute_type, + SMArch(rctx.ctx.dev_id), this->forward_algo_, + this->back_algo_, this->back_algo_w_); } // If we're allowing Tensor Core variants of the algos to be considered in // *Find*() or *Get*(), but a non-Tensor-Core algo variant is the fastest, // we must change the descriptor to preclude Tensor Core. Simplest is to // once again set the mathType in all cases. #if CUDNN_MAJOR >= 7 - // The next two code lines will look like they have typos, but they don't! - // The forward_conv_desc_ is used during inference, which invokes the back_algo_. - // Thus, the mathType of the back_algo_ should be stored in the forward_conv_desc_. - // Conversely, the back_conv_desc_ is used during training backprop, which invokes - // the forward_algo_. Thus, the mathType of the forward_algo_ should be stored - // in the back_conv_desc_. - CUDNN_CALL(cudnnSetConvolutionMathType(forward_conv_desc_, back_algo_.MathType())); - CUDNN_CALL(cudnnSetConvolutionMathType(back_conv_desc_, forward_algo_.MathType())); - CUDNN_CALL(cudnnSetConvolutionMathType(back_conv_desc_w_, back_algo_w_.MathType())); + // The next two code lines will look like they have typos, but they don't! + // The forward_conv_desc_ is used during inference, which invokes the back_algo_. + // Thus, the mathType of the back_algo_ should be stored in the forward_conv_desc_. + // Conversely, the back_conv_desc_ is used during training backprop, which invokes + // the forward_algo_. Thus, the mathType of the forward_algo_ should be stored + // in the back_conv_desc_. + CUDNN_CALL(cudnnSetConvolutionMathType(forward_conv_desc_, back_algo_.MathType())); + CUDNN_CALL(cudnnSetConvolutionMathType(back_conv_desc_, forward_algo_.MathType())); + CUDNN_CALL(cudnnSetConvolutionMathType(back_conv_desc_w_, back_algo_w_.MathType())); #endif } diff --git a/src/operator/nn/deconvolution.cc b/src/operator/nn/deconvolution.cc index a3fc915..0d1b391 100644 --- a/src/operator/nn/deconvolution.cc +++ b/src/operator/nn/deconvolution.cc @@ -391,6 +391,10 @@ NNVM_REGISTER_OP(Deconvolution) [](const NodeAttrs& attrs) { return ListArguments(nnvm::get<DeconvolutionParam>(attrs.parsed)); }) +.set_attr<nnvm::FListOutputNames>("FListOutputNames", + [](const NodeAttrs& attrs) { + return std::vector<std::string>{"output"}; +}) .set_attr<nnvm::FInferShape>("FInferShape", DeconvolutionShape) .set_attr<nnvm::FInferType>("FInferType", DeconvolutionType) .set_attr<FInferStorageType>("FInferStorageType", DeconvStorageType) diff --git a/src/operator/nn/deconvolution.cu b/src/operator/nn/deconvolution.cu index 086b470..1cabe73 100644 --- a/src/operator/nn/deconvolution.cu +++ b/src/operator/nn/deconvolution.cu @@ -39,7 +39,7 @@ static CuDNNDeconvolutionOp<DType> &GetCuDNNDeconvOp(const DeconvolutionParam& p int backward_compute_type, const std::vector<TShape>& in_shape, const std::vector<TShape>& out_shape, - const Context& ctx) { + const RunContext& rctx) { static thread_local std::unordered_map<DeconvSignature, std::shared_ptr<CuDNNDeconvolutionOp<DType> >, OpHash> ops; @@ -56,7 +56,7 @@ static CuDNNDeconvolutionOp<DType> &GetCuDNNDeconvOp(const DeconvolutionParam& p key.AddSign(backward_compute_type); key.AddSign(in_shape); key.AddSign(out_shape); - key.AddSign(ctx.dev_id); + key.AddSign(rctx.ctx.dev_id); auto it = ops.find(key); if (it == ops.end()) { @@ -66,7 +66,7 @@ static CuDNNDeconvolutionOp<DType> &GetCuDNNDeconvOp(const DeconvolutionParam& p CHECK(ins_ret.second); it = ins_ret.first; it->second->Init(param, forward_compute_type, backward_compute_type, in_shape, - out_shape, ctx); + out_shape, rctx); } return *it->second; } @@ -91,7 +91,7 @@ void DeconvolutionCompute<gpu>(const nnvm::NodeAttrs& attrs, op.Init(param); op.Forward(ctx, inputs, req, outputs); } else if (!CuDNNDeconvolutionOp<DType>::Supports(param, - compute_type, compute_type, ctx.run_ctx.ctx)) { + compute_type, compute_type, ctx.run_ctx.ctx.dev_id)) { LOG(WARNING) << "This deconvolution is not supported by cudnn, MXNET deconvolution is applied."; DeconvolutionOp<gpu, DType> op; @@ -104,7 +104,7 @@ void DeconvolutionCompute<gpu>(const nnvm::NodeAttrs& attrs, in_shape[i] = inputs[i].shape_; } GetCuDNNDeconvOp<DType>(param, compute_type, compute_type, - in_shape, out_shape, ctx.run_ctx.ctx).Forward(ctx, inputs, req, outputs); + in_shape, out_shape, ctx.run_ctx).Forward(ctx, inputs, req, outputs); } }) #else @@ -138,7 +138,7 @@ void DeconvolutionGradCompute<gpu>(const nnvm::NodeAttrs& attrs, op.Init(param); op.Backward(ctx, std::vector<TBlob>{out_grad}, in_data, req, in_grad); } else if (!CuDNNDeconvolutionOp<DType>::Supports(param, - compute_type, compute_type, ctx.run_ctx.ctx)) { + compute_type, compute_type, ctx.run_ctx.ctx.dev_id)) { LOG(WARNING) << "This deconvolution is not supported by cudnn, MXNET deconvolution is applied."; DeconvolutionOp<gpu, DType> op; @@ -151,7 +151,7 @@ void DeconvolutionGradCompute<gpu>(const nnvm::NodeAttrs& attrs, in_shape[i] = in_data[i].shape_; } GetCuDNNDeconvOp<DType>(param, compute_type, compute_type, - in_shape, out_shape, ctx.run_ctx.ctx).Backward(ctx, + in_shape, out_shape, ctx.run_ctx).Backward(ctx, std::vector<TBlob>{out_grad}, in_data, req, in_grad); } }) diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 0039449..a0ae480 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -5432,10 +5432,11 @@ def test_op_output_names_monitor(): assert output_name == expected_name data = mx.sym.Variable('data', shape=(10, 3, 10, 10)) - # Temporarily disabling convolutional test as it is exposing a hang. - # See: https://github.com/apache/incubator-mxnet/issues/10341 - # conv_sym = mx.sym.Convolution(data, kernel=(2, 2), num_filter=1, name='conv') - # check_name(conv_sym, ['conv_output']) + conv_sym = mx.sym.Convolution(data, kernel=(2, 2), num_filter=1, name='conv') + check_name(conv_sym, ['conv_output']) + + deconv_sym = mx.sym.Deconvolution(data, kernel=(2, 2), num_filter=1, name='deconv') + check_name(deconv_sym, ['deconv_output']) fc_sym = mx.sym.FullyConnected(data, num_hidden=10, name='fc') check_name(fc_sym, ['fc_output']) -- To stop receiving notification emails like this one, please contact j...@apache.org.