yuanfz98 opened a new pull request, #12017:
URL: https://github.com/apache/tvm/pull/12017
Hello,
This PR supports aten::rnn_tanh, aten::rnn_relu. The idea is from the
previous implementation of GRU and LSTM in relay.
Links to issue #11827
```
def test_RNN_torch(num_layers: int,
bidirectional: bool,
use_bias: bool,
hidden_size: int,
input_size: int,
seq_len: int,
batch_first: bool,
batch_size: int):
r'''
Args:
num_layers (int): num_layers to be passed to torch.nn.RNN
bidirectional (bool): whether to build bidirectional RNN or not
use_bias (bool): whether to use bias or not
hidden_size (int): hidden_size of RNN cells
input_size (int): Input features
seq_len (int): Timesteps in input data
batch_first (bool): Whether batch dimension is first or second
dimension in input tensor
batch_size (int): Batch size of input. If 0, unbatched input will be
fed to network
'''
if batch_first:
input_shape = (batch_size, seq_len, input_size)
else:
input_shape = (seq_len, batch_size, input_size)
pytorch_net = torch.nn.Sequential(
torch.nn.RNN(input_size,
hidden_size,
batch_first=batch_first,
num_layers=num_layers,
bidirectional=bidirectional,
bias=use_bias)
)
scripted_model = torch.jit.trace(pytorch_net.eval(),
torch.randn(input_shape))
mod, params = relay.frontend.from_pytorch(scripted_model,
[('input', input_shape)])
mod = relay.transform.InferType()(mod)
print(mod.astext())
if __name__ == "__main__":
test_RNN_torch(1,
False,
True,
5,
5,
15,
True,
32)
```
Out:
```
#[version = "0.0.5"]
type List[A] {
Cons(A, List[A]),
Nil,
}
type Option[A] {
Some(A),
None,
}
type Tree[A] {
Rose(A, List[Tree[A]]),
}
type tensor_float16_t {
tensor_nil_float16,
tensor0_float16(float16),
tensor1_float16(Tensor[(?), float16]),
tensor2_float16(Tensor[(?, ?), float16]),
tensor3_float16(Tensor[(?, ?, ?), float16]),
tensor4_float16(Tensor[(?, ?, ?, ?), float16]),
tensor5_float16(Tensor[(?, ?, ?, ?, ?), float16]),
tensor6_float16(Tensor[(?, ?, ?, ?, ?, ?), float16]),
}
type tensor_float32_t {
tensor_nil_float32,
tensor0_float32(float32),
tensor1_float32(Tensor[(?), float32]),
...
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]