https://github.com/momchil-velikov updated https://github.com/llvm/llvm-project/pull/135634
>From 528237309c0bfd7bbb51a8fea37b54e07f21ad1d Mon Sep 17 00:00:00 2001 From: Momchil Velikov <momchil.veli...@arm.com> Date: Thu, 10 Apr 2025 14:38:27 +0000 Subject: [PATCH] [MLIR][ArmSVE] Add an ArmSVE dialect operation which maps to `svusmmla` --- mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td | 95 +++++++++++-------- .../Transforms/LegalizeForLLVMExport.cpp | 4 + .../Dialect/ArmSVE/legalize-for-llvm.mlir | 12 +++ mlir/test/Dialect/ArmSVE/roundtrip.mlir | 11 +++ mlir/test/Target/LLVMIR/arm-sve.mlir | 12 +++ 5 files changed, 96 insertions(+), 38 deletions(-) diff --git a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td index 3a990f8464ef8..7385bb73b449a 100644 --- a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td +++ b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td @@ -147,11 +147,9 @@ class ScalableMaskedIOp<string mnemonic, string op_description, "$mask `,` $src1 `,` $src2 attr-dict `:` type($mask) `,` type($res)"; } -def SdotOp : ArmSVE_Op<"sdot", - [Pure, - AllTypesMatch<["src1", "src2"]>, - AllTypesMatch<["acc", "dst"]>, - ]> { +def SdotOp : ArmSVE_Op<"sdot", [Pure, + AllTypesMatch<["src1", "src2"]>, + AllTypesMatch<["acc", "dst"]>]> { let summary = "Vector-vector dot product and accumulate op"; let description = [{ SDOT: Signed integer addition of dot product. @@ -178,11 +176,9 @@ def SdotOp : ArmSVE_Op<"sdot", "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)"; } -def SmmlaOp : ArmSVE_Op<"smmla", - [Pure, - AllTypesMatch<["src1", "src2"]>, - AllTypesMatch<["acc", "dst"]>, - ]> { +def SmmlaOp : ArmSVE_Op<"smmla", [Pure, + AllTypesMatch<["src1", "src2"]>, + AllTypesMatch<["acc", "dst"]>]> { let summary = "Matrix-matrix multiply and accumulate op"; let description = [{ SMMLA: Signed integer matrix multiply-accumulate. @@ -210,11 +206,9 @@ def SmmlaOp : ArmSVE_Op<"smmla", "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)"; } -def UdotOp : ArmSVE_Op<"udot", - [Pure, - AllTypesMatch<["src1", "src2"]>, - AllTypesMatch<["acc", "dst"]>, - ]> { +def UdotOp : ArmSVE_Op<"udot", [Pure, + AllTypesMatch<["src1", "src2"]>, + AllTypesMatch<["acc", "dst"]>]> { let summary = "Vector-vector dot product and accumulate op"; let description = [{ UDOT: Unsigned integer addition of dot product. @@ -241,11 +235,9 @@ def UdotOp : ArmSVE_Op<"udot", "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)"; } -def UmmlaOp : ArmSVE_Op<"ummla", - [Pure, - AllTypesMatch<["src1", "src2"]>, - AllTypesMatch<["acc", "dst"]>, - ]> { +def UmmlaOp : ArmSVE_Op<"ummla", [Pure, + AllTypesMatch<["src1", "src2"]>, + AllTypesMatch<["acc", "dst"]>]> { let summary = "Matrix-matrix multiply and accumulate op"; let description = [{ UMMLA: Unsigned integer matrix multiply-accumulate. @@ -273,14 +265,42 @@ def UmmlaOp : ArmSVE_Op<"ummla", "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)"; } +def UsmmlaOp : ArmSVE_Op<"usmmla", [Pure, + AllTypesMatch<["src1", "src2"]>, + AllTypesMatch<["acc", "dst"]>]> { + let summary = "Matrix-matrix multiply and accumulate op"; + let description = [{ + USMMLA: Unsigned by signed integer matrix multiply-accumulate. + + The unsigned by signed integer matrix multiply-accumulate operation + multiplies the 2×8 matrix of unsigned 8-bit integer values held + the first source vector by the 8×2 matrix of signed 8-bit integer + values in the second source vector. The resulting 2×2 widened 32-bit + integer matrix product is then added to the 32-bit integer matrix + accumulator. + + Source: + https://developer.arm.com/documentation/100987/0000 + }]; + // Supports (vector<16xi8>, vector<16xi8>) -> (vector<4xi32>) + let arguments = (ins + ScalableVectorOfLengthAndType<[4], [I32]>:$acc, + ScalableVectorOfLengthAndType<[16], [I8]>:$src1, + ScalableVectorOfLengthAndType<[16], [I8]>:$src2 + ); + let results = (outs ScalableVectorOfLengthAndType<[4], [I32]>:$dst); + let assemblyFormat = + "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)"; +} + class SvboolTypeConstraint<string lhsArg, string rhsArg> : TypesMatchWith< "expected corresponding svbool type widened to [16]xi1", lhsArg, rhsArg, "VectorType(VectorType::Builder(::llvm::cast<VectorType>($_self)).setDim(::llvm::cast<VectorType>($_self).getRank() - 1, 16))">; def ConvertFromSvboolOp : ArmSVE_Op<"convert_from_svbool", - [Pure, SvboolTypeConstraint<"result", "source">]> -{ + [Pure, + SvboolTypeConstraint<"result", "source">]> { let summary = "Convert a svbool type to a SVE predicate type"; let description = [{ Converts svbool types (`vector<[16]xi1>` or vectors of that type, e.g. @@ -313,8 +333,8 @@ def ConvertFromSvboolOp : ArmSVE_Op<"convert_from_svbool", } def ConvertToSvboolOp : ArmSVE_Op<"convert_to_svbool", - [Pure, SvboolTypeConstraint<"source", "result">]> -{ + [Pure, + SvboolTypeConstraint<"source", "result">]> { let summary = "Convert a SVE predicate type to a svbool type"; let description = [{ Converts SVE predicate types (or vectors of predicate types, e.g. @@ -356,10 +376,9 @@ def ZipInputVectorType : AnyTypeOf<[ Scalable1DVectorOfLength<16, [I8]>], "an SVE vector with element size <= 64-bit">; -def ZipX2Op : ArmSVE_Op<"zip.x2", [ - Pure, - AllTypesMatch<["sourceV1", "sourceV2", "resultV1", "resultV2"]>] -> { +def ZipX2Op : ArmSVE_Op<"zip.x2", [Pure, + AllTypesMatch<["sourceV1", "sourceV2", + "resultV1", "resultV2"]>]> { let summary = "Multi-vector two-way zip op"; let description = [{ @@ -400,12 +419,11 @@ def ZipX2Op : ArmSVE_Op<"zip.x2", [ }]; } -def ZipX4Op : ArmSVE_Op<"zip.x4", [ - Pure, - AllTypesMatch<[ - "sourceV1", "sourceV2", "sourceV3", "sourceV4", - "resultV1", "resultV2", "resultV3", "resultV4"]>] -> { +def ZipX4Op + : ArmSVE_Op<"zip.x4", + [Pure, + AllTypesMatch<["sourceV1", "sourceV2", "sourceV3", "sourceV4", + "resultV1", "resultV2", "resultV3", "resultV4"]>]> { let summary = "Multi-vector four-way zip op"; let description = [{ @@ -463,10 +481,7 @@ def ZipX4Op : ArmSVE_Op<"zip.x4", [ }]; } -def PselOp : ArmSVE_Op<"psel", [ - Pure, - AllTypesMatch<["p1", "result"]>, -]> { +def PselOp : ArmSVE_Op<"psel", [Pure, AllTypesMatch<["p1", "result"]>]> { let summary = "Predicate select"; let description = [{ @@ -571,6 +586,10 @@ def SmmlaIntrOp : ArmSVE_IntrBinaryOverloadedOp<"smmla">, Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>; +def UsmmlaIntrOp : + ArmSVE_IntrBinaryOverloadedOp<"usmmla">, + Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>; + def SdotIntrOp : ArmSVE_IntrBinaryOverloadedOp<"sdot">, Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>; diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp index 536373b82c67f..35f2a02cc4ec6 100644 --- a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp +++ b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp @@ -24,6 +24,7 @@ using SdotOpLowering = OneToOneConvertToLLVMPattern<SdotOp, SdotIntrOp>; using SmmlaOpLowering = OneToOneConvertToLLVMPattern<SmmlaOp, SmmlaIntrOp>; using UdotOpLowering = OneToOneConvertToLLVMPattern<UdotOp, UdotIntrOp>; using UmmlaOpLowering = OneToOneConvertToLLVMPattern<UmmlaOp, UmmlaIntrOp>; +using UsmmlaOpLowering = OneToOneConvertToLLVMPattern<UsmmlaOp, UsmmlaIntrOp>; using DupQLaneLowering = OneToOneConvertToLLVMPattern<DupQLaneOp, DupQLaneIntrOp>; using ScalableMaskedAddIOpLowering = @@ -206,6 +207,7 @@ void mlir::populateArmSVELegalizeForLLVMExportPatterns( SmmlaOpLowering, UdotOpLowering, UmmlaOpLowering, + UsmmlaOpLowering, ZipX2OpLowering, ZipX4OpLowering, SdotOpLowering>(converter); @@ -234,6 +236,7 @@ void mlir::configureArmSVELegalizeForExportTarget( SmmlaIntrOp, UdotIntrOp, UmmlaIntrOp, + UsmmlaIntrOp, WhileLTIntrOp, ZipX2IntrOp, ZipX4IntrOp, @@ -254,6 +257,7 @@ void mlir::configureArmSVELegalizeForExportTarget( SmmlaOp, UdotOp, UmmlaOp, + UsmmlaOp, ZipX2Op, ZipX4Op, SdotOp>(); diff --git a/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir b/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir index 650b3e72d4ecd..8c658db009adf 100644 --- a/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir +++ b/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir @@ -48,6 +48,18 @@ func.func @arm_sve_ummla(%a: vector<[16]xi8>, // ----- +func.func @arm_sve_usmmla(%a: vector<[16]xi8>, + %b: vector<[16]xi8>, + %c: vector<[4]xi32>) + -> vector<[4]xi32> { + // CHECK: arm_sve.intr.usmmla + %0 = arm_sve.usmmla %c, %a, %b : + vector<[16]xi8> to vector<[4]xi32> + return %0 : vector<[4]xi32> +} + +// ----- + func.func @arm_sve_arithi_masked(%a: vector<[4]xi32>, %b: vector<[4]xi32>, %c: vector<[4]xi32>, diff --git a/mlir/test/Dialect/ArmSVE/roundtrip.mlir b/mlir/test/Dialect/ArmSVE/roundtrip.mlir index 0f0c5a8575772..64e0cff39eb06 100644 --- a/mlir/test/Dialect/ArmSVE/roundtrip.mlir +++ b/mlir/test/Dialect/ArmSVE/roundtrip.mlir @@ -44,6 +44,17 @@ func.func @arm_sve_ummla(%a: vector<[16]xi8>, // ----- +func.func @arm_sve_usmmla(%a: vector<[16]xi8>, + %b: vector<[16]xi8>, + %c: vector<[4]xi32>) -> vector<[4]xi32> { + // CHECK: arm_sve.usmmla {{.*}}: vector<[16]xi8> to vector<[4]xi3 + %0 = arm_sve.usmmla %c, %a, %b : + vector<[16]xi8> to vector<[4]xi32> + return %0 : vector<[4]xi32> +} + +// ----- + func.func @arm_sve_masked_arithi(%a: vector<[4]xi32>, %b: vector<[4]xi32>, %c: vector<[4]xi32>, diff --git a/mlir/test/Target/LLVMIR/arm-sve.mlir b/mlir/test/Target/LLVMIR/arm-sve.mlir index 14c68b21fd86c..da71cb5a63bd2 100644 --- a/mlir/test/Target/LLVMIR/arm-sve.mlir +++ b/mlir/test/Target/LLVMIR/arm-sve.mlir @@ -48,6 +48,18 @@ llvm.func @arm_sve_ummla(%arg0: vector<[16]xi8>, llvm.return %0 : vector<[4]xi32> } +// CHECK-LABEL: define <vscale x 4 x i32> @arm_sve_usmmla +llvm.func @arm_sve_usmmla(%arg0: vector<[16]xi8>, + %arg1: vector<[16]xi8>, + %arg2: vector<[4]xi32>) + -> vector<[4]xi32> { + // CHECK: call <vscale x 4 x i32> @llvm.aarch64.sve.usmmla.nxv4i32(<vscale x 4 + %0 = "arm_sve.intr.usmmla"(%arg2, %arg0, %arg1) : + (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>) + -> vector<[4]xi32> + llvm.return %0 : vector<[4]xi32> +} + // CHECK-LABEL: define <vscale x 4 x i32> @arm_sve_arithi llvm.func @arm_sve_arithi(%arg0: vector<[4]xi32>, %arg1: vector<[4]xi32>, _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits