This is an automated email from the ASF dual-hosted git repository.

ekalda pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new d1cd95fa9c [SME] Extract gemm block correctly when fused with bias 
(#17076)
d1cd95fa9c is described below

commit d1cd95fa9c73fac4eced85548b919a4d69c16cdd
Author: Luke Hutton <[email protected]>
AuthorDate: Tue Jun 11 09:42:22 2024 +0100

    [SME] Extract gemm block correctly when fused with bias (#17076)
    
    [SME] Extract gemm block correctly when fused with bias/activation
    
    Prior to this commit, the scheduling assumed the gemm block would
    be the second to last block in the function ("unpadding" step is the
    final block). However, when dense is fused with a bias or activation
    the gemm block is no longer the second to last block. This commit
    instead searches a single reduction block to use as the gemm block.
---
 python/tvm/topi/arm_cpu/matmul.py                  |  8 ++++---
 .../python/codegen/test_target_codegen_aarch64.py  | 15 +++++++++++++
 tests/python/relay/strategy/arm_cpu/test_dense.py  | 26 ++++++++++++++--------
 3 files changed, 37 insertions(+), 12 deletions(-)

diff --git a/python/tvm/topi/arm_cpu/matmul.py 
b/python/tvm/topi/arm_cpu/matmul.py
index 2f09e24c87..23b8734a0b 100644
--- a/python/tvm/topi/arm_cpu/matmul.py
+++ b/python/tvm/topi/arm_cpu/matmul.py
@@ -26,6 +26,7 @@ from tvm.topi import nn
 from tvm.topi.utils import get_const_tuple
 from tvm.topi.arm_cpu.pstate_attributes import SMEAttributes
 from tvm.topi.arm_cpu.arm_utils import pad_dim_to_multiple
+from tvm.dlight.base.analysis import normalize_prim_func
 
 
 @autotvm.register_topi_compute("matmul.arm_cpu.sme")
@@ -126,9 +127,10 @@ def tir_schedule_matmul_sme(sch):
     in_dtype = main_func.buffer_map[data_handle].dtype
     out_dtype = "float32"
 
-    root_block = sch.get_block(main_func.body.block.name_hint)
-    gemm_block = sch.get_child_blocks(root_block)[-2]
-
+    block_infos = normalize_prim_func(sch)
+    reduction_block_infos = [block_info for block_info in block_infos if 
block_info.is_reduction()]
+    assert len(reduction_block_infos) == 1, "Expected a single gemm reduction 
block."
+    gemm_block = reduction_block_infos[0].block_rv
     gemm_block_name = sch.get(gemm_block).name_hint
     transpose = gemm_block_name.split("_")[-1]
     transpose_b = transpose[1] == "T"
diff --git a/tests/python/codegen/test_target_codegen_aarch64.py 
b/tests/python/codegen/test_target_codegen_aarch64.py
index 77c22761a9..9b0408b949 100644
--- a/tests/python/codegen/test_target_codegen_aarch64.py
+++ b/tests/python/codegen/test_target_codegen_aarch64.py
@@ -540,6 +540,21 @@ def test_matmul_sme(dtype):
     check_correct_assembly(dtype=dtype)
 
 
+def test_matmul_sme_no_reduction_block():
+    @T.prim_func
+    def prim_func(a: T.handle, b: T.handle):
+        A = T.match_buffer(a, (4,))
+        B = T.match_buffer(b, (4,))
+        for i in range(3):
+            with T.block("block"):
+                vi = T.axis.remap("S", [i])
+                B[vi] = A[vi]
+
+    sch = tvm.tir.Schedule(prim_func)
+    with pytest.raises(AssertionError, match="Expected a single gemm reduction 
block."):
+        tvm.topi.arm_cpu.matmul.tir_schedule_matmul_sme(sch)
+
+
 @pytest.mark.skipif(
     llvm_version_major() < 11, reason="Vscale is not supported in earlier 
versions of LLVM"
 )
diff --git a/tests/python/relay/strategy/arm_cpu/test_dense.py 
b/tests/python/relay/strategy/arm_cpu/test_dense.py
index 3a8427e815..fee8a87f12 100644
--- a/tests/python/relay/strategy/arm_cpu/test_dense.py
+++ b/tests/python/relay/strategy/arm_cpu/test_dense.py
@@ -99,16 +99,16 @@ class TestDense(BasicDenseTests):
 )
 @tvm.testing.requires_aprofile_aem_fvp
 @pytest.mark.parametrize(
-    "data_shape,weight_shape",
+    "data_shape,weight_shape,enable_bias",
     [
-        ((32, 32), (32, 32)),
-        ((2, 35), (6, 35)),
-        ((3, 3), (68, 3)),
-        ((79, 65), (152, 65)),
+        ((32, 32), (32, 32), False),
+        ((2, 35), (6, 35), False),
+        ((3, 3), (68, 3), False),
+        ((79, 65), (152, 65), True),
     ],
 )
 @pytest.mark.parametrize("in_dtype", ["float32", "float16"])
-def test_sme_dense(data_shape, weight_shape, in_dtype):
+def test_sme_dense(data_shape, weight_shape, enable_bias, in_dtype):
     np.random.seed(0)
     out_dtype = "float32"
 
@@ -117,8 +117,14 @@ def test_sme_dense(data_shape, weight_shape, in_dtype):
     weight_data = np.random.uniform(size=weight_shape).astype(in_dtype)
     weight = relay.const(weight_data, dtype=in_dtype)
 
-    dense = relay.nn.dense(inp, weight, out_dtype=out_dtype)
-    func = relay.Function(relay.analysis.free_vars(dense), dense)
+    relay_op = relay.nn.dense(inp, weight, out_dtype=out_dtype)
+
+    if enable_bias:
+        bias_data = np.random.uniform(size=weight_shape[0]).astype(out_dtype)
+        bias = relay.const(bias_data, dtype=out_dtype)
+        relay_op = relay.nn.bias_add(relay_op, bias)
+
+    func = relay.Function(relay.analysis.free_vars(relay_op), relay_op)
 
     ir_mod = tvm.IRModule.from_expr(func)
     ir_mod = tvm.relay.transform.InferType()(ir_mod)
@@ -147,8 +153,10 @@ def test_sme_dense(data_shape, weight_shape, in_dtype):
             runtime=runtime,
             params=params,
         )
+
+    bias_postfix = "_add" if enable_bias else ""
     generated_func = executor_factory.lowered_ir_mods.items()[0][1][
-        "tvmgen_default_fused_nn_matmul"
+        f"tvmgen_default_fused_nn_matmul{bias_postfix}"
     ]
     extra_memory_in_bytes = 
calculate_extra_workspace_size_from_scalable_extents(generated_func, 4)
 

Reply via email to