This is an automated email from the ASF dual-hosted git repository.

andrewzhaoluo pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new cc352a4c34 [ONNX] Extend converter for Attention from Microsoft 
onnxruntime contrib opset (#13797)
cc352a4c34 is described below

commit cc352a4c34dbfc5b9d8fc561472c73e1c627666d
Author: Alexey Gladyshev <[email protected]>
AuthorDate: Thu Jan 19 20:51:48 2023 +0300

    [ONNX] Extend converter for Attention from Microsoft onnxruntime contrib 
opset (#13797)
    
    * add type & shape checking
    
    * add base class for Attention converter
    
    * add support for 'past' input
    
    * add support for 'unidirectional' attribute
    
    * fix for 'huggingface implementation'
    
    * add common method for calculating Attention
    
    * expand test coverage for Attention
---
 python/tvm/relay/frontend/onnx.py          | 517 ++++++++++++++++++-----------
 tests/python/frontend/onnx/test_forward.py |  92 +++--
 2 files changed, 392 insertions(+), 217 deletions(-)

diff --git a/python/tvm/relay/frontend/onnx.py 
b/python/tvm/relay/frontend/onnx.py
index c4eb7774d7..ffd31317e9 100644
--- a/python/tvm/relay/frontend/onnx.py
+++ b/python/tvm/relay/frontend/onnx.py
@@ -1297,7 +1297,213 @@ class SkipLayerNormalization(OnnxOpConverter):
         return _expr.TupleWrapper(_expr.Tuple([output, placeholder, 
placeholder]), 3)
 
 
-class Attention(OnnxOpConverter):
+class OrtAttentionBase:
+    """
+    Base class for Attention and QAttention from Microsoft onnxruntime contrib 
opset.
+    """
+
+    @classmethod
+    def _check_input_embeddings(cls, input_emb, valid_types, **kwargs):
+        assert infer_type(input_emb).checked_type.dtype in valid_types
+        assert (
+            len(infer_shape(input_emb)) == 3
+        ), "Input should be 3D tensor with shape (batch_size, sequence_length, 
input_hidden_size)"
+        (batch_size, seq_len, input_hidden) = infer_shape(input_emb)
+        assert input_hidden > 0, (
+            "The weight tensor has (input_hidden_size, 3 * output_hidden_size) 
shape, so it doesn't"
+            f" make sense to have ({input_hidden}, 3 * output_hidden_size) 
weight tensor."
+        )
+        assert seq_len > 0, (
+            "The output tensor has (batch_size, sequence_length, hidden_size) 
shape,"
+            f" so it doesn't make sense to have (batch_size, {seq_len}, 
hidden_size) output."
+        )
+
+        return batch_size, seq_len, input_hidden
+
+    @classmethod
+    def _check_weights(cls, weight, valid_types, **kwargs):
+        assert infer_type(weight).checked_type.dtype in valid_types
+        assert len(infer_shape(weight)) == 2, (
+            "Weight should be 2D input tensor with shape (input_hidden_size, 3 
* hidden_size), "
+            "hidden_size = num_heads * head_size"
+        )
+        (input_hidden_weight, out_hidden_x3) = infer_shape(weight)
+        assert kwargs["input_hidden"] == input_hidden_weight
+        assert out_hidden_x3 % 3 == 0, "output hidden shape should be 
divisible by 3: W_Q, W_K, W_V"
+        out_hidden = out_hidden_x3 // 3
+        assert (
+            out_hidden % kwargs["num_heads"] == 0
+        ), "output hidden size should be divisible by number of attention 
heads"
+        head_size = out_hidden // kwargs["num_heads"]
+
+        return out_hidden_x3, out_hidden, head_size
+
+    @classmethod
+    def _check_bias(cls, bias, valid_types, **kwargs):
+        assert infer_type(bias).checked_type.dtype in valid_types
+        assert (
+            len(infer_shape(bias)) == 1
+        ), "Bias should be 1D input tensor with shape (3 * hidden_size)"
+        (out_hidden_x3_bias,) = infer_shape(bias)
+        assert kwargs["out_hidden_x3"] == out_hidden_x3_bias
+
+    @classmethod
+    def _check_mask_index(cls, mask_index, valid_types, **kwargs):
+        assert infer_type(mask_index).checked_type.dtype in valid_types
+        mask_index_shape = infer_shape(mask_index)
+        assert (
+            len(mask_index_shape) == 2
+            and mask_index_shape[0] == kwargs["batch_size"]
+            and mask_index_shape[1] >= kwargs["seq_len"]
+        ), "currently only support (batch_size, past_sequence_len + 
sequence_length) mask index"
+
+        return mask_index_shape[1]
+
+    @classmethod
+    def _check_past(cls, past, valid_types, **kwargs):
+        assert infer_type(past).checked_type.dtype in valid_types
+        past_shape = infer_shape(past)
+        assert len(past_shape) == 5, "past should be 5D tensor"
+        assert (
+            past_shape[0] == 2
+            and past_shape[1] == kwargs["batch_size"]
+            and past_shape[2] == kwargs["num_heads"]
+            and past_shape[3] + kwargs["seq_len"] == kwargs["total_seq_len"]
+            and past_shape[4] == kwargs["head_size"]
+        )
+        past_seq_len = past_shape[3]
+        return past_seq_len
+
+    @classmethod
+    def _split_into_heads(cls, tensor, batch_size, seq_len, num_heads, 
head_size):
+        """
+        In the implementation of Multi-head attention we just split queries, 
keys, and values
+        we compute for a single-head attention into several parts:
+        (batch_size, num_heads, seq_len, head_size)
+        """
+        tensor = _op.reshape(tensor, (batch_size, seq_len, num_heads, 
head_size))
+
+        # (batch_size, num_heads, seq_len, head_size)
+        tensor = _op.transpose(tensor, axes=[0, 2, 1, 3])
+
+        return tensor
+
+    @classmethod
+    def _merge_first_dimensions(cls, tensor):
+        """
+        nn.batch_matmul is expecting 3D tensor:
+        (batch_size * num_heads, past_seq_len + seq_len, head_size)
+        """
+        return _op.reverse_reshape(tensor, (-1, 0, 0))
+
+    @classmethod
+    def _create_unidirectional_mask(cls, left_value, right_value, 
past_seq_len, seq_len, dtype):
+        """
+        [lhs rhs rhs ... rhs rhs]
+        [lhs lhs rhs ... rhs rhs]
+        [lhs lhs lhs ... rhs rhs]
+        .........................
+        [lhs lhs lhs ... lhs rhs]
+        [lhs lhs lhs ... lhs lhs]
+        """
+        numpy_unidirectional_mask = np.array(
+            [
+                np.concatenate(
+                    [
+                        np.full(past_seq_len + s_i + 1, left_value),
+                        np.full(seq_len - s_i - 1, right_value),
+                    ]
+                )
+                for s_i in range(seq_len)
+            ]
+        )
+        unidirectional_mask = _op.const(numpy_unidirectional_mask, dtype=dtype)
+        unidirectional_mask = _op.expand_dims(unidirectional_mask, 0, 
num_newaxis=2)
+
+        return unidirectional_mask
+
+    @classmethod
+    def _compute_attention(cls, Q, K, V, mask_index, **kwargs):
+        # Compute Attention scores
+        att_scores = _op.nn.batch_matmul(Q, K, transpose_a=False, 
transpose_b=True)
+        score_dtype = infer_type(att_scores).checked_type.dtype
+        att_scores = _op.divide(
+            att_scores,
+            _op.const(
+                np.sqrt(kwargs["head_size"]), 
dtype=infer_type(att_scores).checked_type.dtype
+            ),
+        )
+        att_scores = _op.reshape(
+            att_scores,
+            (
+                kwargs["batch_size"],
+                kwargs["num_heads"],
+                kwargs["seq_len"],
+                kwargs["past_seq_len"] + kwargs["seq_len"],
+            ),
+        )
+
+        # Build the attention mask
+        att_mask = _op.cast(mask_index, score_dtype)
+        # Attention mask has value 0 or 1. Here we convert 0 to -10000, and 1 
to 0.
+        att_mask = _op.subtract(_op.const(1, dtype=score_dtype), att_mask)
+        att_mask = _op.multiply(att_mask, _op.const(-10000, dtype=score_dtype))
+        # Expand for att_scores broadcast
+        # (batch_size, past_seq_len + seq_len) -> (batch_size, 1, seq_len, 
past_seq_len + seq_len)
+        att_mask = _op.expand_dims(att_mask, 1, num_newaxis=2)
+        att_mask = _op.concatenate([att_mask] * kwargs["seq_len"], axis=2)
+
+        if kwargs["unidirectional"]:
+            att_mask = _op.add(
+                att_mask,
+                cls._create_unidirectional_mask(
+                    0, -10000, kwargs["past_seq_len"], kwargs["seq_len"], 
score_dtype
+                ),
+            )
+
+        # Apply the mask
+        att_scores = _op.add(att_scores, att_mask)
+        # TODO(agladyshev):
+        #   Comment from ORT source code 
(onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h):
+        #   "Fix unidirectional mask to be parity with huggingface 
implementation"
+        if kwargs["unidirectional"]:
+            att_scores = _op.multiply(
+                att_scores,
+                cls._create_unidirectional_mask(
+                    1, 0, kwargs["past_seq_len"], kwargs["seq_len"], 
score_dtype
+                ),
+            )
+            att_scores = _op.add(
+                att_scores,
+                _op.multiply(
+                    att_mask,
+                    cls._create_unidirectional_mask(
+                        0, 1, kwargs["past_seq_len"], kwargs["seq_len"], 
score_dtype
+                    ),
+                ),
+            )
+
+        # Compute Softmax
+        att_scores = _op.reshape(
+            att_scores,
+            (
+                kwargs["batch_size"] * kwargs["num_heads"],
+                kwargs["seq_len"],
+                kwargs["past_seq_len"] + kwargs["seq_len"],
+            ),
+        )
+        att_probs = _op.nn.softmax(att_scores, axis=-1)
+
+        # Compute output
+        output = _op.nn.batch_matmul(att_probs, V, transpose_a=False, 
transpose_b=False)
+        output = _op.reverse_reshape(output, (-1, kwargs["num_heads"], 0, 0))
+        output = _op.transpose(output, axes=[0, 2, 1, 3])
+        output = _op.reshape(output, (0, 0, kwargs["out_hidden"]))
+
+        return output
+
+
+class Attention(OrtAttentionBase, OnnxOpConverter):
     """Operator converter for Attention from Microsoft onnxruntime contrib 
opset.
 
     This is the self-attention mechanism used in transformer models.
@@ -1305,16 +1511,30 @@ class Attention(OnnxOpConverter):
 
     @classmethod
     def _impl_v1(cls, inputs, attr, params):
+        # ************************* Read attrs *************************
         num_heads = attr["num_heads"]
+        unidirectional = attr["unidirectional"]
+
+        assert (
+            "past_present_share_buffer" not in attr
+        ), "share past and present buffers are not currently supported"
         assert (
             "qkv_hidden_sizes" not in attr
         ), "different hidden sizes for Q, K, V are not currently supported"
-        assert "unidirectional" not in attr, "unidirectional attention not 
current supported"
 
+        # ************************* Read inputs *************************
         # (batch, seq, in_hidden)
         input_emb = inputs[0]
 
-        # (in_hidden, 3 * out_hidden), where out_hidden = num_heads * head_size
+        # TODO(agladyshev):
+        #   ORT documentation says:
+        #       The weights for input projection of Q, K and V are merged.
+        #       The data is stacked on the second dimension.
+        #       Its shape is (input_hidden_size, hidden_size + hidden_size + 
v_hidden_size).
+        #       Here hidden_size is the hidden dimension of Q and K, and 
v_hidden_size is that of V.
+        #   However, in our case, we consider that hidden_size == 
v_hidden_size.
+        #   Therefore, weight has the following shape:
+        #       (in_hidden, 3 * out_hidden), where out_hidden = num_heads * 
head_size
         weight = inputs[1]
 
         # (3 * out_hidden,)
@@ -1325,7 +1545,7 @@ class Attention(OnnxOpConverter):
         # 3. (    batch,            seq, past_seq + seq,)
         # 4. (    batch,)
         # 5. (2 * batch,)
-        # For now, we only support case 2.
+        # TODO: For now, we only support case 2.
         mask_index = inputs[3]
 
         # (2, batch, num_heads, past_seq, head_size)
@@ -1333,28 +1553,47 @@ class Attention(OnnxOpConverter):
 
         # (batch, num_heads, seq, seq)
         extra_add = inputs[5]
+        assert extra_add is None, "extra add to QxK not currently supported"
 
-        (batch_size, seq_len, _) = infer_shape(input_emb)
-        (out_hidden_x3,) = infer_shape(bias)
-        assert out_hidden_x3 % 3 == 0, "bias shape should be divisible by 3"
-        out_hidden = out_hidden_x3 // 3
-        assert (
-            out_hidden % num_heads == 0
-        ), "output hidden size should be divisible by number of attention 
heads"
-        head_size = out_hidden // num_heads
+        # When past_present_share_buffer is used,
+        # it is required to specify past_sequence_length (could be 0)
+        past_seq_len = inputs[6]
+        assert past_seq_len is None, "past sequence length not currently 
supported"
+
+        # ************************* Parse inputs *************************
+        t = ["float32", "float16"]
+        m = ["int32"]
+
+        # input
+        batch_size, seq_len, input_hidden = 
cls._check_input_embeddings(input_emb, t)
+
+        # weight
+        out_hidden_x3, out_hidden, head_size = cls._check_weights(
+            weight, t, num_heads=num_heads, input_hidden=input_hidden
+        )
 
+        # bias
+        cls._check_bias(bias, t, out_hidden_x3=out_hidden_x3)
+
+        # mask_index
         assert (
             mask_index is not None
         ), "Attention import currently only supports required mask_index"
-        mask_index_shape = infer_shape(mask_index)
-        assert (
-            len(mask_index_shape) == 2
-            and mask_index_shape[0] == batch_size
-            and mask_index_shape[1] == seq_len
-        ), "currently only support (batch_size, sequence_length) mask index"
+        total_seq_len = cls._check_mask_index(mask_index, m, 
batch_size=batch_size, seq_len=seq_len)
 
-        assert past is None, "past K, V state is not currently supported"
-        assert extra_add is None, "extra add to QxK not currently supported"
+        # past
+        if past_seq_len is None:
+            past_seq_len = 0
+        if past is not None:
+            past_seq_len = cls._check_past(
+                past,
+                t,
+                batch_size=batch_size,
+                num_heads=num_heads,
+                seq_len=seq_len,
+                total_seq_len=total_seq_len,
+                head_size=head_size,
+            )
 
         # split weight and biases and do the matmuls
         w_Q, w_K, w_V = _op.split(weight, 3, axis=1)
@@ -1365,53 +1604,44 @@ class Attention(OnnxOpConverter):
         K = _op.add(_op.nn.matmul(input_emb, w_K), b_K)
         V = _op.add(_op.nn.matmul(input_emb, w_V), b_V)
 
-        # massage tensors in preparation for batched matmul
-        def massage(tensor):
-            tensor = _op.reshape(tensor, (batch_size, seq_len, num_heads, 
head_size))
-
-            # (batch_size, num_heads, seq_len, head_size)
-            tensor = _op.transpose(tensor, axes=[0, 2, 1, 3])
+        Q = cls._split_into_heads(Q, batch_size, seq_len, num_heads, head_size)
+        K = cls._split_into_heads(K, batch_size, seq_len, num_heads, head_size)
+        V = cls._split_into_heads(V, batch_size, seq_len, num_heads, head_size)
 
-            # (batch_size * num_heads, seq_len, head_size)
-            return _op.reverse_reshape(tensor, (-1, 0, 0))
-
-        Q = massage(Q)
-        K = massage(K)
-        V = massage(V)
+        # Concatenate (past_K, past_V) with (K, V) by sequence axis:
+        # (batch_size, num_heads, past_sequence_length + sequence_length, 
head_size)
+        if past is not None and past_seq_len > 0:
+            K_past, V_past = _op.split(past, 2, axis=0)
+            K = _op.concatenate([_op.squeeze(K_past, axis=[0]), K], axis=2)
+            V = _op.concatenate([_op.squeeze(V_past, axis=[0]), V], axis=2)
 
-        K_present = _op.reshape(K, (batch_size, num_heads, seq_len, head_size))
-        V_present = _op.reshape(V, (batch_size, num_heads, seq_len, head_size))
-        present = _op.stack([K_present, V_present], axis=0)
+        # Prepare present state for Key and Value with shape
+        # (2, batch_size, num_heads, past_sequence_length + sequence_length, 
head_size)
+        present = _op.stack([K, V], axis=0)
 
-        att_scores = _op.nn.batch_matmul(Q, K, transpose_a=False, 
transpose_b=True)
-        score_dtype = infer_type(att_scores).checked_type.dtype
-        att_scores = _op.divide(
-            att_scores,
-            _op.const(np.sqrt(head_size), 
dtype=infer_type(att_scores).checked_type.dtype),
+        Q = cls._merge_first_dimensions(Q)
+        K = cls._merge_first_dimensions(K)
+        V = cls._merge_first_dimensions(V)
+
+        # Compute Attention output
+        output = cls._compute_attention(
+            Q,
+            K,
+            V,
+            mask_index,
+            unidirectional=unidirectional,
+            batch_size=batch_size,
+            out_hidden=out_hidden,
+            num_heads=num_heads,
+            head_size=head_size,
+            seq_len=seq_len,
+            past_seq_len=past_seq_len,
         )
-        att_scores = _op.reshape(att_scores, (batch_size, num_heads, seq_len, 
seq_len))
-
-        # build the attention mask
-        att_mask = _op.cast(mask_index, score_dtype)
-        att_mask = _op.expand_dims(att_mask, 1, num_newaxis=2)
-        att_mask = _op.subtract(_op.const(1, dtype=score_dtype), att_mask)
-        att_mask = _op.multiply(att_mask, _op.const(-10000, dtype=score_dtype))
-
-        # apply the mask
-        att_scores = _op.add(att_scores, att_mask)
-        att_scores = _op.reshape(att_scores, (batch_size * num_heads, seq_len, 
seq_len))
-
-        att_probs = _op.nn.softmax(att_scores, axis=-1)
-
-        output = _op.nn.batch_matmul(att_probs, V, transpose_a=False, 
transpose_b=False)
-        output = _op.reverse_reshape(output, (-1, num_heads, 0, 0))
-        output = _op.transpose(output, axes=[0, 2, 1, 3])
-        output = _op.reshape(output, (0, 0, out_hidden))
 
         return _expr.TupleWrapper(_expr.Tuple([output, present]), 2)
 
 
-class QAttention(OnnxOpConverter):
+class QAttention(OrtAttentionBase, OnnxOpConverter):
     """Operator converter for QAttention from Microsoft onnxruntime contrib 
