This is an automated email from the ASF dual-hosted git repository.
ekalda pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new d1cd95fa9c [SME] Extract gemm block correctly when fused with bias
(#17076)
d1cd95fa9c is described below
commit d1cd95fa9c73fac4eced85548b919a4d69c16cdd
Author: Luke Hutton <[email protected]>
AuthorDate: Tue Jun 11 09:42:22 2024 +0100
[SME] Extract gemm block correctly when fused with bias (#17076)
[SME] Extract gemm block correctly when fused with bias/activation
Prior to this commit, the scheduling assumed the gemm block would
be the second to last block in the function ("unpadding" step is the
final block). However, when dense is fused with a bias or activation
the gemm block is no longer the second to last block. This commit
instead searches a single reduction block to use as the gemm block.
---
python/tvm/topi/arm_cpu/matmul.py | 8 ++++---
.../python/codegen/test_target_codegen_aarch64.py | 15 +++++++++++++
tests/python/relay/strategy/arm_cpu/test_dense.py | 26 ++++++++++++++--------
3 files changed, 37 insertions(+), 12 deletions(-)
diff --git a/python/tvm/topi/arm_cpu/matmul.py
b/python/tvm/topi/arm_cpu/matmul.py
index 2f09e24c87..23b8734a0b 100644
--- a/python/tvm/topi/arm_cpu/matmul.py
+++ b/python/tvm/topi/arm_cpu/matmul.py
@@ -26,6 +26,7 @@ from tvm.topi import nn
from tvm.topi.utils import get_const_tuple
from tvm.topi.arm_cpu.pstate_attributes import SMEAttributes
from tvm.topi.arm_cpu.arm_utils import pad_dim_to_multiple
+from tvm.dlight.base.analysis import normalize_prim_func
@autotvm.register_topi_compute("matmul.arm_cpu.sme")
@@ -126,9 +127,10 @@ def tir_schedule_matmul_sme(sch):
in_dtype = main_func.buffer_map[data_handle].dtype
out_dtype = "float32"
- root_block = sch.get_block(main_func.body.block.name_hint)
- gemm_block = sch.get_child_blocks(root_block)[-2]
-
+ block_infos = normalize_prim_func(sch)
+ reduction_block_infos = [block_info for block_info in block_infos if
block_info.is_reduction()]
+ assert len(reduction_block_infos) == 1, "Expected a single gemm reduction
block."
+ gemm_block = reduction_block_infos[0].block_rv
gemm_block_name = sch.get(gemm_block).name_hint
transpose = gemm_block_name.split("_")[-1]
transpose_b = transpose[1] == "T"
diff --git a/tests/python/codegen/test_target_codegen_aarch64.py
b/tests/python/codegen/test_target_codegen_aarch64.py
index 77c22761a9..9b0408b949 100644
--- a/tests/python/codegen/test_target_codegen_aarch64.py
+++ b/tests/python/codegen/test_target_codegen_aarch64.py
@@ -540,6 +540,21 @@ def test_matmul_sme(dtype):
check_correct_assembly(dtype=dtype)
+def test_matmul_sme_no_reduction_block():
+ @T.prim_func
+ def prim_func(a: T.handle, b: T.handle):
+ A = T.match_buffer(a, (4,))
+ B = T.match_buffer(b, (4,))
+ for i in range(3):
+ with T.block("block"):
+ vi = T.axis.remap("S", [i])
+ B[vi] = A[vi]
+
+ sch = tvm.tir.Schedule(prim_func)
+ with pytest.raises(AssertionError, match="Expected a single gemm reduction
block."):
+ tvm.topi.arm_cpu.matmul.tir_schedule_matmul_sme(sch)
+
+
@pytest.mark.skipif(
llvm_version_major() < 11, reason="Vscale is not supported in earlier
versions of LLVM"
)
diff --git a/tests/python/relay/strategy/arm_cpu/test_dense.py
b/tests/python/relay/strategy/arm_cpu/test_dense.py
index 3a8427e815..fee8a87f12 100644
--- a/tests/python/relay/strategy/arm_cpu/test_dense.py
+++ b/tests/python/relay/strategy/arm_cpu/test_dense.py
@@ -99,16 +99,16 @@ class TestDense(BasicDenseTests):
)
@tvm.testing.requires_aprofile_aem_fvp
@pytest.mark.parametrize(
- "data_shape,weight_shape",
+ "data_shape,weight_shape,enable_bias",
[
- ((32, 32), (32, 32)),
- ((2, 35), (6, 35)),
- ((3, 3), (68, 3)),
- ((79, 65), (152, 65)),
+ ((32, 32), (32, 32), False),
+ ((2, 35), (6, 35), False),
+ ((3, 3), (68, 3), False),
+ ((79, 65), (152, 65), True),
],
)
@pytest.mark.parametrize("in_dtype", ["float32", "float16"])
-def test_sme_dense(data_shape, weight_shape, in_dtype):
+def test_sme_dense(data_shape, weight_shape, enable_bias, in_dtype):
np.random.seed(0)
out_dtype = "float32"
@@ -117,8 +117,14 @@ def test_sme_dense(data_shape, weight_shape, in_dtype):
weight_data = np.random.uniform(size=weight_shape).astype(in_dtype)
weight = relay.const(weight_data, dtype=in_dtype)
- dense = relay.nn.dense(inp, weight, out_dtype=out_dtype)
- func = relay.Function(relay.analysis.free_vars(dense), dense)
+ relay_op = relay.nn.dense(inp, weight, out_dtype=out_dtype)
+
+ if enable_bias:
+ bias_data = np.random.uniform(size=weight_shape[0]).astype(out_dtype)
+ bias = relay.const(bias_data, dtype=out_dtype)
+ relay_op = relay.nn.bias_add(relay_op, bias)
+
+ func = relay.Function(relay.analysis.free_vars(relay_op), relay_op)
ir_mod = tvm.IRModule.from_expr(func)
ir_mod = tvm.relay.transform.InferType()(ir_mod)
@@ -147,8 +153,10 @@ def test_sme_dense(data_shape, weight_shape, in_dtype):
runtime=runtime,
params=params,
)
+
+ bias_postfix = "_add" if enable_bias else ""
generated_func = executor_factory.lowered_ir_mods.items()[0][1][
- "tvmgen_default_fused_nn_matmul"
+ f"tvmgen_default_fused_nn_matmul{bias_postfix}"
]
extra_memory_in_bytes =
calculate_extra_workspace_size_from_scalable_extents(generated_func, 4)