ekalda commented on code in PR #16981:
URL: https://github.com/apache/tvm/pull/16981#discussion_r1606797062
##########
tests/python/relay/strategy/arm_cpu/test_dense.py:
##########
@@ -102,21 +102,22 @@ class TestDense(BasicDenseTests):
"data_shape,weight_shape",
[
((32, 32), (32, 32)),
- ((2, 35), (6, 35)),
((3, 3), (68, 3)),
+ ((2, 35), (6, 35)),
Review Comment:
I suppose this is an artifact of playing around with the tests and the order
of tests doesn't actually make a difference?
##########
python/tvm/topi/arm_cpu/dense_alter_op.py:
##########
@@ -52,23 +54,25 @@ def _alter_dense(attrs, inputs, tinfos, out_type):
), "matmul_sme.arm_cpu requires weights be a Relay Constant"
weight_dtype = tinfos[1].dtype
- weight_data = inputs[1].data.numpy()
- interleaved = weight_data.transpose()
- encoded_weight = relay.const(interleaved, weight_dtype)
+ encoded_weight = inputs[1]
+ transpose_b = True
Review Comment:
Maybe worth adding a comment there why we need to do these conditional
transposes?
##########
python/tvm/tir/op.py:
##########
@@ -3370,6 +3370,22 @@ def get_active_lane_mask(dtype, base, limit):
return call_intrin(dtype, "tir.get_active_lane_mask", base, limit)
+def get_vscale_factor(dtype: Union[str, tvm.DataType], min_size: int = 128) ->
PrimExpr:
+ """
+ Create a datatype dependent scalable expression.
+
+ Parameters
+ ----------
+ dtype : tvm.DataType
Review Comment:
Nit:
```suggestion
dtype : Union[str, tvm.DataType]
```
##########
python/tvm/tir/op.py:
##########
@@ -3370,6 +3370,22 @@ def get_active_lane_mask(dtype, base, limit):
return call_intrin(dtype, "tir.get_active_lane_mask", base, limit)
+def get_vscale_factor(dtype: Union[str, tvm.DataType], min_size: int = 128) ->
PrimExpr:
Review Comment:
Nit: In rest of the codebase "vscale factor" refers to the integer
multiplier in front of `vscale`, so maybe it would be good not to err from this
convention and rename this to sth else, e.g. `get_vscale_expr`
##########
python/tvm/tir/tensor_intrin/arm_cpu.py:
##########
@@ -295,7 +301,151 @@ def impl():
return desc, impl()
-def get_sme_gemm_interleaved_mopa_2svlx2svl_intrin(K):
+def get_sme_transpose_interleave_block2_2svl_fp16_intrin():
+ # pylint: disable=line-too-long
+ """
+ Transpose and block pack a matrix of size 2SVL x 1SVL (where 'SVL' is the
Scalable Vector
+ Length for the fp16 datatype) using the Scalable Matrix Extension (SME).
+
+ Rows of the fp16 input matrix are loaded into the accumulator tile and
columns are stored
+ as fp32 SVL length vectors to the output matrix. When loading, the
accumulator tile is
+ interpreted to be of shape 2 * 8 * vscale x 8 * vscale. When storing, we
interpret the
+ accumulator tile to be of shape 2 * 4 * vscale x 2 * 4 * vscale.
+
+ Example
+ -------
+ In the fp16 instance, the accumulator tile consists of two sub-tiles
numbered 0-1. Rows
+ of A are loaded onto the accumulator tile by interleaving rows in the
first half (0, SVL//2]
+ of the tile and rows in the second half (SVL//2, SVL]. Columns of fp32
values are stored
+ into the output buffer. The fp32 store is used to group pairs of
consecutive values together,
+ resulting in the arrangement displayed below.
+
Review Comment:
Very nice documentation!
##########
python/tvm/tir/op.py:
##########
@@ -3370,6 +3370,22 @@ def get_active_lane_mask(dtype, base, limit):
return call_intrin(dtype, "tir.get_active_lane_mask", base, limit)
+def get_vscale_factor(dtype: Union[str, tvm.DataType], min_size: int = 128) ->
PrimExpr:
+ """
+ Create a datatype dependent scalable expression.
+
+ Parameters
+ ----------
+ dtype : tvm.DataType
+ Element data type.
+ min_size : int
+ The minimum size of the scalable vector.
Review Comment:
```suggestion
The minimum size of the scalable vector in bits.
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]