opset.
 
     This is the self-attention mechanism used in transformer models.
@@ -1473,42 +1703,15 @@ class QAttention(OnnxOpConverter):
         t4 = ["int32"]
 
         # input
-        assert infer_type(input_emb).checked_type.dtype in t1
-        assert (
-            len(infer_shape(input_emb)) == 3
-        ), "Input should be 3D tensor with shape (batch_size, sequence_length, 
input_hidden_size)"
-        (batch_size, seq_len, input_hidden) = infer_shape(input_emb)
-        assert input_hidden > 0, (
-            "The weight tensor has (input_hidden_size, 3 * output_hidden_size) 
shape, so it doesn't"
-            f" make sense to have ({input_hidden}, 3 * output_hidden_size) 
weight tensor."
-        )
-        assert seq_len > 0, (
-            "The output tensor has (batch_size, sequence_length, hidden_size) 
shape,"
-            f" so it doesn't make sense to have (batch_size, {seq_len}, 
hidden_size) output."
-        )
+        batch_size, seq_len, input_hidden = 
cls._check_input_embeddings(input_emb, t1)
 
         # weight
-        assert infer_type(weight).checked_type.dtype in t2
-        assert len(infer_shape(weight)) == 2, (
-            "Weight should be 2D input tensor with shape (input_hidden_size, 3 
* hidden_size), "
-            "hidden_size = num_heads * head_size"
+        out_hidden_x3, out_hidden, head_size = cls._check_weights(
+            weight, t2, num_heads=num_heads, input_hidden=input_hidden
         )
