gemini-code-assist[bot] commented on code in PR #19601:
URL: https://github.com/apache/tvm/pull/19601#discussion_r3296953746
##########
python/tvm/relax/frontend/tflite/tflite_frontend.py:
##########
@@ -4477,6 +4478,92 @@ def convert_unpack(self, op):
return squeezed
+ def convert_unidirectional_sequence_rnn(self, op):
+ """Convert TFLite UNIDIRECTIONAL_SEQUENCE_RNN.
+
+ Inputs (5 tensors):
+ [0] input [batch, time, input_size] (or [time, batch,
input_size] if time_major)
+ [1] input_weights [num_units, input_size]
+ [2] recurrent_weights [num_units, num_units]
+ [3] bias [num_units]
+ [4] hidden_state [batch, num_units] (variable, zero-initialised)
+
+ Output:
+ [0] output [batch, time, num_units]
+
+ Cell equation:
+ h_t = fused_activation(x_t @ W.T + h_{t-1} @ Wr.T + b)
+ """
+ from tflite.BuiltinOptions import BuiltinOptions
+ from tflite.SequenceRNNOptions import SequenceRNNOptions
+
+ if self.is_quantized(op):
+ raise tvm.error.OpNotImplemented(
+ "TFLite quantized UNIDIRECTIONAL_SEQUENCE_RNN is not supported
yet."
+ )
+
+ input_tensors = self.get_input_tensors(op)
+ assert len(input_tensors) == 5, "input tensors length should be 5"
+
+ input_tensor = input_tensors[0]
+ weights_tensor = input_tensors[1]
+ recurrent_tensor = input_tensors[2]
+ bias_tensor = input_tensors[3]
+ hidden_state_tensor = input_tensors[4]
+
+ output_tensors = self.get_output_tensors(op)
+ assert len(output_tensors) >= 1, "output tensors length should be at
least 1"
+
+ assert op.BuiltinOptionsType() == BuiltinOptions.SequenceRNNOptions
+ op_options = op.BuiltinOptions()
+ seq_rnn_options = SequenceRNNOptions()
+ seq_rnn_options.Init(op_options.Bytes, op_options.Pos)
+ time_major = seq_rnn_options.TimeMajor()
+ fused_activation_fn = seq_rnn_options.FusedActivationFunction()
+
+ # Constant weight/bias expressions.
+ weights_expr = self.get_tensor_expr(weights_tensor) # [num_units,
input_size]
+ recurrent_expr = self.get_tensor_expr(recurrent_tensor) # [num_units,
num_units]
+ bias_expr = self.get_tensor_expr(bias_tensor) # [num_units]
+
+ # Transpose to [input_size, num_units] and [num_units, num_units] for
x @ W.T.
+ w_t = relax.op.permute_dims(weights_expr)
+ wr_t = relax.op.permute_dims(recurrent_expr)
+
+ # Resolve the input expression; normalise to batch-major [batch, time,
input_size].
+ in_expr = self.get_tensor_expr(input_tensor)
+ in_shape = to_int_list(self.get_tensor_shape(input_tensor))
+ if time_major:
+ in_expr = relax.op.permute_dims(in_expr, [1, 0, 2])
+ num_steps = in_shape[0]
+ else:
+ num_steps = in_shape[1]
+
+ # Zero-initialised hidden state [batch, num_units].
+ h_shape =
tuple(to_int_list(self.get_tensor_shape(hidden_state_tensor)))
+ h_dtype = self.get_tensor_type_str(hidden_state_tensor.tensor.Type())
+ h = relax.op.zeros(h_shape, dtype=h_dtype)
Review Comment:

The current implementation hardcodes the initial hidden state to zeros,
which ignores any initial values provided in the TFLite model's buffer or the
possibility of the hidden state being a graph input (common in stateful
models). It should use `self.get_tensor_expr(hidden_state_tensor)` to respect
the model's definition, falling back to zeros only if the tensor is an
uninitialized variable.
```suggestion
h_dtype = self.get_tensor_type_str(hidden_state_tensor.tensor.Type())
if self.has_expr(hidden_state_tensor.tensor_idx) or
hidden_state_tensor.buffer.DataLength() > 0:
h = self.get_tensor_expr(hidden_state_tensor)
else:
h_shape = self.get_tensor_shape(hidden_state_tensor)
h = relax.op.zeros(relax.ShapeExpr(h_shape), dtype=h_dtype)
```
##########
python/tvm/relax/frontend/tflite/tflite_frontend.py:
##########
@@ -4477,6 +4478,92 @@ def convert_unpack(self, op):
return squeezed
+ def convert_unidirectional_sequence_rnn(self, op):
+ """Convert TFLite UNIDIRECTIONAL_SEQUENCE_RNN.
+
+ Inputs (5 tensors):
+ [0] input [batch, time, input_size] (or [time, batch,
input_size] if time_major)
+ [1] input_weights [num_units, input_size]
+ [2] recurrent_weights [num_units, num_units]
+ [3] bias [num_units]
+ [4] hidden_state [batch, num_units] (variable, zero-initialised)
+
+ Output:
+ [0] output [batch, time, num_units]
+
+ Cell equation:
+ h_t = fused_activation(x_t @ W.T + h_{t-1} @ Wr.T + b)
+ """
+ from tflite.BuiltinOptions import BuiltinOptions
+ from tflite.SequenceRNNOptions import SequenceRNNOptions
+
+ if self.is_quantized(op):
+ raise tvm.error.OpNotImplemented(
+ "TFLite quantized UNIDIRECTIONAL_SEQUENCE_RNN is not supported
yet."
+ )
+
+ input_tensors = self.get_input_tensors(op)
+ assert len(input_tensors) == 5, "input tensors length should be 5"
+
+ input_tensor = input_tensors[0]
+ weights_tensor = input_tensors[1]
+ recurrent_tensor = input_tensors[2]
+ bias_tensor = input_tensors[3]
+ hidden_state_tensor = input_tensors[4]
+
+ output_tensors = self.get_output_tensors(op)
+ assert len(output_tensors) >= 1, "output tensors length should be at
least 1"
+
+ assert op.BuiltinOptionsType() == BuiltinOptions.SequenceRNNOptions
+ op_options = op.BuiltinOptions()
+ seq_rnn_options = SequenceRNNOptions()
+ seq_rnn_options.Init(op_options.Bytes, op_options.Pos)
+ time_major = seq_rnn_options.TimeMajor()
+ fused_activation_fn = seq_rnn_options.FusedActivationFunction()
+
+ # Constant weight/bias expressions.
+ weights_expr = self.get_tensor_expr(weights_tensor) # [num_units,
input_size]
+ recurrent_expr = self.get_tensor_expr(recurrent_tensor) # [num_units,
num_units]
+ bias_expr = self.get_tensor_expr(bias_tensor) # [num_units]
Review Comment:

