MargaretQian commented on code in PR #10949:
URL: https://github.com/apache/tvm/pull/10949#discussion_r846537572


##########
python/tvm/relay/frontend/onnx.py:
##########
@@ -836,6 +837,208 @@ def _impl_v1(cls, inputs, attr, params):
         return Gelu._impl_v1([inp], attr, params)
 
 
+class EmbedLayerNormalization(OnnxOpConverter):
+    """Operator converter for EmbedLayerNormalization from Microsoft 
onnxruntime contrib opset.
+
+    This layer embeds the input tokens, sums them, and applies layer 
normalization.
+    """
+
+    @classmethod
+    def _impl_v1(cls, inputs, attr, params):
+        input_ids = inputs[0]
+        segment_ids = inputs[1]
+        word_emb = inputs[2]
+        pos_emb = inputs[3]
+        segment_emb = inputs[4]
+        gamma = inputs[5]
+        beta = inputs[6]
+
+        mask = inputs[7]
+        pos_ids = inputs[8]
+
+        eps = attr.get("epsilon", 1e-12)
+
+        (batch_size, seq_len) = infer_shape(input_ids)
+
+        if segment_ids:
+            assert segment_emb
+
+        if pos_ids is None:
+            pos_ids = _op.const([list(range(seq_len))] * seq_len, 
dtype="int64")
+
+        word_vec = _op.take(word_emb, input_ids, axis=0)
+        segment_vec = _op.take(segment_emb, segment_ids, axis=0)
+        pos_vec = _op.take(pos_emb, pos_ids, axis=0)
+
+        vec_sum = _op.add(word_vec, pos_vec)
+        if segment_ids:
+            vec_sum = _op.add(vec_sum, segment_vec)
+
+        eps_dtype = infer_type(word_emb).checked_type.dtype
+
+        u, s = _op.mean_variance(vec_sum, axis=-1, keepdims=True)
+        ln = _op.divide(
+            _op.subtract(vec_sum, u),
+            _op.sqrt(_op.add(s, _op.const(eps, dtype=eps_dtype))),
+        )
+        ln = _op.multiply(ln, gamma) + beta
+
+        mask_index = _op.const(np.zeros((batch_size,), dtype="int64"))
+        if mask:
+            # calculate number of words per sentence
+            mask_index = _op.sum(mask, axis=1)
+
+        return _expr.TupleWrapper(_expr.Tuple([ln, mask_index, vec_sum]), 3)
+
+
+class SkipLayerNormalization(OnnxOpConverter):
+    """Operator converter for SkipLayerNormalization from Microsoft 
onnxruntime contrib opset.
+
+    This layer sums the two input tensors (along with optional bias), and 
applies layer
+    normalization.
+    """
+
+    @classmethod
+    def _impl_v1(cls, inputs, attr, params):
+        data = inputs[0]
+        skip = inputs[1]
+        gamma = inputs[2]
+        beta = inputs[3]
+        bias = inputs[4]
+
+        eps = attr.get("epsilon", 1e-12)
+
+        x = _op.add(data, skip)
+        if bias is not None:
+            x = _op.add(x, bias)
+
+        eps_dtype = infer_type(x).checked_type.dtype
+
+        u, s = _op.mean_variance(x, axis=-1, keepdims=True)
+        output = _op.divide(
+            _op.subtract(x, u),
+            _op.sqrt(_op.add(s, _op.const(eps, dtype=eps_dtype))),
+        )
+        output = _op.multiply(output, gamma)

Review Comment:
   nit: this is basically same normalization calculation as 877-886 above 
right? if it's easy, can we pull it out into a common helper function?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to