https://github.com/momchil-velikov created https://github.com/llvm/llvm-project/pull/135634
Supersedes https://github.com/llvm/llvm-project/pull/135358 >From 71e2f13ad5922bf93961c5d81fd9d1f5899c80b0 Mon Sep 17 00:00:00 2001 From: Momchil Velikov <momchil.veli...@arm.com> Date: Thu, 10 Apr 2025 14:38:27 +0000 Subject: [PATCH] [MLIR][ArmSVE] Add an ArmSVE dialect operation which maps to `svusmmla` --- mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td | 32 +++++++++++++++++++ .../Transforms/LegalizeForLLVMExport.cpp | 4 +++ .../Dialect/ArmSVE/legalize-for-llvm.mlir | 12 +++++++ mlir/test/Dialect/ArmSVE/roundtrip.mlir | 11 +++++++ mlir/test/Target/LLVMIR/arm-sve.mlir | 12 +++++++ 5 files changed, 71 insertions(+) diff --git a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td index 1a59062ccc93d..da2a8f89b4cfd 100644 --- a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td +++ b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td @@ -273,6 +273,34 @@ def UmmlaOp : ArmSVE_Op<"ummla", "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)"; } +def UsmmlaOp : ArmSVE_Op<"usmmla", [Pure, + AllTypesMatch<["src1", "src2"]>, + AllTypesMatch<["acc", "dst"]>]> { + let summary = "Matrix-matrix multiply and accumulate op"; + let description = [{ + USMMLA: Unsigned by signed integer matrix multiply-accumulate. + + The unsigned by signed integer matrix multiply-accumulate operation + multiplies the 2×8 matrix of unsigned 8-bit integer values held + the first source vector by the 8×2 matrix of signed 8-bit integer + values in the second source vector. The resulting 2×2 widened 32-bit + integer matrix product is then added to the 32-bit integer matrix + accumulator. + + Source: + https://developer.arm.com/documentation/100987/0000 + }]; + // Supports (vector<16xi8>, vector<16xi8>) -> (vector<4xi32>) + let arguments = (ins + ScalableVectorOfLengthAndType<[4], [I32]>:$acc, + ScalableVectorOfLengthAndType<[16], [I8]>:$src1, + ScalableVectorOfLengthAndType<[16], [I8]>:$src2 + ); + let results = (outs ScalableVectorOfLengthAndType<[4], [I32]>:$dst); + let assemblyFormat = + "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)"; +} + class SvboolTypeConstraint<string lhsArg, string rhsArg> : TypesMatchWith< "expected corresponding svbool type widened to [16]xi1", lhsArg, rhsArg, @@ -568,6 +596,10 @@ def SmmlaIntrOp : ArmSVE_IntrBinaryOverloadedOp<"smmla">, Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>; +def UsmmlaIntrOp : + ArmSVE_IntrBinaryOverloadedOp<"usmmla">, + Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>; + def SdotIntrOp : ArmSVE_IntrBinaryOverloadedOp<"sdot">, Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>; diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp index fe13ed03356b2..b1846e15196fc 100644 --- a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp +++ b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp @@ -24,6 +24,7 @@ using SdotOpLowering = OneToOneConvertToLLVMPattern<SdotOp, SdotIntrOp>; using SmmlaOpLowering = OneToOneConvertToLLVMPattern<SmmlaOp, SmmlaIntrOp>; using UdotOpLowering = OneToOneConvertToLLVMPattern<UdotOp, UdotIntrOp>; using UmmlaOpLowering = OneToOneConvertToLLVMPattern<UmmlaOp, UmmlaIntrOp>; +using UsmmlaOpLowering = OneToOneConvertToLLVMPattern<UsmmlaOp, UsmmlaIntrOp>; using DupQLaneLowering = OneToOneConvertToLLVMPattern<DupQLaneOp, DupQLaneIntrOp>; using ScalableMaskedAddIOpLowering = @@ -194,6 +195,7 @@ void mlir::populateArmSVELegalizeForLLVMExportPatterns( SmmlaOpLowering, UdotOpLowering, UmmlaOpLowering, + UsmmlaOpLowering, DupQLaneLowering, ScalableMaskedAddIOpLowering, ScalableMaskedAddFOpLowering, @@ -222,6 +224,7 @@ void mlir::configureArmSVELegalizeForExportTarget( SmmlaIntrOp, UdotIntrOp, UmmlaIntrOp, + UsmmlaIntrOp, DupQLaneIntrOp, ScalableMaskedAddIIntrOp, ScalableMaskedAddFIntrOp, @@ -242,6 +245,7 @@ void mlir::configureArmSVELegalizeForExportTarget( SmmlaOp, UdotOp, UmmlaOp, + UsmmlaOp, DupQLaneOp, ScalableMaskedAddIOp, ScalableMaskedAddFOp, diff --git a/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir b/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir index 5d044517e0ea8..47587aa26506c 100644 --- a/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir +++ b/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir @@ -48,6 +48,18 @@ func.func @arm_sve_ummla(%a: vector<[16]xi8>, // ----- +func.func @arm_sve_usmmla(%a: vector<[16]xi8>, + %b: vector<[16]xi8>, + %c: vector<[4]xi32>) + -> vector<[4]xi32> { + // CHECK: arm_sve.intr.usmmla + %0 = arm_sve.usmmla %c, %a, %b : + vector<[16]xi8> to vector<[4]xi32> + return %0 : vector<[4]xi32> +} + +// ----- + func.func @arm_sve_arithi_masked(%a: vector<[4]xi32>, %b: vector<[4]xi32>, %c: vector<[4]xi32>, diff --git a/mlir/test/Dialect/ArmSVE/roundtrip.mlir b/mlir/test/Dialect/ArmSVE/roundtrip.mlir index 0f0c5a8575772..64e0cff39eb06 100644 --- a/mlir/test/Dialect/ArmSVE/roundtrip.mlir +++ b/mlir/test/Dialect/ArmSVE/roundtrip.mlir @@ -44,6 +44,17 @@ func.func @arm_sve_ummla(%a: vector<[16]xi8>, // ----- +func.func @arm_sve_usmmla(%a: vector<[16]xi8>, + %b: vector<[16]xi8>, + %c: vector<[4]xi32>) -> vector<[4]xi32> { + // CHECK: arm_sve.usmmla {{.*}}: vector<[16]xi8> to vector<[4]xi3 + %0 = arm_sve.usmmla %c, %a, %b : + vector<[16]xi8> to vector<[4]xi32> + return %0 : vector<[4]xi32> +} + +// ----- + func.func @arm_sve_masked_arithi(%a: vector<[4]xi32>, %b: vector<[4]xi32>, %c: vector<[4]xi32>, diff --git a/mlir/test/Target/LLVMIR/arm-sve.mlir b/mlir/test/Target/LLVMIR/arm-sve.mlir index ced59eb513b57..4d9b0da611cb0 100644 --- a/mlir/test/Target/LLVMIR/arm-sve.mlir +++ b/mlir/test/Target/LLVMIR/arm-sve.mlir @@ -48,6 +48,18 @@ llvm.func @arm_sve_ummla(%arg0: vector<[16]xi8>, llvm.return %0 : vector<[4]xi32> } +// CHECK-LABEL: define <vscale x 4 x i32> @arm_sve_usmmla +llvm.func @arm_sve_usmmla(%arg0: vector<[16]xi8>, + %arg1: vector<[16]xi8>, + %arg2: vector<[4]xi32>) + -> vector<[4]xi32> { + // CHECK: call <vscale x 4 x i32> @llvm.aarch64.sve.usmmla.nxv4i32(<vscale x 4 + %0 = "arm_sve.intr.usmmla"(%arg2, %arg0, %arg1) : + (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>) + -> vector<[4]xi32> + llvm.return %0 : vector<[4]xi32> +} + // CHECK-LABEL: define <vscale x 4 x i32> @arm_sve_arithi llvm.func @arm_sve_arithi(%arg0: vector<[4]xi32>, %arg1: vector<[4]xi32>, _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits