This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new e3933804ee [Relax][Frontend][TFLite] Support sequence LSTM and RNN
operators (#19634)
e3933804ee is described below
commit e3933804ee15239d38527a18c69c507d76af8ffe
Author: YinHanke <[email protected]>
AuthorDate: Sun May 31 00:49:53 2026 +0800
[Relax][Frontend][TFLite] Support sequence LSTM and RNN operators (#19634)
## Summary
Add three TFLite sequence recurrent operators to the Relax frontend, all
with
coupled input-forget gate (FULL kernel) and float32-only support.
- UNIDIRECTIONAL_SEQUENCE_LSTM
- BIDIRECTIONAL_SEQUENCE_RNN
- BIDIRECTIONAL_SEQUENCE_LSTM
From #19519.
## Changes
- **UNIDIRECTIONAL_SEQUENCE_LSTM**: same layout as single-step LSTM,
unrolls over
time and stacks per-step hidden states. Supports time_major, cell_clip,
proj_clip,
and fused activation.
- **BIDIRECTIONAL_SEQUENCE_RNN**: separate fw/bw RNN cells, backward
scans in
reverse. Supports merge_outputs (concat fw + bw) and split outputs via
Tuple.
- **BIDIRECTIONAL_SEQUENCE_LSTM**: 48-input operator with fw/bw LSTM
cells sharing
the same input tensor. States at indices 35-38.
- All converters propagate final states to exp_tab for multi-step
correctness.
- Peephole, projection, layer norm, and aux input are not supported
(raise
OpNotImplemented).
## Testing
- `test_unidirectional_sequence_lstm_none_activation` — output shape
[batch, time, num_units]
- `test_bidirectional_sequence_rnn_none_activation` —
merge_outputs=True, shape [batch, time, 2*num_units]
- `test_bidirectional_sequence_lstm_none_activation` —
merge_outputs=True, shape [batch, time, 2*num_units]
```bash
python -m pytest tests/python/relax/test_frontend_tflite.py -k
"sequence_lstm or sequence_rnn" -v
```
---
.../tvm/relax/frontend/tflite/tflite_frontend.py | 670 ++++++++++++----
tests/python/relax/test_frontend_tflite.py | 892 +++++++++++++++++++++
2 files changed, 1425 insertions(+), 137 deletions(-)
diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py
b/python/tvm/relax/frontend/tflite/tflite_frontend.py
index c479ec83c1..7046e43bbe 100644
--- a/python/tvm/relax/frontend/tflite/tflite_frontend.py
+++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py
@@ -200,6 +200,8 @@ class OperatorConverter:
"AVERAGE_POOL_2D": functools.partial(self.convert_pool2d,
pool_type="average"),
"BATCH_TO_SPACE_ND": self.convert_batch_to_space_nd,
"BATCH_MATMUL": self.convert_batch_matmul,
+ "BIDIRECTIONAL_SEQUENCE_LSTM":
self.convert_bidirectional_sequence_lstm,
+ "BIDIRECTIONAL_SEQUENCE_RNN":
self.convert_bidirectional_sequence_rnn,
"BITCAST": self.convert_bitcast,
"BROADCAST_TO": self.convert_broadcast_to,
"BROADCAST_ARGS": self.convert_broadcast_args,
@@ -404,7 +406,7 @@ class OperatorConverter:
"UNSORTED_SEGMENT_PROD": functools.partial(
self._convert_segment_op, op_name="UNSORTED_SEGMENT_PROD",
reduction="mul"
),
- # "UNIDIRECTIONAL_SEQUENCE_LSTM":
self.convert_unidirectional_sequence_lstm,
+ "UNIDIRECTIONAL_SEQUENCE_LSTM":
self.convert_unidirectional_sequence_lstm,
"VAR_HANDLE": self.convert_var_handle,
"WHERE": self.convert_select,
"WHILE": self.convert_while,
@@ -5510,153 +5512,547 @@ class OperatorConverter:
# Stack timestep outputs: [batch, time, num_units].
return relax.op.stack(outputs, axis=1)
- """
def convert_unidirectional_sequence_lstm(self, op):
- ### Long Short Term Memory for TFLite implementation. ###
+ """Convert TFLite UNIDIRECTIONAL_SEQUENCE_LSTM.
+
+ Inputs (24 tensors, same layout as single-step LSTM):
+ [0] input [batch, time, input_size]
+ [1] input_to_input_weights [num_units, input_size] (optional)
+ [2] input_to_forget_weights [num_units, input_size]
+ [3] input_to_cell_weights [num_units, input_size]
+ [4] input_to_output_weights [num_units, input_size]
+ [5] recurrent_to_input_weights [num_units, num_units] (optional)
+ [6] recurrent_to_forget_weights [num_units, num_units]
+ [7] recurrent_to_cell_weights [num_units, num_units]
+ [8] recurrent_to_output_weights [num_units, num_units]
+ [9] cell_to_input_weights [num_units] (optional)
+ [10] cell_to_forget_weights [num_units] (optional)
+ [11] cell_to_output_weights [num_units] (optional)
+ [12] input_gate_bias [num_units] (optional)
+ [13] forget_gate_bias [num_units]
+ [14] cell_gate_bias [num_units]
+ [15] output_gate_bias [num_units]
+ [16] projection_weights [num_units, num_units] (optional)
+ [17] projection_bias [num_units] (optional)
+ [18] output_state [batch, num_units] (variable)
+ [19] cell_state [batch, num_units] (variable)
+ [20-23] optional layer norm weights
+
+ Output:
+ [0] output [batch, time, num_units]
+
+ Uses coupled input-forget gate (i = 1 - f) for the FULL kernel.
+ """
+ from tflite.BuiltinOptions import BuiltinOptions
+ from tflite.UnidirectionalSequenceLSTMOptions import
UnidirectionalSequenceLSTMOptions
+
if self.is_quantized(op):
raise tvm.error.OpNotImplemented(
- "TFlite quantized UNIDIRECTIONALSEQUENCELSTM operator is not
supported yet."
+ "TFLite quantized UNIDIRECTIONAL_SEQUENCE_LSTM is not
supported yet."
)
input_tensors = self.get_input_tensors(op)
- assert len(input_tensors) == 24, "input tensors length should be == 24"
+ assert len(input_tensors) == 24, (
+ f"input tensors length should be 24, got {len(input_tensors)}"
+ )
- # Extract input tensor from saved model
- input_tensor = input_tensors[0]
+ output_tensors = self.get_output_tensors(op)
+ assert len(output_tensors) >= 1, "output tensors length should be at
least 1"
+
+ assert op.BuiltinOptionsType() ==
BuiltinOptions.UnidirectionalSequenceLSTMOptions
+ op_options = op.BuiltinOptions()
+ lstm_opts = UnidirectionalSequenceLSTMOptions()
+ lstm_opts.Init(op_options.Bytes, op_options.Pos)
+ time_major = lstm_opts.TimeMajor()
+ fused_activation_fn = lstm_opts.FusedActivationFunction()
+ cell_clip = lstm_opts.CellClip()
+ proj_clip = lstm_opts.ProjClip()
+
+ # Only coupled input-forget gate is supported.
+ if input_tensors[1].tensor_idx != -1 or input_tensors[5].tensor_idx !=
-1:
+ raise tvm.error.OpNotImplemented("Only coupled input-forget LSTM
is supported.")
+ if any(input_tensors[idx].tensor_idx != -1 for idx in [9, 10, 11]):
+ raise tvm.error.OpNotImplemented("TFLite peephole LSTM is not
supported yet.")
+ if any(input_tensors[idx].tensor_idx != -1 for idx in [16, 17]):
+ raise tvm.error.OpNotImplemented("TFLite projection LSTM is not
supported yet.")
+ if any(input_tensors[idx].tensor_idx != -1 for idx in [20, 21, 22,
23]):
+ raise tvm.error.OpNotImplemented("TFLite layer-norm LSTM is not
supported yet.")
+
+ # Weights (transposed once outside the loop).
+ w_f_t = relax.op.permute_dims(self.get_tensor_expr(input_tensors[2]))
+ w_c_t = relax.op.permute_dims(self.get_tensor_expr(input_tensors[3]))
+ w_o_t = relax.op.permute_dims(self.get_tensor_expr(input_tensors[4]))
+ r_f_t = relax.op.permute_dims(self.get_tensor_expr(input_tensors[6]))
+ r_c_t = relax.op.permute_dims(self.get_tensor_expr(input_tensors[7]))
+ r_o_t = relax.op.permute_dims(self.get_tensor_expr(input_tensors[8]))
+
+ # Biases.
+ b_f = self.get_tensor_expr(input_tensors[13])
+ b_c = self.get_tensor_expr(input_tensors[14])
+ b_o = self.get_tensor_expr(input_tensors[15])
+
+ # Initial states.
+ h = self.get_tensor_expr(input_tensors[18])
+ c = self.get_tensor_expr(input_tensors[19])
+
+ # Resolve the input expression; normalise to batch-major [batch, time,
input_size].
+ in_expr = self.get_tensor_expr(input_tensors[0])
+ in_shape = self.get_tensor_shape(input_tensors[0])
+ if time_major:
+ in_expr = relax.op.permute_dims(in_expr, [1, 0, 2])
+ num_steps = int(in_shape[0])
+ else:
+ num_steps = int(in_shape[1])
+
+ # Unroll over the time axis.
+ if num_steps == 1:
+ steps = [relax.op.squeeze(in_expr, axis=[1])]
+ else:
+ splits = relax.op.split(in_expr, num_steps, axis=1)
+ steps = [relax.op.squeeze(splits[i], axis=[1]) for i in
range(num_steps)]
+
+ one = relax.const(1.0, "float32")
+ outputs = []
+ for x_t in steps:
+ f = relax.op.sigmoid(
+ relax.op.add(
+ relax.op.add(
+ relax.op.matmul(x_t, w_f_t),
+ relax.op.matmul(h, r_f_t),
+ ),
+ b_f,
+ )
+ )
+ i = relax.op.subtract(one, f)
+ g = self.convert_fused_activation_function(
+ relax.op.add(
+ relax.op.add(relax.op.matmul(x_t, w_c_t),
relax.op.matmul(h, r_c_t)),
+ b_c,
+ ),
+ fused_activation_fn,
+ )
+ o = relax.op.sigmoid(
+ relax.op.add(
+ relax.op.add(
+ relax.op.matmul(x_t, w_o_t),
+ relax.op.matmul(h, r_o_t),
+ ),
+ b_o,
+ )
+ )
+
+ c_new = relax.op.add(relax.op.multiply(f, c), relax.op.multiply(i,
g))
+ if cell_clip > 0.0:
+ c_new = relax.op.clip(c_new, -cell_clip, cell_clip)
+
+ h_new = relax.op.multiply(
+ o, self.convert_fused_activation_function(c_new,
fused_activation_fn)
+ )
+ if proj_clip > 0.0:
+ h_new = relax.op.clip(h_new, -proj_clip, proj_clip)
+ outputs.append(h_new)
+ h, c = h_new, c_new
+
+ h_out = relax.op.stack(outputs, axis=1)
+ if time_major:
+ h_out = relax.op.permute_dims(h_out, [1, 0, 2])
+
+ # Update state tensors in the expression table for subsequent ops.
+ self.exp_tab.set_expr(
+ get_tensor_name(self.subgraph, input_tensors[18].tensor_idx),
+ h,
+ force_override=True,
+ )
+ self.exp_tab.set_expr(
+ get_tensor_name(self.subgraph, input_tensors[19].tensor_idx),
+ c,
+ force_override=True,
+ )
+
+ return h_out
+
+ def convert_bidirectional_sequence_rnn(self, op):
+ """Convert TFLite BIDIRECTIONAL_SEQUENCE_RNN.
+
+ Inputs (9 tensors, aux_input not supported):
+ [0] input [batch, time, input_size]
+ [1] fw_weights [num_units, input_size]
+ [2] fw_recurrent_weights [num_units, num_units]
+ [3] fw_bias [num_units]
+ [4] fw_hidden_state [batch, num_units] (variable)
+ [5] bw_weights [num_units, input_size]
+ [6] bw_recurrent_weights [num_units, num_units]
+ [7] bw_bias [num_units]
+ [8] bw_hidden_state [batch, num_units] (variable)
+
+ Output (merge_outputs=True):
+ [0] output [batch, time, 2 * num_units] (fw and bw concatenated)
+
+ Output (merge_outputs=False):
+ [0] fw_output [batch, time, num_units]
+ [1] bw_output [batch, time, num_units]
+ """
+ from tflite.BidirectionalSequenceRNNOptions import
BidirectionalSequenceRNNOptions
+ from tflite.BuiltinOptions import BuiltinOptions
+
+ if self.is_quantized(op):
+ raise tvm.error.OpNotImplemented(
+ "TFLite quantized BIDIRECTIONAL_SEQUENCE_RNN is not supported
yet."
+ )
+
+ input_tensors = self.get_input_tensors(op)
+ assert len(input_tensors) == 12, (
+ f"input tensors length should be 12, got {len(input_tensors)}"
+ )
- # Extract tensors from input tensors from saved model
- # Input weights
- input_input_weights = input_tensors[1]
- input_forget_weights = input_tensors[2]
- input_cell_weights = input_tensors[3]
- input_output_weights = input_tensors[4]
- # Recurrent weights
- recurrent_input_weights = input_tensors[5]
- recurrent_forget_weights = input_tensors[6]
- recurrent_cell_weights = input_tensors[7]
- recurrent_output_weights = input_tensors[8]
- # inputs 9, 10, 11, 16, 17, 20, 21, 22, 23 are not occupied
- # there locations are -1 in the flatbuffer
- # Bias weights
- input_gate_bias = input_tensors[12]
- forget_gate_bias = input_tensors[13]
- cell_gate_bias = input_tensors[14]
- output_gate_bias = input_tensors[15]
-
- # State input
- output_state_in = input_tensors[18]
- cell_state_in = input_tensors[19]
-
- # Extract output tensor from saved model
output_tensors = self.get_output_tensors(op)
- assert len(output_tensors) == 1, "output tensors length should be 1"
- X_steps = self.unbind(input_tensor, axis=1)
- weights_dict = {}
-
- # hidden_state_weights is equivalent to output_state_in in tflite model
- out_state_in_shape = tuple(self.get_tensor_shape(output_state_in))
- out_state_in_dtype =
self.get_tensor_type_str(output_state_in.tensor.Type())
- out_state_in_expr = relax.op.zeros(out_state_in_shape,
dtype=out_state_in_dtype)
- weights_dict["hidden_state"] = relax.op.split(out_state_in_expr, 1)[0]
-
- # cell_state_weights is equivalent to output_state_in tflite model
- cell_state_in_shape = tuple(self.get_tensor_shape(cell_state_in))
- cell_state_in_dtype =
self.get_tensor_type_str(cell_state_in.tensor.Type())
- cell_state_in_expr = relax.op.zeros(cell_state_in_shape,
dtype=cell_state_in_dtype)
- weights_dict["cell_state"] = relax.op.split(cell_state_in_expr, 1)[0]
-
- # Process weight matrix of input: w_inp
- # Concatenate of [input_input_weight, input_forget_weights,
- # input_cell_weights, input_output_weights]
- input_input_weights_default_values =
self.get_tensor_value(input_input_weights)
- input_input_weights_op = relax.op.split(
- relax.op.const(input_input_weights_default_values.tolist()), 1
- )
- input_output_weights_default_values =
self.get_tensor_value(input_output_weights)
- input_output_weights_op = relax.op.split(
- relax.op.const(input_output_weights_default_values.tolist()), 1
- )
- input_forget_weights_default_values =
self.get_tensor_value(input_forget_weights)
- input_forget_weights_op = relax.op.split(
- relax.op.const(input_forget_weights_default_values.tolist()), 1
- )
- input_cell_weights_default_values =
self.get_tensor_value(input_cell_weights)
- input_cell_weights_op = relax.op.split(
- _op.const(input_cell_weights_default_values.tolist()), 1
- )
- weights_dict["w_inp"] = relax.op.concat(
- [
- relax.op.squeeze(input_input_weights_op[0]),
- relax.op.squeeze(input_forget_weights_op[0]),
- relax.op.squeeze(input_cell_weights_op[0]),
- relax.op.squeeze(input_output_weights_op[0]),
- ],
- axis=0,
- )
-
- # Process weight matrix of hidden state:
- # w_hid to support lstm_cell function. Not used in tflite
- recurrent_input_weights_values =
self.get_tensor_value(recurrent_input_weights)
- recurrent_input_weights_op = relax.op.split(
- relax.op.const(recurrent_input_weights_values.tolist()), 1
- )
- recurrent_output_weights_values =
self.get_tensor_value(recurrent_output_weights)
- recurrent_output_weights_op = relax.op.split(
- relax.op.const(recurrent_output_weights_values.tolist()), 1
- )
- recurrent_forget_weights_values =
self.get_tensor_value(recurrent_forget_weights)
- recurrent_forget_weights_op = relax.op.split(
- relax.op.const(recurrent_forget_weights_values.tolist()), 1
- )
- recurrent_cell_weights_values =
self.get_tensor_value(recurrent_cell_weights)
- recurrent_cell_weights_op = relax.op.split(
- _op.const(recurrent_cell_weights_values.tolist()), 1
- )
- weights_dict["w_hid"] = relax.op.concat(
- [
- recurrent_input_weights_op[0],
- recurrent_forget_weights_op[0],
- recurrent_cell_weights_op[0],
- recurrent_output_weights_op[0],
- ],
- axis=0,
- )
-
- # Process weight matrix of bias: b_inp
- input_gate_bias_values = self.get_tensor_value(input_gate_bias)
- input_gate_bias_op =
relax.op.split(_op.const(input_gate_bias_values.tolist()), 1)
- output_gate_bias_values = self.get_tensor_value(output_gate_bias)
- output_gate_bias_op =
relax.op.split(_op.const(output_gate_bias_values.tolist()), 1)
- forget_gate_bias_values = self.get_tensor_value(forget_gate_bias)
- forget_gate_bias_op =
relax.op.split(_op.const(forget_gate_bias_values.tolist()), 1)
- cell_gate_bias_values = self.get_tensor_value(cell_gate_bias)
- cell_gate_bias_op =
relax.op.split(_op.const(cell_gate_bias_values.tolist()), 1)
- weights_dict["b_inp"] = relax.op.concat(
- [
- input_gate_bias_op[0],
- forget_gate_bias_op[0],
- cell_gate_bias_op[0],
- output_gate_bias_op[0],
- ],
- axis=0,
- )
-
- # Process weight matrix of hidden bias:
- # b_hid (with the same shape as b_inp)
- gate_bias_dtype =
self.get_tensor_type_str(input_gate_bias.tensor.Type())
- weights_dict["b_hid"] = relax.op.split(
- relax.op.const(
- np.zeros(self._infer_shape(weights_dict["b_inp"]),
dtype=gate_bias_dtype),
- dtype=gate_bias_dtype,
- ),
- 1,
- )[0]
+ assert len(output_tensors) >= 1, "output tensors length should be at
least 1"
+
+ assert op.BuiltinOptionsType() ==
BuiltinOptions.BidirectionalSequenceRNNOptions
+ op_options = op.BuiltinOptions()
+ rnn_opts = BidirectionalSequenceRNNOptions()
+ rnn_opts.Init(op_options.Bytes, op_options.Pos)
+ time_major = rnn_opts.TimeMajor()
+ fused_activation_fn = rnn_opts.FusedActivationFunction()
+ merge_outputs = rnn_opts.MergeOutputs()
+ if any(input_tensors[idx].tensor_idx != -1 for idx in [9, 10, 11]):
+ raise tvm.error.OpNotImplemented(
+ "TFLite BIDIRECTIONAL_SEQUENCE_RNN aux input is not supported
yet."
+ )
- outputs, _, _ = lstm_cell(input_seqs=X_steps, **weights_dict)
+ # Forward weights and biases.
+ fw_weights_expr = self.get_tensor_expr(input_tensors[1])
+ fw_recurrent_expr = self.get_tensor_expr(input_tensors[2])
+ fw_bias_expr = self.get_tensor_expr(input_tensors[3])
+ fw_w_t = relax.op.permute_dims(fw_weights_expr)
+ fw_wr_t = relax.op.permute_dims(fw_recurrent_expr)
- output = relax.op.stack(outputs, axis=1)
- return output
- """
+ # Backward weights and biases.
+ bw_weights_expr = self.get_tensor_expr(input_tensors[5])
+ bw_recurrent_expr = self.get_tensor_expr(input_tensors[6])
+ bw_bias_expr = self.get_tensor_expr(input_tensors[7])
+ bw_w_t = relax.op.permute_dims(bw_weights_expr)
+ bw_wr_t = relax.op.permute_dims(bw_recurrent_expr)
+
+ # Resolve the input expression; normalise to batch-major [batch, time,
input_size].
+ in_expr = self.get_tensor_expr(input_tensors[0])
+ in_shape = self.get_tensor_shape(input_tensors[0])
+ if time_major:
+ in_expr = relax.op.permute_dims(in_expr, [1, 0, 2])
+ num_steps = int(in_shape[0])
+ else:
+ num_steps = int(in_shape[1])
+
+ # Initial hidden states.
+ def _get_hidden_state(tensor):
+ if self.has_expr(tensor.tensor_idx) or (
+ tensor.buffer is not None and tensor.buffer.DataLength() > 0
+ ):
+ return self.get_tensor_expr(tensor)
+ dtype = self.get_tensor_type_str(tensor.tensor.Type())
+ h_shape = tuple(to_int_list(self.get_tensor_shape(tensor)))
+ return relax.op.zeros(h_shape, dtype=dtype)
+
+ fw_h = _get_hidden_state(input_tensors[4])
+ bw_h = _get_hidden_state(input_tensors[8])
+
+ # Unroll over the time axis.
+ if num_steps == 1:
+ steps = [relax.op.squeeze(in_expr, axis=[1])]
+ else:
+ splits = relax.op.split(in_expr, num_steps, axis=1)
+ steps = [relax.op.squeeze(splits[i], axis=[1]) for i in
range(num_steps)]
+
+ # Forward pass.
+ fw_outputs = []
+ for x_t in steps:
+ gates = relax.op.add(
+ relax.op.add(relax.op.matmul(x_t, fw_w_t),
relax.op.matmul(fw_h, fw_wr_t)),
+ fw_bias_expr,
+ )
+ fw_h = self.convert_fused_activation_function(gates,
fused_activation_fn)
+ fw_outputs.append(fw_h)
+
+ # Backward pass (process steps in reverse).
+ bw_outputs = []
+ for x_t in reversed(steps):
+ gates = relax.op.add(
+ relax.op.add(relax.op.matmul(x_t, bw_w_t),
relax.op.matmul(bw_h, bw_wr_t)),
+ bw_bias_expr,
+ )
+ bw_h = self.convert_fused_activation_function(gates,
fused_activation_fn)
+ bw_outputs.append(bw_h)
+ bw_outputs.reverse()
+
+ fw_stacked = relax.op.stack(fw_outputs, axis=1) # [batch, time,
num_units]
+ bw_stacked = relax.op.stack(bw_outputs, axis=1) # [batch, time,
num_units]
+ if time_major:
+ fw_stacked = relax.op.permute_dims(fw_stacked, [1, 0, 2])
+ bw_stacked = relax.op.permute_dims(bw_stacked, [1, 0, 2])
+
+ # Update state tensors in the expression table for subsequent ops.
+ self.exp_tab.set_expr(
+ get_tensor_name(self.subgraph, input_tensors[4].tensor_idx),
+ fw_h,
+ force_override=True,
+ )
+ self.exp_tab.set_expr(
+ get_tensor_name(self.subgraph, input_tensors[8].tensor_idx),
+ bw_h,
+ force_override=True,
+ )
+
+ if merge_outputs:
+ return relax.op.concat([fw_stacked, bw_stacked], axis=-1)
+ else:
+ return relax.Tuple([fw_stacked, bw_stacked])
+
+ def convert_bidirectional_sequence_lstm(self, op):
+ """Convert TFLite BIDIRECTIONAL_SEQUENCE_LSTM.
+
+ Inputs (48 tensors, indices 0-17 forward LSTM, 18-34 backward LSTM,
35-38 states,
+ 39-47 optional aux inputs, which are not supported):
+
+ Forward LSTM cell (indices 0-17, same layout as single-step LSTM):
+ [0] input (shared) [batch, time, input_size]
+ [1] fw_input_to_input_weights (optional)
+ [2] fw_input_to_forget_weights
+ [3] fw_input_to_cell_weights
+ [4] fw_input_to_output_weights
+ [5] fw_recurrent_to_input_wts (optional)
+ [6] fw_recurrent_to_forget_wts
+ [7] fw_recurrent_to_cell_wts
+ [8] fw_recurrent_to_output_wts
+ [9-11] fw cell_to_*_weights (optional, not supported)
+ [12] fw_input_gate_bias (optional)
+ [13] fw_forget_gate_bias
+ [14] fw_cell_gate_bias
+ [15] fw_output_gate_bias
+ [16] fw_projection_weights (optional, not supported)
+ [17] fw_projection_bias (optional, not supported)
+
+ Backward LSTM cell (indices 18-34, same layout as fw):
+ [19] bw_input_to_forget_weights
+ [20] bw_input_to_cell_weights
+ [21] bw_input_to_output_weights
+ [23] bw_recurrent_to_forget_wts
+ [24] bw_recurrent_to_cell_wts
+ [25] bw_recurrent_to_output_wts
+ [30] bw_forget_gate_bias
+ [31] bw_cell_gate_bias
+ [32] bw_output_gate_bias
+
+ State tensors:
+ [35] fw_activation_state [batch, num_units]
+ [36] fw_cell_state [batch, num_units]
+ [37] bw_activation_state [batch, num_units]
+ [38] bw_cell_state [batch, num_units]
+
+ Output (merge_outputs=True):
+ [0] output [batch, time, 2 * num_units]
+
+ Output (merge_outputs=False):
+ [0] fw_output [batch, time, num_units]
+ [1] bw_output [batch, time, num_units]
+ """
+ from tflite.BidirectionalSequenceLSTMOptions import
BidirectionalSequenceLSTMOptions
+ from tflite.BuiltinOptions import BuiltinOptions
+
+ if self.is_quantized(op):
+ raise tvm.error.OpNotImplemented(
+ "TFLite quantized BIDIRECTIONAL_SEQUENCE_LSTM is not supported
yet."
+ )
+
+ input_tensors = self.get_input_tensors(op)
+ assert len(input_tensors) == 48, (
+ f"input tensors length should be 48, got {len(input_tensors)}"
+ )
+
+ output_tensors = self.get_output_tensors(op)
+ assert len(output_tensors) >= 1, "output tensors length should be at
least 1"
+
+ assert op.BuiltinOptionsType() ==
BuiltinOptions.BidirectionalSequenceLSTMOptions
+ op_options = op.BuiltinOptions()
+ lstm_opts = BidirectionalSequenceLSTMOptions()
+ lstm_opts.Init(op_options.Bytes, op_options.Pos)
+ time_major = lstm_opts.TimeMajor()
+ fused_activation_fn = lstm_opts.FusedActivationFunction()
+ merge_outputs = lstm_opts.MergeOutputs()
+ cell_clip = lstm_opts.CellClip()
+ proj_clip = lstm_opts.ProjClip()
+
+ # ── Forward LSTM weights (transposed once outside the loop) ──
+ if input_tensors[1].tensor_idx != -1 or input_tensors[5].tensor_idx !=
-1:
+ raise tvm.error.OpNotImplemented("Only coupled input-forget LSTM
is supported.")
+ if any(input_tensors[idx].tensor_idx != -1 for idx in [9, 10, 11]):
+ raise tvm.error.OpNotImplemented("TFLite peephole LSTM is not
supported yet.")
+ if any(input_tensors[idx].tensor_idx != -1 for idx in [16, 17]):
+ raise tvm.error.OpNotImplemented("TFLite projection LSTM is not
supported yet.")
+
+ fw_w_f_t =
relax.op.permute_dims(self.get_tensor_expr(input_tensors[2]))
+ fw_w_c_t =
relax.op.permute_dims(self.get_tensor_expr(input_tensors[3]))
+ fw_w_o_t =
relax.op.permute_dims(self.get_tensor_expr(input_tensors[4]))
+ fw_r_f_t =
relax.op.permute_dims(self.get_tensor_expr(input_tensors[6]))
+ fw_r_c_t =
relax.op.permute_dims(self.get_tensor_expr(input_tensors[7]))
+ fw_r_o_t =
relax.op.permute_dims(self.get_tensor_expr(input_tensors[8]))
+ fw_b_f = self.get_tensor_expr(input_tensors[13])
+ fw_b_c = self.get_tensor_expr(input_tensors[14])
+ fw_b_o = self.get_tensor_expr(input_tensors[15])
+
+ # ── Backward LSTM weights (transposed once outside the loop) ──
+ if input_tensors[18].tensor_idx != -1 or input_tensors[22].tensor_idx
!= -1:
+ raise tvm.error.OpNotImplemented("Only coupled input-forget LSTM
is supported.")
+ if any(input_tensors[idx].tensor_idx != -1 for idx in [26, 27, 28]):
+ raise tvm.error.OpNotImplemented("TFLite peephole LSTM is not
supported yet.")
+ if any(input_tensors[idx].tensor_idx != -1 for idx in [33, 34]):
+ raise tvm.error.OpNotImplemented("TFLite projection LSTM is not
supported yet.")
+ if any(input_tensors[idx].tensor_idx != -1 for idx in range(39, 48)):
+ raise tvm.error.OpNotImplemented(
+ "TFLite BIDIRECTIONAL_SEQUENCE_LSTM aux input is not supported
yet."
+ )
+
+ bw_w_f_t =
relax.op.permute_dims(self.get_tensor_expr(input_tensors[19]))
+ bw_w_c_t =
relax.op.permute_dims(self.get_tensor_expr(input_tensors[20]))
+ bw_w_o_t =
relax.op.permute_dims(self.get_tensor_expr(input_tensors[21]))
+ bw_r_f_t =
relax.op.permute_dims(self.get_tensor_expr(input_tensors[23]))
+ bw_r_c_t =
relax.op.permute_dims(self.get_tensor_expr(input_tensors[24]))
+ bw_r_o_t =
relax.op.permute_dims(self.get_tensor_expr(input_tensors[25]))
+ bw_b_f = self.get_tensor_expr(input_tensors[30])
+ bw_b_c = self.get_tensor_expr(input_tensors[31])
+ bw_b_o = self.get_tensor_expr(input_tensors[32])
+
+ # ── Initial states ──
+ fw_h = self.get_tensor_expr(input_tensors[35])
+ fw_c = self.get_tensor_expr(input_tensors[36])
+ bw_h = self.get_tensor_expr(input_tensors[37])
+ bw_c = self.get_tensor_expr(input_tensors[38])
+
+ # ── Unroll input ──
+ in_expr = self.get_tensor_expr(input_tensors[0])
+ in_shape = self.get_tensor_shape(input_tensors[0])
+ if time_major:
+ in_expr = relax.op.permute_dims(in_expr, [1, 0, 2])
+ num_steps = int(in_shape[0])
+ else:
+ num_steps = int(in_shape[1])
+
+ if num_steps == 1:
+ steps = [relax.op.squeeze(in_expr, axis=[1])]
+ else:
+ splits = relax.op.split(in_expr, num_steps, axis=1)
+ steps = [relax.op.squeeze(splits[i], axis=[1]) for i in
range(num_steps)]
+
+ one = relax.const(1.0, "float32")
+
+ def _lstm_step(x_t, h, c, w_f_t, w_c_t, w_o_t, r_f_t, r_c_t, r_o_t,
b_f, b_c, b_o):
+ """Single LSTM step with coupled input-forget gate."""
+ f = relax.op.sigmoid(
+ relax.op.add(
+ relax.op.add(
+ relax.op.matmul(x_t, w_f_t),
+ relax.op.matmul(h, r_f_t),
+ ),
+ b_f,
+ )
+ )
+ i = relax.op.subtract(one, f)
+ g = self.convert_fused_activation_function(
+ relax.op.add(
+ relax.op.add(relax.op.matmul(x_t, w_c_t),
relax.op.matmul(h, r_c_t)),
+ b_c,
+ ),
+ fused_activation_fn,
+ )
+ o = relax.op.sigmoid(
+ relax.op.add(
+ relax.op.add(
+ relax.op.matmul(x_t, w_o_t),
+ relax.op.matmul(h, r_o_t),
+ ),
+ b_o,
+ )
+ )
+ c_new = relax.op.add(relax.op.multiply(f, c), relax.op.multiply(i,
g))
+ if cell_clip > 0.0:
+ c_new = relax.op.clip(c_new, -cell_clip, cell_clip)
+ h_new = relax.op.multiply(
+ o, self.convert_fused_activation_function(c_new,
fused_activation_fn)
+ )
+ if proj_clip > 0.0:
+ h_new = relax.op.clip(h_new, -proj_clip, proj_clip)
+ return h_new, c_new
+
+ # ── Forward pass ──
+ fw_outputs = []
+ for x_t in steps:
+ fw_h, fw_c = _lstm_step(
+ x_t,
+ fw_h,
+ fw_c,
+ fw_w_f_t,
+ fw_w_c_t,
+ fw_w_o_t,
+ fw_r_f_t,
+ fw_r_c_t,
+ fw_r_o_t,
+ fw_b_f,
+ fw_b_c,
+ fw_b_o,
+ )
+ fw_outputs.append(fw_h)
+
+ # ── Backward pass ──
+ bw_outputs = []
+ for x_t in reversed(steps):
+ bw_h, bw_c = _lstm_step(
+ x_t,
+ bw_h,
+ bw_c,
+ bw_w_f_t,
+ bw_w_c_t,
+ bw_w_o_t,
+ bw_r_f_t,
+ bw_r_c_t,
+ bw_r_o_t,
+ bw_b_f,
+ bw_b_c,
+ bw_b_o,
+ )
+ bw_outputs.append(bw_h)
+ bw_outputs.reverse()
+
+ fw_stacked = relax.op.stack(fw_outputs, axis=1)
+ bw_stacked = relax.op.stack(bw_outputs, axis=1)
+ if time_major:
+ fw_stacked = relax.op.permute_dims(fw_stacked, [1, 0, 2])
+ bw_stacked = relax.op.permute_dims(bw_stacked, [1, 0, 2])
+
+ # Update state tensors in the expression table for subsequent ops.
+ self.exp_tab.set_expr(
+ get_tensor_name(self.subgraph, input_tensors[35].tensor_idx),
+ fw_h,
+ force_override=True,
+ )
+ self.exp_tab.set_expr(
+ get_tensor_name(self.subgraph, input_tensors[36].tensor_idx),
+ fw_c,
+ force_override=True,
+ )
+ self.exp_tab.set_expr(
+ get_tensor_name(self.subgraph, input_tensors[37].tensor_idx),
+ bw_h,
+ force_override=True,
+ )
+ self.exp_tab.set_expr(
+ get_tensor_name(self.subgraph, input_tensors[38].tensor_idx),
+ bw_c,
+ force_override=True,
+ )
+
+ if merge_outputs:
+ return relax.op.concat([fw_stacked, bw_stacked], axis=-1)
+ else:
+ return relax.Tuple([fw_stacked, bw_stacked])
def convert_batch_to_space_nd(self, op):
"""batch_to_space_nd implementation."""
diff --git a/tests/python/relax/test_frontend_tflite.py
b/tests/python/relax/test_frontend_tflite.py
index e9ccea7ad1..05a6c1e5e5 100644
--- a/tests/python/relax/test_frontend_tflite.py
+++ b/tests/python/relax/test_frontend_tflite.py
@@ -3723,6 +3723,15 @@ _tfl_tensor_type = _get_tflite_schema_enum("TensorType")
_tfl_lstm_options = _get_tflite_schema_module("LSTMOptions")
_tfl_sequence_rnn_options = _get_tflite_schema_module("SequenceRNNOptions")
_tfl_svdf_options = _get_tflite_schema_module("SVDFOptions")
+_tfl_unidirectional_sequence_lstm_options = _get_tflite_schema_module(
+ "UnidirectionalSequenceLSTMOptions"
+)
+_tfl_bidirectional_sequence_rnn_options = _get_tflite_schema_module(
+ "BidirectionalSequenceRNNOptions"
+)
+_tfl_bidirectional_sequence_lstm_options = _get_tflite_schema_module(
+ "BidirectionalSequenceLSTMOptions"
+)
_DENSIFY_TEST_VALUES = np.array([1.0, 2.0], dtype=np.float32)
_DENSIFY_TEST_DENSE = np.array([[1.0, 0.0], [0.0, 2.0]], dtype=np.float32)
@@ -11052,6 +11061,889 @@ def test_svdf_shared_state_updates_exp_tab():
tvm.ir.assert_structural_equal(mod, Expected)
+# ── UNIDIRECTIONAL_SEQUENCE_LSTM ─────────────────────────────────────────────
+
+
+def _build_unidirectional_sequence_lstm_model(
+ batch,
+ time,
+ input_size,
+ num_units,
+ input_to_forget_weights,
+ input_to_cell_weights,
+ input_to_output_weights,
+ recurrent_to_forget_weights,
+ recurrent_to_cell_weights,
+ recurrent_to_output_weights,
+ forget_gate_bias,
+ cell_bias,
+ output_gate_bias,
+ activation,
+ *,
+ time_major=False,
+ cell_clip=0.0,
+ proj_clip=0.0,
+ projection_weights=None,
+):
+ """Build a TFLite flatbuffer model with one UNIDIRECTIONAL_SEQUENCE_LSTM
op.
+
+ Tensor indices (same layout as single-step LSTM, but input is 3D):
+ 0 - input [batch, time, input_size]
+ 1 - input_to_forget_weights [num_units, input_size]
+ 2 - input_to_cell_weights [num_units, input_size]
+ 3 - input_to_output_weights [num_units, input_size]
+ 4 - recurrent_to_forget_weights [num_units, num_units]
+ 5 - recurrent_to_cell_weights [num_units, num_units]
+ 6 - recurrent_to_output_weights [num_units, num_units]
+ 7 - forget_gate_bias [num_units]
+ 8 - cell_bias [num_units]
+ 9 - output_gate_bias [num_units]
+ 10 - output_state [batch, num_units] (model input)
+ 11 - cell_state [batch, num_units] (model input)
+ 12 - output [batch, time, num_units] or [time,
batch, num_units]
+ """
+ builder = flatbuffers.Builder(4096)
+
+
_tfl_unidirectional_sequence_lstm_options.UnidirectionalSequenceLSTMOptionsStart(builder)
+
_tfl_unidirectional_sequence_lstm_options.UnidirectionalSequenceLSTMOptionsAddFusedActivationFunction(
+ builder, activation
+ )
+
_tfl_unidirectional_sequence_lstm_options.UnidirectionalSequenceLSTMOptionsAddTimeMajor(
+ builder, time_major
+ )
+
_tfl_unidirectional_sequence_lstm_options.UnidirectionalSequenceLSTMOptionsAddCellClip(
+ builder, cell_clip
+ )
+
_tfl_unidirectional_sequence_lstm_options.UnidirectionalSequenceLSTMOptionsAddProjClip(
+ builder, proj_clip
+ )
+ lstm_opts =
_tfl_unidirectional_sequence_lstm_options.UnidirectionalSequenceLSTMOptionsEnd(
+ builder
+ )
+
+ lstm_op_code = _build_operator_code(builder,
_tfl_builtin_operator.UNIDIRECTIONAL_SEQUENCE_LSTM)
+
+ def _t(buf_idx, shape):
+ shape_vec = _tflite_shape(builder, shape)
+ _tfl_tensor.TensorStart(builder)
+ _tfl_tensor.TensorAddBuffer(builder, buf_idx)
+ _tfl_tensor.TensorAddHasRank(builder, True)
+ _tfl_tensor.TensorAddIsVariable(builder, False)
+ _tfl_tensor.TensorAddShape(builder, shape_vec)
+ _tfl_tensor.TensorAddType(builder, _tfl_tensor_type.FLOAT32)
+ return _tfl_tensor.TensorEnd(builder)
+
+ input_shape = [time, batch, input_size] if time_major else [batch, time,
input_size]
+ output_shape = [time, batch, num_units] if time_major else [batch, time,
num_units]
+ tensors = [
+ _t(0, input_shape), # 0: input
+ _t(1, [num_units, input_size]), # 1: input_to_forget_weights
+ _t(2, [num_units, input_size]), # 2: input_to_cell_weights
+ _t(3, [num_units, input_size]), # 3: input_to_output_weights
+ _t(4, [num_units, num_units]), # 4: recurrent_to_forget_weights
+ _t(5, [num_units, num_units]), # 5: recurrent_to_cell_weights
+ _t(6, [num_units, num_units]), # 6: recurrent_to_output_weights
+ _t(7, [num_units]), # 7: forget_gate_bias
+ _t(8, [num_units]), # 8: cell_bias
+ _t(9, [num_units]), # 9: output_gate_bias
+ _t(0, [batch, num_units]), # 10: output_state (model input)
+ _t(0, [batch, num_units]), # 11: cell_state (model input)
+ _t(0, output_shape), # 12: output
+ ]
+
+ # 24 operator inputs, -1 for absent.
+ lstm_inputs = [
+ 0,
+ -1,
+ 1,
+ 2,
+ 3,
+ -1,
+ 4,
+ 5,
+ 6,
+ -1,
+ -1,
+ -1,
+ -1,
+ 7,
+ 8,
+ 9,
+ -1,
+ -1,
+ 10,
+ 11,
+ -1,
+ -1,
+ -1,
+ -1,
+ ]
+ buffers = [
+ _build_buffer(builder), # 0: empty
+ _build_buffer(builder, input_to_forget_weights.tobytes()), # 1
+ _build_buffer(builder, input_to_cell_weights.tobytes()), # 2
+ _build_buffer(builder, input_to_output_weights.tobytes()), # 3
+ _build_buffer(builder, recurrent_to_forget_weights.tobytes()), # 4
+ _build_buffer(builder, recurrent_to_cell_weights.tobytes()), # 5
+ _build_buffer(builder, recurrent_to_output_weights.tobytes()), # 6
+ _build_buffer(builder, forget_gate_bias.tobytes()), # 7
+ _build_buffer(builder, cell_bias.tobytes()), # 8
+ _build_buffer(builder, output_gate_bias.tobytes()), # 9
+ ]
+ if projection_weights is not None:
+ tensors.append(_t(len(buffers), [num_units, num_units]))
+ lstm_inputs[16] = len(tensors) - 1
+ buffers.append(_build_buffer(builder, projection_weights.tobytes()))
+
+ lstm_op = _build_operator(
+ builder,
+ 0,
+ lstm_inputs,
+ [12],
+
builtin_options_type=_tfl_builtin_options.UnidirectionalSequenceLSTMOptions,
+ builtin_options=lstm_opts,
+ )
+
+ subgraph = _build_subgraph(
+ builder,
+ tensors=tensors,
+ operators=[lstm_op],
+ inputs=[0, 10, 11],
+ outputs=[12],
+ )
+
+ return _finish_tflite_model(
+ builder,
+ subgraph=subgraph,
+ operator_codes=[lstm_op_code],
+ buffers=buffers,
+ )
+
+
+def test_unidirectional_sequence_lstm_none_activation():
+ """UNIDIRECTIONAL_SEQUENCE_LSTM with NONE activation keeps cell activation
linear."""
+ from tflite.ActivationFunctionType import ActivationFunctionType
+
+ batch, time, input_size, num_units = 2, 1, 2, 2
+ w_f = np.eye(num_units, input_size, dtype=np.float32)
+ w_c = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32)
+ w_o = np.array([[0.5, -0.25], [0.75, 0.5]], dtype=np.float32)
+ r_f = np.eye(num_units, dtype=np.float32)
+ r_c = np.array([[0.5, 0.0], [0.0, 0.25]], dtype=np.float32)
+ r_o = np.array([[0.1, 0.0], [0.0, 0.2]], dtype=np.float32)
+ b_f = np.zeros(num_units, dtype=np.float32)
+ b_c = np.zeros(num_units, dtype=np.float32)
+ b_o = np.zeros(num_units, dtype=np.float32)
+
+ mod = _load_model_from_buffer(
+ _build_unidirectional_sequence_lstm_model(
+ batch,
+ time,
+ input_size,
+ num_units,
+ w_f,
+ w_c,
+ w_o,
+ r_f,
+ r_c,
+ r_o,
+ b_f,
+ b_c,
+ b_o,
+ ActivationFunctionType.NONE,
+ )
+ )
+
+ script = mod.script(show_meta=True)
+ assert script.count("R.sigmoid") == 2
+ assert "R.tanh" not in script
+ assert "R.multiply" in script
+
+
+def test_unidirectional_sequence_lstm_tanh_activation():
+ """UNIDIRECTIONAL_SEQUENCE_LSTM with TANH activation applies it inside the
cell."""
+ from tflite.ActivationFunctionType import ActivationFunctionType
+
+ batch, time, input_size, num_units = 2, 1, 2, 2
+ w_f = np.eye(num_units, input_size, dtype=np.float32)
+ w_c = np.array([[1.0, -1.0], [0.25, 0.5]], dtype=np.float32)
+ w_o = np.array([[0.5, 0.5], [-0.5, 1.0]], dtype=np.float32)
+ r_f = np.eye(num_units, dtype=np.float32)
+ r_c = np.array([[0.0, 0.1], [0.2, 0.0]], dtype=np.float32)
+ r_o = np.array([[0.3, 0.0], [0.0, 0.4]], dtype=np.float32)
+ b_f = np.zeros(num_units, dtype=np.float32)
+ b_c = np.zeros(num_units, dtype=np.float32)
+ b_o = np.zeros(num_units, dtype=np.float32)
+
+ mod = _load_model_from_buffer(
+ _build_unidirectional_sequence_lstm_model(
+ batch,
+ time,
+ input_size,
+ num_units,
+ w_f,
+ w_c,
+ w_o,
+ r_f,
+ r_c,
+ r_o,
+ b_f,
+ b_c,
+ b_o,
+ ActivationFunctionType.TANH,
+ )
+ )
+
+ script = mod.script(show_meta=True)
+ assert script.count("R.sigmoid") == 2
+ assert script.count("R.tanh") == 2
+ assert "R.multiply" in script
+
+
+def test_unidirectional_sequence_lstm_time_major():
+ """UNIDIRECTIONAL_SEQUENCE_LSTM preserves time-major output layout."""
+ from tflite.ActivationFunctionType import ActivationFunctionType
+
+ batch, time, input_size, num_units = 2, 3, 2, 2
+ weights = np.eye(num_units, input_size, dtype=np.float32)
+ recurrent = np.eye(num_units, dtype=np.float32)
+ bias = np.zeros(num_units, dtype=np.float32)
+
+ mod = _load_model_from_buffer(
+ _build_unidirectional_sequence_lstm_model(
+ batch,
+ time,
+ input_size,
+ num_units,
+ weights,
+ weights,
+ weights,
+ recurrent,
+ recurrent,
+ recurrent,
+ bias,
+ bias,
+ bias,
+ ActivationFunctionType.NONE,
+ time_major=True,
+ )
+ )
+
+ fn = mod["main"]
+ assert tuple(int(d) for d in fn.params[0].struct_info.shape) == (time,
batch, input_size)
+ assert tuple(int(d) for d in fn.ret_struct_info.shape) == (time, batch,
num_units)
+
+
+def test_unidirectional_sequence_lstm_rejects_projection():
+ """UNIDIRECTIONAL_SEQUENCE_LSTM rejects unsupported projection inputs."""
+ from tflite.ActivationFunctionType import ActivationFunctionType
+
+ batch, time, input_size, num_units = 2, 2, 2, 2
+ weights = np.eye(num_units, input_size, dtype=np.float32)
+ recurrent = np.eye(num_units, dtype=np.float32)
+ bias = np.zeros(num_units, dtype=np.float32)
+
+ with pytest.raises(tvm.error.OpNotImplemented, match="projection LSTM"):
+ _load_model_from_buffer(
+ _build_unidirectional_sequence_lstm_model(
+ batch,
+ time,
+ input_size,
+ num_units,
+ weights,
+ weights,
+ weights,
+ recurrent,
+ recurrent,
+ recurrent,
+ bias,
+ bias,
+ bias,
+ ActivationFunctionType.NONE,
+ projection_weights=np.eye(num_units, dtype=np.float32),
+ )
+ )
+
+
+# ── BIDIRECTIONAL_SEQUENCE_RNN ───────────────────────────────────────────────
+
+
+def _build_bidirectional_sequence_rnn_model(
+ batch,
+ time,
+ input_size,
+ num_units,
+ fw_weights,
+ fw_recurrent_weights,
+ fw_bias,
+ bw_weights,
+ bw_recurrent_weights,
+ bw_bias,
+ activation,
+ *,
+ time_major=False,
+ merge_outputs=True,
+ with_aux_input=False,
+):
+ """Build a TFLite flatbuffer model with one BIDIRECTIONAL_SEQUENCE_RNN op.
+
+ Tensor indices:
+ 0 - input [batch, time, input_size]
+ 1 - fw_weights [num_units, input_size]
+ 2 - fw_recurrent_weights [num_units, num_units]
+ 3 - fw_bias [num_units]
+ 4 - fw_hidden_state [batch, num_units] (model input)
+ 5 - bw_weights [num_units, input_size]
+ 6 - bw_recurrent_weights [num_units, num_units]
+ 7 - bw_bias [num_units]
+ 8 - bw_hidden_state [batch, num_units] (model input)
+ 9 - aux_input (optional)
+ 10 - fw_aux_weights (optional)
+ 11 - bw_aux_weights (optional)
+ 12 - output (or fw_output if merge_outputs=False)
+ 13 - bw_output (only if merge_outputs=False)
+ """
+ builder = flatbuffers.Builder(4096)
+
+
_tfl_bidirectional_sequence_rnn_options.BidirectionalSequenceRNNOptionsStart(builder)
+
_tfl_bidirectional_sequence_rnn_options.BidirectionalSequenceRNNOptionsAddTimeMajor(
+ builder, time_major
+ )
+
_tfl_bidirectional_sequence_rnn_options.BidirectionalSequenceRNNOptionsAddFusedActivationFunction(
+ builder, activation
+ )
+
_tfl_bidirectional_sequence_rnn_options.BidirectionalSequenceRNNOptionsAddMergeOutputs(
+ builder, merge_outputs
+ )
+ rnn_opts =
_tfl_bidirectional_sequence_rnn_options.BidirectionalSequenceRNNOptionsEnd(builder)
+
+ rnn_op_code = _build_operator_code(builder,
_tfl_builtin_operator.BIDIRECTIONAL_SEQUENCE_RNN)
+
+ def _t(buf_idx, shape):
+ shape_vec = _tflite_shape(builder, shape)
+ _tfl_tensor.TensorStart(builder)
+ _tfl_tensor.TensorAddBuffer(builder, buf_idx)
+ _tfl_tensor.TensorAddHasRank(builder, True)
+ _tfl_tensor.TensorAddIsVariable(builder, False)
+ _tfl_tensor.TensorAddShape(builder, shape_vec)
+ _tfl_tensor.TensorAddType(builder, _tfl_tensor_type.FLOAT32)
+ return _tfl_tensor.TensorEnd(builder)
+
+ input_shape = [time, batch, input_size] if time_major else [batch, time,
input_size]
+ output_prefix = [time, batch] if time_major else [batch, time]
+ output_shape = output_prefix + ([num_units * 2] if merge_outputs else
[num_units])
+
+ tensors = [
+ _t(0, input_shape), # 0: input
+ _t(1, [num_units, input_size]), # 1: fw_weights
+ _t(2, [num_units, num_units]), # 2: fw_recurrent_weights
+ _t(3, [num_units]), # 3: fw_bias
+ _t(0, [batch, num_units]), # 4: fw_hidden_state (model input)
+ _t(4, [num_units, input_size]), # 5: bw_weights
+ _t(5, [num_units, num_units]), # 6: bw_recurrent_weights
+ _t(6, [num_units]), # 7: bw_bias
+ _t(0, [batch, num_units]), # 8: bw_hidden_state (model input)
+ ]
+ buffers = [
+ _build_buffer(builder), # 0: empty
+ _build_buffer(builder, fw_weights.tobytes()), # 1
+ _build_buffer(builder, fw_recurrent_weights.tobytes()), # 2
+ _build_buffer(builder, fw_bias.tobytes()), # 3
+ _build_buffer(builder, bw_weights.tobytes()), # 4
+ _build_buffer(builder, bw_recurrent_weights.tobytes()), # 5
+ _build_buffer(builder, bw_bias.tobytes()), # 6
+ ]
+ rnn_inputs = [*list(range(9)), -1, -1, -1]
+ if with_aux_input:
+ tensors.extend(
+ [
+ _t(len(buffers), input_shape),
+ _t(len(buffers) + 1, [num_units, input_size]),
+ _t(len(buffers) + 2, [num_units, input_size]),
+ ]
+ )
+ rnn_inputs[9:12] = [len(tensors) - 3, len(tensors) - 2, len(tensors) -
1]
+ buffers.extend(
+ [
+ _build_buffer(builder, np.zeros(input_shape,
dtype=np.float32).tobytes()),
+ _build_buffer(
+ builder, np.zeros((num_units, input_size),
dtype=np.float32).tobytes()
+ ),
+ _build_buffer(
+ builder, np.zeros((num_units, input_size),
dtype=np.float32).tobytes()
+ ),
+ ]
+ )
+
+ if merge_outputs:
+ tensors.append(_t(0, output_shape))
+ outputs = [len(tensors) - 1]
+ else:
+ tensors.extend([_t(0, output_shape), _t(0, output_shape)])
+ outputs = [len(tensors) - 2, len(tensors) - 1]
+
+ rnn_op = _build_operator(
+ builder,
+ 0,
+ rnn_inputs,
+ outputs,
+
builtin_options_type=_tfl_builtin_options.BidirectionalSequenceRNNOptions,
+ builtin_options=rnn_opts,
+ )
+
+ subgraph = _build_subgraph(
+ builder,
+ tensors=tensors,
+ operators=[rnn_op],
+ inputs=[0, 4, 8],
+ outputs=outputs,
+ )
+
+ return _finish_tflite_model(
+ builder,
+ subgraph=subgraph,
+ operator_codes=[rnn_op_code],
+ buffers=buffers,
+ )
+
+
+def test_bidirectional_sequence_rnn_none_activation():
+ """BIDIRECTIONAL_SEQUENCE_RNN with NONE activation lowers the expected
equations."""
+ from tflite.ActivationFunctionType import ActivationFunctionType
+
+ batch, time, input_size, num_units = 2, 1, 2, 2
+ fw_w = np.array([[1.0, 0.0], [0.5, -1.0]], dtype=np.float32)
+ fw_r = np.array([[0.25, 0.0], [0.0, 0.5]], dtype=np.float32)
+ fw_b = np.zeros(num_units, dtype=np.float32)
+ bw_w = np.array([[0.0, 1.0], [-0.5, 0.75]], dtype=np.float32)
+ bw_r = np.array([[0.1, 0.0], [0.0, 0.2]], dtype=np.float32)
+ bw_b = np.zeros(num_units, dtype=np.float32)
+
+ mod = _load_model_from_buffer(
+ _build_bidirectional_sequence_rnn_model(
+ batch,
+ time,
+ input_size,
+ num_units,
+ fw_w,
+ fw_r,
+ fw_b,
+ bw_w,
+ bw_r,
+ bw_b,
+ ActivationFunctionType.NONE,
+ )
+ )
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(
+ x: R.Tensor((2, 1, 2), dtype="float32"),
+ fw_h: R.Tensor((2, 2), dtype="float32"),
+ bw_h: R.Tensor((2, 2), dtype="float32"),
+ ) -> R.Tensor((2, 1, 4), dtype="float32"):
+ R.func_attr({"num_input": 3})
+ with R.dataflow():
+ x_t: R.Tensor((2, 2), dtype="float32") = R.squeeze(x, axis=[1])
+ fw_w_t: R.Tensor((2, 2), dtype="float32") =
R.permute_dims(R.const(fw_w), axes=None)
+ fw_x: R.Tensor((2, 2), dtype="float32") = R.matmul(x_t,
fw_w_t, out_dtype="void")
+ fw_r_t: R.Tensor((2, 2), dtype="float32") =
R.permute_dims(R.const(fw_r), axes=None)
+ fw_h_proj: R.Tensor((2, 2), dtype="float32") = R.matmul(
+ fw_h, fw_r_t, out_dtype="void"
+ )
+ fw_out: R.Tensor((2, 2), dtype="float32") = R.add(
+ R.add(fw_x, fw_h_proj), R.const(fw_b)
+ )
+ fw_stacked: R.Tensor((2, 1, 2), dtype="float32") =
R.stack((fw_out,), axis=1)
+ bw_w_t: R.Tensor((2, 2), dtype="float32") =
R.permute_dims(R.const(bw_w), axes=None)
+ bw_x: R.Tensor((2, 2), dtype="float32") = R.matmul(x_t,
bw_w_t, out_dtype="void")
+ bw_r_t: R.Tensor((2, 2), dtype="float32") =
R.permute_dims(R.const(bw_r), axes=None)
+ bw_h_proj: R.Tensor((2, 2), dtype="float32") = R.matmul(
+ bw_h, bw_r_t, out_dtype="void"
+ )
+ bw_out: R.Tensor((2, 2), dtype="float32") = R.add(
+ R.add(bw_x, bw_h_proj), R.const(bw_b)
+ )
+ bw_stacked: R.Tensor((2, 1, 2), dtype="float32") =
R.stack((bw_out,), axis=1)
+ gv: R.Tensor((2, 1, 4), dtype="float32") = R.concat(
+ (fw_stacked, bw_stacked), axis=-1
+ )
+ R.output(gv)
+ return gv
+
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_bidirectional_sequence_rnn_time_major():
+ """BIDIRECTIONAL_SEQUENCE_RNN preserves time-major output layout."""
+ from tflite.ActivationFunctionType import ActivationFunctionType
+
+ batch, time, input_size, num_units = 2, 3, 2, 2
+ weights = np.eye(num_units, input_size, dtype=np.float32)
+ recurrent = np.eye(num_units, dtype=np.float32)
+ bias = np.zeros(num_units, dtype=np.float32)
+
+ mod = _load_model_from_buffer(
+ _build_bidirectional_sequence_rnn_model(
+ batch,
+ time,
+ input_size,
+ num_units,
+ weights,
+ recurrent,
+ bias,
+ weights,
+ recurrent,
+ bias,
+ ActivationFunctionType.NONE,
+ time_major=True,
+ )
+ )
+
+ fn = mod["main"]
+ assert tuple(int(d) for d in fn.params[0].struct_info.shape) == (time,
batch, input_size)
+ assert tuple(int(d) for d in fn.ret_struct_info.shape) == (time, batch,
num_units * 2)
+
+
+def test_bidirectional_sequence_rnn_rejects_aux_input():
+ """BIDIRECTIONAL_SEQUENCE_RNN rejects unsupported auxiliary input
tensors."""
+ from tflite.ActivationFunctionType import ActivationFunctionType
+
+ batch, time, input_size, num_units = 2, 2, 2, 2
+ weights = np.eye(num_units, input_size, dtype=np.float32)
+ recurrent = np.eye(num_units, dtype=np.float32)
+ bias = np.zeros(num_units, dtype=np.float32)
+
+ with pytest.raises(tvm.error.OpNotImplemented, match="aux input"):
+ _load_model_from_buffer(
+ _build_bidirectional_sequence_rnn_model(
+ batch,
+ time,
+ input_size,
+ num_units,
+ weights,
+ recurrent,
+ bias,
+ weights,
+ recurrent,
+ bias,
+ ActivationFunctionType.NONE,
+ with_aux_input=True,
+ )
+ )
+
+
+# ── BIDIRECTIONAL_SEQUENCE_LSTM ──────────────────────────────────────────────
+
+
+def _build_bidirectional_sequence_lstm_model(
+ batch,
+ time,
+ input_size,
+ num_units,
+ fw_w_f,
+ fw_w_c,
+ fw_w_o,
+ fw_r_f,
+ fw_r_c,
+ fw_r_o,
+ fw_b_f,
+ fw_b_c,
+ fw_b_o,
+ bw_w_f,
+ bw_w_c,
+ bw_w_o,
+ bw_r_f,
+ bw_r_c,
+ bw_r_o,
+ bw_b_f,
+ bw_b_c,
+ bw_b_o,
+ activation,
+ *,
+ time_major=False,
+ merge_outputs=True,
+ cell_clip=0.0,
+ proj_clip=0.0,
+ with_aux_input=False,
+):
+ """Build a TFLite flatbuffer model with one BIDIRECTIONAL_SEQUENCE_LSTM op.
+
+ 48 operator inputs. Forward LSTM: indices 0-17, Backward LSTM: indices
18-34,
+ States: indices 35-38.
+ """
+ builder = flatbuffers.Builder(8192)
+
+
_tfl_bidirectional_sequence_lstm_options.BidirectionalSequenceLSTMOptionsStart(builder)
+
_tfl_bidirectional_sequence_lstm_options.BidirectionalSequenceLSTMOptionsAddFusedActivationFunction(
+ builder, activation
+ )
+
_tfl_bidirectional_sequence_lstm_options.BidirectionalSequenceLSTMOptionsAddTimeMajor(
+ builder, time_major
+ )
+
_tfl_bidirectional_sequence_lstm_options.BidirectionalSequenceLSTMOptionsAddMergeOutputs(
+ builder, merge_outputs
+ )
+
_tfl_bidirectional_sequence_lstm_options.BidirectionalSequenceLSTMOptionsAddCellClip(
+ builder, cell_clip
+ )
+
_tfl_bidirectional_sequence_lstm_options.BidirectionalSequenceLSTMOptionsAddProjClip(
+ builder, proj_clip
+ )
+ lstm_opts =
_tfl_bidirectional_sequence_lstm_options.BidirectionalSequenceLSTMOptionsEnd(
+ builder
+ )
+
+ lstm_op_code = _build_operator_code(builder,
_tfl_builtin_operator.BIDIRECTIONAL_SEQUENCE_LSTM)
+
+ def _t(buf_idx, shape, is_variable=False):
+ shape_vec = _tflite_shape(builder, shape)
+ _tfl_tensor.TensorStart(builder)
+ _tfl_tensor.TensorAddBuffer(builder, buf_idx)
+ _tfl_tensor.TensorAddHasRank(builder, True)
+ _tfl_tensor.TensorAddIsVariable(builder, is_variable)
+ _tfl_tensor.TensorAddShape(builder, shape_vec)
+ _tfl_tensor.TensorAddType(builder, _tfl_tensor_type.FLOAT32)
+ return _tfl_tensor.TensorEnd(builder)
+
+ input_shape = [time, batch, input_size] if time_major else [batch, time,
input_size]
+ output_size = num_units * 2 if merge_outputs else num_units
+ output_shape = ([time, batch] if time_major else [batch, time]) +
[output_size]
+
+ tensors = [
+ _t(0, input_shape), # 0: input
+ _t(1, [num_units, input_size]), # 1: fw_w_f
+ _t(2, [num_units, input_size]), # 2: fw_w_c
+ _t(3, [num_units, input_size]), # 3: fw_w_o
+ _t(4, [num_units, num_units]), # 4: fw_r_f
+ _t(5, [num_units, num_units]), # 5: fw_r_c
+ _t(6, [num_units, num_units]), # 6: fw_r_o
+ _t(7, [num_units]), # 7: fw_b_f
+ _t(8, [num_units]), # 8: fw_b_c
+ _t(9, [num_units]), # 9: fw_b_o
+ _t(10, [num_units, input_size]), # 10: bw_w_f
+ _t(11, [num_units, input_size]), # 11: bw_w_c
+ _t(12, [num_units, input_size]), # 12: bw_w_o
+ _t(13, [num_units, num_units]), # 13: bw_r_f
+ _t(14, [num_units, num_units]), # 14: bw_r_c
+ _t(15, [num_units, num_units]), # 15: bw_r_o
+ _t(16, [num_units]), # 16: bw_b_f
+ _t(17, [num_units]), # 17: bw_b_c
+ _t(18, [num_units]), # 18: bw_b_o
+ _t(0, [batch, num_units]), # 19: fw_activation_state (model input)
+ _t(0, [batch, num_units]), # 20: fw_cell_state (model input)
+ _t(0, [batch, num_units]), # 21: bw_activation_state (model input)
+ _t(0, [batch, num_units]), # 22: bw_cell_state (model input)
+ _t(0, output_shape), # 23: output
+ ]
+
+ # Build operator inputs: 48 total, with unsupported optional inputs set to
-1.
+ fw_inputs = [0, -1, 1, 2, 3, -1, 4, 5, 6, -1, -1, -1, -1, 7, 8, 9, -1, -1]
+ bw_inputs = [-1, 10, 11, 12, -1, 13, 14, 15, -1, -1, -1, -1, 16, 17, 18,
-1, -1]
+ states = [19, 20, 21, 22]
+ aux_inputs = [-1] * 9
+ if with_aux_input:
+ tensors.append(_t(0, input_shape))
+ aux_inputs[0] = len(tensors) - 1
+ lstm_inputs = fw_inputs + bw_inputs + states + aux_inputs
+
+ lstm_op = _build_operator(
+ builder,
+ 0,
+ lstm_inputs,
+ [23],
+
builtin_options_type=_tfl_builtin_options.BidirectionalSequenceLSTMOptions,
+ builtin_options=lstm_opts,
+ )
+
+ subgraph = _build_subgraph(
+ builder,
+ tensors=tensors,
+ operators=[lstm_op],
+ inputs=[0, 19, 20, 21, 22],
+ outputs=[23],
+ )
+
+ buffers = [
+ _build_buffer(builder), # 0: empty
+ _build_buffer(builder, fw_w_f.tobytes()), # 1
+ _build_buffer(builder, fw_w_c.tobytes()), # 2
+ _build_buffer(builder, fw_w_o.tobytes()), # 3
+ _build_buffer(builder, fw_r_f.tobytes()), # 4
+ _build_buffer(builder, fw_r_c.tobytes()), # 5
+ _build_buffer(builder, fw_r_o.tobytes()), # 6
+ _build_buffer(builder, fw_b_f.tobytes()), # 7
+ _build_buffer(builder, fw_b_c.tobytes()), # 8
+ _build_buffer(builder, fw_b_o.tobytes()), # 9
+ _build_buffer(builder, bw_w_f.tobytes()), # 10
+ _build_buffer(builder, bw_w_c.tobytes()), # 11
+ _build_buffer(builder, bw_w_o.tobytes()), # 12
+ _build_buffer(builder, bw_r_f.tobytes()), # 13
+ _build_buffer(builder, bw_r_c.tobytes()), # 14
+ _build_buffer(builder, bw_r_o.tobytes()), # 15
+ _build_buffer(builder, bw_b_f.tobytes()), # 16
+ _build_buffer(builder, bw_b_c.tobytes()), # 17
+ _build_buffer(builder, bw_b_o.tobytes()), # 18
+ ]
+
+ return _finish_tflite_model(
+ builder,
+ subgraph=subgraph,
+ operator_codes=[lstm_op_code],
+ buffers=buffers,
+ )
+
+
+def test_bidirectional_sequence_lstm_none_activation():
+ """BIDIRECTIONAL_SEQUENCE_LSTM with NONE activation keeps both cell
activations linear."""
+ from tflite.ActivationFunctionType import ActivationFunctionType
+
+ batch, time, input_size, num_units = 2, 1, 2, 2
+
+ def _eye_or_randn(m, n):
+ if m == n:
+ return np.eye(m, dtype=np.float32)
+ return np.arange(m * n, dtype=np.float32).reshape(m, n) / 10.0
+
+ fw_w_f = _eye_or_randn(num_units, input_size)
+ fw_w_c = np.array([[1.0, -0.5], [0.25, 0.75]], dtype=np.float32)
+ fw_w_o = np.array([[0.5, 0.25], [-0.25, 1.0]], dtype=np.float32)
+ fw_r_f = _eye_or_randn(num_units, num_units)
+ fw_r_c = np.array([[0.2, 0.0], [0.0, 0.3]], dtype=np.float32)
+ fw_r_o = np.array([[0.1, 0.0], [0.0, 0.2]], dtype=np.float32)
+ fw_b_f = np.zeros(num_units, dtype=np.float32)
+ fw_b_c = np.zeros(num_units, dtype=np.float32)
+ fw_b_o = np.zeros(num_units, dtype=np.float32)
+
+ bw_w_f = np.array([[1.0, 0.0], [0.0, 1.0]], dtype=np.float32)
+ bw_w_c = np.array([[0.5, 0.5], [-0.5, 1.0]], dtype=np.float32)
+ bw_w_o = np.array([[0.25, -0.25], [0.75, 0.5]], dtype=np.float32)
+ bw_r_f = np.array([[0.4, 0.0], [0.0, 0.6]], dtype=np.float32)
+ bw_r_c = np.array([[0.3, 0.0], [0.0, 0.2]], dtype=np.float32)
+ bw_r_o = np.array([[0.2, 0.0], [0.0, 0.1]], dtype=np.float32)
+ bw_b_f = np.zeros(num_units, dtype=np.float32)
+ bw_b_c = np.zeros(num_units, dtype=np.float32)
+ bw_b_o = np.zeros(num_units, dtype=np.float32)
+
+ mod = _load_model_from_buffer(
+ _build_bidirectional_sequence_lstm_model(
+ batch,
+ time,
+ input_size,
+ num_units,
+ fw_w_f,
+ fw_w_c,
+ fw_w_o,
+ fw_r_f,
+ fw_r_c,
+ fw_r_o,
+ fw_b_f,
+ fw_b_c,
+ fw_b_o,
+ bw_w_f,
+ bw_w_c,
+ bw_w_o,
+ bw_r_f,
+ bw_r_c,
+ bw_r_o,
+ bw_b_f,
+ bw_b_c,
+ bw_b_o,
+ ActivationFunctionType.NONE,
+ )
+ )
+
+ script = mod.script(show_meta=True)
+ assert script.count("R.sigmoid") == 4
+ assert "R.tanh" not in script
+ assert script.count("R.stack") == 2
+ assert "R.concat" in script
+
+
+def test_bidirectional_sequence_lstm_time_major():
+ """BIDIRECTIONAL_SEQUENCE_LSTM preserves time-major output layout."""
+ from tflite.ActivationFunctionType import ActivationFunctionType
+
+ batch, time, input_size, num_units = 2, 3, 2, 2
+ weights = np.eye(num_units, input_size, dtype=np.float32)
+ recurrent = np.eye(num_units, dtype=np.float32)
+ bias = np.zeros(num_units, dtype=np.float32)
+
+ mod = _load_model_from_buffer(
+ _build_bidirectional_sequence_lstm_model(
+ batch,
+ time,
+ input_size,
+ num_units,
+ weights,
+ weights,
+ weights,
+ recurrent,
+ recurrent,
+ recurrent,
+ bias,
+ bias,
+ bias,
+ weights,
+ weights,
+ weights,
+ recurrent,
+ recurrent,
+ recurrent,
+ bias,
+ bias,
+ bias,
+ ActivationFunctionType.NONE,
+ time_major=True,
+ )
+ )
+
+ fn = mod["main"]
+ assert tuple(int(d) for d in fn.params[0].struct_info.shape) == (time,
batch, input_size)
+ assert tuple(int(d) for d in fn.ret_struct_info.shape) == (time, batch,
num_units * 2)
+
+
+def test_bidirectional_sequence_lstm_rejects_aux_input():
+ """BIDIRECTIONAL_SEQUENCE_LSTM rejects unsupported auxiliary inputs."""
+ from tflite.ActivationFunctionType import ActivationFunctionType
+
+ batch, time, input_size, num_units = 2, 2, 2, 2
+ weights = np.eye(num_units, input_size, dtype=np.float32)
+ recurrent = np.eye(num_units, dtype=np.float32)
+ bias = np.zeros(num_units, dtype=np.float32)
+
+ with pytest.raises(tvm.error.OpNotImplemented, match="aux input"):
+ _load_model_from_buffer(
+ _build_bidirectional_sequence_lstm_model(
+ batch,
+ time,
+ input_size,
+ num_units,
+ weights,
+ weights,
+ weights,
+ recurrent,
+ recurrent,
+ recurrent,
+ bias,
+ bias,
+ bias,
+ weights,
+ weights,
+ weights,
+ recurrent,
+ recurrent,
+ recurrent,
+ bias,
+ bias,
+ bias,
+ ActivationFunctionType.NONE,
+ with_aux_input=True,
+ )
+ )
+
+
# ── UNIDIRECTIONAL_SEQUENCE_RNN
───────────────────────────────────────────────