Repository: incubator-singa Updated Branches: refs/heads/master 56292f1fb -> e16cea129
SINGA-346 Update cudnn from V5 to V7 support cudnn5 (conv and rnn has API changes from v5 to v7) Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/e2092030 Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/e2092030 Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/e2092030 Branch: refs/heads/master Commit: e20920309bfeb6ed7e0adf3d529c2fba1d44ad2f Parents: 56292f1 Author: Wang Wei <[email protected]> Authored: Thu Jul 5 22:57:33 2018 +0800 Committer: wang wei <[email protected]> Committed: Sun Jul 8 16:00:38 2018 +0800 ---------------------------------------------------------------------- src/model/layer/cudnn_convolution.cc | 101 +++++++++--------- src/model/layer/cudnn_rnn.cc | 165 +++++++++++++++--------------- 2 files changed, 137 insertions(+), 129 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/e2092030/src/model/layer/cudnn_convolution.cc ---------------------------------------------------------------------- diff --git a/src/model/layer/cudnn_convolution.cc b/src/model/layer/cudnn_convolution.cc index 8846746..1b12f93 100644 --- a/src/model/layer/cudnn_convolution.cc +++ b/src/model/layer/cudnn_convolution.cc @@ -44,7 +44,7 @@ void CudnnConvolution::Setup(const Shape& in_sample, const LayerConf &conf) { CHECK(prefer_ == "fastest" || prefer_ == "limited_workspace" || prefer_ == "no_workspace" || prefer_ == "autotune") << "CudnnConvolution only supports four algorithm preferences: fastest, " - "limited_workspace, no_workspace and autotune"; + "limited_workspace, no_workspace and autotune"; } void CudnnConvolution::ToDevice(std::shared_ptr<Device> device) { @@ -70,16 +70,19 @@ void CudnnConvolution::InitCudnn(const Tensor &input) { GetCudnnDataType(dtype), batchsize, channels_, height_, width_)); CUDNN_CHECK(cudnnSetTensor4dDescriptor( - y_desc_, CUDNN_TENSOR_NCHW, GetCudnnDataType(dtype), batchsize, - num_filters_, conv_height_, conv_width_)); + y_desc_, CUDNN_TENSOR_NCHW, GetCudnnDataType(dtype), batchsize, + num_filters_, conv_height_, conv_width_)); if (bias_term_) CUDNN_CHECK(cudnnSetTensor4dDescriptor(bias_desc_, CUDNN_TENSOR_NCHW, GetCudnnDataType(dtype), 1, num_filters_, 1, 1)); CUDNN_CHECK(cudnnSetConvolution2dDescriptor(conv_desc_, pad_h_, pad_w_, - stride_h_, stride_w_, 1, 1, - CUDNN_CROSS_CORRELATION, - GetCudnnDataType(dtype))); + stride_h_, stride_w_, 1, 1, // dilation x and y + CUDNN_CROSS_CORRELATION +#if CUDNN_MAJOR == 5 + , GetCudnnDataType(dtype) +#endif // CUDNN_MAJOR + )); CUDNN_CHECK(cudnnSetFilter4dDescriptor(filter_desc_, GetCudnnDataType(dtype), CUDNN_TENSOR_NCHW, num_filters_, channels_, kernel_h_, kernel_w_)); @@ -102,15 +105,15 @@ void CudnnConvolution::InitCudnn(const Tensor &input) { bwd_data_pref = CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT; } CUDNN_CHECK(cudnnGetConvolutionForwardAlgorithm( - ctx->cudnn_handle, x_desc_, filter_desc_, conv_desc_, y_desc_, fwd_pref, - workspace_byte_limit_, &fp_alg_)); + ctx->cudnn_handle, x_desc_, filter_desc_, conv_desc_, y_desc_, fwd_pref, + workspace_byte_limit_, &fp_alg_)); CUDNN_CHECK(cudnnGetConvolutionBackwardFilterAlgorithm( - ctx->cudnn_handle, x_desc_, y_desc_, conv_desc_, filter_desc_, - bwd_filt_pref, workspace_byte_limit_, &bp_filter_alg_)); + ctx->cudnn_handle, x_desc_, y_desc_, conv_desc_, filter_desc_, + bwd_filt_pref, workspace_byte_limit_, &bp_filter_alg_)); // deprecated in cudnn v7 CUDNN_CHECK(cudnnGetConvolutionBackwardDataAlgorithm( - ctx->cudnn_handle, filter_desc_, y_desc_, conv_desc_, x_desc_, - bwd_data_pref, workspace_byte_limit_, &bp_data_alg_)); + ctx->cudnn_handle, filter_desc_, y_desc_, conv_desc_, x_desc_, + bwd_data_pref, workspace_byte_limit_, &bp_data_alg_)); } else if (prefer_ == "autotune") { const int topk = 1; int num_fp_alg, num_bp_filt_alg, num_bp_data_alg; @@ -118,16 +121,16 @@ void CudnnConvolution::InitCudnn(const Tensor &input) { cudnnConvolutionBwdFilterAlgoPerf_t bp_filt_perf[topk]; cudnnConvolutionBwdDataAlgoPerf_t bp_data_perf[topk]; CUDNN_CHECK(cudnnFindConvolutionForwardAlgorithm( - ctx->cudnn_handle, x_desc_, filter_desc_, conv_desc_, y_desc_, topk, - &num_fp_alg, fp_alg_perf)); + ctx->cudnn_handle, x_desc_, filter_desc_, conv_desc_, y_desc_, topk, + &num_fp_alg, fp_alg_perf)); fp_alg_ = fp_alg_perf[0].algo; CUDNN_CHECK(cudnnFindConvolutionBackwardFilterAlgorithm( - ctx->cudnn_handle, x_desc_, y_desc_, conv_desc_, filter_desc_, topk, - &num_bp_filt_alg, bp_filt_perf)); + ctx->cudnn_handle, x_desc_, y_desc_, conv_desc_, filter_desc_, topk, + &num_bp_filt_alg, bp_filt_perf)); bp_filter_alg_ = bp_filt_perf[0].algo; CUDNN_CHECK(cudnnFindConvolutionBackwardDataAlgorithm( - ctx->cudnn_handle, filter_desc_, y_desc_, conv_desc_, x_desc_, topk, - &num_bp_data_alg, bp_data_perf)); + ctx->cudnn_handle, filter_desc_, y_desc_, conv_desc_, x_desc_, topk, + &num_bp_data_alg, bp_data_perf)); bp_data_alg_ = bp_data_perf[0].algo; } else { LOG(FATAL) << "Preferred algorithm is not available!"; @@ -135,22 +138,22 @@ void CudnnConvolution::InitCudnn(const Tensor &input) { size_t fp_byte, bp_data_byte, bp_filter_byte; CUDNN_CHECK(cudnnGetConvolutionForwardWorkspaceSize( - ctx->cudnn_handle, x_desc_, filter_desc_, conv_desc_, y_desc_, fp_alg_, - &fp_byte)); + ctx->cudnn_handle, x_desc_, filter_desc_, conv_desc_, y_desc_, fp_alg_, + &fp_byte)); CUDNN_CHECK(cudnnGetConvolutionBackwardDataWorkspaceSize( - ctx->cudnn_handle, filter_desc_, y_desc_, conv_desc_, x_desc_, - bp_data_alg_, &bp_data_byte)); + ctx->cudnn_handle, filter_desc_, y_desc_, conv_desc_, x_desc_, + bp_data_alg_, &bp_data_byte)); CUDNN_CHECK(cudnnGetConvolutionBackwardFilterWorkspaceSize( - ctx->cudnn_handle, x_desc_, y_desc_, conv_desc_, filter_desc_, - bp_filter_alg_, &bp_filter_byte)); + ctx->cudnn_handle, x_desc_, y_desc_, conv_desc_, filter_desc_, + bp_filter_alg_, &bp_filter_byte)); workspace_count_ = std::max(std::max(fp_byte, bp_data_byte), bp_filter_byte) / - sizeof(float) + + sizeof(float) + 1; if (workspace_count_ * sizeof(float) > workspace_byte_limit_) LOG(WARNING) << "The required memory for workspace (" - << workspace_count_ * sizeof(float) - << ") is larger than the expected Bytes (" - << workspace_byte_limit_ << ")"; + << workspace_count_ * sizeof(float) + << ") is larger than the expected Bytes (" + << workspace_byte_limit_ << ")"; workspace_ = Tensor(Shape{workspace_count_}, dev, dtype); has_init_cudnn_ = true; } @@ -170,23 +173,23 @@ const Tensor CudnnConvolution::Forward(int flag, const Tensor &input) { int n, c, h, w, s; cudnnDataType_t type; CUDNN_CHECK(cudnnGetTensor4dDescriptor(x_desc_, &type, &n, &c, &h, &w, - &s, &s, &s, &s)); + &s, &s, &s, &s)); if (batchsize != static_cast<size_t>(n)) InitCudnn(input); CHECK(input.shape(1) == static_cast<size_t>(c) - && input.shape(2) == static_cast<size_t>(h) - && input.shape(3) == static_cast<size_t>(w)) - << "input sample shape should not change" - << "previous shape " << c << ", " << h << ", " << w - << "current shape " << input.shape(1) << ", " << input.shape(2) << ", " - << input.shape(3); + && input.shape(2) == static_cast<size_t>(h) + && input.shape(3) == static_cast<size_t>(w)) + << "input sample shape should not change" + << "previous shape " << c << ", " << h << ", " << w + << "current shape " << input.shape(1) << ", " << input.shape(2) << ", " + << input.shape(3); } Shape shape{batchsize, num_filters_, conv_height_, conv_width_}; Tensor output(shape, dev, dtype); - output.device()->Exec([input, output, this](Context *ctx) { + output.device()->Exec([input, output, this](Context * ctx) { Block *inblock = input.block(), *outblock = output.block(), - *wblock = this->weight_.block(); + *wblock = this->weight_.block(); float alpha = 1.f, beta = 0.f; cudnnConvolutionForward(ctx->cudnn_handle, &alpha, this->x_desc_, inblock->data(), this->filter_desc_, wblock->data(), @@ -197,7 +200,7 @@ const Tensor CudnnConvolution::Forward(int flag, const Tensor &input) { }, {input.block(), weight_.block()}, {output.block()}, workspace_.block()); if (bias_term_) { - output.device()->Exec([output, this](Context *ctx) { + output.device()->Exec([output, this](Context * ctx) { float beta = 1.f, alpha = 1.0f; Block *outblock = output.block(), *bblock = this->bias_.block(); cudnnAddTensor(ctx->cudnn_handle, &alpha, this->bias_desc_, @@ -209,7 +212,7 @@ const Tensor CudnnConvolution::Forward(int flag, const Tensor &input) { } const std::pair<Tensor, vector<Tensor>> CudnnConvolution::Backward( - int flag, const Tensor &grad) { +int flag, const Tensor &grad) { CHECK(has_init_cudnn_); CHECK_EQ(grad.device()->lang(), kCuda); CHECK_EQ(grad.nDim(), 4u); @@ -225,7 +228,7 @@ const std::pair<Tensor, vector<Tensor>> CudnnConvolution::Backward( // LOG(ERROR) << "backward bias"; if (bias_term_) { db.ResetLike(bias_); - dx.device()->Exec([grad, db, this](Context *ctx) { + dx.device()->Exec([grad, db, this](Context * ctx) { Block *dyblock = grad.block(), *dbblock = db.block(); float alpha = 1.f, beta = 0.f; cudnnConvolutionBackwardBias(ctx->cudnn_handle, &alpha, this->y_desc_, @@ -234,22 +237,22 @@ const std::pair<Tensor, vector<Tensor>> CudnnConvolution::Backward( }, {grad.block()}, {db.block()}); } // LOG(ERROR) << "backward w"; - dx.device()->Exec([grad, dw, src_data, this](Context *ctx) { + dx.device()->Exec([grad, dw, src_data, this](Context * ctx) { Block *inblock = src_data.block(), *dyblock = grad.block(), - *dwblock = dw.block(); + *dwblock = dw.block(); float alpha = 1.f, beta = 0.f; cudnnConvolutionBackwardFilter( - ctx->cudnn_handle, &alpha, this->x_desc_, inblock->data(), - this->y_desc_, dyblock->data(), this->conv_desc_, this->bp_filter_alg_, - this->workspace_.block()->mutable_data(), - this->workspace_count_ * sizeof(float), &beta, this->filter_desc_, - dwblock->mutable_data()); + ctx->cudnn_handle, &alpha, this->x_desc_, inblock->data(), + this->y_desc_, dyblock->data(), this->conv_desc_, this->bp_filter_alg_, + this->workspace_.block()->mutable_data(), + this->workspace_count_ * sizeof(float), &beta, this->filter_desc_, + dwblock->mutable_data()); }, {grad.block(), src_data.block()}, {dw.block(), workspace_.block()}); // LOG(ERROR) << "backward src"; - dx.device()->Exec([dx, grad, this](Context *ctx) { + dx.device()->Exec([dx, grad, this](Context * ctx) { Block *wblock = this->weight_.block(), *dyblock = grad.block(), - *dxblock = dx.block(); + *dxblock = dx.block(); float alpha = 1.f, beta = 0.f; cudnnConvolutionBackwardData(ctx->cudnn_handle, &alpha, this->filter_desc_, wblock->data(), this->y_desc_, dyblock->data(), http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/e2092030/src/model/layer/cudnn_rnn.cc ---------------------------------------------------------------------- diff --git a/src/model/layer/cudnn_rnn.cc b/src/model/layer/cudnn_rnn.cc index fb5fee0..28a52c5 100644 --- a/src/model/layer/cudnn_rnn.cc +++ b/src/model/layer/cudnn_rnn.cc @@ -125,8 +125,8 @@ void CudnnRNN::SetRNNDescriptor(shared_ptr<Device> dev) { CUDNN_CHECK(cudnnDropoutGetStatesSize(ctx->cudnn_handle, &state_size)); dropout_state_ = Tensor(Shape{state_size}, dev, kChar); CUDNN_CHECK(cudnnSetDropoutDescriptor( - dropout_desc_, ctx->cudnn_handle, 1 - dropout_, // keep probability - dropout_state_.block()->mutable_data(), state_size, seed_)); + dropout_desc_, ctx->cudnn_handle, 1 - dropout_, // keep probability + dropout_state_.block()->mutable_data(), state_size, seed_)); CUDNN_CHECK(cudnnCreateRNNDescriptor(&rnn_desc_)); cudnnRNNInputMode_t input_mode = CUDNN_LINEAR_INPUT; @@ -144,10 +144,15 @@ void CudnnRNN::SetRNNDescriptor(shared_ptr<Device> dev) { rnn_mode = CUDNN_RNN_TANH; else if (rnn_mode_ == "gru") rnn_mode = CUDNN_GRU; +#ifdef CUDNN_MAJOR == 5 + CUDNN_CHECK(cudnnSetRNNDescriptor(rnn_desc_, hidden_size_, num_stacks_, + dropout_desc_, input_mode, direction, + rnn_mode, dtype_)); +#else CUDNN_CHECK(cudnnSetRNNDescriptor(ctx->cudnn_handle, rnn_desc_, hidden_size_, num_stacks_, dropout_desc_, input_mode, direction, rnn_mode, CUDNN_RNN_ALGO_STANDARD, dtype_)); - +#endif size_t weight_size; CUDNN_CHECK(cudnnGetRNNParamsSize(ctx->cudnn_handle, rnn_desc_, x_descs_[0], &weight_size, dtype_)); @@ -199,7 +204,7 @@ void CudnnRNN::UpdateSpaces(size_t seq_length, shared_ptr<Device> dev) { } CUDNN_CHECK(cudnnGetRNNTrainingReserveSize(ctx->cudnn_handle, rnn_desc_, - seq_length, x_descs_, &count)); + seq_length, x_descs_, &count)); if (reserve_space_.Size() != count) { reserve_space_ = Tensor(Shape{count}, dev, kChar); // reserve_space_.SetValue(0); @@ -263,8 +268,8 @@ const vector<Tensor> CudnnRNN::Forward(int flag, const vector<Tensor> &inputs) { if (rnn_desc_ != nullptr) CHECK_EQ(dtype_, GetCudnnDataType(dtype)) - << "Cannot change cudnn data type during training from " << dtype_ - << " to " << GetCudnnDataType(dtype); + << "Cannot change cudnn data type during training from " << dtype_ + << " to " << GetCudnnDataType(dtype); else dtype_ = GetCudnnDataType(dtype); @@ -303,57 +308,57 @@ const vector<Tensor> CudnnRNN::Forward(int flag, const vector<Tensor> &inputs) { // LOG(INFO) << "hidden size " << hy.Size(); // LOG(INFO) << "weight size " << weight_.Size() << " value " << weight_.L1(); Block *inb = input.block(), *outb = output.block(), - *wb = this->weight_.block(), *hxb = hx.block(), *cxb = cx.block(), - *hyb = hy.block(), *cyb = cy.block(), - *wspace = this->workspace_.block(), - *rspace = this->reserve_space_.block(); + *wb = this->weight_.block(), *hxb = hx.block(), *cxb = cx.block(), + *hyb = hy.block(), *cyb = cy.block(), + *wspace = this->workspace_.block(), + *rspace = this->reserve_space_.block(); if (flag & kTrain) { CHECK_EQ(reserve_space_.device()->lang(), kCuda); CHECK_EQ(did, reserve_space_.device()->id()); dev->Exec( - [inb, outb, wb, hxb, cxb, hyb, cyb, wspace, rspace, this](Context *ctx) { - // clang-format off - cudnnRNNForwardTraining( - ctx->cudnn_handle, - this->rnn_desc_, - this->seq_length_, - this->x_descs_, inb->data(), - this->hx_desc_, hxb == nullptr ? nullptr : hxb->data(), - this->cx_desc_, cxb == nullptr ? nullptr : cxb->data(), - this->weight_desc_, wb->data(), - this->y_descs_, outb->mutable_data(), - this->hy_desc_, hyb->mutable_data(), - this->cy_desc_, cyb == nullptr ? nullptr : cyb->mutable_data(), - wspace->mutable_data(), - this->workspace_.Size(), rspace->mutable_data(), - this->reserve_space_.Size()); - // clang-format on - }, - {inb, wb, hxb, cxb}, {outb, hyb, cyb, wspace, rspace}); + [inb, outb, wb, hxb, cxb, hyb, cyb, wspace, rspace, this](Context * ctx) { + // clang-format off + cudnnRNNForwardTraining( + ctx->cudnn_handle, + this->rnn_desc_, + this->seq_length_, + this->x_descs_, inb->data(), + this->hx_desc_, hxb == nullptr ? nullptr : hxb->data(), + this->cx_desc_, cxb == nullptr ? nullptr : cxb->data(), + this->weight_desc_, wb->data(), + this->y_descs_, outb->mutable_data(), + this->hy_desc_, hyb->mutable_data(), + this->cy_desc_, cyb == nullptr ? nullptr : cyb->mutable_data(), + wspace->mutable_data(), + this->workspace_.Size(), rspace->mutable_data(), + this->reserve_space_.Size()); + // clang-format on + }, + {inb, wb, hxb, cxb}, {outb, hyb, cyb, wspace, rspace}); buf_.push(input); buf_.push(output); buf_.push(hx); buf_.push(cx); } else { - dev->Exec([inb, outb, wb, hxb, cxb, hyb, cyb, wspace, this](Context *ctx) { + dev->Exec([inb, outb, wb, hxb, cxb, hyb, cyb, wspace, this](Context * ctx) { // clang-format off cudnnRNNForwardInference( - ctx->cudnn_handle, - this->rnn_desc_, - this->seq_length_, - this->x_descs_, inb->data(), - this->hx_desc_, hxb == nullptr ? nullptr : hxb->data(), - this->cx_desc_, cxb == nullptr ? nullptr : cxb->data(), - this->weight_desc_, wb->data(), - this->y_descs_, outb->mutable_data(), - this->hy_desc_, hyb->mutable_data(), - this->cy_desc_, cyb == nullptr ? nullptr : cyb->mutable_data(), - wspace->mutable_data(), this->workspace_.Size()); + ctx->cudnn_handle, + this->rnn_desc_, + this->seq_length_, + this->x_descs_, inb->data(), + this->hx_desc_, hxb == nullptr ? nullptr : hxb->data(), + this->cx_desc_, cxb == nullptr ? nullptr : cxb->data(), + this->weight_desc_, wb->data(), + this->y_descs_, outb->mutable_data(), + this->hy_desc_, hyb->mutable_data(), + this->cy_desc_, cyb == nullptr ? nullptr : cyb->mutable_data(), + wspace->mutable_data(), this->workspace_.Size()); // clang-format on }, {inb, wb, hxb, cxb}, {outb, hyb, cyb, wspace}); } auto outputs = - SplitOutput(num_x, hidden_size_ * num_directions_, inputs, output); + SplitOutput(num_x, hidden_size_ * num_directions_, inputs, output); outputs.push_back(hy); if (has_cell_) outputs.push_back(cy); return outputs; @@ -361,7 +366,7 @@ const vector<Tensor> CudnnRNN::Forward(int flag, const vector<Tensor> &inputs) { // TODO(wangwei) check Tensor device to be on cuda? const std::pair<vector<Tensor>, vector<Tensor>> CudnnRNN::Backward( - int flag, const vector<Tensor> &grads) { +int flag, const vector<Tensor> &grads) { // dhy (and dcy) is at last const Tensor cx = buf_.top(); // cannot use const Tensor& due to pop() buf_.pop(); @@ -395,45 +400,45 @@ const std::pair<vector<Tensor>, vector<Tensor>> CudnnRNN::Backward( dcx.ResetLike(dhx); dw.SetValue(0.0f); Block *yb = y.block(), *dyb = dy.block(), *dhyb = dhy.block(), - *dcyb = dcy.block(), *xb = x.block(), *cxb = cx.block(), - *wb = weight_.block(), *dwb = dw.block(), *hxb = hx.block(), - *dxb = dx.block(), *dhxb = dhx.block(), *dcxb = dcx.block(), - *wspace = workspace_.block(), *rspace = reserve_space_.block(); + *dcyb = dcy.block(), *xb = x.block(), *cxb = cx.block(), + *wb = weight_.block(), *dwb = dw.block(), *hxb = hx.block(), + *dxb = dx.block(), *dhxb = dhx.block(), *dcxb = dcx.block(), + *wspace = workspace_.block(), *rspace = reserve_space_.block(); y.device()->Exec( - [yb, dyb, dhyb, dcyb, xb, cxb, wb, dwb, hxb, dxb, dhxb, dcxb, wspace, - rspace, this](Context *ctx) { - // clang-format off - cudnnRNNBackwardData( - ctx->cudnn_handle, - this->rnn_desc_, - this->seq_length_, - this->y_descs_, yb->data(), - this->dy_descs_, dyb->data(), - this->dhy_desc_, dhyb == nullptr ? nullptr : dhyb->data(), - this->dcy_desc_, dcyb == nullptr ? nullptr : dcyb->data(), - this->weight_desc_, wb->data(), - this->hx_desc_, hxb == nullptr ? nullptr : hxb->data(), - this->cx_desc_, cxb == nullptr ? nullptr : cxb->data(), - this->dx_descs_, dxb->mutable_data(), - this->dhx_desc_, dhxb->mutable_data(), - this->dcx_desc_, dcxb == nullptr ? nullptr : dcxb->mutable_data(), - wspace->mutable_data(), this->workspace_.Size(), - rspace->mutable_data(), this->reserve_space_.Size()); - cudnnRNNBackwardWeights( - ctx->cudnn_handle, - this->rnn_desc_, - this->seq_length_, - this->x_descs_, xb->data(), - this->hx_desc_, hxb == nullptr ? nullptr : hxb->data(), - this->y_descs_, yb->data(), - wspace->data(), this->workspace_.Size(), - this->dweight_desc_, dwb->mutable_data(), - rspace->data(), this->reserve_space_.Size()); - // clang-format on - }, - {yb, dyb, dhyb, dcyb, xb, wb, wspace, rspace}, - {dxb, dwb, dhxb, dcxb, wspace, rspace}); + [yb, dyb, dhyb, dcyb, xb, cxb, wb, dwb, hxb, dxb, dhxb, dcxb, wspace, + rspace, this](Context * ctx) { + // clang-format off + cudnnRNNBackwardData( + ctx->cudnn_handle, + this->rnn_desc_, + this->seq_length_, + this->y_descs_, yb->data(), + this->dy_descs_, dyb->data(), + this->dhy_desc_, dhyb == nullptr ? nullptr : dhyb->data(), + this->dcy_desc_, dcyb == nullptr ? nullptr : dcyb->data(), + this->weight_desc_, wb->data(), + this->hx_desc_, hxb == nullptr ? nullptr : hxb->data(), + this->cx_desc_, cxb == nullptr ? nullptr : cxb->data(), + this->dx_descs_, dxb->mutable_data(), + this->dhx_desc_, dhxb->mutable_data(), + this->dcx_desc_, dcxb == nullptr ? nullptr : dcxb->mutable_data(), + wspace->mutable_data(), this->workspace_.Size(), + rspace->mutable_data(), this->reserve_space_.Size()); + cudnnRNNBackwardWeights( + ctx->cudnn_handle, + this->rnn_desc_, + this->seq_length_, + this->x_descs_, xb->data(), + this->hx_desc_, hxb == nullptr ? nullptr : hxb->data(), + this->y_descs_, yb->data(), + wspace->data(), this->workspace_.Size(), + this->dweight_desc_, dwb->mutable_data(), + rspace->data(), this->reserve_space_.Size()); + // clang-format on + }, + {yb, dyb, dhyb, dcyb, xb, wb, wspace, rspace}, + {dxb, dwb, dhxb, dcxb, wspace, rspace}); vector <Tensor> param_grad{dw}; auto data_grads = SplitOutput(num_dy, input_size_, grads, dx);
