yuanfz98 opened a new pull request, #12017:
URL: https://github.com/apache/tvm/pull/12017

   Hello,
   
   This PR supports aten::rnn_tanh, aten::rnn_relu. The idea is from the 
previous implementation of GRU and LSTM in relay.
   
   Links to issue #11827
   
   ```
   def test_RNN_torch(num_layers: int,
                                       bidirectional: bool,
                                       use_bias: bool,
                                       hidden_size: int,
                                       input_size: int,
                                       seq_len: int,
                                       batch_first: bool,
                                       batch_size: int):
       r''' 
       Args:
           num_layers (int): num_layers to be passed to torch.nn.RNN
           bidirectional (bool): whether to build bidirectional RNN or not
           use_bias (bool): whether to use bias or not
           hidden_size (int): hidden_size of RNN cells
           input_size (int): Input features
           seq_len (int): Timesteps in input data
           batch_first (bool): Whether batch dimension is first or second 
dimension in input tensor
           batch_size (int): Batch size of input. If 0, unbatched input will be 
fed to network
       '''
   
       if batch_first:
           input_shape = (batch_size, seq_len, input_size)
       else:
           input_shape = (seq_len, batch_size, input_size)
       pytorch_net = torch.nn.Sequential(
           torch.nn.RNN(input_size,
                        hidden_size,
                        batch_first=batch_first,
                        num_layers=num_layers,
                        bidirectional=bidirectional,
                        bias=use_bias)
       )
   
       scripted_model = torch.jit.trace(pytorch_net.eval(),
                                        torch.randn(input_shape))
   
       mod, params = relay.frontend.from_pytorch(scripted_model,
                                                 [('input', input_shape)])
       mod = relay.transform.InferType()(mod)
       print(mod.astext())
   
   if __name__ == "__main__":
   
       test_RNN_torch(1,
                      False,
                      True,
                      5,
                      5,
                      15,
                      True,
                      32)
   
   ```
   
   Out:
   
   ```
   #[version = "0.0.5"]
   type List[A] {
     Cons(A, List[A]),
     Nil,
   }
   
   type Option[A] {
     Some(A),
     None,
   }
   
   type Tree[A] {
     Rose(A, List[Tree[A]]),
   }
   
   type tensor_float16_t {
     tensor_nil_float16,
     tensor0_float16(float16),
     tensor1_float16(Tensor[(?), float16]),
     tensor2_float16(Tensor[(?, ?), float16]),
     tensor3_float16(Tensor[(?, ?, ?), float16]),
     tensor4_float16(Tensor[(?, ?, ?, ?), float16]),
     tensor5_float16(Tensor[(?, ?, ?, ?, ?), float16]),
     tensor6_float16(Tensor[(?, ?, ?, ?, ?, ?), float16]),
   }
   
   type tensor_float32_t {
     tensor_nil_float32,
     tensor0_float32(float32),
     tensor1_float32(Tensor[(?), float32]),
   ...
   ```


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to