MargaretQian commented on code in PR #10949:
URL: https://github.com/apache/tvm/pull/10949#discussion_r846976559
##########
python/tvm/relay/frontend/onnx.py:
##########
@@ -836,6 +837,208 @@ def _impl_v1(cls, inputs, attr, params):
return Gelu._impl_v1([inp], attr, params)
+class EmbedLayerNormalization(OnnxOpConverter):
+ """Operator converter for EmbedLayerNormalization from Microsoft
onnxruntime contrib opset.
+
+ This layer embeds the input tokens, sums them, and applies layer
normalization.
+ """
+
+ @classmethod
+ def _impl_v1(cls, inputs, attr, params):
+ input_ids = inputs[0]
+ segment_ids = inputs[1]
+ word_emb = inputs[2]
+ pos_emb = inputs[3]
+ segment_emb = inputs[4]
+ gamma = inputs[5]
+ beta = inputs[6]
+
+ mask = inputs[7]
+ pos_ids = inputs[8]
+
+ eps = attr.get("epsilon", 1e-12)
+
+ (batch_size, seq_len) = infer_shape(input_ids)
+
+ if segment_ids:
+ assert segment_emb
+
+ if pos_ids is None:
+ pos_ids = _op.const([list(range(seq_len))] * seq_len,
dtype="int64")
+
+ word_vec = _op.take(word_emb, input_ids, axis=0)
+ segment_vec = _op.take(segment_emb, segment_ids, axis=0)
+ pos_vec = _op.take(pos_emb, pos_ids, axis=0)
+
+ vec_sum = _op.add(word_vec, pos_vec)
+ if segment_ids:
+ vec_sum = _op.add(vec_sum, segment_vec)
+
+ ln = SkipLayerNormalization._compute_layer_norm(vec_sum, eps, gamma,
beta)
Review Comment:
nit: maybe instead of referencing SkipLayerNormalization here, you could
create a LayerNormalization base class that contains `_compute_layer_norm`?
sort of like how `Pool` is the base class for `MaxPool`/`AveragePool` etc?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]