This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new b819364  [Frontend] [Torch] [ONNX] GRU layer (#8781)
b819364 is described below

commit b8193646fa9f97fc3476b5275d8ce8b0270408a3
Author: Valery Chernov <[email protected]>
AuthorDate: Wed Aug 25 10:48:14 2021 +0300

    [Frontend] [Torch] [ONNX] GRU layer (#8781)
    
    * GRU cell was implemented in common.py. GRU was supported on pytorch 
frontend side
    
    * update GRU in common.py and onnx frontend
    
    * fix issue related to GRU accuracy in pytorch and ONNX frontend
    
    * small fixes and remove excess
    
    * common GRU was additionaly updated. tuned pytorch GRU was strongly 
accelerated
    
    * GRU cell in ONNX frontend was used from common.py. previous 
implementation was removed
    
    * small fixes in comments
    
    * fixes after review. GRU test was implemented for pytorch frontend
    
    * tests for RNN layers was unified for pytorch frontend
    
    Co-authored-by: Valery Chernov <[email protected]>
---
 python/tvm/relay/frontend/common.py         |  84 ++++++
 python/tvm/relay/frontend/onnx.py           | 149 +++++-----
 python/tvm/relay/frontend/pytorch.py        | 189 +++++++++++-
 tests/python/frontend/pytorch/test_lstms.py | 363 -----------------------
 tests/python/frontend/pytorch/test_rnns.py  | 430 ++++++++++++++++++++++++++++
 5 files changed, 774 insertions(+), 441 deletions(-)

diff --git a/python/tvm/relay/frontend/common.py 
b/python/tvm/relay/frontend/common.py
index 077b942..ce048105 100755
--- a/python/tvm/relay/frontend/common.py
+++ b/python/tvm/relay/frontend/common.py
@@ -658,6 +658,90 @@ def unbind(data, axis=0):
     return _expr.TupleWrapper(_expr.Tuple(ret), selections)
 
 
+def gru_cell(
+    input_seqs,
+    hidden_state,
+    w_inp,
+    w_hid,
+    b_inp=None,
+    b_hid=None,
+    rz_act=_op.sigmoid,
+    n_act=_op.tanh,
+    backwards=False,
+    linear_before_reset=True,
+):
+    """
+    Common implementation of GRU cell for all frontends of TVM
+    TODO(vvchernov): currently it is used by pytorch and ONNX. Extend for 
other frontends
+
+    Parameters
+    ----------
+    input_seqs : List[relay.Expr]
+        The sequence of input tensors
+        Input tensor should be 2d while issue #8412 is not resolved
+        Shape = (batch, feature_size)
+    hidden_state : relay.Expr
+        Hidden state. shape = (batch_size, hidden_size)
+    w_inp, w_hid : relay.Expr
+        weight matrices. wi shape = (3 * hidden_size, feature_size)
+        wh shape = (3 * hidden_size, hidden_size)
+        NOTE: wi = (w_ir|w_iz|w_in) for reset, update and new gates.
+        The order is important for correct GRU calculation!
+    b_inp, b_hid : relay.Expr
+        bias matrices. The same order of internal parts as for weights. shape 
= (3 * hidden_size)
+    r_act : relay.op
+        activation funtion for reset gate. it is sigmoid by default
+    z_act : relay.op
+        activation funtion for update gate. it is sigmoid by default
+    n_act : relay.op
+        activation funtion for new gate. it is tanh by default
+    backwards : bool
+        Flag for reverse pass of GRU
+
+    Returns
+    -------
+    result : List[relay.Expr], relay.Expr, relay.Expr
+        The sequence of computed result, final hidden and cell state
+    """
+
+    outputs_list = []
+    for x_t in input_seqs if not backwards else reversed(input_seqs):
+        xwt = _op.nn.dense(x_t, w_inp)
+        if linear_before_reset:
+            hwt = _op.nn.dense(hidden_state, w_hid)
+            if b_inp is not None and b_hid is not None:
+                xwt += b_inp
+                hwt += b_hid
+            i_r, i_z, i_n = _op.split(xwt, 3, axis=-1)
+            h_r, h_z, h_n = _op.split(hwt, 3, axis=-1)
+            r_gate = rz_act(i_r + h_r)
+            z_gate = rz_act(i_z + h_z)
+            n_gate = n_act(i_n + r_gate * h_n)
+        else:
+            i_r, i_z, i_n = _op.split(xwt, 3, axis=1)
+            w_hr, w_hz, w_hn = _op.split(w_hid, 3, axis=0)
+            r_gate = i_r + _op.nn.dense(hidden_state, w_hr)
+            z_gate = i_z + _op.nn.dense(hidden_state, w_hz)
+            if b_inp is not None and b_hid is not None:
+                b_ir, b_iz, b_in = _op.split(b_inp, 3, axis=-1)
+                b_hr, b_hz, b_hn = _op.split(b_hid, 3, axis=-1)
+                r_gate += b_ir + b_hr
+                z_gate += b_iz + b_hz
+                i_n += b_in
+                h_n = _op.nn.dense((r_gate * hidden_state), w_hn) + b_hn
+            else:
+                h_n = _op.nn.dense((r_gate * hidden_state), w_hn)
+            r_gate = rz_act(r_gate)
+            z_gate = rz_act(z_gate)
+            n_gate = n_act(i_n + h_n)
+
+        hidden_state = (hidden_state - n_gate) * z_gate + n_gate
+
+        outputs_list.append(hidden_state)  # [seq_num, (batch, hidden_size)]
+
+    return outputs_list, hidden_state
+
+
 def lstm_cell(
     input_seqs,
     hidden_state,
diff --git a/python/tvm/relay/frontend/onnx.py 
b/python/tvm/relay/frontend/onnx.py
index 0f78c32..5471f67 100644
--- a/python/tvm/relay/frontend/onnx.py
+++ b/python/tvm/relay/frontend/onnx.py
@@ -47,6 +47,7 @@ from .common import (
     infer_value,
     new_var,
     unbind,
+    gru_cell,
     lstm_cell,
 )
 
@@ -2349,56 +2350,41 @@ class GRU(RNN):
     """Operator convert for GRU"""
 
     @classmethod
-    def generate_gru(
-        cls, X_steps, H_t, W, R, B, linear_before_reset, f_act, g_act, 
W_dtype, backwards=False
+    def bidir_gru_cell(
+        cls,
+        input_seqs,
+        weight_dicts,
+        acts,
     ):
-        """Create an unrolled gru loop.
-
-        See https://github.com/onnx/onnx/blob/master/docs/Operators.md for 
math.
         """
-        h_list = []
-        seq_length = len(X_steps)
-        for i in range(seq_length):
-            step = X_steps[i] if not backwards else X_steps[seq_length - (i + 
1)]
-            step = _op.squeeze(step, axis=[0])
-            current = _op.nn.dense(step, W)
-            cz, cr, ch = _op.split(current, 3, axis=1)
-            rz, rr, rh = _op.split(R, 3, axis=0)
-            z = cz + _op.nn.dense(H_t, rz)
-            r = cr + _op.nn.dense(H_t, rr)
-            if B is not None:
-                WB, RB = _op.split(B, 2)
-                wbz, wbr, wbh = _op.split(WB, 3, axis=-1)
-                rbz, rbr, rbh = _op.split(RB, 3, axis=-1)
-                z += wbz + rbz
-                r += wbr + rbr
-                if linear_before_reset:
-                    h = ch + (r * (_op.nn.dense(H_t, rh) + rbh)) + wbh
-                else:
-                    h = ch + _op.nn.dense((r * H_t), rh) + wbh + rbh
-            else:
-                if linear_before_reset:
-                    h = ch + (r * (_op.nn.dense(H_t, rh)))
-                else:
-                    h = ch + _op.nn.dense((r * H_t), rh)
-
-            z = f_act(z)
-            r = f_act(r)
-            h = g_act(h)
-
-            H_t = ((_expr.const(1, dtype=W_dtype) - z) * h) + (z * H_t)
-            h_list.append(_op.expand_dims(H_t, axis=0))
+        Bidirectional GRU cell
+        """
+        seq_len = len(input_seqs)
+        forward_outputs, fw_H_t = gru_cell(
+            input_seqs,
+            **weight_dicts[0],
+            rz_act=acts[0],
+            n_act=acts[1],
+        )
 
-        if backwards:
-            # Canonical view is hidden states from the first token not last
-            h_list = h_list[::-1]
+        reverse_outputs, rev_H_t = gru_cell(
+            input_seqs,
+            **weight_dicts[1],
+            rz_act=acts[2],
+            n_act=acts[3],
+            backwards=True,
+        )
 
-        # Concatenate outputs and add back in direction axis.
-        concatenated = _op.concatenate(h_list, 0)
-        output = _op.expand_dims(concatenated, axis=1)
-        H_t = _op.expand_dims(H_t, axis=0)
+        final_outputs = []
+        for i in range(seq_len):
+            final_outputs.append(
+                _op.stack([forward_outputs[i], reverse_outputs[seq_len - 1 - 
i]], axis=0)
+            )
 
-        return output, H_t
+        return (
+            _op.stack(final_outputs, axis=0),
+            _op.stack([fw_H_t, rev_H_t], axis=0),
+        )
 
     @classmethod
     def _impl_v7(cls, inputs, attr, params):
@@ -2416,20 +2402,14 @@ class GRU(RNN):
         W_dtype = infer_type(Wp).checked_type.dtype
 
         if num_directions not in [1, 2]:
-            raise NotImplementedError(
-                f"Directions for GRUs should be either 1 or 2 got 
{num_directions}"
-            )
+            raise ValueError("num_directions must be either 1 or 2!")
 
         X_shape = infer_shape(X)
         hidden_size = infer_shape(Rp)[-1]
         batch_size = X_shape[1]
 
-        # Initialize state if not provided.
-        # Otherwise remove bidirectional axis.
         if Hp_0 is None:
             Hp_0 = _op.zeros((num_directions, batch_size, hidden_size), 
W_dtype)
-        if Bp is None:
-            Bp = _op.zeros((num_directions, hidden_size * 6), W_dtype)
 
         if "activations" in attr:
             activations = attr["activations"]
@@ -2460,39 +2440,54 @@ class GRU(RNN):
         else:
             acts = [_op.sigmoid, _op.tanh] * 2
 
-        result_output = []
-        result_H = []
+        # TODO (vvchernov): It can be replaced by _op.split if issue #8412 is 
resolved
+        X_steps = unbind(X, axis=0)
 
-        X_steps = _op.split(X, indices_or_sections=X_shape[0], axis=0)
         H_ts = _op.split(Hp_0, num_directions)
         Ws = _op.split(Wp, num_directions)
         Rs = _op.split(Rp, num_directions)
-        Bs = _op.split(Bp, num_directions)
 
+        if Bp is not None:
+            Bs = _op.split(Bp, num_directions)
+
+        weights_dicts = []
         for i in range(num_directions):
-            H_t = _op.squeeze(H_ts[i], axis=[0])
-            W = _op.squeeze(Ws[i], axis=[0])
-            R = _op.squeeze(Rs[i], axis=[0])
-            B = _op.squeeze(Bs[i], axis=[0])
-            f_act, g_act = acts[i * 2 : (i + 1) * 2]
-            output, H = GRU.generate_gru(
-                X_steps=X_steps,
-                H_t=H_t,
-                W=W,
-                R=R,
-                B=B,
-                linear_before_reset=linear_before_reset,
-                f_act=f_act,
-                g_act=g_act,
-                W_dtype=W_dtype,
-                backwards=i == 1,
-            )
+            weights_dict = {}
+
+            weights_dict["hidden_state"] = _op.squeeze(H_ts[i], axis=[0])
+            weights_dict["linear_before_reset"] = linear_before_reset
+
+            # Weights permutation: onnx format i-o-f-c, lstm cell format 
i-f-c-o
+            matz, matr, matn = _op.split(_op.squeeze(Ws[i], axis=[0]), 3)
+            weights_dict["w_inp"] = _op.concatenate([matr, matz, matn], axis=0)
+            matz, matr, matn = _op.split(_op.squeeze(Rs[i], axis=[0]), 3)
+            weights_dict["w_hid"] = _op.concatenate([matr, matz, matn], axis=0)
+            if Bp is not None:
+                Bi, Bh = _op.split(Bs[i], 2, -1)
+                matz, matr, matn = _op.split(_op.squeeze(Bi, axis=[0]), 3)
+                weights_dict["b_inp"] = _op.concatenate([matr, matz, matn], 
axis=0)
+                matz, matr, matn = _op.split(_op.squeeze(Bh, axis=[0]), 3)
+                weights_dict["b_hid"] = _op.concatenate([matr, matz, matn], 
axis=0)
+            weights_dicts.append(weights_dict)
 
-            result_output.append(output)
-            result_H.append(H)
+        if num_directions == 2:
+            output, H = GRU.bidir_gru_cell(
+                input_seqs=X_steps,
+                weight_dicts=weights_dicts,
+                acts=acts,
+            )
+        else:
+            # outputs shape = [seqs_num, (batch_size, hidden_size)]
+            outputs, H = gru_cell(
+                input_seqs=X_steps,
+                **weights_dicts[0],
+                rz_act=acts[0],
+                n_act=acts[1],
+            )
 
-        output = _op.concatenate(result_output, axis=1)
-        H = _op.concatenate(result_H, axis=0)
+            # output shape = (seqs_num, num_directions, batch_size, 
hidden_size)
+            output = _op.expand_dims(_op.stack(outputs, axis=0), axis=1)
+            H = _op.expand_dims(H, axis=0)
 
         return _expr.TupleWrapper(_expr.Tuple((output, H)), 2)
 
diff --git a/python/tvm/relay/frontend/pytorch.py 
b/python/tvm/relay/frontend/pytorch.py
index 7c10889..613643f 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -39,7 +39,7 @@ from ..loops import while_loop
 from ..prelude import Prelude, StaticTensorArrayOps
 from ..ty import Any, TensorType, TupleType
 from . import qnn_torch
-from .common import AttrCvt, get_relay_op, unbind, lstm_cell
+from .common import AttrCvt, get_relay_op, unbind, lstm_cell, gru_cell
 from .common import infer_value as _infer_value
 from .common import infer_shape as _infer_shape
 from .common import infer_value_simulated as _infer_value_simulated
@@ -2315,6 +2315,192 @@ class PyTorchOpConverter:
         axis = inputs[1]
         return _op.transform.reverse(data, axis=axis[0])
 
+    def bidir_gru_cell(
+        self,
+        input_seqs,
+        weights_dicts,
+    ):
+        """
+        Bidirectional GRU cell
+        """
+        seq_len = len(input_seqs)
+        forward_outputs, fw_H_t = gru_cell(
+            input_seqs,
+            **weights_dicts[0],
+        )
+
+        reverse_outputs, rev_H_t = gru_cell(
+            input_seqs,
+            **weights_dicts[1],
+            backwards=True,
+        )
+
+        final_outputs = []
+        for i in range(seq_len):
+            final_outputs.append(
+                _op.concatenate([forward_outputs[i], reverse_outputs[seq_len - 
1 - i]], axis=-1)
+            )
+
+        return final_outputs, _op.stack([fw_H_t, rev_H_t], axis=0)
+
+    def gru_layers(self, input_data, layer_weights_dicts, bidirectional, 
dropout_p=0.0):
+        """
+        Methods iterates layers for Stacked GRU
+        """
+        layers_num = len(layer_weights_dicts)
+        # split input sequence to samples set
+        input_seqs = unbind(input_data, 0)  # [seq_num, (batch, feature_size)]
+        output_hiddens = []
+        for i in range(layers_num):
+            weights_dicts = layer_weights_dicts[i]
+            # input_seqs shape = [seq_num, (batch, feature_size)] or
+            # [seq_num, (batch, 2*feature_size)] for bidirectional
+            if bidirectional:
+                input_seqs, H_t = self.bidir_gru_cell(input_seqs, 
weights_dicts)
+            else:
+                input_seqs, H_t = gru_cell(input_seqs, **weights_dicts[0])
+
+            output_hiddens.append(H_t)
+
+            # TODO (vvchernov): in pytorch implementation train is also checked
+            # see 
https://github.com/pytorch/pytorch/blob/70c8daf43946b53af6493d058899ef952d27d339
+            # /aten/src/ATen/native/RNN.cpp#L1054
+            if dropout_p != 0 and i < layers_num - 1:
+                # for input in input_seqs:
+                #     input = _op.dropout(input, dropout_p)
+                raise NotImplementedError("Dropout for GRU has not been 
supported yet!")
+
+        return _op.stack(input_seqs, 0), _op.stack(output_hiddens, 0)
+
+    def gru(self, inputs, input_types):
+        """
+        Description of GRU in pytorch:
+        
https://pytorch.org/docs/stable/generated/torch.nn.GRU.html?highlight=gru#torch.nn.GRU
+        """
+        # TODO (vvchernov): support dropout
+        assert len(inputs) == 9, "Input of size 9 is expected"
+        # Unpack inputs, note that if optional and not provided then value 
will be None.
+        _X = inputs[0]
+        # _X shape (seq_num, batch, feature_size) or (batch, seq_num, 
feature_size)
+
+        hidden_state = inputs[1]
+        # Hidden state shape (hidden_layers_num, batch, hidden_size)
+
+        _weights = inputs[2]
+        # Wi layer[0] shape (3 * hidden_size, feature_size)
+        # Wh layer[0] shape (3 * hidden_size, hidden_size)
+        # Bi layer[0] shape (3 * hidden_size)
+        # Bh layer[0] shape (3 * hidden_size)
+
+        # Wi layer[>0] shape (3 * hidden_size, hidden_size * num_directions)
+        # Wh layer[>0] shape (3 * hidden_size, hidden_size)
+        # Bi layer[>0] shape (3 * hidden_size)
+        # Bh layer[>0] shape (3 * hidden_size)
+
+        # Scalar inputs
+        has_biases = inputs[3]
+        num_layers = inputs[4]
+        dropout_p = inputs[5]  # dropout probability, if 0.0 it means there is 
no dropout
+        # train = inputs[6]
+        bidirectional = inputs[7]
+        batch_first = inputs[8]
+
+        num_directions = 1
+        if bidirectional:
+            num_directions = 2
+
+        rsd = len(_weights) % num_layers
+        assert rsd == 0, "The number of weights must be a multiple of the 
number of layers!"
+        rsd = (len(_weights) / num_layers) % num_directions
+        assert (
+            rsd == 0
+        ), "The number of weights in layer must be a multiple of the number of 
directions!"
+
+        weights_num = int(len(_weights) / num_layers / num_directions)
+        if has_biases:
+            assert weights_num == 4, "The weights number in layer is expected 
equal to 4"
+        else:
+            assert weights_num == 2, "The weights number in layer is expected 
equal to 2"
+
+        X = _op.transpose(_X, (1, 0, 2)) if batch_first else _X
+        # TODO (vvchernov): Which data type should be used? from input or 
weights?
+        # Instead of it _infer_type(X).checked_type.dtype can be used
+        X_dtype = input_types[0]
+        X_shape = _infer_shape(X)  # (seq_num, batch, feature_size)
+
+        hidden_size = int(_infer_shape(_weights[0])[0] / 3)
+        batch_size = X_shape[1]
+
+        # Initialize hidden states if not provided.
+        layers_h = []
+        hidden_layers_num = num_directions * num_layers
+        if hidden_state is None:
+            h_0 = _op.zeros((batch_size, hidden_size), X_dtype)
+            for i in range(hidden_layers_num):
+                layers_h.append(h_0)
+        else:
+            layers_h = unbind(hidden_state, 0)
+
+        layer_weights_dicts = []
+        k = 0  # layer counter
+        if has_biases:
+            names = ["hidden_state", "w_inp", "w_hid", "b_inp", "b_hid"]
+            if bidirectional:
+                rsd = len(_weights) % (2 * weights_num)
+                assert rsd == 0, "got an incorrect number of GRU weights"
+                for i in range(0, len(_weights), 2 * weights_num):
+                    fw_tensors = [layers_h[2 * k], *_weights[i : i + 4]]
+                    fw_weights_dict = dict(zip(names, fw_tensors))
+                    j = i + weights_num
+                    rev_tensors = [layers_h[2 * k + 1], *_weights[j : j + 4]]
+                    rev_weights_dict = dict(zip(names, rev_tensors))
+                    layer_weights_dicts.append([fw_weights_dict, 
rev_weights_dict])
+                    k += 1
+            else:
+                assert len(_weights) % weights_num == 0, "got an incorrect 
number of GRU weights"
+                for i in range(0, len(_weights), weights_num):
+                    fw_tensors = [layers_h[k], *_weights[i : i + 4]]
+                    fw_weights_dict = dict(zip(names, fw_tensors))
+                    layer_weights_dicts.append([fw_weights_dict])
+                    k += 1
+        else:
+            names = ["hidden_state", "w_inp", "w_hid"]
+            if bidirectional:
+                rsd = len(_weights) % (2 * weights_num)
+                assert rsd == 0, "got an incorrect number of GRU weights"
+                for i in range(0, len(_weights), 2 * weights_num):
+                    fw_tensors = [layers_h[2 * k], *_weights[i : i + 2]]
+                    fw_weights_dict = dict(zip(names, fw_tensors))
+                    j = i + weights_num
+                    rev_tensors = [layers_h[2 * k + 1], *_weights[j : j + 2]]
+                    rev_weights_dict = dict(zip(names, rev_tensors))
+                    layer_weights_dicts.append([fw_weights_dict, 
rev_weights_dict])
+                    k += 1
+            else:
+                assert len(_weights) % weights_num == 0, "got an incorrect 
number of GRU weights"
+                for i in range(0, len(_weights), weights_num):
+                    fw_tensors = [layers_h[k], *_weights[i : i + 2]]
+                    fw_weights_dict = dict(zip(names, fw_tensors))
+                    layer_weights_dicts.append([fw_weights_dict])
+                    k += 1
+        assert (
+            len(layer_weights_dicts) == num_layers and k == num_layers
+        ), "For stacked GRU number of weights sets should be the same as 
number of layers!"
+
+        output, out_hidden_state = self.gru_layers(
+            X,
+            layer_weights_dicts,
+            bidirectional,
+            dropout_p=dropout_p,
+        )
+
+        # output shape = (seq_num, batch, hidden_size) or
+        # (seq_num, batch, 2*feature_size) for bidirectional
+        if batch_first:
+            output = _op.transpose(output, (1, 0, 2))
+
+        return (output, out_hidden_state)
+
     def bidir_lstm_cell(
         self,
         input_seqs,
@@ -2792,6 +2978,7 @@ class PyTorchOpConverter:
             "aten::nll_loss": self.nll_loss,
             "aten::nll_loss2d": self.nll_loss,
             "aten::flip": self.flip,
+            "aten::gru": self.gru,
             "aten::lstm": self.lstm,
         }
 
diff --git a/tests/python/frontend/pytorch/test_lstms.py 
b/tests/python/frontend/pytorch/test_lstms.py
deleted file mode 100644
index 967245e..0000000
--- a/tests/python/frontend/pytorch/test_lstms.py
+++ /dev/null
@@ -1,363 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-import tvm
-import tvm.testing
-import numpy as np
-import torch
-import onnx
-import io
-import sys
-import pytest
-
-from tvm import relay
-from tvm.contrib import graph_executor
-
-from torch import nn
-
-## Model parameters
-model_feature_size = 16
-model_hidden_size = 32
-model_num_layers = 2
-seqs_length = 2
-projection_size = 20
-batch_size = 2
-
-
-def check_torch_version_for_proj_in_lstm():
-    """
-    proj_size parameter is supported in torch.nn.LSTM layer started from 1.8.0 
torch version
-    """
-    me = False
-
-    version = torch.__version__
-    major, minor, micro = version.split(".")
-
-    if int(major) > 1:
-        me = True
-    elif int(major) == 1:
-        if int(minor) >= 8:
-            me = True
-
-    return me
-
-
-class LSTM_Model(nn.Module):
-    def __init__(
-        self,
-        device,
-        batch_first=False,
-        layer_num=1,
-        bidirectional=False,
-        proj_size=0,
-        use_bias=True,
-        rnd_weights_init=False,
-    ):
-        super().__init__()
-
-        self.device = device
-        self.batch_first = batch_first
-        self.use_bias = use_bias
-
-        if check_torch_version_for_proj_in_lstm():
-            self.lstm = nn.LSTM(
-                input_size=model_feature_size,
-                hidden_size=model_hidden_size,
-                num_layers=layer_num,
-                bidirectional=bidirectional,
-                proj_size=proj_size,
-                batch_first=batch_first,
-                bias=use_bias,
-            ).to(device)
-        else:
-            if proj_size > 0:
-                print(
-                    "WARNING: projection is not supported for torch version 
less than 1.8.0! ",
-                    "LSTM was constructed without projection!",
-                )
-                # sys.exit()
-            self.lstm = nn.LSTM(
-                input_size=model_feature_size,
-                hidden_size=model_hidden_size,
-                num_layers=layer_num,
-                bidirectional=bidirectional,
-                batch_first=batch_first,
-                bias=use_bias,
-            ).to(device)
-
-        if rnd_weights_init:
-            self.gen_rnd_weights()
-
-    def forward(self, input, hidden_init=None):
-        """
-        Computes the output tensor after input inference along LSTM layer.
-
-        :param input: batch of data as a tensor of shape (seqs_length, 
batch_size, model_feature_size) or (batch_size, seqs_length, 
model_feature_size) if self.batch_first = True
-        :param hidden_init: initial hidden state of the LSTM as a tensor of 
shape (num_layers, batch_size, hidden_size). Will default to a tensor of zeros 
if None.
-        :return: the output tensor of shape (batch_size, model_hidden_size)
-        """
-        # Pass the input through the LSTM layers and retrieve all outputs, the 
final hidden state
-        # and the final cell state.
-        out, (hidden, cell) = self.lstm(input, hidden_init)
-
-        return out
-
-    def gen_rnd_weights(self):
-        """
-        Generate random weigths for the model with biases
-        Without projection:
-            For first weights group:
-                Wi (4*model_hidden_size, model_feature_size)
-                Wh (4*model_hidden_size, model_hidden_size)
-                Bi (4*model_hidden_size)
-                Bh (4*model_hidden_size)
-            For first bidirectional weights group:
-                Wi (4*model_hidden_size, model_feature_size)
-                Wh (4*model_hidden_size, model_hidden_size)
-                Bi (4*model_hidden_size)
-                Bh (4*model_hidden_size)
-            For other weights group:
-                Wi (4*model_hidden_size, model_hidden_size)
-                Wh (4*model_hidden_size, model_hidden_size)
-                Bi (4*model_hidden_size)
-                Bh (4*model_hidden_size)
-        With projection:
-            For first weights group:
-                Wi (4*model_hidden_size, model_feature_size)
-                Wh (4*model_hidden_size, proj_size)
-                Bi (4*model_hidden_size)
-                Bh (4*model_hidden_size)
-                P  (proj_size, model_hidden_size)
-            For first bidirectional weights group:
-                Wi (4*model_hidden_size, model_feature_size)
-                Wh (4*model_hidden_size, proj_size)
-                Bi (4*model_hidden_size)
-                Bh (4*model_hidden_size)
-                P  (proj_size, model_hidden_size)
-            For other weights group:
-                Wi (4*model_hidden_size, proj_size * num_directions)
-                Wh (4*model_hidden_size, proj_size)
-                Bi (4*model_hidden_size)
-                Bh (4*model_hidden_size)
-                P  (proj_size, model_hidden_size)
-        For generation of random weigths for the model without biases Bi and 
Bh are skipped
-        """
-        for weight_group in self.lstm.all_weights:
-            for weight in weight_group:
-                weight.data = torch.rand(weight.shape)
-
-    def get_dummy_input(self):
-        shape = [seqs_length, batch_size, model_feature_size]
-        if self.batch_first:
-            shape = [batch_size, seqs_length, model_feature_size]
-        res = torch.rand(shape)
-
-        return res, shape
-
-
-def compare(input, gold_data, rtol=1e-5, atol=1e-5):
-    tvm.testing.assert_allclose(input, gold_data, rtol=rtol, atol=atol)
-
-
-def check_lstm_with_type(
-    lstm_type, target=tvm.target.Target("llvm -mcpu=core-avx2"), dev=tvm.cpu(0)
-):
-    has_proj = "p" in lstm_type
-
-    device = torch.device("cpu")
-    hidden_layers_num = 1
-    model = None
-    for batch_first in (True, False):
-        for use_bias in (True, False):
-            for rnd_weights in [True]:  # (True, False):
-                if lstm_type == "uni":
-                    model = LSTM_Model(
-                        device,
-                        batch_first=batch_first,
-                        rnd_weights_init=rnd_weights,
-                        use_bias=use_bias,
-                    )
-                elif lstm_type == "b":
-                    model = LSTM_Model(
-                        device,
-                        batch_first=batch_first,
-                        bidirectional=True,
-                        rnd_weights_init=rnd_weights,
-                        use_bias=use_bias,
-                    )
-                    hidden_layers_num = 2
-                elif lstm_type == "p":
-                    model = LSTM_Model(
-                        device,
-                        batch_first=batch_first,
-                        proj_size=projection_size,
-                        rnd_weights_init=rnd_weights,
-                        use_bias=use_bias,
-                    )
-                elif lstm_type == "s":
-                    model = LSTM_Model(
-                        device,
-                        batch_first=batch_first,
-                        layer_num=model_num_layers,
-                        rnd_weights_init=rnd_weights,
-                        use_bias=use_bias,
-                    )
-                    hidden_layers_num = model_num_layers
-                elif lstm_type == "sb":
-                    model = LSTM_Model(
-                        device,
-                        batch_first=batch_first,
-                        bidirectional=True,
-                        layer_num=model_num_layers,
-                        rnd_weights_init=rnd_weights,
-                        use_bias=use_bias,
-                    )
-                    hidden_layers_num = 2 * model_num_layers
-                elif lstm_type == "sp":
-                    model = LSTM_Model(
-                        device,
-                        batch_first=batch_first,
-                        layer_num=model_num_layers,
-                        proj_size=projection_size,
-                        rnd_weights_init=rnd_weights,
-                        use_bias=use_bias,
-                    )
-                    hidden_layers_num = model_num_layers
-                elif lstm_type == "bp":
-                    model = LSTM_Model(
-                        device,
-                        batch_first=batch_first,
-                        bidirectional=True,
-                        proj_size=projection_size,
-                        rnd_weights_init=rnd_weights,
-                        use_bias=use_bias,
-                    )
-                    hidden_layers_num = 2
-                elif lstm_type == "sbp":
-                    model = LSTM_Model(
-                        device,
-                        batch_first=batch_first,
-                        bidirectional=True,
-                        layer_num=model_num_layers,
-                        proj_size=projection_size,
-                        rnd_weights_init=rnd_weights,
-                        use_bias=use_bias,
-                    )
-                    hidden_layers_num = 2 * model_num_layers
-                else:
-                    print("WARNING: LSTM type {} is not supported 
here!".format(lstm_type))
-                    return
-
-                model.eval()
-
-                # Get golden output from original model
-                input_hidden_shape = (hidden_layers_num, batch_size, 
model_hidden_size)
-                input_hidden_shape_with_proj = (hidden_layers_num, batch_size, 
projection_size)
-                dummy_input, input_shape = model.get_dummy_input()
-                golden_output_batch = 
model.forward(dummy_input.to(device)).detach().cpu().numpy()
-
-                dtype = "float32"
-                h_zeros = np.zeros(input_hidden_shape, dtype=dtype)
-                if has_proj:
-                    h_zeros = np.zeros(input_hidden_shape_with_proj, 
dtype=dtype)
-                c_zeros = np.zeros(input_hidden_shape, dtype=dtype)
-
-                tvm_output = None
-                for format in ["ts"]:  # ["ts", "onnx"]:
-                    if format == "ts":
-                        # Use torch.jit.trace to generate a 
torch.jit.ScriptModule via tracing.
-                        traced_script_module = torch.jit.trace(model, 
dummy_input).eval()
-
-                        # Import model to Relay
-                        shape_list = [("input", input_shape)]
-                        mod, params = 
relay.frontend.from_pytorch(traced_script_module, shape_list)
-
-                        # Model compilation by tvm
-                        with tvm.transform.PassContext(opt_level=3):
-                            lib = relay.build(mod, target=target, 
params=params)
-                    elif format == "onnx":
-                        if has_proj:
-                            print(
-                                "WARNING: torch.onnx.export does not support 
conversion LSTM with projection "
-                                "from pytorch! TODO: waiting for the support 
and correct test after that."
-                            )
-                            continue
-                        onnx_io = io.BytesIO()
-                        with torch.no_grad():
-                            h0 = torch.rand(input_hidden_shape)
-                            if has_proj:
-                                h0 = torch.rand(input_hidden_shape_with_proj)
-                            c0 = torch.rand(input_hidden_shape)
-                            input_names = ["input", "h0", "c0"]
-
-                            # default export (without dynamic input)
-                            torch.onnx.export(
-                                model, (dummy_input, (h0, c0)), onnx_io, 
input_names=input_names
-                            )
-                        onnx_io.seek(0, 0)
-                        onnx_model = onnx.load_model(onnx_io)
-
-                        # Import model to Relay
-                        shape_dict = {
-                            "input": input_shape,
-                            "h0": input_hidden_shape,
-                            "c0": input_hidden_shape,
-                        }
-                        if has_proj:
-                            shape_dict = {
-                                "input": input_shape,
-                                "h0": input_hidden_shape_with_proj,
-                                "c0": input_hidden_shape,
-                            }
-                        mod, params = relay.frontend.from_onnx(onnx_model, 
shape_dict)
-
-                        # Model compilation by tvm
-                        with tvm.transform.PassContext(opt_level=1):
-                            lib = relay.build(mod, target=target, 
params=params)
-
-                    # Inference of the model with given input data
-                    m = graph_executor.GraphModule(lib["default"](dev))
-
-                    # Set inputs
-                    m.set_input(
-                        input=tvm.nd.array(dummy_input.numpy().astype(dtype)),
-                        h0=tvm.nd.array(h_zeros),
-                        c0=tvm.nd.array(c_zeros),
-                    )
-                    # Execute
-                    m.run()
-                    # Get outputs (converted to numpy array)
-                    tvm_output = m.get_output(0).numpy()
-
-                    compare(tvm_output, golden_output_batch)
-
-
[email protected]_gpu
-def test_lstms():
-    for target, dev in tvm.testing.enabled_targets():
-        check_lstm_with_type("uni", target, dev)
-        # check_lstm_with_type("p", target, dev)
-        check_lstm_with_type("s", target, dev)
-        check_lstm_with_type("b", target, dev)
-        # check_lstm_with_type("bp", target, dev)
-        # check_lstm_with_type("sp", target, dev)
-        check_lstm_with_type("sb", target, dev)
-        # check_lstm_with_type("sbp", target, dev)
-
-
-if __name__ == "__main__":
-    test_lstms()
diff --git a/tests/python/frontend/pytorch/test_rnns.py 
b/tests/python/frontend/pytorch/test_rnns.py
new file mode 100644
index 0000000..b5784a6
--- /dev/null
+++ b/tests/python/frontend/pytorch/test_rnns.py
@@ -0,0 +1,430 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import tvm
+import tvm.testing
+import torch
+import onnx
+import io
+import sys
+
+from tvm import relay
+from tvm.contrib import graph_executor
+
+from torch import nn
+
+## LSTM parameters
+lstm_feature_size = 16
+lstm_hidden_size = 32
+lstm_projection_size = 20
+
+## GRU parameters
+gru_feature_size = 8
+gru_hidden_size = 16
+
+num_layers = 2
+seqs_length = 2
+batch_size = 2
+
+
+class RNN_Model(nn.Module):
+    """
+    It is base class for RNN layer classes.
+    It contains some common fields and methods for child classes.
+    """
+
+    def __init__(
+        self,
+    ):
+        super().__init__()
+
+        # model is defined in child class
+        self.model = None
+
+    def forward(self, input, hidden_init=None):
+        """
+        Computes the output tensor after input inference along RNN layer.
+
+        :param input: batch of data as a tensor of shape (seqs_length, 
batch_size, feature_size) or (batch_size, seqs_length, feature_size) if 
self.batch_first = True
+        :param hidden_init: initial hidden state(s) of the RNN as a tensor(s) 
of shape (num_layers, batch_size, hidden_size). Will default to a tensor of 
zeros if None.
+        :return: the output tensor of shape (batch_size, hidden_size)
+        """
+        if self.model is None:
+            raise NotImplementedError("self.model must be defined in 
subclasses!")
+        out, _ = self.model(input, hidden_init)
+
+        return out
+
+    def gen_rnd_weights(self):
+        """
+        Generate random weigths for the model
+        """
+        if self.model is None:
+            raise NotImplementedError("self.model must be defined in 
subclasses!")
+        with torch.no_grad():
+            for weight_group in self.model.all_weights:
+                for weight in weight_group:
+                    weight.data = torch.rand(weight.shape)
+
+    def get_dummy_inputs(self):
+        raise NotImplementedError("subclasses must override 
get_dummy_inputs()!")
+
+    def get_input_names(self):
+        raise NotImplementedError("subclasses must override 
get_input_names()!")
+
+    def get_shape_desc(self, frontend_type):
+        raise NotImplementedError("subclasses must override 
get_shape_desc(frontend_type)!")
+
+    def get_tvm_inputs(self, dtype):
+        raise NotImplementedError("subclasses must override 
get_tvm_inputs(dtype)!")
+
+
+class GRU_Model(RNN_Model):
+    def __init__(
+        self,
+        seq_len=seqs_length,
+        batch_size=batch_size,
+        feature_size=gru_feature_size,
+        hidden_size=gru_hidden_size,
+        batch_first=False,
+        layer_num=1,
+        bidirectional=False,
+        use_bias=True,
+        rnd_weights_init=False,
+    ):
+        super().__init__()
+
+        # Shapes
+        self.shape = [seq_len, batch_size, feature_size]
+        if batch_first:
+            self.shape = [batch_size, seq_len, feature_size]
+        layers_num = 2 * layer_num if bidirectional else layer_num
+        self.h0_shape = [layers_num, batch_size, hidden_size]
+        # Dummy inputs
+        self.dummy_inputs = (torch.rand(self.shape), 
torch.zeros(self.h0_shape))
+
+        self.model = nn.GRU(
+            input_size=feature_size,
+            hidden_size=hidden_size,
+            num_layers=layer_num,
+            bidirectional=bidirectional,
+            batch_first=batch_first,
+            bias=use_bias,
+        )
+
+        if rnd_weights_init:
+            self.gen_rnd_weights()
+
+    def gen_rnd_weights(self):
+        """
+        Generate random weigths for the model with biases
+        For first uni- and bidirectional weights group:
+            Wi (3*hidden_size, feature_size)
+            Wh (3*hidden_size, hidden_size)
+            Bi (3*hidden_size)
+            Bh (3*hidden_size)
+        For other weights group:
+            Wi (3*hidden_size, hidden_size)
+            Wh (3*hidden_size, hidden_size)
+            Bi (3*hidden_size)
+            Bh (3*hidden_size)
+        For generation of random weigths for the model without biases the Bi 
and Bh weights are skipped
+        """
+        super().gen_rnd_weights()
+
+    def get_dummy_inputs(self):
+        return self.dummy_inputs
+
+    def get_input_names(self):
+        return ["input", "h0"]
+
+    def get_shape_desc(self, frontend_type):
+        shape_desc = None
+        if frontend_type == "pt":  # PyTorch
+            shape_desc = [("input", self.shape)]
+        elif frontend_type == "onnx":  # ONNX
+            shape_desc = {
+                "input": self.shape,
+                "h0": self.h0_shape,
+            }
+        return shape_desc
+
+    def get_tvm_inputs(self, dtype):
+        return {
+            "input": tvm.nd.array(self.dummy_inputs[0].numpy().astype(dtype)),
+            "h0": tvm.nd.array(self.dummy_inputs[1].numpy().astype(dtype)),
+        }
+
+
+def check_torch_version_for_proj_in_lstm():
+    """
+    proj_size parameter is supported in torch.nn.LSTM layer started from 1.8.0 
torch version
+    """
+    me = False
+
+    version = torch.__version__
+    major, minor, micro = version.split(".")
+
+    if int(major) > 1:
+        me = True
+    elif int(major) == 1:
+        if int(minor) >= 8:
+            me = True
+
+    return me
+
+
+class LSTM_Model(RNN_Model):
+    def __init__(
+        self,
+        seq_len=seqs_length,
+        batch_size=batch_size,
+        feature_size=lstm_feature_size,
+        hidden_size=lstm_hidden_size,
+        batch_first=False,
+        layer_num=1,
+        bidirectional=False,
+        proj_size=0,
+        use_bias=True,
+        rnd_weights_init=False,
+    ):
+        super().__init__()
+
+        # Shapes
+        self.shape = [seq_len, batch_size, feature_size]
+        if batch_first:
+            self.shape = [batch_size, seq_len, feature_size]
+        layers_num = 2 * layer_num if bidirectional else layer_num
+        self.h0_shape = [layers_num, batch_size, hidden_size]
+        if proj_size > 0:
+            self.h0_shape = [layers_num, batch_size, proj_size]
+        self.c0_shape = [layers_num, batch_size, hidden_size]
+        # Dummy inputs
+        self.dummy_inputs = (
+            torch.rand(self.shape),
+            (torch.zeros(self.h0_shape), torch.zeros(self.c0_shape)),
+        )
+
+        if check_torch_version_for_proj_in_lstm():
+            self.model = nn.LSTM(
+                input_size=lstm_feature_size,
+                hidden_size=lstm_hidden_size,
+                num_layers=layer_num,
+                bidirectional=bidirectional,
+                proj_size=proj_size,
+                batch_first=batch_first,
+                bias=use_bias,
+            )
+        else:
+            if proj_size > 0:
+                print(
+                    "WARNING: projection is not supported for torch version 
less than 1.8.0! ",
+                    "LSTM was constructed without projection!",
+                )
+                # sys.exit()
+            self.model = nn.LSTM(
+                input_size=lstm_feature_size,
+                hidden_size=lstm_hidden_size,
+                num_layers=layer_num,
+                bidirectional=bidirectional,
+                batch_first=batch_first,
+                bias=use_bias,
+            )
+
+        if rnd_weights_init:
+            self.gen_rnd_weights()
+
+    def gen_rnd_weights(self):
+        """
+        Generate random weigths for the model with biases
+        Without projection:
+            For first weights group:
+                Wi (4*lstm_hidden_size, lstm_feature_size)
+                Wh (4*lstm_hidden_size, lstm_hidden_size)
+                Bi (4*lstm_hidden_size)
+                Bh (4*lstm_hidden_size)
+            For first bidirectional weights group:
+                Wi (4*lstm_hidden_size, lstm_feature_size)
+                Wh (4*lstm_hidden_size, lstm_hidden_size)
+                Bi (4*lstm_hidden_size)
+                Bh (4*lstm_hidden_size)
+            For other weights group:
+                Wi (4*lstm_hidden_size, lstm_hidden_size)
+                Wh (4*lstm_hidden_size, lstm_hidden_size)
+                Bi (4*lstm_hidden_size)
+                Bh (4*lstm_hidden_size)
+        With projection:
+            For first weights group:
+                Wi (4*lstm_hidden_size, lstm_feature_size)
+                Wh (4*lstm_hidden_size, proj_size)
+                Bi (4*lstm_hidden_size)
+                Bh (4*lstm_hidden_size)
+                P  (proj_size, lstm_hidden_size)
+            For first bidirectional weights group:
+                Wi (4*lstm_hidden_size, lstm_feature_size)
+                Wh (4*lstm_hidden_size, proj_size)
+                Bi (4*lstm_hidden_size)
+                Bh (4*lstm_hidden_size)
+                P  (proj_size, lstm_hidden_size)
+            For other weights group:
+                Wi (4*lstm_hidden_size, proj_size * num_directions)
+                Wh (4*lstm_hidden_size, proj_size)
+                Bi (4*lstm_hidden_size)
+                Bh (4*lstm_hidden_size)
+                P  (proj_size, lstm_hidden_size)
+        For generation of random weigths for the model without biases Bi and 
Bh are skipped
+        """
+        super().gen_rnd_weights()
+
+    def get_dummy_inputs(self):
+        return self.dummy_inputs
+
+    def get_input_names(self):
+        return ["input", "h0", "c0"]
+
+    def get_shape_desc(self, frontend_type):
+        shape_desc = None
+        if frontend_type == "pt":  # PyTorch
+            shape_desc = [("input", self.shape)]
+        elif frontend_type == "onnx":  # ONNX
+            shape_desc = {
+                "input": self.shape,
+                "h0": self.h0_shape,
+                "c0": self.c0_shape,
+            }
+        return shape_desc
+
+    def get_tvm_inputs(self, dtype):
+        return {
+            "input": tvm.nd.array(self.dummy_inputs[0].numpy().astype(dtype)),
+            "h0": tvm.nd.array(self.dummy_inputs[1][0].numpy().astype(dtype)),
+            "c0": tvm.nd.array(self.dummy_inputs[1][1].numpy().astype(dtype)),
+        }
+
+
+def compare(input, gold_data, rtol=1e-5, atol=1e-5):
+    tvm.testing.assert_allclose(input, gold_data, rtol=rtol, atol=atol)
+
+
+def check_rnn(rnn_type, rnn_mod, target=tvm.target.Target("llvm 
-mcpu=core-avx2"), dev=tvm.cpu(0)):
+    def get_model(
+        rnn_type,
+        rnn_mod,
+        args,
+    ):
+        # Fill args
+        if "b" in rnn_mod:
+            args["bidirectional"] = True
+        if "s" in rnn_mod:
+            args["layer_num"] = num_layers
+
+        if rnn_type == "GRU":
+            RNN_Model_selector = GRU_Model
+        elif rnn_type == "LSTM":
+            RNN_Model_selector = LSTM_Model
+            if "p" in rnn_mod:
+                args["proj_size"] = lstm_projection_size
+
+        return RNN_Model_selector(**args)
+
+    def get_onnx_model(model):
+        onnx_io = io.BytesIO()
+        with torch.no_grad():
+            input_names = model.get_input_names()
+            inputs = model.get_dummy_inputs()
+
+            # default export (without dynamic input)
+            torch.onnx.export(model, inputs, onnx_io, input_names=input_names)
+
+        onnx_io.seek(0, 0)
+        return onnx.load_model(onnx_io)
+
+    model = None
+    dtype = "float32"
+    device = torch.device("cpu")
+    for batch_first in (True, False):
+        for use_bias in (True, False):
+            for rnd_weights in [True]:  # (True, False):
+                model_inputs = {
+                    "batch_first": batch_first,
+                    "use_bias": use_bias,
+                    "rnd_weights_init": rnd_weights,
+                }
+                model = get_model(rnn_type, rnn_mod, model_inputs)
+                model.to(device)
+                model.eval()
+
+                # Get golden output from original model
+                dummy_inputs = model.get_dummy_inputs()
+                golden_output = 
model.forward(dummy_inputs[0].to(device)).detach().cpu().numpy()
+
+                tvm_output = None
+                for format in ["pt"]:  # ["pt", "onnx"]:
+                    shape_desc = model.get_shape_desc(format)
+                    if format == "pt":
+                        # Use torch.jit.trace to generate a 
torch.jit.ScriptModule via tracing.
+                        traced_script_module = torch.jit.trace(model, 
dummy_inputs[0]).eval()
+
+                        # Import model to Relay
+                        mod, params = 
relay.frontend.from_pytorch(traced_script_module, shape_desc)
+                    elif format == "onnx":
+                        try:
+                            onnx_model = get_onnx_model(model)
+                        except:
+                            print(
+                                "WARNING: torch.onnx.export does not support 
conversion LSTM with projection "
+                                "from pytorch! TODO: waiting for the support 
and correct test after that."
+                            )
+                            continue
+
+                        # Import model to Relay
+                        mod, params = relay.frontend.from_onnx(onnx_model, 
shape_desc)
+
+                    # Model compilation by tvm
+                    with tvm.transform.PassContext(opt_level=3):
+                        lib = relay.build(mod, target=target, params=params)
+
+                    # Inference of the model with given input data
+                    m = graph_executor.GraphModule(lib["default"](dev))
+
+                    # Set inputs
+                    tvm_inputs = model.get_tvm_inputs(dtype)
+                    m.set_input(**tvm_inputs)
+                    # Execute
+                    m.run()
+                    # Get outputs (converted to numpy array)
+                    tvm_output = m.get_output(0).numpy()
+
+                    compare(tvm_output, golden_output)
+
+
[email protected]_gpu
+def test_rnns():
+    for target, dev in tvm.testing.enabled_targets():
+        # RNN types: GRU, LSTM
+        # GRU modifications: unidirectional, stacked, bidirectional, stacked 
bidirectional
+        for mod_type in ["uni", "s", "b", "sb"]:
+            check_rnn("GRU", mod_type, target, dev)
+        # LSTM modifications: unidirectional, stacked, bidirectional, stacked 
bidirectional,
+        # and all these types with projection ("p", "sp", "bp", "sbp")
+        # The latter are skiped for test acceleration
+        for mod_type in ["uni", "s", "b", "sb"]:
+            check_rnn("LSTM", mod_type, target, dev)
+
+
+if __name__ == "__main__":
+    test_rnns()

Reply via email to