This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 934c4a4869 [Relax][PyTorch] Add support for bidirectional GRU (#18532)
934c4a4869 is described below

commit 934c4a4869e931a61913fa061b878d5628002d43
Author: Guan-Ming (Wesley) Chiu <[email protected]>
AuthorDate: Sun Nov 30 13:08:51 2025 +0800

    [Relax][PyTorch] Add support for bidirectional GRU (#18532)
    
    ## How
    
    - implement bidirectional GRU
---
 .../frontend/torch/exported_program_translator.py  | 486 ++++++++++-----------
 .../relax/test_frontend_from_exported_program.py   |  86 ++++
 2 files changed, 327 insertions(+), 245 deletions(-)

diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py 
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 940ce9be81..2ec61796c3 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -609,292 +609,288 @@ class ExportedProgramImporter(BaseFXGraphImporter):
             output = self.block_builder.emit(relax.op.permute_dims(output, 
axes=[1, 0, 2]))
         return output
 
-    def _gru(self, node: fx.Node) -> relax.Var:
-        args = self.retrieve_args(node)
-        input_tensor = args[0]
-        hx = args[1] if len(args) > 1 else None
-        params = args[2] if len(args) > 2 else None
-        has_biases = args[3] if len(args) > 3 else True
-        num_layers = args[4] if len(args) > 4 else 1
-        _dropout = args[5] if len(args) > 5 else 0.0  # Not used in inference
-        _train = args[6] if len(args) > 6 else False  # Not used in inference
-        bidirectional = args[7] if len(args) > 7 else False
-        batch_first = args[8] if len(args) > 8 else False
-
-        if bidirectional:
-            raise NotImplementedError("Bidirectional GRU is not yet supported")
-
-        input_shape = self.shape_of(input_tensor)
-        if batch_first:
-            batch_size, seq_len, input_size = input_shape
-        else:
-            seq_len, batch_size, input_size = input_shape
-
-        if isinstance(seq_len, tvm.tir.IntImm):
-            seq_len = seq_len.value
-        if isinstance(batch_size, tvm.tir.IntImm):
-            batch_size = batch_size.value
-        if isinstance(input_size, tvm.tir.IntImm):
-            input_size = input_size.value
+    def _gru_cell_unroll(
+        self,
+        input_reshaped,
+        weight_ih,
+        weight_hh,
+        bias_ih,
+        bias_hh,
+        h_prev,
+        seq_len,
+        hidden_size,
+        dtype,
+        reverse=False,
+    ):
+        """Unroll GRU cells for a single direction."""
+        gate_size = hidden_size
 
-        if params and len(params) >= 2:
-            # For multi-layer, we need to extract the first layer's weights
-            # to determine hidden size
-            if num_layers > 1:
-                # Multi-layer: params[0] is first layer's weight_ih
-                weight_ih = params[0]
-            else:
-                # Single layer: params[0] is weight_ih
-                weight_ih = params[0]
-            # Extract hidden size from weight dimensions
-            # weight_ih has shape (3 * hidden_size, input_size)
-            weight_ih_shape = self.shape_of(weight_ih)
-            hidden_size = weight_ih_shape[0] // 3  # 3 gates: reset, update, 
new
-        else:
-            # Fallback to a default hidden size
-            hidden_size = 16
+        # Split weights by gates: PyTorch GRU gate order: reset, update, new 
(r, z, n)
+        # Reset gate weights
+        weight_ih_r = self.block_builder.emit(
+            relax.op.strided_slice(weight_ih, axes=[0], begin=[0], 
end=[gate_size])
+        )
+        weight_hh_r = self.block_builder.emit(
+            relax.op.strided_slice(weight_hh, axes=[0], begin=[0], 
end=[gate_size])
+        )
 
-        # Implement actual GRU computation using Relax operations
-        # GRU equations:
-        # r_t = sigmoid(W_ir * x_t + b_ir + W_hr * h_{t-1} + b_hr)
-        # z_t = sigmoid(W_iz * x_t + b_iz + W_hz * h_{t-1} + b_hz)
-        # n_t = tanh(W_in * x_t + b_in + r_t * (W_hn * h_{t-1} + b_hn))
-        # h_t = (1 - z_t) * n_t + z_t * h_{t-1}
-        dtype = input_tensor.struct_info.dtype
+        # Update gate weights
+        weight_ih_z = self.block_builder.emit(
+            relax.op.strided_slice(weight_ih, axes=[0], begin=[gate_size], 
end=[2 * gate_size])
+        )
+        weight_hh_z = self.block_builder.emit(
+            relax.op.strided_slice(weight_hh, axes=[0], begin=[gate_size], 
end=[2 * gate_size])
+        )
 
-        # Reshape input for processing
-        if batch_first:
-            # Input: (batch, seq_len, input_size) -> (seq_len, batch, 
input_size)
-            input_reshaped = self.block_builder.emit(
-                relax.op.permute_dims(input_tensor, axes=[1, 0, 2])
-            )
-        else:
-            input_reshaped = input_tensor
+        # New gate weights
+        weight_ih_n = self.block_builder.emit(
+            relax.op.strided_slice(weight_ih, axes=[0], begin=[2 * gate_size], 
end=[3 * gate_size])
+        )
+        weight_hh_n = self.block_builder.emit(
+            relax.op.strided_slice(weight_hh, axes=[0], begin=[2 * gate_size], 
end=[3 * gate_size])
+        )
 
-        # Initialize hidden states for all layers
-        if hx is not None:
-            # hx shape: (num_layers, batch_size, hidden_size)
-            h_states = []
-            for layer in range(num_layers):
-                h_layer = self.block_builder.emit(
-                    relax.op.take(hx, relax.const(layer, "int64"), axis=0, 
mode="clip")
-                )
-                h_states.append(h_layer)
-        else:
-            h_states = []
-            for layer in range(num_layers):
-                h_layer = self.block_builder.emit(
-                    relax.op.zeros(relax.ShapeExpr((batch_size, hidden_size)), 
dtype)
-                )
-                h_states.append(h_layer)
+        # Transpose weights for matmul
+        weight_ih_r_t = 
self.block_builder.emit(relax.op.permute_dims(weight_ih_r, axes=[1, 0]))
+        weight_hh_r_t = 
self.block_builder.emit(relax.op.permute_dims(weight_hh_r, axes=[1, 0]))
+        weight_ih_z_t = 
self.block_builder.emit(relax.op.permute_dims(weight_ih_z, axes=[1, 0]))
+        weight_hh_z_t = 
self.block_builder.emit(relax.op.permute_dims(weight_hh_z, axes=[1, 0]))
+        weight_ih_n_t = 
self.block_builder.emit(relax.op.permute_dims(weight_ih_n, axes=[1, 0]))
+        weight_hh_n_t = 
self.block_builder.emit(relax.op.permute_dims(weight_hh_n, axes=[1, 0]))
 
         outputs = []
+        time_steps = range(seq_len - 1, -1, -1) if reverse else range(seq_len)
 
-        for t in range(seq_len):
+        for t in time_steps:
             # Get input at time t: (batch_size, input_size)
             x_t = self.block_builder.emit(
                 relax.op.take(input_reshaped, relax.const(t, "int64"), axis=0, 
mode="clip")
             )
 
-            # Process through each layer
-            current_input = x_t
-            new_h_states = []
-
-            for layer in range(num_layers):
-                # Get layer parameters
-                if params and len(params) >= 4 * num_layers:
-                    # Multi-layer case: params are organized as
-                    # [layer0_ih, layer0_hh, layer0_bias_ih, layer0_bias_hh, 
layer1_ih, ...]
-                    param_offset = layer * 4
-                    weight_ih = params[param_offset]
-                    weight_hh = params[param_offset + 1]
-                    bias_ih = params[param_offset + 2] if has_biases else None
-                    bias_hh = params[param_offset + 3] if has_biases else None
-                elif params and len(params) >= 4:
-                    # Single layer case
-                    weight_ih = params[0]
-                    weight_hh = params[1]
-                    bias_ih = params[2] if has_biases else None
-                    bias_hh = params[3] if has_biases else None
-                else:
-                    # Fallback: create zero weights
-                    weight_ih = self.block_builder.emit(
-                        relax.op.zeros(
-                            relax.ShapeExpr(
-                                (3 * hidden_size, input_size if layer == 0 
else hidden_size)
-                            ),
-                            dtype,
-                        )
-                    )
-                    weight_hh = self.block_builder.emit(
-                        relax.op.zeros(relax.ShapeExpr((3 * hidden_size, 
hidden_size)), dtype)
-                    )
-                    bias_ih = None
-                    bias_hh = None
-
-                # Get previous hidden state for this layer
-                h_prev = h_states[layer]
-
-                # Split weights by gates: PyTorch GRU gate order: reset, 
update, new (r, z, n)
-                gate_size = hidden_size
-
-                # Reset gate weights
-                weight_ih_r = self.block_builder.emit(
-                    relax.op.strided_slice(weight_ih, axes=[0], begin=[0], 
end=[gate_size])
+            # Compute reset gate: r_t = sigmoid(W_ir * x_t + b_ir + W_hr * 
h_{t-1} + b_hr)
+            r_ih = self.block_builder.emit(relax.op.linear_algebra.matmul(x_t, 
weight_ih_r_t))
+            r_hh = 
self.block_builder.emit(relax.op.linear_algebra.matmul(h_prev, weight_hh_r_t))
+            if bias_ih is not None and bias_hh is not None:
+                bias_ih_r = self.block_builder.emit(
+                    relax.op.strided_slice(bias_ih, axes=[0], begin=[0], 
end=[gate_size])
                 )
