anwang2009 commented on code in PR #10949:
URL: https://github.com/apache/tvm/pull/10949#discussion_r847576581
##########
python/tvm/relay/frontend/onnx.py:
##########
@@ -836,6 +837,208 @@ def _impl_v1(cls, inputs, attr, params):
return Gelu._impl_v1([inp], attr, params)
+class EmbedLayerNormalization(OnnxOpConverter):
+ """Operator converter for EmbedLayerNormalization from Microsoft
onnxruntime contrib opset.
+
+ This layer embeds the input tokens, sums them, and applies layer
normalization.
+ """
+
+ @classmethod
+ def _impl_v1(cls, inputs, attr, params):
+ input_ids = inputs[0]
+ segment_ids = inputs[1]
+ word_emb = inputs[2]
+ pos_emb = inputs[3]
+ segment_emb = inputs[4]
+ gamma = inputs[5]
+ beta = inputs[6]
+
+ mask = inputs[7]
+ pos_ids = inputs[8]
+
+ eps = attr.get("epsilon", 1e-12)
+
+ (batch_size, seq_len) = infer_shape(input_ids)
+
+ if segment_ids:
+ assert segment_emb
+
+ if pos_ids is None:
+ pos_ids = _op.const([list(range(seq_len))] * seq_len,
dtype="int64")
+
+ word_vec = _op.take(word_emb, input_ids, axis=0)
+ segment_vec = _op.take(segment_emb, segment_ids, axis=0)
+ pos_vec = _op.take(pos_emb, pos_ids, axis=0)
+
+ vec_sum = _op.add(word_vec, pos_vec)
+ if segment_ids:
+ vec_sum = _op.add(vec_sum, segment_vec)
+
+ ln = SkipLayerNormalization._compute_layer_norm(vec_sum, eps, gamma,
beta)
Review Comment:
redefining _compute_layer_norm as a global func in this file -- won't put it
in a separate class as semantically LayerNorm is not an onnx operator.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]