-        (input_hidden_weight, out_hidden_x3) = infer_shape(weight)
-        assert input_hidden == input_hidden_weight
-        assert out_hidden_x3 % 3 == 0, "output hidden shape should be 
divisible by 3: W_Q, W_K, W_V"
-        out_hidden = out_hidden_x3 // 3
-        assert (
-            out_hidden % num_heads == 0
-        ), "output hidden size should be divisible by number of attention 
heads"
-        head_size = out_hidden // num_heads
 
         # bias
-        assert infer_type(bias).checked_type.dtype in t3
-        assert (
-            len(infer_shape(bias)) == 1
-        ), "Bias should be 1D input tensor with shape (3 * hidden_size)"
-        (out_hidden_x3_bias,) = infer_shape(bias)
-        assert out_hidden_x3 == out_hidden_x3_bias
+        cls._check_bias(bias, t3, out_hidden_x3=out_hidden_x3)
 
         # input_scale
         assert infer_type(input_scale).checked_type.dtype in t3
@@ -1527,13 +1730,9 @@ class QAttention(OnnxOpConverter):
         assert (
             mask_index is not None
         ), "Attention import currently only supports required mask_index"
-        assert infer_type(mask_index).checked_type.dtype in t4
-        mask_index_shape = infer_shape(mask_index)
-        assert (
-            len(mask_index_shape) == 2
-            and mask_index_shape[0] == batch_size
-            and mask_index_shape[1] >= seq_len
-        ), "currently only support (batch_size, sequence_length) mask index"
+        total_seq_len = cls._check_mask_index(
+            mask_index, t4, batch_size=batch_size, seq_len=seq_len
+        )
 
         # TODO(agladyshev): int32 required for qnn.batch_matmul 