-                weight_hh_r = self.block_builder.emit(
-                    relax.op.strided_slice(weight_hh, axes=[0], begin=[0], 
end=[gate_size])
+                bias_hh_r = self.block_builder.emit(
+                    relax.op.strided_slice(bias_hh, axes=[0], begin=[0], 
end=[gate_size])
                 )
+                r_t = self.block_builder.emit(
+                    relax.op.sigmoid(
+                        relax.op.add(relax.op.add(relax.op.add(r_ih, 
bias_ih_r), r_hh), bias_hh_r)
+                    )
+                )
+            else:
+                r_t = 
self.block_builder.emit(relax.op.sigmoid(relax.op.add(r_ih, r_hh)))
 
-                # Update gate weights
-                weight_ih_z = self.block_builder.emit(
+            # Compute update gate: z_t = sigmoid(W_iz * x_t + b_iz + W_hz * 
h_{t-1} + b_hz)
+            z_ih = self.block_builder.emit(relax.op.linear_algebra.matmul(x_t, 
weight_ih_z_t))
+            z_hh = 
self.block_builder.emit(relax.op.linear_algebra.matmul(h_prev, weight_hh_z_t))
+            if bias_ih is not None and bias_hh is not None:
+                bias_ih_z = self.block_builder.emit(
                     relax.op.strided_slice(
-                        weight_ih, axes=[0], begin=[gate_size], end=[2 * 
gate_size]
+                        bias_ih, axes=[0], begin=[gate_size], end=[2 * 
gate_size]
                     )
                 )
-                weight_hh_z = self.block_builder.emit(
+                bias_hh_z = self.block_builder.emit(
                     relax.op.strided_slice(
-                        weight_hh, axes=[0], begin=[gate_size], end=[2 * 
gate_size]
+                        bias_hh, axes=[0], begin=[gate_size], end=[2 * 
gate_size]
+                    )
+                )
+                z_t = self.block_builder.emit(
+                    relax.op.sigmoid(
+                        relax.op.add(relax.op.add(relax.op.add(z_ih, 
bias_ih_z), z_hh), bias_hh_z)
                     )
                 )
+            else:
+                z_t = 
self.block_builder.emit(relax.op.sigmoid(relax.op.add(z_ih, z_hh)))
 
-                # New gate weights
-                weight_ih_n = self.block_builder.emit(
+            # Compute new gate: n_t = tanh(W_in * x_t + b_in + r_t * (W_hn * 
h_{t-1} + b_hn))
+            n_ih = self.block_builder.emit(relax.op.linear_algebra.matmul(x_t, 
weight_ih_n_t))
+            n_hh = 
self.block_builder.emit(relax.op.linear_algebra.matmul(h_prev, weight_hh_n_t))
+            if bias_ih is not None and bias_hh is not None:
+                bias_ih_n = self.block_builder.emit(
                     relax.op.strided_slice(
-                        weight_ih, axes=[0], begin=[2 * gate_size], end=[3 * 
gate_size]
+                        bias_ih, axes=[0], begin=[2 * gate_size], end=[3 * 
gate_size]
                     )
                 )
-                weight_hh_n = self.block_builder.emit(
+                bias_hh_n = self.block_builder.emit(
                     relax.op.strided_slice(
-                        weight_hh, axes=[0], begin=[2 * gate_size], end=[3 * 
gate_size]
+                        bias_hh, axes=[0], begin=[2 * gate_size], end=[3 * 
gate_size]
                     )
                 )
-
-                # Transpose weights for matmul
-                weight_ih_r_t = self.block_builder.emit(
-                    relax.op.permute_dims(weight_ih_r, axes=[1, 0])
-                )
-                weight_hh_r_t = self.block_builder.emit(
-                    relax.op.permute_dims(weight_hh_r, axes=[1, 0])
-                )
-                weight_ih_z_t = self.block_builder.emit(
-                    relax.op.permute_dims(weight_ih_z, axes=[1, 0])
-                )
-                weight_hh_z_t = self.block_builder.emit(
-                    relax.op.permute_dims(weight_hh_z, axes=[1, 0])
-                )
-                weight_ih_n_t = self.block_builder.emit(
-                    relax.op.permute_dims(weight_ih_n, axes=[1, 0])
-                )
-                weight_hh_n_t = self.block_builder.emit(
-                    relax.op.permute_dims(weight_hh_n, axes=[1, 0])
-                )
-
-                # Compute reset gate: r_t = sigmoid(W_ir * x_t + b_ir + W_hr * 
h_{t-1} + b_hr)
-                r_ih = self.block_builder.emit(
-                    relax.op.linear_algebra.matmul(current_input, 
weight_ih_r_t)
-                )
-                r_hh = self.block_builder.emit(
-                    relax.op.linear_algebra.matmul(h_prev, weight_hh_r_t)
-                )
-                if bias_ih is not None and bias_hh is not None:
-                    bias_ih_r = self.block_builder.emit(
-                        relax.op.strided_slice(bias_ih, axes=[0], begin=[0], 
end=[gate_size])
-                    )
-                    bias_hh_r = self.block_builder.emit(
-                        relax.op.strided_slice(bias_hh, axes=[0], begin=[0], 
end=[gate_size])
-                    )
-                    r_t = self.block_builder.emit(
-                        relax.op.sigmoid(
-                            relax.op.add(
-                                relax.op.add(relax.op.add(r_ih, bias_ih_r), 
r_hh), bias_hh_r
-                            )
+                n_t = self.block_builder.emit(
+                    relax.op.tanh(
+                        relax.op.add(
+                            relax.op.add(n_ih, bias_ih_n),
+                            relax.op.multiply(r_t, relax.op.add(n_hh, 
bias_hh_n)),
                         )
                     )
-                else:
-                    r_t = 
self.block_builder.emit(relax.op.sigmoid(relax.op.add(r_ih, r_hh)))
-
-                # Compute update gate: z_t = sigmoid(W_iz * x_t + b_iz + W_hz 
* h_{t-1} + b_hz)
-                z_ih = self.block_builder.emit(
-                    relax.op.linear_algebra.matmul(current_input, 
weight_ih_z_t)
                 )
-                z_hh = self.block_builder.emit(
-                    relax.op.linear_algebra.matmul(h_prev, weight_hh_z_t)
+            else:
+                n_t = self.block_builder.emit(
+                    relax.op.tanh(relax.op.add(n_ih, relax.op.multiply(r_t, 
n_hh)))
                 )
-                if bias_ih is not None and bias_hh is not None:
-                    bias_ih_z = self.block_builder.emit(
-                        relax.op.strided_slice(
-                            bias_ih, axes=[0], begin=[gate_size], end=[2 * 
gate_size]
-                        )
-                    )
-                    bias_hh_z = self.block_builder.emit(
-                        relax.op.strided_slice(
-                            bias_hh, axes=[0], begin=[gate_size], end=[2 * 
gate_size]
-                        )
-                    )
-                    z_t = self.block_builder.emit(
-                        relax.op.sigmoid(
-                            relax.op.add(
-                                relax.op.add(relax.op.add(z_ih, bias_ih_z), 
z_hh), bias_hh_z
-                            )
-                        )
-                    )
-                else:
-                    z_t = 
self.block_builder.emit(relax.op.sigmoid(relax.op.add(z_ih, z_hh)))
 
-                # Compute new gate: n_t = tanh(W_in * x_t + b_in + r_t * (W_hn 
* h_{t-1} + b_hn))
-                n_ih = self.block_builder.emit(
-                    relax.op.linear_algebra.matmul(current_input, 
weight_ih_n_t)
+            # Update hidden state: h_t = (1 - z_t) * n_t + z_t * h_{t-1}
+            one_minus_z = 
self.block_builder.emit(relax.op.subtract(relax.const(1.0, dtype), z_t))
+            h_t = self.block_builder.emit(
+                relax.op.add(relax.op.multiply(one_minus_z, n_t), 
relax.op.multiply(z_t, h_prev))
+            )
+
+            outputs.append(h_t)
+            h_prev = h_t
+
+        if reverse:
+            outputs = outputs[::-1]
+
+        output = self.block_builder.emit(relax.op.stack(outputs, axis=0))
+        return output
+
+    def _gru(self, node: fx.Node) -> relax.Var:
+        args = self.retrieve_args(node)
+        input_tensor = args[0]
+        hx = args[1] if len(args) > 1 else None
+        params = args[2] if len(args) > 2 else None
+        has_biases = args[3] if len(args) > 3 else True
+        num_layers = args[4] if len(args) > 4 else 1
+        _dropout = args[5] if len(args) > 5 else 0.0  # Not used in inference
+        _train = args[6] if len(args) > 6 else False  # Not used in inference
+        bidirectional = args[7] if len(args) > 7 else False
+        batch_first = args[8] if len(args) > 8 else False
+
+        if num_layers > 1:
+            raise NotImplementedError("Multi-layer GRU is not yet supported")
+
+        input_shape = self.shape_of(input_tensor)
+        if batch_first:
+            batch_size, seq_len, input_size = input_shape
+        else:
+            seq_len, batch_size, input_size = input_shape
+
+        seq_len = int(seq_len) if isinstance(seq_len, tvm.tir.IntImm) else 
seq_len
+        batch_size = int(batch_size) if isinstance(batch_size, tvm.tir.IntImm) 
else batch_size
+        input_size = int(input_size) if isinstance(input_size, tvm.tir.IntImm) 
else input_size
+
+        # Extract hidden size from parameters
+        # For bidirectional: params has weights for both directions
+        # params_per_direction = 4 if has_biases else 2 (weight_ih, weight_hh, 
[bias_ih, bias_hh])
+        params_per_direction = 4 if has_biases else 2
+
+        if params and len(params) >= 2:
+            # Extract hidden size from weight dimensions
+            # weight_ih has shape (3 * hidden_size, input_size)
+            weight_ih_shape = self.shape_of(params[0])
+            hidden_size = weight_ih_shape[0] // 3  # 3 gates: reset, update, 
new
+        else:
+            # Fallback to a default hidden size
+            hidden_size = 16
+
+        dtype = input_tensor.struct_info.dtype
+
+        # Extract forward direction weights
+        if params and len(params) >= params_per_direction:
+            weight_ih_fwd = params[0]
+            weight_hh_fwd = params[1]
+            bias_ih_fwd = params[2] if has_biases else None
+            bias_hh_fwd = params[3] if has_biases else None
+        else:
+            # Fallback: create zero weights
+            weight_ih_fwd = self.block_builder.emit(
+                relax.op.zeros(relax.ShapeExpr((3 * hidden_size, input_size)), 
dtype)
+            )
+            weight_hh_fwd = self.block_builder.emit(
+                relax.op.zeros(relax.ShapeExpr((3 * hidden_size, 
hidden_size)), dtype)
+            )
+            bias_ih_fwd = None
+            bias_hh_fwd = None
+
+        # Extract or create backward direction weights if bidirectional
+        if bidirectional:
+            if params and len(params) >= params_per_direction * 2:
+                weight_ih_bwd = params[params_per_direction]
+                weight_hh_bwd = params[params_per_direction + 1]
+                bias_ih_bwd = params[params_per_direction + 2] if has_biases 
else None
+                bias_hh_bwd = params[params_per_direction + 3] if has_biases 
else None
+            else:
+                # Fallback: create zero weights
+                weight_ih_bwd = self.block_builder.emit(
+                    relax.op.zeros(relax.ShapeExpr((3 * hidden_size, 
input_size)), dtype)
                 )
-                n_hh = self.block_builder.emit(
-                    relax.op.linear_algebra.matmul(h_prev, weight_hh_n_t)
+                weight_hh_bwd = self.block_builder.emit(
+                    relax.op.zeros(relax.ShapeExpr((3 * hidden_size, 
hidden_size)), dtype)
                 )
-                if bias_ih is not None and bias_hh is not None:
-                    bias_ih_n = self.block_builder.emit(
-                        relax.op.strided_slice(
-                            bias_ih, axes=[0], begin=[2 * gate_size], end=[3 * 
gate_size]
-                        )
-                    )
-                    bias_hh_n = self.block_builder.emit(
-                        relax.op.strided_slice(
-                            bias_hh, axes=[0], begin=[2 * gate_size], end=[3 * 
gate_size]
-                        )
-                    )
-                    n_t = self.block_builder.emit(
-                        relax.op.tanh(
-                            relax.op.add(
-                                relax.op.add(n_ih, bias_ih_n),
-                                relax.op.multiply(r_t, relax.op.add(n_hh, 
bias_hh_n)),
-                            )
-                        )
-                    )
-                else:
-                    n_t = self.block_builder.emit(
-                        relax.op.tanh(relax.op.add(n_ih, 
relax.op.multiply(r_t, n_hh)))
-                    )
+                bias_ih_bwd = None
+                bias_hh_bwd = None
+        else:
+            weight_ih_bwd = None
+            weight_hh_bwd = None
+            bias_ih_bwd = None
+            bias_hh_bwd = None
 
-                # Update hidden state: h_t = (1 - z_t) * n_t + z_t * h_{t-1}
-                one_minus_z = self.block_builder.emit(
-                    relax.op.subtract(relax.const(1.0, dtype), z_t)
+        # Initialize hidden states
+        if hx is not None:
+            h_prev_fwd = self.block_builder.emit(
+                relax.op.take(hx, relax.const(0, "int64"), axis=0, mode="clip")
+            )
+            if bidirectional:
+                h_prev_bwd = self.block_builder.emit(
+                    relax.op.take(hx, relax.const(1, "int64"), axis=0, 
mode="clip")
                 )
-                h_t = self.block_builder.emit(
-                    relax.op.add(
-                        relax.op.multiply(one_minus_z, n_t), 
relax.op.multiply(z_t, h_prev)
-                    )
+            else:
+                h_prev_bwd = None
+        else:
+            h_prev_fwd = self.block_builder.emit(
+                relax.op.zeros(relax.ShapeExpr((batch_size, hidden_size)), 
dtype)
+            )
+            if bidirectional:
+                h_prev_bwd = self.block_builder.emit(
+                    relax.op.zeros(relax.ShapeExpr((batch_size, hidden_size)), 
dtype)
                 )
+            else:
+                h_prev_bwd = None
 
-                new_h_states.append(h_t)
-
-                current_input = h_t
-
-            # Update hidden states for next time step
-            h_states = new_h_states
+        # Reshape input for processing
+        input_reshaped = (
+            self.block_builder.emit(relax.op.permute_dims(input_tensor, 
axes=[1, 0, 2]))
+            if batch_first
+            else input_tensor
+        )
 
-            # Store output (from the last layer)
-            outputs.append(h_states[-1])
+        # Process forward direction
+        output_fwd = self._gru_cell_unroll(
+            input_reshaped,
+            weight_ih_fwd,
+            weight_hh_fwd,
+            bias_ih_fwd,
+            bias_hh_fwd,
+            h_prev_fwd,
+            seq_len,
+            hidden_size,
+            dtype,
+            reverse=False,
+        )
 
-        # Stack outputs: (seq_len, batch_size, hidden_size)
-        output = self.block_builder.emit(relax.op.stack(outputs, axis=0))
+        # Process backward direction if bidirectional
+        if bidirectional:
+            output_bwd = self._gru_cell_unroll(
+                input_reshaped,
+                weight_ih_bwd,
+                weight_hh_bwd,
+                bias_ih_bwd,
+                bias_hh_bwd,
+                h_prev_bwd,
+                seq_len,
+                hidden_size,
+                dtype,
+                reverse=True,
+            )
+            # Concatenate forward and backward outputs along feature dimension
+            output = self.block_builder.emit(relax.op.concat([output_fwd, 
output_bwd], axis=2))
+        else:
+            output = output_fwd
 
         # Reshape back to batch_first if needed
         if batch_first:
diff --git a/tests/python/relax/test_frontend_from_exported_program.py 
b/tests/python/relax/test_frontend_from_exported_program.py
index 7397b3f21a..0658dbfaf3 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -8011,6 +8011,92 @@ def test_gru():
     assert pytorch_output2.shape == tvm_output2_np.shape
     tvm.testing.assert_allclose(pytorch_output2.numpy(), tvm_output2_np, 
rtol=1e-4, atol=1e-5)
 
+    # Test bidirectional GRU with batch_first=True
+    class BidirectionalGRU(nn.Module):
+        def __init__(self):
+            super().__init__()
+            self.gru = nn.GRU(
+                input_size=4,
+                hidden_size=5,
+                num_layers=1,
+                batch_first=True,
+                bidirectional=True,
+            )
+
+        def forward(self, x):
+            y, _ = self.gru(x)
+            return y
+
+    torch.manual_seed(44)
+    x3 = torch.randn(2, 3, 4, dtype=torch.float32)
+    model3 = BidirectionalGRU()
+    with torch.no_grad():
+        pytorch_output3 = model3(x3)
+
+    # Verify output shape is correct (hidden_size * 2 due to bidirectional)
+    assert pytorch_output3.shape == (
+        2,
+        3,
+        10,
+    ), f"Expected shape (2, 3, 10), got {pytorch_output3.shape}"
+
+    exported_program3 = export(model3, args=(x3,))
+    mod3 = from_exported_program(exported_program3)
+    ex3 = relax.build(mod3, target)
+    vm3 = relax.VirtualMachine(ex3, tvm.cpu())
+    x3_tvm = tvm.runtime.tensor(x3.numpy())
+    tvm_output3 = vm3["main"](x3_tvm)
+    if hasattr(tvm_output3, "numpy"):
+        tvm_output3_np = tvm_output3.numpy()
+    else:
+        tvm_output3_np = tvm_output3[0].numpy()
+    assert (
+        pytorch_output3.shape == tvm_output3_np.shape
+    ), f"Shape mismatch: PyTorch {pytorch_output3.shape} vs TVM 
{tvm_output3_np.shape}"
+    tvm.testing.assert_allclose(pytorch_output3.numpy(), tvm_output3_np, 
rtol=1e-4, atol=1e-5)
+
+    # Test bidirectional GRU with batch_first=False
+    class SeqFirstBidirectionalGRU(nn.Module):
+        def __init__(self):
+            super().__init__()
+            self.gru = nn.GRU(
+                input_size=3,
+                hidden_size=4,
+                num_layers=1,
+                batch_first=False,
+                bidirectional=True,
+            )
+
+        def forward(self, x):
+            y, _ = self.gru(x)
+            return y
+
+    torch.manual_seed(45)
+    x4 = torch.randn(4, 2, 3, dtype=torch.float32)  # (seq_len, batch, 
input_size)
+    model4 = SeqFirstBidirectionalGRU()
+    with torch.no_grad():
+        pytorch_output4 = model4(x4)
+
+    # Verify output shape (seq_len, batch, hidden_size * 2)
+    assert pytorch_output4.shape == (
+        4,
+        2,
+        8,
+    ), f"Expected shape (4, 2, 8), got {pytorch_output4.shape}"
+
+    exported_program4 = export(model4, args=(x4,))
+    mod4 = from_exported_program(exported_program4)
+    ex4 = relax.build(mod4, target)
+    vm4 = relax.VirtualMachine(ex4, tvm.cpu())
+    x4_tvm = tvm.runtime.tensor(x4.numpy())
+    tvm_output4 = vm4["main"](x4_tvm)
+    if hasattr(tvm_output4, "numpy"):
+        tvm_output4_np = tvm_output4.numpy()
+    else:
+        tvm_output4_np = tvm_output4[0].numpy()
+    assert pytorch_output4.shape == tvm_output4_np.shape
+    tvm.testing.assert_allclose(pytorch_output4.numpy(), tvm_output4_np, 
rtol=1e-4, atol=1e-5)
+
 
 def test_dynamic_shape_with_range_constraints():
     class DynamicModel(torch.nn.Module):

Reply via email to