This is an automated email from the ASF dual-hosted git repository.

mshr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 3b8d324eb1 [Relax][PyTorch] Support gru op for ExportedProgram 
importer (#18360)
3b8d324eb1 is described below

commit 3b8d324eb151e9fcb78f44746a1a4a2ab62cf02e
Author: Shushi Hong <[email protected]>
AuthorDate: Mon Oct 6 10:59:40 2025 -0400

    [Relax][PyTorch] Support gru op for ExportedProgram importer (#18360)
---
 .../frontend/torch/exported_program_translator.py  | 295 +++++++++++++++++++++
 .../relax/test_frontend_from_exported_program.py   |  71 +++++
 2 files changed, 366 insertions(+)

diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py 
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index c9c55eb8d6..a84c35e622 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -391,6 +391,300 @@ class ExportedProgramImporter(BaseFXGraphImporter):
             output = self.block_builder.emit(relax.op.permute_dims(output, 
axes=[1, 0, 2]))
         return output
 
+    def _gru(self, node: fx.Node) -> relax.Var:
+        args = self.retrieve_args(node)
+        input_tensor = args[0]
+        hx = args[1] if len(args) > 1 else None
+        params = args[2] if len(args) > 2 else None
+        has_biases = args[3] if len(args) > 3 else True
+        num_layers = args[4] if len(args) > 4 else 1
+        _dropout = args[5] if len(args) > 5 else 0.0  # Not used in inference
+        _train = args[6] if len(args) > 6 else False  # Not used in inference
+        bidirectional = args[7] if len(args) > 7 else False
+        batch_first = args[8] if len(args) > 8 else False
+
+        if bidirectional:
+            raise NotImplementedError("Bidirectional GRU is not yet supported")
+
+        input_shape = self.shape_of(input_tensor)
+        if batch_first:
+            batch_size, seq_len, input_size = input_shape
+        else:
+            seq_len, batch_size, input_size = input_shape
+
+        if isinstance(seq_len, tvm.tir.IntImm):
+            seq_len = seq_len.value
+        if isinstance(batch_size, tvm.tir.IntImm):
+            batch_size = batch_size.value
+        if isinstance(input_size, tvm.tir.IntImm):
+            input_size = input_size.value
+
+        if params and len(params) >= 2:
+            # For multi-layer, we need to extract the first layer's weights
+            # to determine hidden size
+            if num_layers > 1:
+                # Multi-layer: params[0] is first layer's weight_ih
+                weight_ih = params[0]
+            else:
+                # Single layer: params[0] is weight_ih
+                weight_ih = params[0]
+            # Extract hidden size from weight dimensions
+            # weight_ih has shape (3 * hidden_size, input_size)
+            weight_ih_shape = self.shape_of(weight_ih)
+            hidden_size = weight_ih_shape[0] // 3  # 3 gates: reset, update, 
new
+        else:
+            # Fallback to a default hidden size
+            hidden_size = 16
+
+        # Implement actual GRU computation using Relax operations
+        # GRU equations:
+        # r_t = sigmoid(W_ir * x_t + b_ir + W_hr * h_{t-1} + b_hr)
+        # z_t = sigmoid(W_iz * x_t + b_iz + W_hz * h_{t-1} + b_hz)
+        # n_t = tanh(W_in * x_t + b_in + r_t * (W_hn * h_{t-1} + b_hn))
+        # h_t = (1 - z_t) * n_t + z_t * h_{t-1}
+        dtype = input_tensor.struct_info.dtype
+
+        # Reshape input for processing
+        if batch_first:
+            # Input: (batch, seq_len, input_size) -> (seq_len, batch, 
input_size)
+            input_reshaped = self.block_builder.emit(
+                relax.op.permute_dims(input_tensor, axes=[1, 0, 2])
+            )
+        else:
+            input_reshaped = input_tensor
+
+        # Initialize hidden states for all layers
+        if hx is not None:
+            # hx shape: (num_layers, batch_size, hidden_size)
+            h_states = []
+            for layer in range(num_layers):
+                h_layer = self.block_builder.emit(
+                    relax.op.take(hx, relax.const(layer, "int64"), axis=0, 
mode="clip")
+                )
+                h_states.append(h_layer)
+        else:
+            h_states = []
+            for layer in range(num_layers):
+                h_layer = self.block_builder.emit(
+                    relax.op.zeros(relax.ShapeExpr((batch_size, hidden_size)), 
dtype)
+                )
+                h_states.append(h_layer)
+
+        outputs = []
+
+        for t in range(seq_len):
+            # Get input at time t: (batch_size, input_size)
+            x_t = self.block_builder.emit(
+                relax.op.take(input_reshaped, relax.const(t, "int64"), axis=0, 
mode="clip")
+            )
+
+            # Process through each layer
+            current_input = x_t
+            new_h_states = []
+
+            for layer in range(num_layers):
+                # Get layer parameters
+                if params and len(params) >= 4 * num_layers:
+                    # Multi-layer case: params are organized as
+                    # [layer0_ih, layer0_hh, layer0_bias_ih, layer0_bias_hh, 
layer1_ih, ...]
+                    param_offset = layer * 4
+                    weight_ih = params[param_offset]
+                    weight_hh = params[param_offset + 1]
+                    bias_ih = params[param_offset + 2] if has_biases else None
+                    bias_hh = params[param_offset + 3] if has_biases else None
+                elif params and len(params) >= 4:
+                    # Single layer case
+                    weight_ih = params[0]
+                    weight_hh = params[1]
+                    bias_ih = params[2] if has_biases else None
+                    bias_hh = params[3] if has_biases else None
+                else:
+                    # Fallback: create zero weights
+                    weight_ih = self.block_builder.emit(
+                        relax.op.zeros(
+                            relax.ShapeExpr(
+                                (3 * hidden_size, input_size if layer == 0 
else hidden_size)
+                            ),
+                            dtype,
+                        )
+                    )
+                    weight_hh = self.block_builder.emit(
+                        relax.op.zeros(relax.ShapeExpr((3 * hidden_size, 
hidden_size)), dtype)
+                    )
+                    bias_ih = None
+                    bias_hh = None
+
+                # Get previous hidden state for this layer
+                h_prev = h_states[layer]
+
+                # Split weights by gates: PyTorch GRU gate order: reset, 
update, new (r, z, n)
+                gate_size = hidden_size
+
+                # Reset gate weights
+                weight_ih_r = self.block_builder.emit(
+                    relax.op.strided_slice(weight_ih, axes=[0], begin=[0], 
end=[gate_size])
+                )
+                weight_hh_r = self.block_builder.emit(
+                    relax.op.strided_slice(weight_hh, axes=[0], begin=[0], 
end=[gate_size])
+                )
+
+                # Update gate weights
+                weight_ih_z = self.block_builder.emit(
+                    relax.op.strided_slice(
+                        weight_ih, axes=[0], begin=[gate_size], end=[2 * 
gate_size]
+                    )
+                )
+                weight_hh_z = self.block_builder.emit(
+                    relax.op.strided_slice(
+                        weight_hh, axes=[0], begin=[gate_size], end=[2 * 
gate_size]
+                    )
+                )
+
+                # New gate weights
+                weight_ih_n = self.block_builder.emit(
+                    relax.op.strided_slice(
+                        weight_ih, axes=[0], begin=[2 * gate_size], end=[3 * 
gate_size]
+                    )
+                )
+                weight_hh_n = self.block_builder.emit(
+                    relax.op.strided_slice(
+                        weight_hh, axes=[0], begin=[2 * gate_size], end=[3 * 
gate_size]
+                    )
+                )
+
+                # Transpose weights for matmul
+                weight_ih_r_t = self.block_builder.emit(
+                    relax.op.permute_dims(weight_ih_r, axes=[1, 0])
+                )
+                weight_hh_r_t = self.block_builder.emit(
+                    relax.op.permute_dims(weight_hh_r, axes=[1, 0])
+                )
+                weight_ih_z_t = self.block_builder.emit(
+                    relax.op.permute_dims(weight_ih_z, axes=[1, 0])
+                )
+                weight_hh_z_t = self.block_builder.emit(
+                    relax.op.permute_dims(weight_hh_z, axes=[1, 0])
+                )
+                weight_ih_n_t = self.block_builder.emit(
+                    relax.op.permute_dims(weight_ih_n, axes=[1, 0])
+                )
+                weight_hh_n_t = self.block_builder.emit(
+                    relax.op.permute_dims(weight_hh_n, axes=[1, 0])
+                )
+
+                # Compute reset gate: r_t = sigmoid(W_ir * x_t + b_ir + W_hr * 
h_{t-1} + b_hr)
+                r_ih = self.block_builder.emit(
+                    relax.op.linear_algebra.matmul(current_input, 
weight_ih_r_t)
+                )
+                r_hh = self.block_builder.emit(
+                    relax.op.linear_algebra.matmul(h_prev, weight_hh_r_t)
+                )
+                if bias_ih is not None and bias_hh is not None:
+                    bias_ih_r = self.block_builder.emit(
+                        relax.op.strided_slice(bias_ih, axes=[0], begin=[0], 
end=[gate_size])
+                    )
+                    bias_hh_r = self.block_builder.emit(
+                        relax.op.strided_slice(bias_hh, axes=[0], begin=[0], 
end=[gate_size])
+                    )
+                    r_t = self.block_builder.emit(
+                        relax.op.sigmoid(
+                            relax.op.add(
+                                relax.op.add(relax.op.add(r_ih, bias_ih_r), 
r_hh), bias_hh_r
+                            )
+                        )
+                    )
+                else:
+                    r_t = 
self.block_builder.emit(relax.op.sigmoid(relax.op.add(r_ih, r_hh)))
+
+                # Compute update gate: z_t = sigmoid(W_iz * x_t + b_iz + W_hz 
* h_{t-1} + b_hz)
+                z_ih = self.block_builder.emit(
+                    relax.op.linear_algebra.matmul(current_input, 
weight_ih_z_t)
+                )
+                z_hh = self.block_builder.emit(
+                    relax.op.linear_algebra.matmul(h_prev, weight_hh_z_t)
+                )
+                if bias_ih is not None and bias_hh is not None:
+                    bias_ih_z = self.block_builder.emit(
+                        relax.op.strided_slice(
+                            bias_ih, axes=[0], begin=[gate_size], end=[2 * 
gate_size]
+                        )
+                    )
+                    bias_hh_z = self.block_builder.emit(
+                        relax.op.strided_slice(
+                            bias_hh, axes=[0], begin=[gate_size], end=[2 * 
gate_size]
+                        )
+                    )
+                    z_t = self.block_builder.emit(
+                        relax.op.sigmoid(
+                            relax.op.add(
+                                relax.op.add(relax.op.add(z_ih, bias_ih_z), 
z_hh), bias_hh_z
+                            )
+                        )
+                    )
+                else:
+                    z_t = 
self.block_builder.emit(relax.op.sigmoid(relax.op.add(z_ih, z_hh)))
+
+                # Compute new gate: n_t = tanh(W_in * x_t + b_in + r_t * (W_hn 
* h_{t-1} + b_hn))
+                n_ih = self.block_builder.emit(
+                    relax.op.linear_algebra.matmul(current_input, 
weight_ih_n_t)
+                )
+                n_hh = self.block_builder.emit(
+                    relax.op.linear_algebra.matmul(h_prev, weight_hh_n_t)
+                )
+                if bias_ih is not None and bias_hh is not None:
+                    bias_ih_n = self.block_builder.emit(
+                        relax.op.strided_slice(
+                            bias_ih, axes=[0], begin=[2 * gate_size], end=[3 * 
gate_size]
+                        )
+                    )
+                    bias_hh_n = self.block_builder.emit(
+                        relax.op.strided_slice(
+                            bias_hh, axes=[0], begin=[2 * gate_size], end=[3 * 
gate_size]
+                        )
+                    )
+                    n_t = self.block_builder.emit(
+                        relax.op.tanh(
+                            relax.op.add(
+                                relax.op.add(n_ih, bias_ih_n),
+                                relax.op.multiply(r_t, relax.op.add(n_hh, 
bias_hh_n)),
+                            )
+                        )
+                    )
+                else:
+                    n_t = self.block_builder.emit(
+                        relax.op.tanh(relax.op.add(n_ih, 
relax.op.multiply(r_t, n_hh)))
+                    )
+
+                # Update hidden state: h_t = (1 - z_t) * n_t + z_t * h_{t-1}
+                one_minus_z = self.block_builder.emit(
+                    relax.op.subtract(relax.const(1.0, dtype), z_t)
+                )
+                h_t = self.block_builder.emit(
+                    relax.op.add(
+                        relax.op.multiply(one_minus_z, n_t), 
relax.op.multiply(z_t, h_prev)
+                    )
+                )
+
+                new_h_states.append(h_t)
+
+                current_input = h_t
+
+            # Update hidden states for next time step
+            h_states = new_h_states
+
+            # Store output (from the last layer)
+            outputs.append(h_states[-1])
+
+        # Stack outputs: (seq_len, batch_size, hidden_size)
+        output = self.block_builder.emit(relax.op.stack(outputs, axis=0))
+
+        # Reshape back to batch_first if needed
+        if batch_first:
+            # (seq_len, batch_size, hidden_size) -> (batch_size, seq_len, 
hidden_size)
+            output = self.block_builder.emit(relax.op.permute_dims(output, 
axes=[1, 0, 2]))
+
+        return output
+
     ########## Manipulation ##########
 
     def _narrow(self, node: fx.Node) -> relax.Var:
