zhang-yi-chi commented on code in PR #12213:
URL: https://github.com/apache/tvm/pull/12213#discussion_r935527341
##########
python/tvm/relay/frontend/onnx.py:
##########
@@ -2756,6 +2757,138 @@ def _activation_needs_beta(cls, activation):
]
return activation.decode("utf-8") in needs_beta
+ @classmethod
+ def bidir_rnn_cell(
+ cls,
+ input_seqs,
+ weight_dicts,
+ acts,
+ ):
+ """
+ Bidirectional RNN cell
+ """
+ seq_len = len(input_seqs)
+ forward_outputs, fw_H_t = rnn_cell(
+ input_seqs,
+ **weight_dicts[0],
+ act=acts[0],
+ )
+
+ reverse_outputs, rev_H_t = rnn_cell(
+ input_seqs,
+ **weight_dicts[1],
+ act=acts[1],
+ backwards=True,
+ )
+
+ final_outputs = []
+ for i in range(seq_len):
+ final_outputs.append(
+ _op.stack([forward_outputs[i], reverse_outputs[seq_len - 1 -
i]], axis=0)
+ )
+
+ return (
+ _op.stack(final_outputs, axis=0),
+ _op.stack([fw_H_t, rev_H_t], axis=0),
+ )
+
+ @classmethod
+ def _impl_v7(cls, inputs, attr, params):
+ # Unpack inputs, note that if optional and not provided then value
will be None.
+ X = inputs[0]
+ Wp = inputs[1]
+ Rp = inputs[2]
+ Bp = inputs[3]
+ # Sequence length currently unused as it can be inferred from shapes.
+ # sequence_lens = inputs['sequence_lens']
+ Hp_0 = inputs[5]
+
+ num_directions = infer_shape(Wp)[0]
+ W_dtype = infer_type(Wp).checked_type.dtype
+
+ if num_directions not in [1, 2]:
+ raise ValueError("num_directions must be either 1 or 2!")
+
+ X_shape = infer_shape(X)
+ hidden_size = infer_shape(Rp)[-1]
+ batch_size = X_shape[1]
+
+ if Hp_0 is None:
+ Hp_0 = _op.zeros((num_directions, batch_size, hidden_size),
W_dtype)
+
+ if "activations" in attr:
Review Comment:
I refactored both onnx.py and test_forward.py.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]