(QnnBatchMatmulRel)
         zero_point_zero = _expr.const(0, "int32")
@@ -1557,17 +1756,15 @@ class QAttention(OnnxOpConverter):
         # past (2, batch_size, num_heads, past_sequence_length, head_size)
         past_seq_len = 0
         if past is not None:
-            assert infer_type(past).checked_type.dtype in t3
-            past_shape = infer_shape(past)
-            assert len(past_shape) == 5, "past should be 5D tensor"
-            assert (
-                past_shape[0] == 2
-                and past_shape[1] == batch_size
-                and past_shape[2] == num_heads
-                and past_shape[3] + seq_len == mask_index_shape[1]
-                and past_shape[4] == head_size
+            past_seq_len = cls._check_past(
+                past,
+                t3,
+                batch_size=batch_size,
+                num_heads=num_heads,
+                seq_len=seq_len,
+                total_seq_len=total_seq_len,
+                head_size=head_size,
             )
-            past_seq_len = past_shape[3]
 
         # ************************* Create Relay *************************
         # Add batch dimension for QNN Batch Matmul
@@ -1604,22 +1801,9 @@ class QAttention(OnnxOpConverter):
             input_emb, w_V, input_scale, weight_scale, input_zero_point, 
weight_zero_point, b_V
         )
 
