This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 934c4a4869 [Relax][PyTorch] Add support for bidirectional GRU (#18532)
934c4a4869 is described below
commit 934c4a4869e931a61913fa061b878d5628002d43
Author: Guan-Ming (Wesley) Chiu <[email protected]>
AuthorDate: Sun Nov 30 13:08:51 2025 +0800
[Relax][PyTorch] Add support for bidirectional GRU (#18532)
## How
- implement bidirectional GRU
---
.../frontend/torch/exported_program_translator.py | 486 ++++++++++-----------
.../relax/test_frontend_from_exported_program.py | 86 ++++
2 files changed, 327 insertions(+), 245 deletions(-)
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 940ce9be81..2ec61796c3 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -609,292 +609,288 @@ class ExportedProgramImporter(BaseFXGraphImporter):
output = self.block_builder.emit(relax.op.permute_dims(output,
axes=[1, 0, 2]))
return output
- def _gru(self, node: fx.Node) -> relax.Var:
- args = self.retrieve_args(node)
- input_tensor = args[0]
- hx = args[1] if len(args) > 1 else None
- params = args[2] if len(args) > 2 else None
- has_biases = args[3] if len(args) > 3 else True
- num_layers = args[4] if len(args) > 4 else 1
- _dropout = args[5] if len(args) > 5 else 0.0 # Not used in inference
- _train = args[6] if len(args) > 6 else False # Not used in inference
- bidirectional = args[7] if len(args) > 7 else False
- batch_first = args[8] if len(args) > 8 else False
-
- if bidirectional:
- raise NotImplementedError("Bidirectional GRU is not yet supported")
-
- input_shape = self.shape_of(input_tensor)
- if batch_first:
- batch_size, seq_len, input_size = input_shape
- else:
- seq_len, batch_size, input_size = input_shape
-
- if isinstance(seq_len, tvm.tir.IntImm):
- seq_len = seq_len.value
- if isinstance(batch_size, tvm.tir.IntImm):
- batch_size = batch_size.value
- if isinstance(input_size, tvm.tir.IntImm):
- input_size = input_size.value
+ def _gru_cell_unroll(
+ self,
+ input_reshaped,
+ weight_ih,
+ weight_hh,
+ bias_ih,
+ bias_hh,
+ h_prev,
+ seq_len,
+ hidden_size,
+ dtype,
+ reverse=False,
+ ):
+ """Unroll GRU cells for a single direction."""
+ gate_size = hidden_size
- if params and len(params) >= 2:
- # For multi-layer, we need to extract the first layer's weights
- # to determine hidden size
- if num_layers > 1:
- # Multi-layer: params[0] is first layer's weight_ih
- weight_ih = params[0]
- else:
- # Single layer: params[0] is weight_ih
- weight_ih = params[0]
- # Extract hidden size from weight dimensions
- # weight_ih has shape (3 * hidden_size, input_size)
- weight_ih_shape = self.shape_of(weight_ih)
- hidden_size = weight_ih_shape[0] // 3 # 3 gates: reset, update,
new
- else:
- # Fallback to a default hidden size
- hidden_size = 16
+ # Split weights by gates: PyTorch GRU gate order: reset, update, new
(r, z, n)
+ # Reset gate weights
+ weight_ih_r = self.block_builder.emit(
+ relax.op.strided_slice(weight_ih, axes=[0], begin=[0],
end=[gate_size])
+ )
+ weight_hh_r = self.block_builder.emit(
+ relax.op.strided_slice(weight_hh, axes=[0], begin=[0],
end=[gate_size])
+ )
- # Implement actual GRU computation using Relax operations
- # GRU equations:
- # r_t = sigmoid(W_ir * x_t + b_ir + W_hr * h_{t-1} + b_hr)
- # z_t = sigmoid(W_iz * x_t + b_iz + W_hz * h_{t-1} + b_hz)
- # n_t = tanh(W_in * x_t + b_in + r_t * (W_hn * h_{t-1} + b_hn))
- # h_t = (1 - z_t) * n_t + z_t * h_{t-1}
- dtype = input_tensor.struct_info.dtype
+ # Update gate weights
+ weight_ih_z = self.block_builder.emit(
+ relax.op.strided_slice(weight_ih, axes=[0], begin=[gate_size],
end=[2 * gate_size])
+ )
+ weight_hh_z = self.block_builder.emit(
+ relax.op.strided_slice(weight_hh, axes=[0], begin=[gate_size],
end=[2 * gate_size])
+ )
- # Reshape input for processing
- if batch_first:
- # Input: (batch, seq_len, input_size) -> (seq_len, batch,
input_size)
- input_reshaped = self.block_builder.emit(
- relax.op.permute_dims(input_tensor, axes=[1, 0, 2])
- )
- else:
- input_reshaped = input_tensor
+ # New gate weights
+ weight_ih_n = self.block_builder.emit(
+ relax.op.strided_slice(weight_ih, axes=[0], begin=[2 * gate_size],
end=[3 * gate_size])
+ )
+ weight_hh_n = self.block_builder.emit(
+ relax.op.strided_slice(weight_hh, axes=[0], begin=[2 * gate_size],
end=[3 * gate_size])
+ )
- # Initialize hidden states for all layers
- if hx is not None:
- # hx shape: (num_layers, batch_size, hidden_size)
- h_states = []
- for layer in range(num_layers):
- h_layer = self.block_builder.emit(
- relax.op.take(hx, relax.const(layer, "int64"), axis=0,
mode="clip")
- )
- h_states.append(h_layer)
- else:
- h_states = []
- for layer in range(num_layers):
- h_layer = self.block_builder.emit(
- relax.op.zeros(relax.ShapeExpr((batch_size, hidden_size)),
dtype)
- )
- h_states.append(h_layer)
+ # Transpose weights for matmul
+ weight_ih_r_t =
self.block_builder.emit(relax.op.permute_dims(weight_ih_r, axes=[1, 0]))
+ weight_hh_r_t =
self.block_builder.emit(relax.op.permute_dims(weight_hh_r, axes=[1, 0]))
+ weight_ih_z_t =
self.block_builder.emit(relax.op.permute_dims(weight_ih_z, axes=[1, 0]))
+ weight_hh_z_t =
self.block_builder.emit(relax.op.permute_dims(weight_hh_z, axes=[1, 0]))
+ weight_ih_n_t =
self.block_builder.emit(relax.op.permute_dims(weight_ih_n, axes=[1, 0]))
+ weight_hh_n_t =
self.block_builder.emit(relax.op.permute_dims(weight_hh_n, axes=[1, 0]))
outputs = []
+ time_steps = range(seq_len - 1, -1, -1) if reverse else range(seq_len)
- for t in range(seq_len):
+ for t in time_steps:
# Get input at time t: (batch_size, input_size)
x_t = self.block_builder.emit(
relax.op.take(input_reshaped, relax.const(t, "int64"), axis=0,
mode="clip")
)
- # Process through each layer
- current_input = x_t
- new_h_states = []
-
- for layer in range(num_layers):
- # Get layer parameters
- if params and len(params) >= 4 * num_layers:
- # Multi-layer case: params are organized as
- # [layer0_ih, layer0_hh, layer0_bias_ih, layer0_bias_hh,
layer1_ih, ...]
- param_offset = layer * 4
- weight_ih = params[param_offset]
- weight_hh = params[param_offset + 1]
- bias_ih = params[param_offset + 2] if has_biases else None
- bias_hh = params[param_offset + 3] if has_biases else None
- elif params and len(params) >= 4:
- # Single layer case
- weight_ih = params[0]
- weight_hh = params[1]
- bias_ih = params[2] if has_biases else None
- bias_hh = params[3] if has_biases else None
- else:
- # Fallback: create zero weights
- weight_ih = self.block_builder.emit(
- relax.op.zeros(
- relax.ShapeExpr(
- (3 * hidden_size, input_size if layer == 0
else hidden_size)
- ),
- dtype,
- )
- )
- weight_hh = self.block_builder.emit(
- relax.op.zeros(relax.ShapeExpr((3 * hidden_size,
hidden_size)), dtype)
- )
- bias_ih = None
- bias_hh = None
-
- # Get previous hidden state for this layer
- h_prev = h_states[layer]
-
- # Split weights by gates: PyTorch GRU gate order: reset,
update, new (r, z, n)
- gate_size = hidden_size
-
- # Reset gate weights
- weight_ih_r = self.block_builder.emit(
- relax.op.strided_slice(weight_ih, axes=[0], begin=[0],
end=[gate_size])
+ # Compute reset gate: r_t = sigmoid(W_ir * x_t + b_ir + W_hr *
h_{t-1} + b_hr)
+ r_ih = self.block_builder.emit(relax.op.linear_algebra.matmul(x_t,
weight_ih_r_t))
+ r_hh =
self.block_builder.emit(relax.op.linear_algebra.matmul(h_prev, weight_hh_r_t))
+ if bias_ih is not None and bias_hh is not None:
+ bias_ih_r = self.block_builder.emit(
+ relax.op.strided_slice(bias_ih, axes=[0], begin=[0],
end=[gate_size])
)
- weight_hh_r = self.block_builder.emit(
- relax.op.strided_slice(weight_hh, axes=[0], begin=[0],
end=[gate_size])
+ bias_hh_r = self.block_builder.emit(
+ relax.op.strided_slice(bias_hh, axes=[0], begin=[0],
end=[gate_size])
)
+ r_t = self.block_builder.emit(
+ relax.op.sigmoid(
+ relax.op.add(relax.op.add(relax.op.add(r_ih,
bias_ih_r), r_hh), bias_hh_r)
+ )
+ )
+ else:
+ r_t =
self.block_builder.emit(relax.op.sigmoid(relax.op.add(r_ih, r_hh)))
- # Update gate weights
- weight_ih_z = self.block_builder.emit(
+ # Compute update gate: z_t = sigmoid(W_iz * x_t + b_iz + W_hz *
h_{t-1} + b_hz)
+ z_ih = self.block_builder.emit(relax.op.linear_algebra.matmul(x_t,
weight_ih_z_t))
+ z_hh =
self.block_builder.emit(relax.op.linear_algebra.matmul(h_prev, weight_hh_z_t))
+ if bias_ih is not None and bias_hh is not None:
+ bias_ih_z = self.block_builder.emit(
relax.op.strided_slice(
- weight_ih, axes=[0], begin=[gate_size], end=[2 *
gate_size]
+ bias_ih, axes=[0], begin=[gate_size], end=[2 *
gate_size]
)
)
- weight_hh_z = self.block_builder.emit(
+ bias_hh_z = self.block_builder.emit(
relax.op.strided_slice(
- weight_hh, axes=[0], begin=[gate_size], end=[2 *
gate_size]
+ bias_hh, axes=[0], begin=[gate_size], end=[2 *
gate_size]
+ )
+ )
+ z_t = self.block_builder.emit(
+ relax.op.sigmoid(
+ relax.op.add(relax.op.add(relax.op.add(z_ih,
bias_ih_z), z_hh), bias_hh_z)
)
)
+ else:
+ z_t =
self.block_builder.emit(relax.op.sigmoid(relax.op.add(z_ih, z_hh)))
- # New gate weights
- weight_ih_n = self.block_builder.emit(
+ # Compute new gate: n_t = tanh(W_in * x_t + b_in + r_t * (W_hn *
h_{t-1} + b_hn))
+ n_ih = self.block_builder.emit(relax.op.linear_algebra.matmul(x_t,
weight_ih_n_t))
+ n_hh =
self.block_builder.emit(relax.op.linear_algebra.matmul(h_prev, weight_hh_n_t))
+ if bias_ih is not None and bias_hh is not None:
+ bias_ih_n = self.block_builder.emit(
relax.op.strided_slice(
- weight_ih, axes=[0], begin=[2 * gate_size], end=[3 *
gate_size]
+ bias_ih, axes=[0], begin=[2 * gate_size], end=[3 *
gate_size]
)
)
- weight_hh_n = self.block_builder.emit(
+ bias_hh_n = self.block_builder.emit(
relax.op.strided_slice(
- weight_hh, axes=[0], begin=[2 * gate_size], end=[3 *
gate_size]
+ bias_hh, axes=[0], begin=[2 * gate_size], end=[3 *
gate_size]
)
)
-
- # Transpose weights for matmul
- weight_ih_r_t = self.block_builder.emit(
- relax.op.permute_dims(weight_ih_r, axes=[1, 0])
- )
- weight_hh_r_t = self.block_builder.emit(
- relax.op.permute_dims(weight_hh_r, axes=[1, 0])
- )
- weight_ih_z_t = self.block_builder.emit(
- relax.op.permute_dims(weight_ih_z, axes=[1, 0])
- )
- weight_hh_z_t = self.block_builder.emit(
- relax.op.permute_dims(weight_hh_z, axes=[1, 0])
- )
- weight_ih_n_t = self.block_builder.emit(
- relax.op.permute_dims(weight_ih_n, axes=[1, 0])
- )
- weight_hh_n_t = self.block_builder.emit(
- relax.op.permute_dims(weight_hh_n, axes=[1, 0])
- )
-
- # Compute reset gate: r_t = sigmoid(W_ir * x_t + b_ir + W_hr *
h_{t-1} + b_hr)
- r_ih = self.block_builder.emit(
- relax.op.linear_algebra.matmul(current_input,
weight_ih_r_t)
- )
- r_hh = self.block_builder.emit(
- relax.op.linear_algebra.matmul(h_prev, weight_hh_r_t)
- )
- if bias_ih is not None and bias_hh is not None:
- bias_ih_r = self.block_builder.emit(
- relax.op.strided_slice(bias_ih, axes=[0], begin=[0],
end=[gate_size])
- )
- bias_hh_r = self.block_builder.emit(
- relax.op.strided_slice(bias_hh, axes=[0], begin=[0],
end=[gate_size])
- )
- r_t = self.block_builder.emit(
- relax.op.sigmoid(
- relax.op.add(
- relax.op.add(relax.op.add(r_ih, bias_ih_r),
r_hh), bias_hh_r
- )
+ n_t = self.block_builder.emit(
+ relax.op.tanh(
+ relax.op.add(
+ relax.op.add(n_ih, bias_ih_n),
+ relax.op.multiply(r_t, relax.op.add(n_hh,
bias_hh_n)),
)
)
- else:
- r_t =
self.block_builder.emit(relax.op.sigmoid(relax.op.add(r_ih, r_hh)))
-
- # Compute update gate: z_t = sigmoid(W_iz * x_t + b_iz + W_hz
* h_{t-1} + b_hz)
- z_ih = self.block_builder.emit(
- relax.op.linear_algebra.matmul(current_input,
weight_ih_z_t)
)
- z_hh = self.block_builder.emit(
- relax.op.linear_algebra.matmul(h_prev, weight_hh_z_t)
+ else:
+ n_t = self.block_builder.emit(
+ relax.op.tanh(relax.op.add(n_ih, relax.op.multiply(r_t,
n_hh)))
)
- if bias_ih is not None and bias_hh is not None:
- bias_ih_z = self.block_builder.emit(
- relax.op.strided_slice(
- bias_ih, axes=[0], begin=[gate_size], end=[2 *
gate_size]
- )
- )
- bias_hh_z = self.block_builder.emit(
- relax.op.strided_slice(
- bias_hh, axes=[0], begin=[gate_size], end=[2 *
gate_size]
- )
- )
- z_t = self.block_builder.emit(
- relax.op.sigmoid(
- relax.op.add(
- relax.op.add(relax.op.add(z_ih, bias_ih_z),
z_hh), bias_hh_z
- )
- )
- )
- else:
- z_t =
self.block_builder.emit(relax.op.sigmoid(relax.op.add(z_ih, z_hh)))
- # Compute new gate: n_t = tanh(W_in * x_t + b_in + r_t * (W_hn
* h_{t-1} + b_hn))
- n_ih = self.block_builder.emit(
- relax.op.linear_algebra.matmul(current_input,
weight_ih_n_t)
+ # Update hidden state: h_t = (1 - z_t) * n_t + z_t * h_{t-1}
+ one_minus_z =
self.block_builder.emit(relax.op.subtract(relax.const(1.0, dtype), z_t))
+ h_t = self.block_builder.emit(
+ relax.op.add(relax.op.multiply(one_minus_z, n_t),
relax.op.multiply(z_t, h_prev))
+ )
+
+ outputs.append(h_t)
+ h_prev = h_t
+
+ if reverse:
+ outputs = outputs[::-1]
+
+ output = self.block_builder.emit(relax.op.stack(outputs, axis=0))
+ return output
+
+ def _gru(self, node: fx.Node) -> relax.Var:
+ args = self.retrieve_args(node)
+ input_tensor = args[0]
+ hx = args[1] if len(args) > 1 else None
+ params = args[2] if len(args) > 2 else None
+ has_biases = args[3] if len(args) > 3 else True
+ num_layers = args[4] if len(args) > 4 else 1
+ _dropout = args[5] if len(args) > 5 else 0.0 # Not used in inference
+ _train = args[6] if len(args) > 6 else False # Not used in inference
+ bidirectional = args[7] if len(args) > 7 else False
+ batch_first = args[8] if len(args) > 8 else False
+
+ if num_layers > 1:
+ raise NotImplementedError("Multi-layer GRU is not yet supported")
+
+ input_shape = self.shape_of(input_tensor)
+ if batch_first:
+ batch_size, seq_len, input_size = input_shape
+ else:
+ seq_len, batch_size, input_size = input_shape
+
+ seq_len = int(seq_len) if isinstance(seq_len, tvm.tir.IntImm) else
seq_len
+ batch_size = int(batch_size) if isinstance(batch_size, tvm.tir.IntImm)
else batch_size
+ input_size = int(input_size) if isinstance(input_size, tvm.tir.IntImm)
else input_size
+
+ # Extract hidden size from parameters
+ # For bidirectional: params has weights for both directions
+ # params_per_direction = 4 if has_biases else 2 (weight_ih, weight_hh,
[bias_ih, bias_hh])
+ params_per_direction = 4 if has_biases else 2
+
+ if params and len(params) >= 2:
+ # Extract hidden size from weight dimensions
+ # weight_ih has shape (3 * hidden_size, input_size)
+ weight_ih_shape = self.shape_of(params[0])
+ hidden_size = weight_ih_shape[0] // 3 # 3 gates: reset, update,
new
+ else:
+ # Fallback to a default hidden size
+ hidden_size = 16
+
+ dtype = input_tensor.struct_info.dtype
+
+ # Extract forward direction weights
+ if params and len(params) >= params_per_direction:
+ weight_ih_fwd = params[0]
+ weight_hh_fwd = params[1]
+ bias_ih_fwd = params[2] if has_biases else None
+ bias_hh_fwd = params[3] if has_biases else None
+ else:
+ # Fallback: create zero weights
+ weight_ih_fwd = self.block_builder.emit(
+ relax.op.zeros(relax.ShapeExpr((3 * hidden_size, input_size)),
dtype)
+ )
+ weight_hh_fwd = self.block_builder.emit(
+ relax.op.zeros(relax.ShapeExpr((3 * hidden_size,
hidden_size)), dtype)
+ )
+ bias_ih_fwd = None
+ bias_hh_fwd = None
+
+ # Extract or create backward direction weights if bidirectional
+ if bidirectional:
+ if params and len(params) >= params_per_direction * 2:
+ weight_ih_bwd = params[params_per_direction]
+ weight_hh_bwd = params[params_per_direction + 1]
+ bias_ih_bwd = params[params_per_direction + 2] if has_biases
else None
+ bias_hh_bwd = params[params_per_direction + 3] if has_biases
else None
+ else:
+ # Fallback: create zero weights
+ weight_ih_bwd = self.block_builder.emit(
+ relax.op.zeros(relax.ShapeExpr((3 * hidden_size,
input_size)), dtype)
)
- n_hh = self.block_builder.emit(
- relax.op.linear_algebra.matmul(h_prev, weight_hh_n_t)
+ weight_hh_bwd = self.block_builder.emit(
+ relax.op.zeros(relax.ShapeExpr((3 * hidden_size,
hidden_size)), dtype)
)
- if bias_ih is not None and bias_hh is not None:
- bias_ih_n = self.block_builder.emit(
- relax.op.strided_slice(
- bias_ih, axes=[0], begin=[2 * gate_size], end=[3 *
gate_size]
- )
- )
- bias_hh_n = self.block_builder.emit(
- relax.op.strided_slice(
- bias_hh, axes=[0], begin=[2 * gate_size], end=[3 *
gate_size]
- )
- )
- n_t = self.block_builder.emit(
- relax.op.tanh(
- relax.op.add(
- relax.op.add(n_ih, bias_ih_n),
- relax.op.multiply(r_t, relax.op.add(n_hh,
bias_hh_n)),
- )
- )
- )
- else:
- n_t = self.block_builder.emit(
- relax.op.tanh(relax.op.add(n_ih,
relax.op.multiply(r_t, n_hh)))
- )
+ bias_ih_bwd = None
+ bias_hh_bwd = None
+ else:
+ weight_ih_bwd = None
+ weight_hh_bwd = None
+ bias_ih_bwd = None
+ bias_hh_bwd = None
- # Update hidden state: h_t = (1 - z_t) * n_t + z_t * h_{t-1}
- one_minus_z = self.block_builder.emit(
- relax.op.subtract(relax.const(1.0, dtype), z_t)
+ # Initialize hidden states
+ if hx is not None:
+ h_prev_fwd = self.block_builder.emit(
+ relax.op.take(hx, relax.const(0, "int64"), axis=0, mode="clip")
+ )
+ if bidirectional:
+ h_prev_bwd = self.block_builder.emit(
+ relax.op.take(hx, relax.const(1, "int64"), axis=0,
mode="clip")
)
- h_t = self.block_builder.emit(
- relax.op.add(
- relax.op.multiply(one_minus_z, n_t),
relax.op.multiply(z_t, h_prev)
- )
+ else:
+ h_prev_bwd = None
+ else:
+ h_prev_fwd = self.block_builder.emit(
+ relax.op.zeros(relax.ShapeExpr((batch_size, hidden_size)),
dtype)
+ )
+ if bidirectional:
+ h_prev_bwd = self.block_builder.emit(
+ relax.op.zeros(relax.ShapeExpr((batch_size, hidden_size)),
dtype)
)
+ else:
+ h_prev_bwd = None
- new_h_states.append(h_t)
-
- current_input = h_t
-
- # Update hidden states for next time step
- h_states = new_h_states
+ # Reshape input for processing
+ input_reshaped = (
+ self.block_builder.emit(relax.op.permute_dims(input_tensor,
axes=[1, 0, 2]))
+ if batch_first
+ else input_tensor
+ )
- # Store output (from the last layer)
- outputs.append(h_states[-1])
+ # Process forward direction
+ output_fwd = self._gru_cell_unroll(
+ input_reshaped,
+ weight_ih_fwd,
+ weight_hh_fwd,
+ bias_ih_fwd,
+ bias_hh_fwd,
+ h_prev_fwd,
+ seq_len,
+ hidden_size,
+ dtype,
+ reverse=False,
+ )
- # Stack outputs: (seq_len, batch_size, hidden_size)
- output = self.block_builder.emit(relax.op.stack(outputs, axis=0))
+ # Process backward direction if bidirectional
+ if bidirectional:
+ output_bwd = self._gru_cell_unroll(
+ input_reshaped,
+ weight_ih_bwd,
+ weight_hh_bwd,
+ bias_ih_bwd,
+ bias_hh_bwd,
+ h_prev_bwd,
+ seq_len,
+ hidden_size,
+ dtype,
+ reverse=True,
+ )
+ # Concatenate forward and backward outputs along feature dimension
+ output = self.block_builder.emit(relax.op.concat([output_fwd,
output_bwd], axis=2))
+ else:
+ output = output_fwd
# Reshape back to batch_first if needed
if batch_first:
diff --git a/tests/python/relax/test_frontend_from_exported_program.py
b/tests/python/relax/test_frontend_from_exported_program.py
index 7397b3f21a..0658dbfaf3 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -8011,6 +8011,92 @@ def test_gru():
assert pytorch_output2.shape == tvm_output2_np.shape
tvm.testing.assert_allclose(pytorch_output2.numpy(), tvm_output2_np,
rtol=1e-4, atol=1e-5)
+ # Test bidirectional GRU with batch_first=True
+ class BidirectionalGRU(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.gru = nn.GRU(
+ input_size=4,
+ hidden_size=5,
+ num_layers=1,
+ batch_first=True,
+ bidirectional=True,
+ )
+
+ def forward(self, x):
+ y, _ = self.gru(x)
+ return y
+
+ torch.manual_seed(44)
+ x3 = torch.randn(2, 3, 4, dtype=torch.float32)
+ model3 = BidirectionalGRU()
+ with torch.no_grad():
+ pytorch_output3 = model3(x3)
+
+ # Verify output shape is correct (hidden_size * 2 due to bidirectional)
+ assert pytorch_output3.shape == (
+ 2,
+ 3,
+ 10,
+ ), f"Expected shape (2, 3, 10), got {pytorch_output3.shape}"
+
+ exported_program3 = export(model3, args=(x3,))
+ mod3 = from_exported_program(exported_program3)
+ ex3 = relax.build(mod3, target)
+ vm3 = relax.VirtualMachine(ex3, tvm.cpu())
+ x3_tvm = tvm.runtime.tensor(x3.numpy())
+ tvm_output3 = vm3["main"](x3_tvm)
+ if hasattr(tvm_output3, "numpy"):
+ tvm_output3_np = tvm_output3.numpy()
+ else:
+ tvm_output3_np = tvm_output3[0].numpy()
+ assert (
+ pytorch_output3.shape == tvm_output3_np.shape
+ ), f"Shape mismatch: PyTorch {pytorch_output3.shape} vs TVM
{tvm_output3_np.shape}"
+ tvm.testing.assert_allclose(pytorch_output3.numpy(), tvm_output3_np,
rtol=1e-4, atol=1e-5)
+
+ # Test bidirectional GRU with batch_first=False
+ class SeqFirstBidirectionalGRU(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.gru = nn.GRU(
+ input_size=3,
+ hidden_size=4,
+ num_layers=1,
+ batch_first=False,
+ bidirectional=True,
+ )
+
+ def forward(self, x):
+ y, _ = self.gru(x)
+ return y
+
+ torch.manual_seed(45)
+ x4 = torch.randn(4, 2, 3, dtype=torch.float32) # (seq_len, batch,
input_size)
+ model4 = SeqFirstBidirectionalGRU()
+ with torch.no_grad():
+ pytorch_output4 = model4(x4)
+
+ # Verify output shape (seq_len, batch, hidden_size * 2)
+ assert pytorch_output4.shape == (
+ 4,
+ 2,
+ 8,
+ ), f"Expected shape (4, 2, 8), got {pytorch_output4.shape}"
+
+ exported_program4 = export(model4, args=(x4,))
+ mod4 = from_exported_program(exported_program4)
+ ex4 = relax.build(mod4, target)
+ vm4 = relax.VirtualMachine(ex4, tvm.cpu())
+ x4_tvm = tvm.runtime.tensor(x4.numpy())
+ tvm_output4 = vm4["main"](x4_tvm)
+ if hasattr(tvm_output4, "numpy"):
+ tvm_output4_np = tvm_output4.numpy()
+ else:
+ tvm_output4_np = tvm_output4[0].numpy()
+ assert pytorch_output4.shape == tvm_output4_np.shape
+ tvm.testing.assert_allclose(pytorch_output4.numpy(), tvm_output4_np,
rtol=1e-4, atol=1e-5)
+
def test_dynamic_shape_with_range_constraints():
class DynamicModel(torch.nn.Module):