This is an automated email from the ASF dual-hosted git repository.

andrewzhaoluo pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new cfa55251b2 [Relay][Frontend][Onnx] Add RNN operation for ONNX frontend 
(#12213)
cfa55251b2 is described below

commit cfa55251b2163b77e98726e73c4c3298b1903899
Author: zhang-yi-chi <[email protected]>
AuthorDate: Fri Aug 5 02:18:02 2022 +0800

    [Relay][Frontend][Onnx] Add RNN operation for ONNX frontend (#12213)
    
    * Add RNN operation for ONNX frontend.
    
    * link checks
    
    * rm test_rnn_batchwise in unsupported_onnx_tests
    
    * merge similar codes to class methods
    
    * implement opset 14 and refactor test_forward
    
    * reformat verify_rnn_helper
    
    Co-authored-by: 张亦驰 <[email protected]>
---
 python/tvm/relay/frontend/onnx.py          | 304 ++++++++++++++++----------
 tests/python/frontend/onnx/test_forward.py | 328 +++++++++++------------------
 2 files changed, 313 insertions(+), 319 deletions(-)

diff --git a/python/tvm/relay/frontend/onnx.py 
b/python/tvm/relay/frontend/onnx.py
index e78e65dc4e..e885fb89ab 100644
--- a/python/tvm/relay/frontend/onnx.py
+++ b/python/tvm/relay/frontend/onnx.py
@@ -53,6 +53,7 @@ from .common import (
     infer_value,
     lstm_cell,
     new_var,
+    rnn_cell,
     shape_of,
     try_resolve_var_to_const,
     unbind,
@@ -2723,7 +2724,7 @@ class Expand(OnnxOpConverter):
 
 
 class RNN(OnnxOpConverter):
-    """Operator converter for RNNs such as LSTM and GRU."""
+    """Operator converter for RNNs such as RNN, LSTM and GRU."""
 
     @classmethod
     def _activation_helper(cls, activation, alpha, beta):
@@ -2756,35 +2757,27 @@ class RNN(OnnxOpConverter):
         ]
         return activation.decode("utf-8") in needs_beta
 
-
-class LSTM(RNN):
-    """Operator converter for LSTM"""
-
     @classmethod
-    def bidir_lstm_cell(
+    def bidir_rnn_cell(
         cls,
         input_seqs,
         weight_dicts,
         acts,
     ):
         """
-        Bidirectional LSTM cell
+        Bidirectional RNN cell
         """
         seq_len = len(input_seqs)
-        forward_outputs, fw_H_t, fw_C_t = lstm_cell(
+        forward_outputs, fw_H_t = rnn_cell(
             input_seqs,
             **weight_dicts[0],
-            f_act=acts[0],
-            g_act=acts[1],
-            h_act=acts[2],
+            act=acts[0],
         )
 
-        reverse_outputs, rev_H_t, rev_C_t = lstm_cell(
+        reverse_outputs, rev_H_t = rnn_cell(
             input_seqs,
             **weight_dicts[1],
-            f_act=acts[3],
-            g_act=acts[4],
-            h_act=acts[5],
+            act=acts[1],
             backwards=True,
         )
 
@@ -2797,44 +2790,24 @@ class LSTM(RNN):
         return (
             _op.stack(final_outputs, axis=0),
             _op.stack([fw_H_t, rev_H_t], axis=0),
-            _op.stack([fw_C_t, rev_C_t], axis=0),
         )
 
     @classmethod
-    def _impl_v7(cls, inputs, attr, params):
-        # Unpack inputs, note that if optional and not provided then value 
will be None.
-        X = inputs[0]
-        Wp = inputs[1]
-        Rp = inputs[2]
-        Bp = inputs[3]
-        # Sequence length currently unused as it can be inferred from shapes.
-        # sequence_lens = inputs['sequence_lens']
-        Hp_0 = inputs[5]
-        Cp_0 = inputs[6]
-        Pp = inputs[7]
-
-        num_directions = infer_shape(Wp)[0]
-        W_dtype = infer_type(Wp).checked_type.dtype
-
-        if num_directions not in [1, 2]:
-            raise ValueError("num_directions must be either 1 or 2!")
-
-        X_shape = infer_shape(X)
-        hidden_size = infer_shape(Rp)[-1]
-        batch_size = X_shape[1]
-
-        # Initialize state if not provided.
-        # Otherwise remove bidirectional axis.
-        if Hp_0 is None:
-            Hp_0 = _op.zeros((num_directions, batch_size, hidden_size), 
W_dtype)
-        if Cp_0 is None:
-            Cp_0 = _op.zeros((num_directions, batch_size, hidden_size), 
W_dtype)
+    def _default_activations(cls, num_directions):
+        return [_op.tanh] * num_directions
 
+    @classmethod
+    def _get_activations(cls, attr, multiplier, num_directions, rnn_type):
+        """
+        Activation functions
+        """
         if "activations" in attr:
             activations = attr["activations"]
-            if len(activations) != 3 * num_directions:
+            if len(activations) != multiplier * num_directions:
                 raise NotImplementedError(
-                    f"LSTM assumes 3 * num_directions activation functions are 
provided"
+                    "{} assumes {} * num_directions activation functions are 
provided".format(
+                        rnn_type, multiplier
+                    )
                 )
             alpha_loc = 0
             alphas = attr.get("activation_alpha", [])
@@ -2845,7 +2818,7 @@ class LSTM(RNN):
             if isinstance(betas, float):
                 betas = [betas]
             acts = []
-            for i in range(3 * num_directions):
+            for i in range(multiplier * num_directions):
                 alpha = None
                 beta = None
                 activation = activations[i]
@@ -2857,18 +2830,171 @@ class LSTM(RNN):
                     beta_loc += 1
                 acts.append(cls._activation_helper(activation, alpha, beta))
         else:
-            acts = [_op.sigmoid, _op.tanh, _op.tanh] * num_directions
+            acts = cls._default_activations(num_directions)
+        return acts
+
+    @classmethod
+    def _inputs_helper(cls, inputs, layout):
+        """
+        Process inputs
+        """
+        # Unpack inputs, note that if optional and not provided then value 
will be None.
+        X = inputs[0]
+        Wp = inputs[1]
+        Rp = inputs[2]
+        Bp = inputs[3]
+        # Sequence length currently unused as it can be inferred from shapes.
+        # sequence_lens = inputs['sequence_lens']
+        Hp_0 = inputs[5]
+
+        num_directions = infer_shape(Wp)[0]
+
+        if num_directions not in [1, 2]:
+            raise ValueError("num_directions must be either 1 or 2!")
+
+        if layout == 1:
+            X = _op.transpose(X, axes=(1, 0))
+
+        # Initialize state if not provided.
+        if Hp_0 is None:
+            W_dtype = infer_type(Wp).checked_type.dtype
+            X_shape = infer_shape(X)
+            hidden_size = infer_shape(Rp)[-1]
+            batch_size = X_shape[1]
+            Hp_0 = _op.zeros((num_directions, batch_size, hidden_size), 
W_dtype)
+        elif layout == 1:
+            Hp_0 = _op.transpose(Hp_0, axes=(1, 0))
 
         # TODO (vvchernov): It can be replaced by _op.split if issue #8412 is 
resolved
         X_steps = unbind(X, axis=0)
 
         H_ts = _op.split(Hp_0, num_directions)
-        C_ts = _op.split(Cp_0, num_directions)
         Ws = _op.split(Wp, num_directions)
         Rs = _op.split(Rp, num_directions)
 
+        Bs = None
         if Bp is not None:
             Bs = _op.split(Bp, num_directions)
+        return X_steps, H_ts, Ws, Rs, Bs, num_directions
+
+    @classmethod
+    def _impl_common(cls, inputs, attr, layout):
+        X_steps, H_ts, Ws, Rs, Bs, num_directions = cls._inputs_helper(inputs, 
layout)
+        acts = cls._get_activations(attr, 1, num_directions, "RNN")
+
+        weights_dicts = []
+        for i in range(num_directions):
+            weights_dict = {}
+
+            weights_dict["hidden_state"] = _op.squeeze(H_ts[i], axis=[0])
+
+            weights_dict["w_inp"] = _op.squeeze(Ws[i], axis=[0])
+            weights_dict["w_hid"] = _op.squeeze(Rs[i], axis=[0])
+            if Bs is not None:
+                Bi, Bh = _op.split(Bs[i], 2, -1)
+                weights_dict["b_inp"] = _op.squeeze(Bi, axis=[0])
+                weights_dict["b_hid"] = _op.squeeze(Bh, axis=[0])
+            weights_dicts.append(weights_dict)
+
+        if num_directions == 2:
+            output, H = RNN.bidir_rnn_cell(
+                input_seqs=X_steps,
+                weight_dicts=weights_dicts,
+                acts=acts,
+            )
+        else:
+            # outputs shape = [seqs_num, (batch_size, hidden_size)]
+            outputs, H = rnn_cell(
+                input_seqs=X_steps,
+                **weights_dicts[0],
+                act=acts[0],
+            )
+
+            # output shape = (seqs_num, num_directions, batch_size, 
hidden_size)
+            output = _op.expand_dims(_op.stack(outputs, axis=0), axis=1)
+            H = _op.expand_dims(H, axis=0)
+
+        if layout == 1:
+            output = _op.transpose(output, axes=(1, 0))
+            H = _op.transpose(H, axes=(1, 0))
+        return _expr.TupleWrapper(_expr.Tuple((output, H)), 2)
+
+    @classmethod
+    def _impl_v7(cls, inputs, attr, params):
+        return cls._impl_common(inputs, attr, 0)
+
+    @classmethod
+    def _impl_v14(cls, inputs, attr, params):
+        layout = attr.get("layout", 0)
+        return cls._impl_common(inputs, attr, layout)
+
+
+class LSTM(RNN):
+    """Operator converter for LSTM"""
+
+    @classmethod
+    def bidir_lstm_cell(
+        cls,
+        input_seqs,
+        weight_dicts,
+        acts,
+    ):
+        """
+        Bidirectional LSTM cell
+        """
+        seq_len = len(input_seqs)
+        forward_outputs, fw_H_t, fw_C_t = lstm_cell(
+            input_seqs,
+            **weight_dicts[0],
+            f_act=acts[0],
+            g_act=acts[1],
+            h_act=acts[2],
+        )
+
+        reverse_outputs, rev_H_t, rev_C_t = lstm_cell(
+            input_seqs,
+            **weight_dicts[1],
+            f_act=acts[3],
+            g_act=acts[4],
+            h_act=acts[5],
+            backwards=True,
+        )
+
+        final_outputs = []
+        for i in range(seq_len):
+            final_outputs.append(
+                _op.stack([forward_outputs[i], reverse_outputs[seq_len - 1 - 
i]], axis=0)
+            )
+
+        return (
+            _op.stack(final_outputs, axis=0),
+            _op.stack([fw_H_t, rev_H_t], axis=0),
+            _op.stack([fw_C_t, rev_C_t], axis=0),
+        )
+
+    @classmethod
+    def _default_activations(cls, num_directions):
+        return [_op.sigmoid, _op.tanh, _op.tanh] * num_directions
+
+    @classmethod
+    def _impl_common(cls, inputs, attr, layout):
+        X_steps, H_ts, Ws, Rs, Bs, num_directions = cls._inputs_helper(inputs, 
layout)
+        acts = cls._get_activations(attr, 3, num_directions, "LSTM")
+
+        # cell state
+        Cp_0 = inputs[6]
+        if Cp_0 is None:
+            C_ts = _expr.TupleWrapper(
+                _expr.Tuple([_op.zeros_like(H_ts[i]) for i in 
range(num_directions)]),
+                num_directions,
+            )
+        else:
+            if layout == 1:
+                Cp_0 = _op.transpose(Cp_0, axes=(1, 0))
+            C_ts = _op.split(Cp_0, num_directions)
+
+        # peepholes
+        Pp = inputs[7]
         if Pp is not None:
             p_i, p_o, p_f = _op.split(Pp, 3, axis=1)
 
@@ -2888,7 +3014,7 @@ class LSTM(RNN):
             weights_dict["w_inp"] = _op.concatenate([mati, matf, matc, mato], 
axis=0)
             mati, mato, matf, matc = _op.split(_op.squeeze(Rs[i], axis=[0]), 4)
             weights_dict["w_hid"] = _op.concatenate([mati, matf, matc, mato], 
axis=0)
-            if Bp is not None:
+            if Bs is not None:
                 Bi, Bh = _op.split(Bs[i], 2, -1)
                 mati, mato, matf, matc = _op.split(_op.squeeze(Bi, axis=[0]), 
4)
                 weights_dict["b_inp"] = _op.concatenate([mati, matf, matc, 
mato], axis=0)
@@ -2921,6 +3047,10 @@ class LSTM(RNN):
             H = _op.expand_dims(H, axis=0)
             C = _op.expand_dims(C, axis=0)
 
+        if layout == 1:
+            output = _op.transpose(output, axes=(1, 0))
+            H = _op.transpose(H, axes=(1, 0))
+            C = _op.transpose(C, axes=(1, 0))
         return _expr.TupleWrapper(_expr.Tuple((output, H, C)), 3)
 
 
@@ -2965,68 +3095,14 @@ class GRU(RNN):
         )
 
     @classmethod
-    def _impl_v7(cls, inputs, attr, params):
-        # Unpack inputs, note that if optional and not provided then value 
will be None.
-        X = inputs[0]
-        Wp = inputs[1]
-        Rp = inputs[2]
-        Bp = inputs[3]
-        # Sequence length currently unused as it can be inferred from shapes.
-        # sequence_lens = inputs['sequence_lens']
-        Hp_0 = inputs[5]
-        linear_before_reset = attr.get("linear_before_reset", 0)
-
-        num_directions = infer_shape(Wp)[0]
-        W_dtype = infer_type(Wp).checked_type.dtype
-
-        if num_directions not in [1, 2]:
-            raise ValueError("num_directions must be either 1 or 2!")
-
-        X_shape = infer_shape(X)
-        hidden_size = infer_shape(Rp)[-1]
-        batch_size = X_shape[1]
-
-        if Hp_0 is None:
-            Hp_0 = _op.zeros((num_directions, batch_size, hidden_size), 
W_dtype)
+    def _default_activations(cls, num_directions):
+        return [_op.sigmoid, _op.tanh] * num_directions
 
-        if "activations" in attr:
-            activations = attr["activations"]
-            if len(activations) != 2 * num_directions:
-                raise NotImplementedError(
-                    "GRU assumes 2 * num_directions activation functions are 
provided"
-                )
-            alpha_loc = 0
-            alphas = attr.get("activation_alpha", [])
-            if isinstance(alphas, float):
-                alphas = [alphas]
-            beta_loc = 0
-            betas = attr.get("activation_beta", [])
-            if isinstance(betas, float):
-                betas = [betas]
-            acts = []
-            for i in range(2 * num_directions):
-                alpha = None
-                beta = None
-                activation = activations[i]
-                if cls._activation_needs_alpha(activation) and len(alphas) > 
alpha_loc:
-                    alpha = alphas[alpha_loc]
-                    alpha_loc += 1
-                if cls._activation_needs_beta(activation) and len(betas) > 
beta_loc:
-                    beta = betas[beta_loc]
-                    beta_loc += 1
-                acts.append(cls._activation_helper(activation, alpha, beta))
-        else:
-            acts = [_op.sigmoid, _op.tanh] * 2
-
-        # TODO (vvchernov): It can be replaced by _op.split if issue #8412 is 
resolved
-        X_steps = unbind(X, axis=0)
-
-        H_ts = _op.split(Hp_0, num_directions)
-        Ws = _op.split(Wp, num_directions)
-        Rs = _op.split(Rp, num_directions)
-
-        if Bp is not None:
-            Bs = _op.split(Bp, num_directions)
+    @classmethod
+    def _impl_common(cls, inputs, attr, layout):
+        X_steps, H_ts, Ws, Rs, Bs, num_directions = cls._inputs_helper(inputs, 
layout)
+        acts = cls._get_activations(attr, 2, num_directions, "GRU")
+        linear_before_reset = attr.get("linear_before_reset", 0)
 
         weights_dicts = []
         for i in range(num_directions):
@@ -3040,7 +3116,7 @@ class GRU(RNN):
             weights_dict["w_inp"] = _op.concatenate([matr, matz, matn], axis=0)
             matz, matr, matn = _op.split(_op.squeeze(Rs[i], axis=[0]), 3)
             weights_dict["w_hid"] = _op.concatenate([matr, matz, matn], axis=0)
-            if Bp is not None:
+            if Bs is not None:
                 Bi, Bh = _op.split(Bs[i], 2, -1)
                 matz, matr, matn = _op.split(_op.squeeze(Bi, axis=[0]), 3)
                 weights_dict["b_inp"] = _op.concatenate([matr, matz, matn], 
axis=0)
@@ -3067,6 +3143,9 @@ class GRU(RNN):
             output = _op.expand_dims(_op.stack(outputs, axis=0), axis=1)
             H = _op.expand_dims(H, axis=0)
 
+        if layout == 1:
+            output = _op.transpose(output, axes=(1, 0))
+            H = _op.transpose(H, axes=(1, 0))
         return _expr.TupleWrapper(_expr.Tuple((output, H)), 2)
 
 
@@ -5301,6 +5380,7 @@ def _get_convert_map(opset):
         "Flatten": Flatten.get_converter(opset),
         "LRN": LRN.get_converter(opset),
         # Recurrent Layers
+        "RNN": RNN.get_converter(opset),
         "LSTM": LSTM.get_converter(opset),
         "GRU": GRU.get_converter(opset),
         # defs/vision
diff --git a/tests/python/frontend/onnx/test_forward.py 
b/tests/python/frontend/onnx/test_forward.py
index e500f0902c..3d0cc3414c 100644
--- a/tests/python/frontend/onnx/test_forward.py
+++ b/tests/python/frontend/onnx/test_forward.py
@@ -3749,12 +3749,15 @@ def verify_rnn(
     use_peep=False,
     linear_before_reset=False,
     directions=1,
+    layout=0,
     rtol=1e-5,
     atol=1e-5,
     target=None,
     dev=None,
 ):
-    if rnn_type == "LSTM":
+    if rnn_type == "RNN":
+        multiplier = 1
+    elif rnn_type == "LSTM":
         multiplier = 4
     elif rnn_type == "GRU":
         multiplier = 3
@@ -3787,7 +3790,10 @@ def verify_rnn(
             proto_type = dtype_map[np_arr.dtype.name]
             input_tensors.append(helper.make_tensor_value_info(name, 
proto_type, shape))
 
-        x_np = np.random.uniform(size=(seq_length, batch_size, 
input_size)).astype("float32")
+        if layout == 1:
+            x_np = np.random.uniform(size=(batch_size, seq_length, 
input_size)).astype("float32")
+        else:
+            x_np = np.random.uniform(size=(seq_length, batch_size, 
input_size)).astype("float32")
         w_np = np.random.uniform(size=(directions, multiplier * hidden_size, 
input_size)).astype(
             "float32"
         )
@@ -3809,15 +3815,25 @@ def verify_rnn(
             sequence_np = np.repeat(seq_length, batch_size).astype("int32")
             register(sequence_np, "sequence_lens")
 
-            initial_h_np = np.random.uniform(size=(directions, batch_size, 
hidden_size)).astype(
-                "float32"
-            )
+            if layout == 1:
+                initial_h_np = np.random.uniform(size=(batch_size, directions, 
hidden_size)).astype(
+                    "float32"
+                )
+            else:
+                initial_h_np = np.random.uniform(size=(directions, batch_size, 
hidden_size)).astype(
+                    "float32"
+                )
             register(initial_h_np, "initial_h")
 
             if rnn_type == "LSTM":
-                initial_c_np = np.random.uniform(size=(directions, batch_size, 
hidden_size)).astype(
-                    "float32"
-                )
+                if layout == 1:
+                    initial_c_np = np.random.uniform(
+                        size=(batch_size, directions, hidden_size)
+                    ).astype("float32")
+                else:
+                    initial_c_np = np.random.uniform(
+                        size=(directions, batch_size, hidden_size)
+                    ).astype("float32")
                 register(initial_c_np, "initial_c")
 
         if use_peep and rnn_type == "LSTM":
@@ -3839,11 +3855,18 @@ def verify_rnn(
             graph_outputs.append(helper.make_tensor_value_info(name, 
proto_type, list(shape)))
             output_shapes.append(list(shape))
 
-        register("Y", [seq_length, directions, batch_size, hidden_size], 
TensorProto.FLOAT)
-        register("Y_h", [directions, batch_size, hidden_size], 
TensorProto.FLOAT)
+        if layout == 1:
+            register("Y", [directions, seq_length, batch_size, hidden_size], 
TensorProto.FLOAT)
+            register("Y_h", [batch_size, directions, hidden_size], 
TensorProto.FLOAT)
+        else:
+            register("Y", [seq_length, directions, batch_size, hidden_size], 
TensorProto.FLOAT)
+            register("Y_h", [directions, batch_size, hidden_size], 
TensorProto.FLOAT)
 
         if rnn_type == "LSTM":
-            register("Y_c", [directions, batch_size, hidden_size], 
TensorProto.FLOAT)
+            if layout == 1:
+                register("Y_c", [batch_size, directions, hidden_size], 
TensorProto.FLOAT)
+            else:
+                register("Y_c", [directions, batch_size, hidden_size], 
TensorProto.FLOAT)
 
         return output_names, graph_outputs, output_shapes
 
@@ -3867,6 +3890,9 @@ def verify_rnn(
     if linear_before_reset and rnn_type == "GRU":
         lbr_attr = helper.make_attribute("linear_before_reset", 1)
         rnn_node.attribute.append(lbr_attr)
+    if layout == 1:
+        layout_attr = helper.make_attribute("layout", 1)
+        rnn_node.attribute.append(layout_attr)
 
     graph = helper.make_graph([rnn_node], "rnn_test", inputs=input_tensors, 
outputs=graph_outputs)
 
@@ -3877,8 +3903,13 @@ def verify_rnn(
     )
 
 
[email protected]_targets
-def test_lstm(target, dev):
+def verify_rnn_helper(target, dev, rnn_type):
+    num_activations = 1
+    if rnn_type == "GRU":
+        num_activations = 2
+    elif rnn_type == "LSTM":
+        num_activations = 3
+
     for directions in [1, 2]:
         # No bias.
         verify_rnn(
@@ -3887,7 +3918,7 @@ def test_lstm(target, dev):
             input_size=16,
             hidden_size=32,
             use_bias=False,
-            rnn_type="LSTM",
+            rnn_type=rnn_type,
             directions=directions,
             target=target,
             dev=dev,
@@ -3899,7 +3930,7 @@ def test_lstm(target, dev):
             input_size=16,
             hidden_size=32,
             use_bias=True,
-            rnn_type="LSTM",
+            rnn_type=rnn_type,
             directions=directions,
             target=target,
             dev=dev,
@@ -3911,7 +3942,7 @@ def test_lstm(target, dev):
             input_size=16,
             hidden_size=40,
             use_bias=True,
-            rnn_type="LSTM",
+            rnn_type=rnn_type,
             directions=directions,
             target=target,
             dev=dev,
@@ -3923,7 +3954,7 @@ def test_lstm(target, dev):
             input_size=16,
             hidden_size=32,
             use_bias=True,
-            rnn_type="LSTM",
+            rnn_type=rnn_type,
             directions=directions,
             target=target,
             dev=dev,
@@ -3935,7 +3966,7 @@ def test_lstm(target, dev):
             input_size=16,
             hidden_size=128,
             use_bias=True,
-            rnn_type="LSTM",
+            rnn_type=rnn_type,
             directions=directions,
             target=target,
             dev=dev,
@@ -3947,7 +3978,7 @@ def test_lstm(target, dev):
             input_size=64,
             hidden_size=32,
             use_bias=True,
-            rnn_type="LSTM",
+            rnn_type=rnn_type,
             directions=directions,
             target=target,
             dev=dev,
@@ -3955,50 +3986,59 @@ def test_lstm(target, dev):
 
         # Different activation testing.
         # Default value hardsigmoid.
-        verify_rnn(
-            seq_length=2,
-            batch_size=1,
-            input_size=16,
-            hidden_size=32,
-            use_bias=False,
-            activations=["HardSigmoid", "Tanh", "Tanh"] * directions,
-            rnn_type="LSTM",
-            directions=directions,
-            target=target,
-            dev=dev,
-        )
+        # TODO: onnxruntime <= v1.12.0 has wrong default value of all 
activation functions
+        if rnn_type != "RNN":
+            activations = ["HardSigmoid", "Tanh", "Tanh"][0:num_activations] * 
directions
+            verify_rnn(
+                seq_length=2,
+                batch_size=1,
+                input_size=16,
+                hidden_size=32,
+                use_bias=False,
+                activations=activations,
+                rnn_type=rnn_type,
+                directions=directions,
+                target=target,
+                dev=dev,
+            )
         # Multiple parametrized activations.
+        activations = ["HardSigmoid", "LeakyRelu", "Tanh"][0:num_activations] 
* directions
+        alphas = [2.0, 0.5, 0.0][0:num_activations] * directions
+        betas = [0.3, 0.0, 0.0][0:num_activations] * directions
         verify_rnn(
             seq_length=2,
             batch_size=1,
             input_size=16,
             hidden_size=32,
             use_bias=False,
-            activations=["HardSigmoid", "LeakyRelu", "Tanh"] * directions,
-            alphas=[2.0, 0.5, 0.0] * directions,
-            betas=[0.3, 0.0, 0.0] * directions,
-            rnn_type="LSTM",
+            activations=activations,
+            alphas=alphas,
+            betas=betas,
+            rnn_type=rnn_type,
             directions=directions,
             target=target,
             dev=dev,
         )
         # All parametrized with new Affine activation.
+        activations = ["Affine", "LeakyRelu", 
"HardSigmoid"][0:num_activations] * directions
+        alphas = [0.8, 2.0, 0.5][0:num_activations] * directions
+        betas = [0.0, 0.3, 0.0][0:num_activations] * directions
         verify_rnn(
             seq_length=2,
             batch_size=1,
             input_size=16,
             hidden_size=32,
             use_bias=False,
-            activations=["HardSigmoid", "LeakyRelu", "Affine"] * directions,
-            alphas=[2.0, 0.5, 0.8] * directions,
-            betas=[0.3, 0.1, 0.0] * directions,
-            rnn_type="LSTM",
+            activations=activations,
+            alphas=alphas,
+            betas=betas,
+            rnn_type=rnn_type,
             directions=directions,
             target=target,
             dev=dev,
         )
 
-        # Testing with initial state and peepholes
+        # Testing with initial state
         verify_rnn(
             seq_length=2,
             batch_size=1,
@@ -4006,182 +4046,57 @@ def test_lstm(target, dev):
             hidden_size=32,
             use_bias=True,
             use_initial_state=True,
-            rnn_type="LSTM",
+            rnn_type=rnn_type,
             directions=directions,
             target=target,
             dev=dev,
         )
 
-        verify_rnn(
-            seq_length=2,
-            batch_size=1,
-            input_size=16,
-            hidden_size=32,
-            use_bias=True,
-            use_initial_state=True,
-            use_peep=True,
-            rnn_type="LSTM",
-            directions=directions,
-            target=target,
-            dev=dev,
-        )
+        # Testing layout
+        # TODO: onnxruntime <= 1.12.0 doesn't support layout == 1
+        # verify_rnn(
+        #     seq_length=2,
+        #     batch_size=1,
+        #     input_size=16,
+        #     hidden_size=32,
+        #     use_bias=True,
+        #     rnn_type="RNN",
+        #     directions=directions,
+        #     layout=1,
+        #     target=target,
+        #     dev=dev,
+        # )
+
+        # Testing with peepholes
+        if rnn_type == "LSTM":
+            verify_rnn(
+                seq_length=2,
+                batch_size=1,
+                input_size=16,
+                hidden_size=32,
+                use_bias=True,
+                use_initial_state=True,
+                use_peep=True,
+                rnn_type="LSTM",
+                directions=directions,
+                target=target,
+                dev=dev,
+            )
 
 
 @tvm.testing.parametrize_targets
-def test_gru(target, dev):
-    # Set seed for test reproduction
-    np.random.seed(137)
-    for directions in [1, 2]:
-        # No bias.
-        verify_rnn(
-            seq_length=2,
-            batch_size=1,
-            input_size=16,
-            hidden_size=32,
-            use_bias=False,
-            rnn_type="GRU",
-            directions=directions,
-            rtol=1e-6,
-            atol=1e-6,
-            target=target,
-            dev=dev,
-        )
-        # large batch. linear before reset
-        verify_rnn(
-            seq_length=4,
-            batch_size=8,
-            input_size=16,
-            hidden_size=32,
-            use_bias=True,
-            rnn_type="GRU",
-            linear_before_reset=True,
-            directions=directions,
-            target=target,
-            dev=dev,
-        )
-        # Non power of two.
-        verify_rnn(
-            seq_length=3,
-            batch_size=3,
-            input_size=16,
-            hidden_size=40,
-            use_bias=True,
-            rnn_type="GRU",
-            directions=directions,
-            rtol=1e-6,
-            atol=1e-6,
-            target=target,
-            dev=dev,
-        )
-        # Long sequence.
-        verify_rnn(
-            seq_length=8,
-            batch_size=1,
-            input_size=16,
-            hidden_size=32,
-            use_bias=True,
-            rnn_type="GRU",
-            directions=directions,
-            rtol=1e-6,
-            atol=1e-6,
-            target=target,
-            dev=dev,
-        )
-        # Large hidden.
-        verify_rnn(
-            seq_length=2,
-            batch_size=1,
-            input_size=16,
-            hidden_size=128,
-            use_bias=True,
-            rnn_type="GRU",
-            directions=directions,
-            rtol=1e-6,
-            atol=1e-6,
-            target=target,
-            dev=dev,
-        )
-        # Large input.
-        verify_rnn(
-            seq_length=2,
-            batch_size=1,
-            input_size=64,
-            hidden_size=32,
-            use_bias=True,
-            rnn_type="GRU",
-            directions=directions,
-            rtol=1e-6,
-            atol=1e-6,
-            target=target,
-            dev=dev,
-        )
+def test_rnn(target, dev):
+    verify_rnn_helper(target, dev, "RNN")
 
-        # Different activation testing.
-        # Default value hardsigmoid.
-        verify_rnn(
-            seq_length=2,
-            batch_size=1,
-            input_size=16,
-            hidden_size=32,
-            use_bias=False,
-            activations=["HardSigmoid", "Softsign"] * directions,
-            rnn_type="GRU",
-            directions=directions,
-            rtol=1e-6,
-            atol=1e-6,
-            target=target,
-            dev=dev,
-        )
-        # Multiple parametrized activations.
-        verify_rnn(
-            seq_length=2,
-            batch_size=1,
-            input_size=16,
-            hidden_size=32,
-            use_bias=False,
-            activations=["HardSigmoid", "LeakyRelu"] * directions,
-            alphas=[2.0, 0.5] * directions,
-            betas=[0.3, 0.0] * directions,
-            rnn_type="GRU",
-            directions=directions,
-            rtol=1e-8,
-            atol=1e-8,
-            target=target,
-            dev=dev,
-        )
-        # All parametrized with new Affine activation.
-        verify_rnn(
-            seq_length=2,
-            batch_size=1,
-            input_size=16,
-            hidden_size=32,
-            use_bias=False,
-            activations=["HardSigmoid", "Affine"] * directions,
-            alphas=[2.0, 0.8] * directions,
-            betas=[0.3, 0.1] * directions,
-            rnn_type="GRU",
-            directions=directions,
-            rtol=1e-8,
-            atol=1e-8,
-            target=target,
-            dev=dev,
-        )
 
-        # Testing with initial state
-        verify_rnn(
-            seq_length=2,
-            batch_size=1,
-            input_size=16,
-            hidden_size=32,
-            use_bias=True,
-            use_initial_state=True,
-            rnn_type="GRU",
-            directions=directions,
-            rtol=1e-6,
-            atol=1e-6,
-            target=target,
-            dev=dev,
-        )
[email protected]_targets
+def test_lstm(target, dev):
+    verify_rnn_helper(target, dev, "LSTM")
+
+
[email protected]_targets
+def test_gru(target, dev):
+    verify_rnn_helper(target, dev, "GRU")
 
 
 @tvm.testing.parametrize_targets
@@ -5213,7 +5128,6 @@ unsupported_onnx_tests = [
     "test_reduce_sum_keepdims_random",
     "test_reduce_sum_negative_axes_keepdims_example",
     "test_reduce_sum_negative_axes_keepdims_random",
-    "test_rnn_seq_length",
     "test_sequence_insert_at_back",
     "test_sequence_insert_at_front",
     "test_simple_rnn_batchwise",

Reply via email to