This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 182db0f252 [Relax][PyTorch] Add rnn_tanh.input converter (#19837)
182db0f252 is described below

commit 182db0f25251670df2086231539e7c21d35d746e
Author: Neo Chien <[email protected]>
AuthorDate: Sun Jun 28 10:56:13 2026 +0800

    [Relax][PyTorch] Add rnn_tanh.input converter (#19837)
    
    Hi Committers,
    
    This PR addresses the `rnn_tanh.input` part of issue
    https://github.com/apache/tvm/issues/18364. Any suggestions would be
    appreciated if you are available.
    
    ---------
    
    Co-authored-by: cchung100m <[email protected]>
---
 .../frontend/torch/exported_program_translator.py  | 195 +++++++++++++++++++++
 .../relax/test_frontend_from_exported_program.py   |  57 ++++++
 2 files changed, 252 insertions(+)

diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py 
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 0ff4410941..481be4d94f 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -918,6 +918,200 @@ class ExportedProgramImporter(BaseFXGraphImporter):
 
         return output
 
+    def _rnn_tanh_cell_unroll(
+        self,
+        input_reshaped,
+        weight_ih,
+        weight_hh,
+        bias_ih,
+        bias_hh,
+        h_prev,
+        seq_len,
+        reverse=False,
+    ):
+        """Unroll vanilla tanh-RNN cells for a single direction."""
+        # Transpose weights for matmul: (hidden_size, in) -> (in, hidden_size)
+        weight_ih_t = self.block_builder.emit(relax.op.permute_dims(weight_ih, 
axes=[1, 0]))
+        weight_hh_t = self.block_builder.emit(relax.op.permute_dims(weight_hh, 
axes=[1, 0]))
+
+        bias = None
+        if bias_ih is not None and bias_hh is not None:
+            bias = self.block_builder.emit(relax.op.add(bias_ih, bias_hh))
+
+        outputs = []
+        time_steps = range(seq_len - 1, -1, -1) if reverse else range(seq_len)
+
+        for t in time_steps:
+            # Input at time t: (batch_size, input_size)
+            x_t = self.block_builder.emit(
+                relax.op.take(input_reshaped, relax.const(t, "int64"), axis=0, 
mode="clip")
+            )
+
+            # h_t = tanh(W_ih @ x_t + W_hh @ h_{t-1} + (b_ih + b_hh))
+            ih = self.block_builder.emit(relax.op.linear_algebra.matmul(x_t, 
weight_ih_t))
+            hh = 
self.block_builder.emit(relax.op.linear_algebra.matmul(h_prev, weight_hh_t))
+            ih_hh = self.block_builder.emit(relax.op.add(ih, hh))
+            if bias is not None:
+                ih_hh = self.block_builder.emit(relax.op.add(ih_hh, bias))
+            h_t = self.block_builder.emit(relax.op.tanh(ih_hh))
+
+            outputs.append(h_t)
+            h_prev = h_t
+
+        if reverse:
+            outputs = outputs[::-1]
+
+        output = self.block_builder.emit(relax.op.stack(outputs, axis=0))
+        # 'h_prev' is the hidden state after the final processed time step 
(this direction' s h_n)
+        # independent of the output-sequence ordering above.
+        return output, h_prev
+
+    def _rnn_tanh(self, node: fx.Node) -> relax.Var:
+        args = self.retrieve_args(node)
+        input_tensor = args[0]
+        hx = args[1] if len(args) > 1 else None
+        params = args[2] if len(args) > 2 else None
+        has_biases = args[3] if len(args) > 3 else True
+        num_layers = args[4] if len(args) > 4 else 1
+        _dropout = args[5] if len(args) > 5 else 0.0  # Not used in inference
+        _train = args[6] if len(args) > 6 else False  # Not used in inference
+        bidirectional = args[7] if len(args) > 7 else False
+        batch_first = args[8] if len(args) > 8 else False
+
+        if num_layers > 1:
+            raise NotImplementedError("Multi-layer RNN is not yet supported")
+
+        def _node_meta(fx_node):
+            meta = fx_node.meta
+            return meta["val"] if "val" in meta else meta["tensor_meta"]
+
+        input_meta = _node_meta(node.args[0])
+        input_shape = list(input_meta.shape)
+        if batch_first:
+            batch_size, seq_len, input_size = input_shape
+        else:
+            seq_len, batch_size, input_size = input_shape
+
+        if not isinstance(seq_len, int):
+            raise NotImplementedError("Dynamic sequence length is not 
supported for rnn_tanh")
+
+        # params per direction: weight_ih, weight_hh, [bias_ih, bias_hh]
+        params_per_direction = 4 if has_biases else 2
+
+        # A vanilla RNN has a single gate, so weight_ih has shape 
(hidden_size, input_size)
+        if params and len(params) >= 2:
+            hidden_size = int(_node_meta(node.args[2][0]).shape[0])
+        else:
+            hidden_size = 16
+
+        dtype = self._convert_data_type(input_meta.dtype)
+
+        # Forward direction weights
+        if params and len(params) >= params_per_direction:
+            weight_ih_fwd = params[0]
+            weight_hh_fwd = params[1]
+            bias_ih_fwd = params[2] if has_biases else None
+            bias_hh_fwd = params[3] if has_biases else None
+        else:
+            weight_ih_fwd = self.block_builder.emit(
+                relax.op.zeros(relax.ShapeExpr((hidden_size, input_size)), 
dtype)
+            )
+            weight_hh_fwd = self.block_builder.emit(
+                relax.op.zeros(relax.ShapeExpr((hidden_size, hidden_size)), 
dtype)
+            )
+            bias_ih_fwd = None
+            bias_hh_fwd = None
+
+        # Backward direction weights if bidirectional
+        if bidirectional:
+            if params and len(params) >= params_per_direction * 2:
+                weight_ih_bwd = params[params_per_direction]
+                weight_hh_bwd = params[params_per_direction + 1]
+                bias_ih_bwd = params[params_per_direction + 2] if has_biases 
else None
+                bias_hh_bwd = params[params_per_direction + 3] if has_biases 
else None
+            else:
+                weight_ih_bwd = self.block_builder.emit(
+                    relax.op.zeros(relax.ShapeExpr((hidden_size, input_size)), 
dtype)
+                )
+                weight_hh_bwd = self.block_builder.emit(
+                    relax.op.zeros(relax.ShapeExpr((hidden_size, 
hidden_size)), dtype)
+                )
+                bias_ih_bwd = None
+                bias_hh_bwd = None
+        else:
+            weight_ih_bwd = None
+            weight_hh_bwd = None
+            bias_ih_bwd = None
+            bias_hh_bwd = None
+
+        # Initial hidden states
+        if hx is not None:
+            h_prev_fwd = self.block_builder.emit(
+                relax.op.take(hx, relax.const(0, "int64"), axis=0, mode="clip")
+            )
+            h_prev_bwd = (
+                self.block_builder.emit(
+                    relax.op.take(hx, relax.const(1, "int64"), axis=0, 
mode="clip")
+                )
+                if bidirectional
+                else None
+            )
+        else:
+            h_prev_fwd = self.block_builder.emit(
+                relax.op.zeros(relax.ShapeExpr((batch_size, hidden_size)), 
dtype)
+            )
+            h_prev_bwd = (
+                self.block_builder.emit(
+                    relax.op.zeros(relax.ShapeExpr((batch_size, hidden_size)), 
dtype)
+                )
+                if bidirectional
+                else None
+            )
+
+        # Reshape input to (seq_len, batch_size, input_size)
+        input_reshaped = (
+            self.block_builder.emit(relax.op.permute_dims(input_tensor, 
axes=[1, 0, 2]))
+            if batch_first
+            else input_tensor
+        )
+
+        # Process forward direction
+        output_fwd, h_n_fwd = self._rnn_tanh_cell_unroll(
+            input_reshaped,
+            weight_ih_fwd,
+            weight_hh_fwd,
+            bias_ih_fwd,
+            bias_hh_fwd,
+            h_prev_fwd,
+            seq_len,
+            reverse=False,
+        )
+
+        # Process backward direction if bidirectional
+        if bidirectional:
+            output_bwd, h_n_bwd = self._rnn_tanh_cell_unroll(
+                input_reshaped,
+                weight_ih_bwd,
+                weight_hh_bwd,
+                bias_ih_bwd,
+                bias_hh_bwd,
+                h_prev_bwd,
+                seq_len,
+                reverse=True,
+            )
+            # Concatenate forward and backward outputs along the feature 
dimension
+            output = self.block_builder.emit(relax.op.concat([output_fwd, 
output_bwd], axis=2))
+            h_n = self.block_builder.emit(relax.op.stack([h_n_fwd, h_n_bwd], 
axis=0))
+        else:
+            output = output_fwd
+            h_n = self.block_builder.emit(relax.op.expand_dims(h_n_fwd, 
axis=0))
+
+        # Reshape the output back to batch_first if needed (h_n is 
layout-independent).
+        if batch_first:
+            output = self.block_builder.emit(relax.op.permute_dims(output, 
axes=[1, 0, 2]))
+
+        return self.block_builder.emit(relax.Tuple([output, h_n]))
+
     ########## Manipulation ##########
 
     def _narrow(self, node: fx.Node) -> relax.Var:
@@ -1703,6 +1897,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
             "linear.default": self._linear,
             "lstm.input": self._lstm,
             "gru.input": self._gru,
+            "rnn_tanh.input": self._rnn_tanh,
             "max_pool1d.default": self._max_pool1d,
             "max_pool2d.default": self._max_pool2d,
             "max_pool2d_with_indices.default": self._max_pool2d_with_indices,
diff --git a/tests/python/relax/test_frontend_from_exported_program.py 
b/tests/python/relax/test_frontend_from_exported_program.py
index 1ac6e9e20f..d91aa46f21 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -8552,6 +8552,63 @@ def test_gru():
     tvm.testing.assert_allclose(pytorch_output4.numpy(), tvm_output4_np, 
rtol=1e-4, atol=1e-5)
 
 
[email protected](not env.has_llvm(), reason="need llvm")
+def test_rnn_tanh():
+    target = tvm.target.Target("llvm")
+
+    def _check(rnn_kwargs, x_shape, seed):
+        class RNNWithState(nn.Module):
+            def __init__(self):
+                super().__init__()
+                self.rnn = nn.RNN(nonlinearity="tanh", num_layers=1, 
**rnn_kwargs)
+
+            def forward(self, x):
+                output, h_n = self.rnn(x)
+                return output, h_n
+
+        torch.manual_seed(seed)
+        x = torch.randn(*x_shape, dtype=torch.float32)
+        model = RNNWithState()
+        with torch.no_grad():
+            pt_out, pt_hn = model(x)
+
+        exported_program = export(model, args=(x,))
+        mod = from_exported_program(exported_program, 
run_ep_decomposition=False)
+        ex = relax.build(mod, target)
+        vm = relax.VirtualMachine(ex, tvm.cpu())
+        tvm_outputs = vm["main"](tvm.runtime.tensor(x.numpy()))
+        tvm_out_np = tvm_outputs[0].numpy()
+        tvm_hn_np = tvm_outputs[1].numpy()
+
+        assert pt_out.shape == tvm_out_np.shape, (
+            f"output shape mismatch: PyTorch {tuple(pt_out.shape)} vs TVM 
{tvm_out_np.shape}"
+        )
+        assert pt_hn.shape == tvm_hn_np.shape, (
+            f"h_n shape mismatch: PyTorch {tuple(pt_hn.shape)} vs TVM 
{tvm_hn_np.shape}"
+        )
+        tvm.testing.assert_allclose(pt_out.numpy(), tvm_out_np, rtol=1e-4, 
atol=1e-5)
+        tvm.testing.assert_allclose(pt_hn.numpy(), tvm_hn_np, rtol=1e-4, 
atol=1e-5)
+
+    # batch_first, unidirectional
+    _check(
+        {"input_size": 4, "hidden_size": 8, "batch_first": True, 
"bidirectional": False},
+        (2, 3, 4),
+        seed=42,
+    )
+    # seq-first (batch_first=False), unidirectional
+    _check(
+        {"input_size": 3, "hidden_size": 6, "batch_first": False, 
"bidirectional": False},
+        (4, 2, 3),
+        seed=43,
+    )
+    # bidirectional, batch_first
+    _check(
+        {"input_size": 4, "hidden_size": 8, "batch_first": True, 
"bidirectional": True},
+        (2, 3, 4),
+        seed=44,
+    )
+
+
 def test_dynamic_shape_with_range_constraints():
     class DynamicModel(torch.nn.Module):
         def forward(self, x1, x2):

Reply via email to