-        def split_into_heads(tensor):
-            """
-            In the implementation of Multi-head attention we just split 
queries, keys, and values
-            we compute for a single-head attention into several parts:
-            (batch_size, num_heads, seq_len, head_size)
-            """
-            tensor = _op.reshape(tensor, (batch_size, seq_len, num_heads, 
head_size))
-
-            # (batch_size, num_heads, seq_len, head_size)
-            tensor = _op.transpose(tensor, axes=[0, 2, 1, 3])
-
-            return tensor
-
-        Q = split_into_heads(Q)
-        K = split_into_heads(K)
-        V = split_into_heads(V)
+        Q = cls._split_into_heads(Q, batch_size, seq_len, num_heads, head_size)
+        K = cls._split_into_heads(K, batch_size, seq_len, num_heads, head_size)
+        V = cls._split_into_heads(V, batch_size, seq_len, num_heads, head_size)
 
         # Concatenate (past_K, past_V) with (K, V) by sequence axis:
         # (batch_size, num_heads, past_sequence_length + sequence_length, 
head_size)
@@ -1632,78 +1816,25 @@ class QAttention(OnnxOpConverter):
         # (2, batch_size, num_heads, past_sequence_length + sequence_length, 
head_size)
         present = _op.stack([K, V], axis=0)
 
