KJlaccHoeUM9l commented on code in PR #13654:
URL: https://github.com/apache/tvm/pull/13654#discussion_r1058146999


##########
python/tvm/relay/frontend/onnx.py:
##########
@@ -1379,6 +1379,298 @@ def massage(tensor):
         return _expr.TupleWrapper(_expr.Tuple([output, present]), 2)
 
 
+class QAttention(OnnxOpConverter):
+    """Operator converter for QAttention from Microsoft onnxruntime contrib 
opset.
+
+    This is the self-attention mechanism used in transformer models.
+    """
+
+    @classmethod
+    def _impl_v1(cls, inputs, attr, params):
+        # ************************* Read attrs *************************
+        num_heads = attr["num_heads"]
+        unidirectional = attr["unidirectional"]
+
+        # ************************* Read inputs *************************
+        # (batch, seq, in_hidden)
+        input_emb = inputs[0]
+
+        # (in_hidden, 3 * out_hidden), where out_hidden = num_heads * head_size
+        weight = inputs[1]
+
+        # (3 * out_hidden,)
+        bias = inputs[2]
+
+        # Scalar, which means a per-tensor/layer quantization
+        input_scale = inputs[3]
+
+        # Scalar or a 1D tensor, which means a per-tensor/per-column 
quantization.
+        # Its size should be 3 * out_hidden if it is per-column quantization
+        weight_scale = inputs[4]
+
+        # TODO(agladyshev):
+        #  ORT documentation says that shape is (batch,),
+        #  but in ORT source code we have following comment:
+        #       1. (batch_size)
+        #       2. (2 * batch_size)
+        #       3. (batch_size, 1)
+        #       4. (1, 1)
+        #       5. (batch_size, past_sequence_length + sequence_length)
+        #  In practice, for GPT-2 there shape is (batch, past_seq_length + 
seq_length).
+        #  Currently only (batch, past_seq_length + seq_length) shape is 
supported.
+        mask_index = inputs[5]
+
+        # Scalar, which means a per-tensor/layer quantization

Review Comment:
   done



##########
python/tvm/relay/frontend/onnx.py:
##########
@@ -1379,6 +1379,298 @@ def massage(tensor):
         return _expr.TupleWrapper(_expr.Tuple([output, present]), 2)
 
 
+class QAttention(OnnxOpConverter):
+    """Operator converter for QAttention from Microsoft onnxruntime contrib 
opset.
+
+    This is the self-attention mechanism used in transformer models.
+    """
+
+    @classmethod
+    def _impl_v1(cls, inputs, attr, params):
+        # ************************* Read attrs *************************
+        num_heads = attr["num_heads"]
+        unidirectional = attr["unidirectional"]
+
+        # ************************* Read inputs *************************
+        # (batch, seq, in_hidden)
+        input_emb = inputs[0]
+
+        # (in_hidden, 3 * out_hidden), where out_hidden = num_heads * head_size
+        weight = inputs[1]
+
+        # (3 * out_hidden,)
+        bias = inputs[2]
+
+        # Scalar, which means a per-tensor/layer quantization
+        input_scale = inputs[3]
+
+        # Scalar or a 1D tensor, which means a per-tensor/per-column 
quantization.
+        # Its size should be 3 * out_hidden if it is per-column quantization
+        weight_scale = inputs[4]
+
+        # TODO(agladyshev):
+        #  ORT documentation says that shape is (batch,),
+        #  but in ORT source code we have following comment:
+        #       1. (batch_size)
+        #       2. (2 * batch_size)
+        #       3. (batch_size, 1)
+        #       4. (1, 1)
+        #       5. (batch_size, past_sequence_length + sequence_length)
+        #  In practice, for GPT-2 there shape is (batch, past_seq_length + 
seq_length).
+        #  Currently only (batch, past_seq_length + seq_length) shape is 
supported.
+        mask_index = inputs[5]
+
+        # Scalar, which means a per-tensor/layer quantization
+        input_zero_point = inputs[6]
+
+        # Scalar or a 1D tensor, which means a per-tensor/per-column 
quantization.
+        # Its size should be 3 * out_hidden if it is per-column quantization
+        weight_zero_point = inputs[7]
+
+        # (2, batch, num_heads, past_seq, head_size)
+        past = inputs[8]
+
+        # ************************* Parse inputs *************************
+        t1 = ["int8", "uint8"]
+        t2 = ["int8", "uint8"]
+        t3 = ["float32", "float16"]
+        t4 = ["int32"]
+
+        # input
+        assert infer_type(input_emb).checked_type.dtype in t1
+        assert (
+            len(infer_shape(input_emb)) == 3
+        ), "Input should be 3D tensor with shape (batch_size, sequence_length, 
input_hidden_size)"
+        (batch_size, seq_len, input_hidden) = infer_shape(input_emb)
+        assert input_hidden > 0, (
+            "The weight tensor has (input_hidden_size, 3 * output_hidden_size) 
shape, so it doesn't"
+            f" make sense to have ({input_hidden}, 3 * output_hidden_size) 
weight tensor."
+        )
+        assert seq_len > 0, (
+            "The output tensor has (batch_size, sequence_length, hidden_size) 
shape,"
+            f" so it doesn't make sense to have (batch_size, {seq_len}, 
hidden_size) output."
+        )
+
+        # weight
+        assert infer_type(weight).checked_type.dtype in t2
+        assert len(infer_shape(weight)) == 2, (
+            "Weight should be 2D input tensor with shape (input_hidden_size, 3 
* hidden_size), "
+            "hidden_size = num_heads * head_size"
+        )
+        (input_hidden_weight, out_hidden_x3) = infer_shape(weight)
+        assert input_hidden == input_hidden_weight
+        assert out_hidden_x3 % 3 == 0, "output hidden shape should be 
divisible by 3: W_Q, W_K, W_V"
+        out_hidden = out_hidden_x3 // 3
+        assert (
+            out_hidden % num_heads == 0
+        ), "output hidden size should be divisible by number of attention 
heads"
+        head_size = out_hidden // num_heads
+
+        # bias
+        assert infer_type(bias).checked_type.dtype in t3
+        assert (
+            len(infer_shape(bias)) == 1
+        ), "Bias should be 1D input tensor with shape (3 * hidden_size)"
+        (out_hidden_x3_bias,) = infer_shape(bias)
+        assert out_hidden_x3 == out_hidden_x3_bias
+
+        # input_scale
+        assert infer_type(input_scale).checked_type.dtype in t3
+        input_scale = get_scalar(
+            input_scale, params, 
dtype=infer_type(input_scale).checked_type.dtype
+        )
+
+        # weight_scale
+        assert infer_type(weight_scale).checked_type.dtype in t3
+        # TODO(agladyshev): now QNN Batch Matmul only supports scalar types 
for scale and zero_point
+        weight_scale = get_scalar(
+            weight_scale, params, 
dtype=infer_type(weight_scale).checked_type.dtype
+        )
+
+        # mask_index
+        assert (
+            mask_index is not None
+        ), "Attention import currently only supports required mask_index"
+        assert infer_type(mask_index).checked_type.dtype in t4
+        mask_index_shape = infer_shape(mask_index)
+        assert (
+            len(mask_index_shape) == 2
+            and mask_index_shape[0] == batch_size
+            and mask_index_shape[1] >= seq_len
+        ), "currently only support (batch_size, sequence_length) mask index"
+
+        # TODO(agladyshev): int32 required for qnn.batch_matmul 
(QnnBatchMatmulRel)
+        zero_point_zero = _expr.const(0, "int32")
+
+        # input_zero_point
+        if input_zero_point is None:
+            input_zero_point = zero_point_zero
+        else:
+            assert infer_type(input_zero_point).checked_type.dtype in t1
+            # TODO(agladyshev): int32 required for qnn.batch_matmul 
(QnnBatchMatmulRel)
+            input_zero_point = get_scalar(input_zero_point, params, 
dtype="int32")
+
+        # weight_zero_point
+        if weight_zero_point is None:
+            weight_zero_point = zero_point_zero
+        else:
+            assert infer_type(weight_zero_point).checked_type.dtype in t2
+            # TODO(agladyshev): int32 required for qnn.batch_matmul 
(QnnBatchMatmulRel)
+            weight_zero_point = get_scalar(weight_zero_point, params, 
dtype="int32")
+
+        # past (2, batch_size, num_heads, past_sequence_length, head_size)
+        past_seq_len = 0
+        if past is not None:
+            assert infer_type(past).checked_type.dtype in t3
+            past_shape = infer_shape(past)
+            assert len(past_shape) == 5, "past should be 5D tensor"
+            assert (
+                past_shape[0] == 2
+                and past_shape[1] == batch_size
+                and past_shape[2] == num_heads
+                and past_shape[3] + seq_len == mask_index_shape[1]
+                and past_shape[4] == head_size
+            )
+            past_seq_len = past_shape[3]
+
+        # ************************* Create Relay *************************
+        # Add batch dimension for QNN Batch Matmul
+        weight = _op.expand_dims(weight, 0, num_newaxis=1)
+        weight = _op.concatenate([weight] * batch_size, axis=0)
+
+        # Split weight and biases and do the Matmul
+        w_Q, w_K, w_V = _op.split(weight, 3, axis=-1)
+        b_Q, b_K, b_V = _op.split(bias, 3, axis=-1)
+
+        def qmatmul_dequantize_bias(
+            lhs, rhs, lhs_scale, rhs_scale, lhs_zero_point, rhs_zero_point, 
bias
+        ):
+            rhs_transposed = _op.transpose(rhs, axes=[0, 2, 1])  # QNN Batch 
Matmul do: X * Y^T
+            result = _qnn.op.batch_matmul(
+                lhs, rhs_transposed, lhs_zero_point, rhs_zero_point, 
lhs_scale, rhs_scale
+            )
+            result = _qnn.op.dequantize(
+                result,
+                _op.multiply(lhs_scale, rhs_scale),
+                zero_point_zero,
+                axis=-1,  # TODO(agladyshev): what is 'axis' parameter for?

Review Comment:
   done



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to