This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 182db0f252 [Relax][PyTorch] Add rnn_tanh.input converter (#19837)
182db0f252 is described below
commit 182db0f25251670df2086231539e7c21d35d746e
Author: Neo Chien <[email protected]>
AuthorDate: Sun Jun 28 10:56:13 2026 +0800
[Relax][PyTorch] Add rnn_tanh.input converter (#19837)
Hi Committers,
This PR addresses the `rnn_tanh.input` part of issue
https://github.com/apache/tvm/issues/18364. Any suggestions would be
appreciated if you are available.
---------
Co-authored-by: cchung100m <[email protected]>
---
.../frontend/torch/exported_program_translator.py | 195 +++++++++++++++++++++
.../relax/test_frontend_from_exported_program.py | 57 ++++++
2 files changed, 252 insertions(+)
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 0ff4410941..481be4d94f 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -918,6 +918,200 @@ class ExportedProgramImporter(BaseFXGraphImporter):
return output
+ def _rnn_tanh_cell_unroll(
+ self,
+ input_reshaped,
+ weight_ih,
+ weight_hh,
+ bias_ih,
+ bias_hh,
+ h_prev,
+ seq_len,
+ reverse=False,
+ ):
+ """Unroll vanilla tanh-RNN cells for a single direction."""
+ # Transpose weights for matmul: (hidden_size, in) -> (in, hidden_size)
+ weight_ih_t = self.block_builder.emit(relax.op.permute_dims(weight_ih,
axes=[1, 0]))
+ weight_hh_t = self.block_builder.emit(relax.op.permute_dims(weight_hh,
axes=[1, 0]))
+
+ bias = None
+ if bias_ih is not None and bias_hh is not None:
+ bias = self.block_builder.emit(relax.op.add(bias_ih, bias_hh))
+
+ outputs = []
+ time_steps = range(seq_len - 1, -1, -1) if reverse else range(seq_len)
+
+ for t in time_steps:
+ # Input at time t: (batch_size, input_size)
+ x_t = self.block_builder.emit(
+ relax.op.take(input_reshaped, relax.const(t, "int64"), axis=0,
mode="clip")
+ )
+
+ # h_t = tanh(W_ih @ x_t + W_hh @ h_{t-1} + (b_ih + b_hh))
+ ih = self.block_builder.emit(relax.op.linear_algebra.matmul(x_t,
weight_ih_t))
+ hh =
self.block_builder.emit(relax.op.linear_algebra.matmul(h_prev, weight_hh_t))
+ ih_hh = self.block_builder.emit(relax.op.add(ih, hh))
+ if bias is not None:
+ ih_hh = self.block_builder.emit(relax.op.add(ih_hh, bias))
+ h_t = self.block_builder.emit(relax.op.tanh(ih_hh))
+
+ outputs.append(h_t)
+ h_prev = h_t
+
+ if reverse:
+ outputs = outputs[::-1]
+
+ output = self.block_builder.emit(relax.op.stack(outputs, axis=0))
+ # 'h_prev' is the hidden state after the final processed time step
(this direction' s h_n)
+ # independent of the output-sequence ordering above.
+ return output, h_prev
+
+ def _rnn_tanh(self, node: fx.Node) -> relax.Var:
+ args = self.retrieve_args(node)
+ input_tensor = args[0]
+ hx = args[1] if len(args) > 1 else None
+ params = args[2] if len(args) > 2 else None
+ has_biases = args[3] if len(args) > 3 else True
+ num_layers = args[4] if len(args) > 4 else 1
+ _dropout = args[5] if len(args) > 5 else 0.0 # Not used in inference
+ _train = args[6] if len(args) > 6 else False # Not used in inference
+ bidirectional = args[7] if len(args) > 7 else False
+ batch_first = args[8] if len(args) > 8 else False
+
+ if num_layers > 1:
+ raise NotImplementedError("Multi-layer RNN is not yet supported")
+
+ def _node_meta(fx_node):
+ meta = fx_node.meta
+ return meta["val"] if "val" in meta else meta["tensor_meta"]
+
+ input_meta = _node_meta(node.args[0])
+ input_shape = list(input_meta.shape)
+ if batch_first:
+ batch_size, seq_len, input_size = input_shape
+ else:
+ seq_len, batch_size, input_size = input_shape
+
+ if not isinstance(seq_len, int):
+ raise NotImplementedError("Dynamic sequence length is not
supported for rnn_tanh")
+
+ # params per direction: weight_ih, weight_hh, [bias_ih, bias_hh]
+ params_per_direction = 4 if has_biases else 2
+
+ # A vanilla RNN has a single gate, so weight_ih has shape
(hidden_size, input_size)
+ if params and len(params) >= 2:
+ hidden_size = int(_node_meta(node.args[2][0]).shape[0])
+ else:
+ hidden_size = 16
+
+ dtype = self._convert_data_type(input_meta.dtype)
+
+ # Forward direction weights
+ if params and len(params) >= params_per_direction:
+ weight_ih_fwd = params[0]
+ weight_hh_fwd = params[1]
+ bias_ih_fwd = params[2] if has_biases else None
+ bias_hh_fwd = params[3] if has_biases else None
+ else:
+ weight_ih_fwd = self.block_builder.emit(
+ relax.op.zeros(relax.ShapeExpr((hidden_size, input_size)),
dtype)
+ )
+ weight_hh_fwd = self.block_builder.emit(
+ relax.op.zeros(relax.ShapeExpr((hidden_size, hidden_size)),
dtype)
+ )
+ bias_ih_fwd = None
+ bias_hh_fwd = None
+
+ # Backward direction weights if bidirectional
+ if bidirectional:
+ if params and len(params) >= params_per_direction * 2:
+ weight_ih_bwd = params[params_per_direction]
+ weight_hh_bwd = params[params_per_direction + 1]
+ bias_ih_bwd = params[params_per_direction + 2] if has_biases
else None
+ bias_hh_bwd = params[params_per_direction + 3] if has_biases
else None
+ else:
+ weight_ih_bwd = self.block_builder.emit(
+ relax.op.zeros(relax.ShapeExpr((hidden_size, input_size)),
dtype)
+ )
+ weight_hh_bwd = self.block_builder.emit(
+ relax.op.zeros(relax.ShapeExpr((hidden_size,
hidden_size)), dtype)
+ )
+ bias_ih_bwd = None
+ bias_hh_bwd = None
+ else:
+ weight_ih_bwd = None
+ weight_hh_bwd = None
+ bias_ih_bwd = None
+ bias_hh_bwd = None
+
+ # Initial hidden states
+ if hx is not None:
+ h_prev_fwd = self.block_builder.emit(
+ relax.op.take(hx, relax.const(0, "int64"), axis=0, mode="clip")
+ )
+ h_prev_bwd = (
+ self.block_builder.emit(
+ relax.op.take(hx, relax.const(1, "int64"), axis=0,
mode="clip")
+ )
+ if bidirectional
+ else None
+ )
+ else:
+ h_prev_fwd = self.block_builder.emit(
+ relax.op.zeros(relax.ShapeExpr((batch_size, hidden_size)),
dtype)
+ )
+ h_prev_bwd = (
+ self.block_builder.emit(
+ relax.op.zeros(relax.ShapeExpr((batch_size, hidden_size)),
dtype)
+ )
+ if bidirectional
+ else None
+ )
+
+ # Reshape input to (seq_len, batch_size, input_size)
+ input_reshaped = (
+ self.block_builder.emit(relax.op.permute_dims(input_tensor,
axes=[1, 0, 2]))
+ if batch_first
+ else input_tensor
+ )
+
+ # Process forward direction
+ output_fwd, h_n_fwd = self._rnn_tanh_cell_unroll(
+ input_reshaped,
+ weight_ih_fwd,
+ weight_hh_fwd,
+ bias_ih_fwd,
+ bias_hh_fwd,
+ h_prev_fwd,
+ seq_len,
+ reverse=False,
+ )
+
+ # Process backward direction if bidirectional
+ if bidirectional:
+ output_bwd, h_n_bwd = self._rnn_tanh_cell_unroll(
+ input_reshaped,
+ weight_ih_bwd,
+ weight_hh_bwd,
+ bias_ih_bwd,
+ bias_hh_bwd,
+ h_prev_bwd,
+ seq_len,
+ reverse=True,
+ )
+ # Concatenate forward and backward outputs along the feature
dimension
+ output = self.block_builder.emit(relax.op.concat([output_fwd,
output_bwd], axis=2))
+ h_n = self.block_builder.emit(relax.op.stack([h_n_fwd, h_n_bwd],
axis=0))
+ else:
+ output = output_fwd
+ h_n = self.block_builder.emit(relax.op.expand_dims(h_n_fwd,
axis=0))
+
+ # Reshape the output back to batch_first if needed (h_n is
layout-independent).
+ if batch_first:
+ output = self.block_builder.emit(relax.op.permute_dims(output,
axes=[1, 0, 2]))
+
+ return self.block_builder.emit(relax.Tuple([output, h_n]))
+
########## Manipulation ##########
def _narrow(self, node: fx.Node) -> relax.Var:
@@ -1703,6 +1897,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
"linear.default": self._linear,
"lstm.input": self._lstm,
"gru.input": self._gru,
+ "rnn_tanh.input": self._rnn_tanh,
"max_pool1d.default": self._max_pool1d,
"max_pool2d.default": self._max_pool2d,
"max_pool2d_with_indices.default": self._max_pool2d_with_indices,
diff --git a/tests/python/relax/test_frontend_from_exported_program.py
b/tests/python/relax/test_frontend_from_exported_program.py
index 1ac6e9e20f..d91aa46f21 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -8552,6 +8552,63 @@ def test_gru():
tvm.testing.assert_allclose(pytorch_output4.numpy(), tvm_output4_np,
rtol=1e-4, atol=1e-5)
[email protected](not env.has_llvm(), reason="need llvm")
+def test_rnn_tanh():
+ target = tvm.target.Target("llvm")
+
+ def _check(rnn_kwargs, x_shape, seed):
+ class RNNWithState(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.rnn = nn.RNN(nonlinearity="tanh", num_layers=1,
**rnn_kwargs)
+
+ def forward(self, x):
+ output, h_n = self.rnn(x)
+ return output, h_n
+
+ torch.manual_seed(seed)
+ x = torch.randn(*x_shape, dtype=torch.float32)
+ model = RNNWithState()
+ with torch.no_grad():
+ pt_out, pt_hn = model(x)
+
+ exported_program = export(model, args=(x,))
+ mod = from_exported_program(exported_program,
run_ep_decomposition=False)
+ ex = relax.build(mod, target)
+ vm = relax.VirtualMachine(ex, tvm.cpu())
+ tvm_outputs = vm["main"](tvm.runtime.tensor(x.numpy()))
+ tvm_out_np = tvm_outputs[0].numpy()
+ tvm_hn_np = tvm_outputs[1].numpy()
+
+ assert pt_out.shape == tvm_out_np.shape, (
+ f"output shape mismatch: PyTorch {tuple(pt_out.shape)} vs TVM
{tvm_out_np.shape}"
+ )
+ assert pt_hn.shape == tvm_hn_np.shape, (
+ f"h_n shape mismatch: PyTorch {tuple(pt_hn.shape)} vs TVM
{tvm_hn_np.shape}"
+ )
+ tvm.testing.assert_allclose(pt_out.numpy(), tvm_out_np, rtol=1e-4,
atol=1e-5)
+ tvm.testing.assert_allclose(pt_hn.numpy(), tvm_hn_np, rtol=1e-4,
atol=1e-5)
+
+ # batch_first, unidirectional
+ _check(
+ {"input_size": 4, "hidden_size": 8, "batch_first": True,
"bidirectional": False},
+ (2, 3, 4),
+ seed=42,
+ )
+ # seq-first (batch_first=False), unidirectional
+ _check(
+ {"input_size": 3, "hidden_size": 6, "batch_first": False,
"bidirectional": False},
+ (4, 2, 3),
+ seed=43,
+ )
+ # bidirectional, batch_first
+ _check(
+ {"input_size": 4, "hidden_size": 8, "batch_first": True,
"bidirectional": True},
+ (2, 3, 4),
+ seed=44,
+ )
+
+
def test_dynamic_shape_with_range_constraints():
class DynamicModel(torch.nn.Module):
def forward(self, x1, x2):