This is an automated email from the ASF dual-hosted git repository.
mshr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 1c77db7889 [Relax][PyTorch] Add support for bidirectional LSTM (#18516)
1c77db7889 is described below
commit 1c77db78891b81b22ff8f1404546d1bee4fc1bb1
Author: Guan-Ming (Wesley) Chiu <[email protected]>
AuthorDate: Fri Nov 28 18:00:52 2025 +0800
[Relax][PyTorch] Add support for bidirectional LSTM (#18516)
---
.../frontend/torch/exported_program_translator.py | 269 +++++++++++++--------
.../relax/test_frontend_from_exported_program.py | 106 ++++----
2 files changed, 222 insertions(+), 153 deletions(-)
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 04e5330ce6..fc0ca18209 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -378,6 +378,75 @@ class ExportedProgramImporter(BaseFXGraphImporter):
align_corners=align_corners,
)
+ def _lstm_cell_unroll(
+ self,
+ input_reshaped,
+ weight_ih,
+ weight_hh,
+ bias_ih,
+ bias_hh,
+ h_prev,
+ c_prev,
+ seq_len,
+ hidden_size,
+ reverse=False,
+ ):
+ """Unroll LSTM cells for a single direction."""
+ weight_ih_t = self.block_builder.emit(relax.op.permute_dims(weight_ih,
axes=[1, 0]))
+ weight_hh_t = self.block_builder.emit(relax.op.permute_dims(weight_hh,
axes=[1, 0]))
+ outputs = []
+ time_steps = range(seq_len - 1, -1, -1) if reverse else range(seq_len)
+
+ for t in time_steps:
+ x_t = self.block_builder.emit(
+ relax.op.take(input_reshaped, relax.const(t, "int64"), axis=0,
mode="clip")
+ )
+ ih_gates =
self.block_builder.emit(relax.op.linear_algebra.matmul(x_t, weight_ih_t))
+ hh_gates =
self.block_builder.emit(relax.op.linear_algebra.matmul(h_prev, weight_hh_t))
+
+ gates = self.block_builder.emit(relax.op.add(ih_gates, hh_gates))
+ if bias_ih is not None:
+ gates = self.block_builder.emit(relax.op.add(gates, bias_ih))
+ if bias_hh is not None:
+ gates = self.block_builder.emit(relax.op.add(gates, bias_hh))
+
+ i_gate = self.block_builder.emit(
+ relax.op.strided_slice(gates, axes=[1], begin=[0],
end=[hidden_size])
+ )
+ f_gate = self.block_builder.emit(
+ relax.op.strided_slice(gates, axes=[1], begin=[hidden_size],
end=[2 * hidden_size])
+ )
+ g_gate = self.block_builder.emit(
+ relax.op.strided_slice(
+ gates, axes=[1], begin=[2 * hidden_size], end=[3 *
hidden_size]
+ )
+ )
+ o_gate = self.block_builder.emit(
+ relax.op.strided_slice(
+ gates, axes=[1], begin=[3 * hidden_size], end=[4 *
hidden_size]
+ )
+ )
+
+ i_t = self.block_builder.emit(relax.op.sigmoid(i_gate))
+ f_t = self.block_builder.emit(relax.op.sigmoid(f_gate))
+ g_t = self.block_builder.emit(relax.op.tanh(g_gate))
+ o_t = self.block_builder.emit(relax.op.sigmoid(o_gate))
+
+ c_t = self.block_builder.emit(
+ relax.op.add(relax.op.multiply(f_t, c_prev),
relax.op.multiply(i_t, g_t))
+ )
+ h_t = self.block_builder.emit(relax.op.multiply(o_t,
relax.op.tanh(c_t)))
+
+ outputs.append(h_t)
+ h_prev = h_t
+ c_prev = c_t
+
+ if reverse:
+ outputs = outputs[::-1]
+
+ output = self.block_builder.emit(relax.op.stack(outputs, axis=0))
+ return output
+
def _lstm(self, node: fx.Node) -> relax.Var:
args = self.retrieve_args(node)
input_tensor = args[0]
@@ -385,39 +454,30 @@ class ExportedProgramImporter(BaseFXGraphImporter):
params = args[2] if len(args) > 2 else None
has_biases = args[3] if len(args) > 3 else True
num_layers = args[4] if len(args) > 4 else 1
- _dropout = args[5] if len(args) > 5 else 0.0 # Not used in inference
- _train = args[6] if len(args) > 6 else False # Not used in inference
bidirectional = args[7] if len(args) > 7 else False
batch_first = args[8] if len(args) > 8 else False
- if bidirectional:
- raise NotImplementedError("Bidirectional LSTM is not yet
supported")
+
if num_layers > 1:
raise NotImplementedError("Multi-layer LSTM is not yet supported")
+
input_shape = self.shape_of(input_tensor)
if batch_first:
- # Input shape: (batch, seq_len, input_size)
batch_size, seq_len, input_size = input_shape
else:
- # Input shape: (seq_len, batch, input_size)
seq_len, batch_size, input_size = input_shape
- if isinstance(seq_len, tvm.tir.IntImm):
- seq_len = seq_len.value
- if isinstance(batch_size, tvm.tir.IntImm):
- batch_size = batch_size.value
- if isinstance(input_size, tvm.tir.IntImm):
- input_size = input_size.value
+ seq_len = int(seq_len) if isinstance(seq_len, tvm.tir.IntImm) else
seq_len
+ batch_size = int(batch_size) if isinstance(batch_size, tvm.tir.IntImm)
else batch_size
+ input_size = int(input_size) if isinstance(input_size, tvm.tir.IntImm)
else input_size
# Extract hidden size from the LSTM parameters
# The parameters are: [weight_ih, weight_hh, bias_ih, bias_hh]
# weight_ih shape: (4 * hidden_size, input_size)
# weight_hh shape: (4 * hidden_size, hidden_size)
if params and len(params) >= 2:
- weight_ih = params[0]
- weight_hh = params[1]
# Extract hidden size from weight dimensions
# weight_ih has shape (4 * hidden_size, input_size)
- weight_ih_shape = self.shape_of(weight_ih)
- hidden_size = weight_ih_shape[0] // 4 # 4 gates: input, forget,
cell, output
+ weight_ih_shape = self.shape_of(params[0])
+ hidden_size = weight_ih_shape[0] // 4
else:
# Fallback to a default hidden size
hidden_size = 16
@@ -430,109 +490,120 @@ class ExportedProgramImporter(BaseFXGraphImporter):
# c_t = f_t * c_{t-1} + i_t * g_t
# h_t = o_t * tanh(c_t)
dtype = input_tensor.struct_info.dtype
- if params and len(params) >= 4:
- weight_ih = params[0] # (4 * hidden_size, input_size)
- weight_hh = params[1] # (4 * hidden_size, hidden_size)
- bias_ih = params[2] if has_biases else None # (4 * hidden_size,)
- bias_hh = params[3] if has_biases else None # (4 * hidden_size,)
+ params_per_direction = 4 if has_biases else 2
+
+ # Extract or create forward direction weights
+ if params and len(params) >= 2:
+ weight_ih_fwd = params[0]
+ weight_hh_fwd = params[1]
+ bias_ih_fwd = params[2] if has_biases and len(params) > 2 else None
+ bias_hh_fwd = params[3] if has_biases and len(params) > 3 else None
else:
# Fallback: create zero weights
- weight_ih = self.block_builder.emit(
+ weight_ih_fwd = self.block_builder.emit(
relax.op.zeros(relax.ShapeExpr((4 * hidden_size, input_size)),
dtype)
)
- weight_hh = self.block_builder.emit(
+ weight_hh_fwd = self.block_builder.emit(
relax.op.zeros(relax.ShapeExpr((4 * hidden_size,
hidden_size)), dtype)
)
- bias_ih = None
- bias_hh = None
- # Initialize hidden and cell states
+ bias_ih_fwd = None
+ bias_hh_fwd = None
+
+ # Extract or create backward direction weights if bidirectional
+ if bidirectional:
+ if params and len(params) >= params_per_direction * 2:
+ weight_ih_bwd = params[params_per_direction]
+ weight_hh_bwd = params[params_per_direction + 1]
+ bias_ih_bwd = params[params_per_direction + 2] if has_biases
else None
+ bias_hh_bwd = params[params_per_direction + 3] if has_biases
else None
+ else:
+ # Fallback: create zero weights
+ weight_ih_bwd = self.block_builder.emit(
+ relax.op.zeros(relax.ShapeExpr((4 * hidden_size,
input_size)), dtype)
+ )
+ weight_hh_bwd = self.block_builder.emit(
+ relax.op.zeros(relax.ShapeExpr((4 * hidden_size,
hidden_size)), dtype)
+ )
+ bias_ih_bwd = None
+ bias_hh_bwd = None
+ else:
+ weight_ih_bwd = None
+ weight_hh_bwd = None
+ bias_ih_bwd = None
+ bias_hh_bwd = None
+
if hx is not None and len(hx) >= 2:
- h_0 = hx[0] # (num_layers, batch_size, hidden_size)
- c_0 = hx[1] # (num_layers, batch_size, hidden_size)
- # Extract the first layer's hidden state
- h_prev = self.block_builder.emit(
+ h_0, c_0 = hx[0], hx[1]
+ h_prev_fwd = self.block_builder.emit(
relax.op.take(h_0, relax.const(0, "int64"), axis=0,
mode="clip")
)
- c_prev = self.block_builder.emit(
+ c_prev_fwd = self.block_builder.emit(
relax.op.take(c_0, relax.const(0, "int64"), axis=0,
mode="clip")
)
+ if bidirectional:
+ h_prev_bwd = self.block_builder.emit(
+ relax.op.take(h_0, relax.const(1, "int64"), axis=0,
mode="clip")
+ )
+ c_prev_bwd = self.block_builder.emit(
+ relax.op.take(c_0, relax.const(1, "int64"), axis=0,
mode="clip")
+ )
+ else:
+ h_prev_bwd = None
+ c_prev_bwd = None
else:
- h_prev = self.block_builder.emit(
+ h_prev_fwd = self.block_builder.emit(
relax.op.zeros(relax.ShapeExpr((batch_size, hidden_size)),
dtype)
)
- c_prev = self.block_builder.emit(
+ c_prev_fwd = self.block_builder.emit(
relax.op.zeros(relax.ShapeExpr((batch_size, hidden_size)),
dtype)
)
- # Reshape input for processing
- if batch_first:
- # Input: (batch, seq_len, input_size) -> (seq_len, batch,
input_size)
- input_reshaped = self.block_builder.emit(
- relax.op.permute_dims(input_tensor, axes=[1, 0, 2])
- )
- else:
- input_reshaped = input_tensor
- weight_ih_t = self.block_builder.emit(relax.op.permute_dims(weight_ih,
axes=[1, 0]))
- weight_hh_t = self.block_builder.emit(relax.op.permute_dims(weight_hh,
axes=[1, 0]))
- outputs = []
- for t in range(seq_len):
- # Get input at time t: (batch_size, input_size)
- x_t = self.block_builder.emit(
- relax.op.take(input_reshaped, relax.const(t, "int64"), axis=0,
mode="clip")
- )
- # Compute gates: W_ih * x_t + W_hh * h_{t-1} + bias
- # Input-to-hidden: (batch_size, input_size) @ (4*hidden_size,
input_size).T
- ih_gates =
self.block_builder.emit(relax.op.linear_algebra.matmul(x_t, weight_ih_t))
-
- # Hidden-to-hidden: (batch_size, hidden_size) @ (4*hidden_size,
hidden_size).T
- hh_gates =
self.block_builder.emit(relax.op.linear_algebra.matmul(h_prev, weight_hh_t))
- # Add biases if present
- if bias_ih is not None and bias_hh is not None:
- gates = self.block_builder.emit(
- relax.op.add(relax.op.add(relax.op.add(ih_gates, bias_ih),
hh_gates), bias_hh)
- )
- elif bias_ih is not None:
- gates = self.block_builder.emit(
- relax.op.add(relax.op.add(ih_gates, bias_ih), hh_gates)
+ if bidirectional:
+ h_prev_bwd = self.block_builder.emit(
+ relax.op.zeros(relax.ShapeExpr((batch_size, hidden_size)),
dtype)
)
- elif bias_hh is not None:
- gates = self.block_builder.emit(
- relax.op.add(relax.op.add(ih_gates, hh_gates), bias_hh)
+ c_prev_bwd = self.block_builder.emit(
+ relax.op.zeros(relax.ShapeExpr((batch_size, hidden_size)),
dtype)
)
else:
- gates = self.block_builder.emit(relax.op.add(ih_gates,
hh_gates))
- # Split gates: (batch_size, 4 * hidden_size) -> 4 x (batch_size,
hidden_size)
- gate_size = hidden_size
- i_gate = self.block_builder.emit(
- relax.op.strided_slice(gates, axes=[1], begin=[0],
end=[gate_size])
- )
- f_gate = self.block_builder.emit(
- relax.op.strided_slice(gates, axes=[1], begin=[gate_size],
end=[2 * gate_size])
- )
- g_gate = self.block_builder.emit(
- relax.op.strided_slice(gates, axes=[1], begin=[2 * gate_size],
end=[3 * gate_size])
- )
- o_gate = self.block_builder.emit(
- relax.op.strided_slice(gates, axes=[1], begin=[3 * gate_size],
end=[4 * gate_size])
- )
- # Apply activations
- i_t = self.block_builder.emit(relax.op.sigmoid(i_gate))
- f_t = self.block_builder.emit(relax.op.sigmoid(f_gate))
- g_t = self.block_builder.emit(relax.op.tanh(g_gate))
- o_t = self.block_builder.emit(relax.op.sigmoid(o_gate))
- # Update cell state: c_t = f_t * c_{t-1} + i_t * g_t
- c_t = self.block_builder.emit(
- relax.op.add(relax.op.multiply(f_t, c_prev),
relax.op.multiply(i_t, g_t))
+ h_prev_bwd = None
+ c_prev_bwd = None
+
+ input_reshaped = (
+ self.block_builder.emit(relax.op.permute_dims(input_tensor,
axes=[1, 0, 2]))
+ if batch_first
+ else input_tensor
+ )
+
+ output_fwd = self._lstm_cell_unroll(
+ input_reshaped,
+ weight_ih_fwd,
+ weight_hh_fwd,
+ bias_ih_fwd,
+ bias_hh_fwd,
+ h_prev_fwd,
+ c_prev_fwd,
+ seq_len,
+ hidden_size,
+ reverse=False,
+ )
+
+ if bidirectional:
+ output_bwd = self._lstm_cell_unroll(
+ input_reshaped,
+ weight_ih_bwd,
+ weight_hh_bwd,
+ bias_ih_bwd,
+ bias_hh_bwd,
+ h_prev_bwd,
+ c_prev_bwd,
+ seq_len,
+ hidden_size,
+ reverse=True,
)
- # Update hidden state: h_t = o_t * tanh(c_t)
- h_t = self.block_builder.emit(relax.op.multiply(o_t,
relax.op.tanh(c_t)))
- # Store output
- outputs.append(h_t)
- # Update for next iteration
- h_prev = h_t
- c_prev = c_t
- # Stack outputs: (seq_len, batch_size, hidden_size)
- output = self.block_builder.emit(relax.op.stack(outputs, axis=0))
- # Reshape back to batch_first if needed
+ output = self.block_builder.emit(relax.op.concat([output_fwd,
output_bwd], axis=2))
+ else:
+ output = output_fwd
+
if batch_first:
# (seq_len, batch_size, hidden_size) -> (batch_size, seq_len,
hidden_size)
output = self.block_builder.emit(relax.op.permute_dims(output,
axes=[1, 0, 2]))
diff --git a/tests/python/relax/test_frontend_from_exported_program.py
b/tests/python/relax/test_frontend_from_exported_program.py
index fe3ff28aea..8ff46bf611 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -57,6 +57,37 @@ def verify_model(
tvm.ir.assert_structural_equal(mod, expected, map_free_vars=map_free_vars)
+def verify_model_numerically(torch_model, example_args, rtol=1e-7, atol=1e-7):
+ """Verify model by comparing numerical outputs between PyTorch and TVM."""
+ with torch.no_grad():
+ pytorch_output = torch_model(*example_args)
+
+ exported_program = export(torch_model, args=example_args)
+ mod = from_exported_program(exported_program)
+ target = tvm.target.Target("llvm")
+ ex = relax.build(mod, target)
+ vm = relax.VirtualMachine(ex, tvm.cpu())
+
+ tvm_args = [tvm.runtime.tensor(arg.numpy()) for arg in example_args]
+ tvm_output = vm["main"](*tvm_args)
+
+ if hasattr(tvm_output, "numpy"):
+ tvm_output_np = tvm_output.numpy()
+ else:
+ tvm_output_np = tvm_output[0].numpy()
+
+ pytorch_output_np = (
+ pytorch_output.numpy()
+ if isinstance(pytorch_output, torch.Tensor)
+ else pytorch_output[0].numpy()
+ )
+
+ assert (
+ pytorch_output_np.shape == tvm_output_np.shape
+ ), f"Shape mismatch: PyTorch {pytorch_output_np.shape} vs TVM
{tvm_output_np.shape}"
+ tvm.testing.assert_allclose(pytorch_output_np, tvm_output_np, rtol=rtol,
atol=atol)
+
+
operator_basic_unary = [
(torch.abs, R.abs),
(torch.acos, R.acos),
@@ -7831,75 +7862,42 @@ def test_sparse_mm():
verify_model(SparseMatrixMultiply(), example_args, {}, Expected)
[email protected]_llvm
def test_lstm():
- class BasicLSTM(nn.Module):
- def __init__(self):
+ class LSTM(nn.Module):
+ def __init__(self, input_size, hidden_size, batch_first,
bidirectional):
super().__init__()
self.lstm = nn.LSTM(
- input_size=4,
- hidden_size=8,
+ input_size=input_size,
+ hidden_size=hidden_size,
num_layers=1,
- batch_first=True,
- bidirectional=False,
+ batch_first=batch_first,
+ bidirectional=bidirectional,
)
def forward(self, x):
y, _ = self.lstm(x)
return y
+ # Unidirectional LSTM with batch_first=True
torch.manual_seed(42)
x = torch.randn(2, 3, 4, dtype=torch.float32)
- model = BasicLSTM()
- with torch.no_grad():
- pytorch_output = model(x)
- exported_program = export(model, args=(x,))
- mod = from_exported_program(exported_program)
- target = tvm.target.Target("llvm")
- ex = relax.build(mod, target)
- vm = relax.VirtualMachine(ex, tvm.cpu())
- x_tvm = tvm.runtime.tensor(x.numpy())
- tvm_output = vm["main"](x_tvm)
- if hasattr(tvm_output, "numpy"):
- tvm_output_np = tvm_output.numpy()
- else:
- tvm_output_np = tvm_output[0].numpy()
- assert (
- pytorch_output.shape == tvm_output_np.shape
- ), f"Shape mismatch: PyTorch {pytorch_output.shape} vs TVM
{tvm_output_np.shape}"
- np.testing.assert_allclose(pytorch_output.numpy(), tvm_output_np,
rtol=1e-4, atol=1e-5)
-
- class SeqFirstLSTM(nn.Module):
- def __init__(self):
- super().__init__()
- self.lstm = nn.LSTM(
- input_size=3,
- hidden_size=6,
- num_layers=1,
- batch_first=False,
- bidirectional=False,
- )
-
- def forward(self, x):
- y, _ = self.lstm(x)
- return y
+ verify_model_numerically(LSTM(4, 8, batch_first=True,
bidirectional=False), (x,))
+ # Unidirectional LSTM with batch_first=False
torch.manual_seed(43)
x2 = torch.randn(4, 2, 3, dtype=torch.float32)
- model2 = SeqFirstLSTM()
- with torch.no_grad():
- pytorch_output2 = model2(x2)
- exported_program2 = export(model2, args=(x2,))
- mod2 = from_exported_program(exported_program2)
- ex2 = relax.build(mod2, target)
- vm2 = relax.VirtualMachine(ex2, tvm.cpu())
- x2_tvm = tvm.runtime.tensor(x2.numpy())
- tvm_output2 = vm2["main"](x2_tvm)
- if hasattr(tvm_output2, "numpy"):
- tvm_output2_np = tvm_output2.numpy()
- else:
- tvm_output2_np = tvm_output2[0].numpy()
- assert pytorch_output2.shape == tvm_output2_np.shape
- np.testing.assert_allclose(pytorch_output2.numpy(), tvm_output2_np,
rtol=1e-4, atol=1e-5)
+ verify_model_numerically(LSTM(3, 6, batch_first=False,
bidirectional=False), (x2,))
+
+ # Bidirectional LSTM with batch_first=True
+ torch.manual_seed(44)
+ x3 = torch.randn(2, 3, 4, dtype=torch.float32)
+ verify_model_numerically(LSTM(4, 8, batch_first=True, bidirectional=True),
(x3,))
+
+ # Bidirectional LSTM with batch_first=False
+ torch.manual_seed(45)
+ x4 = torch.randn(4, 2, 3, dtype=torch.float32)
+ verify_model_numerically(LSTM(3, 6, batch_first=False,
bidirectional=True), (x4,))
def test_tensor_none_tuple():