[
https://issues.apache.org/jira/browse/SINGA-218?page=com.atlassian.jira.plugin.system.issuetabpanels:comment-tabpanel&focusedCommentId=15413850#comment-15413850
]
ASF subversion and git services commented on SINGA-218:
-------------------------------------------------------
Commit 8e0b1083992f471849bb80b0a8e869767ee9edc0 in incubator-singa's branch
refs/heads/dev from [~wangwei.cs]
[ https://git-wip-us.apache.org/repos/asf?p=incubator-singa.git;h=8e0b108 ]
SINGA-218 Implementation for RNN CUDNN version
Finish the CudnnRNN layer.
Pass test for tanh rnn.
RNN forward accepts a vector of input tensors: <x0, x1, ... x(n-1), hx, cx>
x(i) is the i-th input tensor, hx is the init hidden tensor which could
be a dummy tensor. A dummy tensor is a tensor created without
shape/device/data_type,
during compuation, cudnnRNN would use 0s for this tensor. cx is not necessary
for relu/tanh/gru rnn. For lstm, it could also be a dummy tensor like hx.
The output is: <y0, y1, ... y(n-1), hy, cy>.
relu/tanh/gru rnns does not have cy. lstm have both hy and cy.
RNN backward accepts a vector of input gradient tensors: <dy0, dy1, ...
dy(n-1), dhy, dcy>.
dhy is necessry for all rnns, but could be a dummy tensor, in which case
a tensor with 0s would be used for dhy during computation. dcy is used
only for lstm, which could also be a dummy tensor.
The output is: <dw, <dx0, dx1, ... dx(n-1), dhx, dcx>>,
where dhx is a tensor for the gradient of hx. dcx is only used for lstm.
The CudnnRNN must be moved onto cuda, otherwise memory error would happen (the
weight is on cpu).
> Implementation for RNN CUDNN version
> ------------------------------------
>
> Key: SINGA-218
> URL: https://issues.apache.org/jira/browse/SINGA-218
> Project: Singa
> Issue Type: New Feature
> Reporter: ZHAOJING
>
> (1) cudnn rnn implementation (cudnn_rnn,h, cudnn_rnn.cc, rnn.cc, rnn.h,
> test_cudnn_rnn.cc).
> (2) The weight shape now are manually calculated instead of using API
> provided by CUDNN.
> (3) Test for RNN_cudnn_Tanh (unidirectional, 1 hidden layer).
--
This message was sent by Atlassian JIRA
(v6.3.4#6332)