-        def merge_first_dimensions(tensor):
-            """
-            nn.batch_matmul is expecting 3D tensor:
-            (batch_size * num_heads, past_seq_len + seq_len, head_size)
-            """
-            return _op.reverse_reshape(tensor, (-1, 0, 0))
-
-        Q = merge_first_dimensions(Q)
-        K = merge_first_dimensions(K)
-        V = merge_first_dimensions(V)
-
-        att_scores = _op.nn.batch_matmul(Q, K, transpose_a=False, 
transpose_b=True)
-        score_dtype = infer_type(att_scores).checked_type.dtype
-        att_scores = _op.divide(
-            att_scores,
-            _op.const(np.sqrt(head_size), 
dtype=infer_type(att_scores).checked_type.dtype),
-        )
-        att_scores = _op.reshape(
-            att_scores, (batch_size, num_heads, seq_len, past_seq_len + 
seq_len)
+        Q = cls._merge_first_dimensions(Q)
+        K = cls._merge_first_dimensions(K)
+        V = cls._merge_first_dimensions(V)
+
+        # Compute Attention output
+        output = cls._compute_attention(
+            Q,
+            K,
+            V,
+            mask_index,
+            unidirectional=unidirectional,
+            batch_size=batch_size,
+            out_hidden=out_hidden,
+            num_heads=num_heads,
+            head_size=head_size,
+            seq_len=seq_len,
+            past_seq_len=past_seq_len,
         )
 
-        # Build the attention mask
-        att_mask = _op.cast(mask_index, score_dtype)
-        # Attention mask has value 0 or 1. Here we convert 0 to -10000, and 1 
to 0.
-        att_mask = _op.subtract(_op.const(1, dtype=score_dtype), att_mask)
-        att_mask = _op.multiply(att_mask, _op.const(-10000, dtype=score_dtype))
-        # Expand for att_scores broadcast
-        # (batch_size, past_seq_len + seq_len) -> (batch_size, 1, seq_len, 
past_seq_len + seq_len)
-        att_mask = _op.expand_dims(att_mask, 1, num_newaxis=2)
-        att_mask = _op.concatenate([att_mask] * seq_len, axis=2)
-
-        def create_unidirectional_mask(left_value, right_value):
-            numpy_unidirectional_mask = np.array(
-                [
-                    np.concatenate(
-                        [
-                            np.full(past_seq_len + s_i + 1, left_value),
-                            np.full(seq_len - s_i - 1, right_value),
-                        ]
-                    )
-                    for s_i in range(seq_len)
-                ]
-            )
-            unidirectional_mask = _op.const(numpy_unidirectional_mask, 
dtype=score_dtype)
-            unidirectional_mask = _op.expand_dims(unidirectional_mask, 0, 
num_newaxis=2)
-
-            return unidirectional_mask
-
-        if unidirectional:
-            att_mask = _op.add(att_mask, create_unidirectional_mask(0, -10000))
-
-        # Apply the mask
-        att_scores = _op.add(att_scores, att_mask)
-        # TODO(agladyshev):
-        #   Comment from ORT source code 
(onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h):
-        #   "Fix unidirectional mask to be parity with huggingface 
implementation"
-        if unidirectional:
-            att_scores = _op.multiply(att_scores, 
create_unidirectional_mask(1, 0))
-            att_scores = _op.add(att_scores, create_unidirectional_mask(0, 
-10000))
-
-        # Compute Softmax
-        att_scores = _op.reshape(
-            att_scores, (batch_size * num_heads, seq_len, past_seq_len + 
seq_len)
-        )
-        att_probs = _op.nn.softmax(att_scores, axis=-1)
-
-        # Compute output
-        output = _op.nn.batch_matmul(att_probs, V, transpose_a=False, 
transpose_b=False)
-        output = _op.reverse_reshape(output, (-1, num_heads, 0, 0))
-        output = _op.transpose(output, axes=[0, 2, 1, 3])
-        output = _op.reshape(output, (0, 0, out_hidden))
-
         return _expr.TupleWrapper(_expr.Tuple([output, present]), 2)
 
 
