AndrewZhaoLuo commented on code in PR #12213:
URL: https://github.com/apache/tvm/pull/12213#discussion_r934979601
##########
python/tvm/relay/frontend/onnx.py:
##########
@@ -2756,6 +2757,138 @@ def _activation_needs_beta(cls, activation):
]
return activation.decode("utf-8") in needs_beta
+ @classmethod
+ def bidir_rnn_cell(
+ cls,
+ input_seqs,
+ weight_dicts,
+ acts,
+ ):
+ """
+ Bidirectional RNN cell
+ """
+ seq_len = len(input_seqs)
+ forward_outputs, fw_H_t = rnn_cell(
+ input_seqs,
+ **weight_dicts[0],
+ act=acts[0],
+ )
+
+ reverse_outputs, rev_H_t = rnn_cell(
+ input_seqs,
+ **weight_dicts[1],
+ act=acts[1],
+ backwards=True,
+ )
+
+ final_outputs = []
+ for i in range(seq_len):
+ final_outputs.append(
+ _op.stack([forward_outputs[i], reverse_outputs[seq_len - 1 -
i]], axis=0)
+ )
+
+ return (
+ _op.stack(final_outputs, axis=0),
+ _op.stack([fw_H_t, rev_H_t], axis=0),
+ )
+
+ @classmethod
+ def _impl_v7(cls, inputs, attr, params):
Review Comment:
why target v7 when we have an operator v14?
v7 vs v14 don't seem very different, v14 just cuts out the `hidden_size`
attribute which we dont use anyway
Can we change this to _impl_v14?
##########
python/tvm/relay/frontend/onnx.py:
##########
@@ -2756,6 +2757,138 @@ def _activation_needs_beta(cls, activation):
]
return activation.decode("utf-8") in needs_beta
+ @classmethod
+ def bidir_rnn_cell(
+ cls,
+ input_seqs,
+ weight_dicts,
+ acts,
+ ):
+ """
+ Bidirectional RNN cell
+ """
+ seq_len = len(input_seqs)
+ forward_outputs, fw_H_t = rnn_cell(
+ input_seqs,
+ **weight_dicts[0],
+ act=acts[0],
+ )
+
+ reverse_outputs, rev_H_t = rnn_cell(
+ input_seqs,
+ **weight_dicts[1],
+ act=acts[1],
+ backwards=True,
+ )
+
+ final_outputs = []
+ for i in range(seq_len):
+ final_outputs.append(
+ _op.stack([forward_outputs[i], reverse_outputs[seq_len - 1 -
i]], axis=0)
+ )
+
+ return (
+ _op.stack(final_outputs, axis=0),
+ _op.stack([fw_H_t, rev_H_t], axis=0),
+ )
+
+ @classmethod
+ def _impl_v7(cls, inputs, attr, params):
+ # Unpack inputs, note that if optional and not provided then value
will be None.
+ X = inputs[0]
+ Wp = inputs[1]
+ Rp = inputs[2]
+ Bp = inputs[3]
+ # Sequence length currently unused as it can be inferred from shapes.
+ # sequence_lens = inputs['sequence_lens']
+ Hp_0 = inputs[5]
+
+ num_directions = infer_shape(Wp)[0]
+ W_dtype = infer_type(Wp).checked_type.dtype
+
+ if num_directions not in [1, 2]:
+ raise ValueError("num_directions must be either 1 or 2!")
+
+ X_shape = infer_shape(X)
+ hidden_size = infer_shape(Rp)[-1]
+ batch_size = X_shape[1]
+
+ if Hp_0 is None:
+ Hp_0 = _op.zeros((num_directions, batch_size, hidden_size),
W_dtype)
+
+ if "activations" in attr:
Review Comment:
Can you refactor this as it seems identical to what's in LSTM?
##########
tests/python/frontend/onnx/test_forward.py:
##########
@@ -5212,7 +5341,7 @@ def verify_eyelike(indata, dynamic=False):
"test_reduce_sum_keepdims_random",
"test_reduce_sum_negative_axes_keepdims_example",
"test_reduce_sum_negative_axes_keepdims_random",
- "test_rnn_seq_length",
+ "test_rnn_batchwise",
Review Comment:
Why does this fail?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]