altanh commented on code in PR #10949:
URL: https://github.com/apache/tvm/pull/10949#discussion_r846523259
##########
python/tvm/relay/frontend/onnx.py:
##########
@@ -836,6 +837,192 @@ def _impl_v1(cls, inputs, attr, params):
return Gelu._impl_v1([inp], attr, params)
+class EmbedLayerNormalization(OnnxOpConverter):
+ @classmethod
+ def _impl_v1(cls, inputs, attr, params):
+ input_ids = inputs[0]
+ segment_ids = inputs[1]
+ word_emb = inputs[2]
+ pos_emb = inputs[3]
+ segment_emb = inputs[4]
+ gamma = inputs[5]
+ beta = inputs[6]
+
+ mask = inputs[7]
Review Comment:
the inputs to these converters are actually wrapped in an `onnx_input`
object which returns None for out of bounds accesses, so we can safely do this
(we also can't use `len` properly here as trailing optional inputs aren't
counted)
##########
python/tvm/relay/frontend/onnx.py:
##########
@@ -836,6 +837,192 @@ def _impl_v1(cls, inputs, attr, params):
return Gelu._impl_v1([inp], attr, params)
+class EmbedLayerNormalization(OnnxOpConverter):
+ @classmethod
+ def _impl_v1(cls, inputs, attr, params):
+ input_ids = inputs[0]
+ segment_ids = inputs[1]
+ word_emb = inputs[2]
+ pos_emb = inputs[3]
+ segment_emb = inputs[4]
+ gamma = inputs[5]
+ beta = inputs[6]
+
+ mask = inputs[7]
+ pos_ids = inputs[8]
+
+ eps = attr["epsilon"] if "epsilon" in attr else 1e-12
+
+ (batch_size, seq_len) = infer_shape(input_ids)
+
+ if segment_ids:
+ assert segment_emb
+
+ if pos_ids is None:
+ pos_ids = _op.const([list(range(seq_len))] * seq_len,
dtype="int64")
+
+ word_vec = _op.take(word_emb, input_ids, axis=0)
+ segment_vec = _op.take(segment_emb, segment_ids, axis=0)
+ pos_vec = _op.take(pos_emb, pos_ids, axis=0)
+
+ vec_sum = _op.add(word_vec, pos_vec)
+ if segment_ids:
+ vec_sum = _op.add(vec_sum, segment_vec)
+
+ eps_dtype = infer_type(word_emb).checked_type.dtype
+
+ u, s = _op.mean_variance(vec_sum, axis=-1, keepdims=True)
+ ln = _op.divide(
+ _op.subtract(vec_sum, u),
+ _op.sqrt(_op.add(s, _op.const(eps, dtype=eps_dtype))),
+ )
+ ln = _op.multiply(ln, gamma) + beta
+
+ mask_index = _op.const(np.zeros((batch_size,), dtype="int64"))
+ if mask:
+ # calculate number of words per sentence
+ mask_index = _op.sum(mask, axis=1)
+
+ return _expr.TupleWrapper(_expr.Tuple([ln, mask_index, vec_sum]), 3)
+
+
+class SkipLayerNormalization(OnnxOpConverter):
+ @classmethod
+ def _impl_v1(cls, inputs, attr, params):
+ data = inputs[0]
+ skip = inputs[1]
+ gamma = inputs[2]
+ beta = inputs[3]
Review Comment:
see previous
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]