This is an automated email from the ASF dual-hosted git repository.

mshr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 1c77db7889 [Relax][PyTorch] Add support for bidirectional LSTM (#18516)
1c77db7889 is described below

commit 1c77db78891b81b22ff8f1404546d1bee4fc1bb1
Author: Guan-Ming (Wesley) Chiu <[email protected]>
AuthorDate: Fri Nov 28 18:00:52 2025 +0800

    [Relax][PyTorch] Add support for bidirectional LSTM (#18516)
---
 .../frontend/torch/exported_program_translator.py  | 269 +++++++++++++--------
 .../relax/test_frontend_from_exported_program.py   | 106 ++++----
 2 files changed, 222 insertions(+), 153 deletions(-)

diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py 
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 04e5330ce6..fc0ca18209 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -378,6 +378,75 @@ class ExportedProgramImporter(BaseFXGraphImporter):
             align_corners=align_corners,
         )
 
+    def _lstm_cell_unroll(
+        self,
+        input_reshaped,
+        weight_ih,
+        weight_hh,
+        bias_ih,
+        bias_hh,
+        h_prev,
+        c_prev,
+        seq_len,
+        hidden_size,
+        reverse=False,
+    ):
+        """Unroll LSTM cells for a single direction."""
+        weight_ih_t = self.block_builder.emit(relax.op.permute_dims(weight_ih, 
axes=[1, 0]))
+        weight_hh_t = self.block_builder.emit(relax.op.permute_dims(weight_hh, 
axes=[1, 0]))
+        outputs = []
+        time_steps = range(seq_len - 1, -1, -1) if reverse else range(seq_len)
+
+        for t in time_steps:
+            x_t = self.block_builder.emit(
+                relax.op.take(input_reshaped, relax.const(t, "int64"), axis=0, 
mode="clip")
+            )
+            ih_gates = 
self.block_builder.emit(relax.op.linear_algebra.matmul(x_t, weight_ih_t))
+            hh_gates = 
self.block_builder.emit(relax.op.linear_algebra.matmul(h_prev, weight_hh_t))
+
+            gates = self.block_builder.emit(relax.op.add(ih_gates, hh_gates))
+            if bias_ih is not None:
+                gates = self.block_builder.emit(relax.op.add(gates, bias_ih))
+            if bias_hh is not None:
+                gates = self.block_builder.emit(relax.op.add(gates, bias_hh))
+
+            i_gate = self.block_builder.emit(
+                relax.op.strided_slice(gates, axes=[1], begin=[0], 
end=[hidden_size])
+            )
+            f_gate = self.block_builder.emit(
+                relax.op.strided_slice(gates, axes=[1], begin=[hidden_size], 
end=[2 * hidden_size])
+            )
+            g_gate = self.block_builder.emit(
+                relax.op.strided_slice(
+                    gates, axes=[1], begin=[2 * hidden_size], end=[3 * 
hidden_size]
+                )
+            )
+            o_gate = self.block_builder.emit(
+                relax.op.strided_slice(
+                    gates, axes=[1], begin=[3 * hidden_size], end=[4 * 
hidden_size]
+                )
+            )
+
+            i_t = self.block_builder.emit(relax.op.sigmoid(i_gate))
+            f_t = self.block_builder.emit(relax.op.sigmoid(f_gate))
+            g_t = self.block_builder.emit(relax.op.tanh(g_gate))
+            o_t = self.block_builder.emit(relax.op.sigmoid(o_gate))
+
+            c_t = self.block_builder.emit(
+                relax.op.add(relax.op.multiply(f_t, c_prev), 
relax.op.multiply(i_t, g_t))
+            )
+            h_t = self.block_builder.emit(relax.op.multiply(o_t, 
relax.op.tanh(c_t)))
+
+            outputs.append(h_t)
+            h_prev = h_t
+            c_prev = c_t
+
+        if reverse:
+            outputs = outputs[::-1]
+
+        output = self.block_builder.emit(relax.op.stack(outputs, axis=0))
+        return output
+
     def _lstm(self, node: fx.Node) -> relax.Var:
         args = self.retrieve_args(node)
         input_tensor = args[0]
@@ -385,39 +454,30 @@ class ExportedProgramImporter(BaseFXGraphImporter):
         params = args[2] if len(args) > 2 else None
         has_biases = args[3] if len(args) > 3 else True
         num_layers = args[4] if len(args) > 4 else 1
-        _dropout = args[5] if len(args) > 5 else 0.0  # Not used in inference
-        _train = args[6] if len(args) > 6 else False  # Not used in inference
         bidirectional = args[7] if len(args) > 7 else False
         batch_first = args[8] if len(args) > 8 else False
-        if bidirectional:
-            raise NotImplementedError("Bidirectional LSTM is not yet 
supported")
+
         if num_layers > 1:
             raise NotImplementedError("Multi-layer LSTM is not yet supported")
+
         input_shape = self.shape_of(input_tensor)
         if batch_first:
-            # Input shape: (batch, seq_len, input_size)
             batch_size, seq_len, input_size = input_shape
         else:
-            # Input shape: (seq_len, batch, input_size)
             seq_len, batch_size, input_size = input_shape
 
-        if isinstance(seq_len, tvm.tir.IntImm):
-            seq_len = seq_len.value
-        if isinstance(batch_size, tvm.tir.IntImm):
-            batch_size = batch_size.value
-        if isinstance(input_size, tvm.tir.IntImm):
-            input_size = input_size.value
+        seq_len = int(seq_len) if isinstance(seq_len, tvm.tir.IntImm) else 
seq_len
+        batch_size = int(batch_size) if isinstance(batch_size, tvm.tir.IntImm) 
else batch_size
+        input_size = int(input_size) if isinstance(input_size, tvm.tir.IntImm) 
else input_size
         # Extract hidden size from the LSTM parameters
         # The parameters are: [weight_ih, weight_hh, bias_ih, bias_hh]
         # weight_ih shape: (4 * hidden_size, input_size)
         # weight_hh shape: (4 * hidden_size, hidden_size)
         if params and len(params) >= 2:
-            weight_ih = params[0]
-            weight_hh = params[1]
             # Extract hidden size from weight dimensions
             # weight_ih has shape (4 * hidden_size, input_size)
-            weight_ih_shape = self.shape_of(weight_ih)
-            hidden_size = weight_ih_shape[0] // 4  # 4 gates: input, forget, 
cell, output
+            weight_ih_shape = self.shape_of(params[0])
+            hidden_size = weight_ih_shape[0] // 4
         else:
             # Fallback to a default hidden size
             hidden_size = 16
@@ -430,109 +490,120 @@ class ExportedProgramImporter(BaseFXGraphImporter):
         # c_t = f_t * c_{t-1} + i_t * g_t
         # h_t = o_t * tanh(c_t)
         dtype = input_tensor.struct_info.dtype
-        if params and len(params) >= 4:
-            weight_ih = params[0]  # (4 * hidden_size, input_size)
-            weight_hh = params[1]  # (4 * hidden_size, hidden_size)
-            bias_ih = params[2] if has_biases else None  # (4 * hidden_size,)
-            bias_hh = params[3] if has_biases else None  # (4 * hidden_size,)
+        params_per_direction = 4 if has_biases else 2
+
+        # Extract or create forward direction weights
+        if params and len(params) >= 2:
+            weight_ih_fwd = params[0]
+            weight_hh_fwd = params[1]
+            bias_ih_fwd = params[2] if has_biases and len(params) > 2 else None
+            bias_hh_fwd = params[3] if has_biases and len(params) > 3 else None
         else:
             # Fallback: create zero weights
-            weight_ih = self.block_builder.emit(
+            weight_ih_fwd = self.block_builder.emit(
                 relax.op.zeros(relax.ShapeExpr((4 * hidden_size, input_size)), 
dtype)
             )
-            weight_hh = self.block_builder.emit(
+            weight_hh_fwd = self.block_builder.emit(
                 relax.op.zeros(relax.ShapeExpr((4 * hidden_size, 
hidden_size)), dtype)
             )
-            bias_ih = None
-            bias_hh = None
-        # Initialize hidden and cell states
+            bias_ih_fwd = None
+            bias_hh_fwd = None
+
+        # Extract or create backward direction weights if bidirectional
+        if bidirectional:
+            if params and len(params) >= params_per_direction * 2:
+                weight_ih_bwd = params[params_per_direction]
+                weight_hh_bwd = params[params_per_direction + 1]
+                bias_ih_bwd = params[params_per_direction + 2] if has_biases 
else None
+                bias_hh_bwd = params[params_per_direction + 3] if has_biases 
else None
+            else:
+                # Fallback: create zero weights
+                weight_ih_bwd = self.block_builder.emit(
+                    relax.op.zeros(relax.ShapeExpr((4 * hidden_size, 
input_size)), dtype)
+                )
+                weight_hh_bwd = self.block_builder.emit(
+                    relax.op.zeros(relax.ShapeExpr((4 * hidden_size, 
hidden_size)), dtype)
+                )
+                bias_ih_bwd = None
+                bias_hh_bwd = None
+        else:
+            weight_ih_bwd = None
+            weight_hh_bwd = None
+            bias_ih_bwd = None
+            bias_hh_bwd = None
+
         if hx is not None and len(hx) >= 2:
-            h_0 = hx[0]  # (num_layers, batch_size, hidden_size)
-            c_0 = hx[1]  # (num_layers, batch_size, hidden_size)
-            # Extract the first layer's hidden state
-            h_prev = self.block_builder.emit(
+            h_0, c_0 = hx[0], hx[1]
+            h_prev_fwd = self.block_builder.emit(
                 relax.op.take(h_0, relax.const(0, "int64"), axis=0, 
mode="clip")
             )
-            c_prev = self.block_builder.emit(
+            c_prev_fwd = self.block_builder.emit(
                 relax.op.take(c_0, relax.const(0, "int64"), axis=0, 
mode="clip")
             )
+            if bidirectional:
+                h_prev_bwd = self.block_builder.emit(
+                    relax.op.take(h_0, relax.const(1, "int64"), axis=0, 
mode="clip")
+                )
+                c_prev_bwd = self.block_builder.emit(
+                    relax.op.take(c_0, relax.const(1, "int64"), axis=0, 
mode="clip")
+                )
+            else:
+                h_prev_bwd = None
+                c_prev_bwd = None
         else:
-            h_prev = self.block_builder.emit(
+            h_prev_fwd = self.block_builder.emit(
                 relax.op.zeros(relax.ShapeExpr((batch_size, hidden_size)), 
dtype)
             )
-            c_prev = self.block_builder.emit(
+            c_prev_fwd = self.block_builder.emit(
                 relax.op.zeros(relax.ShapeExpr((batch_size, hidden_size)), 
dtype)
             )
-        # Reshape input for processing
-        if batch_first:
-            # Input: (batch, seq_len, input_size) -> (seq_len, batch, 
input_size)
-            input_reshaped = self.block_builder.emit(
-                relax.op.permute_dims(input_tensor, axes=[1, 0, 2])
-            )
-        else:
-            input_reshaped = input_tensor
-        weight_ih_t = self.block_builder.emit(relax.op.permute_dims(weight_ih, 
axes=[1, 0]))
-        weight_hh_t = self.block_builder.emit(relax.op.permute_dims(weight_hh, 
axes=[1, 0]))
-        outputs = []
-        for t in range(seq_len):
-            # Get input at time t: (batch_size, input_size)
-            x_t = self.block_builder.emit(
-                relax.op.take(input_reshaped, relax.const(t, "int64"), axis=0, 
mode="clip")
-            )
-            # Compute gates: W_ih * x_t + W_hh * h_{t-1} + bias
-            # Input-to-hidden: (batch_size, input_size) @ (4*hidden_size, 
input_size).T
-            ih_gates = 
self.block_builder.emit(relax.op.linear_algebra.matmul(x_t, weight_ih_t))
-
-            # Hidden-to-hidden: (batch_size, hidden_size) @ (4*hidden_size, 
hidden_size).T
-            hh_gates = 
self.block_builder.emit(relax.op.linear_algebra.matmul(h_prev, weight_hh_t))
-            # Add biases if present
-            if bias_ih is not None and bias_hh is not None:
-                gates = self.block_builder.emit(
-                    relax.op.add(relax.op.add(relax.op.add(ih_gates, bias_ih), 
hh_gates), bias_hh)
-                )
-            elif bias_ih is not None:
-                gates = self.block_builder.emit(
-                    relax.op.add(relax.op.add(ih_gates, bias_ih), hh_gates)
+            if bidirectional:
+                h_prev_bwd = self.block_builder.emit(
+                    relax.op.zeros(relax.ShapeExpr((batch_size, hidden_size)), 
dtype)
                 )
-            elif bias_hh is not None:
-                gates = self.block_builder.emit(
-                    relax.op.add(relax.op.add(ih_gates, hh_gates), bias_hh)
+                c_prev_bwd = self.block_builder.emit(
+                    relax.op.zeros(relax.ShapeExpr((batch_size, hidden_size)), 
dtype)
                 )
             else:
-                gates = self.block_builder.emit(relax.op.add(ih_gates, 
hh_gates))
-            # Split gates: (batch_size, 4 * hidden_size) -> 4 x (batch_size, 
hidden_size)
-            gate_size = hidden_size
-            i_gate = self.block_builder.emit(
-                relax.op.strided_slice(gates, axes=[1], begin=[0], 
end=[gate_size])
-            )
-            f_gate = self.block_builder.emit(
-                relax.op.strided_slice(gates, axes=[1], begin=[gate_size], 
end=[2 * gate_size])
-            )
-            g_gate = self.block_builder.emit(
-                relax.op.strided_slice(gates, axes=[1], begin=[2 * gate_size], 
end=[3 * gate_size])
-            )
-            o_gate = self.block_builder.emit(
-                relax.op.strided_slice(gates, axes=[1], begin=[3 * gate_size], 
end=[4 * gate_size])
-            )
-            # Apply activations
-            i_t = self.block_builder.emit(relax.op.sigmoid(i_gate))
-            f_t = self.block_builder.emit(relax.op.sigmoid(f_gate))
-            g_t = self.block_builder.emit(relax.op.tanh(g_gate))
-            o_t = self.block_builder.emit(relax.op.sigmoid(o_gate))
-            # Update cell state: c_t = f_t * c_{t-1} + i_t * g_t
-            c_t = self.block_builder.emit(
-                relax.op.add(relax.op.multiply(f_t, c_prev), 
relax.op.multiply(i_t, g_t))
+                h_prev_bwd = None
+                c_prev_bwd = None
+
+        input_reshaped = (
+            self.block_builder.emit(relax.op.permute_dims(input_tensor, 
axes=[1, 0, 2]))
+            if batch_first
+            else input_tensor
+        )
+
+        output_fwd = self._lstm_cell_unroll(
+            input_reshaped,
+            weight_ih_fwd,
+            weight_hh_fwd,
+            bias_ih_fwd,
+            bias_hh_fwd,
+            h_prev_fwd,
+            c_prev_fwd,
+            seq_len,
+            hidden_size,
+            reverse=False,
+        )
+
+        if bidirectional:
+            output_bwd = self._lstm_cell_unroll(
+                input_reshaped,
+                weight_ih_bwd,
+                weight_hh_bwd,
+                bias_ih_bwd,
+                bias_hh_bwd,
+                h_prev_bwd,
+                c_prev_bwd,
+                seq_len,
+                hidden_size,
+                reverse=True,
             )
-            # Update hidden state: h_t = o_t * tanh(c_t)
-            h_t = self.block_builder.emit(relax.op.multiply(o_t, 
relax.op.tanh(c_t)))
-            # Store output
-            outputs.append(h_t)
-            # Update for next iteration
-            h_prev = h_t
-            c_prev = c_t
-        # Stack outputs: (seq_len, batch_size, hidden_size)
-        output = self.block_builder.emit(relax.op.stack(outputs, axis=0))
-        # Reshape back to batch_first if needed
+            output = self.block_builder.emit(relax.op.concat([output_fwd, 
output_bwd], axis=2))
+        else:
+            output = output_fwd
+
         if batch_first:
             # (seq_len, batch_size, hidden_size) -> (batch_size, seq_len, 
hidden_size)
             output = self.block_builder.emit(relax.op.permute_dims(output, 
axes=[1, 0, 2]))
diff --git a/tests/python/relax/test_frontend_from_exported_program.py 
b/tests/python/relax/test_frontend_from_exported_program.py
index fe3ff28aea..8ff46bf611 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -57,6 +57,37 @@ def verify_model(
     tvm.ir.assert_structural_equal(mod, expected, map_free_vars=map_free_vars)
 
 
+def verify_model_numerically(torch_model, example_args, rtol=1e-7, atol=1e-7):
+    """Verify model by comparing numerical outputs between PyTorch and TVM."""
+    with torch.no_grad():
+        pytorch_output = torch_model(*example_args)
+
+    exported_program = export(torch_model, args=example_args)
+    mod = from_exported_program(exported_program)
+    target = tvm.target.Target("llvm")
+    ex = relax.build(mod, target)
+    vm = relax.VirtualMachine(ex, tvm.cpu())
+
+    tvm_args = [tvm.runtime.tensor(arg.numpy()) for arg in example_args]
+    tvm_output = vm["main"](*tvm_args)
+
+    if hasattr(tvm_output, "numpy"):
+        tvm_output_np = tvm_output.numpy()
+    else:
+        tvm_output_np = tvm_output[0].numpy()
+
+    pytorch_output_np = (
+        pytorch_output.numpy()
+        if isinstance(pytorch_output, torch.Tensor)
+        else pytorch_output[0].numpy()
+    )
+
+    assert (
+        pytorch_output_np.shape == tvm_output_np.shape
+    ), f"Shape mismatch: PyTorch {pytorch_output_np.shape} vs TVM 
{tvm_output_np.shape}"
+    tvm.testing.assert_allclose(pytorch_output_np, tvm_output_np, rtol=rtol, 
atol=atol)
+
+
 operator_basic_unary = [
     (torch.abs, R.abs),
     (torch.acos, R.acos),
@@ -7831,75 +7862,42 @@ def test_sparse_mm():
     verify_model(SparseMatrixMultiply(), example_args, {}, Expected)
 
 
[email protected]_llvm
 def test_lstm():
-    class BasicLSTM(nn.Module):
-        def __init__(self):
+    class LSTM(nn.Module):
+        def __init__(self, input_size, hidden_size, batch_first, 
bidirectional):
             super().__init__()
             self.lstm = nn.LSTM(
-                input_size=4,
-                hidden_size=8,
+                input_size=input_size,
+                hidden_size=hidden_size,
                 num_layers=1,
-                batch_first=True,
-                bidirectional=False,
+                batch_first=batch_first,
+                bidirectional=bidirectional,
             )
 
         def forward(self, x):
             y, _ = self.lstm(x)
             return y
 
+    # Unidirectional LSTM with batch_first=True
     torch.manual_seed(42)
     x = torch.randn(2, 3, 4, dtype=torch.float32)
-    model = BasicLSTM()
-    with torch.no_grad():
-        pytorch_output = model(x)
-    exported_program = export(model, args=(x,))
-    mod = from_exported_program(exported_program)
-    target = tvm.target.Target("llvm")
-    ex = relax.build(mod, target)
-    vm = relax.VirtualMachine(ex, tvm.cpu())
-    x_tvm = tvm.runtime.tensor(x.numpy())
-    tvm_output = vm["main"](x_tvm)
-    if hasattr(tvm_output, "numpy"):
-        tvm_output_np = tvm_output.numpy()
-    else:
-        tvm_output_np = tvm_output[0].numpy()
-    assert (
-        pytorch_output.shape == tvm_output_np.shape
-    ), f"Shape mismatch: PyTorch {pytorch_output.shape} vs TVM 
{tvm_output_np.shape}"
-    np.testing.assert_allclose(pytorch_output.numpy(), tvm_output_np, 
rtol=1e-4, atol=1e-5)
-
-    class SeqFirstLSTM(nn.Module):
-        def __init__(self):
-            super().__init__()
-            self.lstm = nn.LSTM(
-                input_size=3,
-                hidden_size=6,
-                num_layers=1,
-                batch_first=False,
-                bidirectional=False,
-            )
-
-        def forward(self, x):
-            y, _ = self.lstm(x)
-            return y
+    verify_model_numerically(LSTM(4, 8, batch_first=True, 
bidirectional=False), (x,))
 
+    # Unidirectional LSTM with batch_first=False
     torch.manual_seed(43)
     x2 = torch.randn(4, 2, 3, dtype=torch.float32)
-    model2 = SeqFirstLSTM()
-    with torch.no_grad():
-        pytorch_output2 = model2(x2)
-    exported_program2 = export(model2, args=(x2,))
-    mod2 = from_exported_program(exported_program2)
-    ex2 = relax.build(mod2, target)
-    vm2 = relax.VirtualMachine(ex2, tvm.cpu())
-    x2_tvm = tvm.runtime.tensor(x2.numpy())
-    tvm_output2 = vm2["main"](x2_tvm)
-    if hasattr(tvm_output2, "numpy"):
-        tvm_output2_np = tvm_output2.numpy()
-    else:
-        tvm_output2_np = tvm_output2[0].numpy()
-    assert pytorch_output2.shape == tvm_output2_np.shape
-    np.testing.assert_allclose(pytorch_output2.numpy(), tvm_output2_np, 
rtol=1e-4, atol=1e-5)
+    verify_model_numerically(LSTM(3, 6, batch_first=False, 
bidirectional=False), (x2,))
+
+    # Bidirectional LSTM with batch_first=True
+    torch.manual_seed(44)
+    x3 = torch.randn(2, 3, 4, dtype=torch.float32)
+    verify_model_numerically(LSTM(4, 8, batch_first=True, bidirectional=True), 
(x3,))
+
+    # Bidirectional LSTM with batch_first=False
+    torch.manual_seed(45)
+    x4 = torch.randn(4, 2, 3, dtype=torch.float32)
+    verify_model_numerically(LSTM(3, 6, batch_first=False, 
bidirectional=True), (x4,))
 
 
 def test_tensor_none_tuple():

Reply via email to