SebastianBoblestETAS commented on code in PR #11183:
URL: https://github.com/apache/tvm/pull/11183#discussion_r867767234
##########
python/tvm/relay/frontend/tflite.py:
##########
@@ -2710,6 +2743,145 @@ def convert_unpack(self, op):
return squeezed
+ def convert_unidirectional_sequence_lstm(self, op):
+ """Long Short Term Memory for TFLite implementation."""
+ if self.is_quantized(op):
+ raise tvm.error.OpNotImplemented(
+ "TFlite quantized UNIDIRECTIONALSEQUENCELSTM operator is not
supported yet."
+ )
+
+ input_tensors = self.get_input_tensors(op)
+ assert len(input_tensors) >= 20, "input tensors length should be >= 20"
+
+ # Extract input tensor from saved model
+ input_tensor = input_tensors[0]
+
+ # Extract tensors from input tensors from saved model
+ # Input weights
+ input_input_weights = input_tensors[1]
+ input_forget_weights = input_tensors[2]
+ input_cell_weights = input_tensors[3]
+ input_output_weights = input_tensors[4]
+ # Recurrent weights
+ recurrent_input_weights = input_tensors[5]
+ recurrent_forget_weights = input_tensors[6]
+ recurrent_cell_weights = input_tensors[7]
+ recurrent_output_weights = input_tensors[8]
+ # Bias weights
+ input_gate_bias = input_tensors[12]
+ forget_gate_bias = input_tensors[13]
+ cell_gate_bias = input_tensors[14]
+ output_gate_bias = input_tensors[15]
+ # State input
+ output_state_in = input_tensors[18]
+ cell_state_in = input_tensors[19]
+
+ # Extract output tensor from saved model
+ output_tensors = self.get_output_tensors(op)
+ assert len(output_tensors) == 1, "output tensors length should be 1"
+ X_steps = self.unbind(input_tensor, axis=1)
+ weights_dict = {}
+
+ # hidden_state_weights is equivalent to output_state_in in tflite model
+ out_state_in_shape = tuple(self.get_tensor_shape(output_state_in))
+ out_state_in_dtype =
self.get_tensor_type_str(output_state_in.tensor.Type())
+ out_state_in_expr = _op.zeros(out_state_in_shape,
dtype=out_state_in_dtype)
+ weights_dict["hidden_state"] = _op.split(out_state_in_expr, 1)[0]
Review Comment:
A relay LSTM operator might be valuable. But it should then support all
frontends, right? For the moment we decided to use the existing lstm_cell that
is also used by the onnx frontend.
If we have a native LSTM operator in relay later, we can adapt the parsers
to use that instead.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]