TFLite operators often have optional tensors (indicated by a tensor index of
-1). The converter should check if the bias tensor is provided before calling
`get_tensor_expr` to avoid a potential crash during model import.
```suggestion
if bias_tensor.tensor_idx != -1:
bias_expr = self.get_tensor_expr(bias_tensor)
else:
num_units = int(self.get_tensor_shape(weights_tensor)[0])
bias_dtype =
self.get_tensor_type_str(weights_tensor.tensor.Type())
bias_expr = relax.op.zeros(relax.ShapeExpr((num_units,)),
dtype=bias_dtype)
```
##########
tests/python/relax/test_frontend_tflite.py:
##########
@@ -6730,6 +6732,213 @@ def main(
tvm.ir.assert_structural_equal(mod, Expected)
+ tvm.ir.assert_structural_equal(mod, Expected)
Review Comment:

This line is a duplicate of the assertion on line 6733.
##########
python/tvm/relax/frontend/tflite/tflite_frontend.py:
##########
@@ -4477,6 +4478,92 @@ def convert_unpack(self, op):
return squeezed
+ def convert_unidirectional_sequence_rnn(self, op):
+ """Convert TFLite UNIDIRECTIONAL_SEQUENCE_RNN.
+
+ Inputs (5 tensors):
+ [0] input [batch, time, input_size] (or [time, batch,
input_size] if time_major)
+ [1] input_weights [num_units, input_size]
+ [2] recurrent_weights [num_units, num_units]
+ [3] bias [num_units]
+ [4] hidden_state [batch, num_units] (variable, zero-initialised)
+
+ Output:
+ [0] output [batch, time, num_units]
+
+ Cell equation:
+ h_t = fused_activation(x_t @ W.T + h_{t-1} @ Wr.T + b)
+ """
+ from tflite.BuiltinOptions import BuiltinOptions
+ from tflite.SequenceRNNOptions import SequenceRNNOptions
+
+ if self.is_quantized(op):
+ raise tvm.error.OpNotImplemented(
+ "TFLite quantized UNIDIRECTIONAL_SEQUENCE_RNN is not supported
yet."
+ )
+
+ input_tensors = self.get_input_tensors(op)
+ assert len(input_tensors) == 5, "input tensors length should be 5"
+
+ input_tensor = input_tensors[0]
+ weights_tensor = input_tensors[1]
+ recurrent_tensor = input_tensors[2]
+ bias_tensor = input_tensors[3]
+ hidden_state_tensor = input_tensors[4]
+
+ output_tensors = self.get_output_tensors(op)
+ assert len(output_tensors) >= 1, "output tensors length should be at
least 1"
+
+ assert op.BuiltinOptionsType() == BuiltinOptions.SequenceRNNOptions
+ op_options = op.BuiltinOptions()
+ seq_rnn_options = SequenceRNNOptions()
+ seq_rnn_options.Init(op_options.Bytes, op_options.Pos)
+ time_major = seq_rnn_options.TimeMajor()
+ fused_activation_fn = seq_rnn_options.FusedActivationFunction()
+
+ # Constant weight/bias expressions.
+ weights_expr = self.get_tensor_expr(weights_tensor) # [num_units,
input_size]
+ recurrent_expr = self.get_tensor_expr(recurrent_tensor) # [num_units,
num_units]
+ bias_expr = self.get_tensor_expr(bias_tensor) # [num_units]
+
+ # Transpose to [input_size, num_units] and [num_units, num_units] for
x @ W.T.
+ w_t = relax.op.permute_dims(weights_expr)
+ wr_t = relax.op.permute_dims(recurrent_expr)
+
+ # Resolve the input expression; normalise to batch-major [batch, time,
input_size].
+ in_expr = self.get_tensor_expr(input_tensor)
+ in_shape = to_int_list(self.get_tensor_shape(input_tensor))
+ if time_major:
+ in_expr = relax.op.permute_dims(in_expr, [1, 0, 2])
+ num_steps = in_shape[0]
+ else:
+ num_steps = in_shape[1]
Review Comment:

Using `to_int_list` on the entire input shape is overly restrictive as it
forces all dimensions (including batch size) to be static. Only the sequence
length dimension needs to be a static integer to support unrolling at graph
construction time. It is better to extract only the required dimension.
```suggestion
in_shape = self.get_tensor_shape(input_tensor)
if time_major:
in_expr = relax.op.permute_dims(in_expr, [1, 0, 2])
num_steps = int(in_shape[0])
else:
num_steps = int(in_shape[1])
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]