https://github.com/momchil-velikov updated https://github.com/llvm/llvm-project/pull/148198
>From 5b7e081d51f2c98c91ed53964998b3f7f61fb747 Mon Sep 17 00:00:00 2001 From: Momchil Velikov <momchil.veli...@arm.com> Date: Fri, 11 Jul 2025 10:03:18 +0000 Subject: [PATCH 1/4] [MLIR][AArch64] Lower vector.contract to Neon FEAT_BF16 operations --- mlir/include/mlir/Conversion/Passes.td | 4 + .../TransformOps/ArmNeonVectorTransformOps.td | 15 +- .../include/mlir/Dialect/ArmNeon/Transforms.h | 4 +- .../VectorToLLVM/ConvertVectorToLLVMPass.cpp | 4 +- .../ArmNeonVectorTransformOps.cpp | 7 +- .../Dialect/ArmNeon/Transforms/CMakeLists.txt | 2 +- ...rn.cpp => LowerContractToNeonPatterns.cpp} | 126 +++++++--- .../LowerContractionToSVEI8MMPattern.cpp | 2 +- mlir/test/Dialect/ArmNeon/vector-bfmmla.mlir | 225 ++++++++++++++++++ .../CPU/ArmNeon/vector-contract-bfmmla.mlir | 176 ++++++++++++++ .../CPU/ArmNeon/vector-contract-i8mm.mlir | 2 +- 11 files changed, 531 insertions(+), 36 deletions(-) rename mlir/lib/Dialect/ArmNeon/Transforms/{LowerContractionToNeonI8MMPattern.cpp => LowerContractToNeonPatterns.cpp} (81%) create mode 100644 mlir/test/Dialect/ArmNeon/vector-bfmmla.mlir create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmNeon/vector-contract-bfmmla.mlir diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index 76e751243a12c..8183f355795a9 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -1449,6 +1449,10 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> { "bool", /*default=*/"false", "Enables the use of Arm FEAT_I8MM instructions while lowering " "the vector dialect.">, + Option<"armBF16", "enable-arm-bf16", + "bool", /*default=*/"false", + "Enables the use of Arm FEAT_BF16 instructions while lowering " + "the vector dialect.">, Option<"x86Vector", "enable-x86vector", "bool", /*default=*/"false", "Enables the use of X86Vector dialect while lowering the vector " diff --git a/mlir/include/mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.td b/mlir/include/mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.td index bcaca7da967fa..35747126d3db1 100644 --- a/mlir/include/mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.td +++ b/mlir/include/mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.td @@ -17,8 +17,19 @@ def ApplyArmNeonContractionToI8MMPatternsOp "apply_patterns.arm_neon.vector_contract_to_i8mm", [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> { let description = [{ - Indicates that vector.contract operations should be lowered to - finer-grained vector primitives from the ArmNeon dialect. + Indicates that vector contract operations should be lowered to + to ArmNeon dialect operations mapping to instructions from FEAT_I8MM. + }]; + + let assemblyFormat = "attr-dict"; +} + +def ApplyArmNeonContractionToBFMMLAPatternsOp + : Op<Transform_Dialect, "apply_patterns.arm_neon.vector_contract_to_bfmmla", + [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> { + let description = [{ + Indicates that vector contract operations should be lowered to + to ArmNeon dialect operations mapping to instructions from FEAT_BF16. }]; let assemblyFormat = "attr-dict"; diff --git a/mlir/include/mlir/Dialect/ArmNeon/Transforms.h b/mlir/include/mlir/Dialect/ArmNeon/Transforms.h index 2f0f634a96770..08065a3b25266 100644 --- a/mlir/include/mlir/Dialect/ArmNeon/Transforms.h +++ b/mlir/include/mlir/Dialect/ArmNeon/Transforms.h @@ -13,8 +13,8 @@ namespace mlir { class RewritePatternSet; namespace arm_neon { -void populateLowerContractionToNeonI8MMPatternPatterns( - RewritePatternSet &patterns); +void populateLowerContractionToNeonI8MMPatterns(RewritePatternSet &patterns); +void populateLowerContractionToNeonBFMMLAPatterns(RewritePatternSet &patterns); } // namespace arm_neon } // namespace mlir diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp index 66d2ffbde1751..e9ae7131b8e1d 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp @@ -96,10 +96,12 @@ void ConvertVectorToLLVMPass::runOnOperation() { populateVectorGatherLoweringPatterns(patterns); if (armI8MM) { if (armNeon) - arm_neon::populateLowerContractionToNeonI8MMPatternPatterns(patterns); + arm_neon::populateLowerContractionToNeonI8MMPatterns(patterns); if (armSVE) populateLowerContractionToSVEI8MMPatternPatterns(patterns); } + if (armBF16 && armNeon) + arm_neon::populateLowerContractionToNeonBFMMLAPatterns(patterns); (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } diff --git a/mlir/lib/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.cpp b/mlir/lib/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.cpp index d07e6a52d8b5f..d069bde6d9979 100644 --- a/mlir/lib/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.cpp +++ b/mlir/lib/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.cpp @@ -20,7 +20,12 @@ using namespace mlir; void transform::ApplyArmNeonContractionToI8MMPatternsOp::populatePatterns( RewritePatternSet &patterns) { - arm_neon::populateLowerContractionToNeonI8MMPatternPatterns(patterns); + arm_neon::populateLowerContractionToNeonI8MMPatterns(patterns); +} + +void transform::ApplyArmNeonContractionToBFMMLAPatternsOp::populatePatterns( + RewritePatternSet &patterns) { + arm_neon::populateLowerContractionToNeonBFMMLAPatterns(patterns); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/ArmNeon/Transforms/CMakeLists.txt b/mlir/lib/Dialect/ArmNeon/Transforms/CMakeLists.txt index 06bafde451cbb..368dacac7b835 100644 --- a/mlir/lib/Dialect/ArmNeon/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/ArmNeon/Transforms/CMakeLists.txt @@ -1,5 +1,5 @@ add_mlir_dialect_library(MLIRArmNeonTransforms - LowerContractionToNeonI8MMPattern.cpp + LowerContractToNeonPatterns.cpp DEPENDS MLIRArmNeonIncGen diff --git a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToNeonI8MMPattern.cpp b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractToNeonPatterns.cpp similarity index 81% rename from mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToNeonI8MMPattern.cpp rename to mlir/lib/Dialect/ArmNeon/Transforms/LowerContractToNeonPatterns.cpp index 0738de6c7788c..1ad563537d874 100644 --- a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToNeonI8MMPattern.cpp +++ b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractToNeonPatterns.cpp @@ -1,4 +1,4 @@ -//===- LowerContractionToNeonI8MMPattern.cpp - Contract to I8MM -*- C++ -*-===// +//===- LowerContractToNeonPatterns.cpp - Contract to I8MM/BF16 --*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -93,15 +93,20 @@ class VectorContractRewriter { // multiplications. enum class MMLA { Nop, - Signed, // smmla - Unsigned, // ummla - Mixed, // usmmla - MixedSwapped // usmmla with LHS and RHS swapped + SignedInt, // smmla + UnsignedInt, // ummla + MixedInt, // usmmla + Bfloat // bfmmla }; // Lower-level operation to be emitted. MMLA mmlaOp = MMLA::Nop; + // Indicate if the operands for the ArmNeon dialect operation need to be + // swapped. Currently this is needed in order to emulate an "summla" + // operation. + bool swapOperands = false; + // The operand tiles. These are not necessarily the operands of // `vector.contract`, for example they could be operands to `arith.extsi` // that is in turn fed into `vector.contract`. @@ -126,21 +131,22 @@ class VectorContractRewriter { // Create the matrix multiply and accumulate operation according to `mmlaOp`. Value createMMLA(PatternRewriter &rewriter, Location loc, Value acc, Value lhs, Value rhs) { + + if (swapOperands) + std::swap(lhs, rhs); switch (mmlaOp) { - case MMLA::Signed: + case MMLA::SignedInt: return rewriter.createOrFold<arm_neon::SmmlaOp>(loc, acc.getType(), acc, lhs, rhs); - case MMLA::Unsigned: + case MMLA::UnsignedInt: return rewriter.createOrFold<arm_neon::UmmlaOp>(loc, acc.getType(), acc, lhs, rhs); - case MMLA::Mixed: + case MMLA::MixedInt: return rewriter.createOrFold<arm_neon::UsmmlaOp>(loc, acc.getType(), acc, lhs, rhs); - case MMLA::MixedSwapped: - // The accumulator comes transposed and the result will be transposed - // later, so all we have to do here is swap the operands. - return rewriter.createOrFold<arm_neon::UsmmlaOp>(loc, acc.getType(), acc, - rhs, lhs); + case MMLA::Bfloat: + return rewriter.create<arm_neon::BfmmlaOp>(loc, acc.getType(), acc, lhs, + rhs); case MMLA::Nop: llvm_unreachable("Uninitialized operation type"); } @@ -273,7 +279,7 @@ class VectorContractRewriter { // Transpose ACC if doing signed by unsigned multiplication, because we're // using the instruction for unsigned by signed multiplication with // reversed operands. - if (mmlaOp == MMLA::MixedSwapped) + if (swapOperands) tiledAcc = rewriter.create<vector::TransposeOp>( loc, tiledAcc, ArrayRef<int64_t>({1, 0})); @@ -302,7 +308,7 @@ class VectorContractRewriter { // Because of the reversed operands the result is obtained transposed. // Transpose it back, - if (mmlaOp == MMLA::MixedSwapped) + if (swapOperands) tiledRes = rewriter.create<vector::TransposeOp>( loc, tiledRes, ArrayRef<int64_t>({1, 0})); @@ -339,10 +345,10 @@ class VectorContractRewriterI8MM : public VectorContractRewriter { // values before the extension. All four signed/unsigned combinations for // input operands are supported, but they are lowered to different // operations. Determine which is the appropriate operation to lower to. - mmlaOp = MMLA::Signed; + mmlaOp = MMLA::SignedInt; auto maybeLhs = getExtOperand<arith::ExtSIOp>(op.getLhs()); if (!maybeLhs) { - mmlaOp = MMLA::Unsigned; + mmlaOp = MMLA::UnsignedInt; maybeLhs = getExtOperand<arith::ExtUIOp>(op.getLhs()); } if (!maybeLhs) @@ -351,11 +357,13 @@ class VectorContractRewriterI8MM : public VectorContractRewriter { auto maybeRhs = getExtOperand<arith::ExtSIOp>(op.getRhs()); if (maybeRhs) { - if (mmlaOp == MMLA::Unsigned) - mmlaOp = MMLA::Mixed; + if (mmlaOp == MMLA::UnsignedInt) + mmlaOp = MMLA::MixedInt; } else { - if (mmlaOp == MMLA::Signed) - mmlaOp = MMLA::MixedSwapped; + if (mmlaOp == MMLA::SignedInt) { + mmlaOp = MMLA::MixedInt; + swapOperands = true; + } maybeRhs = getExtOperand<arith::ExtUIOp>(op.getRhs()); } @@ -372,16 +380,17 @@ class VectorContractRewriterI8MM : public VectorContractRewriter { auto lhsExtInType = cast<VectorType>(lhs.getType()); if (lhsExtInType.getElementTypeBitWidth() < 8) lhs = extendSmallIntVector(loc, lhsExtInType, lhs, - /* signExt */ mmlaOp == MMLA::Signed || - mmlaOp == MMLA::Mixed, + /* signExt */ + (mmlaOp == MMLA::SignedInt || + (mmlaOp == MMLA::MixedInt && !swapOperands)), rewriter); auto rhsExtInType = cast<VectorType>(rhs.getType()); if (rhsExtInType.getElementTypeBitWidth() < 8) - rhs = extendSmallIntVector(loc, rhsExtInType, rhs, - /* signExt */ mmlaOp != MMLA::Unsigned && - mmlaOp != MMLA::Mixed, + /* signExt */ + (mmlaOp == MMLA::SignedInt || + (mmlaOp == MMLA::MixedInt && swapOperands)), rewriter); // Initialize parameters for unrolling. @@ -395,6 +404,47 @@ class VectorContractRewriterI8MM : public VectorContractRewriter { } }; +class VectorContractRewriterBFMMLA : public VectorContractRewriter { +public: + LogicalResult matchAndInit(vector::ContractionOp op, + PatternRewriter &rewriter) { + + if (failed(VectorContractRewriter::matchAndInit(op, rewriter))) + return failure(); + + // Unrolling patterns can handle any [2, 2, 4] shaped multiple of inputs for + // tiling. + if ((dimM != 1 && dimM % 2 != 0) || dimN % 2 != 0 || dimK % 4 != 0) + return rewriter.notifyMatchFailure(op, "Unsupported operand shapes"); + + // Check the output is a vector of Float32 elements. + auto outTy = dyn_cast<VectorType>(op.getResultType()); + if (!outTy || outTy.getElementType() != rewriter.getF32Type()) + return rewriter.notifyMatchFailure(op, + "output type is not a vector of f32"); + + // Check the inputs are vectors of BFloat16 elements. + if (op.getLhsType().getElementType() != rewriter.getBF16Type()) + return rewriter.notifyMatchFailure(op, + "input type is not a vector of bf16"); + + mmlaOp = MMLA::Bfloat; + swapOperands = false; + lhs = op.getLhs(); + rhs = op.getRhs(); + acc = op.getAcc(); + + // Initialize parameters for unrolling. + iterationBounds = *op.getShapeForUnroll(); + if (iterationBounds.size() == 3) + subTileShape = SmallVector<int64_t>({dimM == 1 ? 1 : 2, 2, 4}); + else + subTileShape = SmallVector<int64_t>({2, 4}); + + return success(); + } +}; + /// Lowering from a vector::contractOp arm neon smmla intrinsic. This will tile /// any vector.contract into multiple smmla instructions with unrolling so long /// as [2,2,8] is a divisor of its shape. It can also process vecmats with dimM @@ -416,10 +466,32 @@ class LowerContractionToNeonI8MMPattern } }; +class LowerContractionToNeonBFMMLAPattern + : public OpRewritePattern<vector::ContractionOp> { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(vector::ContractionOp op, + PatternRewriter &rewriter) const override { + + VectorContractRewriterBFMMLA vcr; + if (failed(vcr.matchAndInit(op, rewriter))) + return failure(); + vcr.rewrite(op, rewriter); + + return success(); + } +}; + } // namespace -void mlir::arm_neon::populateLowerContractionToNeonI8MMPatternPatterns( +void mlir::arm_neon::populateLowerContractionToNeonI8MMPatterns( RewritePatternSet &patterns) { MLIRContext *context = patterns.getContext(); patterns.add<LowerContractionToNeonI8MMPattern>(context, /*benefit=*/2); } + +void mlir::arm_neon::populateLowerContractionToNeonBFMMLAPatterns( + RewritePatternSet &patterns) { + MLIRContext *context = patterns.getContext(); + patterns.add<LowerContractionToNeonBFMMLAPattern>(context, /*benefit=*/2); +} diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp index bd051b100a91b..a1c6f2cb5dd9f 100644 --- a/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp +++ b/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp @@ -12,7 +12,7 @@ // TODO: There may be opportunities to unify this with a similar pattern // for Neon. See: // https://github.com/llvm/llvm-project/issues/145559 -// LowerContractionToNeonI8MMPattern.cpp +// LowerContracToNeonPatterns.cpp // //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/ArmNeon/vector-bfmmla.mlir b/mlir/test/Dialect/ArmNeon/vector-bfmmla.mlir new file mode 100644 index 0000000000000..229c4e5b2dc3a --- /dev/null +++ b/mlir/test/Dialect/ArmNeon/vector-bfmmla.mlir @@ -0,0 +1,225 @@ +// RUN: mlir-opt %s --transform-interpreter | FileCheck %s + +// Test lowering of vector.contract to BFMMLA operations. +// For each iteration [I, J, K] sub-tiles are extracted from offsets as follows: +// LHS: [2*I, 4*K] +// RHS: [2*J, 4*K] +// ACC: [2*I, 2*J] +// Sub-tile insert offsets for the result are as like ACC (there are redundant +// inserts). + +// CHECK-LABEL: func.func @vector_contract_to_bfmmla +// CHECK-SAME: %[[LHS:.+]]: vector<4x8xbf16>, %[[RHS:.+]]: vector<4x8xbf16>, %[[ACC:.+]]: vector<4x4xf32> + +// %[[INIT_RES:.+]] = arith.constant dense<0.000000e+00> : vector<4x4xf32> + +// Iteration [0, 0, 0] +// Extract sib-tiles from each of LHS, RHS and ACC +// %[[T0:.+]] = vector.extract_strided_slice %[[LHS]] {offsets = [0, 0], sizes = [2, 4], strides = [1, 1]} : vector<4x8xbf16> to vector<2x4xbf16> +// %[[T1:.+]] = vector.extract_strided_slice %[[RHS]] {offsets = [0, 0], sizes = [2, 4], strides = [1, 1]} : vector<4x8xbf16> to vector<2x4xbf16> +// %[[T2:.+]] = vector.extract_strided_slice %[[ACC]] {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32> + +// Flatten the operands to fit the `bfmmla` operation types +// %[[T3:.+]] = vector.shape_cast %[[T0]] : vector<2x4xbf16> to vector<8xbf16> +// %[[T4:.+]] = vector.shape_cast %[[T1]] : vector<2x4xbf16> to vector<8xbf16> +// %[[T5:.+]] = vector.shape_cast %[[T2]] : vector<2x2xf32> to vector<4xf32> + +// Perform the matrix multiply and accumulate +// %[[K_ACC_0:.+]] = arm_neon.intr.bfmmla %[[T5]], %[[T3]], %[[T4]] : vector<8xbf16> to vector<4xf32> + +// Un-flatten the output sub-tile and inserr into the result +// %[[T7:.+]] = vector.shape_cast %[[K_ACC_0]] : vectK_ACCor<4xf32> to vector<2x2xf32> +// %[[TMP_RES_0:.+]] = vector.insert_strided_slice %[[T7]], %[[INIT_RES]] {offsets = [0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> + +// Iteration [0, 0, 1] +// %[[T9:.+]] = vector.extract_strided_slice %[[LHS]] {offsets = [0, 4], sizes = [2, 4], strides = [1, 1]} : vector<4x8xbf16> to vector<2x4xbf16> +// %[[T10:.+]] = vector.extract_strided_slice %[[RHS]] {offsets = [0, 4], sizes = [2, 4], strides = [1, 1]} : vector<4x8xbf16> to vector<2x4xbf16> +// %[[T11:.+]] = vector.shape_cast %[[T9]] : vector<2x4xbf16> to vector<8xbf16> +// %[[T12:.+]] = vector.shape_cast %[[T1]]0 : vector<2x4xbf16> to vector<8xbf16> +// %[[T13:.+]] = arm_neon.intr.bfmmla %[[K_ACC_0]], %[[T1]]1, %[[T1]]2 : vector<8xbf16> to vector<4xf32> +// %[[T14:.+]] = vector.shape_cast %[[T1]]3 : vector<4xf32> to vector<2x2xf32> +// %[[TMP_RES_1:.+]] = vector.insert_strided_slice %[[T1]]4, %[[TMP_RES_0]] {offsets = [0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> + +// Iteration [0, 1, 0] +// %[[T16:.+]] = vector.extract_strided_slice %[[LHS]] {offsets = [0, 0], sizes = [2, 4], strides = [1, 1]} : vector<4x8xbf16> to vector<2x4xbf16> +// %[[T17:.+]] = vector.extract_strided_slice %[[RHS]] {offsets = [2, 0], sizes = [2, 4], strides = [1, 1]} : vector<4x8xbf16> to vector<2x4xbf16> +// %[[T18:.+]] = vector.extract_strided_slice %[[ACC]] {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32> +// %[[T19:.+]] = vector.shape_cast %[[T1]]6 : vector<2x4xbf16> to vector<8xbf16> +// %[[T20:.+]] = vector.shape_cast %[[T1]]7 : vector<2x4xbf16> to vector<8xbf16> +// %[[T21:.+]] = vector.shape_cast %[[T1]]8 : vector<2x2xf32> to vector<4xf32> +// %[[K_ACC_1:.+]] = arm_neon.intr.bfmmla %[[T2]]1, %[[T1]]9, %[[T2]]0 : vector<8xbf16> to vector<4xf32> +// %[[T23:.+]] = vector.shape_cast %[[K_ACC_1]] : vector<4xf32> to vector<2x2xf32> +// %[[TMP_RES_2:.+]] = vector.insert_strided_slice %[[T2]]3, %[[TMP_RES_1]] {offsets = [0, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> + +// Iteration [0, 1, 1] +// %[[T25:.+]] = vector.extract_strided_slice %[[LHS]] {offsets = [0, 4], sizes = [2, 4], strides = [1, 1]} : vector<4x8xbf16> to vector<2x4xbf16> +// %[[T26:.+]] = vector.extract_strided_slice %[[RHS]] {offsets = [2, 4], sizes = [2, 4], strides = [1, 1]} : vector<4x8xbf16> to vector<2x4xbf16> +// %[[T27:.+]] = vector.shape_cast %[[T2]]5 : vector<2x4xbf16> to vector<8xbf16> +// %[[T28:.+]] = vector.shape_cast %[[T2]]6 : vector<2x4xbf16> to vector<8xbf16> +// %[[T29:.+]] = arm_neon.intr.bfmmla %[[K_ACC_1]], %[[T2]]7, %[[T2]]8 : vector<8xbf16> to vector<4xf32> +// %[[T30:.+]] = vector.shape_cast %[[T2]]9 : vector<4xf32> to vector<2x2xf32> +// %[[TMP_RES_3:.+]] = vector.insert_strided_slice %[[T3]]0, %[[TMP_RES_2]] {offsets = [0, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> + +// Iteration [1, 0, 0] +// %[[T32:.+]] = vector.extract_strided_slice %[[LHS]] {offsets = [2, 0], sizes = [2, 4], strides = [1, 1]} : vector<4x8xbf16> to vector<2x4xbf16> +// %[[T33:.+]] = vector.extract_strided_slice %[[RHS]] {offsets = [0, 0], sizes = [2, 4], strides = [1, 1]} : vector<4x8xbf16> to vector<2x4xbf16> +// %[[T34:.+]] = vector.extract_strided_slice %[[ACC]] {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32> +// %[[T35:.+]] = vector.shape_cast %[[T3]]2 : vector<2x4xbf16> to vector<8xbf16> +// %[[T36:.+]] = vector.shape_cast %[[T3]]3 : vector<2x4xbf16> to vector<8xbf16> +// %[[T37:.+]] = vector.shape_cast %[[T3]]4 : vector<2x2xf32> to vector<4xf32> +// %[[K_ACC_2:.+]] = arm_neon.intr.bfmmla %[[T3]]7, %[[T3]]5, %[[T3]]6 : vector<8xbf16> to vector<4xf32> +// %[[T39:.+]] = vector.shape_cast %[[K_ACC_2]] : vector<4xf32> to vector<2x2xf32> +//%[[TMP_RES_4:.+]] = vector.insert_strided_slice %[[T3]]9, %[[TMP_RES_3]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> + +// Iteration [1, 0, 1] +// %[[T41:.+]] = vector.extract_strided_slice %[[LHS]] {offsets = [2, 4], sizes = [2, 4], strides = [1, 1]} : vector<4x8xbf16> to vector<2x4xbf16> +// %[[T42:.+]] = vector.extract_strided_slice %[[RHS]] {offsets = [0, 4], sizes = [2, 4], strides = [1, 1]} : vector<4x8xbf16> to vector<2x4xbf16> +// %[[T43:.+]] = vector.shape_cast %[[T4]]1 : vector<2x4xbf16> to vector<8xbf16> +// %[[T44:.+]] = vector.shape_cast %[[T4]]2 : vector<2x4xbf16> to vector<8xbf16> +// %[[T45:.+]] = arm_neon.intr.bfmmla %[[K_ACC_2]], %[[T4]]3, %[[T4]]4 : vector<8xbf16> to vector<4xf32> +// %[[T46:.+]] = vector.shape_cast %[[T4]]5 : vector<4xf32> to vector<2x2xf32> +//%[[TMP_RES_5:.+]] = vector.insert_strided_slice %[[T4]]6,%[[TMP_RES_4]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> + +// Iteration [1, 1, 0] +// %[[T48:.+]] = vector.extract_strided_slice %[[LHS]] {offsets = [2, 0], sizes = [2, 4], strides = [1, 1]} : vector<4x8xbf16> to vector<2x4xbf16> +// %[[T49:.+]] = vector.extract_strided_slice %[[RHS]] {offsets = [2, 0], sizes = [2, 4], strides = [1, 1]} : vector<4x8xbf16> to vector<2x4xbf16> +// %[[T50:.+]] = vector.extract_strided_slice %[[ACC]] {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32> +// %[[T51:.+]] = vector.shape_cast %[[T4]]8 : vector<2x4xbf16> to vector<8xbf16> +// %[[T52:.+]] = vector.shape_cast %[[T4]]9 : vector<2x4xbf16> to vector<8xbf16> +// %[[T53:.+]] = vector.shape_cast %[[T5]]0 : vector<2x2xf32> to vector<4xf32> +// %[[K_ACC_3:.+]] = arm_neon.intr.bfmmla %[[T5]]3, %[[T5]]1, %[[T5]]2 : vector<8xbf16> to vector<4xf32> +// %[[T55:.+]] = vector.shape_cast %[[K_ACC_3]] : vector<4xf32> to vector<2x2xf32> +//%[[TMP_RES_6:.+]] = vector.insert_strided_slice %[[T5]]5,%[[TMP_RES_5]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> + +// Iteration [1, 1, 1] +// %[[T57:.+]] = vector.extract_strided_slice %[[LHS]] {offsets = [2, 4], sizes = [2, 4], strides = [1, 1]} : vector<4x8xbf16> to vector<2x4xbf16> +// %[[T58:.+]] = vector.extract_strided_slice %[[RHS]] {offsets = [2, 4], sizes = [2, 4], strides = [1, 1]} : vector<4x8xbf16> to vector<2x4xbf16> +// %[[T59:.+]] = vector.shape_cast %[[T5]]7 : vector<2x4xbf16> to vector<8xbf16> +// %[[T60:.+]] = vector.shape_cast %[[T5]]8 : vector<2x4xbf16> to vector<8xbf16> +// %[[T61:.+]] = arm_neon.intr.bfmmla %[[K_ACC_3]], %[[T5]]9, %[[T6]]0 : vector<8xbf16> to vector<4xf32> +// %[[T62:.+]] = vector.shape_cast %[[T6]]1 : vector<4xf32> to vector<2x2xf32> +// %[[RESULT:.+]] = vector.insert_strided_slice %[[T6]]2,%[[TMP_RES_6]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> + +// return %[[RESULT]] : vector<4x4xf32> + +func.func @vector_contract_to_bfmmla(%lhs: vector<4x8xbf16>, + %rhs: vector<4x8xbf16>, + %acc: vector<4x4xf32>) -> vector<4x4xf32> { + %0 = vector.contract { indexing_maps = [ + affine_map<(m, n, k) -> (m, k)>, + affine_map<(m, n, k) -> (n, k)>, + affine_map<(m, n, k) -> (m, n)> + ], + iterator_types = ["parallel", "parallel", "reduction"], + kind = #vector.kind<add> + } + %lhs, %rhs, %acc : vector<4x8xbf16>, vector<4x8xbf16> into vector<4x4xf32> + + return %0 : vector<4x4xf32> +} + +// Test lowering of vector.contract, representing vector by matrix multiply and +// accumulate, to BFMMLA operations. + +// For each iteration [J, K] sub-tiles are extracted from offsets as follows: +// LHS: [4*K] +// RHS: [2*J, 4*K] +// ACC: [2*J] +// Sub-tile insert offsets for the result are as like ACC (there are redundant +// inserts). +// CHECK-LABEL: func.func @vector_contract_vecmat_to_bfmmla +// CHECK-SAME: %[[LHS:.+]]: vector<8xbf16>, %[[RHS:.+]]: vector<4x8xbf16>, %[[ACC:.+]]: vector<4xf32>) -> vector<4xf32> { +// CHECK: %[[ACC_PAD_Z:.+]] = arith.constant dense<0.000000e+00> : vector<2x2xf32> +// CHECK: %[[LHS_PAD_Z:.+]] = arith.constant dense<0.000000e+00> : vector<2x4xbf16> +// CHECK: %[[RES_INIT:.+]] = arith.constant dense<0.000000e+00> : vector<4xf32> + +// Iteration [0, 0] +// Extract sub-tiles +// CHECK: %[[T0:.+]] = vector.extract_strided_slice %[[LHS]] {offsets = [0], sizes = [4], strides = [1]} : vector<8xbf16> to vector<4xbf16> +// CHECK: %[[T1:.+]] = vector.extract_strided_slice %[[RHS]] {offsets = [0, 0], sizes = [2, 4], strides = [1, 1]} : vector<4x8xbf16> to vector<2x4xbf16> +// CHECK: %[[T2:.+]] = vector.extract_strided_slice %[[ACC]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32> + +// Pad LHS sub-tile/vector with an extra row of zeroes +// CHECK: %[[T3:.+]] = vector.insert_strided_slice %[[T0]], %[[LHS_PAD_Z]] {offsets = [0, 0], strides = [1]} : vector<4xbf16> into vector<2x4xbf16> + +// Pad ACC sub-tile/vector with an extra row of zeroes +// CHECK: %[[T4:.+]] = vector.insert_strided_slice %[[T2]], %[[ACC_PAD_Z]] {offsets = [0, 0], strides = [1]} : vector<2xf32> into vector<2x2xf32> + +// Flatten the operands to fit the `bfmmla` operation types +// CHECK: %[[T5:.+]] = vector.shape_cast %[[T3]] : vector<2x4xbf16> to vector<8xbf16> +// CHECK: %[[T6:.+]] = vector.shape_cast %[[T1]] : vector<2x4xbf16> to vector<8xbf16> +// CHECK: %[[T7:.+]] = vector.shape_cast %[[T4]] : vector<2x2xf32> to vector<4xf32> + +// Perform the matrix multiply and accumulate +// CHECK: %[[K_ACC_0:.+]] = arm_neon.intr.bfmmla %[[T7]], %[[T5]], %[[T6]] : vector<8xbf16> to vector<4xf32> + +// Un-flatten the output sub-tile +// CHECK: %[[T9:.+]] = vector.shape_cast %[[K_ACC_0]] : vector<4xf32> to vector<2x2xf32> + +// Extract the first rows (the second row is padding) and insert into the result +// CHECK: %[[T10:.+]] = vector.extract %[[T9]][0] : vector<2xf32> from vector<2x2xf32> +// CHECK: %[[TMP_RES_0:.+]] = vector.insert_strided_slice %[[T10]], %[[RES_INIT]] {offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32> + +// Iteration [0, 1] +// CHECK: %[[T12:.+]] = vector.extract_strided_slice %[[LHS]] {offsets = [4], sizes = [4], strides = [1]} : vector<8xbf16> to vector<4xbf16> +// CHECK: %[[T13:.+]] = vector.extract_strided_slice %[[RHS]] {offsets = [0, 4], sizes = [2, 4], strides = [1, 1]} : vector<4x8xbf16> to vector<2x4xbf16> +// CHECK: %[[T14:.+]] = vector.insert_strided_slice %[[T12]], %[[LHS_PAD_Z]] {offsets = [0, 0], strides = [1]} : vector<4xbf16> into vector<2x4xbf16> +// CHECK: %[[T15:.+]] = vector.shape_cast %[[T14]] : vector<2x4xbf16> to vector<8xbf16> +// CHECK: %[[T16:.+]] = vector.shape_cast %[[T13]] : vector<2x4xbf16> to vector<8xbf16> +// CHECK: %[[T17:.+]] = arm_neon.intr.bfmmla %[[K_ACC_0]], %[[T15]], %[[T16]] : vector<8xbf16> to vector<4xf32> +// CHECK: %[[T18:.+]] = vector.shape_cast %[[T17]] : vector<4xf32> to vector<2x2xf32> +// CHECK: %[[T19:.+]] = vector.extract %[[T18]][0] : vector<2xf32> from vector<2x2xf32> +// CHECK: %[[TMP_RES_1:.+]] = vector.insert_strided_slice %[[T19]], %[[TMP_RES_0]] {offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32> + +// Iteration [1, 0] +// CHECK: %[[T21:.+]] = vector.extract_strided_slice %[[LHS]] {offsets = [0], sizes = [4], strides = [1]} : vector<8xbf16> to vector<4xbf16> +// CHECK: %[[T22:.+]] = vector.extract_strided_slice %[[RHS]] {offsets = [2, 0], sizes = [2, 4], strides = [1, 1]} : vector<4x8xbf16> to vector<2x4xbf16> +// CHECK: %[[T23:.+]] = vector.extract_strided_slice %[[ACC]] {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32> +// CHECK: %[[T24:.+]] = vector.insert_strided_slice %[[T21]], %[[LHS_PAD_Z]] {offsets = [0, 0], strides = [1]} : vector<4xbf16> into vector<2x4xbf16> +// CHECK: %[[T25:.+]] = vector.insert_strided_slice %[[T23]], %[[ACC_PAD_Z]] {offsets = [0, 0], strides = [1]} : vector<2xf32> into vector<2x2xf32> +// CHECK: %[[T26:.+]] = vector.shape_cast %[[T24]] : vector<2x4xbf16> to vector<8xbf16> +// CHECK: %[[T27:.+]] = vector.shape_cast %[[T22]] : vector<2x4xbf16> to vector<8xbf16> +// CHECK: %[[T28:.+]] = vector.shape_cast %[[T25]] : vector<2x2xf32> to vector<4xf32> +// CHECK: %[[K_ACC_1:.+]] = arm_neon.intr.bfmmla %[[T28]], %[[T26]], %[[T27]] : vector<8xbf16> to vector<4xf32> +// CHECK: %[[T30:.+]] = vector.shape_cast %[[K_ACC_1]] : vector<4xf32> to vector<2x2xf32> +// CHECK: %[[T31:.+]] = vector.extract %[[T30]][0] : vector<2xf32> from vector<2x2xf32> +// CHECK: %[[TMP_RES_2:.+]] = vector.insert_strided_slice %[[T31]], %[[TMP_RES_1]] {offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32> + +// Iteration [1, 1] +// CHECK: %[[T33:.+]] = vector.extract_strided_slice %[[LHS]] {offsets = [4], sizes = [4], strides = [1]} : vector<8xbf16> to vector<4xbf16> +// CHECK: %[[T34:.+]] = vector.extract_strided_slice %[[RHS]] {offsets = [2, 4], sizes = [2, 4], strides = [1, 1]} : vector<4x8xbf16> to vector<2x4xbf16> +// CHECK: %[[T35:.+]] = vector.insert_strided_slice %[[T33]], %[[LHS_PAD_Z]] {offsets = [0, 0], strides = [1]} : vector<4xbf16> into vector<2x4xbf16> +// CHECK: %[[T36:.+]] = vector.shape_cast %[[T35]] : vector<2x4xbf16> to vector<8xbf16> +// CHECK: %[[T37:.+]] = vector.shape_cast %[[T34]] : vector<2x4xbf16> to vector<8xbf16> +// CHECK: %[[T38:.+]] = arm_neon.intr.bfmmla %[[K_ACC_1]], %[[T36]], %[[T37]] : vector<8xbf16> to vector<4xf32> +// CHECK: %[[T39:.+]] = vector.shape_cast %[[T38]] : vector<4xf32> to vector<2x2xf32> +// CHECK: %[[T40:.+]] = vector.extract %[[T39]][0] : vector<2xf32> from vector<2x2xf32> +// CHECK: %[[RESULT:.+]] = vector.insert_strided_slice %[[T40]], %[[TMP_RES_2]] {offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32> +// CHECK: return %[[RESULT]] : vector<4xf32> +func.func @vector_contract_vecmat_to_bfmmla(%lhs: vector<8xbf16>, + %rhs: vector<4x8xbf16>, + %acc: vector<4xf32>) -> vector<4xf32> { + %0 = vector.contract { indexing_maps = [ + affine_map<(n, k) -> (k)>, + affine_map<(n, k) -> (n, k)>, + affine_map<(n, k) -> (n)> + ], + iterator_types = ["parallel", "reduction"], + kind = #vector.kind<add> + } + %lhs, %rhs, %acc : vector<8xbf16>, vector<4x8xbf16> into vector<4xf32> + + return %0 : vector<4xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%module: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %module : (!transform.any_op) -> !transform.op<"func.func"> + + transform.apply_patterns to %func { + transform.apply_patterns.arm_neon.vector_contract_to_bfmmla + } : !transform.op<"func.func"> + + transform.yield + } +} diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmNeon/vector-contract-bfmmla.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmNeon/vector-contract-bfmmla.mlir new file mode 100644 index 0000000000000..b62ae040f364b --- /dev/null +++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmNeon/vector-contract-bfmmla.mlir @@ -0,0 +1,176 @@ +// REQUIRES: arm-emulator + +// DEFINE: %{compile} = mlir-opt %s \ +// DEFINE: --convert-vector-to-scf --convert-scf-to-cf --convert-vector-to-llvm='enable-arm-neon enable-arm-bf16' \ +// DEFINE: --expand-strided-metadata --convert-to-llvm --finalize-memref-to-llvm \ +// DEFINE: --lower-affine --convert-arith-to-llvm --reconcile-unrealized-casts \ +// DEFINE: -o %t + +// DEFINE: %{entry_point} = main + +// DEFINE: %{run} = %mcr_aarch64_cmd %t -e %{entry_point} -entry-point-result=void --march=aarch64 --mattr="+bf16" \ +// DEFINE: -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%native_mlir_arm_runner_utils + +// RUN: rm -f %t && %{compile} && FileCheck %s --input-file=%t -check-prefix CHECK-IR && %{run} | FileCheck %s + +#packed_maps = [ + affine_map<(m, n, k) -> (m, k)>, + affine_map<(m, n, k) -> (n, k)>, + affine_map<(m, n, k) -> (m, n)> +] + +// +// Test the lowering of `vector.contract` using the `LowerContractionToNeonBFMMLAPattern` +// +// The operation that the `vector.contract` in this test performs is matrix +// multiplication with accumulate +// OUT = ACC + LHS * RHS +// of two BFloat16 matrices LHS and RHS, and a Float32 matrix ACC into a Float32 OUT. +// +// Tested are calculations as well as that the relevant `ArmNeon` dialect +// operation (`arm_neon.intr.bfmmla`) is emitted. +// +// That pattern above handles (therefore this test prepares) input/output vectors with +// specific shapes: +// * LHS: vector<MxKxbf16> +// * RHS: vector<NxKxbf16> +// * ACC, OUT: vector<MxNxf32> +// where the M and N are even and K is divisible by 4. +// Note that the RHS is transposed. +// This data layout makes it efficient to load data into SIMD +// registers in the layout expected by BFMMLA instruction. +// Such a `vector.contract` is representative of the code we aim to generate +// by vectorisation of `linalg.mmt4d`. +// +// In this specific test we use M == 4, N == 4, and K == 4. + +// CHECK-IR-LABEL: llvm.func @matrix_by_matrix_mul_and_acc +// CHECK-IR-COUNT-4: arm_neon.intr.bfmmla +func.func @matrix_by_matrix_mul_and_acc() { + + %c0 = arith.constant 0 : index + %c0_f32 = arith.constant 0.0 : f32 + %c0_bf16 = arith.constant 0.0 : bf16 + + // Accumulator test data + %acc_cst = arith.constant dense<[[ 0.7, 1.0, -0.1, 1.8], + [-0.5, 0.9, 0.7, -0.7], + [ 0.5, -1.3, -2.2, 0.1], + [-0.7, 1.0, 1.7, -1.0]]> : vector<4x4xf32> + + %acc_mem = memref.alloc() : memref<4x4xf32> + vector.transfer_write %acc_cst, %acc_mem[%c0, %c0] {in_bounds = [true, true] } : vector<4x4xf32>, memref<4x4xf32> + %acc = vector.transfer_read %acc_mem[%c0, %c0], %c0_f32 {in_bounds = [true, true]} : memref<4x4xf32>, vector<4x4xf32> + + // LHS test data + %lhs_cst = arith.constant dense<[[ 0.1, 0.7, -0.9, 1.3], + [-1.6, 0.7, -0.3, -0.3], + [-0.4, 0.6, 0.8, -0.5], + [-0.6, -1.0, -1.0, -1.0]]> : vector<4x4xbf16> + + %lhs_mem = memref.alloc() : memref<4x4xbf16> + vector.transfer_write %lhs_cst, %lhs_mem[%c0, %c0] {in_bounds = [true, true] } : vector<4x4xbf16>, memref<4x4xbf16> + %lhs = vector.transfer_read %lhs_mem[%c0, %c0], %c0_bf16 {in_bounds = [true, true]} : memref<4x4xbf16>, vector<4x4xbf16> + + // RHS test data + %rhs_cst = arith.constant dense<[[ 0.6, 1.3, 0.1, -0.9], + [ 0.5, 1.6, 1.8, 1.6], + [-0.2, 0.4, 1.0, 0.4], + [-1.3, -0.2, -2.2, 0.3]]> : vector<4x4xbf16> + + %rhs_mem = memref.alloc() : memref<4x4xbf16> + vector.transfer_write %rhs_cst, %rhs_mem[%c0, %c0] {in_bounds = [true, true] } : vector<4x4xbf16>, memref<4x4xbf16> + %rhs = vector.transfer_read %rhs_mem[%c0, %c0], %c0_bf16 {in_bounds = [true, true]} : memref<4x4xbf16>, vector<4x4xbf16> + + // Matrix multiplication and accumulate with transposed RHS. + %0 = vector.contract {indexing_maps = #packed_maps, + iterator_types = ["parallel", "parallel", "reduction"], + kind = #vector.kind<add>} %lhs, %rhs, %acc + : vector<4x4xbf16>, vector<4x4xbf16> into vector<4x4xf32> + + // Display the result of the multiplication + vector.print str "Result(BFMMLA):\n" + %u0 = vector.extract %0[0] : vector<4xf32> from vector<4x4xf32> + %u1 = vector.extract %0[1] : vector<4xf32> from vector<4x4xf32> + %u2 = vector.extract %0[2] : vector<4xf32> from vector<4x4xf32> + %u3 = vector.extract %0[3] : vector<4xf32> from vector<4x4xf32> + vector.print %u0 : vector<4xf32> + vector.print %u1 : vector<4xf32> + vector.print %u2 : vector<4xf32> + vector.print %u3 : vector<4xf32> + + return +} + +// Test when the LHS is a one-dimensional vector. +// +// In the vector by matrix case the dhapes ae as follows: +// * LHS: vector<Kxbf16> +// * RHS: vector<NxKxbf16> +// * ACC, OUT: vector<Nxf32> +// N is even and K is divisible by 4. +// In this specific test we use N == 4, and K == 4. + +// CHECK-IR-LABEL: llvm.func @vector_by_matrix_mul_and_acc +// CHECK-IR-COUNT-2: arm_neon.intr.bfmmla +func.func @vector_by_matrix_mul_and_acc() { + %c0 = arith.constant 0 : index + %c0_f32 = arith.constant 0.0 : f32 + %c0_bf16 = arith.constant 0.0 : bf16 + + // Accumulator test data + %acc_cst = arith.constant dense<[0.7, 1.0, -0.1, 1.8]> : vector<4xf32> + + %acc_mem = memref.alloc() : memref<4xf32> + vector.transfer_write %acc_cst, %acc_mem[%c0] {in_bounds = [true] } : vector<4xf32>, memref<4xf32> + %acc = vector.transfer_read %acc_mem[%c0], %c0_f32 {in_bounds = [true]} : memref<4xf32>, vector<4xf32> + + // LHS test data + %lhs_cst = arith.constant dense<[0.1, 0.7, -0.9, 1.3]> : vector<4xbf16> + + %lhs_mem = memref.alloc() : memref<4xbf16> + vector.transfer_write %lhs_cst, %lhs_mem[%c0] {in_bounds = [true] } : vector<4xbf16>, memref<4xbf16> + %lhs = vector.transfer_read %lhs_mem[%c0], %c0_bf16 {in_bounds = [true]} : memref<4xbf16>, vector<4xbf16> + + // RHS test data + %rhs_cst = arith.constant dense<[[ 0.6, 1.3, 0.1, -0.9], + [ 0.5, 1.6, 1.8, 1.6], + [-0.2, 0.4, 1.0, 0.4], + [-1.3, -0.2, -2.2, 0.3]]> : vector<4x4xbf16> + + %rhs_mem = memref.alloc() : memref<4x4xbf16> + vector.transfer_write %rhs_cst, %rhs_mem[%c0, %c0] {in_bounds = [true, true] } : vector<4x4xbf16>, memref<4x4xbf16> + %rhs = vector.transfer_read %rhs_mem[%c0, %c0], %c0_bf16 {in_bounds = [true, true]} : memref<4x4xbf16>, vector<4x4xbf16> + + // Vector by matrix multiplication and accumulate with transposed RHS. + %0 = vector.contract { indexing_maps = [ + affine_map<(n, k) -> (k)>, + affine_map<(n, k) -> (n, k)>, + affine_map<(n, k) -> (n)> + ], + iterator_types = ["parallel", "reduction"], + kind = #vector.kind<add> + } + %lhs, %rhs, %acc : vector<4xbf16>, vector<4x4xbf16> into vector<4xf32> + + // Display the result of the multiplication + vector.print str "Result(BFMMLA, vecmat):\n" + vector.print %0 : vector<4xf32> + + return +} + +func.func @main() { + // CHECK-LABEL: Result(BFMMLA): + // CHECK: ( 0.411922, 2.63254, -0.219259, 3.89965 ) + // CHECK: ( -0.316515, 0.196875, 0.879375, 1.80924 ) + // CHECK: ( 1.56867, 0.101367, -1.2784, -1.41579 ) + // CHECK: ( -1.56041, -4.30078, 0.0196488, 1.88269 ) + func.call @matrix_by_matrix_mul_and_acc() : () -> () + + // CHECK-LABEL: Result(BFMMLA, vecmat): + // CHECK: ( 0.411922, 2.63254, -0.219259, 3.89965 ) + func.call @vector_by_matrix_mul_and_acc() : () -> () + + return +} diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmNeon/vector-contract-i8mm.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmNeon/vector-contract-i8mm.mlir index 1ce55ca05c90e..f6012bbd3d0b2 100644 --- a/mlir/test/Integration/Dialect/Vector/CPU/ArmNeon/vector-contract-i8mm.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmNeon/vector-contract-i8mm.mlir @@ -240,7 +240,7 @@ func.func @test_usmmla() { // Test the operation where LHS is interpreted as signed and RHS is interpreted // as unsigned. In this test we ultimately emit end execute the `usmmla` -// instruction with reversed operands, see `LowerContractionToNeonI8MMPattern.cpp` +// instruction with reversed operands, see `LowerContractoNeonPatterns.cpp` // for more details. // CHECK-IR-LABEL: llvm.func @test_summla >From 8dfd754e1d998bc3b2c1df8f38823aed40ff1e18 Mon Sep 17 00:00:00 2001 From: Momchil Velikov <momchil.veli...@arm.com> Date: Mon, 21 Jul 2025 11:16:02 +0000 Subject: [PATCH 2/4] [fixup] Rename a member function and chanege some allocs to allocas --- .../Transforms/LowerContractToNeonPatterns.cpp | 6 +++--- .../Vector/CPU/ArmNeon/vector-contract-bfmmla.mlir | 12 ++++++------ 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractToNeonPatterns.cpp b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractToNeonPatterns.cpp index 1ad563537d874..6112eb586907e 100644 --- a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractToNeonPatterns.cpp +++ b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractToNeonPatterns.cpp @@ -201,7 +201,7 @@ class VectorContractRewriter { } public: - void rewrite(vector::ContractionOp op, PatternRewriter &rewriter) { + void lower(vector::ContractionOp op, PatternRewriter &rewriter) { // Create some convenience types. auto inputElementType = cast<ShapedType>(lhs.getType()).getElementType(); auto accElementType = cast<ShapedType>(acc.getType()).getElementType(); @@ -460,7 +460,7 @@ class LowerContractionToNeonI8MMPattern VectorContractRewriterI8MM vcr; if (failed(vcr.matchAndInit(op, rewriter))) return failure(); - vcr.rewrite(op, rewriter); + vcr.lower(op, rewriter); return success(); } @@ -476,7 +476,7 @@ class LowerContractionToNeonBFMMLAPattern VectorContractRewriterBFMMLA vcr; if (failed(vcr.matchAndInit(op, rewriter))) return failure(); - vcr.rewrite(op, rewriter); + vcr.lower(op, rewriter); return success(); } diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmNeon/vector-contract-bfmmla.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmNeon/vector-contract-bfmmla.mlir index b62ae040f364b..9acc97da0d53c 100644 --- a/mlir/test/Integration/Dialect/Vector/CPU/ArmNeon/vector-contract-bfmmla.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmNeon/vector-contract-bfmmla.mlir @@ -58,7 +58,7 @@ func.func @matrix_by_matrix_mul_and_acc() { [ 0.5, -1.3, -2.2, 0.1], [-0.7, 1.0, 1.7, -1.0]]> : vector<4x4xf32> - %acc_mem = memref.alloc() : memref<4x4xf32> + %acc_mem = memref.alloca() : memref<4x4xf32> vector.transfer_write %acc_cst, %acc_mem[%c0, %c0] {in_bounds = [true, true] } : vector<4x4xf32>, memref<4x4xf32> %acc = vector.transfer_read %acc_mem[%c0, %c0], %c0_f32 {in_bounds = [true, true]} : memref<4x4xf32>, vector<4x4xf32> @@ -68,7 +68,7 @@ func.func @matrix_by_matrix_mul_and_acc() { [-0.4, 0.6, 0.8, -0.5], [-0.6, -1.0, -1.0, -1.0]]> : vector<4x4xbf16> - %lhs_mem = memref.alloc() : memref<4x4xbf16> + %lhs_mem = memref.alloca() : memref<4x4xbf16> vector.transfer_write %lhs_cst, %lhs_mem[%c0, %c0] {in_bounds = [true, true] } : vector<4x4xbf16>, memref<4x4xbf16> %lhs = vector.transfer_read %lhs_mem[%c0, %c0], %c0_bf16 {in_bounds = [true, true]} : memref<4x4xbf16>, vector<4x4xbf16> @@ -78,7 +78,7 @@ func.func @matrix_by_matrix_mul_and_acc() { [-0.2, 0.4, 1.0, 0.4], [-1.3, -0.2, -2.2, 0.3]]> : vector<4x4xbf16> - %rhs_mem = memref.alloc() : memref<4x4xbf16> + %rhs_mem = memref.alloca() : memref<4x4xbf16> vector.transfer_write %rhs_cst, %rhs_mem[%c0, %c0] {in_bounds = [true, true] } : vector<4x4xbf16>, memref<4x4xbf16> %rhs = vector.transfer_read %rhs_mem[%c0, %c0], %c0_bf16 {in_bounds = [true, true]} : memref<4x4xbf16>, vector<4x4xbf16> @@ -121,14 +121,14 @@ func.func @vector_by_matrix_mul_and_acc() { // Accumulator test data %acc_cst = arith.constant dense<[0.7, 1.0, -0.1, 1.8]> : vector<4xf32> - %acc_mem = memref.alloc() : memref<4xf32> + %acc_mem = memref.alloca() : memref<4xf32> vector.transfer_write %acc_cst, %acc_mem[%c0] {in_bounds = [true] } : vector<4xf32>, memref<4xf32> %acc = vector.transfer_read %acc_mem[%c0], %c0_f32 {in_bounds = [true]} : memref<4xf32>, vector<4xf32> // LHS test data %lhs_cst = arith.constant dense<[0.1, 0.7, -0.9, 1.3]> : vector<4xbf16> - %lhs_mem = memref.alloc() : memref<4xbf16> + %lhs_mem = memref.alloca() : memref<4xbf16> vector.transfer_write %lhs_cst, %lhs_mem[%c0] {in_bounds = [true] } : vector<4xbf16>, memref<4xbf16> %lhs = vector.transfer_read %lhs_mem[%c0], %c0_bf16 {in_bounds = [true]} : memref<4xbf16>, vector<4xbf16> @@ -138,7 +138,7 @@ func.func @vector_by_matrix_mul_and_acc() { [-0.2, 0.4, 1.0, 0.4], [-1.3, -0.2, -2.2, 0.3]]> : vector<4x4xbf16> - %rhs_mem = memref.alloc() : memref<4x4xbf16> + %rhs_mem = memref.alloca() : memref<4x4xbf16> vector.transfer_write %rhs_cst, %rhs_mem[%c0, %c0] {in_bounds = [true, true] } : vector<4x4xbf16>, memref<4x4xbf16> %rhs = vector.transfer_read %rhs_mem[%c0, %c0], %c0_bf16 {in_bounds = [true, true]} : memref<4x4xbf16>, vector<4x4xbf16> >From e9a91c0702165277e70259b5d43bf6a0693b112e Mon Sep 17 00:00:00 2001 From: Momchil Velikov <momchil.veli...@arm.com> Date: Mon, 21 Jul 2025 13:01:35 +0000 Subject: [PATCH 3/4] [fixup] Add a comment about memory ops --- .../Dialect/Vector/CPU/ArmNeon/vector-contract-bfmmla.mlir | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmNeon/vector-contract-bfmmla.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmNeon/vector-contract-bfmmla.mlir index 9acc97da0d53c..368f332e40602 100644 --- a/mlir/test/Integration/Dialect/Vector/CPU/ArmNeon/vector-contract-bfmmla.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmNeon/vector-contract-bfmmla.mlir @@ -110,6 +110,10 @@ func.func @matrix_by_matrix_mul_and_acc() { // * ACC, OUT: vector<Nxf32> // N is even and K is divisible by 4. // In this specific test we use N == 4, and K == 4. +// Note: the seemingly unnecessary writes of test vectors to memory are done +// in order to introduce memory load operations on the path leading up to +// `vector.contract` since that's more representation of the typical usage +// when multiplying big, non-constant tensors. // CHECK-IR-LABEL: llvm.func @vector_by_matrix_mul_and_acc // CHECK-IR-COUNT-2: arm_neon.intr.bfmmla >From 0e349f5e12865f44fedad9ee13ebd68007c94fbb Mon Sep 17 00:00:00 2001 From: Momchil Velikov <momchil.veli...@arm.com> Date: Mon, 21 Jul 2025 13:10:58 +0000 Subject: [PATCH 4/4] [fixup] Move a comment, it was accidentally in the wrong place --- .../Vector/CPU/ArmNeon/vector-contract-bfmmla.mlir | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmNeon/vector-contract-bfmmla.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmNeon/vector-contract-bfmmla.mlir index 368f332e40602..4285260906251 100644 --- a/mlir/test/Integration/Dialect/Vector/CPU/ArmNeon/vector-contract-bfmmla.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmNeon/vector-contract-bfmmla.mlir @@ -44,6 +44,12 @@ // // In this specific test we use M == 4, N == 4, and K == 4. +// Note: In this and in the following test the seemingly unnecessary +// writes of test vectors to memory are done in order to introduce memory +// load operations on the path leading up to `vector.contract` since +// that's more representation of the typical usage when multiplying +// big, non-constant tensors. + // CHECK-IR-LABEL: llvm.func @matrix_by_matrix_mul_and_acc // CHECK-IR-COUNT-4: arm_neon.intr.bfmmla func.func @matrix_by_matrix_mul_and_acc() { @@ -110,10 +116,6 @@ func.func @matrix_by_matrix_mul_and_acc() { // * ACC, OUT: vector<Nxf32> // N is even and K is divisible by 4. // In this specific test we use N == 4, and K == 4. -// Note: the seemingly unnecessary writes of test vectors to memory are done -// in order to introduce memory load operations on the path leading up to -// `vector.contract` since that's more representation of the typical usage -// when multiplying big, non-constant tensors. // CHECK-IR-LABEL: llvm.func @vector_by_matrix_mul_and_acc // CHECK-IR-COUNT-2: arm_neon.intr.bfmmla _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits