gemini-code-assist[bot] commented on code in PR #19601:
URL: https://github.com/apache/tvm/pull/19601#discussion_r3296953746


##########
python/tvm/relax/frontend/tflite/tflite_frontend.py:
##########
@@ -4477,6 +4478,92 @@ def convert_unpack(self, op):
 
         return squeezed
 
+    def convert_unidirectional_sequence_rnn(self, op):
+        """Convert TFLite UNIDIRECTIONAL_SEQUENCE_RNN.
+
+        Inputs (5 tensors):
+          [0] input          [batch, time, input_size]  (or [time, batch, 
input_size] if time_major)
+          [1] input_weights  [num_units, input_size]
+          [2] recurrent_weights [num_units, num_units]
+          [3] bias           [num_units]
+          [4] hidden_state   [batch, num_units]  (variable, zero-initialised)
+
+        Output:
+          [0] output  [batch, time, num_units]
+
+        Cell equation:
+          h_t = fused_activation(x_t @ W.T + h_{t-1} @ Wr.T + b)
+        """
+        from tflite.BuiltinOptions import BuiltinOptions
+        from tflite.SequenceRNNOptions import SequenceRNNOptions
+
+        if self.is_quantized(op):
+            raise tvm.error.OpNotImplemented(
+                "TFLite quantized UNIDIRECTIONAL_SEQUENCE_RNN is not supported 
yet."
+            )
+
+        input_tensors = self.get_input_tensors(op)
+        assert len(input_tensors) == 5, "input tensors length should be 5"
+
+        input_tensor = input_tensors[0]
+        weights_tensor = input_tensors[1]
+        recurrent_tensor = input_tensors[2]
+        bias_tensor = input_tensors[3]
+        hidden_state_tensor = input_tensors[4]
+
+        output_tensors = self.get_output_tensors(op)
+        assert len(output_tensors) >= 1, "output tensors length should be at 
least 1"
+
+        assert op.BuiltinOptionsType() == BuiltinOptions.SequenceRNNOptions
+        op_options = op.BuiltinOptions()
+        seq_rnn_options = SequenceRNNOptions()
+        seq_rnn_options.Init(op_options.Bytes, op_options.Pos)
+        time_major = seq_rnn_options.TimeMajor()
+        fused_activation_fn = seq_rnn_options.FusedActivationFunction()
+
+        # Constant weight/bias expressions.
+        weights_expr = self.get_tensor_expr(weights_tensor)  # [num_units, 
input_size]
+        recurrent_expr = self.get_tensor_expr(recurrent_tensor)  # [num_units, 
num_units]
+        bias_expr = self.get_tensor_expr(bias_tensor)  # [num_units]
+
+        # Transpose to [input_size, num_units] and [num_units, num_units] for 
x @ W.T.
+        w_t = relax.op.permute_dims(weights_expr)
+        wr_t = relax.op.permute_dims(recurrent_expr)
+
+        # Resolve the input expression; normalise to batch-major [batch, time, 
input_size].
+        in_expr = self.get_tensor_expr(input_tensor)
+        in_shape = to_int_list(self.get_tensor_shape(input_tensor))
+        if time_major:
+            in_expr = relax.op.permute_dims(in_expr, [1, 0, 2])
+            num_steps = in_shape[0]
+        else:
+            num_steps = in_shape[1]
+
+        # Zero-initialised hidden state [batch, num_units].
+        h_shape = 
tuple(to_int_list(self.get_tensor_shape(hidden_state_tensor)))
+        h_dtype = self.get_tensor_type_str(hidden_state_tensor.tensor.Type())
+        h = relax.op.zeros(h_shape, dtype=h_dtype)

Review Comment:
   ![high](https://www.gstatic.com/codereviewagent/high-priority.svg)
   
   The current implementation hardcodes the initial hidden state to zeros, 
which ignores any initial values provided in the TFLite model's buffer or the 
possibility of the hidden state being a graph input (common in stateful 
models). It should use `self.get_tensor_expr(hidden_state_tensor)` to respect 
the model's definition, falling back to zeros only if the tensor is an 
uninitialized variable.
   
   ```suggestion
           h_dtype = self.get_tensor_type_str(hidden_state_tensor.tensor.Type())
           if self.has_expr(hidden_state_tensor.tensor_idx) or 
hidden_state_tensor.buffer.DataLength() > 0:
               h = self.get_tensor_expr(hidden_state_tensor)
           else:
               h_shape = self.get_tensor_shape(hidden_state_tensor)
               h = relax.op.zeros(relax.ShapeExpr(h_shape), dtype=h_dtype)
   ```



##########
python/tvm/relax/frontend/tflite/tflite_frontend.py:
##########
@@ -4477,6 +4478,92 @@ def convert_unpack(self, op):
 
         return squeezed
 
+    def convert_unidirectional_sequence_rnn(self, op):
+        """Convert TFLite UNIDIRECTIONAL_SEQUENCE_RNN.
+
+        Inputs (5 tensors):
+          [0] input          [batch, time, input_size]  (or [time, batch, 
input_size] if time_major)
+          [1] input_weights  [num_units, input_size]
+          [2] recurrent_weights [num_units, num_units]
+          [3] bias           [num_units]
+          [4] hidden_state   [batch, num_units]  (variable, zero-initialised)
+
+        Output:
+          [0] output  [batch, time, num_units]
+
+        Cell equation:
+          h_t = fused_activation(x_t @ W.T + h_{t-1} @ Wr.T + b)
+        """
+        from tflite.BuiltinOptions import BuiltinOptions
+        from tflite.SequenceRNNOptions import SequenceRNNOptions
+
+        if self.is_quantized(op):
+            raise tvm.error.OpNotImplemented(
+                "TFLite quantized UNIDIRECTIONAL_SEQUENCE_RNN is not supported 
yet."
+            )
+
+        input_tensors = self.get_input_tensors(op)
+        assert len(input_tensors) == 5, "input tensors length should be 5"
+
+        input_tensor = input_tensors[0]
+        weights_tensor = input_tensors[1]
+        recurrent_tensor = input_tensors[2]
+        bias_tensor = input_tensors[3]
+        hidden_state_tensor = input_tensors[4]
+
+        output_tensors = self.get_output_tensors(op)
+        assert len(output_tensors) >= 1, "output tensors length should be at 
least 1"
+
+        assert op.BuiltinOptionsType() == BuiltinOptions.SequenceRNNOptions
+        op_options = op.BuiltinOptions()
+        seq_rnn_options = SequenceRNNOptions()
+        seq_rnn_options.Init(op_options.Bytes, op_options.Pos)
+        time_major = seq_rnn_options.TimeMajor()
+        fused_activation_fn = seq_rnn_options.FusedActivationFunction()
+
+        # Constant weight/bias expressions.
+        weights_expr = self.get_tensor_expr(weights_tensor)  # [num_units, 
input_size]
+        recurrent_expr = self.get_tensor_expr(recurrent_tensor)  # [num_units, 
num_units]
+        bias_expr = self.get_tensor_expr(bias_tensor)  # [num_units]

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   TFLite operators often have optional tensors (indicated by a tensor index of 
-1). The converter should check if the bias tensor is provided before calling 
`get_tensor_expr` to avoid a potential crash during model import.
   
   ```suggestion
           if bias_tensor.tensor_idx != -1:
               bias_expr = self.get_tensor_expr(bias_tensor)
           else:
               num_units = int(self.get_tensor_shape(weights_tensor)[0])
               bias_dtype = 
self.get_tensor_type_str(weights_tensor.tensor.Type())
               bias_expr = relax.op.zeros(relax.ShapeExpr((num_units,)), 
dtype=bias_dtype)
   ```



##########
tests/python/relax/test_frontend_tflite.py:
##########
@@ -6730,6 +6732,213 @@ def main(
 
     tvm.ir.assert_structural_equal(mod, Expected)
 
+    tvm.ir.assert_structural_equal(mod, Expected)

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   This line is a duplicate of the assertion on line 6733.



##########
python/tvm/relax/frontend/tflite/tflite_frontend.py:
##########
@@ -4477,6 +4478,92 @@ def convert_unpack(self, op):
 
         return squeezed
 
+    def convert_unidirectional_sequence_rnn(self, op):
+        """Convert TFLite UNIDIRECTIONAL_SEQUENCE_RNN.
+
+        Inputs (5 tensors):
+          [0] input          [batch, time, input_size]  (or [time, batch, 
input_size] if time_major)
+          [1] input_weights  [num_units, input_size]
+          [2] recurrent_weights [num_units, num_units]
+          [3] bias           [num_units]
+          [4] hidden_state   [batch, num_units]  (variable, zero-initialised)
+
+        Output:
+          [0] output  [batch, time, num_units]
+
+        Cell equation:
+          h_t = fused_activation(x_t @ W.T + h_{t-1} @ Wr.T + b)
+        """
+        from tflite.BuiltinOptions import BuiltinOptions
+        from tflite.SequenceRNNOptions import SequenceRNNOptions
+
+        if self.is_quantized(op):
+            raise tvm.error.OpNotImplemented(
+                "TFLite quantized UNIDIRECTIONAL_SEQUENCE_RNN is not supported 
yet."
+            )
+
+        input_tensors = self.get_input_tensors(op)
+        assert len(input_tensors) == 5, "input tensors length should be 5"
+
+        input_tensor = input_tensors[0]
+        weights_tensor = input_tensors[1]
+        recurrent_tensor = input_tensors[2]
+        bias_tensor = input_tensors[3]
+        hidden_state_tensor = input_tensors[4]
+
+        output_tensors = self.get_output_tensors(op)
+        assert len(output_tensors) >= 1, "output tensors length should be at 
least 1"
+
+        assert op.BuiltinOptionsType() == BuiltinOptions.SequenceRNNOptions
+        op_options = op.BuiltinOptions()
+        seq_rnn_options = SequenceRNNOptions()
+        seq_rnn_options.Init(op_options.Bytes, op_options.Pos)
+        time_major = seq_rnn_options.TimeMajor()
+        fused_activation_fn = seq_rnn_options.FusedActivationFunction()
+
+        # Constant weight/bias expressions.
+        weights_expr = self.get_tensor_expr(weights_tensor)  # [num_units, 
input_size]
+        recurrent_expr = self.get_tensor_expr(recurrent_tensor)  # [num_units, 
num_units]
+        bias_expr = self.get_tensor_expr(bias_tensor)  # [num_units]
+
+        # Transpose to [input_size, num_units] and [num_units, num_units] for 
x @ W.T.
+        w_t = relax.op.permute_dims(weights_expr)
+        wr_t = relax.op.permute_dims(recurrent_expr)
+
+        # Resolve the input expression; normalise to batch-major [batch, time, 
input_size].
+        in_expr = self.get_tensor_expr(input_tensor)
+        in_shape = to_int_list(self.get_tensor_shape(input_tensor))
+        if time_major:
+            in_expr = relax.op.permute_dims(in_expr, [1, 0, 2])
+            num_steps = in_shape[0]
+        else:
+            num_steps = in_shape[1]

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   Using `to_int_list` on the entire input shape is overly restrictive as it 
forces all dimensions (including batch size) to be static. Only the sequence 
length dimension needs to be a static integer to support unrolling at graph 
construction time. It is better to extract only the required dimension.
   
   ```suggestion
           in_shape = self.get_tensor_shape(input_tensor)
           if time_major:
               in_expr = relax.op.permute_dims(in_expr, [1, 0, 2])
               num_steps = int(in_shape[0])
           else:
               num_steps = int(in_shape[1])
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to