This is an automated email from the ASF dual-hosted git repository.
mshr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 3b8d324eb1 [Relax][PyTorch] Support gru op for ExportedProgram
importer (#18360)
3b8d324eb1 is described below
commit 3b8d324eb151e9fcb78f44746a1a4a2ab62cf02e
Author: Shushi Hong <[email protected]>
AuthorDate: Mon Oct 6 10:59:40 2025 -0400
[Relax][PyTorch] Support gru op for ExportedProgram importer (#18360)
---
.../frontend/torch/exported_program_translator.py | 295 +++++++++++++++++++++
.../relax/test_frontend_from_exported_program.py | 71 +++++
2 files changed, 366 insertions(+)
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index c9c55eb8d6..a84c35e622 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -391,6 +391,300 @@ class ExportedProgramImporter(BaseFXGraphImporter):
output = self.block_builder.emit(relax.op.permute_dims(output,
axes=[1, 0, 2]))
return output
+ def _gru(self, node: fx.Node) -> relax.Var:
+ args = self.retrieve_args(node)
+ input_tensor = args[0]
+ hx = args[1] if len(args) > 1 else None
+ params = args[2] if len(args) > 2 else None
+ has_biases = args[3] if len(args) > 3 else True
+ num_layers = args[4] if len(args) > 4 else 1
+ _dropout = args[5] if len(args) > 5 else 0.0 # Not used in inference
+ _train = args[6] if len(args) > 6 else False # Not used in inference
+ bidirectional = args[7] if len(args) > 7 else False
+ batch_first = args[8] if len(args) > 8 else False
+
+ if bidirectional:
+ raise NotImplementedError("Bidirectional GRU is not yet supported")
+
+ input_shape = self.shape_of(input_tensor)
+ if batch_first:
+ batch_size, seq_len, input_size = input_shape
+ else:
+ seq_len, batch_size, input_size = input_shape
+
+ if isinstance(seq_len, tvm.tir.IntImm):
+ seq_len = seq_len.value
+ if isinstance(batch_size, tvm.tir.IntImm):
+ batch_size = batch_size.value
+ if isinstance(input_size, tvm.tir.IntImm):
+ input_size = input_size.value
+
+ if params and len(params) >= 2:
+ # For multi-layer, we need to extract the first layer's weights
+ # to determine hidden size
+ if num_layers > 1:
+ # Multi-layer: params[0] is first layer's weight_ih
+ weight_ih = params[0]
+ else:
+ # Single layer: params[0] is weight_ih
+ weight_ih = params[0]
+ # Extract hidden size from weight dimensions
+ # weight_ih has shape (3 * hidden_size, input_size)
+ weight_ih_shape = self.shape_of(weight_ih)
+ hidden_size = weight_ih_shape[0] // 3 # 3 gates: reset, update,
new
+ else:
+ # Fallback to a default hidden size
+ hidden_size = 16
+
+ # Implement actual GRU computation using Relax operations
+ # GRU equations:
+ # r_t = sigmoid(W_ir * x_t + b_ir + W_hr * h_{t-1} + b_hr)
+ # z_t = sigmoid(W_iz * x_t + b_iz + W_hz * h_{t-1} + b_hz)
+ # n_t = tanh(W_in * x_t + b_in + r_t * (W_hn * h_{t-1} + b_hn))
+ # h_t = (1 - z_t) * n_t + z_t * h_{t-1}
+ dtype = input_tensor.struct_info.dtype
+
+ # Reshape input for processing
+ if batch_first:
+ # Input: (batch, seq_len, input_size) -> (seq_len, batch,
input_size)
+ input_reshaped = self.block_builder.emit(
+ relax.op.permute_dims(input_tensor, axes=[1, 0, 2])
+ )
+ else:
+ input_reshaped = input_tensor
+
+ # Initialize hidden states for all layers
+ if hx is not None:
+ # hx shape: (num_layers, batch_size, hidden_size)
+ h_states = []
+ for layer in range(num_layers):
+ h_layer = self.block_builder.emit(
+ relax.op.take(hx, relax.const(layer, "int64"), axis=0,
mode="clip")
+ )
+ h_states.append(h_layer)
+ else:
+ h_states = []
+ for layer in range(num_layers):
+ h_layer = self.block_builder.emit(
+ relax.op.zeros(relax.ShapeExpr((batch_size, hidden_size)),
dtype)
+ )
+ h_states.append(h_layer)
+
+ outputs = []
+
+ for t in range(seq_len):
+ # Get input at time t: (batch_size, input_size)
+ x_t = self.block_builder.emit(
+ relax.op.take(input_reshaped, relax.const(t, "int64"), axis=0,
mode="clip")
+ )
+
+ # Process through each layer
+ current_input = x_t
+ new_h_states = []
+
+ for layer in range(num_layers):
+ # Get layer parameters
+ if params and len(params) >= 4 * num_layers:
+ # Multi-layer case: params are organized as
+ # [layer0_ih, layer0_hh, layer0_bias_ih, layer0_bias_hh,
layer1_ih, ...]
+ param_offset = layer * 4
+ weight_ih = params[param_offset]
+ weight_hh = params[param_offset + 1]
+ bias_ih = params[param_offset + 2] if has_biases else None
+ bias_hh = params[param_offset + 3] if has_biases else None
+ elif params and len(params) >= 4:
+ # Single layer case
+ weight_ih = params[0]
+ weight_hh = params[1]
+ bias_ih = params[2] if has_biases else None
+ bias_hh = params[3] if has_biases else None
+ else:
+ # Fallback: create zero weights
+ weight_ih = self.block_builder.emit(
+ relax.op.zeros(
+ relax.ShapeExpr(
+ (3 * hidden_size, input_size if layer == 0
else hidden_size)
+ ),
+ dtype,
+ )
+ )
+ weight_hh = self.block_builder.emit(
+ relax.op.zeros(relax.ShapeExpr((3 * hidden_size,
hidden_size)), dtype)
+ )
+ bias_ih = None
+ bias_hh = None
+
+ # Get previous hidden state for this layer
+ h_prev = h_states[layer]
+
+ # Split weights by gates: PyTorch GRU gate order: reset,
update, new (r, z, n)
+ gate_size = hidden_size
+
+ # Reset gate weights
+ weight_ih_r = self.block_builder.emit(
+ relax.op.strided_slice(weight_ih, axes=[0], begin=[0],
end=[gate_size])
+ )
+ weight_hh_r = self.block_builder.emit(
+ relax.op.strided_slice(weight_hh, axes=[0], begin=[0],
end=[gate_size])
+ )
+
+ # Update gate weights
+ weight_ih_z = self.block_builder.emit(
+ relax.op.strided_slice(
+ weight_ih, axes=[0], begin=[gate_size], end=[2 *
gate_size]
+ )
+ )
+ weight_hh_z = self.block_builder.emit(
+ relax.op.strided_slice(
+ weight_hh, axes=[0], begin=[gate_size], end=[2 *
gate_size]
+ )
+ )
+
+ # New gate weights
+ weight_ih_n = self.block_builder.emit(
+ relax.op.strided_slice(
+ weight_ih, axes=[0], begin=[2 * gate_size], end=[3 *
gate_size]
+ )
+ )
+ weight_hh_n = self.block_builder.emit(
+ relax.op.strided_slice(
+ weight_hh, axes=[0], begin=[2 * gate_size], end=[3 *
gate_size]
+ )
+ )
+
+ # Transpose weights for matmul
+ weight_ih_r_t = self.block_builder.emit(
+ relax.op.permute_dims(weight_ih_r, axes=[1, 0])
+ )
+ weight_hh_r_t = self.block_builder.emit(
+ relax.op.permute_dims(weight_hh_r, axes=[1, 0])
+ )
+ weight_ih_z_t = self.block_builder.emit(
+ relax.op.permute_dims(weight_ih_z, axes=[1, 0])
+ )
+ weight_hh_z_t = self.block_builder.emit(
+ relax.op.permute_dims(weight_hh_z, axes=[1, 0])
+ )
+ weight_ih_n_t = self.block_builder.emit(
+ relax.op.permute_dims(weight_ih_n, axes=[1, 0])
+ )
+ weight_hh_n_t = self.block_builder.emit(
+ relax.op.permute_dims(weight_hh_n, axes=[1, 0])
+ )
+
+ # Compute reset gate: r_t = sigmoid(W_ir * x_t + b_ir + W_hr *
h_{t-1} + b_hr)
+ r_ih = self.block_builder.emit(
+ relax.op.linear_algebra.matmul(current_input,
weight_ih_r_t)
+ )
+ r_hh = self.block_builder.emit(
+ relax.op.linear_algebra.matmul(h_prev, weight_hh_r_t)
+ )
+ if bias_ih is not None and bias_hh is not None:
+ bias_ih_r = self.block_builder.emit(
+ relax.op.strided_slice(bias_ih, axes=[0], begin=[0],
end=[gate_size])
+ )
+ bias_hh_r = self.block_builder.emit(
+ relax.op.strided_slice(bias_hh, axes=[0], begin=[0],
end=[gate_size])
+ )
+ r_t = self.block_builder.emit(
+ relax.op.sigmoid(
+ relax.op.add(
+ relax.op.add(relax.op.add(r_ih, bias_ih_r),
r_hh), bias_hh_r
+ )
+ )
+ )
+ else:
+ r_t =
self.block_builder.emit(relax.op.sigmoid(relax.op.add(r_ih, r_hh)))
+
+ # Compute update gate: z_t = sigmoid(W_iz * x_t + b_iz + W_hz
* h_{t-1} + b_hz)
+ z_ih = self.block_builder.emit(
+ relax.op.linear_algebra.matmul(current_input,
weight_ih_z_t)
+ )
+ z_hh = self.block_builder.emit(
+ relax.op.linear_algebra.matmul(h_prev, weight_hh_z_t)
+ )
+ if bias_ih is not None and bias_hh is not None:
+ bias_ih_z = self.block_builder.emit(
+ relax.op.strided_slice(
+ bias_ih, axes=[0], begin=[gate_size], end=[2 *
gate_size]
+ )
+ )
+ bias_hh_z = self.block_builder.emit(
+ relax.op.strided_slice(
+ bias_hh, axes=[0], begin=[gate_size], end=[2 *
gate_size]
+ )
+ )
+ z_t = self.block_builder.emit(
+ relax.op.sigmoid(
+ relax.op.add(
+ relax.op.add(relax.op.add(z_ih, bias_ih_z),
z_hh), bias_hh_z
+ )
+ )
+ )
+ else:
+ z_t =
self.block_builder.emit(relax.op.sigmoid(relax.op.add(z_ih, z_hh)))
+
+ # Compute new gate: n_t = tanh(W_in * x_t + b_in + r_t * (W_hn
* h_{t-1} + b_hn))
+ n_ih = self.block_builder.emit(
+ relax.op.linear_algebra.matmul(current_input,
weight_ih_n_t)
+ )
+ n_hh = self.block_builder.emit(
+ relax.op.linear_algebra.matmul(h_prev, weight_hh_n_t)
+ )
+ if bias_ih is not None and bias_hh is not None:
+ bias_ih_n = self.block_builder.emit(
+ relax.op.strided_slice(
+ bias_ih, axes=[0], begin=[2 * gate_size], end=[3 *
gate_size]
+ )
+ )
+ bias_hh_n = self.block_builder.emit(
+ relax.op.strided_slice(
+ bias_hh, axes=[0], begin=[2 * gate_size], end=[3 *
gate_size]
+ )
+ )
+ n_t = self.block_builder.emit(
+ relax.op.tanh(
+ relax.op.add(
+ relax.op.add(n_ih, bias_ih_n),
+ relax.op.multiply(r_t, relax.op.add(n_hh,
bias_hh_n)),
+ )
+ )
+ )
+ else:
+ n_t = self.block_builder.emit(
+ relax.op.tanh(relax.op.add(n_ih,
relax.op.multiply(r_t, n_hh)))
+ )
+
+ # Update hidden state: h_t = (1 - z_t) * n_t + z_t * h_{t-1}
+ one_minus_z = self.block_builder.emit(
+ relax.op.subtract(relax.const(1.0, dtype), z_t)
+ )
+ h_t = self.block_builder.emit(
+ relax.op.add(
+ relax.op.multiply(one_minus_z, n_t),
relax.op.multiply(z_t, h_prev)
+ )
+ )
+
+ new_h_states.append(h_t)
+
+ current_input = h_t
+
+ # Update hidden states for next time step
+ h_states = new_h_states
+
+ # Store output (from the last layer)
+ outputs.append(h_states[-1])
+
+ # Stack outputs: (seq_len, batch_size, hidden_size)
+ output = self.block_builder.emit(relax.op.stack(outputs, axis=0))
+
+ # Reshape back to batch_first if needed
+ if batch_first:
+ # (seq_len, batch_size, hidden_size) -> (batch_size, seq_len,
hidden_size)
+ output = self.block_builder.emit(relax.op.permute_dims(output,
axes=[1, 0, 2]))
+
+ return output
+
########## Manipulation ##########
def _narrow(self, node: fx.Node) -> relax.Var:
@@ -652,6 +946,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
"layer_norm.default": self._layer_norm,
"linear.default": self._linear,
"lstm.input": self._lstm,
+ "gru.input": self._gru,
"max_pool1d.default": self._max_pool1d,
"max_pool2d.default": self._max_pool2d,
"max_pool3d.default": self._max_pool3d,
diff --git a/tests/python/relax/test_frontend_from_exported_program.py
b/tests/python/relax/test_frontend_from_exported_program.py
index b35af088b5..657ade455b 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -6050,5 +6050,76 @@ def test_tensor_none_tuple():
verify_model(TensorNoneModel(), example_args, {}, Expected)
+def test_gru():
+ class BasicGRU(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.gru = nn.GRU(
+ input_size=4,
+ hidden_size=8,
+ num_layers=1,
+ batch_first=True,
+ bidirectional=False,
+ )
+
+ def forward(self, x):
+ y, _ = self.gru(x)
+ return y
+
+ torch.manual_seed(42)
+ x = torch.randn(2, 3, 4, dtype=torch.float32)
+ model = BasicGRU()
+ with torch.no_grad():
+ pytorch_output = model(x)
+ exported_program = export(model, args=(x,))
+ mod = from_exported_program(exported_program)
+ target = tvm.target.Target("llvm")
+ ex = relax.build(mod, target)
+ vm = relax.VirtualMachine(ex, tvm.cpu())
+ x_tvm = tvm.runtime.tensor(x.numpy())
+ tvm_output = vm["main"](x_tvm)
+ if hasattr(tvm_output, "numpy"):
+ tvm_output_np = tvm_output.numpy()
+ else:
+ tvm_output_np = tvm_output[0].numpy()
+ assert (
+ pytorch_output.shape == tvm_output_np.shape
+ ), f"Shape mismatch: PyTorch {pytorch_output.shape} vs TVM
{tvm_output_np.shape}"
+ np.testing.assert_allclose(pytorch_output.numpy(), tvm_output_np,
rtol=1e-4, atol=1e-5)
+
+ class SeqFirstGRU(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.gru = nn.GRU(
+ input_size=3,
+ hidden_size=6,
+ num_layers=1,
+ batch_first=False,
+ bidirectional=False,
+ )
+
+ def forward(self, x):
+ y, _ = self.gru(x)
+ return y
+
+ torch.manual_seed(43)
+ x2 = torch.randn(4, 2, 3, dtype=torch.float32)
+ model2 = SeqFirstGRU()
+ with torch.no_grad():
+ pytorch_output2 = model2(x2)
+ exported_program2 = export(model2, args=(x2,))
+ mod2 = from_exported_program(exported_program2)
+ ex2 = relax.build(mod2, target)
+ vm2 = relax.VirtualMachine(ex2, tvm.cpu())
+ x2_tvm = tvm.runtime.tensor(x2.numpy())
+ tvm_output2 = vm2["main"](x2_tvm)
+ if hasattr(tvm_output2, "numpy"):
+ tvm_output2_np = tvm_output2.numpy()
+ else:
+ tvm_output2_np = tvm_output2[0].numpy()
+ assert pytorch_output2.shape == tvm_output2_np.shape
+ np.testing.assert_allclose(pytorch_output2.numpy(), tvm_output2_np,
rtol=1e-4, atol=1e-5)
+
+
if __name__ == "__main__":
tvm.testing.main()