@@ -652,6 +946,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
             "layer_norm.default": self._layer_norm,
             "linear.default": self._linear,
             "lstm.input": self._lstm,
+            "gru.input": self._gru,
             "max_pool1d.default": self._max_pool1d,
             "max_pool2d.default": self._max_pool2d,
             "max_pool3d.default": self._max_pool3d,
diff --git a/tests/python/relax/test_frontend_from_exported_program.py 
b/tests/python/relax/test_frontend_from_exported_program.py
index b35af088b5..657ade455b 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -6050,5 +6050,76 @@ def test_tensor_none_tuple():
     verify_model(TensorNoneModel(), example_args, {}, Expected)
 
 
+def test_gru():
+    class BasicGRU(nn.Module):
+        def __init__(self):
+            super().__init__()
+            self.gru = nn.GRU(
+                input_size=4,
+                hidden_size=8,
+                num_layers=1,
+                batch_first=True,
+                bidirectional=False,
+            )
+
+        def forward(self, x):
+            y, _ = self.gru(x)
+            return y
+
+    torch.manual_seed(42)
+    x = torch.randn(2, 3, 4, dtype=torch.float32)
+    model = BasicGRU()
+    with torch.no_grad():
+        pytorch_output = model(x)
+    exported_program = export(model, args=(x,))
+    mod = from_exported_program(exported_program)
+    target = tvm.target.Target("llvm")
+    ex = relax.build(mod, target)
+    vm = relax.VirtualMachine(ex, tvm.cpu())
+    x_tvm = tvm.runtime.tensor(x.numpy())
+    tvm_output = vm["main"](x_tvm)
+    if hasattr(tvm_output, "numpy"):
+        tvm_output_np = tvm_output.numpy()
+    else:
+        tvm_output_np = tvm_output[0].numpy()
+    assert (
+        pytorch_output.shape == tvm_output_np.shape
+    ), f"Shape mismatch: PyTorch {pytorch_output.shape} vs TVM 
{tvm_output_np.shape}"
+    np.testing.assert_allclose(pytorch_output.numpy(), tvm_output_np, 
rtol=1e-4, atol=1e-5)
+
+    class SeqFirstGRU(nn.Module):
+        def __init__(self):
+            super().__init__()
+            self.gru = nn.GRU(
+                input_size=3,
+                hidden_size=6,
+                num_layers=1,
+                batch_first=False,
+                bidirectional=False,
+            )
+
+        def forward(self, x):
+            y, _ = self.gru(x)
+            return y
+
+    torch.manual_seed(43)
+    x2 = torch.randn(4, 2, 3, dtype=torch.float32)
+    model2 = SeqFirstGRU()
+    with torch.no_grad():
+        pytorch_output2 = model2(x2)
+    exported_program2 = export(model2, args=(x2,))
+    mod2 = from_exported_program(exported_program2)
+    ex2 = relax.build(mod2, target)
+    vm2 = relax.VirtualMachine(ex2, tvm.cpu())
+    x2_tvm = tvm.runtime.tensor(x2.numpy())
+    tvm_output2 = vm2["main"](x2_tvm)
+    if hasattr(tvm_output2, "numpy"):
+        tvm_output2_np = tvm_output2.numpy()
+    else:
+        tvm_output2_np = tvm_output2[0].numpy()
+    assert pytorch_output2.shape == tvm_output2_np.shape
+    np.testing.assert_allclose(pytorch_output2.numpy(), tvm_output2_np, 
rtol=1e-4, atol=1e-5)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to