This is an automated email from the ASF dual-hosted git repository.
andrewzhaoluo pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new cc352a4c34 [ONNX] Extend converter for Attention from Microsoft
onnxruntime contrib opset (#13797)
cc352a4c34 is described below
commit cc352a4c34dbfc5b9d8fc561472c73e1c627666d
Author: Alexey Gladyshev <[email protected]>
AuthorDate: Thu Jan 19 20:51:48 2023 +0300
[ONNX] Extend converter for Attention from Microsoft onnxruntime contrib
opset (#13797)
* add type & shape checking
* add base class for Attention converter
* add support for 'past' input
* add support for 'unidirectional' attribute
* fix for 'huggingface implementation'
* add common method for calculating Attention
* expand test coverage for Attention
---
python/tvm/relay/frontend/onnx.py | 517 ++++++++++++++++++-----------
tests/python/frontend/onnx/test_forward.py | 92 +++--
2 files changed, 392 insertions(+), 217 deletions(-)
diff --git a/python/tvm/relay/frontend/onnx.py
b/python/tvm/relay/frontend/onnx.py
index c4eb7774d7..ffd31317e9 100644
--- a/python/tvm/relay/frontend/onnx.py
+++ b/python/tvm/relay/frontend/onnx.py
@@ -1297,7 +1297,213 @@ class SkipLayerNormalization(OnnxOpConverter):
return _expr.TupleWrapper(_expr.Tuple([output, placeholder,
placeholder]), 3)
-class Attention(OnnxOpConverter):
+class OrtAttentionBase:
+ """
+ Base class for Attention and QAttention from Microsoft onnxruntime contrib
opset.
+ """
+
+ @classmethod
+ def _check_input_embeddings(cls, input_emb, valid_types, **kwargs):
+ assert infer_type(input_emb).checked_type.dtype in valid_types
+ assert (
+ len(infer_shape(input_emb)) == 3
+ ), "Input should be 3D tensor with shape (batch_size, sequence_length,
input_hidden_size)"
+ (batch_size, seq_len, input_hidden) = infer_shape(input_emb)
+ assert input_hidden > 0, (
+ "The weight tensor has (input_hidden_size, 3 * output_hidden_size)
shape, so it doesn't"
+ f" make sense to have ({input_hidden}, 3 * output_hidden_size)
weight tensor."
+ )
+ assert seq_len > 0, (
+ "The output tensor has (batch_size, sequence_length, hidden_size)
shape,"
+ f" so it doesn't make sense to have (batch_size, {seq_len},
hidden_size) output."
+ )
+
+ return batch_size, seq_len, input_hidden
+
+ @classmethod
+ def _check_weights(cls, weight, valid_types, **kwargs):
+ assert infer_type(weight).checked_type.dtype in valid_types
+ assert len(infer_shape(weight)) == 2, (
+ "Weight should be 2D input tensor with shape (input_hidden_size, 3
* hidden_size), "
+ "hidden_size = num_heads * head_size"
+ )
+ (input_hidden_weight, out_hidden_x3) = infer_shape(weight)
+ assert kwargs["input_hidden"] == input_hidden_weight
+ assert out_hidden_x3 % 3 == 0, "output hidden shape should be
divisible by 3: W_Q, W_K, W_V"
+ out_hidden = out_hidden_x3 // 3
+ assert (
+ out_hidden % kwargs["num_heads"] == 0
+ ), "output hidden size should be divisible by number of attention
heads"
+ head_size = out_hidden // kwargs["num_heads"]
+
+ return out_hidden_x3, out_hidden, head_size
+
+ @classmethod
+ def _check_bias(cls, bias, valid_types, **kwargs):
+ assert infer_type(bias).checked_type.dtype in valid_types
+ assert (
+ len(infer_shape(bias)) == 1
+ ), "Bias should be 1D input tensor with shape (3 * hidden_size)"
+ (out_hidden_x3_bias,) = infer_shape(bias)
+ assert kwargs["out_hidden_x3"] == out_hidden_x3_bias
+
+ @classmethod
+ def _check_mask_index(cls, mask_index, valid_types, **kwargs):
+ assert infer_type(mask_index).checked_type.dtype in valid_types
+ mask_index_shape = infer_shape(mask_index)
+ assert (
+ len(mask_index_shape) == 2
+ and mask_index_shape[0] == kwargs["batch_size"]
+ and mask_index_shape[1] >= kwargs["seq_len"]
+ ), "currently only support (batch_size, past_sequence_len +
sequence_length) mask index"
+
+ return mask_index_shape[1]
+
+ @classmethod
+ def _check_past(cls, past, valid_types, **kwargs):
+ assert infer_type(past).checked_type.dtype in valid_types
+ past_shape = infer_shape(past)
+ assert len(past_shape) == 5, "past should be 5D tensor"
+ assert (
+ past_shape[0] == 2
+ and past_shape[1] == kwargs["batch_size"]
+ and past_shape[2] == kwargs["num_heads"]
+ and past_shape[3] + kwargs["seq_len"] == kwargs["total_seq_len"]
+ and past_shape[4] == kwargs["head_size"]
+ )
+ past_seq_len = past_shape[3]
+ return past_seq_len
+
+ @classmethod
+ def _split_into_heads(cls, tensor, batch_size, seq_len, num_heads,
head_size):
+ """
+ In the implementation of Multi-head attention we just split queries,
keys, and values
+ we compute for a single-head attention into several parts:
+ (batch_size, num_heads, seq_len, head_size)
+ """
+ tensor = _op.reshape(tensor, (batch_size, seq_len, num_heads,
head_size))
+
+ # (batch_size, num_heads, seq_len, head_size)
+ tensor = _op.transpose(tensor, axes=[0, 2, 1, 3])
+
+ return tensor
+
+ @classmethod
+ def _merge_first_dimensions(cls, tensor):
+ """
+ nn.batch_matmul is expecting 3D tensor:
+ (batch_size * num_heads, past_seq_len + seq_len, head_size)
+ """
+ return _op.reverse_reshape(tensor, (-1, 0, 0))
+
+ @classmethod
+ def _create_unidirectional_mask(cls, left_value, right_value,
past_seq_len, seq_len, dtype):
+ """
+ [lhs rhs rhs ... rhs rhs]
+ [lhs lhs rhs ... rhs rhs]
+ [lhs lhs lhs ... rhs rhs]
+ .........................
+ [lhs lhs lhs ... lhs rhs]
+ [lhs lhs lhs ... lhs lhs]
+ """
+ numpy_unidirectional_mask = np.array(
+ [
+ np.concatenate(
+ [
+ np.full(past_seq_len + s_i + 1, left_value),
+ np.full(seq_len - s_i - 1, right_value),
+ ]
+ )
+ for s_i in range(seq_len)
+ ]
+ )
+ unidirectional_mask = _op.const(numpy_unidirectional_mask, dtype=dtype)
+ unidirectional_mask = _op.expand_dims(unidirectional_mask, 0,
num_newaxis=2)
+
+ return unidirectional_mask
+
+ @classmethod
+ def _compute_attention(cls, Q, K, V, mask_index, **kwargs):
+ # Compute Attention scores
+ att_scores = _op.nn.batch_matmul(Q, K, transpose_a=False,
transpose_b=True)
+ score_dtype = infer_type(att_scores).checked_type.dtype
+ att_scores = _op.divide(
+ att_scores,
+ _op.const(
+ np.sqrt(kwargs["head_size"]),
dtype=infer_type(att_scores).checked_type.dtype
+ ),
+ )
+ att_scores = _op.reshape(
+ att_scores,
+ (
+ kwargs["batch_size"],
+ kwargs["num_heads"],
+ kwargs["seq_len"],
+ kwargs["past_seq_len"] + kwargs["seq_len"],
+ ),
+ )
+
+ # Build the attention mask
+ att_mask = _op.cast(mask_index, score_dtype)
+ # Attention mask has value 0 or 1. Here we convert 0 to -10000, and 1
to 0.
+ att_mask = _op.subtract(_op.const(1, dtype=score_dtype), att_mask)
+ att_mask = _op.multiply(att_mask, _op.const(-10000, dtype=score_dtype))
+ # Expand for att_scores broadcast
+ # (batch_size, past_seq_len + seq_len) -> (batch_size, 1, seq_len,
past_seq_len + seq_len)
+ att_mask = _op.expand_dims(att_mask, 1, num_newaxis=2)
+ att_mask = _op.concatenate([att_mask] * kwargs["seq_len"], axis=2)
+
+ if kwargs["unidirectional"]:
+ att_mask = _op.add(
+ att_mask,
+ cls._create_unidirectional_mask(
+ 0, -10000, kwargs["past_seq_len"], kwargs["seq_len"],
score_dtype
+ ),
+ )
+
+ # Apply the mask
+ att_scores = _op.add(att_scores, att_mask)
+ # TODO(agladyshev):
+ # Comment from ORT source code
(onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h):
+ # "Fix unidirectional mask to be parity with huggingface
implementation"
+ if kwargs["unidirectional"]:
+ att_scores = _op.multiply(
+ att_scores,
+ cls._create_unidirectional_mask(
+ 1, 0, kwargs["past_seq_len"], kwargs["seq_len"],
score_dtype
+ ),
+ )
+ att_scores = _op.add(
+ att_scores,
+ _op.multiply(
+ att_mask,
+ cls._create_unidirectional_mask(
+ 0, 1, kwargs["past_seq_len"], kwargs["seq_len"],
score_dtype
+ ),
+ ),
+ )
+
+ # Compute Softmax
+ att_scores = _op.reshape(
+ att_scores,
+ (
+ kwargs["batch_size"] * kwargs["num_heads"],
+ kwargs["seq_len"],
+ kwargs["past_seq_len"] + kwargs["seq_len"],
+ ),
+ )
+ att_probs = _op.nn.softmax(att_scores, axis=-1)
+
+ # Compute output
+ output = _op.nn.batch_matmul(att_probs, V, transpose_a=False,
transpose_b=False)
+ output = _op.reverse_reshape(output, (-1, kwargs["num_heads"], 0, 0))
+ output = _op.transpose(output, axes=[0, 2, 1, 3])
+ output = _op.reshape(output, (0, 0, kwargs["out_hidden"]))
+
+ return output
+
+
+class Attention(OrtAttentionBase, OnnxOpConverter):
"""Operator converter for Attention from Microsoft onnxruntime contrib
opset.
This is the self-attention mechanism used in transformer models.
@@ -1305,16 +1511,30 @@ class Attention(OnnxOpConverter):
@classmethod
def _impl_v1(cls, inputs, attr, params):
+ # ************************* Read attrs *************************
num_heads = attr["num_heads"]
+ unidirectional = attr["unidirectional"]
+
+ assert (
+ "past_present_share_buffer" not in attr
+ ), "share past and present buffers are not currently supported"
assert (
"qkv_hidden_sizes" not in attr
), "different hidden sizes for Q, K, V are not currently supported"
- assert "unidirectional" not in attr, "unidirectional attention not
current supported"
+ # ************************* Read inputs *************************
# (batch, seq, in_hidden)
input_emb = inputs[0]
- # (in_hidden, 3 * out_hidden), where out_hidden = num_heads * head_size
+ # TODO(agladyshev):
+ # ORT documentation says:
+ # The weights for input projection of Q, K and V are merged.
+ # The data is stacked on the second dimension.
+ # Its shape is (input_hidden_size, hidden_size + hidden_size +
v_hidden_size).
+ # Here hidden_size is the hidden dimension of Q and K, and
v_hidden_size is that of V.
+ # However, in our case, we consider that hidden_size ==
v_hidden_size.
+ # Therefore, weight has the following shape:
+ # (in_hidden, 3 * out_hidden), where out_hidden = num_heads *
head_size
weight = inputs[1]
# (3 * out_hidden,)
@@ -1325,7 +1545,7 @@ class Attention(OnnxOpConverter):
# 3. ( batch, seq, past_seq + seq,)
# 4. ( batch,)
# 5. (2 * batch,)
- # For now, we only support case 2.
+ # TODO: For now, we only support case 2.
mask_index = inputs[3]
# (2, batch, num_heads, past_seq, head_size)
@@ -1333,28 +1553,47 @@ class Attention(OnnxOpConverter):
# (batch, num_heads, seq, seq)
extra_add = inputs[5]
+ assert extra_add is None, "extra add to QxK not currently supported"
- (batch_size, seq_len, _) = infer_shape(input_emb)
- (out_hidden_x3,) = infer_shape(bias)
- assert out_hidden_x3 % 3 == 0, "bias shape should be divisible by 3"
- out_hidden = out_hidden_x3 // 3
- assert (
- out_hidden % num_heads == 0
- ), "output hidden size should be divisible by number of attention
heads"
- head_size = out_hidden // num_heads
+ # When past_present_share_buffer is used,
+ # it is required to specify past_sequence_length (could be 0)
+ past_seq_len = inputs[6]
+ assert past_seq_len is None, "past sequence length not currently
supported"
+
+ # ************************* Parse inputs *************************
+ t = ["float32", "float16"]
+ m = ["int32"]
+
+ # input
+ batch_size, seq_len, input_hidden =
cls._check_input_embeddings(input_emb, t)
+
+ # weight
+ out_hidden_x3, out_hidden, head_size = cls._check_weights(
+ weight, t, num_heads=num_heads, input_hidden=input_hidden
+ )
+ # bias
+ cls._check_bias(bias, t, out_hidden_x3=out_hidden_x3)
+
+ # mask_index
assert (
mask_index is not None
), "Attention import currently only supports required mask_index"
- mask_index_shape = infer_shape(mask_index)
- assert (
- len(mask_index_shape) == 2
- and mask_index_shape[0] == batch_size
- and mask_index_shape[1] == seq_len
- ), "currently only support (batch_size, sequence_length) mask index"
+ total_seq_len = cls._check_mask_index(mask_index, m,
batch_size=batch_size, seq_len=seq_len)
- assert past is None, "past K, V state is not currently supported"
- assert extra_add is None, "extra add to QxK not currently supported"
+ # past
+ if past_seq_len is None:
+ past_seq_len = 0
+ if past is not None:
+ past_seq_len = cls._check_past(
+ past,
+ t,
+ batch_size=batch_size,
+ num_heads=num_heads,
+ seq_len=seq_len,
+ total_seq_len=total_seq_len,
+ head_size=head_size,
+ )
# split weight and biases and do the matmuls
w_Q, w_K, w_V = _op.split(weight, 3, axis=1)
@@ -1365,53 +1604,44 @@ class Attention(OnnxOpConverter):
K = _op.add(_op.nn.matmul(input_emb, w_K), b_K)
V = _op.add(_op.nn.matmul(input_emb, w_V), b_V)
- # massage tensors in preparation for batched matmul
- def massage(tensor):
- tensor = _op.reshape(tensor, (batch_size, seq_len, num_heads,
head_size))
-
- # (batch_size, num_heads, seq_len, head_size)
- tensor = _op.transpose(tensor, axes=[0, 2, 1, 3])
+ Q = cls._split_into_heads(Q, batch_size, seq_len, num_heads, head_size)
+ K = cls._split_into_heads(K, batch_size, seq_len, num_heads, head_size)
+ V = cls._split_into_heads(V, batch_size, seq_len, num_heads, head_size)
- # (batch_size * num_heads, seq_len, head_size)
- return _op.reverse_reshape(tensor, (-1, 0, 0))
-
- Q = massage(Q)
- K = massage(K)
- V = massage(V)
+ # Concatenate (past_K, past_V) with (K, V) by sequence axis:
+ # (batch_size, num_heads, past_sequence_length + sequence_length,
head_size)
+ if past is not None and past_seq_len > 0:
+ K_past, V_past = _op.split(past, 2, axis=0)
+ K = _op.concatenate([_op.squeeze(K_past, axis=[0]), K], axis=2)
+ V = _op.concatenate([_op.squeeze(V_past, axis=[0]), V], axis=2)
- K_present = _op.reshape(K, (batch_size, num_heads, seq_len, head_size))
- V_present = _op.reshape(V, (batch_size, num_heads, seq_len, head_size))
- present = _op.stack([K_present, V_present], axis=0)
+ # Prepare present state for Key and Value with shape
+ # (2, batch_size, num_heads, past_sequence_length + sequence_length,
head_size)
+ present = _op.stack([K, V], axis=0)
- att_scores = _op.nn.batch_matmul(Q, K, transpose_a=False,
transpose_b=True)
- score_dtype = infer_type(att_scores).checked_type.dtype
- att_scores = _op.divide(
- att_scores,
- _op.const(np.sqrt(head_size),
dtype=infer_type(att_scores).checked_type.dtype),
+ Q = cls._merge_first_dimensions(Q)
+ K = cls._merge_first_dimensions(K)
+ V = cls._merge_first_dimensions(V)
+
+ # Compute Attention output
+ output = cls._compute_attention(
+ Q,
+ K,
+ V,
+ mask_index,
+ unidirectional=unidirectional,
+ batch_size=batch_size,
+ out_hidden=out_hidden,
+ num_heads=num_heads,
+ head_size=head_size,
+ seq_len=seq_len,
+ past_seq_len=past_seq_len,
)
- att_scores = _op.reshape(att_scores, (batch_size, num_heads, seq_len,
seq_len))
-
- # build the attention mask
- att_mask = _op.cast(mask_index, score_dtype)
- att_mask = _op.expand_dims(att_mask, 1, num_newaxis=2)
- att_mask = _op.subtract(_op.const(1, dtype=score_dtype), att_mask)
- att_mask = _op.multiply(att_mask, _op.const(-10000, dtype=score_dtype))
-
- # apply the mask
- att_scores = _op.add(att_scores, att_mask)
- att_scores = _op.reshape(att_scores, (batch_size * num_heads, seq_len,
seq_len))
-
- att_probs = _op.nn.softmax(att_scores, axis=-1)
-
- output = _op.nn.batch_matmul(att_probs, V, transpose_a=False,
transpose_b=False)
- output = _op.reverse_reshape(output, (-1, num_heads, 0, 0))
- output = _op.transpose(output, axes=[0, 2, 1, 3])
- output = _op.reshape(output, (0, 0, out_hidden))
return _expr.TupleWrapper(_expr.Tuple([output, present]), 2)
-class QAttention(OnnxOpConverter):
+class QAttention(OrtAttentionBase, OnnxOpConverter):
"""Operator converter for QAttention from Microsoft onnxruntime contrib
opset.
This is the self-attention mechanism used in transformer models.
@@ -1473,42 +1703,15 @@ class QAttention(OnnxOpConverter):
t4 = ["int32"]
# input
- assert infer_type(input_emb).checked_type.dtype in t1
- assert (
- len(infer_shape(input_emb)) == 3
- ), "Input should be 3D tensor with shape (batch_size, sequence_length,
input_hidden_size)"
- (batch_size, seq_len, input_hidden) = infer_shape(input_emb)
- assert input_hidden > 0, (
- "The weight tensor has (input_hidden_size, 3 * output_hidden_size)
shape, so it doesn't"
- f" make sense to have ({input_hidden}, 3 * output_hidden_size)
weight tensor."
- )
- assert seq_len > 0, (
- "The output tensor has (batch_size, sequence_length, hidden_size)
shape,"
- f" so it doesn't make sense to have (batch_size, {seq_len},
hidden_size) output."
- )
+ batch_size, seq_len, input_hidden =
cls._check_input_embeddings(input_emb, t1)
# weight
- assert infer_type(weight).checked_type.dtype in t2
- assert len(infer_shape(weight)) == 2, (
- "Weight should be 2D input tensor with shape (input_hidden_size, 3
* hidden_size), "
- "hidden_size = num_heads * head_size"
+ out_hidden_x3, out_hidden, head_size = cls._check_weights(
+ weight, t2, num_heads=num_heads, input_hidden=input_hidden
)
- (input_hidden_weight, out_hidden_x3) = infer_shape(weight)
- assert input_hidden == input_hidden_weight
- assert out_hidden_x3 % 3 == 0, "output hidden shape should be
divisible by 3: W_Q, W_K, W_V"
- out_hidden = out_hidden_x3 // 3
- assert (
- out_hidden % num_heads == 0
- ), "output hidden size should be divisible by number of attention
heads"
- head_size = out_hidden // num_heads
# bias
- assert infer_type(bias).checked_type.dtype in t3
- assert (
- len(infer_shape(bias)) == 1
- ), "Bias should be 1D input tensor with shape (3 * hidden_size)"
- (out_hidden_x3_bias,) = infer_shape(bias)
- assert out_hidden_x3 == out_hidden_x3_bias
+ cls._check_bias(bias, t3, out_hidden_x3=out_hidden_x3)
# input_scale
assert infer_type(input_scale).checked_type.dtype in t3
@@ -1527,13 +1730,9 @@ class QAttention(OnnxOpConverter):
assert (
mask_index is not None
), "Attention import currently only supports required mask_index"
- assert infer_type(mask_index).checked_type.dtype in t4
- mask_index_shape = infer_shape(mask_index)
- assert (
- len(mask_index_shape) == 2
- and mask_index_shape[0] == batch_size
- and mask_index_shape[1] >= seq_len
- ), "currently only support (batch_size, sequence_length) mask index"
+ total_seq_len = cls._check_mask_index(
+ mask_index, t4, batch_size=batch_size, seq_len=seq_len
+ )
# TODO(agladyshev): int32 required for qnn.batch_matmul
(QnnBatchMatmulRel)
zero_point_zero = _expr.const(0, "int32")
@@ -1557,17 +1756,15 @@ class QAttention(OnnxOpConverter):
# past (2, batch_size, num_heads, past_sequence_length, head_size)
past_seq_len = 0
if past is not None:
- assert infer_type(past).checked_type.dtype in t3
- past_shape = infer_shape(past)
- assert len(past_shape) == 5, "past should be 5D tensor"
- assert (
- past_shape[0] == 2
- and past_shape[1] == batch_size
- and past_shape[2] == num_heads
- and past_shape[3] + seq_len == mask_index_shape[1]
- and past_shape[4] == head_size
+ past_seq_len = cls._check_past(
+ past,
+ t3,
+ batch_size=batch_size,
+ num_heads=num_heads,
+ seq_len=seq_len,
+ total_seq_len=total_seq_len,
+ head_size=head_size,
)
- past_seq_len = past_shape[3]
# ************************* Create Relay *************************
# Add batch dimension for QNN Batch Matmul
@@ -1604,22 +1801,9 @@ class QAttention(OnnxOpConverter):
input_emb, w_V, input_scale, weight_scale, input_zero_point,
weight_zero_point, b_V
)
- def split_into_heads(tensor):
- """
- In the implementation of Multi-head attention we just split
queries, keys, and values
- we compute for a single-head attention into several parts:
- (batch_size, num_heads, seq_len, head_size)
- """
- tensor = _op.reshape(tensor, (batch_size, seq_len, num_heads,
head_size))
-
- # (batch_size, num_heads, seq_len, head_size)
- tensor = _op.transpose(tensor, axes=[0, 2, 1, 3])
-
- return tensor
-
- Q = split_into_heads(Q)
- K = split_into_heads(K)
- V = split_into_heads(V)
+ Q = cls._split_into_heads(Q, batch_size, seq_len, num_heads, head_size)
+ K = cls._split_into_heads(K, batch_size, seq_len, num_heads, head_size)
+ V = cls._split_into_heads(V, batch_size, seq_len, num_heads, head_size)
# Concatenate (past_K, past_V) with (K, V) by sequence axis:
# (batch_size, num_heads, past_sequence_length + sequence_length,
head_size)
@@ -1632,78 +1816,25 @@ class QAttention(OnnxOpConverter):
# (2, batch_size, num_heads, past_sequence_length + sequence_length,
head_size)
present = _op.stack([K, V], axis=0)
- def merge_first_dimensions(tensor):
- """
- nn.batch_matmul is expecting 3D tensor:
- (batch_size * num_heads, past_seq_len + seq_len, head_size)
- """
- return _op.reverse_reshape(tensor, (-1, 0, 0))
-
- Q = merge_first_dimensions(Q)
- K = merge_first_dimensions(K)
- V = merge_first_dimensions(V)
-
- att_scores = _op.nn.batch_matmul(Q, K, transpose_a=False,
transpose_b=True)
- score_dtype = infer_type(att_scores).checked_type.dtype
- att_scores = _op.divide(
- att_scores,
- _op.const(np.sqrt(head_size),
dtype=infer_type(att_scores).checked_type.dtype),
- )
- att_scores = _op.reshape(
- att_scores, (batch_size, num_heads, seq_len, past_seq_len +
seq_len)
+ Q = cls._merge_first_dimensions(Q)
+ K = cls._merge_first_dimensions(K)
+ V = cls._merge_first_dimensions(V)
+
+ # Compute Attention output
+ output = cls._compute_attention(
+ Q,
+ K,
+ V,
+ mask_index,
+ unidirectional=unidirectional,
+ batch_size=batch_size,
+ out_hidden=out_hidden,
+ num_heads=num_heads,
+ head_size=head_size,
+ seq_len=seq_len,
+ past_seq_len=past_seq_len,
)
- # Build the attention mask
- att_mask = _op.cast(mask_index, score_dtype)
- # Attention mask has value 0 or 1. Here we convert 0 to -10000, and 1
to 0.
- att_mask = _op.subtract(_op.const(1, dtype=score_dtype), att_mask)
- att_mask = _op.multiply(att_mask, _op.const(-10000, dtype=score_dtype))
- # Expand for att_scores broadcast
- # (batch_size, past_seq_len + seq_len) -> (batch_size, 1, seq_len,
past_seq_len + seq_len)
- att_mask = _op.expand_dims(att_mask, 1, num_newaxis=2)
- att_mask = _op.concatenate([att_mask] * seq_len, axis=2)
-
- def create_unidirectional_mask(left_value, right_value):
- numpy_unidirectional_mask = np.array(
- [
- np.concatenate(
- [
- np.full(past_seq_len + s_i + 1, left_value),
- np.full(seq_len - s_i - 1, right_value),
- ]
- )
- for s_i in range(seq_len)
- ]
- )
- unidirectional_mask = _op.const(numpy_unidirectional_mask,
dtype=score_dtype)
- unidirectional_mask = _op.expand_dims(unidirectional_mask, 0,
num_newaxis=2)
-
- return unidirectional_mask
-
- if unidirectional:
- att_mask = _op.add(att_mask, create_unidirectional_mask(0, -10000))
-
- # Apply the mask
- att_scores = _op.add(att_scores, att_mask)
- # TODO(agladyshev):
- # Comment from ORT source code
(onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h):
- # "Fix unidirectional mask to be parity with huggingface
implementation"
- if unidirectional:
- att_scores = _op.multiply(att_scores,
create_unidirectional_mask(1, 0))
- att_scores = _op.add(att_scores, create_unidirectional_mask(0,
-10000))
-
- # Compute Softmax
- att_scores = _op.reshape(
- att_scores, (batch_size * num_heads, seq_len, past_seq_len +
seq_len)
- )
- att_probs = _op.nn.softmax(att_scores, axis=-1)
-
- # Compute output
- output = _op.nn.batch_matmul(att_probs, V, transpose_a=False,
transpose_b=False)
- output = _op.reverse_reshape(output, (-1, num_heads, 0, 0))
- output = _op.transpose(output, axes=[0, 2, 1, 3])
- output = _op.reshape(output, (0, 0, out_hidden))
-
return _expr.TupleWrapper(_expr.Tuple([output, present]), 2)
diff --git a/tests/python/frontend/onnx/test_forward.py
b/tests/python/frontend/onnx/test_forward.py
index a84de82f3b..f5b5f7c65c 100644
--- a/tests/python/frontend/onnx/test_forward.py
+++ b/tests/python/frontend/onnx/test_forward.py
@@ -5878,30 +5878,47 @@ def test_embedlayernormalization(target, dev):
def test_attention(target, dev):
"""test_attention"""
- def verify_attention(input_, weight, bias, mask_index, num_heads):
+ def verify_attention(_unidirectional, _input, _weight, _bias,
_mask_index=None, _past=None):
+ input_names = ["input", "weight", "bias"]
+ if _mask_index is not None:
+ input_names.append("mask_index")
+ if _past is not None:
+ input_names.append("past")
+
node = onnx.helper.make_node(
"Attention",
- inputs=["input", "weight", "bias", "mask_index"],
+ inputs=input_names,
outputs=["output", "present"],
domain="com.microsoft",
num_heads=num_heads,
+ unidirectional=_unidirectional,
)
+ past_shape = (2, batch_size, num_heads, past_sequence_length,
head_size)
present_output_shape = (2, batch_size, num_heads, sequence_length,
head_size)
+ inputs_info = [
+ helper.make_tensor_value_info("input", TensorProto.FLOAT,
list(_input.shape)),
+ helper.make_tensor_value_info("weight", TensorProto.FLOAT,
list(_weight.shape)),
+ helper.make_tensor_value_info("bias", TensorProto.FLOAT,
list(_bias.shape)),
+ ]
+ if _mask_index is not None:
+ inputs_info.append(
+ helper.make_tensor_value_info(
+ "mask_index", TensorProto.INT32, list(_mask_index.shape)
+ ),
+ )
+ if _past is not None:
+ inputs_info.append(
+ helper.make_tensor_value_info("past", TensorProto.FLOAT,
list(past_shape))
+ )
+
graph = helper.make_graph(
[node],
"attention_test",
- inputs=[
- helper.make_tensor_value_info("input", TensorProto.FLOAT,
list(input_.shape)),
- helper.make_tensor_value_info("weight", TensorProto.FLOAT,
list(weight.shape)),
- helper.make_tensor_value_info("bias", TensorProto.FLOAT,
list(bias.shape)),
- helper.make_tensor_value_info(
- "mask_index", TensorProto.INT32, list(mask_index.shape)
- ),
- ],
+ inputs=inputs_info,
outputs=[
- helper.make_tensor_value_info("output", TensorProto.FLOAT,
list(input_.shape)),
+ helper.make_tensor_value_info("output", TensorProto.FLOAT,
list(_input.shape)),
helper.make_tensor_value_info(
"present", TensorProto.FLOAT, list(present_output_shape)
),
@@ -5910,31 +5927,58 @@ def test_attention(target, dev):
model = helper.make_model(graph, producer_name="attention_test")
+ inputs = [_input, _weight, _bias]
+ if _mask_index is not None:
+ inputs.append(_mask_index)
+ if _past is not None:
+ inputs.append(_past)
+
# "present" output should be nullptr when the "past" input isn't
included,
# but ort requires an output shape to be specified?
verify_with_ort_with_inputs(
model,
- [input_, weight, bias, mask_index],
- [input_.shape, present_output_shape],
+ inputs,
+ [_input.shape, present_output_shape],
target=target,
dev=dev,
rtol=1e-4,
atol=1e-4,
)
- hidden_size = 384
- batch_size = 4
- sequence_length = 4
- num_heads = 12
- head_size = 32
+ batch_size = 11
+ num_heads = 13
+ head_size = 37
+ sequence_length = 7
+ input_hidden_size = 147
+ weight_hidden_size = num_heads * head_size
+ past_sequence_length = 17
- dtype = "float32"
- input_array = np.random.random((batch_size, sequence_length,
hidden_size)).astype(dtype)
- weight = np.random.normal(size=(hidden_size, 3 *
hidden_size)).astype(dtype) * 0.1
- bias = np.random.randn(3 * hidden_size).astype(dtype)
- mask_index = np.full((batch_size, sequence_length), 1).astype("int32")
+ total_sequence_length = past_sequence_length + sequence_length
- verify_attention(input_array, weight, bias, mask_index, num_heads)
+ # Required inputs
+ input_array = np.random.normal(size=(batch_size, sequence_length,
input_hidden_size)).astype(
+ "float32"
+ )
+ weight = (
+ np.random.normal(size=(input_hidden_size, 3 *
weight_hidden_size)).astype("float32") * 0.1
+ )
+ bias = np.random.randn(3 * weight_hidden_size).astype("float32")
+
+ # Optional inputs
+ past = np.random.random((2, batch_size, num_heads, past_sequence_length,
head_size)).astype(
+ "float32"
+ )
+
+ for unidirectional in [0, 1]:
+ for have_past in [False, True]:
+ if not have_past:
+ mask_index = np.random.randint(0, 2, (batch_size,
sequence_length)).astype("int32")
+ verify_attention(unidirectional, input_array, weight, bias,
mask_index)
+ else:
+ mask_index = np.random.randint(0, 2, (batch_size,
total_sequence_length)).astype(
+ "int32"
+ )
+ verify_attention(unidirectional, input_array, weight, bias,
mask_index, past)
@tvm.testing.parametrize_targets