This is an automated email from the ASF dual-hosted git repository.
ekalda pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new d3011ab609 [SME] Utilize predication in fp32 matmul and conv2d
schedules (#17054)
d3011ab609 is described below
commit d3011ab609f30ef3363b230bd0f3702ba00aa270
Author: Luke Hutton <[email protected]>
AuthorDate: Fri Jun 14 10:47:00 2024 +0100
[SME] Utilize predication in fp32 matmul and conv2d schedules (#17054)
Prior to this commit, the matmul and conv2d schedules required padding
of the inputs to some multiple of vscale and a final "unpadding" stage.
Instead, we can leverage predicated operations to avoid the
the requirement for padding. Both the transpose interleave and outer
product fp32 intrinsics are updated to use predication. The
`get_active_lane_mask` intrinsic is utilized to generate a variably
sized mask of active lanes depending on the global position the tensor
intrinsic is operating on.
For now this relies on using `offset_of` and `stride` information from
the tensor we're predicating an access on. Likely we will want to
build on this in the future with a more intuitive API for determining
the current tile location.
Support for batched conv2d was removed since this causes numerical
issues which is suspected to be due to how the current tile is
determined (paragraph above).
---
python/tvm/relay/op/strategy/arm_cpu.py | 7 ++
python/tvm/tir/tensor_intrin/arm_cpu.py | 134 +++++++++++++++++----
python/tvm/topi/arm_cpu/conv2d.py | 40 +++---
python/tvm/topi/arm_cpu/conv2d_gemm.py | 39 ++++--
python/tvm/topi/arm_cpu/matmul.py | 58 ++++-----
.../python/codegen/test_target_codegen_aarch64.py | 4 +
tests/python/topi/test_topi_conv2d_nhwc.py | 10 +-
7 files changed, 197 insertions(+), 95 deletions(-)
diff --git a/python/tvm/relay/op/strategy/arm_cpu.py
b/python/tvm/relay/op/strategy/arm_cpu.py
index 35fd2b7a78..f4b4708401 100644
--- a/python/tvm/relay/op/strategy/arm_cpu.py
+++ b/python/tvm/relay/op/strategy/arm_cpu.py
@@ -110,6 +110,8 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type,
target):
"""conv2d arm cpu strategy"""
strategy = _op.OpStrategy()
data, kernel = inputs
+ data_shape = data.shape
+ kernel_shape = kernel.shape
dilation_h, dilation_w = attrs.get_int_tuple("dilation")
stride_h, stride_w = attrs.get_int_tuple("strides")
padding = attrs.get_int_tuple("padding")
@@ -258,6 +260,11 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type,
target):
target.features.has_sme
and kernel.dtype == data.dtype
and out_type.dtype == "float32"
+ and data_shape[0] == 1
+ # The schedule uses tensorization which does not work
when the
+ # reduction axis of the gemm has unit iters. See
+ # https://github.com/apache/tvm/issues/16566
+ and (data_shape[3] * kernel_shape[0] *
kernel_shape[1]) > 1
):
strategy.add_implementation(
wrap_compute_conv2d(topi.arm_cpu.compute_conv2d_NHWC_hybrid_SME),
diff --git a/python/tvm/tir/tensor_intrin/arm_cpu.py
b/python/tvm/tir/tensor_intrin/arm_cpu.py
index 3a3430af51..a6f3538846 100644
--- a/python/tvm/tir/tensor_intrin/arm_cpu.py
+++ b/python/tvm/tir/tensor_intrin/arm_cpu.py
@@ -176,7 +176,51 @@ def _create_ptrue_mask(dtype):
return T.broadcast(T.bool(True), tir.get_vscale_expr(dtype))
-def get_sme_transpose_interleave_2svlx2svl_fp32_intrin():
+def _create_active_lane_mask(tensor, relative_offsets, vertical_limit):
+ """
+ Get the active lane mask intrinsic call for predicated accesses.
+
+ Parameters
+ ----------
+ tensor : tvm.tir.Buffer
+ The tensor the buffer access will be performed on.
+ relative_offsets : Tuple[PrimExpr, PrimExpr]
+ The vertical and horizontal offsets into the accumulator tile.
+ vertical_limit : PrimExpr
+ An absolute offset specifying the limit at which rows should be stored.
+
+ Returns
+ -------
+ PrimExpr
+ The active lane mask intrinsic.
+ """
+ vertical_offset, horizontal_offset = relative_offsets
+ stride = tensor.strides[0]
+
+ # The base is the offset of the first value we wish to store
+ base = T.int32(tensor.offset_of([vertical_offset, horizontal_offset])[0])
+
+ # The limit is the maximum offset in the current row of 'base' that we
wish to allow values
+ # to be stored. Calculating this limit is a bit tricky since we can only
request offsets of
+ # elements in the tensorized tile of the output tensor. One way to
calculate this is to find
+ # the offset of the first value in the row of the output tensor that
'base' is in and add
+ # 'stride' to it.
+ limit = (
+ base
+ - T.int32(horizontal_offset)
+ - T.int32((tensor.offset_of([0, 0])[0] % stride))
+ + T.int32(stride)
+ )
+ limit = T.Min(limit, T.Cast("int32", vertical_limit) * stride)
+
+ return T.get_active_lane_mask(
+ "uint1xvscalex4",
+ T.Cast("int32", base),
+ T.Cast("int32", limit),
+ )
+
+
+def get_sme_transpose_interleave_2svlx2svl_fp32_intrin(cols, rows):
"""
Transpose a matrix of size 2SVL x 2SVL (where 'SVL' is the Scalable Vector
Length) using
the Scalable Matrix Extension (SME).
@@ -247,9 +291,6 @@ def get_sme_transpose_interleave_2svlx2svl_fp32_intrin():
strides=[T.int32(), 1],
)
- # Disable predication
- ptrue = _create_ptrue_mask("float32")
-
with T.block("root"):
T.reads(A[0:SVF2, 0:SVF2])
T.writes(A_t[0:SVF2, 0:SVF2])
@@ -263,19 +304,22 @@ def get_sme_transpose_interleave_2svlx2svl_fp32_intrin():
input_ptr = A.access_ptr("r", offset=offset)
sub_tile = T.int32(sub_tile_idx)
+ predicate = _create_active_lane_mask(
+ A, (row_offset + slice_idx, col_offset), cols
+ )
T.evaluate(
T.call_llvm_intrin(
"void",
"llvm.aarch64.sme.ld1w.horiz",
T.uint32(4),
- ptrue,
+ predicate,
input_ptr,
sub_tile,
slice_idx,
)
)
- # Store columns to the ouptut matrix
+ # Store columns to the output matrix
with T.serial(0, SVF) as slice_idx:
for sub_tile_idx in range(0, sub_tile_count):
col_offset = SVF if sub_tile_idx >=
(sub_tile_count // 2) else 0
@@ -284,12 +328,15 @@ def get_sme_transpose_interleave_2svlx2svl_fp32_intrin():
output_ptr = A_t.access_ptr("w", offset=offset)
sub_tile = T.int32(sub_tile_idx)
+ predicate = _create_active_lane_mask(
+ A_t, (row_offset + slice_idx, col_offset), rows
+ )
T.evaluate(
T.call_llvm_intrin(
"void",
"llvm.aarch64.sme.st1w.vert",
T.uint32(4),
- ptrue,
+ predicate,
output_ptr,
sub_tile,
slice_idx,
@@ -445,7 +492,24 @@ def get_sme_transpose_interleave_block2_2svl_fp16_intrin():
return desc, impl()
-def get_sme_gemm_interleaved_mopa_2svlx2svl_intrin(K, in_dtype):
+def get_transpose_interleave_intrin_name(in_dtype, out_dtype, extent_cols,
extent_rows):
+ if in_dtype == "float32" and out_dtype == "float32":
+ sme_transpose_interleave_intrin_name = (
+ ARM_SME_2SVLx2SVL_FP32_TRANSPOSE_INTERLEAVE +
f"_{extent_cols}_{extent_rows}"
+ )
+ tir.TensorIntrin.register(
+ sme_transpose_interleave_intrin_name,
+ *get_sme_transpose_interleave_2svlx2svl_fp32_intrin(extent_cols,
extent_rows),
+ override=True,
+ )
+ return sme_transpose_interleave_intrin_name
+ elif in_dtype == "float16" and out_dtype == "float32":
+ return ARM_SME_BLOCK2_2SVLx1SVL_FP16_TRANSPOSE_INTERLEAVE
+ else:
+ raise ValueError("Input/output data type combination not supported.")
+
+
+def get_sme_gemm_interleaved_mopa_2svlx2svl_intrin(M, K, in_dtype):
"""
Compute a GEMM of size 2SVL x 2SVL (where 'SVL' is the Scalable Vector
Length using
outer product operations from the Scalable Matrix Extension (SME).
@@ -579,15 +643,39 @@ def get_sme_gemm_interleaved_mopa_2svlx2svl_intrin(K,
in_dtype):
k_row = k * rows_per_iter
in_dtype_svf = tir.get_vscale_expr(in_dtype)
- a_low = T.BufferLoad(A, [k_row, T.Ramp(0, 1,
in_dtype_svf)])
- b_low = T.BufferLoad(B, [k_row, T.Ramp(0, 1,
in_dtype_svf)])
-
+ # Ideally we'd rely on predicating the loads and use
the same predicate
+ # for the outer product operation. However, support
for predicated
+ # buffers is not currently supported by multiple
lowering passes such as
+ # "LowerMatchBuffer", therefore the predicate is
passed directly to the
+ # outer product operation for now.
if in_dtype == "float32":
- a_high = T.BufferLoad(A, [k_row,
T.Ramp(in_dtype_svf, 1, in_dtype_svf)])
- b_high = T.BufferLoad(B, [k_row,
T.Ramp(in_dtype_svf, 1, in_dtype_svf)])
+ a_low = (
+ T.BufferLoad(A, [k_row, T.Ramp(0, 1,
in_dtype_svf)]),
+ _create_active_lane_mask(A, (k_row, 0), K),
+ )
+ b_low = (
+ T.BufferLoad(B, [k_row, T.Ramp(0, 1,
in_dtype_svf)]),
+ _create_active_lane_mask(B, (k_row, 0), K),
+ )
+ a_high = (
+ T.BufferLoad(A, [k_row, T.Ramp(in_dtype_svf,
1, in_dtype_svf)]),
+ _create_active_lane_mask(A, (k_row,
in_dtype_svf), K),
+ )
+ b_high = (
+ T.BufferLoad(B, [k_row, T.Ramp(in_dtype_svf,
1, in_dtype_svf)]),
+ _create_active_lane_mask(B, (k_row,
in_dtype_svf), K),
+ )
else:
- a_high = T.BufferLoad(A, [k_row + 1, T.Ramp(0, 1,
in_dtype_svf)])
- b_high = T.BufferLoad(B, [k_row + 1, T.Ramp(0, 1,
in_dtype_svf)])
+ a_low = (T.BufferLoad(A, [k_row, T.Ramp(0, 1,
in_dtype_svf)]), ptrue)
+ b_low = (T.BufferLoad(B, [k_row, T.Ramp(0, 1,
in_dtype_svf)]), ptrue)
+ a_high = (
+ T.BufferLoad(A, [k_row + 1, T.Ramp(0, 1,
in_dtype_svf)]),
+ ptrue,
+ )
+ b_high = (
+ T.BufferLoad(B, [k_row + 1, T.Ramp(0, 1,
in_dtype_svf)]),
+ ptrue,
+ )
input_combinations = [
(a_low, b_low),
@@ -606,10 +694,10 @@ def get_sme_gemm_interleaved_mopa_2svlx2svl_intrin(K,
in_dtype):
fmopa_intrin,
T.uint32(5),
sub_tile,
- ptrue,
- ptrue,
- input_1,
- input_2,
+ input_1[1],
+ input_2[1],
+ input_1[0],
+ input_2[0],
)
)
@@ -626,7 +714,9 @@ def get_sme_gemm_interleaved_mopa_2svlx2svl_intrin(K,
in_dtype):
"void",
"llvm.aarch64.sme.st1w.horiz",
T.uint32(4),
- _create_ptrue_mask("float32"),
+ _create_active_lane_mask(
+ C, (vert_offset + slice_idx,
horiz_offset), M
+ ),
output_ptr,
T.int32(sub_tile_idx),
T.int32(slice_idx),
@@ -691,10 +781,6 @@ ARM_SME_2SVLx2SVL_GEMM_INTERLEAVED_MOPA =
"sme_2svlx2svl_gemm_interleaved_mopa"
# in versions of LLVM >= 15. Installations with older versions of LLVM will
# not be able to use them.
if llvm_version_major() >= 15:
- TensorIntrin.register(
- ARM_SME_2SVLx2SVL_FP32_TRANSPOSE_INTERLEAVE,
- *get_sme_transpose_interleave_2svlx2svl_fp32_intrin(),
- )
TensorIntrin.register(
ARM_SME_BLOCK2_2SVLx1SVL_FP16_TRANSPOSE_INTERLEAVE,
*get_sme_transpose_interleave_block2_2svl_fp16_intrin(),
diff --git a/python/tvm/topi/arm_cpu/conv2d.py
b/python/tvm/topi/arm_cpu/conv2d.py
index a6c951c078..b7327d5b52 100644
--- a/python/tvm/topi/arm_cpu/conv2d.py
+++ b/python/tvm/topi/arm_cpu/conv2d.py
@@ -24,7 +24,6 @@ from tvm import autotvm
from tvm.script import tir as T
import tvm.contrib.nnpack
from tvm.tir.schedule.analysis import has_block
-from tvm.topi.arm_cpu.matmul import _get_transpose_interleave_intrin_name
from ..utils import traverse_inline, get_const_tuple
from .. import nn
@@ -773,10 +772,7 @@ def schedule_conv2d_NHWC_hybrid_TIR(sch: tvm.tir.Schedule):
ARM_SME_2SVLx2SVL_GEMM_INTERLEAVED_MOPA,
ARM_SME_INIT,
get_sme_gemm_interleaved_mopa_2svlx2svl_intrin,
- )
-
- transpose_interleave_intrin_name =
_get_transpose_interleave_intrin_name(
- in_dtype, out_dtype
+ get_transpose_interleave_intrin_name,
)
# Interleave the padded im2col matrix utilizing the matrix tile
@@ -787,7 +783,9 @@ def schedule_conv2d_NHWC_hybrid_TIR(sch: tvm.tir.Schedule):
ko, ki = sch.split(k, factors=(None, tile_K), disable_predication=True)
sch.parallel(b)
sch.reorder(b, ko, mo, ki, mi)
- sch.tensorize(ki, transpose_interleave_intrin_name)
+ sch.tensorize(
+ ki, get_transpose_interleave_intrin_name(in_dtype, out_dtype,
M_padded, K_padded)
+ )
# Interleave the padded weights matrix utilizing the matrix tile
if in_dtype == "float16":
@@ -797,7 +795,9 @@ def schedule_conv2d_NHWC_hybrid_TIR(sch: tvm.tir.Schedule):
ko, ki = sch.split(k, factors=(None, tile_K),
disable_predication=True)
no, ni = sch.split(n, factors=(None, tile_N),
disable_predication=True)
sch.reorder(ko, no, ki, ni)
- sch.tensorize(ki, transpose_interleave_intrin_name)
+ sch.tensorize(
+ ki, get_transpose_interleave_intrin_name(in_dtype, out_dtype,
M_padded, K_padded)
+ )
# Split and reorder the loops of the GeMM for tensorization
b, m, n, k = sch.get_loops(gemm_block)
@@ -816,11 +816,11 @@ def schedule_conv2d_NHWC_hybrid_TIR(sch:
tvm.tir.Schedule):
# Tensorize the GeMM update
sme_gemm_interleaved_intrin_name = (
- ARM_SME_2SVLx2SVL_GEMM_INTERLEAVED_MOPA + f"_{K_padded}_{in_dtype}"
+ ARM_SME_2SVLx2SVL_GEMM_INTERLEAVED_MOPA +
f"_{M_padded}_{K_padded}_{in_dtype}"
)
tvm.tir.TensorIntrin.register(
sme_gemm_interleaved_intrin_name,
- *get_sme_gemm_interleaved_mopa_2svlx2svl_intrin(K_padded,
in_dtype),
+ *get_sme_gemm_interleaved_mopa_2svlx2svl_intrin(M_padded,
K_padded, in_dtype),
override=True,
)
sch.tensorize(mi, sme_gemm_interleaved_intrin_name)
@@ -922,16 +922,18 @@ def schedule_conv2d_NHWC_hybrid_TIR(sch:
tvm.tir.Schedule):
reshape_block = func_blocks["T_reshape"]
A_pad_block = func_blocks["A_padded_K"] if func_blocks["A_padded_K"]
else None
A_pad_block = func_blocks["A_padded_M"] if func_blocks["A_padded_M"]
else A_pad_block
- if use_sme:
- sch.compute_inline(reshape_block)
- elif A_pad_block:
- sch.compute_inline(reshape_block)
- b, m, k = sch.get_loops(A_pad_block)
- _, k_inner = sch.split(k, [None, tile_N])
- sch.vectorize(k_inner)
- sch.compute_at(A_pad_block, mi)
- else:
- sch.compute_at(reshape_block, mi)
+ use_explicit_predication = use_sme and in_dtype == "float32"
+ if not use_explicit_predication:
+ if use_sme:
+ sch.compute_inline(reshape_block)
+ elif A_pad_block:
+ sch.compute_inline(reshape_block)
+ b, m, k = sch.get_loops(A_pad_block)
+ _, k_inner = sch.split(k, [None, tile_N])
+ sch.vectorize(k_inner)
+ sch.compute_at(A_pad_block, mi)
+ else:
+ sch.compute_at(reshape_block, mi)
# Weight flattening
if func_blocks["weight_flatten"]:
diff --git a/python/tvm/topi/arm_cpu/conv2d_gemm.py
b/python/tvm/topi/arm_cpu/conv2d_gemm.py
index e637aa91e5..bf6a9c7551 100644
--- a/python/tvm/topi/arm_cpu/conv2d_gemm.py
+++ b/python/tvm/topi/arm_cpu/conv2d_gemm.py
@@ -133,23 +133,25 @@ def compute_conv2d_gemm_without_weight_transform(
)
# Pad to tiles (if necessary)
- pad_M, pad_K = arm_utils.get_conv2d_im2col_padding(M, K, tile_M, tile_K_A)
- pad_N, _ = arm_utils.get_conv2d_weights_padding(N, K, tile_N, tile_K_B)
+ use_explicit_predication = use_sme and in_dtype == "float32"
+ if not use_explicit_predication:
+ pad_M, pad_K = arm_utils.get_conv2d_im2col_padding(M, K, tile_M,
tile_K_A)
+ pad_N, _ = arm_utils.get_conv2d_weights_padding(N, K, tile_N, tile_K_B)
- M_padded = M + pad_M
- K_padded = K + pad_K
- N_padded = N + pad_N
+ M_padded = M + pad_M
+ K_padded = K + pad_K
+ N_padded = N + pad_N
- pad_before = (0, 0, 0)
- pad_after = (0, pad_M, pad_K)
+ pad_before = (0, 0, 0)
+ pad_after = (0, pad_M, pad_K)
- if pad_K != 0:
- A = nn.pad(A, pad_before=pad_before, pad_after=pad_after,
name="A_padded_K")
- elif pad_M != 0:
- A = nn.pad(A, pad_before=pad_before, pad_after=pad_after,
name="A_padded_M")
+ if pad_K != 0:
+ A = nn.pad(A, pad_before=pad_before, pad_after=pad_after,
name="A_padded_K")
+ elif pad_M != 0:
+ A = nn.pad(A, pad_before=pad_before, pad_after=pad_after,
name="A_padded_M")
idxm = tvm.tir.indexmod
- k = te.reduce_axis((0, K_padded), "k")
+ k = te.reduce_axis((0, K if use_explicit_predication else K_padded), "k")
# Determine matrix multiplication compute definition
target = Target.current(allow_none=False)
@@ -300,7 +302,18 @@ def compute_conv2d_gemm_without_weight_transform(
name="C",
)
zero = tvm.tir.const(0)
- elif use_scalable_vectors or use_sme:
+ elif use_explicit_predication:
+ assert len(B_interleaved_t.shape) == 2
+ C = te.compute(
+ (batches, M, N),
+ lambda b, x, y: te.sum(
+ A[b, x, k].astype(in_dtype) * B_interleaved_t[k,
y].astype(in_dtype),
+ axis=k,
+ ),
+ name="C",
+ )
+ zero = tvm.tir.const(0)
+ elif use_scalable_vectors:
assert len(B_interleaved_t.shape) == 2
C = te.compute(
(batches, M_padded, N_padded),
diff --git a/python/tvm/topi/arm_cpu/matmul.py
b/python/tvm/topi/arm_cpu/matmul.py
index 23b8734a0b..63f6289f0e 100644
--- a/python/tvm/topi/arm_cpu/matmul.py
+++ b/python/tvm/topi/arm_cpu/matmul.py
@@ -53,19 +53,16 @@ def compute_matmul_sme(cfg, data_a, data_b, _, out_dtype,
transpose_a=False, tra
tile_k *= 2
tile_n = 2 * tvm.tir.get_vscale_expr(data_a.dtype)
- M_padded, pad_M = pad_dim_to_multiple(M, tile_m)
- _, pad_K = pad_dim_to_multiple(K, tile_k)
- N_padded, pad_N = pad_dim_to_multiple(N, tile_n)
-
- m_pad_after = (pad_M, pad_K)
- n_pad_after = (pad_K, pad_N)
- if transpose_b:
- n_pad_after = (pad_N, pad_K)
-
- if pad_M != 0:
- data_a = nn.pad(data_a, pad_before=(0, 0), pad_after=m_pad_after)
- if pad_N != 0:
- data_b = nn.pad(data_b, pad_before=(0, 0), pad_after=n_pad_after)
+ if data_a.dtype == "float16":
+ _, pad_M = pad_dim_to_multiple(M, tile_m)
+ _, pad_K = pad_dim_to_multiple(K, tile_k)
+ _, pad_N = pad_dim_to_multiple(N, tile_n)
+ m_pad_after = (pad_M, pad_K)
+ n_pad_after = (pad_N, pad_K) if transpose_b else (pad_K, pad_N)
+ if pad_M != 0:
+ data_a = nn.pad(data_a, pad_before=(0, 0), pad_after=m_pad_after)
+ if pad_N != 0:
+ data_b = nn.pad(data_b, pad_before=(0, 0), pad_after=n_pad_after)
if out_dtype is None:
out_dtype = data_a.dtype
@@ -87,28 +84,12 @@ def compute_matmul_sme(cfg, data_a, data_b, _, out_dtype,
transpose_a=False, tra
(False, False): "T_matmul_NN",
}[(transpose_a, transpose_b)]
- C = te.compute(
- (M_padded, N_padded),
+ return te.compute(
+ (M, N),
compute,
name=compute_name,
attrs={"schedule_type": "sme"},
)
- return te.compute((M, N), lambda m, n: C[m, n])
-
-
-def _get_transpose_interleave_intrin_name(in_dtype, out_dtype):
- # pylint: disable=import-outside-toplevel
- from tvm.tir.tensor_intrin.arm_cpu import (
- ARM_SME_2SVLx2SVL_FP32_TRANSPOSE_INTERLEAVE,
- ARM_SME_BLOCK2_2SVLx1SVL_FP16_TRANSPOSE_INTERLEAVE,
- )
-
- if in_dtype == "float32" and out_dtype == "float32":
- return ARM_SME_2SVLx2SVL_FP32_TRANSPOSE_INTERLEAVE
- elif in_dtype == "float16" and out_dtype == "float32":
- return ARM_SME_BLOCK2_2SVLx1SVL_FP16_TRANSPOSE_INTERLEAVE
- else:
- raise ValueError("Input/output data type combination not supported.")
def tir_schedule_matmul_sme(sch):
@@ -120,6 +101,7 @@ def tir_schedule_matmul_sme(sch):
ARM_SME_2SVLx2SVL_GEMM_INTERLEAVED_MOPA,
ARM_SME_INIT,
get_sme_gemm_interleaved_mopa_2svlx2svl_intrin,
+ get_transpose_interleave_intrin_name,
)
main_func = sch.mod["main"]
@@ -157,9 +139,9 @@ def tir_schedule_matmul_sme(sch):
outer_m, inner_m = sch.split(m, factors=(None, tile_m),
disable_predication=True)
outer_k, inner_k = sch.split(k, factors=(None, tile_k),
disable_predication=True)
sch.reorder(outer_k, outer_m, inner_k, inner_m)
-
- transpose_interleave_intrin_name =
_get_transpose_interleave_intrin_name(in_dtype, out_dtype)
- sch.tensorize(inner_k, transpose_interleave_intrin_name)
+ sch.tensorize(
+ inner_k, get_transpose_interleave_intrin_name(in_dtype, out_dtype,
extent_m, extent_k)
+ )
# Interleave the weights utilizing the matrix tile
if transpose_b:
@@ -169,7 +151,9 @@ def tir_schedule_matmul_sme(sch):
outer_k, inner_k = sch.split(k, factors=(None, tile_k),
disable_predication=True)
outer_n, inner_n = sch.split(n, factors=(None, tile_n),
disable_predication=True)
sch.reorder(outer_k, outer_n, inner_k, inner_n)
- sch.tensorize(inner_k, transpose_interleave_intrin_name)
+ sch.tensorize(
+ inner_k, get_transpose_interleave_intrin_name(in_dtype, out_dtype,
extent_k, extent_n)
+ )
# Split and reorder the loops of the GeMM for tensorization
tile_m = T.cast(2 * tvm.tir.get_vscale_expr(out_dtype), extent_m.dtype)
@@ -185,11 +169,11 @@ def tir_schedule_matmul_sme(sch):
# Tensorize the GeMM update
sme_gemm_interleaved_intrin_name = (
- ARM_SME_2SVLx2SVL_GEMM_INTERLEAVED_MOPA + f"_{extent_k}_{in_dtype}"
+ ARM_SME_2SVLx2SVL_GEMM_INTERLEAVED_MOPA +
f"_{extent_m}_{extent_k}_{in_dtype}"
)
tvm.tir.TensorIntrin.register(
sme_gemm_interleaved_intrin_name,
- *get_sme_gemm_interleaved_mopa_2svlx2svl_intrin(extent_k, in_dtype),
+ *get_sme_gemm_interleaved_mopa_2svlx2svl_intrin(extent_m, extent_k,
in_dtype),
override=True,
)
sch.tensorize(inner_m, sme_gemm_interleaved_intrin_name)
diff --git a/tests/python/codegen/test_target_codegen_aarch64.py
b/tests/python/codegen/test_target_codegen_aarch64.py
index 9b0408b949..f596549a10 100644
--- a/tests/python/codegen/test_target_codegen_aarch64.py
+++ b/tests/python/codegen/test_target_codegen_aarch64.py
@@ -530,12 +530,14 @@ def test_matmul_sme(dtype):
)
stores = re.findall(r"st1[whdb]\t{\s?za", assembly)
smstop = re.findall(r"smstop\t(sm|za)", assembly)
+ whilelo = re.findall(r"whilelo\tp[0-9].[shdb]", assembly)
assert len(smstart) > 0
assert len(loads) > 0
assert len(mopa) > 0
assert len(stores) > 0
assert len(smstop) > 0
+ assert len(whilelo) > 0
check_correct_assembly(dtype=dtype)
@@ -819,12 +821,14 @@ def test_conv2d_sme(dtype):
)
stores = re.findall(r"st1[whdb]\t{\s?za", assembly)
smstop = re.findall(r"smstop\t(sm|za)", assembly)
+ whilelo = re.findall(r"whilelo\tp[0-9].[shdb]", assembly)
assert len(smstart) > 0
assert len(loads) > 0
assert len(mopa) > 0
assert len(stores) > 0
assert len(smstop) > 0
+ assert len(whilelo) > 0
with tvm.target.Target(target):
check_correct_assembly(dtype=dtype)
diff --git a/tests/python/topi/test_topi_conv2d_nhwc.py
b/tests/python/topi/test_topi_conv2d_nhwc.py
index d46db1b28b..e7009ed179 100644
--- a/tests/python/topi/test_topi_conv2d_nhwc.py
+++ b/tests/python/topi/test_topi_conv2d_nhwc.py
@@ -168,10 +168,16 @@ def test_conv2d_nhwc_gemm(device, ref_data, dtype,
stride, padding, dilation):
target = tvm.target.Target(target_string)
if target.features.has_sve and llvm_version_major() < 15:
- pytest.skip(f"LLVM {llvm_version_major()} does not support targetting
SVE.")
+ pytest.skip(f"LLVM {llvm_version_major()} does not support targeting
SVE.")
if target.features.has_sme and llvm_version_major() < 16:
- pytest.skip(f"LLVM {llvm_version_major()} does not support targetting
SME.")
+ pytest.skip(f"LLVM {llvm_version_major()} does not support targeting
SME.")
+
+ if target.features.has_sme and a_np.shape[0] > 1:
+ pytest.skip(f"Conv2d with batches > 1 targeting SME not implemented.")
+
+ if target.features.has_sme and (a_np.shape[3] * w_np.shape[0] *
w_np.shape[1]) <= 1:
+ pytest.skip(f"Conv2d with unit reduction dimension targeting SME not
supported.")
# SME schedule always outputs float32 results, regardless of input dtype.
# Otherwise, output dtype is the same as input dtype.