diff --git a/tests/python/frontend/onnx/test_forward.py 
b/tests/python/frontend/onnx/test_forward.py
index a84de82f3b..f5b5f7c65c 100644
--- a/tests/python/frontend/onnx/test_forward.py
+++ b/tests/python/frontend/onnx/test_forward.py
@@ -5878,30 +5878,47 @@ def test_embedlayernormalization(target, dev):
 def test_attention(target, dev):
     """test_attention"""
 
-    def verify_attention(input_, weight, bias, mask_index, num_heads):
+    def verify_attention(_unidirectional, _input, _weight, _bias, 
_mask_index=None, _past=None):
+        input_names = ["input", "weight", "bias"]
+        if _mask_index is not None:
+            input_names.append("mask_index")
+        if _past is not None:
+            input_names.append("past")
+
         node = onnx.helper.make_node(
             "Attention",
-            inputs=["input", "weight", "bias", "mask_index"],
+            inputs=input_names,
             outputs=["output", "present"],
             domain="com.microsoft",
             num_heads=num_heads,
+            unidirectional=_unidirectional,
         )
 
+        past_shape = (2, batch_size, num_heads, past_sequence_length, 
head_size)
         present_output_shape = (2, batch_size, num_heads, sequence_length, 
head_size)
 
+        inputs_info = [
+            helper.make_tensor_value_info("input", TensorProto.FLOAT, 
list(_input.shape)),
+            helper.make_tensor_value_info("weight", TensorProto.FLOAT, 
list(_weight.shape)),
+            helper.make_tensor_value_info("bias", TensorProto.FLOAT, 
list(_bias.shape)),
+        ]
+        if _mask_index is not None:
+            inputs_info.append(
+                helper.make_tensor_value_info(
+                    "mask_index", TensorProto.INT32, list(_mask_index.shape)
+                ),
+            )
+        if _past is not None:
+            inputs_info.append(
+                helper.make_tensor_value_info("past", TensorProto.FLOAT, 
list(past_shape))
+            )
+
         graph = helper.make_graph(
             [node],
             "attention_test",
-            inputs=[
-                helper.make_tensor_value_info("input", TensorProto.FLOAT, 
list(input_.shape)),
-                helper.make_tensor_value_info("weight", TensorProto.FLOAT, 
list(weight.shape)),
-                helper.make_tensor_value_info("bias", TensorProto.FLOAT, 
list(bias.shape)),
-                helper.make_tensor_value_info(
-                    "mask_index", TensorProto.INT32, list(mask_index.shape)
-                ),
-            ],
+            inputs=inputs_info,
             outputs=[
-                helper.make_tensor_value_info("output", TensorProto.FLOAT, 
list(input_.shape)),
+                helper.make_tensor_value_info("output", TensorProto.FLOAT, 
list(_input.shape)),
                 helper.make_tensor_value_info(
                     "present", TensorProto.FLOAT, list(present_output_shape)
                 ),
@@ -5910,31 +5927,58 @@ def test_attention(target, dev):
 
         model = helper.make_model(graph, producer_name="attention_test")
 
+        inputs = [_input, _weight, _bias]
+        if _mask_index is not None:
+            inputs.append(_mask_index)
+        if _past is not None:
+            inputs.append(_past)
+
         # "present" output should be nullptr when the "past" input isn't 
included,
         # but ort requires an output shape to be specified?
         verify_with_ort_with_inputs(
             model,
-            [input_, weight, bias, mask_index],
-            [input_.shape, present_output_shape],
+            inputs,
+            [_input.shape, present_output_shape],
             target=target,
             dev=dev,
             rtol=1e-4,
             atol=1e-4,
         )
 
-    hidden_size = 384
-    batch_size = 4
-    sequence_length = 4
-    num_heads = 12
-    head_size = 32
+    batch_size = 11
+    num_heads = 13
+    head_size = 37
+    sequence_length = 7
+    input_hidden_size = 147
+    weight_hidden_size = num_heads * head_size
+    past_sequence_length = 17
 
-    dtype = "float32"
-    input_array = np.random.random((batch_size, sequence_length, 
hidden_size)).astype(dtype)
-    weight = np.random.normal(size=(hidden_size, 3 * 
hidden_size)).astype(dtype) * 0.1
-    bias = np.random.randn(3 * hidden_size).astype(dtype)
-    mask_index = np.full((batch_size, sequence_length), 1).astype("int32")
+    total_sequence_length = past_sequence_length + sequence_length
 
-    verify_attention(input_array, weight, bias, mask_index, num_heads)
+    # Required inputs
+    input_array = np.random.normal(size=(batch_size, sequence_length, 
input_hidden_size)).astype(
+        "float32"
+    )
+    weight = (
+        np.random.normal(size=(input_hidden_size, 3 * 
weight_hidden_size)).astype("float32") * 0.1
+    )
+    bias = np.random.randn(3 * weight_hidden_size).astype("float32")
+
+    # Optional inputs
+    past = np.random.random((2, batch_size, num_heads, past_sequence_length, 
head_size)).astype(
+        "float32"
+    )
+
+    for unidirectional in [0, 1]:
+        for have_past in [False, True]:
+            if not have_past:
+                mask_index = np.random.randint(0, 2, (batch_size, 
sequence_length)).astype("int32")
+                verify_attention(unidirectional, input_array, weight, bias, 
mask_index)
+            else:
+                mask_index = np.random.randint(0, 2, (batch_size, 
total_sequence_length)).astype(
+                    "int32"
+                )
+                verify_attention(unidirectional, input_array, weight, bias, 
mask_index, past)
 
 
 @tvm.testing.parametrize_targets

Reply via email to